diff --git a/README.md b/README.md index d45e7afa1..2f3b7d2fd 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,6 @@ # pie-modules -PyTorch -Lightning -PyTorch-IE
+PythonIE
[![PyPI](https://img.shields.io/pypi/v/pie-modules.svg)][pypi status] [![Tests](https://github.com/arnebinder/pie-modules/workflows/Tests/badge.svg)][tests] @@ -10,31 +8,16 @@ [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)][pre-commit] [![Black](https://img.shields.io/badge/code%20style-black-000000.svg)][black] -Model-, taskmodule-, and metric-implementations as well as document processing utilities for [PyTorch-IE](https://github.com/ChristophAlt/pytorch-ie). +Annotation-, document- and metric implementations as well as utilities for [Python-IE](https://github.com/ArneBinder/pie-core). -Available models: +Available annotation types: see [here](src/pie_modules/annotations.py). -- [SimpleSequenceClassificationModel](src/pie_modules/models/simple_sequence_classification.py) -- [SequenceClassificationModelWithPooler](src/pie_modules/models/sequence_classification_with_pooler.py) -- [SequencePairSimilarityModelWithPooler](src/pie_modules/models/sequence_classification_with_pooler.py) -- [SimpleTokenClassificationModel](src/pie_modules/models/simple_token_classification.py) -- [TokenClassificationModelWithSeq2SeqEncoderAndCrf](src/pie_modules/models/token_classification_with_seq2seq_encoder_and_crf.py) -- [SimpleExtractiveQuestionAnsweringModel](src/pie_modules/models/simple_extractive_question_answering.py) -- [SimpleGenerativeModel](src/pie_modules/models/simple_generative.py) -- [SpanTupleClassificationModel](src/pie_modules/models/span_tuple_classification.py) - -Available taskmodules: - -- [RETextClassificationWithIndicesTaskModule](src/pie_modules/taskmodules/re_text_classification_with_indices.py) -- [CrossTextBinaryCorefTaskModule](src/pie_modules/taskmodules/cross_text_binary_coref.py) -- [LabeledSpanExtractionByTokenClassificationTaskModule](src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py) -- [ExtractiveQuestionAnsweringTaskModule](src/pie_modules/taskmodules/extractive_question_answering.py) -- [TextToTextTaskModule](src/pie_modules/taskmodules/text_to_text.py) -- [PointerNetworkTaskModuleForEnd2EndRE](src/pie_modules/taskmodules/pointer_network_for_end2end_re.py) -- [RESpanPairClassificationTaskModule](src/pie_modules/taskmodules/re_span_pair_classification.py) +Available document types: see [here](src/pie_modules/documents.py). Available metrics: +- [F1Metric](src/pie_modules/metrics/f1.py) +- [ConfusionMatrix](src/pie_modules/metrics/confusion_matrix.py) - [SpanLengthCollector](src/pie_modules/metrics/span_length_collector.py) - [RelationArgumentDistanceCollector](src/pie_modules/metrics/relation_argument_distance_collector.py) - [SpanCoverageCollector](src/pie_modules/metrics/span_coverage_collector.py) @@ -48,7 +31,7 @@ Document processing utilities: - [RelationArgumentSorter](src/pie_modules/document/processing/relation_argument_sorter.py) - [SentenceSplitter](src/pie_modules/document/processing/sentence_splitter.py) - [TextSpanTrimmer](src/pie_modules/document/processing/text_span_trimmer.py) -- [tokenize_document](src/pie_modules/document/processing/tokenization.py) +- [tokenization utils](src/pie_modules/document/processing/tokenization.py), e.g., `text_based_document_to_token_based` and `token_based_document_to_text_based` ## Setup diff --git a/poetry.lock b/poetry.lock index 911cda7bf..5fbd4d80e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,17 +1,5 @@ # This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. -[[package]] -name = "absl-py" -version = "1.4.0" -description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." -optional = false -python-versions = ">=3.6" -groups = ["main"] -files = [ - {file = "absl-py-1.4.0.tar.gz", hash = "sha256:d2c244d01048ba476e7c080bd2c6df5e141d211de80223460d5b3b8a2a58433d"}, - {file = "absl_py-1.4.0-py3-none-any.whl", hash = "sha256:0d3fe606adfa4f7db64792dd4c7aee4ee0c38ab75dfd353b7a83ed3e957fcb47"}, -] - [[package]] name = "accelerate" version = "0.32.1" @@ -44,151 +32,6 @@ test-prod = ["parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "py test-trackers = ["comet-ml", "dvclive", "tensorboard", "wandb"] testing = ["bitsandbytes", "datasets", "diffusers", "evaluate", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"] -[[package]] -name = "aiohttp" -version = "3.9.5" -description = "Async http client/server framework (asyncio)" -optional = false -python-versions = ">=3.8" -groups = ["main"] -files = [ - {file = "aiohttp-3.9.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:fcde4c397f673fdec23e6b05ebf8d4751314fa7c24f93334bf1f1364c1c69ac7"}, - {file = "aiohttp-3.9.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5d6b3f1fabe465e819aed2c421a6743d8debbde79b6a8600739300630a01bf2c"}, - {file = "aiohttp-3.9.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6ae79c1bc12c34082d92bf9422764f799aee4746fd7a392db46b7fd357d4a17a"}, - {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4d3ebb9e1316ec74277d19c5f482f98cc65a73ccd5430540d6d11682cd857430"}, - {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84dabd95154f43a2ea80deffec9cb44d2e301e38a0c9d331cc4aa0166fe28ae3"}, - {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c8a02fbeca6f63cb1f0475c799679057fc9268b77075ab7cf3f1c600e81dd46b"}, - {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c26959ca7b75ff768e2776d8055bf9582a6267e24556bb7f7bd29e677932be72"}, - {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:714d4e5231fed4ba2762ed489b4aec07b2b9953cf4ee31e9871caac895a839c0"}, - {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e7a6a8354f1b62e15d48e04350f13e726fa08b62c3d7b8401c0a1314f02e3558"}, - {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:c413016880e03e69d166efb5a1a95d40f83d5a3a648d16486592c49ffb76d0db"}, - {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:ff84aeb864e0fac81f676be9f4685f0527b660f1efdc40dcede3c251ef1e867f"}, - {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ad7f2919d7dac062f24d6f5fe95d401597fbb015a25771f85e692d043c9d7832"}, - {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:702e2c7c187c1a498a4e2b03155d52658fdd6fda882d3d7fbb891a5cf108bb10"}, - {file = "aiohttp-3.9.5-cp310-cp310-win32.whl", hash = "sha256:67c3119f5ddc7261d47163ed86d760ddf0e625cd6246b4ed852e82159617b5fb"}, - {file = "aiohttp-3.9.5-cp310-cp310-win_amd64.whl", hash = "sha256:471f0ef53ccedec9995287f02caf0c068732f026455f07db3f01a46e49d76bbb"}, - {file = "aiohttp-3.9.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e0ae53e33ee7476dd3d1132f932eeb39bf6125083820049d06edcdca4381f342"}, - {file = "aiohttp-3.9.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c088c4d70d21f8ca5c0b8b5403fe84a7bc8e024161febdd4ef04575ef35d474d"}, - {file = "aiohttp-3.9.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:639d0042b7670222f33b0028de6b4e2fad6451462ce7df2af8aee37dcac55424"}, - {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f26383adb94da5e7fb388d441bf09c61e5e35f455a3217bfd790c6b6bc64b2ee"}, - {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:66331d00fb28dc90aa606d9a54304af76b335ae204d1836f65797d6fe27f1ca2"}, - {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4ff550491f5492ab5ed3533e76b8567f4b37bd2995e780a1f46bca2024223233"}, - {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f22eb3a6c1080d862befa0a89c380b4dafce29dc6cd56083f630073d102eb595"}, - {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a81b1143d42b66ffc40a441379387076243ef7b51019204fd3ec36b9f69e77d6"}, - {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f64fd07515dad67f24b6ea4a66ae2876c01031de91c93075b8093f07c0a2d93d"}, - {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:93e22add827447d2e26d67c9ac0161756007f152fdc5210277d00a85f6c92323"}, - {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:55b39c8684a46e56ef8c8d24faf02de4a2b2ac60d26cee93bc595651ff545de9"}, - {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4715a9b778f4293b9f8ae7a0a7cef9829f02ff8d6277a39d7f40565c737d3771"}, - {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:afc52b8d969eff14e069a710057d15ab9ac17cd4b6753042c407dcea0e40bf75"}, - {file = "aiohttp-3.9.5-cp311-cp311-win32.whl", hash = "sha256:b3df71da99c98534be076196791adca8819761f0bf6e08e07fd7da25127150d6"}, - {file = "aiohttp-3.9.5-cp311-cp311-win_amd64.whl", hash = "sha256:88e311d98cc0bf45b62fc46c66753a83445f5ab20038bcc1b8a1cc05666f428a"}, - {file = "aiohttp-3.9.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:c7a4b7a6cf5b6eb11e109a9755fd4fda7d57395f8c575e166d363b9fc3ec4678"}, - {file = "aiohttp-3.9.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:0a158704edf0abcac8ac371fbb54044f3270bdbc93e254a82b6c82be1ef08f3c"}, - {file = "aiohttp-3.9.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d153f652a687a8e95ad367a86a61e8d53d528b0530ef382ec5aaf533140ed00f"}, - {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82a6a97d9771cb48ae16979c3a3a9a18b600a8505b1115cfe354dfb2054468b4"}, - {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60cdbd56f4cad9f69c35eaac0fbbdf1f77b0ff9456cebd4902f3dd1cf096464c"}, - {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8676e8fd73141ded15ea586de0b7cda1542960a7b9ad89b2b06428e97125d4fa"}, - {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da00da442a0e31f1c69d26d224e1efd3a1ca5bcbf210978a2ca7426dfcae9f58"}, - {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18f634d540dd099c262e9f887c8bbacc959847cfe5da7a0e2e1cf3f14dbf2daf"}, - {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:320e8618eda64e19d11bdb3bd04ccc0a816c17eaecb7e4945d01deee2a22f95f"}, - {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:2faa61a904b83142747fc6a6d7ad8fccff898c849123030f8e75d5d967fd4a81"}, - {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:8c64a6dc3fe5db7b1b4d2b5cb84c4f677768bdc340611eca673afb7cf416ef5a"}, - {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:393c7aba2b55559ef7ab791c94b44f7482a07bf7640d17b341b79081f5e5cd1a"}, - {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c671dc117c2c21a1ca10c116cfcd6e3e44da7fcde37bf83b2be485ab377b25da"}, - {file = "aiohttp-3.9.5-cp312-cp312-win32.whl", hash = "sha256:5a7ee16aab26e76add4afc45e8f8206c95d1d75540f1039b84a03c3b3800dd59"}, - {file = "aiohttp-3.9.5-cp312-cp312-win_amd64.whl", hash = "sha256:5ca51eadbd67045396bc92a4345d1790b7301c14d1848feaac1d6a6c9289e888"}, - {file = "aiohttp-3.9.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:694d828b5c41255e54bc2dddb51a9f5150b4eefa9886e38b52605a05d96566e8"}, - {file = "aiohttp-3.9.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0605cc2c0088fcaae79f01c913a38611ad09ba68ff482402d3410bf59039bfb8"}, - {file = "aiohttp-3.9.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4558e5012ee03d2638c681e156461d37b7a113fe13970d438d95d10173d25f78"}, - {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9dbc053ac75ccc63dc3a3cc547b98c7258ec35a215a92bd9f983e0aac95d3d5b"}, - {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4109adee842b90671f1b689901b948f347325045c15f46b39797ae1bf17019de"}, - {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a6ea1a5b409a85477fd8e5ee6ad8f0e40bf2844c270955e09360418cfd09abac"}, - {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3c2890ca8c59ee683fd09adf32321a40fe1cf164e3387799efb2acebf090c11"}, - {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3916c8692dbd9d55c523374a3b8213e628424d19116ac4308e434dbf6d95bbdd"}, - {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8d1964eb7617907c792ca00b341b5ec3e01ae8c280825deadbbd678447b127e1"}, - {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d5ab8e1f6bee051a4bf6195e38a5c13e5e161cb7bad83d8854524798bd9fcd6e"}, - {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:52c27110f3862a1afbcb2af4281fc9fdc40327fa286c4625dfee247c3ba90156"}, - {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:7f64cbd44443e80094309875d4f9c71d0401e966d191c3d469cde4642bc2e031"}, - {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8b4f72fbb66279624bfe83fd5eb6aea0022dad8eec62b71e7bf63ee1caadeafe"}, - {file = "aiohttp-3.9.5-cp38-cp38-win32.whl", hash = "sha256:6380c039ec52866c06d69b5c7aad5478b24ed11696f0e72f6b807cfb261453da"}, - {file = "aiohttp-3.9.5-cp38-cp38-win_amd64.whl", hash = "sha256:da22dab31d7180f8c3ac7c7635f3bcd53808f374f6aa333fe0b0b9e14b01f91a"}, - {file = "aiohttp-3.9.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1732102949ff6087589408d76cd6dea656b93c896b011ecafff418c9661dc4ed"}, - {file = "aiohttp-3.9.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c6021d296318cb6f9414b48e6a439a7f5d1f665464da507e8ff640848ee2a58a"}, - {file = "aiohttp-3.9.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:239f975589a944eeb1bad26b8b140a59a3a320067fb3cd10b75c3092405a1372"}, - {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b7b30258348082826d274504fbc7c849959f1989d86c29bc355107accec6cfb"}, - {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cd2adf5c87ff6d8b277814a28a535b59e20bfea40a101db6b3bdca7e9926bc24"}, - {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e9a3d838441bebcf5cf442700e3963f58b5c33f015341f9ea86dcd7d503c07e2"}, - {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e3a1ae66e3d0c17cf65c08968a5ee3180c5a95920ec2731f53343fac9bad106"}, - {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9c69e77370cce2d6df5d12b4e12bdcca60c47ba13d1cbbc8645dd005a20b738b"}, - {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0cbf56238f4bbf49dab8c2dc2e6b1b68502b1e88d335bea59b3f5b9f4c001475"}, - {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d1469f228cd9ffddd396d9948b8c9cd8022b6d1bf1e40c6f25b0fb90b4f893ed"}, - {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:45731330e754f5811c314901cebdf19dd776a44b31927fa4b4dbecab9e457b0c"}, - {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:3fcb4046d2904378e3aeea1df51f697b0467f2aac55d232c87ba162709478c46"}, - {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8cf142aa6c1a751fcb364158fd710b8a9be874b81889c2bd13aa8893197455e2"}, - {file = "aiohttp-3.9.5-cp39-cp39-win32.whl", hash = "sha256:7b179eea70833c8dee51ec42f3b4097bd6370892fa93f510f76762105568cf09"}, - {file = "aiohttp-3.9.5-cp39-cp39-win_amd64.whl", hash = "sha256:38d80498e2e169bc61418ff36170e0aad0cd268da8b38a17c4cf29d254a8b3f1"}, - {file = "aiohttp-3.9.5.tar.gz", hash = "sha256:edea7d15772ceeb29db4aff55e482d4bcfb6ae160ce144f2682de02f6d693551"}, -] - -[package.dependencies] -aiosignal = ">=1.1.2" -async-timeout = {version = ">=4.0,<5.0", markers = "python_version < \"3.11\""} -attrs = ">=17.3.0" -frozenlist = ">=1.1.1" -multidict = ">=4.5,<7.0" -yarl = ">=1.0,<2.0" - -[package.extras] -speedups = ["Brotli ; platform_python_implementation == \"CPython\"", "aiodns ; sys_platform == \"linux\" or sys_platform == \"darwin\"", "brotlicffi ; platform_python_implementation != \"CPython\""] - -[[package]] -name = "aiosignal" -version = "1.3.1" -description = "aiosignal: a list of registered asynchronous callbacks" -optional = false -python-versions = ">=3.7" -groups = ["main"] -files = [ - {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"}, - {file = "aiosignal-1.3.1.tar.gz", hash = "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc"}, -] - -[package.dependencies] -frozenlist = ">=1.1.0" - -[[package]] -name = "async-timeout" -version = "4.0.3" -description = "Timeout context manager for asyncio programs" -optional = false -python-versions = ">=3.7" -groups = ["main"] -markers = "python_version < \"3.11\"" -files = [ - {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"}, - {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, -] - -[[package]] -name = "attrs" -version = "23.2.0" -description = "Classes Without Boilerplate" -optional = false -python-versions = ">=3.7" -groups = ["main"] -files = [ - {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"}, - {file = "attrs-23.2.0.tar.gz", hash = "sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30"}, -] - -[package.extras] -cov = ["attrs[tests]", "coverage[toml] (>=5.3)"] -dev = ["attrs[tests]", "pre-commit"] -docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"] -tests = ["attrs[tests-no-zope]", "zope-interface"] -tests-mypy = ["mypy (>=1.6) ; platform_python_implementation == \"CPython\" and python_version >= \"3.8\"", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.8\""] -tests-no-zope = ["attrs[tests-mypy]", "cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"] - [[package]] name = "beautifulsoup4" version = "4.12.3" @@ -760,93 +603,6 @@ ufo = ["fs (>=2.2.0,<3)"] unicode = ["unicodedata2 (>=15.1.0) ; python_version <= \"3.12\""] woff = ["brotli (>=1.0.1) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\"", "zopfli (>=0.1.4)"] -[[package]] -name = "frozenlist" -version = "1.4.1" -description = "A list-like structure which implements collections.abc.MutableSequence" -optional = false -python-versions = ">=3.8" -groups = ["main"] -files = [ - {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f9aa1878d1083b276b0196f2dfbe00c9b7e752475ed3b682025ff20c1c1f51ac"}, - {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:29acab3f66f0f24674b7dc4736477bcd4bc3ad4b896f5f45379a67bce8b96868"}, - {file = "frozenlist-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:74fb4bee6880b529a0c6560885fce4dc95936920f9f20f53d99a213f7bf66776"}, - {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:590344787a90ae57d62511dd7c736ed56b428f04cd8c161fcc5e7232c130c69a"}, - {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:068b63f23b17df8569b7fdca5517edef76171cf3897eb68beb01341131fbd2ad"}, - {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c849d495bf5154cd8da18a9eb15db127d4dba2968d88831aff6f0331ea9bd4c"}, - {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9750cc7fe1ae3b1611bb8cfc3f9ec11d532244235d75901fb6b8e42ce9229dfe"}, - {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9b2de4cf0cdd5bd2dee4c4f63a653c61d2408055ab77b151c1957f221cabf2a"}, - {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0633c8d5337cb5c77acbccc6357ac49a1770b8c487e5b3505c57b949b4b82e98"}, - {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:27657df69e8801be6c3638054e202a135c7f299267f1a55ed3a598934f6c0d75"}, - {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:f9a3ea26252bd92f570600098783d1371354d89d5f6b7dfd87359d669f2109b5"}, - {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:4f57dab5fe3407b6c0c1cc907ac98e8a189f9e418f3b6e54d65a718aaafe3950"}, - {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e02a0e11cf6597299b9f3bbd3f93d79217cb90cfd1411aec33848b13f5c656cc"}, - {file = "frozenlist-1.4.1-cp310-cp310-win32.whl", hash = "sha256:a828c57f00f729620a442881cc60e57cfcec6842ba38e1b19fd3e47ac0ff8dc1"}, - {file = "frozenlist-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:f56e2333dda1fe0f909e7cc59f021eba0d2307bc6f012a1ccf2beca6ba362439"}, - {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a0cb6f11204443f27a1628b0e460f37fb30f624be6051d490fa7d7e26d4af3d0"}, - {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b46c8ae3a8f1f41a0d2ef350c0b6e65822d80772fe46b653ab6b6274f61d4a49"}, - {file = "frozenlist-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fde5bd59ab5357e3853313127f4d3565fc7dad314a74d7b5d43c22c6a5ed2ced"}, - {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:722e1124aec435320ae01ee3ac7bec11a5d47f25d0ed6328f2273d287bc3abb0"}, - {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2471c201b70d58a0f0c1f91261542a03d9a5e088ed3dc6c160d614c01649c106"}, - {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c757a9dd70d72b076d6f68efdbb9bc943665ae954dad2801b874c8c69e185068"}, - {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f146e0911cb2f1da549fc58fc7bcd2b836a44b79ef871980d605ec392ff6b0d2"}, - {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9c515e7914626b2a2e1e311794b4c35720a0be87af52b79ff8e1429fc25f19"}, - {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c302220494f5c1ebeb0912ea782bcd5e2f8308037b3c7553fad0e48ebad6ad82"}, - {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:442acde1e068288a4ba7acfe05f5f343e19fac87bfc96d89eb886b0363e977ec"}, - {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:1b280e6507ea8a4fa0c0a7150b4e526a8d113989e28eaaef946cc77ffd7efc0a"}, - {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:fe1a06da377e3a1062ae5fe0926e12b84eceb8a50b350ddca72dc85015873f74"}, - {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:db9e724bebd621d9beca794f2a4ff1d26eed5965b004a97f1f1685a173b869c2"}, - {file = "frozenlist-1.4.1-cp311-cp311-win32.whl", hash = "sha256:e774d53b1a477a67838a904131c4b0eef6b3d8a651f8b138b04f748fccfefe17"}, - {file = "frozenlist-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:fb3c2db03683b5767dedb5769b8a40ebb47d6f7f45b1b3e3b4b51ec8ad9d9825"}, - {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1979bc0aeb89b33b588c51c54ab0161791149f2461ea7c7c946d95d5f93b56ae"}, - {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cc7b01b3754ea68a62bd77ce6020afaffb44a590c2289089289363472d13aedb"}, - {file = "frozenlist-1.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9c92be9fd329ac801cc420e08452b70e7aeab94ea4233a4804f0915c14eba9b"}, - {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3894db91f5a489fc8fa6a9991820f368f0b3cbdb9cd8849547ccfab3392d86"}, - {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba60bb19387e13597fb059f32cd4d59445d7b18b69a745b8f8e5db0346f33480"}, - {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8aefbba5f69d42246543407ed2461db31006b0f76c4e32dfd6f42215a2c41d09"}, - {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:780d3a35680ced9ce682fbcf4cb9c2bad3136eeff760ab33707b71db84664e3a"}, - {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9acbb16f06fe7f52f441bb6f413ebae6c37baa6ef9edd49cdd567216da8600cd"}, - {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:23b701e65c7b36e4bf15546a89279bd4d8675faabc287d06bbcfac7d3c33e1e6"}, - {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3e0153a805a98f5ada7e09826255ba99fb4f7524bb81bf6b47fb702666484ae1"}, - {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:dd9b1baec094d91bf36ec729445f7769d0d0cf6b64d04d86e45baf89e2b9059b"}, - {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:1a4471094e146b6790f61b98616ab8e44f72661879cc63fa1049d13ef711e71e"}, - {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5667ed53d68d91920defdf4035d1cdaa3c3121dc0b113255124bcfada1cfa1b8"}, - {file = "frozenlist-1.4.1-cp312-cp312-win32.whl", hash = "sha256:beee944ae828747fd7cb216a70f120767fc9f4f00bacae8543c14a6831673f89"}, - {file = "frozenlist-1.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:64536573d0a2cb6e625cf309984e2d873979709f2cf22839bf2d61790b448ad5"}, - {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:20b51fa3f588ff2fe658663db52a41a4f7aa6c04f6201449c6c7c476bd255c0d"}, - {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:410478a0c562d1a5bcc2f7ea448359fcb050ed48b3c6f6f4f18c313a9bdb1826"}, - {file = "frozenlist-1.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c6321c9efe29975232da3bd0af0ad216800a47e93d763ce64f291917a381b8eb"}, - {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48f6a4533887e189dae092f1cf981f2e3885175f7a0f33c91fb5b7b682b6bab6"}, - {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6eb73fa5426ea69ee0e012fb59cdc76a15b1283d6e32e4f8dc4482ec67d1194d"}, - {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fbeb989b5cc29e8daf7f976b421c220f1b8c731cbf22b9130d8815418ea45887"}, - {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:32453c1de775c889eb4e22f1197fe3bdfe457d16476ea407472b9442e6295f7a"}, - {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:693945278a31f2086d9bf3df0fe8254bbeaef1fe71e1351c3bd730aa7d31c41b"}, - {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:1d0ce09d36d53bbbe566fe296965b23b961764c0bcf3ce2fa45f463745c04701"}, - {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3a670dc61eb0d0eb7080890c13de3066790f9049b47b0de04007090807c776b0"}, - {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:dca69045298ce5c11fd539682cff879cc1e664c245d1c64da929813e54241d11"}, - {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a06339f38e9ed3a64e4c4e43aec7f59084033647f908e4259d279a52d3757d09"}, - {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b7f2f9f912dca3934c1baec2e4585a674ef16fe00218d833856408c48d5beee7"}, - {file = "frozenlist-1.4.1-cp38-cp38-win32.whl", hash = "sha256:e7004be74cbb7d9f34553a5ce5fb08be14fb33bc86f332fb71cbe5216362a497"}, - {file = "frozenlist-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:5a7d70357e7cee13f470c7883a063aae5fe209a493c57d86eb7f5a6f910fae09"}, - {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bfa4a17e17ce9abf47a74ae02f32d014c5e9404b6d9ac7f729e01562bbee601e"}, - {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b7e3ed87d4138356775346e6845cccbe66cd9e207f3cd11d2f0b9fd13681359d"}, - {file = "frozenlist-1.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c99169d4ff810155ca50b4da3b075cbde79752443117d89429595c2e8e37fed8"}, - {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edb678da49d9f72c9f6c609fbe41a5dfb9a9282f9e6a2253d5a91e0fc382d7c0"}, - {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6db4667b187a6742b33afbbaf05a7bc551ffcf1ced0000a571aedbb4aa42fc7b"}, - {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55fdc093b5a3cb41d420884cdaf37a1e74c3c37a31f46e66286d9145d2063bd0"}, - {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82e8211d69a4f4bc360ea22cd6555f8e61a1bd211d1d5d39d3d228b48c83a897"}, - {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89aa2c2eeb20957be2d950b85974b30a01a762f3308cd02bb15e1ad632e22dc7"}, - {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9d3e0c25a2350080e9319724dede4f31f43a6c9779be48021a7f4ebde8b2d742"}, - {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7268252af60904bf52c26173cbadc3a071cece75f873705419c8681f24d3edea"}, - {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:0c250a29735d4f15321007fb02865f0e6b6a41a6b88f1f523ca1596ab5f50bd5"}, - {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:96ec70beabbd3b10e8bfe52616a13561e58fe84c0101dd031dc78f250d5128b9"}, - {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:23b2d7679b73fe0e5a4560b672a39f98dfc6f60df63823b0a9970525325b95f6"}, - {file = "frozenlist-1.4.1-cp39-cp39-win32.whl", hash = "sha256:a7496bfe1da7fb1a4e1cc23bb67c58fab69311cc7d32b5a99c2007b4b2a0e932"}, - {file = "frozenlist-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:e6a20a581f9ce92d389a8c7d7c3dd47c81fd5d6e655c8dddf341e14aa48659d0"}, - {file = "frozenlist-1.4.1-py3-none-any.whl", hash = "sha256:04ced3e6a46b4cfffe20f9ae482818e34eba9b5fb0ce4056e4cc9b6e212d09b7"}, - {file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"}, -] - [[package]] name = "fsspec" version = "2023.6.0" @@ -859,10 +615,6 @@ files = [ {file = "fsspec-2023.6.0.tar.gz", hash = "sha256:d0b2f935446169753e7a5c5c55681c54ea91996cc67be93c39a154fb3a2742af"}, ] -[package.dependencies] -aiohttp = {version = "<4.0.0a0 || >4.0.0a0,<4.0.0a1 || >4.0.0a1", optional = true, markers = "extra == \"http\""} -requests = {version = "*", optional = true, markers = "extra == \"http\""} - [package.extras] abfs = ["adlfs"] adl = ["adlfs"] @@ -1064,22 +816,6 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] -[[package]] -name = "intel-openmp" -version = "2021.4.0" -description = "Intel OpenMP* Runtime Library" -optional = false -python-versions = "*" -groups = ["main", "dev"] -markers = "platform_system == \"Windows\"" -files = [ - {file = "intel_openmp-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:41c01e266a7fdb631a7609191709322da2bbf24b252ba763f125dd651bcc7675"}, - {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:3b921236a38384e2016f0f3d65af6732cf2c12918087128a9163225451e776f2"}, - {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:e2240ab8d01472fed04f3544a878cda5da16c26232b7ea1b59132dbfb48b186e"}, - {file = "intel_openmp-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:6e863d8fd3d7e8ef389d52cf97a50fe2afe1a19247e8c0d168ce021546f96fc9"}, - {file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"}, -] - [[package]] name = "janome" version = "0.5.0" @@ -1263,28 +999,6 @@ files = [ [package.dependencies] six = "*" -[[package]] -name = "lightning-utilities" -version = "0.11.2" -description = "Lightning toolbox for across the our ecosystem." -optional = false -python-versions = ">=3.8" -groups = ["main"] -files = [ - {file = "lightning-utilities-0.11.2.tar.gz", hash = "sha256:adf4cf9c5d912fe505db4729e51d1369c6927f3a8ac55a9dff895ce5c0da08d9"}, - {file = "lightning_utilities-0.11.2-py3-none-any.whl", hash = "sha256:541f471ed94e18a28d72879338c8c52e873bb46f4c47644d89228faeb6751159"}, -] - -[package.dependencies] -packaging = ">=17.1" -setuptools = "*" -typing-extensions = "*" - -[package.extras] -cli = ["fire"] -docs = ["requests (>=2.0.0)"] -typing = ["mypy (>=1.0.0)", "types-setuptools"] - [[package]] name = "lxml" version = "5.2.2" @@ -1580,26 +1294,6 @@ python-dateutil = ">=2.7" [package.extras] dev = ["meson-python (>=0.13.1,<0.17.0)", "numpy (>=1.25)", "pybind11 (>=2.6,!=2.13.3)", "setuptools (>=64)", "setuptools_scm (>=7)"] -[[package]] -name = "mkl" -version = "2021.4.0" -description = "Intel® oneAPI Math Kernel Library" -optional = false -python-versions = "*" -groups = ["main", "dev"] -markers = "platform_system == \"Windows\"" -files = [ - {file = "mkl-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:67460f5cd7e30e405b54d70d1ed3ca78118370b65f7327d495e9c8847705e2fb"}, - {file = "mkl-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:636d07d90e68ccc9630c654d47ce9fdeb036bb46e2b193b3a9ac8cfea683cce5"}, - {file = "mkl-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:398dbf2b0d12acaf54117a5210e8f191827f373d362d796091d161f610c1ebfb"}, - {file = "mkl-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:439c640b269a5668134e3dcbcea4350459c4a8bc46469669b2d67e07e3d330e8"}, - {file = "mkl-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:ceef3cafce4c009dd25f65d7ad0d833a0fbadc3d8903991ec92351fe5de1e718"}, -] - -[package.dependencies] -intel-openmp = "==2021.*" -tbb = "==2021.*" - [[package]] name = "more-itertools" version = "10.3.0" @@ -1646,106 +1340,6 @@ docs = ["sphinx"] gmpy = ["gmpy2 (>=2.1.0a4) ; platform_python_implementation != \"PyPy\""] tests = ["pytest (>=4.6)"] -[[package]] -name = "multidict" -version = "6.0.5" -description = "multidict implementation" -optional = false -python-versions = ">=3.7" -groups = ["main"] -files = [ - {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:228b644ae063c10e7f324ab1ab6b548bdf6f8b47f3ec234fef1093bc2735e5f9"}, - {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:896ebdcf62683551312c30e20614305f53125750803b614e9e6ce74a96232604"}, - {file = "multidict-6.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:411bf8515f3be9813d06004cac41ccf7d1cd46dfe233705933dd163b60e37600"}, - {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d147090048129ce3c453f0292e7697d333db95e52616b3793922945804a433c"}, - {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:215ed703caf15f578dca76ee6f6b21b7603791ae090fbf1ef9d865571039ade5"}, - {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c6390cf87ff6234643428991b7359b5f59cc15155695deb4eda5c777d2b880f"}, - {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21fd81c4ebdb4f214161be351eb5bcf385426bf023041da2fd9e60681f3cebae"}, - {file = "multidict-6.0.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3cc2ad10255f903656017363cd59436f2111443a76f996584d1077e43ee51182"}, - {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6939c95381e003f54cd4c5516740faba40cf5ad3eeff460c3ad1d3e0ea2549bf"}, - {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:220dd781e3f7af2c2c1053da9fa96d9cf3072ca58f057f4c5adaaa1cab8fc442"}, - {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:766c8f7511df26d9f11cd3a8be623e59cca73d44643abab3f8c8c07620524e4a"}, - {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:fe5d7785250541f7f5019ab9cba2c71169dc7d74d0f45253f8313f436458a4ef"}, - {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c1c1496e73051918fcd4f58ff2e0f2f3066d1c76a0c6aeffd9b45d53243702cc"}, - {file = "multidict-6.0.5-cp310-cp310-win32.whl", hash = "sha256:7afcdd1fc07befad18ec4523a782cde4e93e0a2bf71239894b8d61ee578c1319"}, - {file = "multidict-6.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:99f60d34c048c5c2fabc766108c103612344c46e35d4ed9ae0673d33c8fb26e8"}, - {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f285e862d2f153a70586579c15c44656f888806ed0e5b56b64489afe4a2dbfba"}, - {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:53689bb4e102200a4fafa9de9c7c3c212ab40a7ab2c8e474491914d2305f187e"}, - {file = "multidict-6.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:612d1156111ae11d14afaf3a0669ebf6c170dbb735e510a7438ffe2369a847fd"}, - {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7be7047bd08accdb7487737631d25735c9a04327911de89ff1b26b81745bd4e3"}, - {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de170c7b4fe6859beb8926e84f7d7d6c693dfe8e27372ce3b76f01c46e489fcf"}, - {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04bde7a7b3de05732a4eb39c94574db1ec99abb56162d6c520ad26f83267de29"}, - {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85f67aed7bb647f93e7520633d8f51d3cbc6ab96957c71272b286b2f30dc70ed"}, - {file = "multidict-6.0.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:425bf820055005bfc8aa9a0b99ccb52cc2f4070153e34b701acc98d201693733"}, - {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d3eb1ceec286eba8220c26f3b0096cf189aea7057b6e7b7a2e60ed36b373b77f"}, - {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7901c05ead4b3fb75113fb1dd33eb1253c6d3ee37ce93305acd9d38e0b5f21a4"}, - {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e0e79d91e71b9867c73323a3444724d496c037e578a0e1755ae159ba14f4f3d1"}, - {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:29bfeb0dff5cb5fdab2023a7a9947b3b4af63e9c47cae2a10ad58394b517fddc"}, - {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e030047e85cbcedbfc073f71836d62dd5dadfbe7531cae27789ff66bc551bd5e"}, - {file = "multidict-6.0.5-cp311-cp311-win32.whl", hash = "sha256:2f4848aa3baa109e6ab81fe2006c77ed4d3cd1e0ac2c1fbddb7b1277c168788c"}, - {file = "multidict-6.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:2faa5ae9376faba05f630d7e5e6be05be22913782b927b19d12b8145968a85ea"}, - {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:51d035609b86722963404f711db441cf7134f1889107fb171a970c9701f92e1e"}, - {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cbebcd5bcaf1eaf302617c114aa67569dd3f090dd0ce8ba9e35e9985b41ac35b"}, - {file = "multidict-6.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ffc42c922dbfddb4a4c3b438eb056828719f07608af27d163191cb3e3aa6cc5"}, - {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ceb3b7e6a0135e092de86110c5a74e46bda4bd4fbfeeb3a3bcec79c0f861e450"}, - {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:79660376075cfd4b2c80f295528aa6beb2058fd289f4c9252f986751a4cd0496"}, - {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4428b29611e989719874670fd152b6625500ad6c686d464e99f5aaeeaca175a"}, - {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d84a5c3a5f7ce6db1f999fb9438f686bc2e09d38143f2d93d8406ed2dd6b9226"}, - {file = "multidict-6.0.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:76c0de87358b192de7ea9649beb392f107dcad9ad27276324c24c91774ca5271"}, - {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:79a6d2ba910adb2cbafc95dad936f8b9386e77c84c35bc0add315b856d7c3abb"}, - {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:92d16a3e275e38293623ebf639c471d3e03bb20b8ebb845237e0d3664914caef"}, - {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:fb616be3538599e797a2017cccca78e354c767165e8858ab5116813146041a24"}, - {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:14c2976aa9038c2629efa2c148022ed5eb4cb939e15ec7aace7ca932f48f9ba6"}, - {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:435a0984199d81ca178b9ae2c26ec3d49692d20ee29bc4c11a2a8d4514c67eda"}, - {file = "multidict-6.0.5-cp312-cp312-win32.whl", hash = "sha256:9fe7b0653ba3d9d65cbe7698cca585bf0f8c83dbbcc710db9c90f478e175f2d5"}, - {file = "multidict-6.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:01265f5e40f5a17f8241d52656ed27192be03bfa8764d88e8220141d1e4b3556"}, - {file = "multidict-6.0.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:19fe01cea168585ba0f678cad6f58133db2aa14eccaf22f88e4a6dccadfad8b3"}, - {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf7a982604375a8d49b6cc1b781c1747f243d91b81035a9b43a2126c04766f5"}, - {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:107c0cdefe028703fb5dafe640a409cb146d44a6ae201e55b35a4af8e95457dd"}, - {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:403c0911cd5d5791605808b942c88a8155c2592e05332d2bf78f18697a5fa15e"}, - {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aeaf541ddbad8311a87dd695ed9642401131ea39ad7bc8cf3ef3967fd093b626"}, - {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4972624066095e52b569e02b5ca97dbd7a7ddd4294bf4e7247d52635630dd83"}, - {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d946b0a9eb8aaa590df1fe082cee553ceab173e6cb5b03239716338629c50c7a"}, - {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b55358304d7a73d7bdf5de62494aaf70bd33015831ffd98bc498b433dfe5b10c"}, - {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:a3145cb08d8625b2d3fee1b2d596a8766352979c9bffe5d7833e0503d0f0b5e5"}, - {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d65f25da8e248202bd47445cec78e0025c0fe7582b23ec69c3b27a640dd7a8e3"}, - {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:c9bf56195c6bbd293340ea82eafd0071cb3d450c703d2c93afb89f93b8386ccc"}, - {file = "multidict-6.0.5-cp37-cp37m-win32.whl", hash = "sha256:69db76c09796b313331bb7048229e3bee7928eb62bab5e071e9f7fcc4879caee"}, - {file = "multidict-6.0.5-cp37-cp37m-win_amd64.whl", hash = "sha256:fce28b3c8a81b6b36dfac9feb1de115bab619b3c13905b419ec71d03a3fc1423"}, - {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:76f067f5121dcecf0d63a67f29080b26c43c71a98b10c701b0677e4a065fbd54"}, - {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b82cc8ace10ab5bd93235dfaab2021c70637005e1ac787031f4d1da63d493c1d"}, - {file = "multidict-6.0.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5cb241881eefd96b46f89b1a056187ea8e9ba14ab88ba632e68d7a2ecb7aadf7"}, - {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8e94e6912639a02ce173341ff62cc1201232ab86b8a8fcc05572741a5dc7d93"}, - {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09a892e4a9fb47331da06948690ae38eaa2426de97b4ccbfafbdcbe5c8f37ff8"}, - {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55205d03e8a598cfc688c71ca8ea5f66447164efff8869517f175ea632c7cb7b"}, - {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37b15024f864916b4951adb95d3a80c9431299080341ab9544ed148091b53f50"}, - {file = "multidict-6.0.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2a1dee728b52b33eebff5072817176c172050d44d67befd681609b4746e1c2e"}, - {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:edd08e6f2f1a390bf137080507e44ccc086353c8e98c657e666c017718561b89"}, - {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:60d698e8179a42ec85172d12f50b1668254628425a6bd611aba022257cac1386"}, - {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:3d25f19500588cbc47dc19081d78131c32637c25804df8414463ec908631e453"}, - {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:4cc0ef8b962ac7a5e62b9e826bd0cd5040e7d401bc45a6835910ed699037a461"}, - {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:eca2e9d0cc5a889850e9bbd68e98314ada174ff6ccd1129500103df7a94a7a44"}, - {file = "multidict-6.0.5-cp38-cp38-win32.whl", hash = "sha256:4a6a4f196f08c58c59e0b8ef8ec441d12aee4125a7d4f4fef000ccb22f8d7241"}, - {file = "multidict-6.0.5-cp38-cp38-win_amd64.whl", hash = "sha256:0275e35209c27a3f7951e1ce7aaf93ce0d163b28948444bec61dd7badc6d3f8c"}, - {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e7be68734bd8c9a513f2b0cfd508802d6609da068f40dc57d4e3494cefc92929"}, - {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1d9ea7a7e779d7a3561aade7d596649fbecfa5c08a7674b11b423783217933f9"}, - {file = "multidict-6.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ea1456df2a27c73ce51120fa2f519f1bea2f4a03a917f4a43c8707cf4cbbae1a"}, - {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf590b134eb70629e350691ecca88eac3e3b8b3c86992042fb82e3cb1830d5e1"}, - {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5c0631926c4f58e9a5ccce555ad7747d9a9f8b10619621f22f9635f069f6233e"}, - {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dce1c6912ab9ff5f179eaf6efe7365c1f425ed690b03341911bf4939ef2f3046"}, - {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0868d64af83169e4d4152ec612637a543f7a336e4a307b119e98042e852ad9c"}, - {file = "multidict-6.0.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:141b43360bfd3bdd75f15ed811850763555a251e38b2405967f8e25fb43f7d40"}, - {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7df704ca8cf4a073334e0427ae2345323613e4df18cc224f647f251e5e75a527"}, - {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6214c5a5571802c33f80e6c84713b2c79e024995b9c5897f794b43e714daeec9"}, - {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:cd6c8fca38178e12c00418de737aef1261576bd1b6e8c6134d3e729a4e858b38"}, - {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:e02021f87a5b6932fa6ce916ca004c4d441509d33bbdbeca70d05dff5e9d2479"}, - {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ebd8d160f91a764652d3e51ce0d2956b38efe37c9231cd82cfc0bed2e40b581c"}, - {file = "multidict-6.0.5-cp39-cp39-win32.whl", hash = "sha256:04da1bb8c8dbadf2a18a452639771951c662c5ad03aefe4884775454be322c9b"}, - {file = "multidict-6.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:d6f6d4f185481c9669b9447bf9d9cf3b95a0e9df9d169bbc17e363b7d5487755"}, - {file = "multidict-6.0.5-py3-none-any.whl", hash = "sha256:0d63c74e3d7ab26de115c49bffc92cc77ed23395303d496eae515d4204a625e7"}, - {file = "multidict-6.0.5.tar.gz", hash = "sha256:f7e301075edaf50500f0b341543c41194d8df3ae5caf4702f2095f3ca73dd8da"}, -] - [[package]] name = "networkx" version = "3.2.1" @@ -1849,6 +1443,228 @@ files = [ {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"}, ] +[[package]] +name = "nvidia-cublas-cu12" +version = "12.6.4.1" +description = "CUBLAS native runtime libraries" +optional = false +python-versions = ">=3" +groups = ["main", "dev"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +files = [ + {file = "nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:08ed2686e9875d01b58e3cb379c6896df8e76c75e0d4a7f7dace3d7b6d9ef8eb"}, + {file = "nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:235f728d6e2a409eddf1df58d5b0921cf80cfa9e72b9f2775ccb7b4a87984668"}, + {file = "nvidia_cublas_cu12-12.6.4.1-py3-none-win_amd64.whl", hash = "sha256:9e4fa264f4d8a4eb0cdbd34beadc029f453b3bafae02401e999cf3d5a5af75f8"}, +] + +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.6.80" +description = "CUDA profiling tools runtime libs." +optional = false +python-versions = ">=3" +groups = ["main", "dev"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +files = [ + {file = "nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:166ee35a3ff1587f2490364f90eeeb8da06cd867bd5b701bf7f9a02b78bc63fc"}, + {file = "nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_aarch64.whl", hash = "sha256:358b4a1d35370353d52e12f0a7d1769fc01ff74a191689d3870b2123156184c4"}, + {file = "nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6768bad6cab4f19e8292125e5f1ac8aa7d1718704012a0e3272a6f61c4bce132"}, + {file = "nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a3eff6cdfcc6a4c35db968a06fcadb061cbc7d6dde548609a941ff8701b98b73"}, + {file = "nvidia_cuda_cupti_cu12-12.6.80-py3-none-win_amd64.whl", hash = "sha256:bbe6ae76e83ce5251b56e8c8e61a964f757175682bbad058b170b136266ab00a"}, +] + +[[package]] +name = "nvidia-cuda-nvrtc-cu12" +version = "12.6.77" +description = "NVRTC native runtime libraries" +optional = false +python-versions = ">=3" +groups = ["main", "dev"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +files = [ + {file = "nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5847f1d6e5b757f1d2b3991a01082a44aad6f10ab3c5c0213fa3e25bddc25a13"}, + {file = "nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:35b0cc6ee3a9636d5409133e79273ce1f3fd087abb0532d2d2e8fff1fe9efc53"}, + {file = "nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:f7007dbd914c56bd80ea31bc43e8e149da38f68158f423ba845fc3292684e45a"}, +] + +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.6.77" +description = "CUDA Runtime native Libraries" +optional = false +python-versions = ">=3" +groups = ["main", "dev"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +files = [ + {file = "nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6116fad3e049e04791c0256a9778c16237837c08b27ed8c8401e2e45de8d60cd"}, + {file = "nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d461264ecb429c84c8879a7153499ddc7b19b5f8d84c204307491989a365588e"}, + {file = "nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ba3b56a4f896141e25e19ab287cd71e52a6a0f4b29d0d31609f60e3b4d5219b7"}, + {file = "nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a84d15d5e1da416dd4774cb42edf5e954a3e60cc945698dc1d5be02321c44dc8"}, + {file = "nvidia_cuda_runtime_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:86c58044c824bf3c173c49a2dbc7a6c8b53cb4e4dca50068be0bf64e9dab3f7f"}, +] + +[[package]] +name = "nvidia-cudnn-cu12" +version = "9.5.1.17" +description = "cuDNN runtime libraries" +optional = false +python-versions = ">=3" +groups = ["main", "dev"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +files = [ + {file = "nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:9fd4584468533c61873e5fda8ca41bac3a38bcb2d12350830c69b0a96a7e4def"}, + {file = "nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:30ac3869f6db17d170e0e556dd6cc5eee02647abc31ca856634d5a40f82c15b2"}, + {file = "nvidia_cudnn_cu12-9.5.1.17-py3-none-win_amd64.whl", hash = "sha256:d7af0f8a4f3b4b9dbb3122f2ef553b45694ed9c384d5a75bab197b8eefb79ab8"}, +] + +[package.dependencies] +nvidia-cublas-cu12 = "*" + +[[package]] +name = "nvidia-cufft-cu12" +version = "11.3.0.4" +description = "CUFFT native runtime libraries" +optional = false +python-versions = ">=3" +groups = ["main", "dev"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +files = [ + {file = "nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d16079550df460376455cba121db6564089176d9bac9e4f360493ca4741b22a6"}, + {file = "nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8510990de9f96c803a051822618d42bf6cb8f069ff3f48d93a8486efdacb48fb"}, + {file = "nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ccba62eb9cef5559abd5e0d54ceed2d9934030f51163df018532142a8ec533e5"}, + {file = "nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.whl", hash = "sha256:768160ac89f6f7b459bee747e8d175dbf53619cfe74b2a5636264163138013ca"}, + {file = "nvidia_cufft_cu12-11.3.0.4-py3-none-win_amd64.whl", hash = "sha256:6048ebddfb90d09d2707efb1fd78d4e3a77cb3ae4dc60e19aab6be0ece2ae464"}, +] + +[package.dependencies] +nvidia-nvjitlink-cu12 = "*" + +[[package]] +name = "nvidia-cufile-cu12" +version = "1.11.1.6" +description = "cuFile GPUDirect libraries" +optional = false +python-versions = ">=3" +groups = ["main", "dev"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +files = [ + {file = "nvidia_cufile_cu12-1.11.1.6-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cc23469d1c7e52ce6c1d55253273d32c565dd22068647f3aa59b3c6b005bf159"}, + {file = "nvidia_cufile_cu12-1.11.1.6-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:8f57a0051dcf2543f6dc2b98a98cb2719c37d3cee1baba8965d57f3bbc90d4db"}, +] + +[[package]] +name = "nvidia-curand-cu12" +version = "10.3.7.77" +description = "CURAND native runtime libraries" +optional = false +python-versions = ">=3" +groups = ["main", "dev"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +files = [ + {file = "nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:6e82df077060ea28e37f48a3ec442a8f47690c7499bff392a5938614b56c98d8"}, + {file = "nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a42cd1344297f70b9e39a1e4f467a4e1c10f1da54ff7a85c12197f6c652c8bdf"}, + {file = "nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:99f1a32f1ac2bd134897fc7a203f779303261268a65762a623bf30cc9fe79117"}, + {file = "nvidia_curand_cu12-10.3.7.77-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:7b2ed8e95595c3591d984ea3603dd66fe6ce6812b886d59049988a712ed06b6e"}, + {file = "nvidia_curand_cu12-10.3.7.77-py3-none-win_amd64.whl", hash = "sha256:6d6d935ffba0f3d439b7cd968192ff068fafd9018dbf1b85b37261b13cfc9905"}, +] + +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.7.1.2" +description = "CUDA solver native runtime libraries" +optional = false +python-versions = ">=3" +groups = ["main", "dev"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +files = [ + {file = "nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0ce237ef60acde1efc457335a2ddadfd7610b892d94efee7b776c64bb1cac9e0"}, + {file = "nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e9e49843a7707e42022babb9bcfa33c29857a93b88020c4e4434656a655b698c"}, + {file = "nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6cf28f17f64107a0c4d7802be5ff5537b2130bfc112f25d5a30df227058ca0e6"}, + {file = "nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:dbbe4fc38ec1289c7e5230e16248365e375c3673c9c8bac5796e2e20db07f56e"}, + {file = "nvidia_cusolver_cu12-11.7.1.2-py3-none-win_amd64.whl", hash = "sha256:6813f9d8073f555444a8705f3ab0296d3e1cb37a16d694c5fc8b862a0d8706d7"}, +] + +[package.dependencies] +nvidia-cublas-cu12 = "*" +nvidia-cusparse-cu12 = "*" +nvidia-nvjitlink-cu12 = "*" + +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.5.4.2" +description = "CUSPARSE native runtime libraries" +optional = false +python-versions = ">=3" +groups = ["main", "dev"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +files = [ + {file = "nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d25b62fb18751758fe3c93a4a08eff08effedfe4edf1c6bb5afd0890fe88f887"}, + {file = "nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7aa32fa5470cf754f72d1116c7cbc300b4e638d3ae5304cfa4a638a5b87161b1"}, + {file = "nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7556d9eca156e18184b94947ade0fba5bb47d69cec46bf8660fd2c71a4b48b73"}, + {file = "nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:23749a6571191a215cb74d1cdbff4a86e7b19f1200c071b3fcf844a5bea23a2f"}, + {file = "nvidia_cusparse_cu12-12.5.4.2-py3-none-win_amd64.whl", hash = "sha256:4acb8c08855a26d737398cba8fb6f8f5045d93f82612b4cfd84645a2332ccf20"}, +] + +[package.dependencies] +nvidia-nvjitlink-cu12 = "*" + +[[package]] +name = "nvidia-cusparselt-cu12" +version = "0.6.3" +description = "NVIDIA cuSPARSELt" +optional = false +python-versions = "*" +groups = ["main", "dev"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +files = [ + {file = "nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8371549623ba601a06322af2133c4a44350575f5a3108fb75f3ef20b822ad5f1"}, + {file = "nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e5c8a26c36445dd2e6812f1177978a24e2d37cacce7e090f297a688d1ec44f46"}, + {file = "nvidia_cusparselt_cu12-0.6.3-py3-none-win_amd64.whl", hash = "sha256:3b325bcbd9b754ba43df5a311488fca11a6b5dc3d11df4d190c000cf1a0765c7"}, +] + +[[package]] +name = "nvidia-nccl-cu12" +version = "2.26.2" +description = "NVIDIA Collective Communication Library (NCCL) Runtime" +optional = false +python-versions = ">=3" +groups = ["main", "dev"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +files = [ + {file = "nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5c196e95e832ad30fbbb50381eb3cbd1fadd5675e587a548563993609af19522"}, + {file = "nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:694cf3879a206553cc9d7dbda76b13efaf610fdb70a50cba303de1b0d1530ac6"}, +] + +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.6.85" +description = "Nvidia JIT LTO Library" +optional = false +python-versions = ">=3" +groups = ["main", "dev"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +files = [ + {file = "nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:eedc36df9e88b682efe4309aa16b5b4e78c2407eac59e8c10a6a47535164369a"}, + {file = "nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf4eaa7d4b6b543ffd69d6abfb11efdeb2db48270d94dfd3a452c24150829e41"}, + {file = "nvidia_nvjitlink_cu12-12.6.85-py3-none-win_amd64.whl", hash = "sha256:e61120e52ed675747825cdd16febc6a0730537451d867ee58bee3853b1b13d1c"}, +] + +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.6.77" +description = "NVIDIA Tools Extension" +optional = false +python-versions = ">=3" +groups = ["main", "dev"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +files = [ + {file = "nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f44f8d86bb7d5629988d61c8d3ae61dddb2015dee142740536bc7481b022fe4b"}, + {file = "nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:adcaabb9d436c9761fca2b13959a2d237c5f9fd406c8e4b723c695409ff88059"}, + {file = "nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b90bed3df379fa79afbd21be8e04a0314336b8ae16768b58f2d34cb1d04cd7d2"}, + {file = "nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6574241a3ec5fdc9334353ab8c479fe75841dbe8f4532a8fc97ce63503330ba1"}, + {file = "nvidia_nvtx_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:2fb11a4af04a5e6c84073e6404d26588a34afd35379f0855a99797897efa75c0"}, +] + [[package]] name = "packaging" version = "24.0" @@ -2248,72 +2064,6 @@ files = [ [package.dependencies] six = ">=1.5" -[[package]] -name = "pytorch-crf" -version = "0.7.2" -description = "Conditional random field in PyTorch" -optional = false -python-versions = ">=3.6, <4" -groups = ["dev"] -files = [ - {file = "pytorch-crf-0.7.2.tar.gz", hash = "sha256:e6456e22ccfc99a3d4fe1e03e996103b1b39e9830bf3c7e12e7a9077d3be866d"}, - {file = "pytorch_crf-0.7.2-py3-none-any.whl", hash = "sha256:1b2d7d5eea3255f6e0cac09ab8b645472e76ff70d9333bc88762cf7317a4992d"}, -] - -[[package]] -name = "pytorch-ie" -version = "0.31.9" -description = "State-of-the-art Information Extraction in PyTorch" -optional = false -python-versions = "<4.0,>=3.9" -groups = ["main"] -files = [ - {file = "pytorch_ie-0.31.9-py3-none-any.whl", hash = "sha256:002eab323d529022e13a1ed1a7effc43e1bc172bcc11abe58c46501a8c37eb54"}, - {file = "pytorch_ie-0.31.9.tar.gz", hash = "sha256:bd516817ce759c059fcbe61c8d420367366c82014d3ba49938a61cb564102610"}, -] - -[package.dependencies] -absl-py = ">=1.0.0,<2.0.0" -fsspec = "<2023.9.0" -pandas = ">=2.0.0,<3.0.0" -pie-core = ">=0.2.0,<0.3.0" -pytorch-lightning = ">=2,<3" -torch = ">=1.10" -torchmetrics = ">=1,<2" -transformers = ">=4.18,<5.0" - -[[package]] -name = "pytorch-lightning" -version = "2.2.5" -description = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate." -optional = false -python-versions = ">=3.8" -groups = ["main"] -files = [ - {file = "pytorch-lightning-2.2.5.tar.gz", hash = "sha256:8d06d0166e2204f82864f5d2b53a367c2c375d9cd5a7f6174434b2dffeaef7e9"}, - {file = "pytorch_lightning-2.2.5-py3-none-any.whl", hash = "sha256:67a7800863326914f68f6afd68f427855ef2315b4f00d554be8ea4c0f0557fd8"}, -] - -[package.dependencies] -fsspec = {version = ">=2022.5.0", extras = ["http"]} -lightning-utilities = ">=0.8.0" -numpy = ">=1.17.2" -packaging = ">=20.0" -PyYAML = ">=5.4" -torch = ">=1.13.0" -torchmetrics = ">=0.7.0" -tqdm = ">=4.57.0" -typing-extensions = ">=4.4.0" - -[package.extras] -all = ["bitsandbytes (==0.41.0)", "deepspeed (>=0.8.2,<=0.9.3) ; platform_system != \"Windows\"", "gym[classic-control] (>=0.17.0)", "hydra-core (>=1.0.5)", "ipython[all] (<8.15.0)", "jsonargparse[signatures] (>=4.27.7)", "lightning-utilities (>=0.8.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "requests (<2.32.0)", "rich (>=12.3.0)", "tensorboardX (>=2.2)", "torchmetrics (>=0.10.0)", "torchvision (>=0.14.0)"] -deepspeed = ["deepspeed (>=0.8.2,<=0.9.3) ; platform_system != \"Windows\""] -dev = ["bitsandbytes (==0.41.0)", "cloudpickle (>=1.3)", "coverage (==7.3.1)", "deepspeed (>=0.8.2,<=0.9.3) ; platform_system != \"Windows\"", "fastapi", "gym[classic-control] (>=0.17.0)", "hydra-core (>=1.0.5)", "ipython[all] (<8.15.0)", "jsonargparse[signatures] (>=4.27.7)", "lightning-utilities (>=0.8.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "onnx (>=0.14.0)", "onnxruntime (>=0.15.0)", "pandas (>1.0)", "psutil (<5.9.6)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-random-order (==1.1.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "requests (<2.32.0)", "rich (>=12.3.0)", "scikit-learn (>0.22.1)", "tensorboard (>=2.9.1)", "tensorboardX (>=2.2)", "torchmetrics (>=0.10.0)", "torchvision (>=0.14.0)", "uvicorn"] -examples = ["gym[classic-control] (>=0.17.0)", "ipython[all] (<8.15.0)", "lightning-utilities (>=0.8.0)", "requests (<2.32.0)", "torchmetrics (>=0.10.0)", "torchvision (>=0.14.0)"] -extra = ["bitsandbytes (==0.41.0)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.27.7)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "rich (>=12.3.0)", "tensorboardX (>=2.2)"] -strategies = ["deepspeed (>=0.8.2,<=0.9.3) ; platform_system != \"Windows\""] -test = ["cloudpickle (>=1.3)", "coverage (==7.3.1)", "fastapi", "onnx (>=0.14.0)", "onnxruntime (>=0.15.0)", "pandas (>1.0)", "psutil (<5.9.6)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-random-order (==1.1.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "scikit-learn (>0.22.1)", "tensorboard (>=2.9.1)", "uvicorn"] - [[package]] name = "pytorch-revgrad" version = "0.2.0" @@ -2413,7 +2163,7 @@ version = "2024.5.15" description = "Alternative regular expression module, to replace re." optional = false python-versions = ">=3.8" -groups = ["main", "dev"] +groups = ["dev"] files = [ {file = "regex-2024.5.15-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a81e3cfbae20378d75185171587cbf756015ccb14840702944f014e0d93ea09f"}, {file = "regex-2024.5.15-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7b59138b219ffa8979013be7bc85bb60c6f7b7575df3d56dc1e403a438c7a3f6"}, @@ -2543,7 +2293,7 @@ version = "0.4.3" description = "" optional = false python-versions = ">=3.7" -groups = ["main", "dev"] +groups = ["dev"] files = [ {file = "safetensors-0.4.3-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:dcf5705cab159ce0130cd56057f5f3425023c407e170bca60b4868048bae64fd"}, {file = "safetensors-0.4.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bb4f8c5d0358a31e9a08daeebb68f5e161cdd4018855426d3f0c23bb51087055"}, @@ -2845,7 +2595,8 @@ version = "70.0.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" -groups = ["main"] +groups = ["main", "dev"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" or python_version >= \"3.12\"" files = [ {file = "setuptools-70.0.0-py3-none-any.whl", hash = "sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4"}, {file = "setuptools-70.0.0.tar.gz", hash = "sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0"}, @@ -2918,18 +2669,21 @@ files = [ [[package]] name = "sympy" -version = "1.12.1" +version = "1.14.0" description = "Computer algebra system (CAS) in Python" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" groups = ["main", "dev"] files = [ - {file = "sympy-1.12.1-py3-none-any.whl", hash = "sha256:9b2cbc7f1a640289430e13d2a56f02f867a1da0190f2f99d8968c2f74da0e515"}, - {file = "sympy-1.12.1.tar.gz", hash = "sha256:2877b03f998cd8c08f07cd0de5b767119cd3ef40d09f41c30d722f6686b0fb88"}, + {file = "sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5"}, + {file = "sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517"}, ] [package.dependencies] -mpmath = ">=1.1.0,<1.4.0" +mpmath = ">=1.1.0,<1.4" + +[package.extras] +dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] [[package]] name = "tabulate" @@ -2946,21 +2700,6 @@ files = [ [package.extras] widechars = ["wcwidth"] -[[package]] -name = "tbb" -version = "2021.12.0" -description = "Intel® oneAPI Threading Building Blocks (oneTBB)" -optional = false -python-versions = "*" -groups = ["main", "dev"] -markers = "platform_system == \"Windows\"" -files = [ - {file = "tbb-2021.12.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:f2cc9a7f8ababaa506cbff796ce97c3bf91062ba521e15054394f773375d81d8"}, - {file = "tbb-2021.12.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:a925e9a7c77d3a46ae31c34b0bb7f801c4118e857d137b68f68a8e458fcf2bd7"}, - {file = "tbb-2021.12.0-py3-none-win32.whl", hash = "sha256:b1725b30c174048edc8be70bd43bb95473f396ce895d91151a474d0fa9f450a8"}, - {file = "tbb-2021.12.0-py3-none-win_amd64.whl", hash = "sha256:fc2772d850229f2f3df85f1109c4844c495a2db7433d38200959ee9265b34789"}, -] - [[package]] name = "threadpoolctl" version = "3.5.0" @@ -2979,7 +2718,7 @@ version = "0.15.2" description = "" optional = false python-versions = ">=3.7" -groups = ["main", "dev"] +groups = ["dev"] files = [ {file = "tokenizers-0.15.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:52f6130c9cbf70544287575a985bf44ae1bda2da7e8c24e97716080593638012"}, {file = "tokenizers-0.15.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:054c1cc9c6d68f7ffa4e810b3d5131e0ba511b6e4be34157aa08ee54c2f8d9ee"}, @@ -3116,71 +2855,65 @@ files = [ [[package]] name = "torch" -version = "2.3.0+cpu" +version = "2.7.1" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false -python-versions = ">=3.8.0" +python-versions = ">=3.9.0" groups = ["main", "dev"] files = [ - {file = "torch-2.3.0+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:e3c220702d82c7596924150e0499fbbffcf62a88a59adc860fa357cd8dc1c302"}, - {file = "torch-2.3.0+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:ab0c05525195b8fecdf2ea75968ed32ccd87dff16381b6e13249babb4a9596ff"}, - {file = "torch-2.3.0+cpu-cp311-cp311-linux_x86_64.whl", hash = "sha256:97a38b25ee0e3d020691e7846efbca62a3d8a57645c027dcb5ba0adfec36fe55"}, - {file = "torch-2.3.0+cpu-cp311-cp311-win_amd64.whl", hash = "sha256:a8ac195974be6f067245bae8156b8c06fb0a723b0eed8f2e244b5dd58c7e2a49"}, - {file = "torch-2.3.0+cpu-cp312-cp312-linux_x86_64.whl", hash = "sha256:a8982e52185771591dad577a124a7770f72f288f8ae5833317b1e329c0d2f07e"}, - {file = "torch-2.3.0+cpu-cp312-cp312-win_amd64.whl", hash = "sha256:483131a7997995d867313ee902743084e844e830ab2a0c5e079c61ec2da3cd17"}, - {file = "torch-2.3.0+cpu-cp38-cp38-linux_x86_64.whl", hash = "sha256:8c52484880d5fbe511cffc255dd34847ddeced3f94334c6bf7eb2b0445f10cb4"}, - {file = "torch-2.3.0+cpu-cp38-cp38-win_amd64.whl", hash = "sha256:28a11bcc0d709b397d675cff689707019b8cc122e6bf328b57b900f47c36f156"}, - {file = "torch-2.3.0+cpu-cp39-cp39-linux_x86_64.whl", hash = "sha256:1e86e225e472392440ace378ba3165b5e87648e8b5fbf16adc41c0df881c38b8"}, - {file = "torch-2.3.0+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:5c2afdff80203eaabf4c223a294c2f465020b3360e8e87f76b52ace9c5801ebe"}, + {file = "torch-2.7.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a103b5d782af5bd119b81dbcc7ffc6fa09904c423ff8db397a1e6ea8fd71508f"}, + {file = "torch-2.7.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:fe955951bdf32d182ee8ead6c3186ad54781492bf03d547d31771a01b3d6fb7d"}, + {file = "torch-2.7.1-cp310-cp310-win_amd64.whl", hash = "sha256:885453d6fba67d9991132143bf7fa06b79b24352f4506fd4d10b309f53454162"}, + {file = "torch-2.7.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:d72acfdb86cee2a32c0ce0101606f3758f0d8bb5f8f31e7920dc2809e963aa7c"}, + {file = "torch-2.7.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:236f501f2e383f1cb861337bdf057712182f910f10aeaf509065d54d339e49b2"}, + {file = "torch-2.7.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:06eea61f859436622e78dd0cdd51dbc8f8c6d76917a9cf0555a333f9eac31ec1"}, + {file = "torch-2.7.1-cp311-cp311-win_amd64.whl", hash = "sha256:8273145a2e0a3c6f9fd2ac36762d6ee89c26d430e612b95a99885df083b04e52"}, + {file = "torch-2.7.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:aea4fc1bf433d12843eb2c6b2204861f43d8364597697074c8d38ae2507f8730"}, + {file = "torch-2.7.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:27ea1e518df4c9de73af7e8a720770f3628e7f667280bce2be7a16292697e3fa"}, + {file = "torch-2.7.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:c33360cfc2edd976c2633b3b66c769bdcbbf0e0b6550606d188431c81e7dd1fc"}, + {file = "torch-2.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:d8bf6e1856ddd1807e79dc57e54d3335f2b62e6f316ed13ed3ecfe1fc1df3d8b"}, + {file = "torch-2.7.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:787687087412c4bd68d315e39bc1223f08aae1d16a9e9771d95eabbb04ae98fb"}, + {file = "torch-2.7.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:03563603d931e70722dce0e11999d53aa80a375a3d78e6b39b9f6805ea0a8d28"}, + {file = "torch-2.7.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:d632f5417b6980f61404a125b999ca6ebd0b8b4bbdbb5fbbba44374ab619a412"}, + {file = "torch-2.7.1-cp313-cp313-win_amd64.whl", hash = "sha256:23660443e13995ee93e3d844786701ea4ca69f337027b05182f5ba053ce43b38"}, + {file = "torch-2.7.1-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:0da4f4dba9f65d0d203794e619fe7ca3247a55ffdcbd17ae8fb83c8b2dc9b585"}, + {file = "torch-2.7.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:e08d7e6f21a617fe38eeb46dd2213ded43f27c072e9165dc27300c9ef9570934"}, + {file = "torch-2.7.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:30207f672328a42df4f2174b8f426f354b2baa0b7cca3a0adb3d6ab5daf00dc8"}, + {file = "torch-2.7.1-cp313-cp313t-win_amd64.whl", hash = "sha256:79042feca1c634aaf6603fe6feea8c6b30dfa140a6bbc0b973e2260c7e79a22e"}, + {file = "torch-2.7.1-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:988b0cbc4333618a1056d2ebad9eb10089637b659eb645434d0809d8d937b946"}, + {file = "torch-2.7.1-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:e0d81e9a12764b6f3879a866607c8ae93113cbcad57ce01ebde63eb48a576369"}, + {file = "torch-2.7.1-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:8394833c44484547ed4a47162318337b88c97acdb3273d85ea06e03ffff44998"}, + {file = "torch-2.7.1-cp39-cp39-win_amd64.whl", hash = "sha256:df41989d9300e6e3c19ec9f56f856187a6ef060c3662fe54f4b6baf1fc90bd19"}, + {file = "torch-2.7.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:a737b5edd1c44a5c1ece2e9f3d00df9d1b3fb9541138bee56d83d38293fb6c9d"}, ] [package.dependencies] filelock = "*" fsspec = "*" jinja2 = "*" -mkl = {version = ">=2021.1.1,<=2021.4.0", markers = "platform_system == \"Windows\""} networkx = "*" -sympy = "*" -typing-extensions = ">=4.8.0" +nvidia-cublas-cu12 = {version = "12.6.4.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-cupti-cu12 = {version = "12.6.80", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-nvrtc-cu12 = {version = "12.6.77", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-runtime-cu12 = {version = "12.6.77", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cudnn-cu12 = {version = "9.5.1.17", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cufft-cu12 = {version = "11.3.0.4", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cufile-cu12 = {version = "1.11.1.6", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-curand-cu12 = {version = "10.3.7.77", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusolver-cu12 = {version = "11.7.1.2", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusparse-cu12 = {version = "12.5.4.2", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusparselt-cu12 = {version = "0.6.3", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nccl-cu12 = {version = "2.26.2", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nvjitlink-cu12 = {version = "12.6.85", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nvtx-cu12 = {version = "12.6.77", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +setuptools = {version = "*", markers = "python_version >= \"3.12\""} +sympy = ">=1.13.3" +triton = {version = "3.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +typing-extensions = ">=4.10.0" [package.extras] opt-einsum = ["opt-einsum (>=3.3)"] -optree = ["optree (>=0.9.1)"] - -[package.source] -type = "legacy" -url = "https://download.pytorch.org/whl/cpu" -reference = "pytorch" - -[[package]] -name = "torchmetrics" -version = "1.4.0.post0" -description = "PyTorch native Metrics" -optional = false -python-versions = ">=3.8" -groups = ["main"] -files = [ - {file = "torchmetrics-1.4.0.post0-py3-none-any.whl", hash = "sha256:ab234216598e3fbd8d62ee4541a0e74e7e8fc935d099683af5b8da50f745b3c8"}, - {file = "torchmetrics-1.4.0.post0.tar.gz", hash = "sha256:ab9bcfe80e65dbabbddb6cecd9be21f1f1d5207bb74051ef95260740f2762358"}, -] - -[package.dependencies] -lightning-utilities = ">=0.8.0" -numpy = ">1.20.0" -packaging = ">17.1" -torch = ">=1.10.0" - -[package.extras] -all = ["SciencePlots (>=2.0.0)", "ipadic (>=1.0.0)", "matplotlib (>=3.3.0)", "mecab-python3 (>=1.0.6)", "mypy (==1.9.0)", "nltk (>=3.6)", "piq (<=0.8.0)", "pretty-errors (>=1.2.0)", "pycocotools (>2.0.0)", "pystoi (>=0.3.0)", "regex (>=2021.9.24)", "scipy (>1.0.0)", "sentencepiece (>=0.2.0)", "torch (==2.3.0)", "torch-fidelity (<=0.4.0)", "torchaudio (>=0.10.0)", "torchvision (>=0.8)", "tqdm (>=4.41.0)", "transformers (>4.4.0)", "transformers (>=4.10.0)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"] -audio = ["pystoi (>=0.3.0)", "torchaudio (>=0.10.0)"] -debug = ["pretty-errors (>=1.2.0)"] -detection = ["pycocotools (>2.0.0)", "torchvision (>=0.8)"] -dev = ["SciencePlots (>=2.0.0)", "bert-score (==0.3.13)", "dython (<=0.7.5)", "fairlearn", "fast-bss-eval (>=0.1.0)", "faster-coco-eval (>=1.3.3)", "huggingface-hub (<0.23)", "ipadic (>=1.0.0)", "jiwer (>=2.3.0)", "kornia (>=0.6.7)", "lpips (<=0.1.4)", "matplotlib (>=3.3.0)", "mecab-ko (>=1.0.0)", "mecab-ko-dic (>=1.0.0)", "mecab-python3 (>=1.0.6)", "mir-eval (>=0.6)", "monai (==1.3.0)", "mypy (==1.9.0)", "netcal (>1.0.0)", "nltk (>=3.6)", "numpy (<1.27.0)", "pandas (>1.0.0)", "pandas (>=1.4.0)", "piq (<=0.8.0)", "pretty-errors (>=1.2.0)", "pycocotools (>2.0.0)", "pystoi (>=0.3.0)", "pytorch-msssim (==1.0.0)", "regex (>=2021.9.24)", "rouge-score (>0.1.0)", "sacrebleu (>=2.3.0)", "scikit-image (>=0.19.0)", "scipy (>1.0.0)", "sentencepiece (>=0.2.0)", "sewar (>=0.4.4)", "statsmodels (>0.13.5)", "torch (==2.3.0)", "torch-complex (<=0.4.3)", "torch-fidelity (<=0.4.0)", "torchaudio (>=0.10.0)", "torchvision (>=0.8)", "tqdm (>=4.41.0)", "transformers (>4.4.0)", "transformers (>=4.10.0)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"] -image = ["scipy (>1.0.0)", "torch-fidelity (<=0.4.0)", "torchvision (>=0.8)"] -multimodal = ["piq (<=0.8.0)", "transformers (>=4.10.0)"] -text = ["ipadic (>=1.0.0)", "mecab-python3 (>=1.0.6)", "nltk (>=3.6)", "regex (>=2021.9.24)", "sentencepiece (>=0.2.0)", "tqdm (>=4.41.0)", "transformers (>4.4.0)"] -typing = ["mypy (==1.9.0)", "torch (==2.3.0)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"] -visual = ["SciencePlots (>=2.0.0)", "matplotlib (>=3.3.0)"] +optree = ["optree (>=0.13.0)"] [[package]] name = "tqdm" @@ -3226,7 +2959,7 @@ version = "4.36.2" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = false python-versions = ">=3.8.0" -groups = ["main", "dev"] +groups = ["dev"] files = [ {file = "transformers-4.36.2-py3-none-any.whl", hash = "sha256:462066c4f74ee52516f12890dcc9ec71d1a5e97998db621668455117a54330f6"}, {file = "transformers-4.36.2.tar.gz", hash = "sha256:d8068e897e47793281501e547d2bbdfc5b8556409c2cb6c3d9e2ca77d4c0b4ec"}, @@ -3293,6 +3026,31 @@ torchhub = ["filelock", "huggingface-hub (>=0.19.3,<1.0)", "importlib-metadata", video = ["av (==9.2.0)", "decord (==0.6.0)"] vision = ["Pillow (>=10.0.1,<=15.0)"] +[[package]] +name = "triton" +version = "3.3.1" +description = "A language and compiler for custom Deep Learning operations" +optional = false +python-versions = "*" +groups = ["main", "dev"] +markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +files = [ + {file = "triton-3.3.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b74db445b1c562844d3cfad6e9679c72e93fdfb1a90a24052b03bb5c49d1242e"}, + {file = "triton-3.3.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b31e3aa26f8cb3cc5bf4e187bf737cbacf17311e1112b781d4a059353dfd731b"}, + {file = "triton-3.3.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9999e83aba21e1a78c1f36f21bce621b77bcaa530277a50484a7cb4a822f6e43"}, + {file = "triton-3.3.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b89d846b5a4198317fec27a5d3a609ea96b6d557ff44b56c23176546023c4240"}, + {file = "triton-3.3.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a3198adb9d78b77818a5388bff89fa72ff36f9da0bc689db2f0a651a67ce6a42"}, + {file = "triton-3.3.1-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f6139aeb04a146b0b8e0fbbd89ad1e65861c57cfed881f21d62d3cb94a36bab7"}, +] + +[package.dependencies] +setuptools = ">=40.8.0" + +[package.extras] +build = ["cmake (>=3.20)", "lit"] +tests = ["autopep8", "isort", "llnl-hatchet", "numpy", "pytest", "pytest-forked", "pytest-xdist", "scipy (>=1.7.1)"] +tutorials = ["matplotlib", "pandas", "tabulate"] + [[package]] name = "typing-extensions" version = "4.12.0" @@ -3462,110 +3220,6 @@ files = [ {file = "wrapt-1.16.0.tar.gz", hash = "sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d"}, ] -[[package]] -name = "yarl" -version = "1.9.4" -description = "Yet another URL library" -optional = false -python-versions = ">=3.7" -groups = ["main"] -files = [ - {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a8c1df72eb746f4136fe9a2e72b0c9dc1da1cbd23b5372f94b5820ff8ae30e0e"}, - {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a3a6ed1d525bfb91b3fc9b690c5a21bb52de28c018530ad85093cc488bee2dd2"}, - {file = "yarl-1.9.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c38c9ddb6103ceae4e4498f9c08fac9b590c5c71b0370f98714768e22ac6fa66"}, - {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9e09c9d74f4566e905a0b8fa668c58109f7624db96a2171f21747abc7524234"}, - {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8477c1ee4bd47c57d49621a062121c3023609f7a13b8a46953eb6c9716ca392"}, - {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5ff2c858f5f6a42c2a8e751100f237c5e869cbde669a724f2062d4c4ef93551"}, - {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:357495293086c5b6d34ca9616a43d329317feab7917518bc97a08f9e55648455"}, - {file = "yarl-1.9.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54525ae423d7b7a8ee81ba189f131054defdb122cde31ff17477951464c1691c"}, - {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:801e9264d19643548651b9db361ce3287176671fb0117f96b5ac0ee1c3530d53"}, - {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e516dc8baf7b380e6c1c26792610230f37147bb754d6426462ab115a02944385"}, - {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:7d5aaac37d19b2904bb9dfe12cdb08c8443e7ba7d2852894ad448d4b8f442863"}, - {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:54beabb809ffcacbd9d28ac57b0db46e42a6e341a030293fb3185c409e626b8b"}, - {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bac8d525a8dbc2a1507ec731d2867025d11ceadcb4dd421423a5d42c56818541"}, - {file = "yarl-1.9.4-cp310-cp310-win32.whl", hash = "sha256:7855426dfbddac81896b6e533ebefc0af2f132d4a47340cee6d22cac7190022d"}, - {file = "yarl-1.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:848cd2a1df56ddbffeb375535fb62c9d1645dde33ca4d51341378b3f5954429b"}, - {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:35a2b9396879ce32754bd457d31a51ff0a9d426fd9e0e3c33394bf4b9036b099"}, - {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c7d56b293cc071e82532f70adcbd8b61909eec973ae9d2d1f9b233f3d943f2c"}, - {file = "yarl-1.9.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d8a1c6c0be645c745a081c192e747c5de06e944a0d21245f4cf7c05e457c36e0"}, - {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b3c1ffe10069f655ea2d731808e76e0f452fc6c749bea04781daf18e6039525"}, - {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:549d19c84c55d11687ddbd47eeb348a89df9cb30e1993f1b128f4685cd0ebbf8"}, - {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7409f968456111140c1c95301cadf071bd30a81cbd7ab829169fb9e3d72eae9"}, - {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e23a6d84d9d1738dbc6e38167776107e63307dfc8ad108e580548d1f2c587f42"}, - {file = "yarl-1.9.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8b889777de69897406c9fb0b76cdf2fd0f31267861ae7501d93003d55f54fbe"}, - {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:03caa9507d3d3c83bca08650678e25364e1843b484f19986a527630ca376ecce"}, - {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4e9035df8d0880b2f1c7f5031f33f69e071dfe72ee9310cfc76f7b605958ceb9"}, - {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:c0ec0ed476f77db9fb29bca17f0a8fcc7bc97ad4c6c1d8959c507decb22e8572"}, - {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:ee04010f26d5102399bd17f8df8bc38dc7ccd7701dc77f4a68c5b8d733406958"}, - {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49a180c2e0743d5d6e0b4d1a9e5f633c62eca3f8a86ba5dd3c471060e352ca98"}, - {file = "yarl-1.9.4-cp311-cp311-win32.whl", hash = "sha256:81eb57278deb6098a5b62e88ad8281b2ba09f2f1147c4767522353eaa6260b31"}, - {file = "yarl-1.9.4-cp311-cp311-win_amd64.whl", hash = "sha256:d1d2532b340b692880261c15aee4dc94dd22ca5d61b9db9a8a361953d36410b1"}, - {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0d2454f0aef65ea81037759be5ca9947539667eecebca092733b2eb43c965a81"}, - {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:44d8ffbb9c06e5a7f529f38f53eda23e50d1ed33c6c869e01481d3fafa6b8142"}, - {file = "yarl-1.9.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aaaea1e536f98754a6e5c56091baa1b6ce2f2700cc4a00b0d49eca8dea471074"}, - {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3777ce5536d17989c91696db1d459574e9a9bd37660ea7ee4d3344579bb6f129"}, - {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fc5fc1eeb029757349ad26bbc5880557389a03fa6ada41703db5e068881e5f2"}, - {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea65804b5dc88dacd4a40279af0cdadcfe74b3e5b4c897aa0d81cf86927fee78"}, - {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa102d6d280a5455ad6a0f9e6d769989638718e938a6a0a2ff3f4a7ff8c62cc4"}, - {file = "yarl-1.9.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09efe4615ada057ba2d30df871d2f668af661e971dfeedf0c159927d48bbeff0"}, - {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:008d3e808d03ef28542372d01057fd09168419cdc8f848efe2804f894ae03e51"}, - {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:6f5cb257bc2ec58f437da2b37a8cd48f666db96d47b8a3115c29f316313654ff"}, - {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:992f18e0ea248ee03b5a6e8b3b4738850ae7dbb172cc41c966462801cbf62cf7"}, - {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0e9d124c191d5b881060a9e5060627694c3bdd1fe24c5eecc8d5d7d0eb6faabc"}, - {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3986b6f41ad22988e53d5778f91855dc0399b043fc8946d4f2e68af22ee9ff10"}, - {file = "yarl-1.9.4-cp312-cp312-win32.whl", hash = "sha256:4b21516d181cd77ebd06ce160ef8cc2a5e9ad35fb1c5930882baff5ac865eee7"}, - {file = "yarl-1.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:a9bd00dc3bc395a662900f33f74feb3e757429e545d831eef5bb280252631984"}, - {file = "yarl-1.9.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:63b20738b5aac74e239622d2fe30df4fca4942a86e31bf47a81a0e94c14df94f"}, - {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7d7f7de27b8944f1fee2c26a88b4dabc2409d2fea7a9ed3df79b67277644e17"}, - {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c74018551e31269d56fab81a728f683667e7c28c04e807ba08f8c9e3bba32f14"}, - {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ca06675212f94e7a610e85ca36948bb8fc023e458dd6c63ef71abfd482481aa5"}, - {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5aef935237d60a51a62b86249839b51345f47564208c6ee615ed2a40878dccdd"}, - {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2b134fd795e2322b7684155b7855cc99409d10b2e408056db2b93b51a52accc7"}, - {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d25039a474c4c72a5ad4b52495056f843a7ff07b632c1b92ea9043a3d9950f6e"}, - {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f7d6b36dd2e029b6bcb8a13cf19664c7b8e19ab3a58e0fefbb5b8461447ed5ec"}, - {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:957b4774373cf6f709359e5c8c4a0af9f6d7875db657adb0feaf8d6cb3c3964c"}, - {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d7eeb6d22331e2fd42fce928a81c697c9ee2d51400bd1a28803965883e13cead"}, - {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:6a962e04b8f91f8c4e5917e518d17958e3bdee71fd1d8b88cdce74dd0ebbf434"}, - {file = "yarl-1.9.4-cp37-cp37m-win32.whl", hash = "sha256:f3bc6af6e2b8f92eced34ef6a96ffb248e863af20ef4fde9448cc8c9b858b749"}, - {file = "yarl-1.9.4-cp37-cp37m-win_amd64.whl", hash = "sha256:ad4d7a90a92e528aadf4965d685c17dacff3df282db1121136c382dc0b6014d2"}, - {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ec61d826d80fc293ed46c9dd26995921e3a82146feacd952ef0757236fc137be"}, - {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8be9e837ea9113676e5754b43b940b50cce76d9ed7d2461df1af39a8ee674d9f"}, - {file = "yarl-1.9.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:bef596fdaa8f26e3d66af846bbe77057237cb6e8efff8cd7cc8dff9a62278bbf"}, - {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d47552b6e52c3319fede1b60b3de120fe83bde9b7bddad11a69fb0af7db32f1"}, - {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84fc30f71689d7fc9168b92788abc977dc8cefa806909565fc2951d02f6b7d57"}, - {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4aa9741085f635934f3a2583e16fcf62ba835719a8b2b28fb2917bb0537c1dfa"}, - {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:206a55215e6d05dbc6c98ce598a59e6fbd0c493e2de4ea6cc2f4934d5a18d130"}, - {file = "yarl-1.9.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07574b007ee20e5c375a8fe4a0789fad26db905f9813be0f9fef5a68080de559"}, - {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5a2e2433eb9344a163aced6a5f6c9222c0786e5a9e9cac2c89f0b28433f56e23"}, - {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:6ad6d10ed9b67a382b45f29ea028f92d25bc0bc1daf6c5b801b90b5aa70fb9ec"}, - {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:6fe79f998a4052d79e1c30eeb7d6c1c1056ad33300f682465e1b4e9b5a188b78"}, - {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a825ec844298c791fd28ed14ed1bffc56a98d15b8c58a20e0e08c1f5f2bea1be"}, - {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8619d6915b3b0b34420cf9b2bb6d81ef59d984cb0fde7544e9ece32b4b3043c3"}, - {file = "yarl-1.9.4-cp38-cp38-win32.whl", hash = "sha256:686a0c2f85f83463272ddffd4deb5e591c98aac1897d65e92319f729c320eece"}, - {file = "yarl-1.9.4-cp38-cp38-win_amd64.whl", hash = "sha256:a00862fb23195b6b8322f7d781b0dc1d82cb3bcac346d1e38689370cc1cc398b"}, - {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:604f31d97fa493083ea21bd9b92c419012531c4e17ea6da0f65cacdcf5d0bd27"}, - {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8a854227cf581330ffa2c4824d96e52ee621dd571078a252c25e3a3b3d94a1b1"}, - {file = "yarl-1.9.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ba6f52cbc7809cd8d74604cce9c14868306ae4aa0282016b641c661f981a6e91"}, - {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a6327976c7c2f4ee6816eff196e25385ccc02cb81427952414a64811037bbc8b"}, - {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8397a3817d7dcdd14bb266283cd1d6fc7264a48c186b986f32e86d86d35fbac5"}, - {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e0381b4ce23ff92f8170080c97678040fc5b08da85e9e292292aba67fdac6c34"}, - {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23d32a2594cb5d565d358a92e151315d1b2268bc10f4610d098f96b147370136"}, - {file = "yarl-1.9.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ddb2a5c08a4eaaba605340fdee8fc08e406c56617566d9643ad8bf6852778fc7"}, - {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:26a1dc6285e03f3cc9e839a2da83bcbf31dcb0d004c72d0730e755b33466c30e"}, - {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:18580f672e44ce1238b82f7fb87d727c4a131f3a9d33a5e0e82b793362bf18b4"}, - {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:29e0f83f37610f173eb7e7b5562dd71467993495e568e708d99e9d1944f561ec"}, - {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:1f23e4fe1e8794f74b6027d7cf19dc25f8b63af1483d91d595d4a07eca1fb26c"}, - {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:db8e58b9d79200c76956cefd14d5c90af54416ff5353c5bfd7cbe58818e26ef0"}, - {file = "yarl-1.9.4-cp39-cp39-win32.whl", hash = "sha256:c7224cab95645c7ab53791022ae77a4509472613e839dab722a72abe5a684575"}, - {file = "yarl-1.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:824d6c50492add5da9374875ce72db7a0733b29c2394890aef23d533106e2b15"}, - {file = "yarl-1.9.4-py3-none-any.whl", hash = "sha256:928cecb0ef9d5a7946eb6ff58417ad2fe9375762382f1bf5c55e61645f2c43ad"}, - {file = "yarl-1.9.4.tar.gz", hash = "sha256:566db86717cf8080b99b58b083b773a908ae40f06681e87e589a976faf8246bf"}, -] - -[package.dependencies] -idna = ">=2.0" -multidict = ">=4.0" - [[package]] name = "zipp" version = "3.19.2" @@ -3586,4 +3240,4 @@ test = ["big-O", "importlib-resources ; python_version < \"3.9\"", "jaraco.funct [metadata] lock-version = "2.1" python-versions = "^3.9" -content-hash = "360e7128e1296a81a16070db02a7145bf9e807f3795bcfe56f4e1e453c976f6b" +content-hash = "084124893c8d1054d2f99e0c3a8faa3346695641ecca8e97f2af90d38f842a84" diff --git a/pyproject.toml b/pyproject.toml index 9af0c843c..1b27be77c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "pie-modules" version = "0.15.9" -description = "Model and Taskmodule implementations for PyTorch-IE" +description = "Utility modules for Python-IE" authors = ["Arne Binder "] readme = "README.md" homepage = "https://github.com/arnebinder/pie-modules" @@ -24,22 +24,15 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.9" -# TODO: remove and use pie-core instead -pytorch-ie = ">=0.31.9,<0.32.0" -pytorch-lightning = "^2.1.0" -torchmetrics = "^1" -# >=4.35 because of BartModelWithDecoderPositionIds, <4.37 because of generation config -# created from model config in BartAsPointerNetwork -transformers = ">=4.35.0,<4.37.0" +pie-core = ">=0.2.0,<0.3.0" +# for show_as_markdown in metrics +pandas = ">=2.0.3,<3.0.0" [tool.poetry.group.dev.dependencies] -torch = {version = "^2.1.0+cpu", source = "pytorch"} pytest = "^7.4.2" pytest-cov = "^4.1.0" pre-commit = "^3.4.0" tabulate = "^0.9" -# for TokenClassificationModelWithSeq2SeqEncoderAndCrf -pytorch-crf = ">=0.7.2" # for rouge metric (tests only) and for NltkSentenceSplitter nltk = "^3.8.1" # for NltkSentenceSplitter @@ -60,7 +53,6 @@ name = "pre-release" url = "https://test.pypi.org/simple/" priority = "explicit" - [tool.pytest.ini_options] addopts = [ "--color=yes", diff --git a/src/pie_modules/annotations.py b/src/pie_modules/annotations.py index d7710595f..dfdbb586a 100644 --- a/src/pie_modules/annotations.py +++ b/src/pie_modules/annotations.py @@ -1,22 +1,192 @@ import dataclasses -from typing import Optional +from dataclasses import dataclass, field +from typing import Any, Optional, Tuple from pie_core import Annotation -# re-export all annotations from pytorch_ie to have a single entry point -from pytorch_ie.annotations import ( - BinaryRelation, - Label, - LabeledMultiSpan, - LabeledSpan, - MultiLabel, - MultiLabeledBinaryRelation, - MultiLabeledSpan, - MultiSpan, - NaryRelation, - Span, - _post_init_single_label, -) + +def _post_init_single_label(self): + if not isinstance(self.label, str): + raise ValueError("label must be a single string.") + + if not isinstance(self.score, float): + raise ValueError("score must be a single float.") + + +def _post_init_multi_label(self): + if self.score is None: + score = tuple([1.0] * len(self.label)) + object.__setattr__(self, "score", score) + + if not isinstance(self.label, tuple): + object.__setattr__(self, "label", tuple(self.label)) + + if not isinstance(self.score, tuple): + object.__setattr__(self, "score", tuple(self.score)) + + if len(self.label) != len(self.score): + raise ValueError( + f"Number of labels ({len(self.label)}) and scores ({len(self.score)}) must be equal." + ) + + +def _post_init_multi_span(self): + if isinstance(self.slices, list): + object.__setattr__(self, "slices", tuple(tuple(s) for s in self.slices)) + + +def _post_init_arguments_and_roles(self): + if len(self.arguments) != len(self.roles): + raise ValueError( + f"Number of arguments ({len(self.arguments)}) and roles ({len(self.roles)}) must be equal" + ) + if not isinstance(self.arguments, tuple): + object.__setattr__(self, "arguments", tuple(self.arguments)) + if not isinstance(self.roles, tuple): + object.__setattr__(self, "roles", tuple(self.roles)) + + +@dataclass(eq=True, frozen=True) +class Label(Annotation): + label: str + score: float = field(default=1.0, compare=False) + + def __post_init__(self) -> None: + _post_init_single_label(self) + + def resolve(self) -> Any: + return self.label + + +@dataclass(eq=True, frozen=True) +class MultiLabel(Annotation): + label: Tuple[str, ...] + score: Optional[Tuple[float, ...]] = field(default=None, compare=False) + + def __post_init__(self) -> None: + _post_init_multi_label(self) + + def resolve(self) -> Any: + return self.label + + +@dataclass(eq=True, frozen=True) +class Span(Annotation): + start: int + end: int + + def __str__(self) -> str: + if not self.is_attached: + return super().__str__() + return str(self.target[self.start : self.end]) + + def resolve(self) -> Any: + if self.is_attached: + return self.target[self.start : self.end] + else: + raise ValueError(f"{self} is not attached to a target.") + + +@dataclass(eq=True, frozen=True) +class LabeledSpan(Span): + label: str + score: float = field(default=1.0, compare=False) + + def __post_init__(self) -> None: + _post_init_single_label(self) + + def resolve(self) -> Any: + return self.label, super().resolve() + + +@dataclass(eq=True, frozen=True) +class MultiLabeledSpan(Span): + label: Tuple[str, ...] + score: Optional[Tuple[float, ...]] = field(default=None, compare=False) + + def __post_init__(self) -> None: + _post_init_multi_label(self) + + def resolve(self) -> Any: + return self.label, super().resolve() + + +@dataclass(eq=True, frozen=True) +class MultiSpan(Annotation): + slices: Tuple[Tuple[int, int], ...] + + def __post_init__(self) -> None: + _post_init_multi_span(self) + + def __str__(self) -> str: + if not self.is_attached: + return super().__str__() + return str(tuple(self.target[start:end] for start, end in self.slices)) + + def resolve(self) -> Any: + if self.is_attached: + return tuple(self.target[start:end] for start, end in self.slices) + else: + raise ValueError(f"{self} is not attached to a target.") + + +@dataclass(eq=True, frozen=True) +class LabeledMultiSpan(MultiSpan): + label: str + score: float = field(default=1.0, compare=False) + + def __post_init__(self) -> None: + super().__post_init__() + _post_init_single_label(self) + + def resolve(self) -> Any: + return self.label, super().resolve() + + +@dataclass(eq=True, frozen=True) +class BinaryRelation(Annotation): + head: Annotation + tail: Annotation + label: str + score: float = field(default=1.0, compare=False) + + def __post_init__(self) -> None: + _post_init_single_label(self) + + def resolve(self) -> Any: + return self.label, (self.head.resolve(), self.tail.resolve()) + + +@dataclass(eq=True, frozen=True) +class MultiLabeledBinaryRelation(Annotation): + head: Annotation + tail: Annotation + label: Tuple[str, ...] + score: Optional[Tuple[float, ...]] = field(default=None, compare=False) + + def __post_init__(self) -> None: + _post_init_multi_label(self) + + def resolve(self) -> Any: + return self.label, (self.head.resolve(), self.tail.resolve()) + + +@dataclass(eq=True, frozen=True) +class NaryRelation(Annotation): + arguments: Tuple[Annotation, ...] + roles: Tuple[str, ...] + label: str + score: float = field(default=1.0, compare=False) + + def __post_init__(self) -> None: + _post_init_arguments_and_roles(self) + _post_init_single_label(self) + + def resolve(self) -> Any: + return ( + self.label, + tuple((role, arg.resolve()) for arg, role in zip(self.arguments, self.roles)), + ) @dataclasses.dataclass(eq=True, frozen=True) diff --git a/src/pie_modules/document/processing/__init__.py b/src/pie_modules/document/processing/__init__.py index 296944c63..28d15ba1a 100644 --- a/src/pie_modules/document/processing/__init__.py +++ b/src/pie_modules/document/processing/__init__.py @@ -7,5 +7,4 @@ from .tokenization import ( text_based_document_to_token_based, token_based_document_to_text_based, - tokenize_document, ) diff --git a/src/pie_modules/document/processing/text_span_trimmer.py b/src/pie_modules/document/processing/text_span_trimmer.py index 412e01110..eb27a58f4 100644 --- a/src/pie_modules/document/processing/text_span_trimmer.py +++ b/src/pie_modules/document/processing/text_span_trimmer.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import TypeVar +from typing import Any, Dict, TypeVar from pie_core import AnnotationLayer, Document @@ -45,6 +45,7 @@ def trim_text_spans( text = spans.target + original_kwargs: dict[str, Any] for span in spans: if isinstance(span, Span): starts_and_ends = [(span.start, span.end)] @@ -99,6 +100,7 @@ def trim_text_spans( ) removed_span_ids.append(span._id) continue + new_kwargs: dict[str, Any] if isinstance(span, Span): if not len(new_starts_and_ends) == 1: raise ValueError(f"Expected one span, got {len(new_starts_and_ends)}") diff --git a/src/pie_modules/document/processing/tokenization.py b/src/pie_modules/document/processing/tokenization.py index 12a8a6be3..c45b6ebc8 100644 --- a/src/pie_modules/document/processing/tokenization.py +++ b/src/pie_modules/document/processing/tokenization.py @@ -1,8 +1,6 @@ -import functools -import json import logging from collections import defaultdict -from copy import copy, deepcopy +from copy import deepcopy from typing import ( Callable, Dict, @@ -18,7 +16,6 @@ from pie_core import Annotation from pie_core.utils.hydra import resolve_type -from transformers import PreTrainedTokenizer from pie_modules.annotations import MultiSpan, Span from pie_modules.documents import TextBasedDocument, TokenBasedDocument @@ -105,13 +102,13 @@ def char_span_to_token_span( f"The first target of a text targeting span must be a string, but found {type(base_text)} as first " f"target type. Can not convert the span {span}." ) - stripped_slices = [ - get_stripped_offsets(start, end, base_text) for start, end in span.slices - ] + stripped_slices = tuple( + [get_stripped_offsets(start, end, base_text) for start, end in span.slices] + ) else: stripped_slices = span.slices # remove empty and invalid slices - stripped_slices = [(start, end) for start, end in stripped_slices if start < end] + stripped_slices = tuple([(start, end) for start, end in stripped_slices if start < end]) if len(stripped_slices) == 0: return None slices_inclusive_end = [ @@ -449,123 +446,3 @@ def token_based_document_to_text_based( added_annotations.setdefault(layer_name, {}).update(annotation_mapping) return result - - -def tokenize_document( - doc: TextBasedDocument, - tokenizer: PreTrainedTokenizer, - result_document_type: Type[ToD], - partition_layer: Optional[str] = None, - strip_spans: bool = False, - strict_span_conversion: bool = True, - added_annotations: Optional[List[Dict[str, Dict[Annotation, Annotation]]]] = None, - verbose: bool = True, - **tokenize_kwargs, -) -> List[ToD]: - """Tokenize a document with a given tokenizer and return a list of token based documents. The - document is tokenized in partitions if a partition layer is provided. The annotations that - target the text are converted to target the tokens and also all dependent annotations are - converted. - - Args: - doc (TextBasedDocument): The document to tokenize. - tokenizer (PreTrainedTokenizer): The tokenizer. - result_document_type (Type[ToD]): The exact type of the token based documents. - partition_layer (Optional[str], optional): The layer to use for partitioning the document. If None, the whole - document is tokenized. Defaults to None. - strip_spans (bool, optional): If True, strip the whitespace from the character spans before converting them to - token spans. Defaults to False. - strict_span_conversion (bool, optional): If True, raise an error if not all annotations can be converted to - token based documents. Defaults to True. - added_annotations (Optional[List[Dict[str, Dict[Annotation, Annotation]]]], optional): Pass an empty list to - collect the added annotations. Defaults to None. - verbose (bool, optional): If True, log warnings if annotations can not be converted. Defaults to True. - - Returns: - List[ToD]: The token based documents of type result_document_type with the converted annotations. - """ - - added_annotation_lists: Dict[str, List[Annotation]] = defaultdict(list) - result = [] - partitions: Iterable[Span] - if partition_layer is None: - partitions = [Span(start=0, end=len(doc.text))] - else: - partitions = doc[partition_layer] - for partition in partitions: - text = doc.text[partition.start : partition.end] - current_tokenize_kwargs = copy(tokenize_kwargs) - if "text" in tokenize_kwargs: - current_tokenize_kwargs["text_pair"] = text - sequence_index = 1 - else: - current_tokenize_kwargs["text"] = text - sequence_index = 0 - tokenized_text = tokenizer(**current_tokenize_kwargs) - for batch_encoding in tokenized_text.encodings: - token_offset_mapping = batch_encoding.offsets - char_to_token: Optional[Callable[[int], Optional[int]]] - char_to_token = functools.partial( - batch_encoding.char_to_token, sequence_index=sequence_index - ) - token_offset_mapping = [ - offsets if s_id == sequence_index else (0, 0) - for s_id, offsets in zip(batch_encoding.sequence_ids, token_offset_mapping) - ] - if partition.start > 0: - token_offset_mapping = [ - (start + partition.start, end + partition.start) - for start, end in token_offset_mapping - ] - char_to_token = None - current_added_annotations: Dict[str, Dict[Annotation, Annotation]] = defaultdict(dict) - tokenized_document = text_based_document_to_token_based( - doc, - tokens=batch_encoding.tokens, - result_document_type=result_document_type, - token_offset_mapping=token_offset_mapping, - char_to_token=char_to_token, - strict_span_conversion=False, - strip_spans=strip_spans, - verbose=False, - added_annotations=current_added_annotations, - ) - tokenized_document.metadata["tokenizer_encoding"] = batch_encoding - result.append(tokenized_document) - for k, v in current_added_annotations.items(): - added_annotation_lists[k].extend(v) - if added_annotations is not None: - added_annotations.append(current_added_annotations) - - missed_annotations = defaultdict(set) - if strict_span_conversion or verbose: - # We check the annotations with respect to the layers of the result_document_type. - # Note that the original document may have more layers, but since result documents - # are of type result_document_type, we only check the layers of this type. - for annotation_field in result_document_type.annotation_fields(): - # do not check the partition layer because the partitions are not required later on - # and entries get quite probably removed when windowing is applied, so this just pollutes the logs - if annotation_field.name != partition_layer: - current_missed_annotations = set(doc[annotation_field.name]) - set( - added_annotation_lists[annotation_field.name] - ) - if len(current_missed_annotations) > 0: - missed_annotations[annotation_field.name] = current_missed_annotations - - if len(missed_annotations) > 0: - missed_annotations_simplified = {k: str(v) for k, v in missed_annotations.items()} - if strict_span_conversion: - raise ValueError( - f"could not convert all annotations from document with id={doc.id} to token based documents, " - f"but strict_span_conversion is True, so raise an error, " - f"missed annotations:\n{json.dumps(missed_annotations_simplified, sort_keys=True, indent=2)}" - ) - else: - if verbose: - logger.warning( - f"could not convert all annotations from document with id={doc.id} to token based documents, " - f"missed annotations (disable this message with verbose=False):\n" - f"{json.dumps(missed_annotations_simplified, sort_keys=True, indent=2)}" - ) - - return result diff --git a/src/pie_modules/documents.py b/src/pie_modules/documents.py index bf07de932..326621cdb 100644 --- a/src/pie_modules/documents.py +++ b/src/pie_modules/documents.py @@ -1,29 +1,8 @@ import dataclasses +from typing import Any, Dict, Optional, Tuple -from pie_core import AnnotationLayer, annotation_field - -# re-export all documents from pytorch_ie to have a single entry point -from pytorch_ie.documents import ( - TextBasedDocument, - TextDocumentWithLabel, - TextDocumentWithLabeledMultiSpans, - TextDocumentWithLabeledMultiSpansAndBinaryRelations, - TextDocumentWithLabeledMultiSpansAndLabeledPartitions, - TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions, - TextDocumentWithLabeledPartitions, - TextDocumentWithLabeledSpans, - TextDocumentWithLabeledSpansAndBinaryRelations, - TextDocumentWithLabeledSpansAndLabeledPartitions, - TextDocumentWithLabeledSpansAndSentences, - TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, - TextDocumentWithMultiLabel, - TextDocumentWithSentences, - TextDocumentWithSpans, - TextDocumentWithSpansAndBinaryRelations, - TextDocumentWithSpansAndLabeledPartitions, - TextDocumentWithSpansBinaryRelationsAndLabeledPartitions, - TokenBasedDocument, -) +from pie_core import AnnotationLayer, Document, annotation_field +from typing_extensions import TypeAlias from pie_modules.annotations import ( AbstractiveSummary, @@ -31,12 +10,172 @@ BinaryRelation, ExtractiveAnswer, GenerativeAnswer, + Label, LabeledMultiSpan, LabeledSpan, + MultiLabel, Question, + Span, ) +@dataclasses.dataclass +class WithMetadata: + id: Optional[str] = None + metadata: Dict[str, Any] = dataclasses.field(default_factory=dict) + + +@dataclasses.dataclass +class WithTokens: + tokens: Tuple[str, ...] + + +@dataclasses.dataclass +class WithText: + text: str + + +@dataclasses.dataclass +class TextBasedDocument(WithMetadata, WithText, Document): + pass + + +@dataclasses.dataclass +class TokenBasedDocument(WithMetadata, WithTokens, Document): + def __post_init__(self) -> None: + + # When used in a dataset, the document gets serialized to json like structure which does not know tuples, + # so they get converted to lists. This is a workaround to automatically convert the "tokens" back to tuples + # when the document is created from a dataset. + if isinstance(self.tokens, list): + object.__setattr__(self, "tokens", tuple(self.tokens)) + elif not isinstance(self.tokens, tuple): + raise ValueError("tokens must be a tuple.") + + # Call the default document construction code + super().__post_init__() + + +# backwards compatibility +TextDocument: TypeAlias = TextBasedDocument + + +@dataclasses.dataclass +class DocumentWithLabel(Document): + label: AnnotationLayer[Label] = annotation_field() + + +@dataclasses.dataclass +class DocumentWithMultiLabel(Document): + label: AnnotationLayer[MultiLabel] = annotation_field() + + +@dataclasses.dataclass +class TextDocumentWithLabel(DocumentWithLabel, TextBasedDocument): + pass + + +@dataclasses.dataclass +class TextDocumentWithMultiLabel(DocumentWithMultiLabel, TextBasedDocument): + pass + + +@dataclasses.dataclass +class TextDocumentWithLabeledPartitions(TextBasedDocument): + labeled_partitions: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + + +@dataclasses.dataclass +class TextDocumentWithSentences(TextBasedDocument): + sentences: AnnotationLayer[Span] = annotation_field(target="text") + + +@dataclasses.dataclass +class TextDocumentWithSpans(TextBasedDocument): + spans: AnnotationLayer[Span] = annotation_field(target="text") + + +@dataclasses.dataclass +class TextDocumentWithLabeledSpans(TextBasedDocument): + labeled_spans: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + + +@dataclasses.dataclass +class TextDocumentWithLabeledSpansAndLabeledPartitions( + TextDocumentWithLabeledSpans, TextDocumentWithLabeledPartitions +): + pass + + +@dataclasses.dataclass +class TextDocumentWithLabeledSpansAndSentences( + TextDocumentWithLabeledSpans, TextDocumentWithSentences +): + pass + + +@dataclasses.dataclass +class TextDocumentWithLabeledSpansAndBinaryRelations(TextDocumentWithLabeledSpans): + binary_relations: AnnotationLayer[BinaryRelation] = annotation_field(target="labeled_spans") + + +@dataclasses.dataclass +class TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions( + TextDocumentWithLabeledSpansAndLabeledPartitions, + TextDocumentWithLabeledSpansAndBinaryRelations, + TextDocumentWithLabeledPartitions, +): + pass + + +@dataclasses.dataclass +class TextDocumentWithSpansAndBinaryRelations(TextDocumentWithSpans): + binary_relations: AnnotationLayer[BinaryRelation] = annotation_field(target="spans") + + +@dataclasses.dataclass +class TextDocumentWithSpansAndLabeledPartitions( + TextDocumentWithSpans, TextDocumentWithLabeledPartitions +): + pass + + +@dataclasses.dataclass +class TextDocumentWithSpansBinaryRelationsAndLabeledPartitions( + TextDocumentWithSpansAndLabeledPartitions, + TextDocumentWithSpansAndBinaryRelations, + TextDocumentWithLabeledPartitions, +): + pass + + +@dataclasses.dataclass +class TextDocumentWithLabeledMultiSpans(TextBasedDocument): + labeled_multi_spans: AnnotationLayer[LabeledMultiSpan] = annotation_field(target="text") + + +@dataclasses.dataclass +class TextDocumentWithLabeledMultiSpansAndLabeledPartitions( + TextDocumentWithLabeledMultiSpans, TextDocumentWithLabeledPartitions +): + pass + + +@dataclasses.dataclass +class TextDocumentWithLabeledMultiSpansAndBinaryRelations(TextDocumentWithLabeledMultiSpans): + binary_relations: AnnotationLayer[BinaryRelation] = annotation_field( + target="labeled_multi_spans" + ) + + +@dataclasses.dataclass +class TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions( + TextDocumentWithLabeledMultiSpansAndLabeledPartitions, + TextDocumentWithLabeledMultiSpansAndBinaryRelations, +): + pass + + @dataclasses.dataclass class TextDocumentWithQuestionsAndExtractiveAnswers(TextBasedDocument): """A text based PIE document with annotations for extractive question answering.""" @@ -66,7 +205,6 @@ class TokenDocumentWithQuestionsAndExtractiveAnswers(TokenBasedDocument): # backwards compatibility ExtractiveQADocument = TextDocumentWithQuestionsAndExtractiveAnswers TokenizedExtractiveQADocument = TokenDocumentWithQuestionsAndExtractiveAnswers -TextDocument = TextBasedDocument @dataclasses.dataclass diff --git a/src/pie_modules/metrics/__init__.py b/src/pie_modules/metrics/__init__.py index 7038ade07..674d2ef0c 100644 --- a/src/pie_modules/metrics/__init__.py +++ b/src/pie_modules/metrics/__init__.py @@ -8,5 +8,4 @@ FieldLengthCollector, LabelCountCollector, SubFieldLengthCollector, - TokenCountCollector, ) diff --git a/src/pie_modules/metrics/relation_argument_distance_collector.py b/src/pie_modules/metrics/relation_argument_distance_collector.py index 935c11cbc..138d3ede2 100644 --- a/src/pie_modules/metrics/relation_argument_distance_collector.py +++ b/src/pie_modules/metrics/relation_argument_distance_collector.py @@ -1,13 +1,9 @@ from collections import defaultdict -from typing import Any, Dict, List, Optional, Type, Union +from typing import Dict, List from pie_core import Document, DocumentStatistic -from pie_core.utils.hydra import resolve_target -from transformers import AutoTokenizer, PreTrainedTokenizer from pie_modules.annotations import BinaryRelation, NaryRelation, Span -from pie_modules.document.processing import tokenize_document -from pie_modules.documents import TextBasedDocument, TokenBasedDocument from pie_modules.utils.span import distance @@ -33,10 +29,6 @@ def __init__( self, layer: str, distance_type: str = "outer", - tokenize: bool = False, - tokenize_kwargs: Optional[Dict[str, Any]] = None, - tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, - tokenized_document_type: Optional[Union[str, Type[TokenBasedDocument]]] = None, key_all: str = "ALL", **kwargs, ): @@ -44,45 +36,16 @@ def __init__( self.layer = layer self.distance_type = distance_type self.key_all = key_all - self.tokenize = tokenize - self.tokenize_kwargs = tokenize_kwargs or {} - if self.tokenize: - if tokenizer is None: - raise ValueError( - "tokenizer must be provided to calculate distance in means of tokens" - ) - if isinstance(tokenizer, str): - tokenizer = AutoTokenizer.from_pretrained(tokenizer) - self.tokenizer = tokenizer - if tokenized_document_type is None: - raise ValueError( - "tokenized_document_type must be provided to calculate distance in means of tokens" - ) - self.tokenized_document_type: Type[TokenBasedDocument] = resolve_target( - tokenized_document_type - ) def _collect(self, doc: Document) -> Dict[str, List[float]]: - if self.tokenize: - if not isinstance(doc, TextBasedDocument): - raise ValueError( - "doc must be a TextBasedDocument to calculate distance in means of tokens" - ) - docs = tokenize_document( - doc, - tokenizer=self.tokenizer, - result_document_type=self.tokenized_document_type, - **self.tokenize_kwargs, - ) - else: - docs = [doc] + docs = [doc] values: Dict[str, List[float]] = defaultdict(list) for doc in docs: layer_obj = getattr(doc, self.layer) for binary_relation in layer_obj: if isinstance(binary_relation, BinaryRelation): - args = [binary_relation.head, binary_relation.tail] + args = (binary_relation.head, binary_relation.tail) label = binary_relation.label elif isinstance(binary_relation, NaryRelation): args = binary_relation.arguments diff --git a/src/pie_modules/metrics/span_coverage_collector.py b/src/pie_modules/metrics/span_coverage_collector.py index df783f228..eb9338386 100644 --- a/src/pie_modules/metrics/span_coverage_collector.py +++ b/src/pie_modules/metrics/span_coverage_collector.py @@ -1,13 +1,9 @@ import logging -from typing import Any, Dict, List, Optional, Set, Type, Union +from typing import List, Optional, Set, Union from pie_core import Document, DocumentStatistic -from pie_core.utils.hydra import resolve_type -from transformers import AutoTokenizer, PreTrainedTokenizer from pie_modules.annotations import LabeledMultiSpan, Span -from pie_modules.document.processing import tokenize_document -from pie_modules.documents import TextBasedDocument, TokenBasedDocument logger = logging.getLogger(__name__) @@ -36,57 +32,16 @@ class SpanCoverageCollector(DocumentStatistic): def __init__( self, layer: str, - tokenize: bool = False, - tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, - tokenized_document_type: Optional[Union[str, Type[TokenBasedDocument]]] = None, labels: Optional[Union[List[str], str]] = None, label_attribute: str = "label", - tokenize_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ): super().__init__(**kwargs) self.layer = layer self.labels = labels self.label_field = label_attribute - self.tokenize = tokenize - if self.tokenize: - if tokenizer is None: - raise ValueError( - "tokenizer must be provided to calculate the span coverage in means of tokens" - ) - if isinstance(tokenizer, str): - tokenizer = AutoTokenizer.from_pretrained(tokenizer) - self.tokenizer = tokenizer - if tokenized_document_type is None: - raise ValueError( - "tokenized_document_type must be provided to calculate the span coverage in means of tokens" - ) - self.tokenized_document_type = resolve_type( - tokenized_document_type, expected_super_type=TokenBasedDocument - ) - self.tokenize_kwargs = tokenize_kwargs or {} def _collect(self, doc: Document) -> float: - docs: Union[List[Document], List[TokenBasedDocument]] - if self.tokenize: - if not isinstance(doc, TextBasedDocument): - raise ValueError( - "doc must be a TextBasedDocument to calculate the span coverage in means of tokens" - ) - docs = tokenize_document( - doc, - tokenizer=self.tokenizer, - result_document_type=self.tokenized_document_type, - **self.tokenize_kwargs, - ) - if len(docs) != 1: - raise ValueError( - "tokenization of a single document must result in a single document to calculate the " - "span coverage correctly. Please check your tokenization settings, especially that " - "no windowing is applied because of max input length restrictions." - ) - doc = docs[0] - layer_obj = getattr(doc, self.layer) target = layer_obj.target covered_indices: Set[int] = set() diff --git a/src/pie_modules/metrics/span_length_collector.py b/src/pie_modules/metrics/span_length_collector.py index 8acac1399..649a8e801 100644 --- a/src/pie_modules/metrics/span_length_collector.py +++ b/src/pie_modules/metrics/span_length_collector.py @@ -1,14 +1,10 @@ import logging from collections import defaultdict -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Dict, List, Optional, Union from pie_core import Document, DocumentStatistic -from pie_core.utils.hydra import resolve_type -from transformers import AutoTokenizer, PreTrainedTokenizer from pie_modules.annotations import Span -from pie_modules.document.processing import tokenize_document -from pie_modules.documents import TextBasedDocument, TokenBasedDocument logger = logging.getLogger(__name__) @@ -26,12 +22,8 @@ class SpanLengthCollector(DocumentStatistic): def __init__( self, layer: str, - tokenize: bool = False, - tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, - tokenized_document_type: Optional[Union[str, Type[TokenBasedDocument]]] = None, labels: Optional[Union[List[str], str]] = None, label_attribute: str = "label", - tokenize_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ): super().__init__(**kwargs) @@ -40,39 +32,9 @@ def __init__( raise ValueError("labels must be a list of strings or 'INFERRED'") self.labels = labels self.label_field = label_attribute - self.tokenize = tokenize - if self.tokenize: - if tokenizer is None: - raise ValueError( - "tokenizer must be provided to calculate the span length in means of tokens" - ) - if isinstance(tokenizer, str): - tokenizer = AutoTokenizer.from_pretrained(tokenizer) - self.tokenizer = tokenizer - if tokenized_document_type is None: - raise ValueError( - "tokenized_document_type must be provided to calculate the span length in means of tokens" - ) - self.tokenized_document_type = resolve_type( - tokenized_document_type, expected_super_type=TokenBasedDocument - ) - self.tokenize_kwargs = tokenize_kwargs or {} def _collect(self, doc: Document) -> Union[List[int], Dict[str, List[int]]]: - docs: Union[List[Document], List[TokenBasedDocument]] - if self.tokenize: - if not isinstance(doc, TextBasedDocument): - raise ValueError( - "doc must be a TextBasedDocument to calculate the span length in means of tokens" - ) - docs = tokenize_document( - doc, - tokenizer=self.tokenizer, - result_document_type=self.tokenized_document_type, - **self.tokenize_kwargs, - ) - else: - docs = [doc] + docs = [doc] values: Dict[str, List[int]] if isinstance(self.labels, str): diff --git a/src/pie_modules/metrics/statistics.py b/src/pie_modules/metrics/statistics.py index 16b9d41d3..9802b6356 100644 --- a/src/pie_modules/metrics/statistics.py +++ b/src/pie_modules/metrics/statistics.py @@ -1,46 +1,12 @@ import logging from collections import defaultdict -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Callable, Dict, List, Union from pie_core import Document, DocumentStatistic -from transformers import AutoTokenizer, PreTrainedTokenizer - -from pie_modules.documents import TextBasedDocument logger = logging.getLogger(__name__) -class TokenCountCollector(DocumentStatistic): - """Collects the token count of a field when tokenizing its content with a Huggingface - tokenizer. - - The content of the field should be a string. - """ - - def __init__( - self, - tokenizer: Union[str, PreTrainedTokenizer], - text_field: str = "text", - tokenizer_kwargs: Optional[Dict[str, Any]] = None, - document_type: Optional[Type[Document]] = None, - **kwargs, - ): - if document_type is None and text_field == "text": - document_type = TextBasedDocument - super().__init__(document_type=document_type, **kwargs) - self.tokenizer = ( - AutoTokenizer.from_pretrained(tokenizer) if isinstance(tokenizer, str) else tokenizer - ) - self.tokenizer_kwargs = tokenizer_kwargs or {} - self.text_field = text_field - - def _collect(self, doc: Document) -> int: - text = getattr(doc, self.text_field) - encodings = self.tokenizer(text, **self.tokenizer_kwargs) - tokens = encodings.tokens() - return len(tokens) - - class FieldLengthCollector(DocumentStatistic): """Collects the length of a field, e.g. to collect the number the characters in the input text. diff --git a/src/pie_modules/models/__init__.py b/src/pie_modules/models/__init__.py deleted file mode 100644 index df8f4a035..000000000 --- a/src/pie_modules/models/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from .sequence_classification_with_pooler import ( - SequenceClassificationModelWithPooler, - SequencePairSimilarityModelWithPooler, -) -from .simple_extractive_question_answering import SimpleExtractiveQuestionAnsweringModel -from .simple_generative import SimpleGenerativeModel -from .simple_sequence_classification import SimpleSequenceClassificationModel -from .simple_token_classification import SimpleTokenClassificationModel -from .span_tuple_classification import SpanTupleClassificationModel -from .token_classification_with_seq2seq_encoder_and_crf import ( - TokenClassificationModelWithSeq2SeqEncoderAndCrf, -) diff --git a/src/pie_modules/models/base_models/__init__.py b/src/pie_modules/models/base_models/__init__.py deleted file mode 100644 index 8bc2cf097..000000000 --- a/src/pie_modules/models/base_models/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .bart_as_pointer_network import BartAsPointerNetwork -from .bart_with_decoder_position_ids import BartModelWithDecoderPositionIds diff --git a/src/pie_modules/models/base_models/bart_as_pointer_network.py b/src/pie_modules/models/base_models/bart_as_pointer_network.py index b17540263..e69de29bb 100644 --- a/src/pie_modules/models/base_models/bart_as_pointer_network.py +++ b/src/pie_modules/models/base_models/bart_as_pointer_network.py @@ -1,476 +0,0 @@ -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union - -import torch.utils.checkpoint -from torch import nn -from torch.nn import Parameter -from torch.optim import Optimizer -from transformers import BartConfig, BartModel, BartPreTrainedModel, GenerationConfig -from transformers.modeling_outputs import Seq2SeqLMOutput, Seq2SeqModelOutput -from transformers.models.bart.modeling_bart import shift_tokens_right -from transformers.utils import logging - -from pie_modules.models.base_models.bart_with_decoder_position_ids import ( - BartModelWithDecoderPositionIds, -) -from pie_modules.models.components.pointer_head import PointerHead - -logger = logging.get_logger(__name__) - - -def get_layer_norm_parameters( - named_parameters: Iterator[Tuple[str, Parameter]], -) -> Iterator[Parameter]: - return ( - param for name, param in named_parameters if "layernorm" in name or "layer_norm" in name - ) - - -def get_non_layer_norm_parameters( - named_parameters: Iterator[Tuple[str, Parameter]], -) -> Iterator[Parameter]: - return ( - param - for name, param in named_parameters - if not ("layernorm" in name or "layer_norm" in name) - ) - - -class BartAsPointerNetworkConfig(BartConfig): - def __init__( - self, - # respective token ids for the label-, eos-, and pad ids. Can be used as a mapping from the - # target ids to the token ids. - target_token_ids: Optional[List[int]] = None, - # token id mapping to better initialize the label embedding weights - embedding_weight_mapping: Optional[Dict[Union[int, str], List[int]]] = None, - # special decoder position id handling - decoder_position_id_mode: Optional[str] = None, - decoder_position_id_pattern: Optional[List[int]] = None, - decoder_position_id_mapping: Optional[Dict[int, int]] = None, - # other parameters - use_encoder_mlp: bool = True, - use_constraints_encoder_mlp: bool = False, - # optimizer - lr: float = 5e-5, - task_lr: Optional[float] = None, - weight_decay: float = 1e-2, - head_decay: Optional[float] = None, - shared_decay: Optional[float] = None, - encoder_layer_norm_decay: Optional[float] = 0.001, - decoder_layer_norm_decay: Optional[float] = None, - # other BartConfig parameters - **kwargs, - ): - super().__init__(**kwargs) - - self.target_token_ids = target_token_ids - - self.embedding_weight_mapping = embedding_weight_mapping - - self.use_encoder_mlp = use_encoder_mlp - self.use_constraints_encoder_mlp = use_constraints_encoder_mlp - - self.decoder_position_id_mode = decoder_position_id_mode - self.decoder_position_id_pattern = decoder_position_id_pattern - self.decoder_position_id_mapping = decoder_position_id_mapping - - self.lr = lr - self.task_lr = task_lr - self.weight_decay = weight_decay - self.head_decay = head_decay - self.shared_decay = shared_decay - self.encoder_layer_norm_decay = encoder_layer_norm_decay - self.decoder_layer_norm_decay = decoder_layer_norm_decay - - -class BartAsPointerNetwork(BartPreTrainedModel): - config_class = BartAsPointerNetworkConfig - base_model_prefix = "model" - _tied_weights_keys = [ - "encoder.embed_tokens.weight", - "decoder.embed_tokens.weight", - ] - - def __init__(self, config: BartAsPointerNetworkConfig): - super().__init__(config) - if self.config.decoder_position_id_mode is not None: - self.model = BartModelWithDecoderPositionIds(config) - else: - self.model = BartModel(config) - - self.pointer_head = PointerHead( - # target space ids - bos_id=self.model.config.bos_token_id, - eos_id=self.model.config.eos_token_id, - pad_id=self.model.config.pad_token_id, - # decoder-input token ids - target_token_ids=self.model.config.target_token_ids, - # embeddings - embeddings=self.model.decoder.embed_tokens, - embedding_weight_mapping=self.model.config.embedding_weight_mapping, - # other parameters - use_encoder_mlp=self.model.config.use_encoder_mlp, - use_constraints_encoder_mlp=self.model.config.use_constraints_encoder_mlp, - decoder_position_id_mode=self.model.config.decoder_position_id_mode, - decoder_position_id_pattern=self.model.config.decoder_position_id_pattern, - decoder_position_id_mapping=self.model.config.decoder_position_id_mapping, - ) - - # Initialize weights and apply final processing - self.post_init() - - @classmethod - def _load_pretrained_model( - cls, - *args, - **kwargs, - ): - ( - model, - missing_keys, - unexpected_keys, - mismatched_keys, - offload_index, - error_msgs, - ) = super()._load_pretrained_model(*args, **kwargs) - # adjust the model after loading the original model (e.g. vanilla BartModel) - model.adjust_after_loading_original_model() - return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs - - def resize_token_embeddings( - self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None - ) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) - # we also need to update the embeddings in the pointer head - self.pointer_head.set_embeddings(new_embeddings) - return new_embeddings - - def adjust_after_loading_original_model(self): - # target_token_ids contains all new target tokens for the labels and new tokens were added to the end - # of the vocabulary, so we can use its maximum to resize the embedding weights - self.resize_token_embeddings(new_num_tokens=max(self.config.target_token_ids) + 1) - # initialize the newly added embeddings for the labels with better weights from the original embeddings - self.pointer_head.overwrite_embeddings_with_mapping() - - # adjust generation settings - # set the correct decoder_start_token_id - self.config.decoder_start_token_id = self.config.bos_token_id - # disable ForcedBOSTokenLogitsProcessor - self.config.forced_bos_token_id = None - # disable ForcedEOSTokenLogitsProcessor - self.config.forced_eos_token_id = None - # update the generation config accordingly - self.generation_config = GenerationConfig.from_model_config(self.config) - - def base_model_named_params(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]: - yield from self.model.named_parameters(prefix=prefix + self.base_model_prefix) - - def head_named_params(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]: - base_model_param_names = { - name for name, param in self.base_model_named_params(prefix=prefix) - } - for name, param in self.named_parameters(prefix=prefix): - if name not in base_model_param_names: - yield name, param - - def encoder_only_named_params(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]: - shared_params = set(dict(self.encoder_decoder_shared_named_params(prefix=prefix)).values()) - for name, param in self.model.encoder.named_parameters( - prefix=prefix + self.base_model_prefix + ".encoder" - ): - if param not in shared_params: - yield name, param - - def decoder_only_named_params(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]: - shared_params = set(dict(self.encoder_decoder_shared_named_params(prefix=prefix)).values()) - for name, param in self.model.decoder.named_parameters( - prefix=prefix + self.base_model_prefix + ".decoder" - ): - if param not in shared_params: - yield name, param - - def encoder_decoder_shared_named_params( - self, prefix: str = "" - ) -> Iterator[Tuple[str, Parameter]]: - encoder_params = set(self.model.encoder.parameters()) - decoder_params = set(self.model.decoder.parameters()) - for name, param in self.base_model_named_params(prefix=prefix): - if param in encoder_params and param in decoder_params: - yield name, param - - def get_encoder(self): - return self.model.get_encoder() - - def get_decoder(self): - return self.model.get_decoder() - - # @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) - # @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) - # @add_end_docstrings(BART_GENERATION_EXAMPLE) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.LongTensor] = None, - decoder_position_ids: Optional[torch.LongTensor] = None, - constraints: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.Tensor] = None, - decoder_head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - encoder_outputs: Optional[List[torch.FloatTensor]] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, Seq2SeqLMOutput]: - r"""Labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels - for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` - are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., - config.vocab_size]`. - - Returns: - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if labels is not None: - if use_cache: - logger.warning( - "The `use_cache` argument is changed to `False` since `labels` is provided." - ) - use_cache = False - if decoder_input_ids is None and decoder_inputs_embeds is None: - decoder_input_ids = shift_tokens_right( - labels, self.config.pad_token_id, self.config.decoder_start_token_id - ) - - if decoder_input_ids is None: - # we can not create the decoder_input_ids from input_ids, because we need the - # encoder_input_ids for the pointer network - raise ValueError("decoder_input_ids has to be set!") - - # this adjusts the input_ids and, if available, the position_ids - decoder_inputs = self.pointer_head.prepare_decoder_inputs( - input_ids=decoder_input_ids, - # in the case of generation (with past_key_values) the position_ids are already prepared - position_ids=decoder_position_ids, - encoder_input_ids=input_ids, - ) - - model_inputs = dict( - input_ids=input_ids, - encoder_outputs=encoder_outputs, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=True, - ) - for k, v in decoder_inputs.items(): - model_inputs[f"decoder_{k}"] = v - outputs = self.model(**model_inputs) - - if not isinstance(outputs, Seq2SeqModelOutput): - raise ValueError( - "Inconsistent output: The output of the model forward should be of type " - f"`Seq2SeqLMOutput`, but is of type `{type(outputs)}`." - ) - logits, loss = self.pointer_head( - last_hidden_state=outputs.last_hidden_state, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_input_ids=input_ids, - encoder_attention_mask=attention_mask, - labels=labels, - decoder_attention_mask=decoder_attention_mask, - constraints=constraints, - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return Seq2SeqLMOutput( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) - - def prepare_inputs_for_generation( - self, - decoder_input_ids, - encoder_input_ids, # added for pointer network - encoder_attention_mask, # added for pointer network - past_key_values=None, - attention_mask=None, - decoder_attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, - use_cache=None, - encoder_outputs=None, - **kwargs, - ): - result = {} - if self.pointer_head.use_prepared_position_ids: - # we need to prepare the position ids for the decoder here, because later we do not have the full - # input_ids anymore - result["decoder_position_ids"] = self.pointer_head.prepare_decoder_position_ids( - input_ids=decoder_input_ids - ) - - # cut decoder_input_ids if past_key_values is used - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - if "decoder_position_ids" in result: - result["decoder_position_ids"] = result["decoder_position_ids"][ - :, remove_prefix_length: - ] - - result.update( - { - "input_ids": encoder_input_ids, - "encoder_outputs": encoder_outputs, - "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, - "attention_mask": encoder_attention_mask, - "decoder_attention_mask": decoder_attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, - "use_cache": use_cache, # change this to avoid caching (presumably for debugging) - } - ) - return result - - def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): - return shift_tokens_right( - labels, self.config.pad_token_id, self.config.decoder_start_token_id - ) - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - # cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple( - past_state.index_select(0, beam_idx.to(past_state.device)) - for past_state in layer_past[:2] - ) - + layer_past[2:], - ) - return reordered_past - - def _prepare_encoder_decoder_kwargs_for_generation( - self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None - ) -> Dict[str, Any]: - result = super()._prepare_encoder_decoder_kwargs_for_generation( - inputs_tensor=inputs_tensor, - model_kwargs=model_kwargs, - model_input_name=model_input_name, - ) - # add items that are needed for pointer network - result["encoder_input_ids"] = inputs_tensor - result["encoder_attention_mask"] = result["attention_mask"] - return result - - def configure_optimizer(self) -> Optimizer: - parameters = [] - - # head parameters - head_decay = ( - self.config.head_decay - if self.config.head_decay is not None - else self.config.weight_decay - ) - params = { - "lr": self.config.task_lr if self.config.task_lr is not None else self.config.lr, - "weight_decay": head_decay, - "params": dict(self.head_named_params()).values(), - } - parameters.append(params) - - # decoder only layer norm parameters - decoder_layer_norm_decay = ( - self.config.decoder_layer_norm_decay - if self.config.decoder_layer_norm_decay is not None - else self.config.weight_decay - ) - params = { - "lr": self.config.lr, - "weight_decay": decoder_layer_norm_decay, - "params": get_layer_norm_parameters(self.decoder_only_named_params()), - } - parameters.append(params) - - # decoder only other parameters - params = { - "lr": self.config.lr, - "weight_decay": self.config.weight_decay, - "params": get_non_layer_norm_parameters(self.decoder_only_named_params()), - } - parameters.append(params) - - # encoder only layer norm parameters - encoder_layer_norm_decay = ( - self.config.encoder_layer_norm_decay - if self.config.encoder_layer_norm_decay is not None - else self.config.weight_decay - ) - params = { - "lr": self.config.lr, - "weight_decay": encoder_layer_norm_decay, - "params": get_layer_norm_parameters(self.encoder_only_named_params()), - } - parameters.append(params) - - # encoder only other parameters - params = { - "lr": self.config.lr, - "weight_decay": self.config.weight_decay, - "params": get_non_layer_norm_parameters(self.encoder_only_named_params()), - } - parameters.append(params) - - # encoder-decoder shared parameters - shared_decay = ( - self.config.shared_decay - if self.config.shared_decay is not None - else self.config.weight_decay - ) - params = { - "lr": self.config.lr, - "weight_decay": shared_decay, - "params": dict(self.encoder_decoder_shared_named_params()).values(), - } - parameters.append(params) - - return torch.optim.AdamW(parameters) diff --git a/src/pie_modules/models/base_models/bart_with_decoder_position_ids.py b/src/pie_modules/models/base_models/bart_with_decoder_position_ids.py deleted file mode 100644 index 03ca5dea6..000000000 --- a/src/pie_modules/models/base_models/bart_with_decoder_position_ids.py +++ /dev/null @@ -1,536 +0,0 @@ -# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch BART model, but the decoder accepts predefined position ids. If not provided, the -original logic is used to create the position ids. - -The model is based on the BartModel from Transformers 4.35.0, -i.e. https://github.com/huggingface/transformers/blob/v4.35.0/src/transformers/models/bart/modeling_bart.py. - -Note: This also contains some minor modifications to make the code mypy (v1.4.1) compliant. -. -""" -import math -from typing import Any, List, Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -from torch import nn -from transformers.modeling_attn_mask_utils import ( - _prepare_4d_attention_mask, - _prepare_4d_causal_attention_mask, -) -from transformers.modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPastAndCrossAttentions, - Seq2SeqModelOutput, -) -from transformers.models.bart import BartConfig -from transformers.models.bart.modeling_bart import ( - _CHECKPOINT_FOR_DOC, - _CONFIG_FOR_DOC, - _EXPECTED_OUTPUT_SHAPE, - BART_INPUTS_DOCSTRING, - BART_START_DOCSTRING, - BartDecoderLayer, - BartEncoder, - BartPreTrainedModel, - shift_tokens_right, -) -from transformers.utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, -) - -logger = logging.get_logger(__name__) - - -class BartLearnedPositionalEmbeddingWithPositionIds(nn.Embedding): - """This module learns positional embeddings up to a fixed maximum size.""" - - def __init__(self, num_embeddings: int, embedding_dim: int): - # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 - # and adjust num_embeddings appropriately. Other models don't have this hack - self.offset = 2 - super().__init__(num_embeddings + self.offset, embedding_dim) - - def forward( - self, - input_ids: torch.Tensor, - past_key_values_length: int = 0, - position_ids: Optional[torch.Tensor] = None, - ): - """`input_ids' shape is expected to be [bsz x seqlen].""" - - if position_ids is None: - bsz, seq_len = input_ids.shape[:2] - positions = torch.arange( - past_key_values_length, - past_key_values_length + seq_len, - dtype=torch.long, - device=self.weight.device, - ).expand(bsz, -1) - else: - positions = position_ids - - return super().forward(positions + self.offset) - - -class BartDecoderWithPositionIds(BartPreTrainedModel): - """Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a - [`BartDecoderLayer`] - - Args: - config: BartConfig - embed_tokens (nn.Embedding): output embedding - """ - - def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): - super().__init__(config) - self.dropout = config.dropout - self.layerdrop = config.decoder_layerdrop - self.padding_idx = config.pad_token_id - self.max_target_positions = config.max_position_embeddings - self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 - - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) - - if embed_tokens is not None: - self.embed_tokens.weight = embed_tokens.weight - - self.embed_positions = BartLearnedPositionalEmbeddingWithPositionIds( - config.max_position_embeddings, - config.d_model, - ) - self.layers = nn.ModuleList( - [BartDecoderLayer(config) for _ in range(config.decoder_layers)] - ) - self.layernorm_embedding = nn.LayerNorm(config.d_model) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - def forward( - self, - input_ids: torch.LongTensor = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: - r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you - provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Position indices for each input sequence token that are used to create the position embedding - of the sequence. If `None` (default), position ids are automatically created as sequential - integers (takes previous `past_key_values` into account, if provided). - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention - of the decoder. - encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): - Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values - selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): - Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing - cross-attention on hidden heads. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of - shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the - cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those - that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of - all `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded - representation. This is useful if you want more control over how to convert `input_ids` indices - into associated vectors than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - elif input_ids is not None: - input = input_ids - input_shape = input.shape - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - input = inputs_embeds[:, :, -1] - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - # past_key_values_length - past_key_values_length = ( - past_key_values[0][0].shape[2] if past_key_values is not None else 0 - ) - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input) * self.embed_scale - - if getattr(self.config, "_flash_attn_2_enabled", False): - # 2d mask is passed through the layers - attention_mask = ( - attention_mask if (attention_mask is not None and 0 in attention_mask) else None - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - if getattr(self.config, "_flash_attn_2_enabled", False): - encoder_attention_mask = ( - encoder_attention_mask if 0 in encoder_attention_mask else None - ) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) - - # embed positions - if position_ids is not None and position_ids.shape != input_shape: - raise ValueError( - f"Position IDs shape {position_ids.shape} does not match input ids shape {input_shape}." - ) - positions = self.embed_positions(input, past_key_values_length, position_ids) - positions = positions.to(inputs_embeds.device) - - hidden_states = inputs_embeds + positions - hidden_states = self.layernorm_embedding(hidden_states) - - hidden_states = nn.functional.dropout( - hidden_states, p=self.dropout, training=self.training - ) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states: Optional[Tuple[Any, ...]] = () if output_hidden_states else None - all_self_attns: Optional[Tuple[Any, ...]] = () if output_attentions else None - all_cross_attentions: Optional[Tuple[Any, ...]] = ( - () if (output_attentions and encoder_hidden_states is not None) else None - ) - next_decoder_cache: Optional[Tuple[Any, ...]] = () if use_cache else None - - # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired - for attn_mask, mask_name in zip( - [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"] - ): - if attn_mask is not None: - if attn_mask.size()[0] != (len(self.layers)): - raise ValueError( - f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" - f" {attn_mask.size()[0]}." - ) - - for idx, decoder_layer in enumerate(self.layers): - # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) - if all_hidden_states is not None: - all_hidden_states += (hidden_states,) - if self.training: - dropout_probability = torch.rand([]) - if dropout_probability < self.layerdrop: - continue - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = layer_outputs[0] - - if next_decoder_cache is not None: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) - - if all_self_attns is not None: - all_self_attns += (layer_outputs[1],) - - if all_cross_attentions is not None: - all_cross_attentions += (layer_outputs[2],) - - # add hidden states from the last decoder layer - if all_hidden_states is not None: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_cache, - all_hidden_states, - all_self_attns, - all_cross_attentions, - ] - if v is not None - ) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attentions, - ) - - -@add_start_docstrings( - "The bare BART Model outputting raw hidden-states without any specific head on top.", - BART_START_DOCSTRING, -) -class BartModelWithDecoderPositionIds(BartPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] - - def __init__(self, config: BartConfig): - super().__init__(config) - - padding_idx, vocab_size = config.pad_token_id, config.vocab_size - self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) - - self.encoder = BartEncoder(config, self.shared) - self.decoder = BartDecoderWithPositionIds(config, self.shared) - - # Initialize weights and apply final processing - self.post_init() - - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) - self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) - - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, value): - self.shared = value - self.encoder.embed_tokens = self.shared - self.decoder.embed_tokens = self.shared - - def get_encoder(self): - return self.encoder - - def get_decoder(self): - return self.decoder - - @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=Seq2SeqModelOutput, - config_class=_CONFIG_FOR_DOC, - expected_output=_EXPECTED_OUTPUT_SHAPE, - ) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.LongTensor] = None, - decoder_position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - decoder_head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - encoder_outputs: Optional[List[torch.FloatTensor]] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, Seq2SeqModelOutput]: - # different to other models, Bart automatically creates decoder_input_ids from - # input_ids if no decoder_input_ids are provided - if decoder_input_ids is None and decoder_inputs_embeds is None: - if input_ids is None: - raise ValueError( - "If no `decoder_input_ids` or `decoder_inputs_embeds` are " - "passed, `input_ids` cannot be `None`. Please pass either " - "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." - ) - - decoder_input_ids = shift_tokens_right( - input_ids, self.config.pad_token_id, self.config.decoder_start_token_id - ) - - output_attentions = ( - output_attentions if output_attentions is not None else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if encoder_outputs is None: - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, - ) - - if not ( - isinstance(encoder_outputs, BaseModelOutput) or isinstance(encoder_outputs, tuple) - ): - raise ValueError( - "Inconsistent output: The output of the model encoder should be of type " - f"`BaseModelOutput` or tuple, but is of type `{type(encoder_outputs)}`." - ) - - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) - decoder_outputs = self.decoder( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, - encoder_hidden_states=encoder_outputs[0], - encoder_attention_mask=attention_mask, - head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - past_key_values=past_key_values, - inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - if not return_dict: - return decoder_outputs + encoder_outputs - - return Seq2SeqModelOutput( - last_hidden_state=decoder_outputs.last_hidden_state, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, - ) diff --git a/src/pie_modules/models/common/__init__.py b/src/pie_modules/models/common/__init__.py deleted file mode 100644 index b021af996..000000000 --- a/src/pie_modules/models/common/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .has_taskmodule import HasTaskmodule -from .model_with_boilerplate import ModelWithBoilerplate -from .model_with_metrics_from_taskmodule import ModelWithMetricsFromTaskModule -from .stages import TESTING, TRAINING, VALIDATION diff --git a/src/pie_modules/models/common/has_taskmodule.py b/src/pie_modules/models/common/has_taskmodule.py deleted file mode 100644 index 03cd8154c..000000000 --- a/src/pie_modules/models/common/has_taskmodule.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Any, Dict, Optional - -from pie_core import AutoTaskModule, TaskModule - -from pie_modules.models.interface import RequiresTaskmoduleConfig - - -class HasTaskmodule(RequiresTaskmoduleConfig): - """A mixin class for models that have a taskmodule. - - Args: - taskmodule_config: The config for the taskmodule which can be obtained from the - taskmodule.config property. - """ - - def __init__(self, taskmodule_config: Optional[Dict[str, Any]] = None, **kwargs): - super().__init__(**kwargs) - self.taskmodule: Optional[TaskModule] = None - if taskmodule_config is not None: - self.taskmodule = AutoTaskModule.from_config(taskmodule_config) diff --git a/src/pie_modules/models/common/model_with_boilerplate.py b/src/pie_modules/models/common/model_with_boilerplate.py deleted file mode 100644 index a58d51e5a..000000000 --- a/src/pie_modules/models/common/model_with_boilerplate.py +++ /dev/null @@ -1,80 +0,0 @@ -import logging -from typing import Generic, Optional, Tuple, TypeVar - -from typing_extensions import TypeAlias - -from .model_with_metrics_from_taskmodule import ModelWithMetricsFromTaskModule -from .stages import TESTING, TRAINING, VALIDATION - -InputType = TypeVar("InputType") -OutputType = TypeVar("OutputType") -TargetType = TypeVar("TargetType") -StepInputType: TypeAlias = Tuple[InputType, TargetType] -StepOutputType = TypeVar("StepOutputType") - -logger = logging.getLogger(__name__) - - -class ModelWithBoilerplate( - ModelWithMetricsFromTaskModule[InputType, TargetType, OutputType], - Generic[InputType, OutputType, TargetType, StepOutputType], -): - """A PyTorchIEModel that adds boilerplate code for training, validation, and testing. - - Especially, it handles updating the metrics and logging of losses and metric results. Also see - ModelWithMetricsFromTaskModule for more details on how metrics are handled. - """ - - def get_loss_from_outputs(self, outputs: OutputType) -> StepOutputType: - if hasattr(outputs, "loss"): - return outputs.loss - else: - raise ValueError( - f"The model {self.__class__.__name__} does not define a 'loss' attribute in its output, " - "so the loss cannot be automatically extracted from the outputs. Please override the" - "get_loss_from_outputs() method for this model." - ) - - def log_loss(self, stage: str, loss: StepOutputType) -> None: - # show loss on each step only during training - self.log( - f"loss/{stage}", - loss, - on_step=(stage == TRAINING), - on_epoch=True, - prog_bar=True, - sync_dist=True, - ) - - def _step(self, stage: str, batch: StepInputType) -> StepOutputType: - inputs, targets = batch - outputs = self(inputs=inputs, targets=targets) - loss = self.get_loss_from_outputs(outputs=outputs) - self.log_loss(stage=stage, loss=loss) - self.update_metric(inputs=inputs, outputs=outputs, targets=targets, stage=stage) - - return loss - - def training_step(self, batch: StepInputType, batch_idx: int) -> StepOutputType: - return self._step(stage=TRAINING, batch=batch) - - def validation_step(self, batch: StepInputType, batch_idx: int) -> StepOutputType: - return self._step(stage=VALIDATION, batch=batch) - - def test_step(self, batch: StepInputType, batch_idx: int) -> StepOutputType: - return self._step(stage=TESTING, batch=batch) - - def predict_step( - self, batch: StepInputType, batch_idx: int, dataloader_idx: int = 0 - ) -> TargetType: - inputs, targets = batch - return self.predict(inputs=inputs) - - def on_train_epoch_end(self) -> None: - self.log_metric(stage=TRAINING) - - def on_validation_epoch_end(self) -> None: - self.log_metric(stage=VALIDATION) - - def on_test_epoch_end(self) -> None: - self.log_metric(stage=TESTING) diff --git a/src/pie_modules/models/common/model_with_metrics_from_taskmodule.py b/src/pie_modules/models/common/model_with_metrics_from_taskmodule.py deleted file mode 100644 index ea65ffd35..000000000 --- a/src/pie_modules/models/common/model_with_metrics_from_taskmodule.py +++ /dev/null @@ -1,152 +0,0 @@ -import logging -from typing import Dict, Generic, List, Optional, Set, TypeVar, Union - -from pie_core.utils.dictionary import flatten_dict_s -from pytorch_ie import PyTorchIEModel -from torchmetrics import Metric, MetricCollection - -from .has_taskmodule import HasTaskmodule -from .stages import TESTING, TRAINING, VALIDATION - -InputType = TypeVar("InputType") -TargetType = TypeVar("TargetType") -OutputType = TypeVar("OutputType") - -logger = logging.getLogger(__name__) - - -class ModelWithMetricsFromTaskModule( - HasTaskmodule, PyTorchIEModel, Generic[InputType, TargetType, OutputType] -): - """A PyTorchIEModel that adds metrics from a taskmodule. - - The metrics are added to the model as attributes with the names metric_{stage} via - setup_metrics method, where stage is one of "train", "val", or "test". The metrics are updated - with the update_metric method and logged with the log_metric method. - - Args: - metric_stages: The stages for which to set up metrics. Must be one of "train", "val", or - "test". - metric_intervals: A dict mapping metric stages to the number of steps between metric - calculation. If not provided, the metrics are calculated at the end of each epoch. - metric_call_predict: Whether to call predict() and use its result for metric calculation - instead of the (decoded) model output. This is useful, for instance, for generative models - that define special logic to produce predictions, e.g. beam search, which requires multiple - passes through the model. If True, predict() is called for all metric stages. If False (default), - the model outputs are passed to decode() and that is used for all metric stages. If a list of - metric stages is provided, predict() is called for these stages and the (decoded) model - outputs for the remaining stages. - """ - - def __init__( - self, - metric_stages: List[str] = [TRAINING, VALIDATION, TESTING], - metric_intervals: Optional[Dict[str, int]] = None, - metric_call_predict: Union[bool, List[str]] = False, - **kwargs, - ) -> None: - super().__init__(**kwargs) - - self.setup_metrics(metric_stages=metric_stages) - - self.metric_intervals = metric_intervals or {} - missed_stages = set(self.metric_intervals) - set(metric_stages) - if len(missed_stages) > 0: - logger.warning( - f"There are stages in metric_intervals that are not in metric_stages: " - f"{missed_stages}. Available metric stages: {metric_stages}." - ) - - self.use_prediction_for_metrics: Set[str] - if isinstance(metric_call_predict, bool): - self.metric_call_predict = set(metric_stages) if metric_call_predict else set() - else: - self.metric_call_predict = set(metric_call_predict) - missed_stages = self.metric_call_predict - set(metric_stages) - if len(missed_stages) > 0: - logger.warning( - f"There are stages in metric_call_predict that are not in metric_stages: " - f"{missed_stages}. Available metric stages: {metric_stages}." - ) - - def setup_metrics(self, metric_stages: List[str]) -> None: - """Set up metrics for the given stages if a taskmodule is available. - - Args: - metric_stages: The stages for which to set up metrics. Must be one of "train", "val", or - "test". - """ - if self.taskmodule is not None: - for stage in metric_stages: - metric = self.taskmodule.configure_model_metric(stage=stage) - if metric is not None: - self._set_metric(stage=stage, metric=metric) - else: - logger.warning( - f"The taskmodule {self.taskmodule.__class__.__name__} does not define a metric for stage " - f"'{stage}'." - ) - elif len(metric_stages) > 0: - logger.warning( - "No taskmodule is available, so no metrics are set up. " - "Please provide a taskmodule_config to enable metrics for stages " - f"{metric_stages}." - ) - - def _get_metric( - self, stage: str, batch_idx: int = 0 - ) -> Optional[Union[Metric, MetricCollection]]: - metric_interval = self.metric_intervals.get(stage, 1) - if (batch_idx + 1) % metric_interval == 0: - return getattr(self, f"metric_{stage}", None) - else: - return None - - def _set_metric(self, stage: str, metric: Optional[Union[Metric, MetricCollection]]) -> None: - setattr(self, f"metric_{stage}", metric) - - def update_metric( - self, - stage: str, - inputs: InputType, - targets: TargetType, - outputs: OutputType, - ) -> None: - """Update the metric for the given stage. If outputs is provided, the predictions are - decoded from the outputs. Otherwise, the predictions are obtained by directly calling the - predict method with the inputs (note that this causes the model to be called a second - time). Finally, the metric is updated with the predictions and targets. - - Args: - stage: The stage for which to update the metric. Must be one of "train", "val", or "test". - inputs: The inputs to the model. - targets: The targets for the inputs. - outputs: The outputs of the model. They are decoded into predictions if provided. If - outputs is None, the predictions are obtained by directly calling the predict method - on the inputs. - """ - - metric = self._get_metric(stage=stage) - if metric is not None: - if stage in self.metric_call_predict: - predictions = self.predict(inputs=inputs) - else: - predictions = self.decode(inputs=inputs, outputs=outputs) - metric.update(predictions, targets) - - def log_metric(self, stage: str, reset: bool = True) -> None: - """Log the metric for the given stage and reset it.""" - - metric = self._get_metric(stage=stage) - if metric is not None: - values = metric.compute() - log_kwargs = {"on_step": False, "on_epoch": True, "sync_dist": True} - if isinstance(values, dict): - values_flat = flatten_dict_s(values, sep="/") - for key, value in values_flat.items(): - self.log(f"metric/{key}/{stage}", value, **log_kwargs) - else: - metric_name = getattr(metric, "name", None) or type(metric).__name__ - self.log(f"metric/{metric_name}/{stage}", values, **log_kwargs) - if reset: - metric.reset() diff --git a/src/pie_modules/models/common/stages.py b/src/pie_modules/models/common/stages.py deleted file mode 100644 index 7299e2092..000000000 --- a/src/pie_modules/models/common/stages.py +++ /dev/null @@ -1,3 +0,0 @@ -TRAINING = "train" -VALIDATION = "val" -TESTING = "test" diff --git a/src/pie_modules/models/components/__init__.py b/src/pie_modules/models/components/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/pie_modules/models/components/pointer_head.py b/src/pie_modules/models/components/pointer_head.py deleted file mode 100644 index e8a02acfa..000000000 --- a/src/pie_modules/models/components/pointer_head.py +++ /dev/null @@ -1,357 +0,0 @@ -from typing import Dict, List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import CrossEntropyLoss -from transformers.utils import logging - -logger = logging.get_logger(__name__) - - -class PointerHead(torch.nn.Module): - # Copy and generate, - def __init__( - self, - # (decoder) input space - target_token_ids: List[int], - # output space (targets) - bos_id: int, - eos_id: int, - pad_id: int, - # embeddings - embeddings: nn.Embedding, - embedding_weight_mapping: Optional[Dict[Union[int, str], List[int]]] = None, - # other parameters - use_encoder_mlp: bool = False, - use_constraints_encoder_mlp: bool = False, - decoder_position_id_mode: Optional[nn.Module] = None, - decoder_position_id_pattern: Optional[List[int]] = None, - decoder_position_id_mapping: Optional[Dict[str, int]] = None, - ): - super().__init__() - - self.embeddings = embeddings - - self.pointer_offset = len(target_token_ids) - - # check that bos, eos, and pad are not out of bounds - for target_id, target_id_name in zip( - [bos_id, eos_id, pad_id], ["bos_id", "eos_id", "pad_id"] - ): - if target_id >= len(target_token_ids): - raise ValueError( - f"{target_id_name} [{target_id}] must be smaller than the number of target token ids " - f"[{len(target_token_ids)}]!" - ) - - self.bos_id = bos_id - self.eos_id = eos_id - self.pad_id = pad_id - # all ids that are not bos, eos or pad are label ids - self.label_ids = [ - target_id - for target_id in range(len(target_token_ids)) - if target_id not in [self.bos_id, self.eos_id, self.pad_id] - ] - - target2token_id = torch.LongTensor(target_token_ids) - self.register_buffer("target2token_id", target2token_id) - self.label_token_ids = self.target2token_id[self.label_ids] - self.eos_token_id = target_token_ids[self.eos_id] - self.pad_token_id = target_token_ids[self.pad_id] - - hidden_size = self.embeddings.embedding_dim - if use_encoder_mlp: - self.encoder_mlp = nn.Sequential( - nn.Linear(hidden_size, hidden_size), - nn.Dropout(0.3), - nn.ReLU(), - nn.Linear(hidden_size, hidden_size), - ) - if use_constraints_encoder_mlp: - self.constraints_encoder_mlp = nn.Sequential( - nn.Linear(hidden_size, hidden_size), - nn.Dropout(0.3), - nn.ReLU(), - nn.Linear(hidden_size, hidden_size), - ) - - self.embedding_weight_mapping = None - if embedding_weight_mapping is not None: - # Because of config serialization, the keys may be strings. Convert them back to ints. - self.embedding_weight_mapping = { - int(k): v for k, v in embedding_weight_mapping.items() - } - - self.decoder_position_id_mode = decoder_position_id_mode - self.decoder_position_id_mapping = decoder_position_id_mapping - if self.decoder_position_id_mode is None: - pass - elif self.decoder_position_id_mode in ["pattern", "pattern_with_increment"]: - if decoder_position_id_pattern is None: - raise ValueError( - "decoder_position_id_pattern must be provided when using " - 'decoder_position_id_mode="pattern" or "pattern_with_increment"!' - ) - self.register_buffer( - "decoder_position_id_pattern", torch.tensor(decoder_position_id_pattern) - ) - elif self.decoder_position_id_mode == "mapping": - if self.decoder_position_id_mapping is None: - raise ValueError( - 'decoder_position_id_mode="mapping" requires decoder_position_id_mapping to be provided!' - ) - else: - raise ValueError( - f'decoder_position_id_mode="{self.decoder_position_id_mode}" is not supported, ' - 'use one of "pattern", "pattern_with_increment", or "mapping"!' - ) - - @property - def use_prepared_position_ids(self): - return self.decoder_position_id_mode is not None - - def set_embeddings(self, embedding: nn.Embedding) -> None: - self.embeddings = embedding - - def overwrite_embeddings_with_mapping(self) -> None: - """Overwrite individual embeddings with embeddings for other tokens. - - This is useful, for instance, if the label vocabulary is a subset of the source vocabulary. - In this case, this method can be used to initialize each label embedding with one or - multiple (averaged) source embeddings. - """ - if self.embedding_weight_mapping is not None: - for special_token_index, source_indices in self.embedding_weight_mapping.items(): - self.embeddings.weight.data[special_token_index] = self.embeddings.weight.data[ - source_indices - ].mean(dim=0) - - def prepare_decoder_input_ids( - self, - input_ids: torch.LongTensor, - encoder_input_ids: torch.LongTensor, - ) -> torch.LongTensor: - mapping_token_mask = input_ids.lt(self.pointer_offset) - mapped_tokens = input_ids.masked_fill(input_ids.ge(self.pointer_offset), 0) - tag_mapped_tokens = self.target2token_id[mapped_tokens] - - encoder_input_ids_index = input_ids - self.pointer_offset - encoder_input_ids_index = encoder_input_ids_index.masked_fill( - encoder_input_ids_index.lt(0), 0 - ) - encoder_input_length = encoder_input_ids.size(1) - if encoder_input_ids_index.max() >= encoder_input_length: - raise ValueError( - f"encoder_input_ids_index.max() [{encoder_input_ids_index.max()}] must be smaller " - f"than encoder_input_length [{encoder_input_length}]!" - ) - - word_mapped_tokens = encoder_input_ids.gather(index=encoder_input_ids_index, dim=1) - - decoder_input_ids = torch.where( - mapping_token_mask, tag_mapped_tokens, word_mapped_tokens - ).to(torch.long) - - # Note: we do not need to explicitly handle the padding (via a decoder attention mask) because - # it gets automatically mapped to the pad token id - - return decoder_input_ids - - def prepare_decoder_position_ids(self, input_ids: torch.LongTensor) -> torch.LongTensor: - if self.decoder_position_id_mode in ["pattern", "pattern_with_increment"]: - bsz, tokens_len = input_ids.size() - pattern_len = len(self.decoder_position_id_pattern) - # the number of full and partly records. note that tokens_len includes the bos token - repeat_num = (tokens_len - 2) // pattern_len + 1 - position_ids = self.decoder_position_id_pattern.repeat(bsz, repeat_num) - - if self.decoder_position_id_mode == "pattern_with_increment": - position_ids_reshaped = position_ids.view(bsz, -1, pattern_len) - add_shift_pos = ( - torch.arange(0, repeat_num, device=position_ids_reshaped.device) - .repeat(bsz) - .view(bsz, -1) - .unsqueeze(-1) - ) - # multiply by the highest position id in the pattern so that the position ids are unique - # for any decoder_position_id_pattern across all records - add_shift_pos *= max(self.decoder_position_id_pattern) + 1 - position_ids_reshaped = add_shift_pos + position_ids_reshaped - position_ids = position_ids_reshaped.view(bsz, -1).long() - # use start_position_id=0 - start_pos = torch.zeros(bsz, 1, dtype=position_ids.dtype, device=position_ids.device) - # shift by 2 to account for start_position_id=0 and pad_position_id=1 - all_position_ids = torch.cat([start_pos, position_ids + 2], dim=-1) - all_position_ids_truncated = all_position_ids[:bsz, :tokens_len] - - # mask the padding tokens - mask_invalid = input_ids.eq(self.pad_id) - all_position_ids_truncated_masked = all_position_ids_truncated.masked_fill( - mask_invalid, 1 - ) - - return all_position_ids_truncated_masked - elif self.decoder_position_id_mode == "mapping": - # we ignor the typing issue here because we ensure that the mapping is not None in the __init__ - mapping: Dict[str, int] = self.decoder_position_id_mapping # type: ignore - if "default" not in mapping: - raise ValueError( - f"mapping must contain a default entry, but only contains {list(mapping)}!" - ) - position_ids = input_ids.new_full(input_ids.size(), fill_value=mapping["default"]) - # ensure that values for all vocab entries are set first - if "vocab" in mapping: - position_ids[input_ids.lt(self.pointer_offset)] = mapping["vocab"] - already_set: Dict[int, Tuple[str, int]] = {} - for key, value in mapping.items(): - if key in ["default", "vocab"]: - continue - elif key == "bos": - input_id = self.bos_id - elif key == "eos": - input_id = self.eos_id - elif key == "pad": - input_id = self.pad_id - else: - raise ValueError(f"Mapping contains unknown key '{key}' (mapping: {mapping}).") - if already_set.get(input_id, (key, value))[1] != value: - previous_key, previous_value = already_set[input_id] - raise ValueError( - f"Can not set the position ids for '{key}' to {value} because it was already " - f"set to {previous_value} by key '{previous_key}'. Note that both, '{key}' and " - f"'{previous_key}', have the same id ({input_id}), so their position_ids need to " - f"be also the same (position id mapping: {mapping})." - ) - position_ids[input_ids.eq(input_id)] = value - already_set[input_id] = key, value - return position_ids - else: - raise ValueError( - f"decoder_position_id_mode={self.decoder_position_id_mode} not supported!" - ) - - def prepare_decoder_inputs( - self, - input_ids: torch.LongTensor, - encoder_input_ids: torch.LongTensor, - position_ids: Optional[torch.LongTensor] = None, - ) -> Dict[str, torch.Tensor]: - inputs = {} - if self.use_prepared_position_ids: - if position_ids is None: - position_ids = self.prepare_decoder_position_ids(input_ids=input_ids) - inputs["position_ids"] = position_ids - - inputs["input_ids"] = self.prepare_decoder_input_ids( - input_ids=input_ids, - encoder_input_ids=encoder_input_ids, - ) - return inputs - - def forward( - self, - last_hidden_state, - encoder_input_ids, - encoder_last_hidden_state, - encoder_attention_mask, - labels: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.LongTensor] = None, - constraints: Optional[torch.LongTensor] = None, - ): - # assemble the logits - logits = last_hidden_state.new_full( - ( - last_hidden_state.size(0), - last_hidden_state.size(1), - self.pointer_offset + encoder_input_ids.size(-1), - ), - fill_value=-1e24, - ) - - # eos and label scores depend only on the decoder output - # bsz x max_len x 1 - eos_scores = F.linear(last_hidden_state, self.embeddings.weight[[self.eos_token_id]]) - label_embeddings = self.embeddings.weight[self.label_token_ids] - # bsz x max_len x num_class - label_scores = F.linear(last_hidden_state, label_embeddings) - - # the pointer depends on the src token embeddings, the encoder output and the decoder output - # bsz x max_bpe_len x hidden_size - src_outputs = encoder_last_hidden_state - if getattr(self, "encoder_mlp", None) is not None: - src_outputs = self.encoder_mlp(src_outputs) - - # bsz x max_word_len x hidden_size - input_embed = self.embeddings(encoder_input_ids) - - # bsz x max_len x max_word_len - word_scores = torch.einsum("blh,bnh->bln", last_hidden_state, src_outputs) - gen_scores = torch.einsum("blh,bnh->bln", last_hidden_state, input_embed) - avg_word_scores = (gen_scores + word_scores) / 2 - - # never point to the padding or the eos token in the encoder input - # TODO: why not excluding the bos token? seems to give worse results, but not tested extensively - mask_invalid = encoder_attention_mask.eq(0) | encoder_input_ids.eq(self.eos_token_id) - avg_word_scores = avg_word_scores.masked_fill(mask_invalid.unsqueeze(1), -1e32) - - # Note: the remaining row in logits contains the score for the bos token which should be never generated! - logits[:, :, [self.eos_id]] = eos_scores - logits[:, :, self.label_ids] = label_scores - logits[:, :, self.pointer_offset :] = avg_word_scores - - loss = None - # compute the loss if labels are provided - if labels is not None: - loss_fct = CrossEntropyLoss() - logits_resized = logits.reshape(-1, logits.size(-1)) - labels_resized = labels.reshape(-1) - if decoder_attention_mask is None: - raise ValueError("decoder_attention_mask must be provided to compute the loss!") - mask_resized = decoder_attention_mask.reshape(-1) - labels_masked = labels_resized.masked_fill( - ~mask_resized.to(torch.bool), loss_fct.ignore_index - ) - loss = loss_fct(logits_resized, labels_masked) - - # compute the constraints loss if constraints are provided - if constraints is not None: - if getattr(self, "constraints_encoder_mlp", None) is not None: - # TODO: is it fine to apply constraints_encoder_mlp to both src_outputs and label_embeddings? - # This is what the original code seems to do, but this is different from the usage of encoder_mlp. - constraints_src_outputs = self.constraints_encoder_mlp(src_outputs) - constraints_label_embeddings = self.constraints_encoder_mlp(label_embeddings) - else: - constraints_src_outputs = src_outputs - constraints_label_embeddings = label_embeddings - constraints_label_scores = F.linear(last_hidden_state, constraints_label_embeddings) - # bsz x max_len x max_word_len - constraints_word_scores = torch.einsum( - "blh,bnh->bln", last_hidden_state, constraints_src_outputs - ) - constraints_logits = last_hidden_state.new_full( - ( - last_hidden_state.size(0), - last_hidden_state.size(1), - self.pointer_offset + encoder_input_ids.size(-1), - ), - fill_value=-1e24, - ) - constraints_logits[:, :, self.label_ids] = constraints_label_scores - constraints_logits[:, :, self.pointer_offset :] = constraints_word_scores - - mask = constraints >= 0 - constraints_logits_valid = constraints_logits[mask] - constraints_valid = constraints[mask] - loss_c = F.binary_cross_entropy( - torch.sigmoid(constraints_logits_valid), constraints_valid.float() - ) - - if loss is None: - loss = loss_c - else: - loss += loss_c - - return logits, loss diff --git a/src/pie_modules/models/components/pooler.py b/src/pie_modules/models/components/pooler.py index b835e2f9c..e69de29bb 100644 --- a/src/pie_modules/models/components/pooler.py +++ b/src/pie_modules/models/components/pooler.py @@ -1,274 +0,0 @@ -import logging -from typing import Any, Callable, Dict, Tuple, Union - -import torch -from torch import Tensor, cat, nn - -# possible pooler types -CLS_TOKEN = "cls_token" # CLS token -START_TOKENS = "start_tokens" # MTB start tokens concat -MENTION_POOLING = "mention_pooling" # mention token pooling and concat - - -logger = logging.getLogger(__name__) - - -def pool_cls(hidden_state: Tensor, **kwargs) -> Tensor: - return hidden_state[:, 0, :] - - -class AtIndexPooler(nn.Module): - """Pooler that takes the hidden states at given indices. If the index is negative, a learned - embedding is used. - - The indices are expected to have the shape [batch_size, num_indices]. The resulting embeddings are concatenated, - so the output shape is [batch_size, num_indices * input_dim]. - - Args: - input_dim: The input dimension of the hidden state. - num_indices: The number of indices to pool. - offset: An offset to add to the indices. This can be useful if the input is prepared with special - tokens at the beginning / at the end of indexed sequences, and we want to use the hidden state of this - token instead of the first / last token of the sequence. - - Returns: - The pooled hidden states with shape [batch_size, num_indices * input_dim]. - """ - - def __init__(self, input_dim: int, num_indices: int = 2, offset: int = 0, **kwargs): - super().__init__(**kwargs) - self.input_dim = input_dim - self.num_indices = num_indices - self.offset = offset - self.missing_embeddings = nn.Parameter(torch.empty(num_indices, self.input_dim)) - nn.init.normal_(self.missing_embeddings) - - def forward(self, hidden_state: Tensor, indices: Tensor, **kwargs) -> Tensor: - batch_size, seq_len, hidden_size = hidden_state.shape - if indices.shape[1] != self.num_indices: - raise ValueError( - f"number of indices [{indices.shape[1]}] has to be the same as num_types [{self.num_indices}]" - ) - - # respect the offset - indices = indices + self.offset - - # times num_types due to concat - result = torch.zeros( - batch_size, hidden_size * self.num_indices, device=hidden_state.device - ) - for batch_idx, current_indices in enumerate(indices): - current_embeddings = [ - ( - hidden_state[batch_idx, current_indices[i], :] - if current_indices[i] >= 0 - else self.missing_embeddings[i] - ) - for i in range(self.num_indices) - ] - result[batch_idx] = cat(current_embeddings, 0) - return result - - @property - def output_dim(self) -> int: - return self.input_dim * self.num_indices - - -class ArgumentWrappedPooler(nn.Module): - """Wraps a pooler and maps the arguments to the pooler. - - Args: - pooler: The pooler to wrap. - argument_mapping: A mapping from the arguments of the forward method to the arguments of the pooler. - """ - - def __init__( - self, pooler: Union[nn.Module, Callable], argument_mapping: Dict[str, str], **kwargs - ): - super().__init__(**kwargs) - self.pooler = pooler - self.argument_mapping = argument_mapping - - def forward(self, hidden_state: Tensor, **kwargs) -> Tensor: - pooler_kwargs = {} - for k, v in kwargs.items(): - if k in self.argument_mapping: - pooler_kwargs[self.argument_mapping[k]] = v - return self.pooler(hidden_state, **pooler_kwargs) - - -class SpanMaxPooler(nn.Module): - """Pooler that takes the max hidden state over spans. If the start or end index is negative, a - learned. - - embedding is used. The indices are expected to have the shape [batch_size, num_indices]. The resulting embeddings - are concatenated, so the output shape is [batch_size, num_indices * input_dim]. - - Args: - input_dim: The input dimension of the hidden state. - num_indices: The number of indices to pool. - - Returns: - The pooled hidden states with shape [batch_size, num_indices * input_dim]. - """ - - def __init__(self, input_dim: int, num_indices: int = 2, **kwargs): - super().__init__(**kwargs) - self.input_dim = input_dim - self.num_indices = num_indices - self.missing_embeddings = nn.Parameter(torch.empty(num_indices, self.input_dim)) - nn.init.normal_(self.missing_embeddings) - - def forward( - self, hidden_state: Tensor, start_indices: Tensor, end_indices: Tensor, **kwargs - ) -> Tensor: - batch_size, seq_len, hidden_size = hidden_state.shape - if start_indices.shape[1] != self.num_indices: - raise ValueError( - f"number of start indices [{start_indices.shape[1]}] has to be the same as num_types [{self.num_indices}]" - ) - - if end_indices.shape[1] != self.num_indices: - raise ValueError( - f"number of end indices [{end_indices.shape[1]}] has to be the same as num_types [{self.num_indices}]" - ) - - # check that start_indices are before end_indices - mask_both_positive = (start_indices >= 0) & (end_indices >= 0) - mask_start_before_end = start_indices < end_indices - mask_valid = mask_start_before_end | ~mask_both_positive - if not torch.all(mask_valid): - raise ValueError( - f"values in start_indices have to be smaller than respective values in " - f"end_indices, but start_indices=\n{start_indices}\n and end_indices=\n{end_indices}" - ) - - # times num_indices due to concat - result = torch.zeros( - batch_size, hidden_size * self.num_indices, device=hidden_state.device - ) - for batch_idx in range(batch_size): - current_start_indices = start_indices[batch_idx] - current_end_indices = end_indices[batch_idx] - current_embeddings = [ - ( - torch.amax( - hidden_state[ - batch_idx, current_start_indices[i] : current_end_indices[i], : - ], - 0, - ) - if current_start_indices[i] >= 0 and current_end_indices[i] >= 0 - else self.missing_embeddings[i] - ) - for i in range(self.num_indices) - ] - result[batch_idx] = cat(current_embeddings, 0) - - return result - - @property - def output_dim(self) -> int: - return self.input_dim * self.num_indices - - -class SpanMeanPooler(nn.Module): - """Pooler that takes the mean hidden state over spans. If the start or end index is negative, a - learned embedding is used. The indices are expected to have the shape [batch_size, - num_indices]. - - The resulting embeddings are concatenated, so the output shape is [batch_size, num_indices * input_dim]. - Note this a slightly modified version of the pie_modules.models.components.pooler.SpanMaxPooler, - i.e. we changed the aggregation method from torch.amax to torch.mean. - - Args: - input_dim: The input dimension of the hidden state. - num_indices: The number of indices to pool. - - Returns: - The pooled hidden states with shape [batch_size, num_indices * input_dim]. - """ - - def __init__(self, input_dim: int, num_indices: int = 2, **kwargs): - super().__init__(**kwargs) - self.input_dim = input_dim - self.num_indices = num_indices - self.missing_embeddings = nn.Parameter(torch.empty(num_indices, self.input_dim)) - nn.init.normal_(self.missing_embeddings) - - def forward( - self, hidden_state: Tensor, start_indices: Tensor, end_indices: Tensor, **kwargs - ) -> Tensor: - batch_size, seq_len, hidden_size = hidden_state.shape - if start_indices.shape[1] != self.num_indices: - raise ValueError( - f"number of start indices [{start_indices.shape[1]}] has to be the same as num_types [{self.num_indices}]" - ) - - if end_indices.shape[1] != self.num_indices: - raise ValueError( - f"number of end indices [{end_indices.shape[1]}] has to be the same as num_types [{self.num_indices}]" - ) - - # check that start_indices are before end_indices - mask_both_positive = (start_indices >= 0) & (end_indices >= 0) - mask_start_before_end = start_indices < end_indices - mask_valid = mask_start_before_end | ~mask_both_positive - if not torch.all(mask_valid): - raise ValueError( - f"values in start_indices have to be smaller than respective values in " - f"end_indices, but start_indices=\n{start_indices}\n and end_indices=\n{end_indices}" - ) - - # times num_indices due to concat - result = torch.zeros( - batch_size, hidden_size * self.num_indices, device=hidden_state.device - ) - for batch_idx in range(batch_size): - current_start_indices = start_indices[batch_idx] - current_end_indices = end_indices[batch_idx] - current_embeddings = [ - ( - torch.mean( - hidden_state[ - batch_idx, current_start_indices[i] : current_end_indices[i], : - ], - dim=0, - ) - if current_start_indices[i] >= 0 and current_end_indices[i] >= 0 - else self.missing_embeddings[i] - ) - for i in range(self.num_indices) - ] - result[batch_idx] = cat(current_embeddings, 0) - - return result - - @property - def output_dim(self) -> int: - return self.input_dim * self.num_indices - - -def get_pooler_and_output_size(config: Dict[str, Any], input_dim: int) -> Tuple[Callable, int]: - pooler_config = dict(config) - pooler_type = pooler_config.pop("type", CLS_TOKEN) - if pooler_type == CLS_TOKEN: - return pool_cls, input_dim - elif pooler_type == START_TOKENS: - pooler = AtIndexPooler(input_dim=input_dim, offset=-1, **pooler_config) - pooler_wrapped = ArgumentWrappedPooler( - pooler=pooler, argument_mapping={"start_indices": "indices"} - ) - return pooler_wrapped, pooler.output_dim - elif pooler_type == MENTION_POOLING: - aggregate = pooler_config.pop("aggregate", "max") - if aggregate == "max": - pooler = SpanMaxPooler(input_dim=input_dim, **pooler_config) - return pooler, pooler.output_dim - elif aggregate == "mean": - pooler = SpanMeanPooler(input_dim=input_dim, **pooler_config) - return pooler, pooler.output_dim - else: - raise ValueError(f'Unknown aggregation method for mention pooling: "{aggregate}"') - else: - raise ValueError(f'Unknown pooler type "{pooler_type}"') diff --git a/src/pie_modules/models/components/seq2seq_encoder.py b/src/pie_modules/models/components/seq2seq_encoder.py deleted file mode 100644 index 52866be9b..000000000 --- a/src/pie_modules/models/components/seq2seq_encoder.py +++ /dev/null @@ -1,77 +0,0 @@ -import logging -from copy import copy -from typing import Any, Dict, List, Optional, Tuple - -from torch import Tensor, nn - -logger = logging.getLogger(__name__) - -RNN_TYPE2CLASS = {"lstm": nn.LSTM, "gru": nn.GRU, "rnn": nn.RNN} -ACTIVATION_TYPE2CLASS = { - "relu": nn.ReLU, - "tanh": nn.Tanh, - "sigmoid": nn.Sigmoid, - "gelu": nn.GELU, -} - - -class RNNWrapper(nn.Module): - def __init__(self, rnn: nn.Module): - super().__init__() - self.rnn = rnn - - def forward(self, *args, **kwargs) -> Tensor: - return self.rnn(*args, **kwargs)[0] - - @property - def output_size(self) -> int: - if self.rnn.bidirectional: - return self.rnn.hidden_size * 2 - else: - return self.rnn.hidden_size - - -def build_seq2seq_encoder( - config: Dict[str, Any], input_size: int -) -> Tuple[Optional[nn.Module], int]: - # copy the config to avoid side effects - config = copy(config) - seq2seq_encoder_type = config.pop("type", None) - if seq2seq_encoder_type is None: - logger.warning( - f"seq2seq_encoder_type is not specified in the seq2seq_encoder: {config}. " - f"Do not build this seq2seq_encoder." - ) - return None, input_size - - if seq2seq_encoder_type == "sequential": - modules: List[nn.Module] = [] - output_size = input_size - for key, subconfig in config.items(): - module, output_size = build_seq2seq_encoder(subconfig, input_size) - if module is not None: - modules.append(module) - input_size = output_size - - seq2seq_encoder = nn.Sequential(*modules) - elif seq2seq_encoder_type in RNN_TYPE2CLASS: - rnn_class = RNN_TYPE2CLASS[seq2seq_encoder_type] - seq2seq_encoder = RNNWrapper(rnn_class(input_size=input_size, batch_first=True, **config)) - output_size = seq2seq_encoder.output_size - elif seq2seq_encoder_type == "linear": - seq2seq_encoder = nn.Linear(in_features=input_size, **config) - output_size = seq2seq_encoder.out_features - elif seq2seq_encoder_type in ACTIVATION_TYPE2CLASS: - activation_class = ACTIVATION_TYPE2CLASS[seq2seq_encoder_type] - seq2seq_encoder = activation_class(**config) - output_size = input_size - elif seq2seq_encoder_type == "dropout": - seq2seq_encoder = nn.Dropout(**config) - output_size = input_size - elif seq2seq_encoder_type == "none": - seq2seq_encoder = None - output_size = input_size - else: - raise ValueError(f"Unknown seq2seq_encoder_type: {seq2seq_encoder_type}") - - return seq2seq_encoder, output_size diff --git a/src/pie_modules/models/interface.py b/src/pie_modules/models/interface.py deleted file mode 100644 index a5bc0bfad..000000000 --- a/src/pie_modules/models/interface.py +++ /dev/null @@ -1,12 +0,0 @@ -class RequiresMaxInputLength: - """Any class inheriting from this class should require a constructor parameter - 'max_input_length'.""" - - pass - - -class RequiresTaskmoduleConfig: - """Any class inheriting from this class should require a constructor parameter - 'taskmodule_config'.""" - - pass diff --git a/src/pie_modules/models/sequence_classification_with_pooler.py b/src/pie_modules/models/sequence_classification_with_pooler.py deleted file mode 100644 index bb7d72785..000000000 --- a/src/pie_modules/models/sequence_classification_with_pooler.py +++ /dev/null @@ -1,362 +0,0 @@ -import logging -from abc import ABC, abstractmethod -from typing import ( - Any, - Callable, - Dict, - Iterator, - List, - MutableMapping, - Optional, - Tuple, - TypeVar, - Union, -) - -import torch -from pytorch_ie import PyTorchIEModel -from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses -from torch import FloatTensor, LongTensor, nn -from torch.nn import Parameter -from torch.optim import AdamW -from transformers import ( - AutoConfig, - AutoModel, - PreTrainedModel, - get_linear_schedule_with_warmup, -) -from transformers.modeling_outputs import SequenceClassifierOutput -from typing_extensions import TypeAlias - -from .common import ModelWithBoilerplate -from .components.pooler import get_pooler_and_output_size - -# model inputs / outputs / targets -InputType: TypeAlias = MutableMapping[str, LongTensor] -OutputType: TypeAlias = SequenceClassifierOutput -TargetType: TypeAlias = MutableMapping[str, Union[LongTensor, FloatTensor]] -# step inputs (batch) / outputs (loss) -StepInputType: TypeAlias = Tuple[InputType, TargetType] -StepOutputType: TypeAlias = FloatTensor - - -HF_MODEL_TYPE_TO_CLASSIFIER_DROPOUT_ATTRIBUTE = { - "albert": "classifier_dropout_prob", - "distilbert": "seq_classif_dropout", -} - -logger = logging.getLogger(__name__) - -T = TypeVar("T") - - -def separate_arguments_by_prefix( - arguments: MutableMapping[str, T], prefixes: List[str] -) -> Dict[str, Dict[str, T]]: - result: Dict[str, Dict[str, T]] = {prefix: {} for prefix in prefixes + ["remaining"]} - for k, v in arguments.items(): - found = False - for prefix in prefixes: - if k.startswith(prefix): - result[prefix][k[len(prefix) :]] = v - found = True - break - if not found: - result["remaining"][k] = v - return result - - -class SequenceClassificationModelWithPoolerBase( - ABC, - ModelWithBoilerplate[InputType, OutputType, TargetType, StepOutputType], - RequiresModelNameOrPath, -): - """Abstract base model for sequence classification with a pooler. - - Args: - model_name_or_path: The name or path of the HuggingFace model to use. - tokenizer_vocab_size: The size of the tokenizer vocabulary. If provided, the model's - tokenizer embeddings are resized to this size. - classifier_dropout: The dropout probability for the classifier. If not provided, the - dropout probability is taken from the Huggingface model config. - learning_rate: The learning rate for the optimizer. - task_learning_rate: The learning rate for the task-specific parameters. If None, the - learning rate for all parameters is set to `learning_rate`. - warmup_proportion: The proportion of steps to warm up the learning rate. - pooler: The pooler configuration. If None, CLS token pooling is used. - freeze_base_model: If True, the base model parameters are frozen. - base_model_prefix: The prefix of the base model parameters when using a task_learning_rate - or freeze_base_model. If None, the base_model_prefix of the model is used. - **kwargs: Additional keyword arguments passed to the parent class, - see :class:`ModelWithBoilerplate`. - """ - - def __init__( - self, - model_name_or_path: str, - tokenizer_vocab_size: Optional[int] = None, - classifier_dropout: Optional[float] = None, - learning_rate: float = 1e-5, - task_learning_rate: Optional[float] = None, - warmup_proportion: float = 0.1, - pooler: Optional[Union[Dict[str, Any], str]] = None, - freeze_base_model: bool = False, - **kwargs, - ) -> None: - super().__init__(**kwargs) - - self.save_hyperparameters() - - self.learning_rate = learning_rate - self.task_learning_rate = task_learning_rate - self.warmup_proportion = warmup_proportion - self.freeze_base_model = freeze_base_model - self.model_name_or_path = model_name_or_path - - self.model = self.setup_base_model() - - if tokenizer_vocab_size is not None: - self.model.resize_token_embeddings(tokenizer_vocab_size) - - if self.freeze_base_model: - for param in self.model.parameters(): - param.requires_grad = False - - if classifier_dropout is None: - # Get the classifier dropout value from the Huggingface model config. - # This is a bit of a mess since some Configs use different variable names or change the semantics - # of the dropout (e.g. DistilBert has one dropout prob for QA and one for Seq classification, and a - # general one for embeddings, encoder and pooler). - classifier_dropout_attr = HF_MODEL_TYPE_TO_CLASSIFIER_DROPOUT_ATTRIBUTE.get( - self.model.config.model_type, "classifier_dropout" - ) - classifier_dropout = getattr(self.model.config, classifier_dropout_attr) or 0.0 - self.dropout = nn.Dropout(classifier_dropout) - - if isinstance(pooler, str): - pooler = {"type": pooler} - self.pooler_config = pooler or {} - self.pooler, pooler_output_dim = self.setup_pooler(input_dim=self.model.config.hidden_size) - self.classifier = self.setup_classifier(pooler_output_dim=pooler_output_dim) - self.loss_fct = self.setup_loss_fct() - - def setup_base_model(self) -> PreTrainedModel: - config = AutoConfig.from_pretrained(self.model_name_or_path) - if self.is_from_pretrained: - return AutoModel.from_config(config=config) - else: - return AutoModel.from_pretrained(self.model_name_or_path, config=config) - - @abstractmethod - def setup_classifier(self, pooler_output_dim: int) -> Callable: - pass - - @abstractmethod - def setup_loss_fct(self) -> Callable: - pass - - def setup_pooler(self, input_dim: int) -> Tuple[Callable, int]: - """Set up the pooler. The pooler is used to get a representation of the input sequence(s) - that can be used by the classifier. It is a callable that takes the hidden states of the - base model (and additional model inputs that are prefixed with "pooler_") and returns the - pooled output. - - Args: - input_dim: The input dimension of the pooler, i.e. the hidden size of the base model. - - Returns: - A tuple with the pooler and the output dimension of the pooler. - """ - return get_pooler_and_output_size(config=self.pooler_config, input_dim=input_dim) - - def get_pooled_output(self, model_inputs, pooler_inputs) -> torch.FloatTensor: - output = self.model(**model_inputs) - hidden_state = output.last_hidden_state - pooled_output = self.pooler(hidden_state, **pooler_inputs) - pooled_output = self.dropout(pooled_output) - return pooled_output - - def forward( - self, - inputs: InputType, - targets: Optional[TargetType] = None, - return_hidden_states: bool = False, - ) -> OutputType: - sanitized_inputs = separate_arguments_by_prefix(arguments=inputs, prefixes=["pooler_"]) - - pooled_output = self.get_pooled_output( - model_inputs=sanitized_inputs["remaining"], pooler_inputs=sanitized_inputs["pooler_"] - ) - - logits = self.classifier(pooled_output) - - result = {"logits": logits} - if targets is not None: - labels = targets["labels"] - loss = self.loss_fct(logits, labels) - result["loss"] = loss - if return_hidden_states: - raise NotImplementedError("return_hidden_states is not yet implemented") - - return SequenceClassifierOutput(**result) - - @abstractmethod - def decode(self, inputs: InputType, outputs: OutputType) -> TargetType: - pass - - def base_model_named_parameters(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]: - if prefix: - prefix = f"{prefix}." - return self.model.named_parameters(prefix=f"{prefix}model") - - def task_named_parameters(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]: - if prefix: - prefix = f"{prefix}." - base_model_parameter_names = dict(self.base_model_named_parameters(prefix=prefix)).keys() - for name, param in self.named_parameters(prefix=prefix): - if name not in base_model_parameter_names: - yield name, param - - def configure_optimizers(self): - if self.task_learning_rate is not None: - base_model_params = (param for name, param in self.base_model_named_parameters()) - task_params = (param for name, param in self.task_named_parameters()) - optimizer = AdamW( - [ - {"params": base_model_params, "lr": self.learning_rate}, - {"params": task_params, "lr": self.task_learning_rate}, - ] - ) - else: - optimizer = AdamW(self.parameters(), lr=self.learning_rate) - - if self.warmup_proportion > 0.0: - stepping_batches = self.trainer.estimated_stepping_batches - scheduler = get_linear_schedule_with_warmup( - optimizer, int(stepping_batches * self.warmup_proportion), stepping_batches - ) - return [optimizer], [{"scheduler": scheduler, "interval": "step"}] - else: - return optimizer - - -@PyTorchIEModel.register() -class SequenceClassificationModelWithPooler( - SequenceClassificationModelWithPoolerBase, - RequiresNumClasses, -): - """A sequence classification model that uses a pooler to get a representation of the input - sequence and then applies a linear classifier to that representation. The pooler can be - configured via the `pooler` argument, see :func:`get_pooler_and_output_size` for details. - - Args: - num_classes: The number of classes for the classification task. - multi_label: If True, the model is trained as a multi-label classifier. - multi_label_threshold: The threshold for the multi-label classifier, i.e. the probability - above which a class is predicted. - **kwargs - """ - - def __init__( - self, - num_classes: int, - multi_label: bool = False, - multi_label_threshold: float = 0.5, - **kwargs, - ): - # set num_classes and multi_label before call to super init because they are used there - # in setup_classifier and setup_loss_fct - self.num_classes = num_classes - self.multi_label = multi_label - super().__init__(**kwargs) - - self.multi_label_threshold = multi_label_threshold - - def setup_classifier(self, pooler_output_dim: int) -> Callable: - return nn.Linear(pooler_output_dim, self.num_classes) - - def setup_loss_fct(self) -> Callable: - return nn.BCEWithLogitsLoss() if self.multi_label else nn.CrossEntropyLoss() - - def decode(self, inputs: InputType, outputs: OutputType) -> TargetType: - if not self.multi_label: - labels = torch.argmax(outputs.logits, dim=-1).to(torch.long) - probabilities = torch.softmax(outputs.logits, dim=-1) - else: - probabilities = torch.sigmoid(outputs.logits) - labels = (probabilities > self.multi_label_threshold).to(torch.long) - return {"labels": labels, "probabilities": probabilities} - - -@PyTorchIEModel.register() -class SequencePairSimilarityModelWithPooler( - SequenceClassificationModelWithPoolerBase, -): - """A span pair similarity model to detect of two spans occurring in different texts are - similar. It uses an encoder to independently calculate contextualized embeddings of both texts, - then uses a pooler to get representations of the spans and, finally, calculates the cosine to - get the similarity scores. - - Args: - label_threshold: The threshold above which score the spans are considered as similar. - pooler: The pooler identifier or config, see :func:`get_pooler_and_output_size` for details. - Defaults to "mention_pooling" (max pooling over the span token embeddings). - **kwargs - """ - - def __init__( - self, - pooler: Optional[Union[Dict[str, Any], str]] = None, - **kwargs, - ): - if pooler is None: - # use (max) mention pooling per default - pooler = {"type": "mention_pooling", "num_indices": 1} - super().__init__(pooler=pooler, **kwargs) - - def setup_classifier( - self, pooler_output_dim: int - ) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]: - return torch.nn.functional.cosine_similarity - - def setup_loss_fct(self) -> Callable: - return nn.BCELoss() - - def forward( - self, - inputs: InputType, - targets: Optional[TargetType] = None, - return_hidden_states: bool = False, - ) -> OutputType: - sanitized_inputs = separate_arguments_by_prefix( - # Note that the order of the prefixes is important because one is a prefix of the other, - # so we need to start with the longer! - arguments=inputs, - prefixes=["pooler_pair_", "pooler_"], - ) - - pooled_output = self.get_pooled_output( - model_inputs=sanitized_inputs["remaining"]["encoding"], - pooler_inputs=sanitized_inputs["pooler_"], - ) - pooled_output_pair = self.get_pooled_output( - model_inputs=sanitized_inputs["remaining"]["encoding_pair"], - pooler_inputs=sanitized_inputs["pooler_pair_"], - ) - - logits = self.classifier(pooled_output, pooled_output_pair) - - result = {"logits": logits} - if targets is not None: - labels = targets["scores"] - loss = self.loss_fct(logits, labels) - result["loss"] = loss - if return_hidden_states: - raise NotImplementedError("return_hidden_states is not yet implemented") - - return SequenceClassifierOutput(**result) - - def decode(self, inputs: InputType, outputs: OutputType) -> TargetType: - # probabilities = torch.sigmoid(outputs.logits) - scores = outputs.logits - return {"scores": scores} diff --git a/src/pie_modules/models/simple_extractive_question_answering.py b/src/pie_modules/models/simple_extractive_question_answering.py deleted file mode 100644 index d54436515..000000000 --- a/src/pie_modules/models/simple_extractive_question_answering.py +++ /dev/null @@ -1,165 +0,0 @@ -from typing import Any, Dict, MutableMapping, Optional, Tuple - -from pytorch_ie import PyTorchIEModel -from pytorch_ie.models.interface import RequiresModelNameOrPath -from pytorch_lightning.utilities.types import OptimizerLRScheduler -from torch import Tensor -from torch.nn import ModuleDict, functional -from torch.optim import Adam -from torchmetrics import F1Score -from transformers import ( - AutoConfig, - AutoModelForQuestionAnswering, - BatchEncoding, - get_linear_schedule_with_warmup, -) -from transformers.modeling_outputs import QuestionAnsweringModelOutput -from typing_extensions import TypeAlias - -from pie_modules.models.interface import RequiresMaxInputLength - -BatchOutput: TypeAlias = Dict[str, Any] - -# The input to the forward method of this model. It is passed to -# the base transformer model. -ModelInputType: TypeAlias = MutableMapping[str, Any] -# The output of the forward method of this model. -ModelOutputType: TypeAlias = QuestionAnsweringModelOutput -# The input to the step methods, i.e. training_step, validation_step, test_step. -# It contains the input and target tensors for a single training step. -StepBatchEncoding: TypeAlias = Tuple[ - ModelInputType, - Optional[Dict[str, Tensor]], -] - - -TRAINING = "train" -VALIDATION = "val" -TEST = "test" - - -@PyTorchIEModel.register() -class SimpleExtractiveQuestionAnsweringModel( - PyTorchIEModel, RequiresModelNameOrPath, RequiresMaxInputLength -): - """A PIE model for extractive question answering. It is a simple Pytorch-Lightning module that - wraps around a question answering model from the Huggingface transformers library. The - ExtractiveQuestionAnsweringTaskModule can be used create the input and target encodings as well - as to decode the model output. - - Args: - model_name_or_path: The name (Huggingface Hub model identifier) or local path of the model to use. - max_input_length: The maximum length of the input sequence. Required for metric calculation. - learning_rate: The learning rate to use for training. Defaults to 1e-5. - """ - - def __init__( - self, - model_name_or_path: str, - max_input_length: int, - learning_rate: float = 1e-5, - warmup_proportion: float = 0.0, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.save_hyperparameters() - - self.learning_rate = learning_rate - self.warmup_proportion = warmup_proportion - self.max_input_length = max_input_length - - config = AutoConfig.from_pretrained(model_name_or_path) - if self.is_from_pretrained: - self.model = AutoModelForQuestionAnswering.from_config(config=config) - else: - self.model = AutoModelForQuestionAnswering.from_pretrained( - model_name_or_path, config=config - ) - - self.f1_start: Dict[str, F1Score] = ModuleDict( - { - f"stage_{stage}": F1Score(task="multiclass", num_classes=max_input_length) - for stage in [TRAINING, VALIDATION, TEST] - } - ) - self.f1_end: Dict[str, F1Score] = ModuleDict( - { - f"stage_{stage}": F1Score(task="multiclass", num_classes=max_input_length) - for stage in [TRAINING, VALIDATION, TEST] - } - ) - - def forward(self, inputs: BatchEncoding) -> ModelOutputType: - return self.model(**inputs) - - def step( - self, - stage: str, - batch: StepBatchEncoding, - ) -> Tensor: - inputs, targets = batch - if targets is None: - raise ValueError("targets has to be available for training, but it is None") - - output = self({**inputs, **targets}) - - loss = output.loss - # show loss on each step only during training - self.log(f"{stage}/loss", loss, on_step=(stage == TRAINING), on_epoch=True, prog_bar=True) - - start_positions = targets["start_positions"] - end_positions = targets["end_positions"] - start_logits = output.start_logits - end_logits = output.end_logits - - sequence_length = inputs["input_ids"].size(1) - f1_start = self.f1_end[f"stage_{stage}"] - # We need to pad the logits to the max_input_length, otherwise the F1 metric complains - # that the shape does not match the num_classes. - start_logits_padded = functional.pad( - start_logits, (0, self.max_input_length - sequence_length), value=float("-inf") - ) - f1_start(start_logits_padded, start_positions) - self.log( - f"{stage}/f1_start", - f1_start, - on_step=(stage == TRAINING), - on_epoch=True, - prog_bar=True, - ) - f1_end = self.f1_end[f"stage_{stage}"] - # We need to pad the logits to the max_input_length, otherwise the F1 metric complains - # that the shape does not match the num_classes. - end_logits_padded = functional.pad( - end_logits, (0, self.max_input_length - sequence_length), value=float("-inf") - ) - f1_end(end_logits_padded, end_positions) - self.log( - f"{stage}/f1_end", f1_end, on_step=(stage == TRAINING), on_epoch=True, prog_bar=True - ) - # log f1 as simple average of start and end f1. we need to call compute() on the metric to get - # the actual value, otherwise lightning complains that there is no model attribute with name "f1" - f1_value = (f1_start.compute() + f1_end.compute()) / 2 - self.log(f"{stage}/f1", f1_value, on_step=False, on_epoch=True, prog_bar=True) - return loss - - def training_step(self, batch: StepBatchEncoding, batch_idx: int) -> Tensor: - return self.step(stage=TRAINING, batch=batch) - - def validation_step(self, batch: StepBatchEncoding, batch_idx: int) -> Tensor: - return self.step(stage=VALIDATION, batch=batch) - - def test_step(self, batch: StepBatchEncoding, batch_idx: int) -> Tensor: - return self.step(stage=TEST, batch=batch) - - def configure_optimizers(self) -> OptimizerLRScheduler: - optimizer = Adam(self.parameters(), lr=self.learning_rate) - - if self.warmup_proportion > 0.0: - stepping_batches = self.trainer.estimated_stepping_batches - scheduler = get_linear_schedule_with_warmup( - optimizer, int(stepping_batches * self.warmup_proportion), stepping_batches - ) - return [optimizer], [{"scheduler": scheduler, "interval": "step"}] - else: - return optimizer diff --git a/src/pie_modules/models/simple_generative.py b/src/pie_modules/models/simple_generative.py deleted file mode 100644 index e5d7d4549..000000000 --- a/src/pie_modules/models/simple_generative.py +++ /dev/null @@ -1,196 +0,0 @@ -import copy -import logging -from typing import Any, Dict, Optional, Tuple, Type, Union - -import torch -from pie_core.utils.hydra import resolve_type -from pytorch_ie import PyTorchIEModel -from pytorch_lightning.utilities.types import OptimizerLRScheduler -from torch import FloatTensor, LongTensor -from torch.optim import Optimizer -from transformers import PreTrainedModel, SchedulerType, get_scheduler -from transformers.modeling_outputs import Seq2SeqLMOutput -from typing_extensions import TypeAlias - -from pie_modules.models.common import ModelWithBoilerplate - -logger = logging.getLogger(__name__) - -# model inputs / outputs / targets -InputType: TypeAlias = Dict[str, LongTensor] -OutputType: TypeAlias = Seq2SeqLMOutput -TargetType: TypeAlias = Dict[str, LongTensor] -# step inputs (batch) / outputs (loss) -StepInputType: TypeAlias = Tuple[InputType, TargetType] -StepOutputType: TypeAlias = FloatTensor - - -@PyTorchIEModel.register() -class SimpleGenerativeModel( - ModelWithBoilerplate[InputType, OutputType, TargetType, StepOutputType], -): - """This model is a simple wrapper around a generative model from Huggingface transformers. That - means, its predict() and predict_step() methods will call the generate() method of the base - model. - - If a taskmodule config is provided, the taskmodule will be instantiated and used to create metrics and - a generation config with its configure_model_metric() and configure_model_generation() methods, - respectively. - - If the base model has a configure_optimizer() method, this will be used to create the optimizer. Otherwise, - the optimizer_type and learning_rate will be used to create an optimizer. - - Args: - base_model_type: The type of the base model, e.g. "transformers.AutoModelForSeq2SeqLM". It should have a - from_pretrained() method. - base_model_config: A dictionary with the keyword arguments that will be passed to the from_pretrained() - method of the base model. - override_generation_kwargs: The generation config for the base model. This will override the generation config - from the taskmodule, if one is provided. - warmup_proportion: The proportion of the training steps that will be used for the warmup of the learning rate - scheduler. - learning_rate: The learning rate for the optimizer. If the base model has a configure_optimizer() method, this - will be ignored. - optimizer_type: The type of the optimizer. If the base model has a configure_optimizer() method, this will be - ignored. - **kwargs: Additional keyword arguments that will be passed to the PyTorchIEModel constructor. - """ - - def __init__( - self, - # base model setup - base_model: Optional[Dict[str, Any]] = None, - # old setup - base_model_type: Optional[str] = None, - base_model_config: Optional[Dict[str, Any]] = None, - # generation - override_generation_kwargs: Optional[Dict[str, Any]] = None, - # optimizer / schedular - # important: the following entries (optimizer_type and learning_rate) are only used - # if the base model does not have a configure_optimizer method! - optimizer_type: Optional[Union[str, Type[Optimizer]]] = None, - learning_rate: Optional[float] = None, - warmup_proportion: float = 0.0, - scheduler_name: Optional[Union[str, SchedulerType]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, - ): - super().__init__(**kwargs) - - if base_model is None: - if base_model_type is None: - raise ValueError( - "Either base_model or base_model_type must be provided. If base_model is not provided, " - "base_model_type must be a valid model type, e.g. 'transformers.AutoModelForSeq2SeqLM'." - ) - logger.warning( - "The base_model_type and base_model_config arguments are deprecated. Please use base_model. " - "You can use the following code to create the base_model argument: " - "base_model = {'_type_': base_model_type, **base_model_config}" - ) - base_model = {"_type_": base_model_type, **(base_model_config or {})} - - if scheduler_name is None and warmup_proportion > 0.0: - logger.warning( - "warmup_proportion is set to a value > 0.0, but scheduler_name is not set. " - "Setting scheduler_name to 'linear' by default." - ) - scheduler_name = "linear" - - self.save_hyperparameters(ignore=["base_model_type", "base_model_config"]) - - # optimizer / scheduler - self.learning_rate = learning_rate - self.optimizer_type = optimizer_type - self.scheduler_name = scheduler_name - self.warmup_proportion = warmup_proportion - self.scheduler_kwargs = scheduler_kwargs or {} - - self.model = self.setup_base_model(config=base_model) - self.generation_config = self.configure_generation(**(override_generation_kwargs or {})) - - def setup_base_model(self, config: Dict[str, Any]) -> PreTrainedModel: - config = copy.copy(config) - resolved_base_model_type: Type[PreTrainedModel] = resolve_type(config.pop("_type_")) - return resolved_base_model_type.from_pretrained(**config) - - def configure_generation(self, **kwargs) -> Dict[str, Any]: - if self.taskmodule is not None: - # get the generation config from the taskmodule - generation_config = self.taskmodule.configure_model_generation() - else: - logger.warning( - "No taskmodule is available, so no generation config will be created. Consider " - "setting taskmodule_config to a valid taskmodule config to use specific setup for generation." - ) - generation_config = {} - generation_config.update(kwargs) - return generation_config - - def predict(self, inputs, **kwargs) -> TargetType: - is_training = self.training - self.eval() - - generation_kwargs = copy.deepcopy(self.generation_config) - generation_kwargs.update(kwargs) - outputs = self.model.generate(**inputs, **generation_kwargs) - - if is_training: - self.train() - - # TODO: move into base model? or does this work for "all" generative models? - # strip the bos_id - if isinstance(outputs, torch.Tensor): - return {"labels": outputs[:, 1:]} - else: - raise ValueError(f"Unsupported output type: {type(outputs)}") - - def forward(self, inputs: InputType, targets: Optional[TargetType] = None) -> OutputType: - kwargs = {**inputs, **(targets or {})} - return self.model(**kwargs) - - def decode(self, inputs: InputType, outputs: OutputType) -> TargetType: - # construct prediction from the model output - logits = outputs.logits - # get the indices (these are without the initial bos_ids, see above) - prediction = torch.argmax(logits, dim=-1) - return {"labels": prediction.to(torch.long)} - - def configure_optimizers(self) -> OptimizerLRScheduler: - if hasattr(self.model, "configure_optimizer") and callable(self.model.configure_optimizer): - if self.learning_rate is not None: - raise ValueError( - f"learning_rate is set to {self.learning_rate}, but the *base model* ({type(self.model)}) has a " - f"configure_optimizer method. Please set learning_rate to None and configure the optimizer " - f"inside the *base model*." - ) - optimizer = self.model.configure_optimizer() - else: - logger.warning( - f"The model does not have a configure_optimizer method. Creating an optimizer of " - f"optimizer_type={self.optimizer_type} with the learning_rate={self.learning_rate} instead." - ) - if self.optimizer_type is None: - raise ValueError( - f"optimizer_type is None, but the *base model* ({type(self.model)}) does not have a " - f"configure_optimizer method. Please set the optimizer_type to a valid optimizer type, " - f"e.g. optimizer_type=torch.optim.Adam." - ) - resolved_optimizer_type = resolve_type( - self.optimizer_type, expected_super_type=Optimizer - ) - optimizer = resolved_optimizer_type(self.parameters(), lr=self.learning_rate) - - if self.scheduler_name is not None: - num_training_steps = self.trainer.estimated_stepping_batches - num_warmup_steps = int(num_training_steps * self.warmup_proportion) - scheduler = get_scheduler( - name=self.scheduler_name, - optimizer=optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps, - scheduler_specific_kwargs=self.scheduler_kwargs, - ) - return [optimizer], [{"scheduler": scheduler, "interval": "step"}] - else: - return optimizer diff --git a/src/pie_modules/models/simple_sequence_classification.py b/src/pie_modules/models/simple_sequence_classification.py deleted file mode 100644 index 651a043a2..000000000 --- a/src/pie_modules/models/simple_sequence_classification.py +++ /dev/null @@ -1,140 +0,0 @@ -import logging -from typing import Iterator, MutableMapping, Optional, Tuple, Union - -import torch.nn -from pytorch_ie import PyTorchIEModel -from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses -from torch import FloatTensor, LongTensor -from torch.nn import Parameter -from torch.optim import AdamW -from transformers import ( - AutoConfig, - AutoModelForSequenceClassification, - get_linear_schedule_with_warmup, -) -from transformers.modeling_outputs import SequenceClassifierOutput -from typing_extensions import TypeAlias - -from pie_modules.models.common import ModelWithBoilerplate - -# model inputs / outputs / targets -InputType: TypeAlias = MutableMapping[str, LongTensor] -OutputType: TypeAlias = SequenceClassifierOutput -TargetType: TypeAlias = MutableMapping[str, Union[LongTensor, FloatTensor]] -# step inputs (batch) / outputs (loss) -StepInputType: TypeAlias = Tuple[InputType, TargetType] -StepOutputType: TypeAlias = FloatTensor - - -logger = logging.getLogger(__name__) - - -@PyTorchIEModel.register() -class SimpleSequenceClassificationModel( - ModelWithBoilerplate[InputType, OutputType, TargetType, StepOutputType], - RequiresModelNameOrPath, - RequiresNumClasses, -): - """A simple sequence classification model. It wraps a HuggingFace - AutoModelForSequenceClassification and adds boilerplate code for training and inference. - - Args: - model_name_or_path: The name or path of the HuggingFace model to use. - num_classes: The number of classes for the classification task. - tokenizer_vocab_size: The size of the tokenizer vocabulary. If provided, the model's - tokenizer embeddings are resized to this size. - learning_rate: The learning rate for the optimizer. - task_learning_rate: The learning rate for the task-specific parameters. If None, the - learning rate for all parameters is set to `learning_rate`. - warmup_proportion: The proportion of steps to warm up the learning rate. - freeze_base_model: If True, the base model parameters are frozen. - base_model_prefix: The prefix of the base model parameters when using a task_learning_rate - or freeze_base_model. If None, the base_model_prefix of the model is used. - **kwargs: Additional keyword arguments passed to the parent class, - see :class:`ModelWithBoilerplate`. - """ - - def __init__( - self, - model_name_or_path: str, - num_classes: int, - tokenizer_vocab_size: Optional[int] = None, - learning_rate: float = 1e-5, - task_learning_rate: Optional[float] = None, - warmup_proportion: float = 0.1, - freeze_base_model: bool = False, - base_model_prefix: Optional[str] = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - - self.save_hyperparameters() - - self.learning_rate = learning_rate - self.task_learning_rate = task_learning_rate - self.warmup_proportion = warmup_proportion - self.freeze_base_model = freeze_base_model - - config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_classes) - if self.is_from_pretrained: - self.model = AutoModelForSequenceClassification.from_config(config=config) - else: - self.model = AutoModelForSequenceClassification.from_pretrained( - model_name_or_path, config=config - ) - - self.base_model_prefix = base_model_prefix or self.model.base_model_prefix - - if tokenizer_vocab_size is not None: - self.model.resize_token_embeddings(tokenizer_vocab_size) - - if self.freeze_base_model: - for name, param in self.base_model_named_parameters(): - param.requires_grad = False - - def base_model_named_parameters(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]: - base_model: torch.nn.Module = getattr(self.model, self.base_model_prefix, None) - if base_model is None: - raise ValueError( - f"Base model with prefix '{self.base_model_prefix}' not found in {type(self.model).__name__}" - ) - if prefix: - prefix = f"{prefix}." - return base_model.named_parameters(prefix=f"{prefix}model.{self.base_model_prefix}") - - def task_named_parameters(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]: - base_model_parameter_names = dict(self.base_model_named_parameters(prefix=prefix)).keys() - for name, param in self.named_parameters(prefix=prefix): - if name not in base_model_parameter_names: - yield name, param - - def forward(self, inputs: InputType, targets: Optional[TargetType] = None) -> OutputType: - kwargs = {**inputs, **(targets or {})} - return self.model(**kwargs) - - def decode(self, inputs: InputType, outputs: OutputType) -> TargetType: - labels = torch.argmax(outputs.logits, dim=-1).to(torch.long) - probabilities = torch.softmax(outputs.logits, dim=-1) - return {"labels": labels, "probabilities": probabilities} - - def configure_optimizers(self): - if self.task_learning_rate is not None: - base_model_params = [param for name, param in self.base_model_named_parameters()] - task_params = [param for name, param in self.task_named_parameters()] - optimizer = AdamW( - [ - {"params": base_model_params, "lr": self.learning_rate}, - {"params": task_params, "lr": self.task_learning_rate}, - ] - ) - else: - optimizer = AdamW(self.parameters(), lr=self.learning_rate) - - if self.warmup_proportion > 0.0: - stepping_batches = self.trainer.estimated_stepping_batches - scheduler = get_linear_schedule_with_warmup( - optimizer, int(stepping_batches * self.warmup_proportion), stepping_batches - ) - return [optimizer], [{"scheduler": scheduler, "interval": "step"}] - else: - return optimizer diff --git a/src/pie_modules/models/simple_token_classification.py b/src/pie_modules/models/simple_token_classification.py deleted file mode 100644 index 7879aae04..000000000 --- a/src/pie_modules/models/simple_token_classification.py +++ /dev/null @@ -1,97 +0,0 @@ -import logging -from typing import MutableMapping, Optional, Tuple, Union - -import torch -from pytorch_ie import PyTorchIEModel -from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses -from pytorch_lightning.utilities.types import OptimizerLRScheduler -from torch import FloatTensor, LongTensor -from transformers import AutoConfig, AutoModelForTokenClassification, BatchEncoding -from transformers.modeling_outputs import TokenClassifierOutput -from typing_extensions import TypeAlias - -from pie_modules.models.common import ModelWithBoilerplate - -# model inputs / outputs / targets -InputType: TypeAlias = BatchEncoding -OutputType: TypeAlias = TokenClassifierOutput -TargetType: TypeAlias = MutableMapping[str, Union[LongTensor, FloatTensor]] -# step inputs (batch) / outputs (loss) -StepInputType: TypeAlias = Tuple[InputType, TargetType] -StepOutputType: TypeAlias = FloatTensor - - -logger = logging.getLogger(__name__) - - -@PyTorchIEModel.register() -class SimpleTokenClassificationModel( - ModelWithBoilerplate[InputType, OutputType, TargetType, StepOutputType], - RequiresModelNameOrPath, - RequiresNumClasses, -): - """A simple token classification model that wraps a (pretrained) model loaded with - AutoModelForTokenClassification from the transformers library. - - The model is trained with a cross-entropy loss function and uses the Adam optimizer. - - Note that for training, the labels for the special tokens (as well as for padding tokens) - are expected to have the value label_pad_id (-100 by default, which is the default ignore_index - value for the CrossEntropyLoss). The predictions for these tokens are also replaced with - label_pad_id to match the training labels for correct metric calculation. Therefore, the model - requires the special_tokens_mask and attention_mask (for padding) to be passed as inputs. - - Args: - model_name_or_path: The name or path of the pretrained transformer model to use. - num_classes: The number of classes to predict. - learning_rate: The learning rate to use for training. - label_pad_id: The label id to use for padding labels (at the padding token positions - as well as for the special tokens). - """ - - def __init__( - self, - model_name_or_path: str, - num_classes: int, - learning_rate: float = 1e-5, - label_pad_id: int = -100, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.save_hyperparameters() - - self.learning_rate = learning_rate - self.label_pad_id = label_pad_id - self.num_classes = num_classes - - config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_classes) - if self.is_from_pretrained: - self.model = AutoModelForTokenClassification.from_config(config=config) - else: - self.model = AutoModelForTokenClassification.from_pretrained( - model_name_or_path, config=config - ) - - def forward(self, inputs: InputType, targets: Optional[TargetType] = None) -> OutputType: - inputs_without_special_tokens_mask = { - k: v for k, v in inputs.items() if k != "special_tokens_mask" - } - return self.model(**inputs_without_special_tokens_mask, **(targets or {})) - - def decode(self, inputs: InputType, outputs: OutputType) -> TargetType: - # get the max index for each token from the logits - tags_tensor = torch.argmax(outputs.logits, dim=-1).to(torch.long) - - # mask out the padding and special tokens - tags_tensor = tags_tensor.masked_fill(inputs["attention_mask"] == 0, self.label_pad_id) - - # mask out the special tokens - tags_tensor = tags_tensor.masked_fill( - inputs["special_tokens_mask"] == 1, self.label_pad_id - ) - probabilities = torch.softmax(outputs.logits, dim=-1) - - return {"labels": tags_tensor, "probabilities": probabilities} - - def configure_optimizers(self) -> OptimizerLRScheduler: - return torch.optim.Adam(self.parameters(), lr=self.learning_rate) diff --git a/src/pie_modules/models/span_tuple_classification.py b/src/pie_modules/models/span_tuple_classification.py deleted file mode 100644 index bbdbf4938..000000000 --- a/src/pie_modules/models/span_tuple_classification.py +++ /dev/null @@ -1,457 +0,0 @@ -import logging -from dataclasses import dataclass -from typing import Iterator, List, MutableMapping, Optional, Tuple, TypeVar, Union - -import torch -from pytorch_ie import PyTorchIEModel -from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses -from torch import BoolTensor, FloatTensor, LongTensor, Tensor, nn -from torch.nn import Dropout, Parameter -from torch.optim import AdamW -from transformers import AutoConfig, AutoModel, get_linear_schedule_with_warmup -from transformers.utils import ModelOutput -from typing_extensions import TypeAlias - -from .common import ModelWithBoilerplate - - -class MLP(nn.Module): - def __init__(self, n_in, n_out, dropout=0, activation=nn.GELU()): - super().__init__() - self.linear = nn.Linear(n_in, n_out) - self.f = activation - self.dropout = Dropout(p=dropout) - self.reset_parameters() - - def reset_parameters(self): - nn.init.xavier_normal_(self.linear.weight) - nn.init.zeros_(self.linear.bias) - - def forward(self, x): - x = self.f(self.linear(x)) - x = self.dropout(x) - return x - - -@dataclass -class SpanPairClassifierOutput(ModelOutput): - """Base class for outputs of span pair classification models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) : - Classification loss. - logits (`torch.FloatTensor` of shape `(num_valid_input_pairs_in_batch, config.num_labels)`): - Classification scores (before SoftMax). - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): - The last hidden state of the transformer model. Returned if `return_embeddings=True`. - span_embeddings (`torch.FloatTensor` of shape `(batch_size, num_spans, span_embedding_dim)`, *optional*): - The embeddings of the spans. Returned if `return_embeddings=True`. - tuple_embeddings (`torch.FloatTensor` of shape `(num_valid_input_pairs_in_batch, tuple_embedding_dim)`, *optional*): - The embeddings of the tuples. Returned if `return_embeddings=True`. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - last_hidden_state: Optional[torch.FloatTensor] = None - span_embeddings: Optional[torch.FloatTensor] = None - tuple_embeddings: Optional[torch.FloatTensor] = None - - -# model inputs / outputs / targets -InputType: TypeAlias = MutableMapping[str, LongTensor] -OutputType: TypeAlias = SpanPairClassifierOutput -TargetType: TypeAlias = MutableMapping[str, Union[LongTensor, FloatTensor]] -# step inputs (batch) / outputs (loss) -StepInputType: TypeAlias = Tuple[InputType, TargetType] -StepOutputType: TypeAlias = FloatTensor - - -HF_MODEL_TYPE_TO_CLASSIFIER_DROPOUT_ATTRIBUTE = { - "albert": "classifier_dropout_prob", - "distilbert": "seq_classif_dropout", -} - -logger = logging.getLogger(__name__) - -T = TypeVar("T", bound=Tensor) - - -def get_embeddings_at_indices(embeddings: T, indices: LongTensor) -> T: - # embeddings: (bs, seq_len, hidden_size) - # indices: (bs, num_indices) - hidden_size = embeddings.size(-1) - # Expand dimensions of start_marker_positions to match hidden_states - indices_expanded = indices.unsqueeze(-1).expand(-1, -1, hidden_size) - # result: (bs, num_indices, hidden_size) - result = embeddings.gather(1, indices_expanded) - return result - - -@PyTorchIEModel.register() -class SpanTupleClassificationModel( - ModelWithBoilerplate[InputType, OutputType, TargetType, StepOutputType], - RequiresModelNameOrPath, - RequiresNumClasses, -): - """A span tuple classification model that uses a pooler to get a representation of the input - spans and then applies a linear classifier to that representation. The pooler can be configured - via the `span_embedding_mode` and `tuple_embedding_mode` arguments. It expects the input to - contain the indices of the start and end tokens of the spans (for the span pooler) and the - indices of the spans in the tuples to classify (for the tuple pooler). - - Args: - model_name_or_path: The name or path of the HuggingFace model to use. - num_classes: The number of classes for the classification task. - span_embedding_mode: The mode to pool the hidden states for the spans. One of "start_token", - "end_token", "start_and_end_token". - tuple_embedding_mode: The mode to pool the span embeddings for the tuples. Possible values are - "concat" (concatenate the embeddings of the tuple entries), "multiply2_and_concat" - (multiply the embeddings of the first two entries and concatenate them with the - embeddings of the first two entries) and "index_{idx}" (use the embedding of the entry - at index {idx} as the tuple embedding). Note that "multiply2_and_concat" requires - `num_tuple_entries=2`. Default: "multiply2_and_concat". - num_tuple_entries: The number of entries in the tuples. - tuple_entry_hidden_dim: If provided, the tuple entries (i.e. the span embeddings at the tuple indices) - are mapped to this dimensionality before combining them. Default: 768. - tokenizer_vocab_size: The size of the tokenizer vocabulary. If provided, the model's - tokenizer embeddings are resized to this size. - classifier_dropout: The dropout probability for the classifier. If not provided, the - dropout probability is taken from the Huggingface model config. - learning_rate: The learning rate for the optimizer. - task_learning_rate: The learning rate for the task-specific parameters. If None, the - learning rate for all parameters is set to `learning_rate`. - warmup_proportion: The proportion of steps to warm up the learning rate. - multi_label: If True, the model is trained as a multi-label classifier. - multi_label_threshold: The threshold for the multi-label classifier, i.e. the probability - above which a class is predicted. - freeze_base_model: If True, the base model parameters are frozen. - label_pad_value: The padding value for the labels. - probability_pad_value: The padding value for the probabilities. - **kwargs: Additional keyword arguments passed to the parent class, - see :class:`ModelWithBoilerplate`. - """ - - def __init__( - self, - model_name_or_path: str, - num_classes: int, - span_embedding_mode: str = "start_and_end_token", - tuple_embedding_mode: str = "multiply2_and_concat", - num_tuple_entries: int = 2, - tuple_entry_hidden_dim: Optional[int] = 768, - tokenizer_vocab_size: Optional[int] = None, - classifier_dropout: Optional[float] = None, - learning_rate: float = 1e-5, - task_learning_rate: Optional[float] = None, - warmup_proportion: float = 0.1, - multi_label: bool = False, - multi_label_threshold: float = 0.5, - freeze_base_model: bool = False, - label_pad_value: int = -100, - probability_pad_value: float = -1.0, - **kwargs, - ) -> None: - super().__init__(**kwargs) - - self.save_hyperparameters() - - self.learning_rate = learning_rate - self.task_learning_rate = task_learning_rate - self.warmup_proportion = warmup_proportion - self.freeze_base_model = freeze_base_model - self.label_pad_value = label_pad_value - self.probability_pad_value = probability_pad_value - - config = AutoConfig.from_pretrained(model_name_or_path) - if self.is_from_pretrained: - self.model = AutoModel.from_config(config=config) - else: - self.model = AutoModel.from_pretrained(model_name_or_path, config=config) - - if tokenizer_vocab_size is not None: - self.model.resize_token_embeddings(tokenizer_vocab_size) - - if self.freeze_base_model: - for param in self.model.parameters(): - param.requires_grad = False - - if classifier_dropout is None: - # Get the classifier dropout value from the Huggingface model config. - # This is a bit of a mess since some Configs use different variable names or change the semantics - # of the dropout (e.g. DistilBert has one dropout prob for QA and one for Seq classification, and a - # general one for embeddings, encoder and pooler). - classifier_dropout_attr = HF_MODEL_TYPE_TO_CLASSIFIER_DROPOUT_ATTRIBUTE.get( - config.model_type, "classifier_dropout" - ) - classifier_dropout = getattr(config, classifier_dropout_attr) or 0.0 - self.dropout = nn.Dropout(classifier_dropout) - - # embedder for the spans - self.span_embedding_mode = span_embedding_mode - if self.span_embedding_mode in ["start_token", "end_token"]: - self.span_embedding_dim = self.model.config.hidden_size - elif self.span_embedding_mode in ["start_and_end_token"]: - self.span_embedding_dim = self.model.config.hidden_size * 2 - else: - raise ValueError(f"Invalid value for span_embedding_mode: {self.span_embedding_mode}") - - # embedder for the tuples - self.num_tuple_entries = num_tuple_entries - self.tuple_entry_hidden_dim = tuple_entry_hidden_dim - if self.tuple_entry_hidden_dim is not None: - self.tuple_entry_embedders = nn.ModuleList( - [ - MLP(self.span_embedding_dim, self.tuple_entry_hidden_dim) - for _ in range(num_tuple_entries) - ] - ) - tuple_entry_dim = self.tuple_entry_hidden_dim - else: - self.tuple_entry_embedders = None - tuple_entry_dim = self.span_embedding_dim - self.tuple_embedding_mode = tuple_embedding_mode - if self.tuple_embedding_mode == "concat": - tuple_embedding_dim = tuple_entry_dim * self.num_tuple_entries - elif self.tuple_embedding_mode == "multiply2_and_concat": - if self.num_tuple_entries != 2: - raise ValueError( - "tuple_embedding_mode='multiply2_and_concat' requires num_tuple_entries=2" - ) - tuple_embedding_dim = tuple_entry_dim * 3 - elif self.tuple_embedding_mode.startswith("index_"): - idx = int(self.tuple_embedding_mode.split("_")[1]) - if idx >= self.num_tuple_entries: - raise ValueError( - f"Invalid index IDX={idx} for tuple_embedding_mode='index_IDX'. " - f"Number of entries in tuple: {self.num_tuple_entries}" - ) - tuple_embedding_dim = tuple_entry_dim - else: - raise ValueError( - f"Invalid value for tuple_embedding_mode: {self.tuple_embedding_mode}" - ) - - # classifier - # TODO: do sth more sophisticated here - self.classifier = nn.Linear(tuple_embedding_dim, num_classes) - - self.multi_label = multi_label - self.multi_label_threshold = multi_label_threshold - self.loss_fct = nn.BCEWithLogitsLoss() if self.multi_label else nn.CrossEntropyLoss() - - def span_embedder( - self, - hidden_state: FloatTensor, - span_start_indices: LongTensor, - span_end_indices: LongTensor, - ) -> FloatTensor: - """Create the span embeddings from the hidden states and the span start and end indices. - - Args: - hidden_state: The last hidden state from the transformer model. shape: (batch_size, seq_len, hidden_size) - span_start_indices: The indices of the start tokens of the spans. shape: (batch_size, num_spans) - span_end_indices: The indices of the end tokens of the spans. shape: (batch_size, num_spans) - - Returns: - The pooled span embeddings. shape: (batch_size, num_spans, hidden_size) - """ - - if self.span_embedding_mode == "start_token": - span_embeddings = get_embeddings_at_indices(hidden_state, span_start_indices) - elif self.span_embedding_mode == "end_token": - span_embeddings = get_embeddings_at_indices(hidden_state, span_end_indices) - elif self.span_embedding_mode == "start_and_end_token": - span_embeddings = torch.cat( - [ - get_embeddings_at_indices(hidden_state, span_start_indices), - get_embeddings_at_indices(hidden_state, span_end_indices), - ], - dim=-1, - ) - else: - raise ValueError(f"Invalid value for span_embedding_mode: {self.span_embedding_mode}") - - return span_embeddings - - def tuple_embedder( - self, - span_embeddings: FloatTensor, - tuple_indices: LongTensor, - tuple_indices_mask: BoolTensor, - ) -> FloatTensor: - """Create the tuple embeddings from the span embeddings and the tuple indices. - - Args: - span_embeddings: The span embeddings. shape: (batch_size, num_spans, span_embedding_size) - tuple_indices: The indices of the spans in the tuples. shape: (batch_size, num_tuples, num_tuple_entries) - tuple_indices_mask: A mask indicating which tuples are valid. shape: (batch_size, num_tuples) - - Returns: - The pooled tuple embeddings. shape: (num_tuples_in_batch, num_tuple_entries * span_embedding_size) - """ - - if not tuple_indices.shape[-1] == self.num_tuple_entries: - raise ValueError( - f"Number of entries in tuple_indices should be equal to num_tuple_entries={self.num_tuple_entries}" - ) - batch_size, max_num_spans = span_embeddings.shape[:2] - # we need to add the batch offsets to the tuple indices to get the correct indices in the - # flattened span_embeddings - batch_offsets = ( - torch.arange(batch_size, device=tuple_indices.device).unsqueeze(-1).unsqueeze(-1) - * max_num_spans - ) - tuple_indices_with_offsets = tuple_indices + batch_offsets - # shape: (num_tuples_in_batch, num_entries) - valid_tuple_indices_flat = tuple_indices_with_offsets[tuple_indices_mask] - - # we need to flatten the span_embeddings to get the embeddings at the tuple indices - # shape: (batch_size * num_spans, span_embedding_size) - span_embeddings_flat = span_embeddings.view(-1, span_embeddings.size(-1)) - - # map the span embeddings individually for each tuple entry - # each entry has the shape: (batch_size * num_spans, tuple_entry_dim) - if self.tuple_entry_embedders is not None: - span_embeddings_mapped = [ - mlp(span_embeddings_flat) for mlp in self.tuple_entry_embedders - ] - else: - span_embeddings_mapped = [span_embeddings_flat] * self.num_tuple_entries - - tuple_embeddings_list: List[FloatTensor] = [] - for i in range(self.num_tuple_entries): - # shape: (num_tuples_in_batch) - current_tuple_indices = valid_tuple_indices_flat[:, i] - # get the embeddings that were mapped with the mlp for the current entry - # shape: (batch_size * num_spans, tuple_entry_dim) - span_embeddings_mapped_for_entry = span_embeddings_mapped[i] - # shape: (num_tuples_in_batch, tuple_entry_dim) - current_embeddings = span_embeddings_mapped_for_entry[current_tuple_indices] - tuple_embeddings_list.append(current_embeddings) - if self.tuple_embedding_mode == "concat": - tuple_embeddings = torch.cat(tuple_embeddings_list, dim=-1).to(span_embeddings.dtype) - elif self.tuple_embedding_mode == "multiply2_and_concat": - tuple_embeddings = torch.cat( - [ - tuple_embeddings_list[0] * tuple_embeddings_list[1], - tuple_embeddings_list[0], - tuple_embeddings_list[1], - ], - dim=-1, - ) - elif self.tuple_embedding_mode.startswith("index_"): - index = int(self.tuple_embedding_mode.split("_")[1]) - tuple_embeddings = tuple_embeddings_list[index] - else: - raise ValueError( - f"Invalid value for tuple_embedding_mode: {self.tuple_embedding_mode}" - ) - return tuple_embeddings - - def forward( - self, - inputs: InputType, - targets: Optional[TargetType] = None, - return_embeddings: bool = False, - ) -> OutputType: - span_embedder_inputs = {} - tuple_embedder_inputs = {} - base_model_inputs = {} - for k, v in inputs.items(): - if k.startswith("span_"): - span_embedder_inputs[k] = v - elif k.startswith("tuple_"): - tuple_embedder_inputs[k] = v - else: - base_model_inputs[k] = v - - output = self.model(**base_model_inputs) - last_hidden_state = self.dropout(output.last_hidden_state) - - # get the span embeddings from the hidden states and the start and end marker positions - span_embeddings = self.span_embedder( - hidden_state=last_hidden_state, **span_embedder_inputs - ) - # get the tuple embeddings from the span embeddings and the tuple indices - # Note that this flattens the batch dimension to not compute embeddings for padding tuples! - tuple_embeddings_flat = self.tuple_embedder( - span_embeddings=span_embeddings, **tuple_embedder_inputs - ) - - logits_valid = self.classifier(tuple_embeddings_flat) - - result = {"logits": logits_valid} - if targets is not None: - labels = targets["labels"] - mask = inputs["tuple_indices_mask"] - valid_labels = labels[mask] - loss = self.loss_fct(logits_valid, valid_labels) - result["loss"] = loss - - if return_embeddings: - result["last_hidden_state"] = last_hidden_state - result["tuple_embeddings"] = tuple_embeddings_flat - result["span_embeddings"] = span_embeddings - - return SpanPairClassifierOutput(**result) - - def decode(self, inputs: InputType, outputs: OutputType) -> TargetType: - if not self.multi_label: - labels_flat = torch.argmax(outputs.logits, dim=-1).to(torch.long) - probabilities_flat = torch.softmax(outputs.logits, dim=-1) - else: - probabilities_flat = torch.sigmoid(outputs.logits) - labels_flat = (probabilities_flat > self.multi_label_threshold).to(torch.long) - - # re-construct the original shape - mask = inputs["tuple_indices_mask"] - # create "empty" labels and probabilities tensors - labels = ( - torch.ones(mask.shape, dtype=torch.long, device=labels_flat.device) - * self.label_pad_value - ) - prob_shape = list(mask.shape) + [probabilities_flat.shape[-1]] - probabilities = ( - torch.ones(prob_shape, dtype=torch.float, device=probabilities_flat.device) - * self.probability_pad_value - ) - # fill in the valid values - labels[mask] = labels_flat - probabilities[mask] = probabilities_flat - - return {"labels": labels, "probabilities": probabilities} - - def base_model_named_parameters(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]: - if prefix: - prefix = f"{prefix}." - return self.model.named_parameters(prefix=f"{prefix}model") - - def task_named_parameters(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]: - if prefix: - prefix = f"{prefix}." - base_model_parameter_names = dict(self.base_model_named_parameters(prefix=prefix)).keys() - for name, param in self.named_parameters(prefix=prefix): - if name not in base_model_parameter_names: - yield name, param - - def configure_optimizers(self): - if self.task_learning_rate is not None: - base_model_params = (param for name, param in self.base_model_named_parameters()) - task_params = (param for name, param in self.task_named_parameters()) - optimizer = AdamW( - [ - {"params": base_model_params, "lr": self.learning_rate}, - {"params": task_params, "lr": self.task_learning_rate}, - ] - ) - else: - optimizer = AdamW(self.parameters(), lr=self.learning_rate) - - if self.warmup_proportion > 0.0: - stepping_batches = self.trainer.estimated_stepping_batches - scheduler = get_linear_schedule_with_warmup( - optimizer, int(stepping_batches * self.warmup_proportion), stepping_batches - ) - return [optimizer], [{"scheduler": scheduler, "interval": "step"}] - else: - return optimizer diff --git a/src/pie_modules/models/token_classification_with_seq2seq_encoder_and_crf.py b/src/pie_modules/models/token_classification_with_seq2seq_encoder_and_crf.py deleted file mode 100644 index e5b0b6e16..000000000 --- a/src/pie_modules/models/token_classification_with_seq2seq_encoder_and_crf.py +++ /dev/null @@ -1,247 +0,0 @@ -import logging -from typing import Any, Dict, MutableMapping, Optional, Tuple, Union - -import torch -from pytorch_ie import PyTorchIEModel -from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses -from pytorch_lightning.utilities.types import OptimizerLRScheduler -from torch import FloatTensor, LongTensor, nn -from transformers import ( - AutoConfig, - AutoModel, - BatchEncoding, - get_linear_schedule_with_warmup, -) -from transformers.modeling_outputs import TokenClassifierOutput -from typing_extensions import TypeAlias - -from pie_modules.models.common import ModelWithBoilerplate -from pie_modules.models.components.seq2seq_encoder import build_seq2seq_encoder - -# model inputs / outputs / targets -InputType: TypeAlias = BatchEncoding -OutputType: TypeAlias = TokenClassifierOutput -TargetType: TypeAlias = MutableMapping[str, Union[LongTensor, FloatTensor]] -# step inputs (batch) / outputs (loss) -StepInputType: TypeAlias = Tuple[InputType, TargetType] -StepOutputType: TypeAlias = FloatTensor - -HF_MODEL_TYPE_TO_CLASSIFIER_DROPOUT_ATTRIBUTE = { - "bert": "hidden_dropout_prob", - "roberta": "hidden_dropout_prob", - "albert": "classifier_dropout_prob", - "distilbert": "seq_classif_dropout", - "deberta-v2": "hidden_dropout_prob", - "longformer": "hidden_dropout_prob", -} - -logger = logging.getLogger(__name__) - - -@PyTorchIEModel.register() -class TokenClassificationModelWithSeq2SeqEncoderAndCrf( - ModelWithBoilerplate[InputType, OutputType, TargetType, StepOutputType], - RequiresNumClasses, - RequiresModelNameOrPath, -): - """A token classification model that wraps a (pretrained) model loaded with AutoModel from the - transformers library. The model can optionally be followed by a seq2seq encoder (e.g. an LSTM). - Finally, Conditional Random Fields (CRFs) can be used to decode the predictions. - - The model is trained with a cross-entropy loss function and uses the Adam optimizer. - - Note that for training, the labels for the special tokens (as well as for padding tokens) - are expected to have the value label_pad_id (-100 by default, which is the default ignore_index - value for the CrossEntropyLoss). The predictions for these tokens are also replaced with - label_pad_id to match the training labels for correct metric calculation. Therefore, the model - requires the special_tokens_mask and attention_mask (for padding) to be passed as inputs. - - Args: - model_name_or_path: The name or path of the (pretrained) transformer model to use. - num_classes: The number of classes to predict. - learning_rate: The learning rate to use for training. - task_learning_rate: The learning rate to use for the task-specific parameters, i.e. - for the sequence-to-sequence encoder, classification head, and CRF. If None, the - learning_rate is used for all parameters. - use_crf: Whether to use a CRF to decode the predictions. - label_pad_id: The label id to use for padding labels (at the padding token positions - as well as for the special tokens). - special_token_label_id: The label id to use for special tokens (e.g. [CLS], [SEP]). This - is used to replace the targets for special tokens with the label_pad_id before passing - them to the CRF because the CRF does not allow the first token to be masked out. - classifier_dropout: The dropout probability to use for the classification head. - freeze_base_model: Whether to freeze the base model (i.e. the transformer) during training. - warmup_proportion: The proportion of training steps to use for the linear warmup. - seq2seq_encoder: A dictionary with the configuration for the seq2seq encoder. If None, no - seq2seq encoder is used. See ./components/seq2seq_encoder.py for further information. - """ - - def __init__( - self, - model_name_or_path: str, - num_classes: int, - learning_rate: float = 1e-5, - task_learning_rate: Optional[float] = None, - use_crf: bool = True, - label_pad_id: int = -100, - special_token_label_id: int = 0, - classifier_dropout: Optional[float] = None, - freeze_base_model: bool = False, - warmup_proportion: float = 0.1, - seq2seq_encoder: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.save_hyperparameters() - - self.special_token_label_id = special_token_label_id - - self.learning_rate = learning_rate - self.warmup_proportion = warmup_proportion - self.task_learning_rate = task_learning_rate - self.label_pad_id = label_pad_id - self.num_classes = num_classes - - config = AutoConfig.from_pretrained(model_name_or_path) - if self.is_from_pretrained: - self.model = AutoModel.from_config(config=config) - else: - self.model = AutoModel.from_pretrained(model_name_or_path, config=config) - - if freeze_base_model: - self.model.requires_grad_(False) - - hidden_size = config.hidden_size - self.seq2seq_encoder = None - if seq2seq_encoder is not None: - self.seq2seq_encoder, hidden_size = build_seq2seq_encoder( - config=seq2seq_encoder, input_size=hidden_size - ) - - if classifier_dropout is None: - # Get the classifier dropout value from the Huggingface model config. - # This is a bit of a mess since some Configs use different variable names or change the semantics - # of the dropout (e.g. DistilBert has one dropout prob for QA and one for Seq classification, and a - # general one for embeddings, encoder and pooler). - classifier_dropout_attr = HF_MODEL_TYPE_TO_CLASSIFIER_DROPOUT_ATTRIBUTE.get( - config.model_type, "classifier_dropout" - ) - if hasattr(config, classifier_dropout_attr): - classifier_dropout = getattr(config, classifier_dropout_attr) - else: - raise ValueError( - f"The config {type(config),__name__} loaded from {model_name_or_path} has no attribute " - f"{classifier_dropout_attr}" - ) - self.dropout = nn.Dropout(classifier_dropout) - - self.classifier = nn.Linear(hidden_size, num_classes) - - if use_crf: - try: - from torchcrf import CRF - except ImportError: - raise ImportError( - "To use CRFs, the torchcrf package must be installed. " - "You can install it with `pip install pytorch-crf`." - ) - - self.crf = CRF(num_tags=num_classes, batch_first=True) - else: - self.crf = None - - def decode(self, inputs: InputType, outputs: OutputType) -> TargetType: - result = {} - logits = outputs.logits - attention_mask = inputs["attention_mask"] - special_tokens_mask = inputs["special_tokens_mask"] - attention_mask_bool = attention_mask.to(torch.bool) - if self.crf is not None: - decoded_tags = self.crf.decode(emissions=logits, mask=attention_mask_bool) - # pad the decoded tags to the length of the logits to have the same shape as when not using the crf - seq_len = logits.shape[1] - padded_tags = [ - tags + [self.label_pad_id] * (seq_len - len(tags)) for tags in decoded_tags - ] - tags_tensor = torch.tensor(padded_tags, device=logits.device).to(torch.long) - else: - # get the max index for each token from the logits - tags_tensor = torch.argmax(logits, dim=-1).to(torch.long) - # set the padding and special tokens to the label_pad_id - mask = attention_mask_bool & ~special_tokens_mask.to(torch.bool) - tags_tensor = tags_tensor.masked_fill(~mask, self.label_pad_id) - - result["labels"] = tags_tensor - # TODO: is it correct to use this also in the case of the crf? - result["probabilities"] = torch.softmax(logits, dim=-1) - - return result - - def forward( - self, inputs: InputType, targets: Optional[TargetType] = None - ) -> TokenClassifierOutput: - inputs_without_special_tokens_mask = { - k: v for k, v in inputs.items() if k != "special_tokens_mask" - } - outputs = self.model(**inputs_without_special_tokens_mask) - sequence_output = outputs[0] - - if self.seq2seq_encoder is not None: - sequence_output = self.seq2seq_encoder(sequence_output) - - sequence_output = self.dropout(sequence_output) - logits = self.classifier(sequence_output) - - loss = None - if targets is not None: - labels = targets["labels"] - if self.crf is not None: - # Overwrite the padding labels with ignore_index. Note that this is different from the - # attention_mask, because the attention_mask includes special tokens, whereas the labels - # are set to label_pad_id also for special tokens (e.g. [CLS]). We need handle all - # occurrences of label_pad_id because usually that index is out of range with respect to - # the number of logits in which case the crf would complain. However, we can not simply - # pass a mask to the crf that also masks out the special tokens, because the crf does not - # allow the first token to be masked out. - mask_pad_or_special = labels == self.label_pad_id - labels_valid = labels.masked_fill(mask_pad_or_special, self.special_token_label_id) - # the crf expects a bool mask - if "attention_mask" in inputs: - mask_bool = inputs["attention_mask"].to(torch.bool) - else: - mask_bool = None - log_likelihood = self.crf(emissions=logits, tags=labels_valid, mask=mask_bool) - loss = -log_likelihood - else: - loss_fct = nn.CrossEntropyLoss(ignore_index=self.label_pad_id) - loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1)) - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def configure_optimizers(self) -> OptimizerLRScheduler: - if self.task_learning_rate is not None: - all_params = dict(self.named_parameters()) - base_model_params = dict(self.model.named_parameters(prefix="model")) - task_params = {k: v for k, v in all_params.items() if k not in base_model_params} - optimizer = torch.optim.AdamW( - [ - {"params": base_model_params.values(), "lr": self.learning_rate}, - {"params": task_params.values(), "lr": self.task_learning_rate}, - ] - ) - else: - optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate) - - if self.warmup_proportion > 0.0: - stepping_batches = self.trainer.estimated_stepping_batches - scheduler = get_linear_schedule_with_warmup( - optimizer, int(stepping_batches * self.warmup_proportion), stepping_batches - ) - return [optimizer], [{"scheduler": scheduler, "interval": "step"}] - else: - return optimizer diff --git a/src/pie_modules/taskmodules/__init__.py b/src/pie_modules/taskmodules/__init__.py deleted file mode 100644 index 46d3766ae..000000000 --- a/src/pie_modules/taskmodules/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from .cross_text_binary_coref import CrossTextBinaryCorefTaskModule -from .extractive_question_answering import ExtractiveQuestionAnsweringTaskModule -from .labeled_span_extraction_by_token_classification import ( - LabeledSpanExtractionByTokenClassificationTaskModule, -) -from .pointer_network_for_end2end_re import PointerNetworkTaskModuleForEnd2EndRE -from .re_span_pair_classification import RESpanPairClassificationTaskModule -from .re_text_classification_with_indices import ( - RETextClassificationWithIndicesTaskModule, -) -from .text_to_text import TextToTextTaskModule diff --git a/src/pie_modules/taskmodules/common/__init__.py b/src/pie_modules/taskmodules/common/__init__.py index e95a3f006..e69de29bb 100644 --- a/src/pie_modules/taskmodules/common/__init__.py +++ b/src/pie_modules/taskmodules/common/__init__.py @@ -1,4 +0,0 @@ -from .interfaces import AnnotationEncoderDecoder, DecodingException -from .mixins import BatchableMixin, RelationStatisticsMixin, StatisticsMixin -from .taskmodule_with_document_converter import TaskModuleWithDocumentConverter -from .utils import get_first_occurrence_index diff --git a/src/pie_modules/taskmodules/common/interfaces.py b/src/pie_modules/taskmodules/common/interfaces.py deleted file mode 100644 index 388b60ae9..000000000 --- a/src/pie_modules/taskmodules/common/interfaces.py +++ /dev/null @@ -1,63 +0,0 @@ -import abc -from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar - -from pie_core import Annotation - -# Annotation Encoding type: encoding for a single annotation -AE = TypeVar("AE") -# Annotation type -A = TypeVar("A", bound=Annotation) -# Annotation Collection Encoding type: encoding for a collection of annotations, -# e.g. all relevant annotations for a document -ACE = TypeVar("ACE") - - -class DecodingException(Exception, Generic[AE], abc.ABC): - """Exception raised when decoding fails.""" - - identifier: str - - def __init__(self, message: str, encoding: AE): - self.message = message - self.encoding = encoding - - -class AnnotationEncoderDecoder(abc.ABC, Generic[A, AE]): - """Base class for annotation encoders and decoders.""" - - @abc.abstractmethod - def encode(self, annotation: A, metadata: Optional[Dict[str, Any]] = None) -> AE: - pass - - @abc.abstractmethod - def decode(self, encoding: AE, metadata: Optional[Dict[str, Any]] = None) -> A: - pass - - def build_decoding_constraints( - self, partial_encoding: AE - ) -> Tuple[Optional[Any], Optional[Any]]: - """Given a partial encoding, build the constraints for the next encoding step. - - Returns: - - A tuple of two elements: - - The first element is a set of positive constraints for the decoder. - - The second element is a set of negative constraints for the decoder. - """ - raise NotImplementedError( - "build_decoder_constraints is not implemented for this encoder/decoder." - ) - - def parse(self, encoding: AE) -> Tuple[List[A], Dict[str, int], AE]: - """Parse the encoding and return a list of annotations. This should be error tolerant and - return all annotations that can be parsed and the remaining encoding. - - Args: - encoding: The encoding to parse. Can be incomplete. - - Returns: - - A tuple of three elements: - - A list of encoded annotations. - - A dictionary mapping error types to their counts. - - The remaining encoding after parsing. - """ - raise NotImplementedError("parse is not implemented for this encoder/decoder.") diff --git a/src/pie_modules/taskmodules/common/mixins.py b/src/pie_modules/taskmodules/common/mixins.py deleted file mode 100644 index 0e2909893..000000000 --- a/src/pie_modules/taskmodules/common/mixins.py +++ /dev/null @@ -1,297 +0,0 @@ -import dataclasses -import logging -from abc import ABC, abstractmethod -from collections import defaultdict -from typing import Any, Dict, Generic, Iterable, List, Optional, Tuple, TypeVar - -import pandas as pd -import torch -import torch.nn.functional as F -from pie_core import Annotation -from torch import Tensor - -logger = logging.getLogger(__name__) - - -def _pad_tensor(tensor: Tensor, target_shape: List[int], pad_value: float) -> Tensor: - """Pad a tensor to a target shape. - - Args: - tensor: The tensor to pad. - target_shape: The target shape. - pad_value: The value to use for padding. - - Returns: The padded tensor. - """ - - shape = tensor.shape - pad: List[int] = [] - for i, s in enumerate(shape): - pad = [0, target_shape[i] - s] + pad - result = F.pad(tensor, pad=pad, value=pad_value) - - return result - - -def maybe_pad_values( - values: Any, pad_value: Optional[Any] = None, strategy: str = "longest" -) -> Any: - """If an iterable of values is passed and a pad value is given, pad the values to the same - length and create a tensor from them. Otherwise, return the values unchanged. - - Note that the padding is done on all dimensions. - - Args: - values: The values to pad. - pad_value: The value to use for padding. - strategy: The padding strategy. Currently only "longest" is supported. - - Returns: The padded values. - """ - - if pad_value is None: - return values - if not isinstance(values, Iterable): - raise TypeError(f"values must be iterable to pad them, but got {type(values)}") - if strategy != "longest": - raise ValueError(f"unknown padding strategy: {strategy}") - tensor_list = [torch.tensor(value_list) for value_list in values] - shape_lists = list(zip(*[t.shape for t in tensor_list])) - max_shape = [max(dims) for dims in shape_lists] - padded = [ - _pad_tensor(tensor=t, target_shape=max_shape, pad_value=pad_value) - for i, t in enumerate(tensor_list) - ] - return torch.stack(padded) - - -def maybe_to_tensor( - values: Iterable[Any], dtype: Optional[torch.dtype] = None, pad_value: Optional[Any] = None -) -> Any: - """If an iterable of values is passed and a dtype is given, convert the values to a tensor of - the given type. - - Args: - values: The values to convert. - dtype: A dtype to convert the values to. - pad_value: A pad value to use if the values are padded. - - Returns: A tensor or the values unchanged. - """ - - if all(v is None for v in values): - return None - if dtype is None: - return values - maybe_padded = maybe_pad_values(values=values, pad_value=pad_value) - if not isinstance(maybe_padded, torch.Tensor): - maybe_padded = torch.Tensor(maybe_padded) - tensor = maybe_padded.to(dtype=dtype) - return tensor - - -class BatchableMixin: - """A mixin class that provides a batch method to batch a list of instances of the class. All - attributes, but also property methods, are batched. The batch method returns a dictionary with - all attribute / property names as keys. The values are tensors created from the stacked values - of the attributes / properties. The tensors are padded to the length of the longest instance in - the batch and converted to the given dtype. - - Example: - >>> import dataclasses - >>> from typing import List, Dict - >>> import torch - >>> - >>> @dataclasses.dataclass - >>> class Foo(BatchableMixin): - >>> a: List[int] - >>> - >>> @property - >>> def len_a(self): - >>> return len(self.a) - >>> - >>> x = Foo(a=[1, 2, 3]) - >>> y = Foo(a=[4, 5]) - >>> - >>> Foo.batch(values=[x, y], dtypes={"a": torch.int64, "len_a": torch.int64}, pad_values={"a": 0}) - {'a': tensor([[1, 2, 3],[4, 5, 0]]), 'len_a': tensor([3, 2])} - """ - - @classmethod - def get_property_names(cls) -> List[str]: - return [name for name in cls.__dict__ if isinstance(getattr(cls, name), property)] - - @classmethod - def get_dataclass_field_names(cls) -> List[str]: - if dataclasses.is_dataclass(cls): - return [f.name for f in dataclasses.fields(cls)] - else: - return [] - - @classmethod - def get_attribute_names(cls) -> List[str]: - return cls.get_property_names() + cls.get_dataclass_field_names() - - @classmethod - def batch( - cls, - values: List[Any], - dtypes: Dict[str, torch.dtype], - pad_values: Dict[str, Any], - ) -> Dict[str, Any]: - attribute_names = cls.get_attribute_names() - return { - k: maybe_to_tensor( - values=[getattr(x, k) for x in values], - dtype=dtypes.get(k, None), - pad_value=pad_values.get(k, None), - ) - for k in attribute_names - # Only batch attributes that are not None for any of the values. - if not all(getattr(x, k) is None for x in values) - } - - -T = TypeVar("T") - - -def increase_counter( - key: Tuple[Any, ...], - statistics: Dict[Tuple[Any, ...], int], - value: int = 1, -): - key_s = tuple(str(k) for k in key) - statistics[key_s] += value - - -class StatisticsMixin(ABC, Generic[T]): - """A mixin class that provides methods to collect and format statistics. - - Args: - collect_statistics: Control whether statistics should be collected. - If `False`, the mixin will not show any statistics when calling - `show_statistics`. Further effects depend on the implementation - of the mixin. - **kwargs: Additional keyword arguments to pass to the parent class. - """ - - def __init__(self, collect_statistics: bool = False, **kwargs): - super().__init__(**kwargs) - self.collect_statistics = collect_statistics - self.reset_statistics() - - @abstractmethod - def reset_statistics(self): - """Reset the statistics collected by this mixin (state).""" - pass - - @abstractmethod - def get_statistics(self) -> T: - """Get the statistics collected by this mixin. - - This should *not* modify the state of the mixin, repeated calls should return the same - result! - """ - pass - - def format_statistics(self, statistics: T) -> str: - """Format the statistics collected by this mixin as string for display (usually on - console).""" - raise NotImplementedError( - f"format_statistics is not implemented for {self.__class__.__name__}. " - "Please implement this method to show formatted statistics." - ) - - def show_statistics(self): - if self.collect_statistics: - logger.info(f"statistics:\n{self.format_statistics(self.get_statistics())}") - - -class RelationStatisticsMixin(StatisticsMixin[Dict[Tuple[str, str], int]]): - """A mixin class that provides methods to collect and format statistics about relations. - - This mixin collects statistics about relations, such as the number of available, used, and - skipped relations. - """ - - def get_none_label_for_statistics(self) -> str: - if not hasattr(self, "_statistics_none_label"): - if hasattr(self, "none_label"): - # If the mixin has a `none_label` attribute, use it as the label for "no relation". - self._statistics_none_label = self.none_label - else: - self._statistics_none_label = "no_relation" - logger.warning( - f"{type(self).__name__} does not have a `none_label` attribute. " - "Using default value 'no_relation'. " - "`none_label` is used as the label for relations with score 0 in statistics and " - "all relations with label different from `none_label` will be summarized to 'all_relations'. " - "Set the `none_label` attribute before using statistics or " - "overwrite `get_none_label_for_statistics()` function to get rid of this message." - ) - - return self._statistics_none_label - - def reset_statistics(self): - self._collected_relations: Dict[str, List[Annotation]] = defaultdict(list) - - def collect_relation(self, kind: str, relation: Annotation): - if self.collect_statistics: - self._collected_relations[kind].append(relation) - - def collect_all_relations(self, kind: str, relations: Iterable[Annotation]): - if self.collect_statistics: - self._collected_relations[kind].extend(relations) - - def get_statistics(self) -> Dict[Tuple[str, str], int]: - if self.collect_statistics: - # create statistics from the collected relations - statistics: Dict[Tuple[str, str], int] = defaultdict(int) - all_relations = set(self._collected_relations["available"]) - used_relations = set(self._collected_relations["used"]) - skipped_other = all_relations - used_relations - for key, rels in self._collected_relations.items(): - rels_set = set(rels) - if key.startswith("skipped_"): - skipped_other -= rels_set - elif key.startswith("used_"): - pass - elif key in ["available", "used"]: - pass - else: - raise ValueError(f"unknown key: {key}") - for rel in rels_set: - # Set `none_label` as label when the score is zero. We encode negative relations - # in such a way in the case of multi-label or binary (similarity for coref). - label = rel.label if rel.score > 0 else self.get_none_label_for_statistics() - increase_counter(key=(key, label), statistics=statistics) - for rel in skipped_other: - increase_counter(key=("skipped_other", rel.label), statistics=statistics) - - return dict(statistics) - else: - return {} - - def format_statistics(self, statistics: Dict[Tuple[str, str], int]) -> str: - if len(statistics) > 0: - to_show_series = pd.Series(statistics) - # unstack index to have relation labels as column names - to_show = to_show_series.unstack() - else: - # If there were no statistics, create an empty dummy dataframe. - to_show = pd.DataFrame(pd.Series(dict())) - # fill missing values with 0 and convert back to int (unstacking may introduce NaNs which are float type) - to_show = to_show.fillna(0).astype(int) - if to_show.columns.size > 1: - to_show["all_relations"] = to_show.loc[ - :, to_show.columns != self.get_none_label_for_statistics() - ].sum(axis=1) - - # transpose - # to have the labels (which may be a lot) as index for improved readability and - # to allow to keep counts as int columns (dtypes are per-column, not per-row) - to_show = to_show.T - if "used" in to_show.columns and "available" in to_show.columns: - to_show["used %"] = (100 * to_show["used"] / to_show["available"]).round() - - return to_show.to_markdown() diff --git a/src/pie_modules/taskmodules/common/taskmodule_with_document_converter.py b/src/pie_modules/taskmodules/common/taskmodule_with_document_converter.py deleted file mode 100644 index 69f8e7668..000000000 --- a/src/pie_modules/taskmodules/common/taskmodule_with_document_converter.py +++ /dev/null @@ -1,117 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Generic, Iterable, Iterator, Optional, Sequence, Type, TypeVar, Union - -from pie_core import ( - Document, - IterableTaskEncodingDataset, - TaskEncoding, - TaskEncodingDataset, - TaskEncodingSequence, - TaskModule, -) -from typing_extensions import TypeAlias - -DocumentType = TypeVar("DocumentType", bound=Document) -ConvertedDocumentType = TypeVar("ConvertedDocumentType", bound=Document) -InputEncodingType = TypeVar("InputEncodingType") -TargetEncodingType = TypeVar("TargetEncodingType") -# TaskEncoding: defined below -TaskBatchEncodingType = TypeVar("TaskBatchEncodingType") -# ModelBatchEncoding: defined in models -ModelBatchOutputType = TypeVar("ModelBatchOutputType") -TaskOutputType = TypeVar("TaskOutputType") - -TaskEncodingType: TypeAlias = TaskEncoding[ - DocumentType, - InputEncodingType, - TargetEncodingType, -] - - -class TaskModuleWithDocumentConverter( - TaskModule, - ABC, - Generic[ - ConvertedDocumentType, - DocumentType, - InputEncodingType, - TargetEncodingType, - TaskBatchEncodingType, - ModelBatchOutputType, - TaskOutputType, - ], -): - @property - def document_type(self) -> Optional[Type[Document]]: - if super().document_type is not None: - raise NotImplementedError(f"please overwrite document_type for {type(self).__name__}") - else: - return None - - @abstractmethod - def _convert_document(self, document: DocumentType) -> ConvertedDocumentType: - """Convert a document of the taskmodule document type to the expected document type of the - wrapped taskmodule. - - Args: - document: the input document - - Returns: the converted document - """ - pass - - def _prepare(self, documents: Sequence[DocumentType]) -> None: - # use an iterator for lazy processing - documents_converted = (self._convert_document(doc) for doc in documents) - super()._prepare(documents=documents_converted) - - def convert_document(self, document: DocumentType) -> ConvertedDocumentType: - converted_document = self._convert_document(document) - if "original_document" in converted_document.metadata: - raise ValueError( - f"metadata of converted_document has already and entry 'original_document', " - f"this is not allowed. Please adjust '{type(self).__name__}._convert_document()' " - f"to produce documents without that key in metadata." - ) - converted_document.metadata["original_document"] = document - return converted_document - - def encode(self, documents: Union[DocumentType, Iterable[DocumentType]], **kwargs) -> Union[ - Sequence[TaskEncodingType], - TaskEncodingSequence[TaskEncodingType, DocumentType], - Iterator[TaskEncodingType], - TaskEncodingDataset[TaskEncodingType], - IterableTaskEncodingDataset[TaskEncodingType], - ]: - converted_documents: Union[DocumentType, Iterable[DocumentType]] - if isinstance(documents, Document): - converted_documents = self.convert_document(documents) - else: - converted_documents = [self.convert_document(doc) for doc in documents] - return super().encode(documents=converted_documents, **kwargs) - - def decode(self, **kwargs) -> Sequence[DocumentType]: - decoded_documents = super().decode(**kwargs) - result = [] - for doc in decoded_documents: - original_document = doc.metadata["original_document"] - self._integrate_predictions_from_converted_document( - converted_document=doc, document=original_document - ) - result.append(original_document) - return result - - @abstractmethod - def _integrate_predictions_from_converted_document( - self, - document: DocumentType, - converted_document: ConvertedDocumentType, - ) -> None: - """Convert the predictions at the respective layers of the converted_document and add them - to the original document predictions. - - Args: - document: document to attach the converted predictions to - converted_document: the document returned by the wrapped taskmodule, including predictions - """ - pass diff --git a/src/pie_modules/taskmodules/common/utils.py b/src/pie_modules/taskmodules/common/utils.py deleted file mode 100644 index cc4daf626..000000000 --- a/src/pie_modules/taskmodules/common/utils.py +++ /dev/null @@ -1,32 +0,0 @@ -import logging -from typing import Union - -import torch - -logger = logging.getLogger(__name__) - - -def get_first_occurrence_index( - tensor: Union[torch.FloatTensor, torch.LongTensor], value: Union[float, int] -) -> torch.LongTensor: - """Returns the index of the first occurrence of `value` in each row of `tensor`. If `value` is - not found, seq_len is returned. - - Args: - tensor: the tensor of shape (bsz, seq_len) to search in - value: the value to search for - - Returns: a tensor of shape (bsz,) containing the index of the first occurrence of `value` in each row of `tensor`. - """ - - mask_value = tensor.eq(value) - # count matching positions from the end - value_counts_to_end = mask_value.flip(dims=[1]).cumsum(dim=1).flip(dims=[1]) - # at the first position stands the number of total matches - total_matches = value_counts_to_end[:, 0] - # the sum of all positions where the number of matches is equal to the total number of matches - # is the index *after* the first occurrence - result = value_counts_to_end.eq(total_matches.unsqueeze(-1)).sum(dim=1) - 1 - # set result to seq_len if no match was found - result[total_matches == 0] = tensor.size(1) - return result diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py index cff2fa139..e69de29bb 100644 --- a/src/pie_modules/taskmodules/cross_text_binary_coref.py +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -1,292 +0,0 @@ -import copy -import logging -from typing import ( - Any, - Dict, - Iterable, - Iterator, - List, - Optional, - Sequence, - Tuple, - TypedDict, - TypeVar, - Union, -) - -import torch -from pie_core import Annotation, TaskEncoding, TaskModule -from pie_core.utils.dictionary import list_of_dicts2dict_of_lists -from pytorch_ie.utils.window import get_window_around_slice -from torchmetrics import MetricCollection -from torchmetrics.classification import ( - BinaryAUROC, - BinaryAveragePrecision, - BinaryF1Score, -) -from transformers import AutoTokenizer, BatchEncoding -from typing_extensions import TypeAlias - -from pie_modules.annotations import Span -from pie_modules.documents import ( - TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, -) -from pie_modules.taskmodules.common.mixins import RelationStatisticsMixin -from pie_modules.taskmodules.metrics import WrappedMetricWithPrepareFunction -from pie_modules.utils.tokenization import ( - SpanNotAlignedWithTokenException, - get_aligned_token_span, -) - -logger = logging.getLogger(__name__) - -InputEncodingType: TypeAlias = Dict[str, Any] -TargetEncodingType: TypeAlias = Sequence[float] -DocumentType: TypeAlias = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations - -TaskEncodingType: TypeAlias = TaskEncoding[ - DocumentType, - InputEncodingType, - TargetEncodingType, -] - - -class TaskOutputType(TypedDict, total=False): - score: float - is_similar: bool - - -ModelInputType: TypeAlias = Dict[str, torch.Tensor] -ModelTargetType: TypeAlias = Dict[str, torch.Tensor] -ModelOutputType: TypeAlias = Dict[str, torch.Tensor] - -TaskModuleType: TypeAlias = TaskModule[ - # _InputEncoding, _TargetEncoding, _TaskBatchEncoding, _ModelBatchOutput, _TaskOutput - DocumentType, - InputEncodingType, - TargetEncodingType, - Tuple[ModelInputType, Optional[ModelTargetType]], - ModelTargetType, - TaskOutputType, -] - - -class SpanDoesNotFitIntoAvailableWindow(Exception): - def __init__(self, span): - self.span = span - - -def _get_labels(model_output: ModelTargetType, label_threshold: float) -> torch.Tensor: - return (model_output["scores"] > label_threshold).to(torch.int) - - -def _get_scores(model_output: ModelTargetType) -> torch.Tensor: - return model_output["scores"] - - -S = TypeVar("S", bound=Span) - - -def shift_span(span: S, offset: int) -> S: - return span.copy(start=span.start + offset, end=span.end + offset) - - -@TaskModule.register() -class CrossTextBinaryCorefTaskModule(RelationStatisticsMixin, TaskModuleType): - """This taskmodule processes documents of type - TextPairDocumentWithLabeledSpansAndBinaryCorefRelations in preparation for a - SequencePairSimilarityModelWithPooler.""" - - DOCUMENT_TYPE = DocumentType - - def __init__( - self, - tokenizer_name_or_path: str, - similarity_threshold: float = 0.9, - max_window: Optional[int] = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.save_hyperparameters() - - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) - self.similarity_threshold = similarity_threshold - self.max_window = max_window if max_window is not None else self.tokenizer.model_max_length - self.available_window = self.max_window - self.tokenizer.num_special_tokens_to_add() - self.num_special_tokens_before = len(self._get_special_tokens_before_input()) - - def _get_special_tokens_before_input(self) -> List[int]: - dummy_ids = self.tokenizer.build_inputs_with_special_tokens(token_ids_0=[-1]) - return dummy_ids[: dummy_ids.index(-1)] - - def encode(self, documents: Union[DocumentType, Iterable[DocumentType]], **kwargs): - self.reset_statistics() - result = super().encode(documents=documents, **kwargs) - self.show_statistics() - return result - - def truncate_encoding_around_span( - self, encoding: BatchEncoding, char_span: Span - ) -> Tuple[Dict[str, List[int]], Span]: - input_ids = copy.deepcopy(encoding["input_ids"]) - - token_span = get_aligned_token_span(encoding=encoding, char_span=char_span) - - # truncate input_ids and shift token_start and token_end - if len(input_ids) > self.available_window: - window_slice = get_window_around_slice( - slice=(token_span.start, token_span.end), - max_window_size=self.available_window, - available_input_length=len(input_ids), - ) - if window_slice is None: - raise SpanDoesNotFitIntoAvailableWindow(span=token_span) - window_start, window_end = window_slice - input_ids = input_ids[window_start:window_end] - token_span = shift_span(token_span, offset=-window_start) - - truncated_encoding = self.tokenizer.prepare_for_model(ids=input_ids) - # shift indices because we added special tokens to the input_ids - token_span = shift_span(token_span, offset=self.num_special_tokens_before) - - return truncated_encoding, token_span - - def encode_input( - self, - document: DocumentType, - is_training: bool = False, - ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]: - self.collect_all_relations(kind="available", relations=document.binary_coref_relations) - tokenizer_kwargs = dict( - padding=False, - truncation=False, - add_special_tokens=False, - ) - encoding = self.tokenizer(text=document.text, **tokenizer_kwargs) - encoding_pair = self.tokenizer(text=document.text_pair, **tokenizer_kwargs) - - task_encodings = [] - for coref_rel in document.binary_coref_relations: - # TODO: This can miss instances if both texts are the same. We could check that - # coref_rel.head is in document.labeled_spans (same for the tail), but would this - # slow down the encoding? - if not ( - coref_rel.head.target == document.text - or coref_rel.tail.target == document.text_pair - ): - raise ValueError( - f"It is expected that coref relations go from (head) spans over 'text' " - f"to (tail) spans over 'text_pair', but this is not the case for this " - f"relation (i.e. it points into the other direction): {coref_rel.resolve()}" - ) - try: - current_encoding, token_span = self.truncate_encoding_around_span( - encoding=encoding, char_span=coref_rel.head - ) - current_encoding_pair, token_span_pair = self.truncate_encoding_around_span( - encoding=encoding_pair, char_span=coref_rel.tail - ) - except SpanNotAlignedWithTokenException as e: - logger.warning( - f"Could not get token offsets for argument ({e.span}) of coref relation: " - f"{coref_rel.resolve()}. Skip it." - ) - self.collect_relation(kind="skipped_args_not_aligned", relation=coref_rel) - continue - except SpanDoesNotFitIntoAvailableWindow as e: - logger.warning( - f"Argument span [{e.span}] does not fit into available token window " - f"({self.available_window}). Skip it." - ) - self.collect_relation( - kind="skipped_span_does_not_fit_into_window", relation=coref_rel - ) - continue - - task_encodings.append( - TaskEncoding( - document=document, - inputs={ - "encoding": current_encoding, - "encoding_pair": current_encoding_pair, - "pooler_start_indices": token_span.start, - "pooler_end_indices": token_span.end, - "pooler_pair_start_indices": token_span_pair.start, - "pooler_pair_end_indices": token_span_pair.end, - }, - metadata={"candidate_annotation": coref_rel}, - ) - ) - self.collect_relation("used", coref_rel) - return task_encodings - - def encode_target( - self, - task_encoding: TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType], - ) -> Optional[TargetEncodingType]: - return task_encoding.metadata["candidate_annotation"].score - - def collate( - self, - task_encodings: Sequence[ - TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType] - ], - ) -> Tuple[ModelInputType, Optional[ModelTargetType]]: - inputs_dict = list_of_dicts2dict_of_lists( - [task_encoding.inputs for task_encoding in task_encodings] - ) - - inputs = { - k: ( - self.tokenizer.pad(v, return_tensors="pt").data - if k in ["encoding", "encoding_pair"] - else torch.tensor(v) - ) - for k, v in inputs_dict.items() - } - for k, v in inputs.items(): - if k.startswith("pooler_") and k.endswith("_indices"): - inputs[k] = v.unsqueeze(-1) - - if not task_encodings[0].has_targets: - return inputs, None - targets = { - "scores": torch.tensor([task_encoding.targets for task_encoding in task_encodings]) - } - return inputs, targets - - def configure_model_metric(self, stage: str) -> MetricCollection: - return MetricCollection( - metrics={ - "continuous": WrappedMetricWithPrepareFunction( - metric=MetricCollection( - { - "auroc": BinaryAUROC(), - "avg-P": BinaryAveragePrecision(validate_args=False), - # "roc": BinaryROC(validate_args=False), - # "PRCurve": BinaryPrecisionRecallCurve(validate_args=False), - "f1": BinaryF1Score(threshold=self.similarity_threshold), - } - ), - prepare_function=_get_scores, - ), - } - ) - - def unbatch_output(self, model_output: ModelTargetType) -> Sequence[TaskOutputType]: - is_similar = (model_output["scores"] > self.similarity_threshold).detach().cpu().tolist() - scores = model_output["scores"].detach().cpu().tolist() - result: List[TaskOutputType] = [ - {"is_similar": is_sim, "score": prob} for is_sim, prob in zip(is_similar, scores) - ] - return result - - def create_annotations_from_output( - self, - task_encoding: TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType], - task_output: TaskOutputType, - ) -> Iterator[Tuple[str, Annotation]]: - if task_output["is_similar"]: - score = task_output["score"] - new_coref_rel = task_encoding.metadata["candidate_annotation"].copy(score=score) - yield "binary_coref_relations", new_coref_rel diff --git a/src/pie_modules/taskmodules/extractive_question_answering.py b/src/pie_modules/taskmodules/extractive_question_answering.py deleted file mode 100644 index d774cf18b..000000000 --- a/src/pie_modules/taskmodules/extractive_question_answering.py +++ /dev/null @@ -1,239 +0,0 @@ -import dataclasses -import logging -from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union - -import numpy as np -import torch -from pie_core import Annotation, AnnotationLayer, TaskEncoding, TaskModule -from tokenizers import Encoding -from transformers import AutoTokenizer, BatchEncoding, PreTrainedTokenizer -from transformers.modeling_outputs import QuestionAnsweringModelOutput -from typing_extensions import TypeAlias - -from pie_modules.annotations import ExtractiveAnswer, Question -from pie_modules.document.processing import tokenize_document -from pie_modules.documents import ( - TextBasedDocument, - TextDocumentWithQuestionsAndExtractiveAnswers, - TokenDocumentWithQuestionsAndExtractiveAnswers, -) - -logger = logging.getLogger(__name__) - - -DocumentType: TypeAlias = TextBasedDocument -InputEncoding: TypeAlias = Union[Dict[str, Any], BatchEncoding] - - -@dataclasses.dataclass -class TargetEncoding: - start_position: int - end_position: int - - -TaskEncodingType: TypeAlias = TaskEncoding[ - TextDocumentWithQuestionsAndExtractiveAnswers, - InputEncoding, - TargetEncoding, -] - -TaskBatchEncoding: TypeAlias = Tuple[BatchEncoding, Optional[Dict[str, Any]]] -ModelBatchOutput: TypeAlias = QuestionAnsweringModelOutput - - -@dataclasses.dataclass -class TaskOutput: - start: int - end: int - start_probability: float - end_probability: float - - -@TaskModule.register() -class ExtractiveQuestionAnsweringTaskModule(TaskModule): - """PIE task module for extractive question answering. - - This task module expects that the document is text based and contains an annotation layer for answers - and one for questions. - - The task module will create a task encoding for each question-answer pair. - The input encoding will be the tokenized document with the question as the second sequence. - The target encoding will be the start and end position of the answer in the context. - The task module will create a dummy target encoding where both start and end index are set to 0 (usually - the CLS token position), if there is no answer for the question. - - Args: - tokenizer_name_or_path: The name (Huggingface Hub identifier) or local path to a config of the tokenizer to use. - max_length: The maximum length of the input sequence in means of tokens. - answer_annotation: The name of the annotation layer for answers. Defaults to "answers". - question_annotation: The name of the annotation layer for questions. Defaults to "questions". - tokenize_kwargs: Additional keyword arguments for the tokenizer. Defaults to None. - """ - - DOCUMENT_TYPE = TextDocumentWithQuestionsAndExtractiveAnswers - - def __init__( - self, - tokenizer_name_or_path: str, - max_length: int, - answer_annotation: str = "answers", - question_annotation: str = "questions", - tokenize_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, - ): - super().__init__(**kwargs) - self.save_hyperparameters() - - self.answer_annotation = answer_annotation - self.question_annotation = question_annotation - self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) - self.max_length = max_length - self.tokenize_kwargs = tokenize_kwargs or {} - - def get_answer_layer(self, document: DocumentType) -> AnnotationLayer[ExtractiveAnswer]: - # we expect that each document have an annotation layer for answers - # where each entry is of type ExtractiveAnswer - return document[self.answer_annotation] - - def get_question_layer(self, document: DocumentType) -> AnnotationLayer[Question]: - answers = self.get_answer_layer(document) - # we expect that the answers annotation layer targets the questions annotation layer - # where each entry is of type Question - return answers.target_layers[self.question_annotation] - - def get_context(self, document: DocumentType) -> str: - answers = self.get_answer_layer(document) - # we expect that the answers annotation layer targets the text field - # which is a simple string - return answers.targets["text"] - - def encode_input( - self, - document: DocumentType, - is_training: bool = False, - ) -> Optional[ - Union[ - TaskEncoding[DocumentType, InputEncoding, TargetEncoding], - Sequence[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]], - ] - ]: - questions = self.get_question_layer(document) - task_encodings: List[TaskEncodingType] = [] - for question in questions: - tokenized_docs = tokenize_document( - document, - tokenizer=self.tokenizer, - text=question.text.strip(), - truncation="only_second", - max_length=self.max_length, - return_overflowing_tokens=True, - result_document_type=TokenDocumentWithQuestionsAndExtractiveAnswers, - strict_span_conversion=False, - verbose=False, - **self.tokenize_kwargs, - ) - for doc in tokenized_docs: - inputs = self.tokenizer.convert_tokens_to_ids(list(doc.tokens)) - task_encodings.append( - TaskEncodingType( - document=document, - inputs=inputs, - metadata=dict(question=question, tokenized_document=doc), - ) - ) - return task_encodings - - def encode_target( - self, - task_encoding: TaskEncodingType, - ) -> Optional[TargetEncoding]: - all_answers = self.get_answer_layer(task_encoding.metadata["tokenized_document"]) - # the document can contain multiple questions, so we filter the answers by the target question - answers = [ - answer - for answer in all_answers - if answer.question == task_encoding.metadata["question"] - ] - # if there is no answer for the target question, we return a dummy target encoding - if len(answers) == 0: - return TargetEncoding(0, 0) - if len(answers) > 1: - logger.warning( - f"The answers annotation layer is expected to have not more than one answer per question, " - f"but it has {len(answers)} answers. We take just the first one." - ) - answer = answers[0] - return TargetEncoding(answer.start, answer.end - 1) - - def collate( - self, task_encodings: Sequence[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]] - ) -> TaskBatchEncoding: - def task_encoding2input_features(task_encoding: TaskEncodingType) -> Dict[str, Any]: - encoding = task_encoding.metadata["tokenized_document"].metadata["tokenizer_encoding"] - return {"input_ids": encoding.ids, "token_type_ids": encoding.type_ids} - - input_features = [ - task_encoding2input_features(task_encoding) for task_encoding in task_encodings - ] - - # will contain: input_ids, token_type_ids, attention_mask - inputs: BatchEncoding = self.tokenizer.pad( - input_features, padding="longest", max_length=self.max_length, return_tensors="pt" - ) - - if not task_encodings[0].has_targets: - return inputs, None - - start_positions = torch.tensor( - [task_encoding.targets.start_position for task_encoding in task_encodings], - dtype=torch.int64, - ) - end_positions = torch.tensor( - [task_encoding.targets.end_position for task_encoding in task_encodings], - dtype=torch.int64, - ) - targets = {"start_positions": start_positions, "end_positions": end_positions} - - return inputs, targets - - def unbatch_output(self, model_output: ModelBatchOutput) -> Sequence[TaskOutput]: - batch_size = len(model_output.start_logits) - start_probs = torch.softmax(model_output.start_logits, dim=-1).detach().cpu().numpy() - end_probs = torch.softmax(model_output.end_logits, dim=-1).detach().cpu().numpy() - best_start = np.argmax(start_probs, axis=1) - best_end = np.argmax(end_probs, axis=1) - return [ - TaskOutput( - start=best_start[i], - end=best_end[i], - start_probability=start_probs[i, best_start[i]], - end_probability=end_probs[i, best_end[i]], - ) - for i in range(batch_size) - ] - - def create_annotations_from_output( - self, - task_encoding: TaskEncoding[DocumentType, InputEncoding, TargetEncoding], - task_output: TaskOutput, - ) -> Iterator[Tuple[str, Annotation]]: - tokenizer_encoding: Encoding = task_encoding.metadata["tokenized_document"].metadata[ - "tokenizer_encoding" - ] - start_chars = tokenizer_encoding.token_to_chars(task_output.start) - end_chars = tokenizer_encoding.token_to_chars(task_output.end) - if start_chars is not None and end_chars is not None: - start_sequence_index = tokenizer_encoding.token_to_sequence(task_output.start) - end_sequence_index = tokenizer_encoding.token_to_sequence(task_output.end) - # the indices need to point into the context which is the second sequence - if start_sequence_index == 1 and end_sequence_index == 1: - start_char = start_chars[0] - end_char = end_chars[-1] - context = self.get_context(task_encoding.document) - if 0 <= start_char < end_char <= len(context): - yield self.answer_annotation, ExtractiveAnswer( - start=start_char, - end=end_char, - question=task_encoding.metadata["question"], - score=float(task_output.start_probability * task_output.end_probability), - ) diff --git a/src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py b/src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py deleted file mode 100644 index ca566a024..000000000 --- a/src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py +++ /dev/null @@ -1,468 +0,0 @@ -""" -workflow: - Document - -> (InputEncoding, TargetEncoding) -> TaskEncoding - -> ModelStepInputType -> ModelBatchOutput - -> TaskOutput - -> Document -""" - -import logging -from functools import partial -from typing import ( - Any, - Dict, - Iterator, - List, - Optional, - Sequence, - Set, - Tuple, - Type, - TypedDict, - Union, -) - -import torch -from pie_core import AnnotationLayer, TaskEncoding, TaskModule -from pie_core.utils.dictionary import list_of_dicts2dict_of_lists -from tokenizers import Encoding -from torchmetrics import F1Score, Metric, MetricCollection, Precision, Recall -from transformers import AutoTokenizer -from typing_extensions import TypeAlias - -from pie_modules.annotations import LabeledSpan -from pie_modules.document.processing import ( - token_based_document_to_text_based, - tokenize_document, -) -from pie_modules.documents import ( - TextBasedDocument, - TextDocumentWithLabeledSpans, - TextDocumentWithLabeledSpansAndLabeledPartitions, - TokenDocumentWithLabeledSpans, - TokenDocumentWithLabeledSpansAndLabeledPartitions, -) -from pie_modules.models.simple_token_classification import InputType as ModelInputType -from pie_modules.models.simple_token_classification import TargetType as ModelTargetType -from pie_modules.taskmodules.metrics import ( - PrecisionRecallAndF1ForLabeledAnnotations, - WrappedMetricWithPrepareFunction, -) -from pie_modules.utils.sequence_tagging import tag_sequence_to_token_spans - -DocumentType: TypeAlias = TextBasedDocument - -InputEncodingType: TypeAlias = Encoding -TargetEncodingType: TypeAlias = Sequence[int] -TaskEncodingType: TypeAlias = TaskEncoding[ - DocumentType, - InputEncodingType, - TargetEncodingType, -] -ModelStepInputType: TypeAlias = Tuple[ - ModelInputType, - Optional[ModelTargetType], -] -ModelOutputType: TypeAlias = ModelTargetType - - -class TaskOutputType(TypedDict, total=False): - labels: torch.LongTensor - probabilities: torch.FloatTensor - - -TaskModuleType: TypeAlias = TaskModule[ - DocumentType, - InputEncodingType, - TargetEncodingType, - ModelStepInputType, - ModelOutputType, - TaskOutputType, -] - -logger = logging.getLogger(__name__) - - -def _get_label_ids_from_model_output( - model_output: ModelTargetType, -) -> torch.LongTensor: - return model_output["labels"] - - -def unbatch_and_decode_annotations( - model_output: ModelOutputType, - taskmodule: "LabeledSpanExtractionByTokenClassificationTaskModule", -) -> List[Sequence[LabeledSpan]]: - task_outputs = taskmodule.unbatch_output(model_output) - annotations = [ - taskmodule.decode_annotations(task_output)["labeled_spans"] for task_output in task_outputs - ] - return annotations - - -@TaskModule.register() -class LabeledSpanExtractionByTokenClassificationTaskModule(TaskModuleType): - """Taskmodule for span prediction (e.g. NER) as token classification. - - This taskmodule expects the input documents to be of TextBasedDocument with an annotation layer of - labeled spans (e.g. TextDocumentWithLabeledSpans). The text is tokenized using the provided tokenizer and - the labels are converted to BIO tags. - - To handle long documents, the text can be windowed using the respective parameters for the tokenizer, - i.e. max_length (and stride). Note, that this requires to set return_overflowing_tokens=True, otherwise just - the first window of input tokens is considered. The windowing is done in a way that the spans are not split - across windows. If a span is split across windows, it is ignored during training and evaluation. Thus, if you - have long spans in your data, it is recommended to set a stride that is as large as the average span length - to avoid missing many spans. - - If a partition annotation is provided, the taskmodule expects the input documents to be of - TextBasedDocument with two annotation layers of labeled spans, one for the spans and one for the partitions - (e.g. TextDocumentWithLabeledSpansAndLabeledPartitions). Then, the text is tokenized and fed to the model - individually per partition (e.g. per sentence). This is useful for long documents that can not be processed - by the model as a whole, but where a natural partitioning exists (e.g. sentences or paragraphs) and, thus, - windowing is not necessary (or a combination of both can be used). - - If labels are not provided, they are collected from the data during the prepare() step. If provided, they act as - whitelist, i.e. spans with labels that are not in the labels are ignored during training and evaluation. - - Args: - tokenizer_name_or_path: Name or path of the HuggingFace tokenizer to use. - span_annotation: Name of the annotation layer that contains the labeled spans. Default: "labeled_spans". - partition_annotation: Name of the annotation layer that contains the labeled partitions. If provided, the - text is tokenized individually per partition. Default: None. - label_pad_id: ID of the padding tag label. The model should ignore this for training. Default: -100. - labels: List of labels to use. If not provided, the labels are collected from the labeled span annotations - in the data during the prepare() step. Default: None. - include_ill_formed_predictions: Whether to include ill-formed predictions in the output. If False, the - predictions are corrected to be well-formed. Default: True. - tokenize_kwargs: Keyword arguments to pass to the tokenizer during tokenization. Default: None. - pad_kwargs: Keyword arguments to pass to the tokenizer during padding. Note, that this is used to pad the - token ids *and* the tag ids, if available (i.e. during training or evaluation). Default: None. - combine_token_scores_method: Method to combine the token scores to a span score. Options are "mean", "max", - "min", and "product". Default: "mean". - log_precision_recall_metrics: Whether to log precision and recall metrics (in addition to F1) for the - spans. Default: True. - """ - - # list of attribute names that need to be set by _prepare() - PREPARED_ATTRIBUTES: List[str] = ["labels"] - - def __init__( - self, - tokenizer_name_or_path: str, - span_annotation: str = "labeled_spans", - partition_annotation: Optional[str] = None, - label_pad_id: int = -100, - labels: Optional[List[str]] = None, - include_ill_formed_predictions: bool = True, - tokenize_kwargs: Optional[Dict[str, Any]] = None, - pad_kwargs: Optional[Dict[str, Any]] = None, - combine_token_scores_method: str = "mean", - log_precision_recall_metrics: bool = True, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.save_hyperparameters() - - self.span_annotation = span_annotation - self.partition_annotation = partition_annotation - self.labels = labels - self.label_pad_id = label_pad_id - self.include_ill_formed_predictions = include_ill_formed_predictions - self.tokenize_kwargs = tokenize_kwargs or {} - self.pad_kwargs = pad_kwargs or {} - self.log_precision_recall_metrics = log_precision_recall_metrics - self.combine_token_scores_method = combine_token_scores_method - - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) - - @property - def document_type(self) -> Optional[Type[TextBasedDocument]]: - dt: Type[TextBasedDocument] - errors = [] - if self.span_annotation != "labeled_spans": - errors.append( - f"span_annotation={self.span_annotation} is not the default value ('labeled_spans')" - ) - if self.partition_annotation is None: - dt = TextDocumentWithLabeledSpans - else: - if self.partition_annotation != "labeled_partitions": - errors.append( - f"partition_annotation={self.partition_annotation} is not the default value " - f"('labeled_partitions')" - ) - dt = TextDocumentWithLabeledSpansAndLabeledPartitions - - if len(errors) == 0: - return dt - else: - logger.warning( - f"{' and '.join(errors)}, so the taskmodule {type(self).__name__} can not request " - f"the usual document type ({dt.__name__}) for auto-conversion because this has the bespoken default " - f"value as layer name(s) instead of the provided one(s)." - ) - return None - - def get_span_layer(self, document: DocumentType) -> AnnotationLayer[LabeledSpan]: - return document[self.span_annotation] - - def _prepare(self, documents: Sequence[DocumentType]) -> None: - # collect all possible labels - labels: Set[str] = set() - for document in documents: - spans: AnnotationLayer[LabeledSpan] = self.get_span_layer(document) - - for span in spans: - labels.add(span.label) - - self.labels = sorted(labels) - logger.info(f"Collected {len(self.labels)} labels from the data: {self.labels}") - - def _post_prepare(self): - # create the real token labels (BIO scheme) from the labels - self.label_to_id = {"O": 0} - current_id = 1 - for label in sorted(self.labels): - for prefix in ["B", "I"]: - self.label_to_id[f"{prefix}-{label}"] = current_id - current_id += 1 - - self.id_to_label = {v: k for k, v in self.label_to_id.items()} - - def encode_input( - self, - document: TextBasedDocument, - ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]: - if self.partition_annotation is None: - tokenized_document_type = TokenDocumentWithLabeledSpans - casted_document_type = TextDocumentWithLabeledSpans - field_mapping = {self.span_annotation: "labeled_spans"} - else: - tokenized_document_type = TokenDocumentWithLabeledSpansAndLabeledPartitions - casted_document_type = TextDocumentWithLabeledSpansAndLabeledPartitions - field_mapping = { - self.span_annotation: "labeled_spans", - self.partition_annotation: "labeled_partitions", - } - casted_document = document.as_type(casted_document_type, field_mapping=field_mapping) - tokenized_docs = tokenize_document( - casted_document, - tokenizer=self.tokenizer, - result_document_type=tokenized_document_type, - partition_layer=( - "labeled_partitions" if self.partition_annotation is not None else None - ), - strict_span_conversion=False, - **self.tokenize_kwargs, - ) - - task_encodings: List[TaskEncodingType] = [] - for tokenized_doc in tokenized_docs: - task_encodings.append( - TaskEncoding( - document=document, - inputs=tokenized_doc.metadata["tokenizer_encoding"], - metadata={"tokenized_document": tokenized_doc}, - ) - ) - - return task_encodings - - def encode_target( - self, - task_encoding: TaskEncodingType, - ) -> Optional[TargetEncodingType]: - metadata = task_encoding.metadata - tokenized_document = metadata["tokenized_document"] - tokenizer_encoding: Encoding = tokenized_document.metadata["tokenizer_encoding"] - - tag_sequence = [ - None if tokenizer_encoding.special_tokens_mask[j] else "O" - for j in range(len(tokenizer_encoding.ids)) - ] - if self.labels is None: - raise ValueError( - "'labels' must be set before calling encode_target(). Was prepare() called on the taskmodule?" - ) - sorted_spans = sorted(tokenized_document.labeled_spans, key=lambda s: (s.start, s.end)) - for span in sorted_spans: - if span.label not in self.labels: - continue - start = span.start - end = span.end - if any(tag != "O" for tag in tag_sequence[start:end]): - logger.warning(f"tag already assigned (current span has an overlap: {span}).") - continue - - tag_sequence[start] = f"B-{span.label}" - for j in range(start + 1, end): - tag_sequence[j] = f"I-{span.label}" - - targets = [ - self.label_to_id[tag] if tag is not None else self.label_pad_id for tag in tag_sequence - ] - - return targets - - def collate(self, task_encodings: Sequence[TaskEncodingType]) -> ModelStepInputType: - input_encodings = [ - { - "input_ids": task_encoding.inputs.ids, - "attention_mask": task_encoding.inputs.attention_mask, - "special_tokens_mask": task_encoding.inputs.special_tokens_mask, - } - for task_encoding in task_encodings - ] - inputs = self.tokenizer.pad( - list_of_dicts2dict_of_lists(input_encodings), return_tensors="pt", **self.pad_kwargs - ) - - if not task_encodings[0].has_targets: - return inputs, None - - tag_ids = [task_encoding.targets for task_encoding in task_encodings] - targets = self.tokenizer.pad( - {"input_ids": tag_ids}, return_tensors="pt", **self.pad_kwargs - )["input_ids"] - - # set the padding label to the label_pad_token_id - pad_mask = inputs["input_ids"] == self.tokenizer.pad_token_id - targets[pad_mask] = self.label_pad_id - - return inputs, {"labels": targets} - - def unbatch_output(self, model_output: ModelOutputType) -> Sequence[TaskOutputType]: - labels = model_output["labels"] - probabilities = model_output.get("probabilities", None) - batch_size = labels.shape[0] - task_outputs: List[TaskOutputType] = [] - for batch_idx in range(batch_size): - task_output: TaskOutputType = {"labels": labels[batch_idx]} - if probabilities is not None: - task_output["probabilities"] = probabilities[batch_idx] - task_outputs.append(task_output) - return task_outputs - - def decode_annotations(self, encoding: TaskOutputType) -> Dict[str, Sequence[LabeledSpan]]: - labels = encoding["labels"] - tag_sequence = [ - "O" if tag_id == self.label_pad_id else self.id_to_label[tag_id] - for tag_id in labels.tolist() - ] - labeled_spans: List[LabeledSpan] = [] - for label, (start, end_inclusive) in tag_sequence_to_token_spans( - tag_sequence, - coding_scheme="IOB2", - include_ill_formed=self.include_ill_formed_predictions, - ): - end = end_inclusive + 1 - # do not set the score if the probabilities are not available - annotation_kwargs = {} - if encoding.get("probabilities") is not None: - span_probabilities = encoding["probabilities"][start:end] - span_label_ids = labels[start:end] - # get the probabilities at the label indices - span_label_probs = torch.stack( - [span_probabilities[i, l] for i, l in enumerate(span_label_ids)] - ) - if self.combine_token_scores_method == "mean": - # use mean probability of the span as score - annotation_kwargs["score"] = span_label_probs.mean().item() - elif self.combine_token_scores_method == "max": - # use max probability of the span as score - annotation_kwargs["score"] = span_label_probs.max().item() - elif self.combine_token_scores_method == "min": - # use min probability of the span as score - annotation_kwargs["score"] = span_label_probs.min().item() - elif self.combine_token_scores_method == "product": - # use product of probabilities of the span as score - annotation_kwargs["score"] = span_label_probs.prod().item() - else: - raise ValueError( - f"combine_token_scores_method={self.combine_token_scores_method} is not supported." - ) - labeled_span = LabeledSpan(label=label, start=start, end=end, **annotation_kwargs) - labeled_spans.append(labeled_span) - return {"labeled_spans": labeled_spans} - - def create_annotations_from_output( - self, - task_encoding: TaskEncodingType, - task_output: TaskOutputType, - ) -> Iterator[Tuple[str, LabeledSpan]]: - tokenized_document = task_encoding.metadata["tokenized_document"] - decoded_annotations = self.decode_annotations(task_output) - - # Note: token_based_document_to_text_based() does not yet consider predictions, so we need to clear - # the main annotations and attach the predictions to that - for layer_name, annotations in decoded_annotations.items(): - tokenized_document[layer_name].clear() - for annotation in annotations: - tokenized_document[layer_name].append(annotation) - - # we can not use self.document_type here because that may be None if self.span_annotation or - # self.partition_annotation is not the default value - document_type = ( - TextDocumentWithLabeledSpansAndLabeledPartitions - if self.partition_annotation - else TextDocumentWithLabeledSpans - ) - untokenized_document: Union[ - TextDocumentWithLabeledSpans, TextDocumentWithLabeledSpansAndLabeledPartitions - ] = token_based_document_to_text_based( - tokenized_document, result_document_type=document_type - ) - - for span in untokenized_document.labeled_spans: - # need to copy the span because it can be attached to only one document - yield self.span_annotation, span.copy() - - def configure_model_metric(self, stage: str) -> Union[Metric, MetricCollection]: - common_metric_kwargs = { - "num_classes": len(self.label_to_id), - "task": "multiclass", - "ignore_index": self.label_pad_id, - } - token_scores = MetricCollection( - { - "token/macro/f1": WrappedMetricWithPrepareFunction( - metric=F1Score(average="macro", **common_metric_kwargs), - prepare_function=_get_label_ids_from_model_output, - ), - "token/micro/f1": WrappedMetricWithPrepareFunction( - metric=F1Score(average="micro", **common_metric_kwargs), - prepare_function=_get_label_ids_from_model_output, - ), - "token/macro/precision": WrappedMetricWithPrepareFunction( - metric=Precision(average="macro", **common_metric_kwargs), - prepare_function=_get_label_ids_from_model_output, - ), - "token/macro/recall": WrappedMetricWithPrepareFunction( - metric=Recall(average="macro", **common_metric_kwargs), - prepare_function=_get_label_ids_from_model_output, - ), - "token/micro/precision": WrappedMetricWithPrepareFunction( - metric=Precision(average="micro", **common_metric_kwargs), - prepare_function=_get_label_ids_from_model_output, - ), - "token/micro/recall": WrappedMetricWithPrepareFunction( - metric=Recall(average="micro", **common_metric_kwargs), - prepare_function=_get_label_ids_from_model_output, - ), - } - ) - - span_scores = PrecisionRecallAndF1ForLabeledAnnotations( - flatten_result_with_sep="/", - prefix="span/", - return_recall_and_precision=self.log_precision_recall_metrics, - ) - span_scores_wrapped = WrappedMetricWithPrepareFunction( - metric=span_scores, - prepare_function=partial(unbatch_and_decode_annotations, taskmodule=self), - prepare_does_unbatch=True, - ) - - return MetricCollection([token_scores, span_scores_wrapped]) diff --git a/src/pie_modules/taskmodules/metrics/__init__.py b/src/pie_modules/taskmodules/metrics/__init__.py deleted file mode 100644 index 8d23c8efb..000000000 --- a/src/pie_modules/taskmodules/metrics/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .precision_recall_and_f1_for_labeled_annotations import ( - PrecisionRecallAndF1ForLabeledAnnotations, -) -from .wrapped_layer_metrics_with_unbatch_and_decode_with_errors_function import ( - WrappedLayerMetricsWithUnbatchAndDecodeWithErrorsFunction, -) -from .wrapped_metric_with_prepare_function import WrappedMetricWithPrepareFunction diff --git a/src/pie_modules/taskmodules/metrics/common.py b/src/pie_modules/taskmodules/metrics/common.py deleted file mode 100644 index 18c8361eb..000000000 --- a/src/pie_modules/taskmodules/metrics/common.py +++ /dev/null @@ -1,38 +0,0 @@ -import logging -from abc import ABC -from typing import Dict, Optional - -import torch -from torch import LongTensor, Tensor -from torchmetrics import Metric - -logger = logging.getLogger(__name__) - - -class MetricWithArbitraryCounts(Metric, ABC): - """A metric that hold counts for arbitrary keys.""" - - def inc_counts(self, counts: LongTensor, key: Optional[str], prefix: str = "counts_"): - full_key = prefix - if key is not None: - full_key += key - - if not hasattr(self, full_key): - self.add_state(full_key, default=torch.zeros_like(counts), dist_reduce_fx="sum") - - prev_value = getattr(self, full_key) - setattr(self, full_key, prev_value + counts) - - def get_counts(self, key_prefix: str = "counts_") -> Dict[Optional[str], LongTensor]: - result = {} - for k, v in self.metric_state.items(): - if k.startswith(key_prefix): - if not isinstance(v, Tensor): - raise ValueError( - f"Expected metric state for key {k} to be a LongTensor, but got {type(v)}." - ) - if not isinstance(v, LongTensor): - v = v.long() - key = k[len(key_prefix) :] or None - result[key] = v - return result diff --git a/src/pie_modules/taskmodules/metrics/precision_recall_and_f1_for_labeled_annotations.py b/src/pie_modules/taskmodules/metrics/precision_recall_and_f1_for_labeled_annotations.py deleted file mode 100644 index 2bb83eb4c..000000000 --- a/src/pie_modules/taskmodules/metrics/precision_recall_and_f1_for_labeled_annotations.py +++ /dev/null @@ -1,137 +0,0 @@ -import logging -from collections import Counter -from typing import Any, Collection, Dict, Iterable, Optional, Union - -import torch -from pie_core import Annotation -from pie_core.utils.dictionary import flatten_dict_s -from torch import LongTensor - -from .common import MetricWithArbitraryCounts - -logger = logging.getLogger(__name__) - - -class PrecisionRecallAndF1ForLabeledAnnotations(MetricWithArbitraryCounts): - """Computes precision, recall and F1 for labeled annotations. Inputs and targets are lists of - annotations. True positives are counted as the number of annotations that are the same in both - inputs and targets calculated as exact matches via set operation, false positives and false - negatives accordingly. The annotations are deduplicated for each instance. But if the same - annotation occurs in different instances, it is counted as two separate annotations. - - Args: - label_mapping: A dictionary mapping annotation labels to human-readable labels. If None, - the annotation labels are used as they are. Can be used to map label ids to string labels. - key_micro: The key to use for the micro-average in the metric result dictionary. - in_percent: Whether to return the results in percent, i.e. values between 0 and 100 instead of - between 0 and 1. - flatten_result_with_sep: If not None, the result dictionary is flattened and the keys of the - different nesting levels are concatenated with the given separator. - prefix: If not None, the most outer keys of the result dictionary are prefixed with this string. - return_recall_and_precision: Whether to return recall and precision in addition to F1. - """ - - def __init__( - self, - label_mapping: Optional[Dict[Any, str]] = None, - key_micro: Optional[str] = "micro", - key_macro: Optional[str] = "macro", - in_percent: bool = False, - flatten_result_with_sep: Optional[str] = None, - prefix: Optional[str] = None, - return_recall_and_precision: bool = True, - ): - super().__init__() - self.label_mapping = label_mapping - self.key_micro = key_micro - self.key_macro = key_macro - self.in_percent = in_percent - self.flatten_result_with_sep = flatten_result_with_sep - self.prefix = prefix - self.return_recall_and_precision = return_recall_and_precision - - def update(self, gold: Iterable[Annotation], predicted: Iterable[Annotation]) -> None: - # remove duplicates within each list - gold_set = set(gold) - predicted_set = set(predicted) - new_counts = self.calculate_counts(gold_set, predicted_set, gold_set & predicted_set) - for k, v in new_counts.items(): - self.inc_counts(counts=v, key=k) - - def get_precision_recall_f1(self, n_gold_predicted_correct: LongTensor) -> Dict[str, float]: - n_gold = n_gold_predicted_correct[0] - n_predicted = n_gold_predicted_correct[1] - n_correct = n_gold_predicted_correct[2] - zero = torch.tensor(0.0).to(self.device) - recall = zero if n_gold == 0 else (n_correct / n_gold) - precision = zero if n_predicted == 0 else (n_correct / n_predicted) - f1 = zero if recall + precision == 0 else (2 * precision * recall) / (precision + recall) - - result = {"f1": f1} - if self.return_recall_and_precision: - result["recall"] = recall - result["precision"] = precision - - if self.in_percent: - result = {k: v * 100 for k, v in result.items()} - return result - - def get_label(self, annotation: Annotation) -> Optional[str]: - label: Optional[str] = getattr(annotation, "label", None) - if self.label_mapping is not None: - return self.label_mapping[label] - return label - - def calculate_counts( - self, - gold: Collection[Annotation], - predicted: Collection[Annotation], - correct: Collection[Annotation], - ) -> Dict[Optional[str], LongTensor]: - result = {} - # per class - gold_counter = Counter([self.get_label(ann) for ann in gold]) - predicted_counter = Counter([self.get_label(ann) for ann in predicted]) - correct_counter = Counter([self.get_label(ann) for ann in correct]) - for label in gold_counter.keys() | predicted_counter.keys(): - if self.key_micro is not None and label == self.key_micro: - raise ValueError( - f"The key '{self.key_micro}' was used as an annotation label, but it is reserved for " - f"the micro average. You can change which key is used for that with the 'key_micro' argument." - ) - result[label] = torch.tensor( - [ - gold_counter.get(label, 0), - predicted_counter.get(label, 0), - correct_counter.get(label, 0), - ] - ).to(device=self.device) - - # overall - if self.key_micro is not None: - result[self.key_micro] = torch.tensor([len(gold), len(predicted), len(correct)]).to( - device=self.device - ) - return result - - def compute(self) -> Union[Dict[str, Any], Dict[Optional[str], dict[str, float]]]: - counts = self.get_counts() - result = {label: self.get_precision_recall_f1(counts[label]) for label in counts.keys()} - if self.key_macro is not None: - result_without_micro = { - k: v for k, v in result.items() if self.key_micro is None or k != self.key_micro - } - if len(result_without_micro) > 0: - sub_keys = list(result_without_micro.values())[0].keys() - result[self.key_macro] = { - k: torch.stack([v[k] for v in result_without_micro.values()]).mean() - for k in sub_keys - } - - if self.prefix is not None: - result = {f"{self.prefix}{k}": v for k, v in result.items()} - - if self.flatten_result_with_sep is not None: - return flatten_dict_s(result, sep=self.flatten_result_with_sep) - else: - return result diff --git a/src/pie_modules/taskmodules/metrics/wrapped_layer_metrics_with_unbatch_and_decode_with_errors_function.py b/src/pie_modules/taskmodules/metrics/wrapped_layer_metrics_with_unbatch_and_decode_with_errors_function.py deleted file mode 100644 index c7c45aaea..000000000 --- a/src/pie_modules/taskmodules/metrics/wrapped_layer_metrics_with_unbatch_and_decode_with_errors_function.py +++ /dev/null @@ -1,147 +0,0 @@ -import logging -from typing import Any, Callable, Dict, Generic, Optional, Sequence, Tuple, TypeVar - -import torch -from torch.nn import ModuleDict -from torchmetrics import Metric - -from .common import MetricWithArbitraryCounts - -logger = logging.getLogger(__name__) -T = TypeVar("T") -U = TypeVar("U") - - -class WrappedLayerMetricsWithUnbatchAndDecodeWithErrorsFunction( - MetricWithArbitraryCounts, Generic[T, U] -): - """A wrapper around annotation layer metrics that can be used with batched encoded annotations. - - Args: - layer_metrics: A dictionary mapping layer names to annotation layer metrics. Each metric - should be a subclass of torchmetrics.Metric and should take two sets of annotations as - input. - unbatch_function: A function that takes a batched input and returns an iterable of - individual inputs. This is used to unbatch the input before passing it to the annotation - decoding function (decode_annotations_with_errors_function). - decode_layers_with_errors_function: A function that takes an annotation encoding and - returns a tuple of two dictionaries. The first dictionary maps layer names to a list of - annotations. The second dictionary maps error names to the number of errors that were - encountered while decoding the annotations. - round_precision: The number of digits to round the results to. If None, no rounding is - performed. - error_key_correct: The key in the error dictionary whose value should be the number of *correctly* - decoded annotations, so that the sum of all values in the error dictionary can be used to - normalize the error counts. If None, the total number of training examples is used to - normalize the error counts. - collect_exact_encoding_matches: Whether to collect the number of examples where the full target encoding - was predicted correctly (exact matches). - """ - - def __init__( - self, - layer_metrics: Dict[str, Metric], - unbatch_function: Callable[[T], Sequence[U]], - decode_layers_with_errors_function: Callable[[U], Tuple[Dict[str, Any], Dict[str, int]]], - round_precision: Optional[int] = 4, - error_key_correct: Optional[str] = None, - collect_exact_encoding_matches: bool = True, - ): - super().__init__() - - self.key_error_correct = error_key_correct - self.collect_exact_encoding_matches = collect_exact_encoding_matches - self.round_precision = round_precision - self.unbatch_function = unbatch_function - self.decode_layers_with_errors_function = decode_layers_with_errors_function - self.layer_metrics = ModuleDict(layer_metrics) - - # total number of encodings - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - # this contains the number of examples where the full target sequence was predicted correctly (exact matches) - self.add_state("exact_encoding_matches", default=torch.tensor(0), dist_reduce_fx="sum") - # note: the error counts are stored via the MetricWithArbitraryCounts base class - - def update(self, prediction, expected): - prediction_list = self.unbatch_function(prediction) - expected_list = self.unbatch_function(expected) - if len(prediction_list) != len(expected_list): - raise ValueError( - f"Number of predictions ({len(prediction_list)}) and targets ({len(expected_list)}) do not match." - ) - - for expected_encoding, prediction_encoding in zip(expected_list, prediction_list): - expected_layers, _ = self.decode_layers_with_errors_function(expected_encoding) - predicted_layers, predicted_errors = self.decode_layers_with_errors_function( - prediction_encoding - ) - for k, v in predicted_errors.items(): - self.inc_counts(counts=torch.tensor(v).to(self.device), key=k, prefix="errors_") - - for layer_name, metric in self.layer_metrics.items(): - metric.update(expected_layers[layer_name], predicted_layers[layer_name]) - - if self.collect_exact_encoding_matches: - if isinstance(expected_encoding, torch.Tensor) and isinstance( - prediction_encoding, torch.Tensor - ): - is_match = torch.equal(expected_encoding, prediction_encoding) - else: - is_match = expected_encoding == prediction_encoding - if is_match: - self.exact_encoding_matches += 1 - - self.total += 1 - - def reset(self): - super().reset() - - for metric in self.layer_metrics.values(): - metric.reset() - - def _nested_round(self, d: Dict[str, Any]) -> Dict[str, Any]: - if self.round_precision is None: - return d - res: Dict[str, Any] = {} - for k, v in d.items(): - if isinstance(v, dict): - res[k] = self._nested_round(v) - elif isinstance(v, float): - res[k] = round(v, self.round_precision) - else: - res[k] = v - return res - - def compute(self): - res = {} - if self.collect_exact_encoding_matches: - res["exact_encoding_matches"] = ( - self.exact_encoding_matches / self.total if self.total > 0 else 0.0 - ) - - errors = self.get_counts(key_prefix="errors_") - # if errors contains a "correct" key, use that to normalize, otherwise use the number of training examples - if self.key_error_correct in errors: - errors_total = sum(errors.values()) - else: - errors_total = self.total - res["decoding_errors"] = { - k: v / errors_total if errors_total > 0 else 0.0 for k, v in errors.items() - } - if "all" not in res["decoding_errors"]: - res["decoding_errors"]["all"] = ( - sum(v for k, v in errors.items() if k != self.key_error_correct) / errors_total - if errors_total > 0 - else 0.0 - ) - - for layer_name, metric in self.layer_metrics.items(): - if layer_name in res: - raise ValueError( - f"Layer name '{layer_name}' is already used in the metric result dictionary." - ) - res[layer_name] = metric.compute() - - res = self._nested_round(res) - - return res diff --git a/src/pie_modules/taskmodules/metrics/wrapped_metric_with_prepare_function.py b/src/pie_modules/taskmodules/metrics/wrapped_metric_with_prepare_function.py deleted file mode 100644 index 7daceb9dd..000000000 --- a/src/pie_modules/taskmodules/metrics/wrapped_metric_with_prepare_function.py +++ /dev/null @@ -1,129 +0,0 @@ -import logging -from collections.abc import Collection, Sized -from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union - -from torch import Tensor -from torchmetrics import Metric, MetricCollection -from torchmetrics.wrappers.abstract import WrapperMetric - -logger = logging.getLogger(__name__) - -T = TypeVar("T") -T2 = TypeVar("T2") - - -class WrappedMetricWithPrepareFunction(WrapperMetric, Generic[T]): - """A wrapper around a metric that can be used with predictions and targets that are need to be - prepared (e.g. un-batched) before passing them to the metric. - - Args: - metric: The metric to wrap. It should be a subclass of torchmetrics.Metric. - prepare_function: A function that prepares the input for the metric. If provided, It is called with - the predictions as well as the targets (separately). - prepare_together_function: A function that prepares both the predictions and the targets together and - should return them as a tuple. If provided, it is called with the predictions and the targets as - arguments. - prepare_does_unbatch: If True, the prepare_function is expected to return an iterable of - individual inputs. This can be used to un-batch the input before passing it to the - wrapped metric. - """ - - def __init__( - self, - metric: Union[Metric, MetricCollection], - prepare_function: Optional[Callable[[T], Any]] = None, - prepare_together_function: Optional[Callable[[T, T], Tuple[Any, Any]]] = None, - prepare_does_unbatch: bool = False, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.metric = metric - self.prepare_function = prepare_function - self.prepare_both_function = prepare_together_function - self.prepare_does_unbatch = prepare_does_unbatch - - def _is_empty_batch(self, prediction: T2, target: T2) -> bool: - if isinstance(prediction, Sized) and isinstance(target, Sized): - pred_len = len(prediction) - target_len = len(target) - else: - raise ValueError( - "Both prediction and target need to be sized when prepare_does_unbatch=False." - ) - if pred_len != target_len: - raise ValueError( - f"Number of elements in prediction ({pred_len}) and target ({target_len}) do not match." - ) - if pred_len == 0: - return True - return False - - def forward(self, prediction: T, target: T) -> Any: - if self.prepare_function is not None: - prediction = self.prepare_function(prediction) - target = self.prepare_function(target) - if self.prepare_both_function is not None: - prediction, target = self.prepare_both_function(prediction, target) - if self.prepare_does_unbatch: - if not isinstance(prediction, Collection) or not isinstance(target, Collection): - raise ValueError( - "Both prediction and target need to be iterable and sized when prepare_does_unbatch=True." - ) - if len(prediction) != len(target): - raise ValueError( - f"Number of prepared predictions ({len(prediction)}) and targets " - f"({len(target)}) do not match." - ) - if len(prediction) == 0: - raise ValueError("Empty batch.") - results = [] - for prediction_str, target_str in zip(prediction, target): - current_result = self.metric(prediction_str, target_str) - results.append(current_result) - return results - else: - if not self._is_empty_batch(prediction, target): - return self.metric(prediction, target) - else: - return None - - def update(self, prediction: T, target: T) -> None: - if self.prepare_function is not None: - prediction = self.prepare_function(prediction) - target = self.prepare_function(target) - if self.prepare_both_function is not None: - prediction, target = self.prepare_both_function(prediction, target) - if self.prepare_does_unbatch: - if not isinstance(prediction, Collection) or not isinstance(target, Collection): - raise ValueError( - "Both prediction and target need to be iterable and sized when prepare_does_unbatch=True." - ) - if len(prediction) != len(target): - raise ValueError( - f"Number of prepared predictions ({len(prediction)}) and targets " - f"({len(target)}) do not match." - ) - if len(prediction) == 0: - raise ValueError("Empty batch.") - for prediction_str, target_str in zip(prediction, target): - self.metric.update(prediction_str, target_str) - else: - if not self._is_empty_batch(prediction, target): - self.metric.update(prediction, target) - - def compute(self) -> Any: - return self.metric.compute() - - def reset(self) -> None: - self.metric.reset() - - @property - def metric_state(self) -> Dict[str, Union[List[Tensor], Tensor]]: - if isinstance(self.metric, Metric): - return self.metric.metric_state - elif isinstance(self.metric, MetricCollection): - return { - metric_name: metric.metric_state for metric_name, metric in self.metric.items() - } - else: - raise ValueError(f"Unsupported metric type: {type(self.metric)}") diff --git a/src/pie_modules/taskmodules/pointer_network/__init__.py b/src/pie_modules/taskmodules/pointer_network/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/pie_modules/taskmodules/pointer_network/annotation_encoder_decoder.py b/src/pie_modules/taskmodules/pointer_network/annotation_encoder_decoder.py deleted file mode 100644 index 5d48629bd..000000000 --- a/src/pie_modules/taskmodules/pointer_network/annotation_encoder_decoder.py +++ /dev/null @@ -1,397 +0,0 @@ -import logging -from collections import defaultdict -from typing import Any, Dict, List, Optional, Set, Tuple - -from pie_modules.annotations import BinaryRelation, LabeledSpan, Span -from pie_modules.taskmodules.common import AnnotationEncoderDecoder, DecodingException - -logger = logging.getLogger(__name__) - - -class DecodingLengthException(DecodingException[List[int]]): - identifier = "len" - - -class DecodingOrderException(DecodingException[List[int]]): - identifier = "order" - - -class DecodingSpanOverlapException(DecodingException[List[int]]): - identifier = "overlap" - - -class DecodingLabelException(DecodingException[List[int]]): - identifier = "label" - - -class DecodingNegativeIndexException(DecodingException[List[int]]): - identifier = "index" - - -KEY_INVALID_CORRECT = "correct" - - -class SpanEncoderDecoder(AnnotationEncoderDecoder[Span, List[int]]): - def __init__(self, exclusive_end: bool = True): - self.exclusive_end = exclusive_end - - def encode(self, annotation: Span, metadata: Optional[Dict[str, Any]] = None) -> List[int]: - end_idx = annotation.end - if not self.exclusive_end: - end_idx -= 1 - return [annotation.start, end_idx] - - def decode(self, encoding: List[int], metadata: Optional[Dict[str, Any]] = None) -> Span: - if len(encoding) != 2: - raise DecodingLengthException( - f"two values are required to decode as Span, but encoding has length {len(encoding)}", - encoding=encoding, - ) - end_idx = encoding[1] - if not self.exclusive_end: - end_idx += 1 - if end_idx < encoding[0]: - raise DecodingOrderException( - f"end index can not be smaller than start index, but got: start={encoding[0]}, " - f"end={end_idx}", - encoding=encoding, - ) - if any(idx < 0 for idx in encoding): - raise DecodingNegativeIndexException( - f"indices must be positive, but got: {encoding}", encoding=encoding - ) - return Span(start=encoding[0], end=end_idx) - - -class SpanEncoderDecoderWithOffset(SpanEncoderDecoder): - def __init__(self, offset: int, **kwargs): - super().__init__(**kwargs) - self.offset = offset - - def encode(self, annotation: Span, metadata: Optional[Dict[str, Any]] = None) -> List[int]: - encoding = super().encode(annotation=annotation, metadata=metadata) - return [x + self.offset for x in encoding] - - def decode(self, encoding: List[int], metadata: Optional[Dict[str, Any]] = None) -> Span: - encoding = [x - self.offset for x in encoding] - return super().decode(encoding=encoding, metadata=metadata) - - -class LabeledSpanEncoderDecoder(AnnotationEncoderDecoder[LabeledSpan, List[int]]): - def __init__( - self, - span_encoder_decoder: AnnotationEncoderDecoder[Span, List[int]], - label2id: Dict[str, int], - mode: str, - ): - self.span_encoder_decoder = span_encoder_decoder - self.label2id = label2id - self.id2label = {idx: label for label, idx in self.label2id.items()} - self.mode = mode - - def encode( - self, annotation: LabeledSpan, metadata: Optional[Dict[str, Any]] = None - ) -> List[int]: - encoded_span = self.span_encoder_decoder.encode(annotation=annotation, metadata=metadata) - encoded_label = self.label2id[annotation.label] - if self.mode == "indices_label": - return encoded_span + [encoded_label] - elif self.mode == "label_indices": - return [encoded_label] + encoded_span - else: - raise ValueError(f"unknown mode: {self.mode}") - - def decode( - self, encoding: List[int], metadata: Optional[Dict[str, Any]] = None - ) -> LabeledSpan: - if self.mode == "label_indices": - encoded_label = encoding[0] - encoded_span = encoding[1:] - elif self.mode == "indices_label": - encoded_label = encoding[-1] - encoded_span = encoding[:-1] - else: - raise ValueError(f"unknown mode: {self.mode}") - - decoded_span = self.span_encoder_decoder.decode(encoding=encoded_span, metadata=metadata) - if encoded_label not in self.id2label: - raise DecodingLabelException( - f"unknown label id: {encoded_label} (label2id: {self.label2id})", encoding=encoding - ) - result = LabeledSpan( - start=decoded_span.start, - end=decoded_span.end, - label=self.id2label[encoded_label], - ) - return result - - -class BinaryRelationEncoderDecoder(AnnotationEncoderDecoder[BinaryRelation, List[int]]): - def __init__( - self, - head_encoder_decoder: AnnotationEncoderDecoder[Span, List[int]], - tail_encoder_decoder: AnnotationEncoderDecoder[Span, List[int]], - label2id: Dict[str, int], - mode: str, - loop_dummy_relation_name: Optional[str] = None, - none_label: Optional[str] = None, - ): - self.head_encoder_decoder = head_encoder_decoder - self.tail_encoder_decoder = tail_encoder_decoder - self.loop_dummy_relation_name = loop_dummy_relation_name - self.none_label = none_label - self.label2id = label2id - self.id2label = {idx: label for label, idx in self.label2id.items()} - self.mode = mode - - def encode( - self, annotation: BinaryRelation, metadata: Optional[Dict[str, Any]] = None - ) -> List[int]: - encoded_head = self.head_encoder_decoder.encode(annotation=annotation.head) - encoded_tail = self.tail_encoder_decoder.encode(annotation=annotation.tail) - - if ( - self.loop_dummy_relation_name is not None - and annotation.label == self.loop_dummy_relation_name - ): - if annotation.head != annotation.tail: - raise ValueError( - f"expected head == tail for loop_dummy_relation, but got: {annotation.head}, " - f"{annotation.tail}" - ) - if self.none_label is None: - raise ValueError( - f"loop_dummy_relation_name is set, but none_label is not set: {self.none_label}" - ) - none_id = self.label2id[self.none_label] - encoded_none_argument = [none_id, none_id, none_id] - if self.mode == "head_tail_label": - return encoded_head + encoded_none_argument + [none_id] - elif self.mode == "tail_head_label": - return encoded_tail + encoded_none_argument + [none_id] - elif self.mode == "label_head_tail": - return [none_id] + encoded_head + encoded_none_argument - elif self.mode == "label_tail_head": - return [none_id] + encoded_tail + encoded_none_argument - else: - raise ValueError(f"unknown mode: {self.mode}") - else: - encoded_label = self.label2id[annotation.label] - if self.mode == "tail_head_label": - return encoded_tail + encoded_head + [encoded_label] - elif self.mode == "head_tail_label": - return encoded_head + encoded_tail + [encoded_label] - elif self.mode == "label_head_tail": - return [encoded_label] + encoded_head + encoded_tail - elif self.mode == "label_tail_head": - return [encoded_label] + encoded_tail + encoded_head - else: - raise ValueError(f"unknown mode: {self.mode}") - - def is_single_span_label(self, label: str) -> bool: - return self.none_label is not None and label == self.none_label - - def decode( - self, encoding: List[int], metadata: Optional[Dict[str, Any]] = None - ) -> BinaryRelation: - if len(encoding) != 7: - raise DecodingLengthException( - f"seven values are required to decode as BinaryRelation, but the encoding has length {len(encoding)}", - encoding=encoding, - ) - if self.mode.endswith("_label"): - encoded_label = encoding[6] - encoded_arguments = encoding[:6] - argument_mode = self.mode[: -len("_label")] - elif self.mode.startswith("label_"): - encoded_label = encoding[0] - encoded_arguments = encoding[1:] - argument_mode = self.mode[len("label_") :] - else: - raise ValueError(f"unknown mode: {self.mode}") - if encoded_label not in self.id2label: - raise DecodingLabelException( - f"unknown label id: {encoded_label} (label2id: {self.label2id})", encoding=encoding - ) - label = self.id2label[encoded_label] - if self.is_single_span_label(label=label): - if argument_mode == "head_tail": - span_encoder = self.head_encoder_decoder - elif argument_mode == "tail_head": - span_encoder = self.tail_encoder_decoder - else: - raise ValueError(f"unknown argument mode: {argument_mode}") - encoded_span = encoded_arguments[:3] - span = span_encoder.decode(encoding=encoded_span, metadata=metadata) - if self.loop_dummy_relation_name is None: - raise ValueError( - f"loop_dummy_relation_name is not set, but none_label={self.none_label} " - f"was found in decoded encoding: {encoding} (label2id: {self.label2id}))" - ) - rel = BinaryRelation(head=span, tail=span, label=self.loop_dummy_relation_name) - else: - if argument_mode == "head_tail": - encoded_head = encoded_arguments[:3] - encoded_tail = encoded_arguments[3:] - elif argument_mode == "tail_head": - encoded_tail = encoded_arguments[:3] - encoded_head = encoded_arguments[3:] - else: - raise ValueError(f"unknown argument mode: {argument_mode}") - head = self.head_encoder_decoder.decode(encoding=encoded_head, metadata=metadata) - tail = self.tail_encoder_decoder.decode(encoding=encoded_tail, metadata=metadata) - rel = BinaryRelation(head=head, tail=tail, label=label) - - return rel - - def build_decoding_constraints( - self, partial_encoding: List[int] - ) -> Tuple[Optional[Set[int]], Optional[Set[int]]]: - """Given a partial encoding, build the constraints for the next encoding step. - - Returns: - Tuple[Optional[Set[int]], Optional[Set[int]]]: A tuple of two sets of integers representing the allowed - and disallowed next indices. The first set contains the allowed indices, and the second set contains - the disallowed indices. If no constraints are needed, both sets can be None. - """ - allowed = None - disallowed = None - - if self.mode != "tail_head_label": - raise NotImplementedError( - f"build_decoder_constraints is not implemented for mode {self.mode}" - ) - - if self.none_label not in self.label2id: - raise ValueError( - f"none_label not found in label2id: {self.label2id} (none_label: {self.none_label})" - ) - none_id = self.label2id[self.none_label] - if self.head_encoder_decoder != self.tail_encoder_decoder: - raise NotImplementedError( - "head and tail encoder/decoder must be the same for build_decoder_constraints" - ) - - if not isinstance(self.head_encoder_decoder, LabeledSpanEncoderDecoder): - raise NotImplementedError( - "head and tail encoder/decoder must be LabeledSpanEncoderDecoder for build_decoder_constraints" - ) - if not isinstance( - self.head_encoder_decoder.span_encoder_decoder, SpanEncoderDecoderWithOffset - ): - raise NotImplementedError( - "head and tail encoder/decoder must be SpanEncoderDecoderWithOffset for build_decoder_constraints" - ) - pointer_offset = self.head_encoder_decoder.span_encoder_decoder.offset - if self.head_encoder_decoder.mode != "indices_label": - raise NotImplementedError( - "head and tail encoder/decoder must be indices_label for build_decoder_constraints" - ) - if ( - not isinstance(self.head_encoder_decoder.span_encoder_decoder, SpanEncoderDecoder) - or self.head_encoder_decoder.span_encoder_decoder.exclusive_end - ): - raise NotImplementedError( - "head and tail encoder/decoder must be exclusive_end for build_decoder_constraints" - ) - span_ids = set(self.head_encoder_decoder.label2id.values()) - relation_ids = set(self.label2id.values()) - {self.label2id[self.none_label]} - contains_none = none_id in partial_encoding - idx = len(partial_encoding) - if idx == 0: # [] -> first span start or eos - # Disallow all labels: - disallowed = set(range(pointer_offset)) - elif idx == 1: # [14] -> first span end - # Allow all offsets greater than the span start. - span_start = partial_encoding[-1] - # result[span_start:] = 1 - disallowed = set(range(span_start)) - # Disallow the none label: - disallowed.add(none_id) - elif idx == 2: # [14,14] -> first span label - # Allow only span ids. - allowed = span_ids - elif idx == 3: # [14,14,s1] -> second span start or none - # Disallow overlap of first and second span: - first_span_start = partial_encoding[0] - first_span_end = partial_encoding[1] + 1 - disallowed = set(range(first_span_start, first_span_end)) - # Disallow all span labels: - disallowed.update(span_ids) - # Disallow all relation labels: - disallowed.update(relation_ids) - # But allow the none label: - disallowed.discard(none_id) - - elif idx == 4: # [14,14,s1,23] -> second span end or none - # if we have a none label, allow only none - if contains_none: - allowed = {none_id} - else: - - first_span_start = partial_encoding[0] - # first_span_end = partial_encoding[1] + 1 - second_span_start = partial_encoding[-1] - # if first span is after the second span, - if second_span_start < first_span_start: - # just allow the offsets between the two spans: - allowed = set(range(second_span_start, first_span_start)) - else: - # otherwise, disallow all offsets before the second span start: - disallowed = set(range(second_span_start)) - - # Disallow all span labels: - disallowed.update(span_ids) - # Disallow all relation labels: - disallowed.update(relation_ids) - - elif idx == 5: # [14,14,s1,23,25] -> second span label or none - # if we have a none label, allow only none - if contains_none: - # result[none_id] = 1 - allowed = {none_id} - else: - # allow only span ids - allowed = span_ids - elif idx == 6: # [14,14,s1,23,25,s2] -> relation label or none - # if we have a none label, allow only none - if contains_none: - allowed = {none_id} - else: - # allow only relation ids - allowed = relation_ids - else: - raise ValueError( - f"unknown partial encoding length: {len(partial_encoding)} (encoding: {partial_encoding})" - ) - - return allowed, disallowed - - def parse(self, encoding: List[int]) -> Tuple[List[BinaryRelation], Dict[str, int], List[int]]: - errors: Dict[str, int] = defaultdict(int) - if self.none_label is None: - raise ValueError( - f"none_label is not set, but is required for parsing: {self.none_label}" - ) - none_id = self.label2id[self.none_label] - relation_ids = set(self.label2id.values()) - {none_id} - encodings = [] - current_encoding: List[int] = [] - valid_encoding: BinaryRelation - if len(encoding): - for i in encoding: - current_encoding.append(i) - # An encoding is complete when it ends with a relation_id - # or when it contains a none_id and has a length of 7 - if i in relation_ids or (i == none_id and len(current_encoding) == 7): - # try to decode the current relation encoding - try: - valid_encoding = self.decode(encoding=current_encoding) - encodings.append(valid_encoding) - errors[KEY_INVALID_CORRECT] += 1 - except DecodingException as e: - errors[e.identifier] += 1 - - current_encoding = [] - - return encodings, dict(errors), current_encoding diff --git a/src/pie_modules/taskmodules/pointer_network/logits_processor.py b/src/pie_modules/taskmodules/pointer_network/logits_processor.py deleted file mode 100644 index 777fd9763..000000000 --- a/src/pie_modules/taskmodules/pointer_network/logits_processor.py +++ /dev/null @@ -1,67 +0,0 @@ -import math -from typing import Callable, List - -import torch -from transformers import LogitsProcessor, add_start_docstrings -from transformers.generation.logits_process import LOGITS_PROCESSOR_INPUTS_DOCSTRING - - -class PrefixConstrainedLogitsProcessorWithMaximum(LogitsProcessor): - r"""This is similar to [`PrefixConstrainedLogitsProcessor`] but the constraint function gets the - maximum possible index as input. This is useful for Pointer Network where the generated token - can be an index into the input which depends on the length of that input. - - Args: - prefix_allowed_tokens_fn (Callable[[int, torch.LongTensor, int], List[int]]): - Should return the list of token ids allowed at the next generation step, - given (`batch_id`, `input_ids_so_far`, `max_index`). - """ - - def __init__( - self, - prefix_allowed_tokens_fn: Callable[[int, torch.LongTensor, int], List[int]], - num_beams: int, - ): - self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn - self._num_beams = num_beams - - @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor - ) -> torch.FloatTensor: - if not torch.isfinite(scores).all(): - raise ValueError( - "scores contains ±inf or NaN, which is not allowed by " - "PrefixConstrainedLogitsProcessorWithMaximum. " - "Insert FinitizeLogitsProcessor earlier to clean them." - ) - mask = torch.full_like(scores, -math.inf) - for batch_id, beam_sent in enumerate( - input_ids.view(-1, self._num_beams, input_ids.shape[-1]) - ): - for beam_id, sent in enumerate(beam_sent): - allowed_ids = self._prefix_allowed_tokens_fn(batch_id, sent, mask.size(1)) - if len(allowed_ids) == 0: - raise ValueError( - f"No allowed token ids for batch_id {batch_id}, beam_id {beam_id} with " - f"previous ids: {sent}. This would result in undefined behaviour, " - "so this is not allowed. Please adjust the prefix_allowed_tokens_fn " - "implementation." - ) - mask[batch_id * self._num_beams + beam_id, allowed_ids] = 0 - - return scores + mask - - -class FinitizeLogitsProcessor(LogitsProcessor): - r"""Replaces any `±inf` logits with the largest-magnitude finite values for the tensor’s dtype, - ensuring all logits are valid for downstream ops.""" - - @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor - ) -> torch.FloatTensor: - finite_min = torch.finfo(scores.dtype).min - finite_max = torch.finfo(scores.dtype).max - # Use nan_to_num for a fast, fused replacement (PyTorch ≥ 1.8) - return torch.nan_to_num(scores, neginf=finite_min, posinf=finite_max) diff --git a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py deleted file mode 100644 index 9ddb4254b..000000000 --- a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py +++ /dev/null @@ -1,865 +0,0 @@ -import dataclasses -import json -import logging -from collections import Counter, defaultdict -from functools import cmp_to_key -from typing import ( - Any, - Dict, - Iterable, - Iterator, - List, - Optional, - Sequence, - Set, - Tuple, - Type, - Union, -) - -import torch -from pie_core import ( - Annotation, - AnnotationLayer, - Document, - TaskEncoding, - TaskModule, -) -from pie_core.taskmodule import ( - InputEncoding, - ModelBatchOutput, - TargetEncoding, - TaskBatchEncoding, -) -from pie_core.utils.hydra import resolve_type -from torchmetrics import Metric -from transformers import AutoTokenizer, LogitsProcessorList, PreTrainedTokenizer -from typing_extensions import TypeAlias - -from pie_modules.annotations import BinaryRelation, LabeledSpan - -# import for backwards compatibility (don't remove!) -from pie_modules.documents import ( - TextBasedDocument, - TokenBasedDocument, - TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, -) - -from ..document.processing import token_based_document_to_text_based, tokenize_document -from .common import BatchableMixin, get_first_occurrence_index -from .metrics import ( - PrecisionRecallAndF1ForLabeledAnnotations, - WrappedLayerMetricsWithUnbatchAndDecodeWithErrorsFunction, -) -from .pointer_network.annotation_encoder_decoder import ( - KEY_INVALID_CORRECT, - BinaryRelationEncoderDecoder, - LabeledSpanEncoderDecoder, - SpanEncoderDecoderWithOffset, -) -from .pointer_network.logits_processor import ( - FinitizeLogitsProcessor, - PrefixConstrainedLogitsProcessorWithMaximum, -) - -logger = logging.getLogger(__name__) - - -DocumentType: TypeAlias = TextBasedDocument - - -@dataclasses.dataclass -class InputEncodingType(BatchableMixin): - input_ids: List[int] - attention_mask: List[int] - - -@dataclasses.dataclass -class LabelsAndOptionalConstraints(BatchableMixin): - labels: List[int] - constraints: Optional[List[List[int]]] = None - - @property - def decoder_attention_mask(self) -> List[int]: - return [1] * len(self.labels) - - -TargetEncodingType: TypeAlias = LabelsAndOptionalConstraints -TaskEncodingType: TypeAlias = TaskEncoding[ - DocumentType, - InputEncodingType, - TargetEncodingType, -] -TaskOutputType: TypeAlias = LabelsAndOptionalConstraints - - -def cmp_src_rel(v1: BinaryRelation, v2: BinaryRelation) -> int: - if not all(isinstance(ann, LabeledSpan) for ann in [v1.head, v1.tail, v2.head, v2.tail]): - raise Exception(f"expected LabeledSpan, but got: {v1}, {v2}") - if v1.head.start == v2.head.start: # v1[0]["from"] == v2[0]["from"]: - return v1.tail.start - v2.tail.start # v1[1]["from"] - v2[1]["from"] - return v1.head.start - v2.head.start # v1[0]["from"] - v2[0]["from"] - - -@TaskModule.register() -class PointerNetworkTaskModuleForEnd2EndRE( - TaskModule[ - DocumentType, - InputEncoding, - TargetEncoding, - TaskBatchEncoding, - ModelBatchOutput, - TaskOutputType, - ], -): - PREPARED_ATTRIBUTES = ["labels_per_layer"] - REVERSED_RELATION_LABEL_SUFFIX = "_reversed" - - def __init__( - self, - tokenizer_name_or_path: str, - # specific for this use case - document_type: str = "pytorch_ie.documents.TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions", - tokenized_document_type: str = "pie_modules.documents.TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions", - relation_layer_name: str = "binary_relations", - add_reversed_relations: bool = False, - symmetric_relations: Optional[List[str]] = None, - none_label: str = "none", - loop_dummy_relation_name: str = "loop", - constrained_generation: bool = False, - # generic pointer network - label_tokens: Optional[Dict[str, str]] = None, - label_representations: Optional[Dict[str, str]] = None, - labels_per_layer: Optional[Dict[str, List[str]]] = None, - exclude_labels_per_layer: Optional[Dict[str, List[str]]] = None, - # target encoding - create_constraints: bool = False, - # tokenization - tokenizer_init_kwargs: Optional[Dict[str, Any]] = None, - tokenizer_kwargs: Optional[Dict[str, Any]] = None, - partition_layer_name: Optional[str] = None, - annotation_field_mapping: Optional[Dict[str, str]] = None, - # logging - log_first_n_examples: Optional[int] = None, - **kwargs, - ): - super().__init__(**kwargs) - self.save_hyperparameters() - - # tokenization - self._document_type: Type[TextBasedDocument] = resolve_type( - document_type, expected_super_type=TextBasedDocument - ) - self._tokenized_document_type: Type[TokenBasedDocument] = resolve_type( - tokenized_document_type, expected_super_type=TokenBasedDocument - ) - self.tokenizer_name_or_path = tokenizer_name_or_path - self.tokenizer_kwargs = tokenizer_kwargs or {} - self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( - tokenizer_name_or_path, - **(tokenizer_init_kwargs or {}), - ) - self.annotation_field_mapping = annotation_field_mapping or dict() - annotation_field_mapping_inv = {v: k for k, v in self.annotation_field_mapping.items()} - if len(self.annotation_field_mapping) != len(annotation_field_mapping_inv): - raise ValueError( - f"inverted annotation_field_mapping is not unique. annotation_field_mapping: " - f"{self.annotation_field_mapping}" - ) - self.partition_layer_name = partition_layer_name - - # for this specific use case: end-to-end relation extraction - self.relation_layer_name = relation_layer_name - relation_layer_mapped = self.annotation_field_mapping.get( - relation_layer_name, relation_layer_name - ) - relation_layer_target = self.document_type.target_name(relation_layer_mapped) - self.span_layer_name = annotation_field_mapping_inv.get( - relation_layer_target, relation_layer_target - ) - self.add_reversed_relations = add_reversed_relations - self.symmetric_relations = set(symmetric_relations or []) - self.none_label = none_label - self.loop_dummy_relation_name = loop_dummy_relation_name - self.constrained_generation = constrained_generation - # will be set in _post_prepare() - self.relation_encoder_decoder: BinaryRelationEncoderDecoder - - # collected in prepare(), if not passed in - self.labels_per_layer = labels_per_layer - self.exclude_labels_per_layer = exclude_labels_per_layer or {} - - # how to encode and decode the annotations - self.bos_token = self.tokenizer.bos_token - self.eos_token = self.tokenizer.eos_token - self.label_tokens = label_tokens or dict() - self.label_representations = label_representations or dict() - - # target encoding - self.create_constraints = create_constraints - self.pad_values = { - "input_ids": self.tokenizer.pad_token_id, - "attention_mask": 0, - "labels": self.target_pad_id, - "decoder_attention_mask": 0, - "constraints": -1, - } - self.dtypes = { - "input_ids": torch.int64, - "attention_mask": torch.int64, - "labels": torch.int64, - "decoder_attention_mask": torch.int64, - "constraints": torch.int64, - } - - # logging - self.log_first_n_examples = log_first_n_examples - - @property - def document_type(self) -> Type[TextBasedDocument]: - return self._document_type - - @property - def tokenized_document_type(self) -> Type[TokenBasedDocument]: - return self._tokenized_document_type - - @property - def layer_names(self) -> List[str]: - return [self.span_layer_name, self.relation_layer_name] - - @property - def special_targets(self) -> list[str]: - return [self.bos_token, self.eos_token] - - @property - def special_target2id(self) -> Dict[str, int]: - return {target: idx for idx, target in enumerate(self.special_targets)} - - @property - def target_pad_id(self) -> int: - return self.special_target2id[self.eos_token] - - def configure_model_generation(self) -> Dict[str, Any]: - result: Dict[str, Any] = {"no_repeat_ngram_size": 7} - if self.constrained_generation: - logits_processor = LogitsProcessorList() - # PrefixConstrainedLogitsProcessorWithMaximum requires finite logits - logits_processor.append(FinitizeLogitsProcessor()) - logits_processor.append( - PrefixConstrainedLogitsProcessorWithMaximum( - prefix_allowed_tokens_fn=self._prefix_allowed_tokens_fn_with_maximum, - # use dummy value of 1, this is fine because num_beams affects only the value of batch_id - # which is not used in _prefix_allowed_tokens_fn_with_maximum() - num_beams=1, - ) - ) - result["logits_processor"] = logits_processor - return result - - def _prefix_allowed_tokens_fn_with_maximum( - self, batch_id: int, input_ids: torch.LongTensor, maximum: int - ) -> List[int]: - # remove the first token (bos_token) and use unbatch_output to un-pad the label_ids - label_ids_without_bos = input_ids[1:] - if len(label_ids_without_bos) > 0: - unpadded_label_ids = self.unbatch_output( - {"labels": label_ids_without_bos.unsqueeze(0)} - )[0].labels - else: - unpadded_label_ids = [] - _, _, remaining = self.relation_encoder_decoder.parse(encoding=unpadded_label_ids) - # this is a binary mask - constraint = self._build_constraint( - previous_ids=remaining, input_len=maximum - self.pointer_offset - ) - # convert to indices - allowed_indices = torch.nonzero(constraint).squeeze(1) - # convert to a list - return allowed_indices.tolist() - - def add_reversed_relation_labels(self, relation_labels: Iterable[str]) -> Set[str]: - result = set(relation_labels) - for rel_label in set(relation_labels): - if rel_label not in self.symmetric_relations: - reversed_label = rel_label + self.REVERSED_RELATION_LABEL_SUFFIX - if reversed_label in result: - raise ValueError( - f"reversed relation label {reversed_label} already exists in relation layer labels" - ) - result.add(reversed_label) - return result - - def _prepare(self, documents: Sequence[DocumentType]) -> None: - # collect all labels - labels: Dict[str, Set[str]] = {layer_name: set() for layer_name in self.layer_names} - for doc in documents: - for layer_name in self.layer_names: - exclude_labels = self.exclude_labels_per_layer.get(layer_name, []) - labels[layer_name].update( - ac.label for ac in doc[layer_name] if ac.label not in exclude_labels - ) - - if self.add_reversed_relations: - labels[self.relation_layer_name] = self.add_reversed_relation_labels( - relation_labels=labels[self.relation_layer_name] - ) - - self.labels_per_layer = { - # sort labels to ensure deterministic order - layer_name: sorted(labels) - for layer_name, labels in labels.items() - } - - def construct_label_token(self, label: str) -> str: - return self.label_tokens.get(label, f"<<{label}>>") - - def get_label_representation(self, label: str) -> str: - return self.label_representations.get(label, label) - - def _post_prepare(self) -> None: - # set up labels - if self.labels_per_layer is None: - raise Exception("labels_per_layer is not defined. Call prepare() first or pass it in.") - self.labels: List[str] = [self.none_label] - for layer_name in self.layer_names: - self.labels.extend(self.labels_per_layer[layer_name]) - if len(set(self.labels)) != len(self.labels): - raise Exception(f"labels are not unique: {self.labels}") - - # set up targets and ids - self.targets: List[str] = self.special_targets + self.labels - self.target2id: Dict[str, int] = {target: idx for idx, target in enumerate(self.targets)} - - # generic ids - self.eos_id: int = self.target2id[self.eos_token] - self.bos_id: int = self.target2id[self.bos_token] - - # span and relation ids - self.span_ids: List[int] = [ - self.target2id[label] for label in self.labels_per_layer[self.span_layer_name] - ] - self.relation_ids: List[int] = [ - self.target2id[label] for label in self.labels_per_layer[self.relation_layer_name] - ] - # the none id is used for the dummy relation which models out-of-relation spans - self.none_id: int = self.target2id[self.none_label] - - # helpers (same as targets / target2id, but only for labels) - self.label2id: Dict[str, int] = {label: self.target2id[label] for label in self.labels} - self.id2label: Dict[int, str] = {idx: label for label, idx in self.label2id.items()} - self.label_ids: List[int] = [self.label2id[label] for label in self.labels] - - # annotation-encoder-decoders - span_encoder_decoder = SpanEncoderDecoderWithOffset( - offset=self.pointer_offset, exclusive_end=False - ) - span_labels = self.labels_per_layer[self.span_layer_name] - labeled_span_encoder_decoder = LabeledSpanEncoderDecoder( - span_encoder_decoder=span_encoder_decoder, - # restrict label2id to get better error messages - label2id={label: idx for label, idx in self.label2id.items() if label in span_labels}, - mode="indices_label", - ) - relation_labels = self.labels_per_layer[self.relation_layer_name] + [ - self.loop_dummy_relation_name, - self.none_label, - ] - self.relation_encoder_decoder = BinaryRelationEncoderDecoder( - head_encoder_decoder=labeled_span_encoder_decoder, - tail_encoder_decoder=labeled_span_encoder_decoder, - # restrict label2id to get better error messages - label2id={ - label: idx for label, idx in self.label2id.items() if label in relation_labels - }, - loop_dummy_relation_name=self.loop_dummy_relation_name, - none_label=self.none_label, - mode="tail_head_label", - ) - - label2token = {label: self.construct_label_token(label=label) for label in self.labels} - if len(set(label2token.values())) != len(label2token): - raise Exception(f"label2token values are not unique: {label2token}") - - already_in_vocab = [ - tok - for tok in label2token.values() - if self.tokenizer.convert_tokens_to_ids(tok) != self.tokenizer.unk_token_id - ] - if len(already_in_vocab) > 0: - raise Exception( - f"some special tokens to add (mapped label ids) are already in the tokenizer vocabulary, " - f"this is not allowed: {already_in_vocab}. You may want to adjust the label2special_token mapping" - ) - # sort by length, so that longer tokens are added first - label_tokens_sorted = sorted(label2token.values(), key=lambda x: len(x), reverse=True) - self.tokenizer.add_special_tokens( - special_tokens_dict={"additional_special_tokens": label_tokens_sorted} - ) - - # target tokens are the special tokens plus the mapped label tokens - self.target_tokens: List[str] = self.special_targets + [ - label2token[label] for label in self.labels - ] - self.target_token_ids: List[int] = self.tokenizer.convert_tokens_to_ids(self.target_tokens) - - # construct a mapping from label_token_id to token_ids that will be used to initialize the embeddings - # of the labels - self.label_embedding_weight_mapping = dict() - for label, label_token in label2token.items(): - label_token_indices = self.tokenizer.convert_tokens_to_ids( - self.tokenizer.tokenize(label_token) - ) - # sanity check: label_tokens should not be split up - if len(label_token_indices) > 1: - raise RuntimeError(f"{label_token} wrong split") - else: - label_token_idx = label_token_indices[0] - - label_representation = self.get_label_representation(label) - source_indices = self.tokenizer.convert_tokens_to_ids( - self.tokenizer.tokenize(label_representation) - ) - if self.tokenizer.unk_token_id in source_indices: - raise RuntimeError( - f"tokenized label_token={label_token} [{source_indices}] contains unk_token" - ) - self.label_embedding_weight_mapping[label_token_idx] = source_indices - - @property - def pointer_offset(self) -> int: - return len(self.targets) - - @property - def target_ids(self) -> Set[int]: - return set(range(self.pointer_offset)) - - def configure_model_metric(self, stage: Optional[str] = None) -> Optional[Metric]: - layer_metrics = { - layer_name: PrecisionRecallAndF1ForLabeledAnnotations() - for layer_name in self.layer_names - } - - return WrappedLayerMetricsWithUnbatchAndDecodeWithErrorsFunction( - unbatch_function=self.unbatch_output, - decode_layers_with_errors_function=self.decode_annotations, - layer_metrics=layer_metrics, - error_key_correct=KEY_INVALID_CORRECT, - ) - - def reverse_relation(self, relation: Annotation) -> BinaryRelation: - if isinstance(relation, BinaryRelation): - reversed_label = relation.label - if ( - reversed_label not in self.symmetric_relations - and reversed_label != self.none_label - ): - reversed_label += self.REVERSED_RELATION_LABEL_SUFFIX - reversed_rel = relation.copy( - head=relation.tail, tail=relation.head, label=reversed_label - ) - return reversed_rel - else: - raise Exception(f"reversing of relations of type {type(relation)} is not supported") - - def unreverse_relation(self, relation: Annotation) -> BinaryRelation: - if isinstance(relation, BinaryRelation): - head, tail, label = relation.head, relation.tail, relation.label - # if the relation is symmetric, we sort head and tail to ensure consistent order - if relation.label in self.symmetric_relations: - head, tail = sorted([head, tail], key=lambda x: (x.start, x.end)) - # if the relation was reversed, we need to reconstruct the original label and swap head and tail - elif label.endswith(self.REVERSED_RELATION_LABEL_SUFFIX): - # reconstruct the original label and swap head and tail - label = label[: -len(self.REVERSED_RELATION_LABEL_SUFFIX)] - head, tail = tail, head - return relation.copy(head=head, tail=tail, label=label) - else: - raise Exception(f"un-reversing of relations of type {type(relation)} is not supported") - - def encode_annotations( - self, layers: Dict[str, Iterable[Annotation]], metadata: Optional[Dict[str, Any]] = None - ) -> TaskOutputType: - if not set(layers.keys()) == set(self.layer_names): - raise Exception(f"unexpected layers: {layers.keys()}. expected: {self.layer_names}") - - if self.labels_per_layer is None: - raise Exception("labels_per_layer is not defined. Call prepare() first or pass it in.") - - # encode relations - all_relation_arguments = set() - relation_arguments2label: Dict[Tuple[Annotation, ...], str] = dict() - relation_encodings = dict() - for rel in layers[self.relation_layer_name]: - if not isinstance(rel, BinaryRelation): - raise Exception(f"expected BinaryRelation, but got: {rel}") - if rel.label in self.labels_per_layer[self.relation_layer_name]: - if (rel.head, rel.tail) in relation_arguments2label: - previous_label = relation_arguments2label[(rel.head, rel.tail)] - if previous_label != rel.label: - raise ValueError( - f"relation {rel.head} -> {rel.tail} already exists, but has another label: " - f"{previous_label} (current label: {rel.label})." - ) - continue - encoded_relation = self.relation_encoder_decoder.encode( - annotation=rel, metadata=metadata - ) - if encoded_relation is None: - raise Exception(f"failed to encode relation: {rel}") - relation_encodings[rel] = encoded_relation - all_relation_arguments.update([rel.head, rel.tail]) - relation_arguments2label[(rel.head, rel.tail)] = rel.label - - # encode spans that are not arguments of any relation - no_relation_spans = [ - span for span in layers[self.span_layer_name] if span not in all_relation_arguments - ] - for span in no_relation_spans: - dummy_relation = BinaryRelation( - head=span, tail=span, label=self.loop_dummy_relation_name - ) - encoded_relation = self.relation_encoder_decoder.encode( - annotation=dummy_relation, metadata=metadata - ) - if encoded_relation is not None: - relation_encodings[dummy_relation] = encoded_relation - - # sort relations by start indices of head and tail # TODO: is this correct? - sorted_relations = sorted(relation_encodings, key=cmp_to_key(cmp_src_rel)) - - # this should never be accessed as it is, so use negative pointer offset to provoke an error - input_len = -self.pointer_offset - 1 - if self.create_constraints: - if metadata is None or "src_len" not in metadata: - raise Exception("metadata with 'src_len' is required to create constraints") - input_len = metadata["src_len"] - - # build target_ids - target_ids = [] - constraints_list = [] - for rel in sorted_relations: - encoded_relation = relation_encodings[rel] - target_ids.extend(encoded_relation) - - if self.create_constraints: - # iterate over all prefixes of the relation encoding - for idx, t in enumerate(encoded_relation): - # get the constraints for the current prefix - current_constraints = self._build_constraint( - previous_ids=encoded_relation[:idx], input_len=input_len - ) - # sanity check - if current_constraints[t] == 0: - raise Exception( - f"current_constraints[{t}] is 0, but should be 1: {current_constraints}" - ) - # add the constraints to the list - constraints_list.append(current_constraints) - - target_ids.append(self.eos_id) - - if self.create_constraints: - # add constraints for the eos_id - eos_constraint = torch.zeros(input_len + self.pointer_offset, dtype=torch.int64) - eos_constraint[self.eos_id] = 1 - constraints_list.append(eos_constraint) - # combine all constraints - constraints = torch.stack(constraints_list).tolist() - else: - constraints = None - - # sanity check - _, encoding_errors, remaining = self.relation_encoder_decoder.parse(encoding=target_ids) - if ( - not all(v == 0 for k, v in encoding_errors.items() if k != "correct") - or len(remaining) > 0 - ): - decoded, invalid = self.decode_annotations(LabelsAndOptionalConstraints(target_ids)) - not_encoded = {} - for layer_name in layers: - # convert to dicts to make them comparable (original annotations are attached which breaks comparison) - decoded_dicts = [ann.asdict() for ann in decoded[layer_name]] - # filter annotations and convert to str to make them json serializable - filtered = { - str(ann) for ann in layers[layer_name] if ann.asdict() not in decoded_dicts - } - if len(filtered) > 0: - not_encoded[layer_name] = list(filtered) - if len(not_encoded) > 0: - logger.warning( - f"encoding errors: {encoding_errors}, skipped annotations:\n" - f"{json.dumps(not_encoded, sort_keys=True, indent=2)}" - ) - elif len([tag for tag in remaining if tag != self.eos_id]) > 0: - logger.warning( - f"encoding errors: {encoding_errors}, remaining encoding ids: {remaining}" - ) - - return LabelsAndOptionalConstraints(labels=target_ids, constraints=constraints) - - def decode_annotations( - self, encoding: TaskOutputType - ) -> Tuple[Dict[str, Iterable[Annotation]], Dict[str, int]]: - decoded_relations, errors, remaining = self.relation_encoder_decoder.parse( - encoding=encoding.labels - ) - relation_tuples: List[Tuple[Annotation, Annotation, str]] = [] - entity_labels: Dict[Annotation, List[str]] = defaultdict(list) - for rel in decoded_relations: - head_dummy = rel.head.copy(label="dummy") - entity_labels[head_dummy].append(rel.head.label) - - if rel.label != self.loop_dummy_relation_name: - tail_dummy = rel.tail.copy(label="dummy") - entity_labels[tail_dummy].append(rel.tail.label) - relation_tuples.append((head_dummy, tail_dummy, rel.label)) - else: - assert rel.head == rel.tail - - # It may happen that some spans take part in multiple relations, but got generated with different labels. - # In this case, we just create one span and take the most common label. - entities: Dict[Annotation, Annotation] = {} - for entity_dummy, labels in entity_labels.items(): - c = Counter(labels) - # if len(c) > 1: - # logger.warning(f"multiple labels for span, take the most common: {dict(c)}") - most_common_label = c.most_common(1)[0][0] - entities[entity_dummy] = entity_dummy.copy(label=most_common_label) - - entity_layer = list(entities.values()) - relation_layer = [ - BinaryRelation(head=entities[head_dummy], tail=entities[tail_dummy], label=label) - for head_dummy, tail_dummy, label in relation_tuples - ] - return { - self.span_layer_name: entity_layer, - self.relation_layer_name: relation_layer, - }, errors - - def _build_constraint( - self, - previous_ids: List[int], - input_len: int, - ) -> torch.LongTensor: - """Build a constraint for the decoder. The constraint is a binary mask that indicates which - ids are allowed to be predicted in the next decoding step. The mask is of size input_len + - pointer_offset, where input_len is the length of the input sequence and pointer_offset is - the number of labels and special tokens. Uses the relation_encoder_decoder to build the - actual constraints. - - Args: - previous_ids: previously decoded ids - input_len: length of the input sequence - - Returns: - A binary mask of size input_len + pointer_offset, where 1 indicates that the id is - allowed to be predicted next, and 0 indicates that the id is not allowed to be predicted next. - """ - result: torch.LongTensor = torch.zeros(input_len + self.pointer_offset, dtype=torch.int64) - if self.eos_id in previous_ids: - # once eos is predicted, only allow padding - result[self.target_pad_id] = 1 - return result - - allowed_ids, disallowed_ids = self.relation_encoder_decoder.build_decoding_constraints( - partial_encoding=previous_ids - ) - if allowed_ids is not None and disallowed_ids is not None: - raise Exception( - f"allowed_ids and disallowed_ids are both not None: {allowed_ids}, {disallowed_ids}" - ) - elif allowed_ids is not None: - for allowed_id in allowed_ids: - result[allowed_id] = 1 - elif disallowed_ids is not None: - for id in range(len(result)): - if id not in disallowed_ids: - result[id] = 1 - else: - raise Exception( - f"allowed_ids and disallowed_ids are both None: {allowed_ids}, {disallowed_ids}" - ) - if len(previous_ids) == 0: - # if there are no previous ids, we also allow the eos_id - result[self.eos_id] = 1 - else: - # if there are previous ids, we don't allow the eos_id - result[self.eos_id] = 0 - # never allow the bos_id - result[self.bos_id] = 0 - - return result - - def maybe_log_example( - self, - task_encoding: TaskEncodingType, - targets: Optional[TargetEncodingType] = None, - ): - if self.log_first_n_examples is not None and self.log_first_n_examples > 0: - tokenized_doc_id = task_encoding.metadata["tokenized_document"].id - inputs = task_encoding.inputs - targets = targets or task_encoding.targets - input_tokens = self.tokenizer.convert_ids_to_tokens(inputs.input_ids) - label_tokens = [ - ( - self.targets[target_id_or_offset] - if target_id_or_offset < self.pointer_offset - else str(target_id_or_offset) - + " {" - + str(input_tokens[target_id_or_offset - self.pointer_offset]) - + "}" - ) - for target_id_or_offset in targets.labels - ] - logger.info("*** Example ***") - logger.info(f"doc.id: {tokenized_doc_id}") - logger.info(f"input_ids: {' '.join([str(i) for i in inputs.input_ids])}") - logger.info(f"input_tokens: {' '.join(input_tokens)}") - logger.info(f"label_ids: {' '.join([str(i) for i in targets.labels])}") - logger.info(f"label_tokens: {' '.join(label_tokens)}") - if self.create_constraints: - # only show the shape because the content is not very readable - logger.info( - f"constraints: {torch.tensor(targets.constraints).shape} (content is omitted)" - ) - self.log_first_n_examples -= 1 - - def tokenize_document(self, document: DocumentType) -> List[TokenBasedDocument]: - field_mapping = dict(self.annotation_field_mapping) - if self.partition_layer_name is not None: - field_mapping[self.partition_layer_name] = "labeled_partitions" - partition_layer = "labeled_partitions" - else: - partition_layer = None - casted_document = document.as_type(self.document_type, field_mapping=field_mapping) - tokenized_docs = tokenize_document( - casted_document, - tokenizer=self.tokenizer, - result_document_type=self.tokenized_document_type, - partition_layer=partition_layer, - **self.tokenizer_kwargs, - ) - for idx, tokenized_doc in enumerate(tokenized_docs): - tokenized_doc.id = f"{document.id}-tokenized-{idx+1}-of-{len(tokenized_docs)}" - - return tokenized_docs - - def encode_input( - self, document: DocumentType, is_training: bool = False - ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]: - tokenized_docs = self.tokenize_document(document) - task_encodings: List[TaskEncodingType] = [] - for tokenized_doc in tokenized_docs: - tokenizer_encoding = tokenized_doc.metadata["tokenizer_encoding"] - task_encodings.append( - TaskEncoding( - document=document, - inputs=InputEncodingType( - input_ids=tokenizer_encoding.ids, - attention_mask=tokenizer_encoding.attention_mask, - ), - metadata={"tokenized_document": tokenized_doc}, - ) - ) - - return task_encodings - - def get_mapped_layer(self, document: Document, layer_name: str) -> AnnotationLayer: - if layer_name in self.annotation_field_mapping: - layer_name = self.annotation_field_mapping[layer_name] - return document[layer_name] - - def encode_target(self, task_encoding: TaskEncodingType) -> Optional[TargetEncodingType]: - try: - document = task_encoding.metadata["tokenized_document"] - - layers = { - layer_name: self.get_mapped_layer(document, layer_name=layer_name) - for layer_name in self.layer_names - } - - if self.add_reversed_relations: - # create a copy to avoid modifying the annotation layer in the document - relations = list(layers[self.relation_layer_name]) - reversed_relations = [self.reverse_relation(rel) for rel in relations] - layers[self.relation_layer_name] = relations + reversed_relations - - result = self.encode_annotations( - layers=layers, - metadata={ - **task_encoding.metadata, - "src_len": len(task_encoding.inputs.input_ids), - }, - ) - - self.maybe_log_example(task_encoding=task_encoding, targets=result) - return result - except Exception as e: - logger.error(f"failed to encode target, it will be skipped: {e}") - return None - - def collate(self, task_encodings: Sequence[TaskEncodingType]) -> TaskBatchEncoding: - if len(task_encodings) == 0: - raise ValueError("no task_encodings available") - inputs = InputEncodingType.batch( - values=[x.inputs for x in task_encodings], - dtypes=self.dtypes, - pad_values=self.pad_values, - ) - - targets = None - if task_encodings[0].has_targets: - targets = TargetEncodingType.batch( - values=[x.targets for x in task_encodings], - dtypes=self.dtypes, - pad_values=self.pad_values, - ) - - return inputs, targets - - def unbatch_output(self, model_output: ModelBatchOutput) -> Sequence[TaskOutputType]: - labels = model_output["labels"] - batch_size = labels.size(0) - - # We use the position after the first eos token as the seq_len. - # Note that, if eos_id is not in model_output for a given batch item, the result will be - # model_output.size(1) + 1 (i.e. seq_len + 1) for that batch item. This is fine, because we use the - # seq_lengths just to truncate the output and want to keep everything if eos_id is not present. - seq_lengths = get_first_occurrence_index(labels, self.eos_id) + 1 - - result = [ - LabelsAndOptionalConstraints(labels[i, : seq_lengths[i]].to(device="cpu").tolist()) - for i in range(batch_size) - ] - return result - - def create_annotations_from_output( - self, - task_encoding: TaskEncodingType, - task_output: TaskOutputType, - ) -> Iterator[Tuple[str, Annotation]]: - layers, errors = self.decode_annotations( - encoding=task_output, # metadata=task_encoding.metadata - ) - tokenized_document = task_encoding.metadata["tokenized_document"] - - # Note: token_based_document_to_text_based() does not yet consider predictions, so we need to clear - # the main annotations and attach the predictions to that - for layer_name, annotations in layers.items(): - layer = self.get_mapped_layer(tokenized_document, layer_name=layer_name) - layer.clear() - layer.extend(annotations) - - untokenized_document = token_based_document_to_text_based( - tokenized_document, result_document_type=self.document_type - ) - - for layer_name in layers: - annotations = self.get_mapped_layer(untokenized_document, layer_name=layer_name) - for annotation in annotations: - # handle relations that may be reversed - if layer_name == self.relation_layer_name and self.add_reversed_relations: - unreversed_relation = self.unreverse_relation(annotation) - yield layer_name, unreversed_relation - else: - yield layer_name, annotation.copy() diff --git a/src/pie_modules/taskmodules/re_span_pair_classification.py b/src/pie_modules/taskmodules/re_span_pair_classification.py deleted file mode 100644 index ad259bc3d..000000000 --- a/src/pie_modules/taskmodules/re_span_pair_classification.py +++ /dev/null @@ -1,829 +0,0 @@ -""" -workflow: - Document - -> (InputEncoding, TargetEncoding) -> TaskEncoding -> TaskBatchEncoding - -> ModelBatchEncoding -> ModelBatchOutput - -> TaskOutput - -> Document -""" - -import logging -from collections import defaultdict -from copy import deepcopy -from typing import ( - Any, - Dict, - Iterable, - Iterator, - List, - Optional, - Sequence, - Set, - Tuple, - Type, - TypedDict, - Union, -) - -import pandas as pd -import torch -from pie_core import ( - Annotation, - AnnotationLayer, - Document, - TaskEncoding, - TaskModule, -) -from pytorch_ie.taskmodules.interface import ChangesTokenizerVocabSize -from tokenizers import AddedToken -from torch import BoolTensor, LongTensor, Tensor -from torch.nn.utils.rnn import pad_sequence -from torchmetrics import ClasswiseWrapper, F1Score, Metric, MetricCollection -from transformers import AutoTokenizer -from typing_extensions import TypeAlias - -from pie_modules.annotations import ( - BinaryRelation, - LabeledSpan, - MultiLabeledBinaryRelation, - NaryRelation, -) -from pie_modules.document.processing import ( - token_based_document_to_text_based, - tokenize_document, -) -from pie_modules.documents import ( - TextBasedDocument, - TextDocumentWithLabeledPartitions, - TextDocumentWithLabeledSpansAndBinaryRelations, - TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, - TokenDocumentWithLabeledSpansAndBinaryRelations, - TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, -) -from pie_modules.taskmodules.metrics import WrappedMetricWithPrepareFunction -from pie_modules.utils.span import distance as get_span_distance - -PAD_VALUES = { - "input_ids": 0, - "attention_mask": 0, - "span_start_indices": 0, - "span_end_indices": 0, - "tuple_indices": -1, - "labels": -100, - "tuple_indices_mask": False, -} -DTYPES = { - "input_ids": torch.long, - "attention_mask": torch.long, - "span_start_indices": torch.long, - "span_end_indices": torch.long, - "tuple_indices": torch.long, - "labels": torch.long, - "tuple_indices_mask": torch.bool, -} - - -class InputEncodingType(TypedDict, total=False): - # shape: (sequence_length,) - input_ids: LongTensor - # shape: (sequence_length,) - attention_mask: LongTensor - # shape: (num_entities,) - span_start_indices: LongTensor - # shape: (num_entities,) - span_end_indices: LongTensor - # list of lists of argument indices: [[head_idx, tail_idx], ...] - # NOTE: these indices point into span_start_indices and span_end_indices! - tuple_indices: LongTensor - tuple_indices_mask: BoolTensor - - -class TargetEncodingType(TypedDict, total=False): - # list of label indices: [label_idx, ...] - labels: LongTensor - - -DocumentType: TypeAlias = TextBasedDocument -TaskEncodingType: TypeAlias = TaskEncoding[ - DocumentType, - InputEncodingType, - TargetEncodingType, -] - - -class TaskOutputType(TypedDict, total=False): - labels: Sequence[str] - probabilities: Sequence[float] - - -class ModelInputType(TypedDict, total=False): - input_ids: LongTensor - attention_mask: LongTensor - span_start_indices: LongTensor - span_end_indices: LongTensor - tuple_indices: LongTensor - tuple_indices_mask: BoolTensor - - -class ModelTargetType(TypedDict, total=False): - labels: LongTensor - probabilities: LongTensor - - -TaskModuleType: TypeAlias = TaskModule[ - # _InputEncoding, _TargetEncoding, _TaskBatchEncoding, _ModelBatchOutput, _TaskOutput - DocumentType, - InputEncodingType, - TargetEncodingType, - Tuple[ModelInputType, Optional[ModelTargetType]], - ModelTargetType, - TaskOutputType, -] - - -HEAD = "head" -TAIL = "tail" -START = "start" -END = "end" - - -logger = logging.getLogger(__name__) - - -def _get_label_ids_from_model_output( - model_output: ModelTargetType, -) -> LongTensor: - return model_output["labels"] - - -def get_relation_argument_spans_and_roles( - relation: Annotation, -) -> Tuple[Tuple[str, Annotation], ...]: - if isinstance(relation, BinaryRelation): - return (HEAD, relation.head), (TAIL, relation.tail) - elif isinstance(relation, NaryRelation): - # create unique order by sorting the arguments by their start and end positions and role - sorted_args = sorted( - zip(relation.roles, relation.arguments), - key=lambda role_and_span: ( - role_and_span[1].start, - role_and_span[1].end, - role_and_span[0], - ), - ) - return tuple(sorted_args) - else: - raise NotImplementedError( - f"the taskmodule does not yet support getting relation arguments for type: {type(relation)}" - ) - - -def construct_argument_marker(pos: str, label: Optional[str] = None, role: str = "SPAN") -> str: - if pos not in [START, END]: - raise ValueError(f"pos must be one of {START} or {END}, but got: {pos}") - start_or_end_marker = "" if pos == START else "/" - if label is not None: - return f"[{start_or_end_marker}{role}:{label}]" - else: - return f"[{start_or_end_marker}{role}]" - - -def inject_markers_into_text( - text: str, positions_and_markers: List[Tuple[int, str]] -) -> Tuple[str, Dict[int, int]]: - offset = 0 - original2new_pos = dict() - for original_pos, marker in sorted(positions_and_markers): - text = text[: original_pos + offset] + marker + text[original_pos + offset :] - offset += len(marker) - original2new_pos[original_pos] = original_pos + offset - return text, original2new_pos - - -def to_tensor(key: str, value: Any) -> Tensor: - return torch.tensor(value, dtype=DTYPES[key]) - - -def pad_or_stack(key: str, values: List[LongTensor]) -> Tensor: - if key in PAD_VALUES: - max_last_dim = None - if key == "tuple_indices": - max_last_dim = max(v.shape[-1] for v in values if len(v.shape) == 2) - values = [v.reshape(-1) for v in values] - result = pad_sequence(values, batch_first=True, padding_value=PAD_VALUES[key]) - if key == "tuple_indices": - batch_size = len(values) - result = result.reshape(batch_size, -1, max_last_dim) - else: - result = torch.stack(values, dim=0) - return result - - -@TaskModule.register() -class RESpanPairClassificationTaskModule(TaskModuleType, ChangesTokenizerVocabSize): - """Task module for relation extraction as span pair classification. - - This task module frames relation extraction as a span pair classification task where all candidate - pairs in a given text are classified at once. The task module injects start and end markers for - each entity (i.e. "[SPAN]" and "[/SPAN]") into the text and tokenizes the text (the markers are - handled as special tokens, and thus, kept as they are). It then collects the start- and end-marker - positions for each entity and constructs a model input encoding from the tokenized text and these - positions. The model target encoding consists of a list of label indices and a list of tuples - (head and tail) of argument indices that point into the start- and end-marker positions from the - model inputs. The model output is expected to be of the same format as the model target encoding, - but with probabilities for each label. - - This means, that the model should return only positive relations (argument indices + label) and - discard all negative ones. - - Args: - tokenizer_name_or_path: The name or path of the tokenizer to use. - relation_annotation: The name of the annotation layer that contains the binary relations. - partition_annotation: The name of the annotation layer that contains the labeled partitions. - If provided, the task module expects the document to have a partition layer with the - given name containing LabeledSpans. These entries are used to split the text into - partitions, e.g. paragraphs or sentences, that are treated as separate documents during - tokenization. Defaults to None. - tokenize_kwargs: Additional keyword arguments passed to the tokenizer during tokenization. - create_candidate_relations: Whether to create candidate relations for training. If True, the - task module creates all possible pairs of entities in the text as candidate relations. - Defaults to False. - create_candidate_relations_kwargs: Additional keyword arguments passed to the method that - creates the candidate relations (e.g. max_argument_distance). Defaults to None. - labels: The list of relation labels. If not provided, the task module will collect the labels - from the documents during preparation. Defaults to None. - entity_labels: The list of entity labels. If not provided, the task module will collect the - entity labels from the documents during preparation. Defaults to None. - add_type_to_marker: Whether to add the entity type to the markers. If True, the markers will - look like this: "[SPAN:entity_type]" and "[/SPAN:entity_type]" where entity_type is the - type of the respective entity. Defaults to False. - log_first_n_examples: The number of examples to log during training. If 0, no examples are logged. - Defaults to 0. - collect_statistics: Whether to collect statistics during preparation. If True, the task module - will collect statistics about the available, used, and skipped relations. Defaults to False. - """ - - PREPARED_ATTRIBUTES = ["labels", "entity_labels"] - - def __init__( - self, - tokenizer_name_or_path: str, - relation_annotation: str = "binary_relations", - no_relation_label: str = "no_relation", - partition_annotation: Optional[str] = None, - tokenize_kwargs: Optional[Dict[str, Any]] = None, - create_candidate_relations: bool = False, - create_candidate_relations_kwargs: Optional[Dict[str, Any]] = None, - labels: Optional[List[str]] = None, - entity_labels: Optional[List[str]] = None, - add_type_to_marker: bool = True, - log_first_n_examples: int = 0, - collect_statistics: bool = False, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.save_hyperparameters() - - self.relation_annotation = relation_annotation - self.no_relation_label = no_relation_label - self.tokenize_kwargs = tokenize_kwargs or {} - self.create_candidate_relations = create_candidate_relations - self.create_candidate_relations_kwargs = create_candidate_relations_kwargs or {} - self.labels = labels - self.add_type_to_marker = add_type_to_marker - self.entity_labels = entity_labels - self.partition_annotation = partition_annotation - # overwrite None with 0 for backward compatibility - self.log_first_n_examples = log_first_n_examples or 0 - self.collect_statistics = collect_statistics - - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) - - self.argument_markers = None - - self._logged_examples_counter = 0 - - self.reset_statistics() - - @property - def document_type(self) -> Optional[Type[DocumentType]]: - if self.partition_annotation is not None: - dt = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions - else: - dt = TextDocumentWithLabeledSpansAndBinaryRelations - if self.relation_annotation == "binary_relations": - return dt - else: - logger.warning( - f"relation_annotation={self.relation_annotation} is " - f"not the default value ('binary_relations'), so the taskmodule {type(self).__name__} can not request " - f"the usual document type for auto-conversion ({dt.__name__}) because this has the bespoken default " - f"value as layer name instead of the provided one." - ) - return None - - @property - def tokenized_document_type(self) -> Type[TokenDocumentWithLabeledSpansAndBinaryRelations]: - if self.partition_annotation is None: - return TokenDocumentWithLabeledSpansAndBinaryRelations - else: - return TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions - - @property - def normalized_document_type(self) -> Type[TextDocumentWithLabeledSpansAndBinaryRelations]: - if self.partition_annotation is None: - return TextDocumentWithLabeledSpansAndBinaryRelations - else: - return TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions - - def normalize_document(self, document) -> TextDocumentWithLabeledSpansAndBinaryRelations: - span_layer_name = self.get_span_layer_name(document) - field_mapping = { - span_layer_name: "labeled_spans", - self.relation_annotation: "binary_relations", - } - if self.partition_annotation is not None: - field_mapping[self.partition_annotation] = "labeled_partitions" - casted_document = document.as_type( - self.normalized_document_type, field_mapping=field_mapping - ) - return casted_document - - def get_relation_layer(self, document: Document) -> AnnotationLayer[BinaryRelation]: - return document[self.relation_annotation] - - def get_span_layer_name(self, document: Document) -> str: - return document[self.relation_annotation].target_name - - def get_entity_layer(self, document: Document) -> AnnotationLayer[LabeledSpan]: - relations: AnnotationLayer[BinaryRelation] = self.get_relation_layer(document) - return relations.target_layer - - def _prepare(self, documents: Sequence[DocumentType]) -> None: - entity_labels: Set[str] = set() - relation_labels: Set[str] = set() - for document in documents: - relations: AnnotationLayer[BinaryRelation] = self.get_relation_layer(document) - entities: AnnotationLayer[LabeledSpan] = self.get_entity_layer(document) - - for entity in entities: - entity_labels.add(entity.label) - - for relation in relations: - relation_labels.add(relation.label) - - if self.no_relation_label in relation_labels: - relation_labels.remove(self.no_relation_label) - - self.labels = sorted(relation_labels) - self.entity_labels = sorted(entity_labels) - - def reset_statistics(self): - self._statistics = defaultdict(int) - self._collected_relations: Dict[str, List[Annotation]] = defaultdict(list) - - def collect_relation(self, kind: str, relation: Annotation): - if self.collect_statistics: - self._collected_relations[kind].append(relation) - - def collect_all_relations(self, kind: str, relations: Iterable[Annotation]): - if self.collect_statistics: - self._collected_relations[kind].extend(relations) - - def finalize_statistics(self): - if self.collect_statistics: - all_relations = set(self._collected_relations["available_tokenized"]) - used_relations = set(self._collected_relations["used"]) - skipped_other = all_relations - used_relations - for key, rels in self._collected_relations.items(): - rels_set = set(rels) - if key.startswith("skipped_"): - skipped_other -= rels_set - elif key.startswith("used_"): - pass - elif key in ["available", "available_tokenized", "used"]: - pass - else: - raise ValueError(f"unknown key: {key}") - for rel in rels_set: - self.increase_counter(key=(key, rel.label)) - for rel in skipped_other: - self.increase_counter(key=("skipped_other", rel.label)) - - def show_statistics(self): - if self.collect_statistics: - self.finalize_statistics() - - to_show = pd.Series(self._statistics) - if len(to_show.index.names) > 1: - to_show = to_show.unstack() - logger.info(f"statistics:\n{to_show.to_markdown()}") - - def increase_counter(self, key: Tuple[Any, ...], value: Optional[int] = 1): - if self.collect_statistics: - key_str = tuple(str(k) for k in key) - self._statistics[key_str] += value - - def encode(self, *args, **kwargs): - self.reset_statistics() - res = super().encode(*args, **kwargs) - self.show_statistics() - return res - - def collect_argument_markers(self, entity_labels: Iterable[str]) -> List[str]: - argument_markers: Set[str] = set() - for arg_pos in [START, END]: - if self.add_type_to_marker: - for entity_label in entity_labels: - argument_markers.add( - construct_argument_marker(pos=arg_pos, label=entity_label) - ) - else: - argument_markers.add(construct_argument_marker(pos=arg_pos)) - - return sorted(list(argument_markers)) - - def _post_prepare(self): - self.label_to_id = {label: i + 1 for i, label in enumerate(self.labels)} - self.label_to_id[self.no_relation_label] = 0 - self.id_to_label = {v: k for k, v in self.label_to_id.items()} - - self.argument_markers = self.collect_argument_markers(self.entity_labels) - num_added = self.tokenizer.add_special_tokens( - {"additional_special_tokens": self.argument_markers} - ) - if len(self.argument_markers) != num_added: - logger.warning( - f"expected to add {len(self.argument_markers)} argument markers, but added {num_added}. It seems " - f"that the tokenizer already contains some of the argument markers." - ) - - self.argument_markers_to_id = { - marker: self.tokenizer.vocab[marker] for marker in self.argument_markers - } - - def _create_candidate_relations( - self, - document: TokenDocumentWithLabeledSpansAndBinaryRelations, - max_argument_distance: Optional[int] = None, - argument_distance_type: str = "inner", - ) -> Sequence[Annotation]: - # TODO: ensure that the relation layer type is BinaryRelation! - labeled_spans = document.labeled_spans - candidate_relations = [] - for i, head in enumerate(labeled_spans): - for j, tail in enumerate(labeled_spans): - if i == j: - continue - rel = BinaryRelation(head=head, tail=tail, label=self.no_relation_label) - if max_argument_distance is not None: - arg_distance = get_span_distance( - start_end=(head.start, head.end), - other_start_end=(tail.start, tail.end), - distance_type=argument_distance_type, - ) - if arg_distance > max_argument_distance: - self.collect_relation("skipped_argument_distance", rel) - continue - candidate_relations.append(rel) - return candidate_relations - - def inject_markers_for_labeled_spans( - self, - document: TextDocumentWithLabeledSpansAndBinaryRelations, - ) -> Tuple[TextDocumentWithLabeledSpansAndBinaryRelations, Dict[LabeledSpan, LabeledSpan]]: - # collect markers and injection positions - positions_and_markers = [] - for labeled_span in document.labeled_spans: - label_or_none = labeled_span.label if self.add_type_to_marker else None - start_marker = construct_argument_marker(pos=START, label=label_or_none) - positions_and_markers.append((labeled_span.start, start_marker)) - end_marker = construct_argument_marker(pos=END, label=label_or_none) - positions_and_markers.append((labeled_span.end, end_marker)) - - if isinstance(document, TextDocumentWithLabeledPartitions): - # create "dummy" markers for the partitions so that entries for these positions are created - # in original2new_pos - for labeled_partition in document.labeled_partitions: - positions_and_markers.append((labeled_partition.start, "")) - positions_and_markers.append((labeled_partition.end, "")) - - # inject markers into the text - marked_text, original2new_pos = inject_markers_into_text( - document.text, positions_and_markers - ) - - # construct new spans - old2new_spans = dict() - for labeled_span in document.labeled_spans: - start = original2new_pos[labeled_span.start] - end = original2new_pos[labeled_span.end] - new_span = LabeledSpan(start=start, end=end, label=labeled_span.label) - old2new_spans[labeled_span] = new_span - - # construct new relations - old2new_relations = dict() - for relation in document.binary_relations: - if isinstance(relation, BinaryRelation): - head = old2new_spans[relation.head] - tail = old2new_spans[relation.tail] - new_relation = BinaryRelation(head=head, tail=tail, label=relation.label) - else: - raise NotImplementedError( - f"the taskmodule does not yet support relations of type {type(relation)}" - ) - old2new_relations[relation] = new_relation - - # construct new document - new_document = type(document)( - id=document.id, - metadata=deepcopy(document.metadata), - text=marked_text, - ) - new_document.labeled_spans.extend(old2new_spans.values()) - new_document.binary_relations.extend(old2new_relations.values()) - if isinstance(document, TextDocumentWithLabeledPartitions): - for labeled_partition in document.labeled_partitions: - new_start = original2new_pos[labeled_partition.start] - new_end = original2new_pos[labeled_partition.end] - new_labeled_partitions = labeled_partition.copy(start=new_start, end=new_end) - new_document.labeled_partitions.append(new_labeled_partitions) - - new2old_spans = {new_span: old_span for old_span, new_span in old2new_spans.items()} - return new_document, new2old_spans - - def encode_input( - self, - document: DocumentType, - is_training: bool = False, - ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]: - self.collect_all_relations("available", self.get_relation_layer(document)) - - # 1. inject start and end markers for each entity into the text - # - save mapping from new entities to original entities - # 2. tokenize the text - # - add the marker tokens to the tokenizer as special tokens - # - tokenize with tokenize_document() - # 3. get start- and end-token positions for each entity - # 4. construct task encoding from tokenized text and entity positions - - normalized_document = self.normalize_document(document) - document_with_markers, injected2original_spans = self.inject_markers_for_labeled_spans( - normalized_document - ) - all_added_annotations: List[Dict[str, Dict[Annotation, Annotation]]] = [] - tokenized_docs = tokenize_document( - document_with_markers, - tokenizer=self.tokenizer, - result_document_type=self.tokenized_document_type, - partition_layer=( - "labeled_partitions" if self.partition_annotation is not None else None - ), - added_annotations=all_added_annotations, - strict_span_conversion=False, - **self.tokenize_kwargs, - ) - - task_encodings: List[TaskEncodingType] = [] - for tokenized_doc, tokenized_annotations in zip(tokenized_docs, all_added_annotations): - self.collect_all_relations("available_tokenized", tokenized_doc.binary_relations) - # collect start- and end-token positions for each entity - span_start_indices = [] - span_end_indices = [] - for labeled_span in tokenized_doc.labeled_spans: - # the start marker is one token before the start of the span - span_start_indices.append(labeled_span.start - 1) - # the end marker is one token after the end of the span, but the end index is exclusive - span_end_indices.append(labeled_span.end) - - labeled_span2idx = {span: idx for idx, span in enumerate(tokenized_doc.labeled_spans)} - tuple_indices = [] # list of lists of argument indices: [[head_idx, tail_idx], ...] - if self.create_candidate_relations: - candidate_relations = self._create_candidate_relations( - tokenized_doc, **self.create_candidate_relations_kwargs - ) - else: - candidate_relations = tokenized_doc.binary_relations - - # if there are no candidate relations, skip the whole (tokenized) document - if len(candidate_relations) == 0: - continue - - for relation in candidate_relations: - current_args_indices = [] - for _, arg_span in get_relation_argument_spans_and_roles(relation): - arg_idx = labeled_span2idx[arg_span] - current_args_indices.append(arg_idx) - tuple_indices.append(current_args_indices) - - encoding = tokenized_doc.metadata["tokenizer_encoding"] - inputs = { - "input_ids": encoding.ids, - "attention_mask": encoding.attention_mask, - "span_start_indices": span_start_indices, - "span_end_indices": span_end_indices, - "tuple_indices": tuple_indices, - "tuple_indices_mask": [True] * len(tuple_indices), - } - inputs_tensors = {k: to_tensor(k, v) for k, v in inputs.items()} - task_encodings.append( - TaskEncoding( - document=document, - inputs=inputs_tensors, - metadata={ - "tokenized_document": tokenized_doc, - "injected2original_spans": injected2original_spans, - "candidate_relations": candidate_relations, - "tokenized_annotations": tokenized_annotations, - }, - ) - ) - - return task_encodings - - def encode_target( - self, - task_encoding: TaskEncodingType, - ) -> TargetEncodingType: - gold_relations = task_encoding.metadata["tokenized_document"].binary_relations - gold_roles_and_args2relation = defaultdict(list) - for relation in gold_relations: - # If we manually set the labels, we only consider relations with a label in the label_to_id mapping - # This allows us to ignore relations with certain labels during training. - if relation.label in self.label_to_id: - gold_roles_and_args2relation[ - get_relation_argument_spans_and_roles(relation) - ].append(relation) - label_indices = [] # list of label indices - candidate_relations = [] - for candidate_relation in task_encoding.metadata["candidate_relations"]: - candidate_roles_and_args = get_relation_argument_spans_and_roles(candidate_relation) - gold_relations = gold_roles_and_args2relation.get(candidate_roles_and_args, []) - if len(gold_relations) == 0: - label_idx = self.label_to_id[candidate_relation.label] - self.collect_relation("used", candidate_relation) - elif len(gold_relations) == 1: - label_idx = self.label_to_id[gold_relations[0].label] - self.collect_relation("used", gold_relations[0]) - else: - # TODO: or should we add all gold relations with the same arguments? - logger.warning( - f"skip the candidate relation because there are more than one gold relation " - f"for its args and roles: {gold_relations}" - ) - for gold_relation in gold_relations: - self.collect_relation("skipped_same_arguments", gold_relation) - label_idx = PAD_VALUES["labels"] - - label_indices.append(label_idx) - candidate_relations.append(candidate_relation) - - task_encoding.metadata["candidate_relations"] = candidate_relations - target: TargetEncodingType = {"labels": to_tensor("labels", label_indices)} - - self._maybe_log_example(task_encoding=task_encoding, target=target) - - return target - - def _maybe_log_example( - self, - task_encoding: TaskEncodingType, - target: TargetEncodingType, - ): - """Maybe log the example.""" - - # log the first n examples - if self._logged_examples_counter < self.log_first_n_examples: - input_ids = task_encoding.inputs["input_ids"] - tokens = self.tokenizer.convert_ids_to_tokens(input_ids) - logger.info("*** Example ***") - logger.info(f"doc id: {task_encoding.document.id}") - logger.info(f"tokens: {' '.join([x for x in tokens])}") - logger.info(f"input_ids: {' '.join([str(x) for x in input_ids.tolist()])}") - # target data - span_start_indices = task_encoding.inputs["span_start_indices"] - span_end_indices = task_encoding.inputs["span_end_indices"] - labels = [self.id_to_label[label] for label in target["labels"].tolist()] - for i, (label, tuple_indices) in enumerate( - zip(labels, task_encoding.inputs["tuple_indices"]) - ): - logger.info(f"relation {i}: {label}") - for j, arg_idx in enumerate(tuple_indices): - arg_tokens = tokens[span_start_indices[arg_idx] : span_end_indices[arg_idx]] - logger.info(f"\targ {j}: {' '.join([str(x) for x in arg_tokens])}") - - self._logged_examples_counter += 1 - - def collate( - self, task_encodings: Sequence[TaskEncodingType] - ) -> Tuple[ModelInputType, Optional[ModelTargetType]]: - input_keys = task_encodings[0].inputs.keys() - inputs: ModelInputType = { # type: ignore - key: pad_or_stack(key, [task_encoding.inputs[key] for task_encoding in task_encodings]) - for key in input_keys - } - - targets: Optional[ModelTargetType] = None - if task_encodings[0].has_targets: - target_keys = task_encodings[0].targets.keys() - targets: ModelTargetType = { # type: ignore - key: pad_or_stack( - key, [task_encoding.targets[key] for task_encoding in task_encodings] - ) - for key in target_keys - } - - return inputs, targets - - def unbatch_output(self, model_output: ModelTargetType) -> Sequence[TaskOutputType]: - # shape: (batch_size, num_candidates) - label_ids = model_output["labels"].detach().cpu().tolist() - # shape: (batch_size, num_candidates, num_labels) - all_probabilities = model_output["probabilities"].detach().cpu().tolist() - unbatched_output = [] - for batch_idx in range(len(label_ids)): - labels = [] - probabilities = [] - for label_id, probs in zip(label_ids[batch_idx], all_probabilities[batch_idx]): - labels.append(self.id_to_label[label_id]) - probabilities.append(probs[label_id]) - entry: TaskOutputType = { - "labels": labels, - "probabilities": probabilities, - } - unbatched_output.append(entry) - - return unbatched_output - - def decode_annotations( - self, - task_output: TaskOutputType, - task_encoding: TaskEncodingType, - ) -> Dict[str, List[Annotation]]: - char2token_spans = task_encoding.metadata["tokenized_annotations"]["labeled_spans"] - token2char_spans = {v: k for k, v in char2token_spans.items()} - injected2original_spans = task_encoding.metadata["injected2original_spans"] - new_relations = [] - for candidate_relation, label, probability, is_valid in zip( - task_encoding.metadata["candidate_relations"], - task_output["labels"], - task_output["probabilities"], - task_encoding.inputs["tuple_indices_mask"], - ): - # exclude - # - padding entries (is_valid=False) - # - negative relations (if we have added them) - if is_valid and ( - label != self.no_relation_label or not self.create_candidate_relations - ): - token_head, token_tail = candidate_relation.head, candidate_relation.tail - char_head = token2char_spans[token_head] - char_tail = token2char_spans[token_tail] - original_head = injected2original_spans[char_head] - original_tail = injected2original_spans[char_tail] - new_annotation = candidate_relation.copy( - head=original_head, tail=original_tail, label=label, score=probability - ) - new_relations.append(new_annotation) - - return {"binary_relations": new_relations} - - def create_annotations_from_output( - self, - task_encoding: TaskEncodingType, - task_output: TaskOutputType, - ) -> Iterator[Tuple[str, Union[BinaryRelation, MultiLabeledBinaryRelation, NaryRelation]]]: - decoded_annotations = self.decode_annotations( - task_output=task_output, task_encoding=task_encoding - ) - - for relation in decoded_annotations["binary_relations"]: - yield self.relation_annotation, relation - - def configure_model_metric(self, stage: str) -> Metric: - if self.label_to_id is None: - raise ValueError( - "The taskmodule has not been prepared yet, so label_to_id is not known. " - "Please call taskmodule.prepare(documents) before configuring the model metric " - "or pass the labels to the taskmodule constructor an call taskmodule.post_prepare()." - ) - labels = [self.id_to_label[i] for i in range(len(self.label_to_id))] - common_metric_kwargs = { - "num_classes": len(labels), - "task": "multiclass", - "ignore_index": PAD_VALUES["labels"], - } - return WrappedMetricWithPrepareFunction( - metric=MetricCollection( - { - "micro/f1": F1Score(average="micro", **common_metric_kwargs), - "macro/f1": F1Score(average="macro", **common_metric_kwargs), - "f1_per_label": ClasswiseWrapper( - F1Score(average=None, **common_metric_kwargs), - labels=labels, - postfix="/f1", - ), - } - ), - prepare_function=_get_label_ids_from_model_output, - ) diff --git a/src/pie_modules/taskmodules/re_text_classification_with_indices.py b/src/pie_modules/taskmodules/re_text_classification_with_indices.py deleted file mode 100644 index 61fa41f25..000000000 --- a/src/pie_modules/taskmodules/re_text_classification_with_indices.py +++ /dev/null @@ -1,1508 +0,0 @@ -""" -workflow: - Document - -> (InputEncoding, TargetEncoding) -> TaskEncoding -> TaskBatchEncoding - -> ModelBatchEncoding -> ModelBatchOutput - -> TaskOutput - -> Document -""" - -import logging -from collections import defaultdict -from functools import partial -from typing import ( - Any, - Dict, - Iterable, - Iterator, - List, - Optional, - Sequence, - Set, - Tuple, - Type, - TypedDict, - Union, -) - -import numpy as np -import torch -from pie_core import ( - Annotation, - AnnotationLayer, - Document, - TaskEncoding, - TaskModule, -) -from pytorch_ie.taskmodules.interface import ChangesTokenizerVocabSize -from pytorch_ie.utils.window import get_window_around_slice -from torch import LongTensor -from torchmetrics import ClasswiseWrapper, F1Score, MetricCollection -from transformers import AutoTokenizer -from transformers.file_utils import PaddingStrategy -from transformers.tokenization_utils_base import TruncationStrategy -from typing_extensions import TypeAlias, TypeVar - -from pie_modules.annotations import ( - BinaryRelation, - LabeledSpan, - MultiLabeledBinaryRelation, - NaryRelation, - Span, -) -from pie_modules.documents import ( - TextBasedDocument, - TextDocumentWithLabeledSpansAndBinaryRelations, - TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, -) -from pie_modules.models.simple_sequence_classification import ( - InputType as ModelInputType, -) -from pie_modules.models.simple_sequence_classification import ( - TargetType as ModelTargetType, -) -from pie_modules.taskmodules.common.mixins import RelationStatisticsMixin -from pie_modules.taskmodules.metrics import WrappedMetricWithPrepareFunction -from pie_modules.utils.span import distance as span_distance -from pie_modules.utils.span import is_contained_in -from pie_modules.utils.tokenization import ( - SpanNotAlignedWithTokenException, - get_aligned_token_span, -) - -InputEncodingType: TypeAlias = Dict[str, Any] -TargetEncodingType: TypeAlias = Sequence[int] -DocumentType: TypeAlias = TextBasedDocument - -TaskEncodingType: TypeAlias = TaskEncoding[ - DocumentType, - InputEncodingType, - TargetEncodingType, -] - - -class TaskOutputType(TypedDict, total=False): - labels: Sequence[str] - probabilities: Sequence[float] - - -TaskModuleType: TypeAlias = TaskModule[ - # _InputEncoding, _TargetEncoding, _TaskBatchEncoding, _ModelBatchOutput, _TaskOutput - DocumentType, - InputEncodingType, - TargetEncodingType, - Tuple[ModelInputType, Optional[ModelTargetType]], - ModelTargetType, - TaskOutputType, -] - - -HEAD = "head" -TAIL = "tail" -START = "start" -END = "end" - - -logger = logging.getLogger(__name__) - - -def _get_labels(model_output: ModelTargetType) -> LongTensor: - return model_output["labels"] - - -def _get_labels_together_remove_none_label( - predictions: ModelTargetType, targets: ModelTargetType, none_idx: int -) -> Tuple[LongTensor, LongTensor]: - mask_not_both_none = (predictions["labels"] != none_idx) | (targets["labels"] != none_idx) - predictions_not_none = predictions["labels"][mask_not_both_none] - targets_not_none = targets["labels"][mask_not_both_none] - return predictions_not_none, targets_not_none - - -def find_sublist(sub: List, bigger: List) -> int: - if not bigger: - return -1 - if not sub: - return 0 - first, rest = sub[0], sub[1:] - pos = 0 - try: - while True: - pos = bigger.index(first, pos) + 1 - if not rest or bigger[pos : pos + len(rest)] == rest: - return pos - 1 - except ValueError: - return -1 - - -class MarkerFactory: - def __init__(self, role_to_marker: Dict[str, str]): - self.role_to_marker = role_to_marker - - def _get_role_marker(self, role: str) -> str: - return self.role_to_marker[role] - - def _get_marker(self, role: str, is_start: bool, label: Optional[str] = None) -> str: - result = "[" - if not is_start: - result += "/" - result += self._get_role_marker(role) - if label is not None: - result += f":{label}" - result += "]" - return result - - def get_start_marker(self, role: str, label: Optional[str] = None) -> str: - return self._get_marker(role=role, is_start=True, label=label) - - def get_end_marker(self, role: str, label: Optional[str] = None) -> str: - return self._get_marker(role=role, is_start=False, label=label) - - def get_append_marker(self, role: str, label: Optional[str] = None) -> str: - role_marker = self._get_role_marker(role) - if label is None: - return f"[{role_marker}]" - else: - return f"[{role_marker}={label}]" - - @property - def all_roles(self) -> Set[str]: - return set(self.role_to_marker) - - def get_all_markers( - self, - entity_labels: List[str], - append_markers: bool = False, - add_type_to_marker: bool = False, - ) -> List[str]: - result: Set[str] = set() - if add_type_to_marker: - none_and_labels = [None] + entity_labels - else: - none_and_labels = [None] - for role in self.all_roles: - # create start and end markers without label and for all labels, if add_type_to_marker - for maybe_label in none_and_labels: - result.add(self.get_start_marker(role=role, label=maybe_label)) - result.add(self.get_end_marker(role=role, label=maybe_label)) - # create append markers for all labels - if append_markers: - for entity_label in entity_labels: - result.add(self.get_append_marker(role=role, label=entity_label)) - - # sort and convert to list - return sorted(result) - - -class RelationArgument: - def __init__( - self, - entity: LabeledSpan, - role: str, - token_span: Span, - add_type_to_marker: bool, - marker_factory: MarkerFactory, - ) -> None: - self.marker_factory = marker_factory - if role not in self.marker_factory.all_roles: - raise ValueError( - f"role='{role}' not in known roles={sorted(self.marker_factory.all_roles)} (did you " - f"initialise the taskmodule with the correct argument_role_to_marker dictionary?)" - ) - - self.entity = entity - - self.role = role - self.token_span = token_span - self.add_type_to_marker = add_type_to_marker - - @property - def maybe_label(self) -> Optional[str]: - return self.entity.label if self.add_type_to_marker else None - - @property - def as_start_marker(self) -> str: - return self.marker_factory.get_start_marker(role=self.role, label=self.maybe_label) - - @property - def as_end_marker(self) -> str: - return self.marker_factory.get_end_marker(role=self.role, label=self.maybe_label) - - @property - def as_append_marker(self) -> str: - # Note: we add the label in either case (we use self.entity.label instead of self.label) - return self.marker_factory.get_append_marker(role=self.role, label=self.entity.label) - - def shift_token_span(self, value: int): - self.token_span = Span( - start=self.token_span.start + value, end=self.token_span.end + value - ) - - -def get_relation_argument_spans_and_roles( - relation: Annotation, -) -> Tuple[Tuple[str, Annotation], ...]: - if isinstance(relation, BinaryRelation): - return (HEAD, relation.head), (TAIL, relation.tail) - elif isinstance(relation, NaryRelation): - # create unique order by sorting the arguments by their start and end positions and role - sorted_args = sorted( - zip(relation.roles, relation.arguments), - key=lambda role_and_span: ( - role_and_span[1].start, - role_and_span[1].end, - role_and_span[0], - ), - ) - return tuple(sorted_args) - else: - raise NotImplementedError( - f"the taskmodule does not yet support getting relation arguments for type: {type(relation)}" - ) - - -def construct_mask(input_ids: torch.LongTensor, positive_ids: List[Any]) -> torch.LongTensor: - """Construct a mask for the input_ids where all entries in mask_ids are 1.""" - masks = [torch.nonzero(input_ids == marker_token_id) for marker_token_id in positive_ids] - globs = torch.cat(masks) - value = torch.ones(globs.shape[0], dtype=int) - mask = torch.zeros(input_ids.shape, dtype=int) - mask.index_put_(tuple(globs.t()), value) - return mask - - -S = TypeVar("S", bound=Span) - - -def shift_span(span: S, offset: int) -> S: - return span.copy(start=span.start + offset, end=span.end + offset) - - -def bio_encode_spans( - spans: List[Tuple[int, int, str]], total_length: int, label2idx: Dict[str, int] -) -> List[int]: - # result = ["O"] * total_length - result = [0] * total_length - for start, end, label in spans: - # result[start] = f"B-{label}" - result[start] = label2idx[label] * 2 + 1 - for i in range(start + 1, end): - # result[i] = f"I-{label}" - result[i] = label2idx[label] * 2 + 2 - return result - - -@TaskModule.register() -class RETextClassificationWithIndicesTaskModule( - RelationStatisticsMixin, - TaskModuleType, - ChangesTokenizerVocabSize, -): - """Marker based relation extraction. This taskmodule prepares the input token ids in such a way - that before and after the candidate head and tail entities special marker tokens are inserted. - Then, the modified token ids can be simply passed into a transformer based text classifier - model. - - parameters: - - partition_annotation: str, optional. If specified, LabeledSpan annotations with this name are - expected to define partitions of the document that will be processed individually, e.g. sentences - or sections of the document text. - none_label: str, defaults to "no_relation". The relation label that indicate dummy/negative relations. - Predicted relations with that label will not be added to the document(s). - max_window: int, optional. If specified, use the tokens in a window of maximal this amount of tokens - around the center of head and tail entities and pass only that into the transformer. - create_relation_candidates: bool, defaults to False. If True, create relation candidates by pairwise - combining all entities in the document and assigning the none_label. If the document already contains - a relation with the entity pair, we do not add it again. If False, assume that the document already - contains relation annotations including negative examples (i.e. relations with the none_label). - handle_relations_with_same_arguments: str, defaults to "keep_none". If "keep_none", all relations that - share same arguments will be removed. If "keep_first", first occurred duplicate will be kept. - argument_type_whitelist: List[List[str]], optional, defaults to None. If set, only relations (candidates) - with given argument type tuples are created from document and by by `create_relation_candidates`. - This affects only model input. - argument_and_relation_type_whitelist: Union[Dict[str, List[List[str]]], List[List[str]]], optional, - defaults None. If set, only given relation types with given argument types will persist in - documents and generated by `create_relation_candidates`. This also affects predictions on - `decode()`, so it strictly filters both model input and output. Can also be passed as a list - of lists, where the first element is the relation type and the rest are the argument types. - """ - - PREPARED_ATTRIBUTES = ["labels", "entity_labels"] - - def __init__( - self, - tokenizer_name_or_path: str, - relation_annotation: str = "binary_relations", - add_candidate_relations: bool = False, - add_reversed_relations: bool = False, - partition_annotation: Optional[str] = None, - none_label: str = "no_relation", - padding: Union[bool, str, PaddingStrategy] = True, - truncation: Union[bool, str, TruncationStrategy] = True, - max_length: Optional[int] = None, - pad_to_multiple_of: Optional[int] = None, - multi_label: bool = False, - labels: Optional[List[str]] = None, - label_to_id: Optional[Dict[str, int]] = None, - add_type_to_marker: bool = False, - argument_role_to_marker: Optional[Dict[str, str]] = None, - single_argument_pair: bool = True, - append_markers: bool = False, - insert_markers: bool = True, - entity_labels: Optional[List[str]] = None, - reversed_relation_label_suffix: str = "_reversed", - symmetric_relations: Optional[List[str]] = None, - reverse_symmetric_relations: bool = True, - max_argument_distance: Optional[int] = None, - max_argument_distance_type: str = "inner", - max_argument_distance_tokens: Optional[int] = None, - max_argument_distance_type_tokens: str = "inner", - max_window: Optional[int] = None, - allow_discontinuous_text: bool = False, - log_first_n_examples: int = 0, - add_argument_indices_to_input: bool = False, - add_argument_tags_to_input: bool = False, - add_entity_tags_to_input: bool = False, - add_global_attention_mask_to_input: bool = False, - argument_type_whitelist: Optional[List[List[str]]] = None, - handle_relations_with_same_arguments: str = "keep_none", - argument_and_relation_type_whitelist: Optional[ - Union[Dict[str, List[List[str]]], List[List[str]]] - ] = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - if label_to_id is not None: - logger.warning( - "The parameter label_to_id is deprecated and will be removed in a future version. " - "Please use labels instead." - ) - id_to_label = {v: k for k, v in label_to_id.items()} - # reconstruct labels from label_to_id. Note that we need to remove the none_label - labels = [ - id_to_label[i] for i in range(len(id_to_label)) if id_to_label[i] != none_label - ] - self.save_hyperparameters(ignore=["label_to_id"]) - - self.relation_annotation = relation_annotation - self.add_candidate_relations = add_candidate_relations - self.add_reversed_relations = add_reversed_relations - self.padding = padding - self.truncation = truncation - self.labels = labels - self.max_length = max_length - self.pad_to_multiple_of = pad_to_multiple_of - self.multi_label = multi_label - self.add_type_to_marker = add_type_to_marker - self.single_argument_pair = single_argument_pair - self.append_markers = append_markers - self.insert_markers = insert_markers - self.entity_labels = entity_labels - self.partition_annotation = partition_annotation - self.none_label = none_label - self.reversed_relation_label_suffix = reversed_relation_label_suffix - self.symmetric_relations = set(symmetric_relations or []) - self.reverse_symmetric_relations = reverse_symmetric_relations - self.max_argument_distance = max_argument_distance - self.max_argument_distance_type = max_argument_distance_type - self.max_argument_distance_tokens = max_argument_distance_tokens - self.max_argument_distance_type_tokens = max_argument_distance_type_tokens - self.max_window = max_window - self.allow_discontinuous_text = allow_discontinuous_text - self.handle_relations_with_same_arguments = handle_relations_with_same_arguments - self.argument_type_whitelist: Optional[Set[Tuple[str, ...]]] = None - self.argument_and_relation_type_whitelist: Optional[Dict[str, Set[Tuple[str, ...]]]] = None - - if argument_type_whitelist is not None: - # hydra does not support tuples, so we got lists and need to convert them - self.argument_type_whitelist = {tuple(types) for types in argument_type_whitelist} - if argument_and_relation_type_whitelist is not None: - # hydra does not support tuples, so we got lists and need to convert them - if isinstance(argument_and_relation_type_whitelist, list): - self.argument_and_relation_type_whitelist = defaultdict(set) - for types_list in argument_and_relation_type_whitelist: - if len(types_list) < 1: - raise ValueError( - "argument_and_relation_type_whitelist must be a list of lists with at least one element" - ) - self.argument_and_relation_type_whitelist[types_list[0]].add( - tuple(types_list[1:]) - ) - else: - self.argument_and_relation_type_whitelist = { - rel: {tuple(types) for types in types_list} - for rel, types_list in argument_and_relation_type_whitelist.items() - } - # overwrite None with 0 for backward compatibility - self.log_first_n_examples = log_first_n_examples or 0 - self.add_argument_indices_to_input = add_argument_indices_to_input - self.add_argument_tags_to_input = add_argument_tags_to_input - self.add_entity_tags_to_input = add_entity_tags_to_input - self.add_global_attention_mask_to_input = add_global_attention_mask_to_input - if argument_role_to_marker is None: - self.argument_role_to_marker = {HEAD: "H", TAIL: "T"} - else: - self.argument_role_to_marker = argument_role_to_marker - - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) - - # used when allow_discontinuous_text - self.glue_token_ids = self._get_glue_token_ids() - - self.argument_markers = None - - self._logged_examples_counter = 0 - - def _get_glue_token_ids(self): - dummy_ids = self.tokenizer.build_inputs_with_special_tokens( - token_ids_0=[-1], token_ids_1=[-2] - ) - return dummy_ids[dummy_ids.index(-1) + 1 : dummy_ids.index(-2)] - - @property - def document_type(self) -> Optional[Type[DocumentType]]: - if self.partition_annotation is not None: - dt = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions - else: - dt = TextDocumentWithLabeledSpansAndBinaryRelations - if self.relation_annotation == "binary_relations": - return dt - else: - logger.warning( - f"relation_annotation={self.relation_annotation} is " - f"not the default value ('binary_relations'), so the taskmodule {type(self).__name__} can not request " - f"the usual document type for auto-conversion ({dt.__name__}) because this has the bespoken default " - f"value as layer name instead of the provided one." - ) - return None - - def get_relation_layer(self, document: Document) -> AnnotationLayer[BinaryRelation]: - return document[self.relation_annotation] - - def get_entity_layer(self, document: Document) -> AnnotationLayer[LabeledSpan]: - relations: AnnotationLayer[BinaryRelation] = self.get_relation_layer(document) - return relations.target_layer - - def get_marker_factory(self) -> MarkerFactory: - return MarkerFactory(role_to_marker=self.argument_role_to_marker) - - def _prepare(self, documents: Sequence[DocumentType]) -> None: - entity_labels: Set[str] = set() - relation_labels: Set[str] = set() - for document in documents: - relations: AnnotationLayer[BinaryRelation] = self.get_relation_layer(document) - entities: AnnotationLayer[LabeledSpan] = self.get_entity_layer(document) - - for entity in entities: - entity_labels.add(entity.label) - - for relation in relations: - relation_labels.add(relation.label) - if self.add_reversed_relations: - if relation.label.endswith(self.reversed_relation_label_suffix): - raise ValueError( - f"doc.id={document.id}: the relation label '{relation.label}' already ends with " - f"the reversed_relation_label_suffix '{self.reversed_relation_label_suffix}', " - f"this is not allowed because we would not know if we should strip the suffix and " - f"revert the arguments during inference or not" - ) - if relation.label not in self.symmetric_relations: - relation_labels.add(relation.label + self.reversed_relation_label_suffix) - - if self.none_label in relation_labels: - relation_labels.remove(self.none_label) - - self.labels = sorted(relation_labels) - self.entity_labels = sorted(entity_labels) - - def encode(self, *args, **kwargs): - self.reset_statistics() - res = super().encode(*args, **kwargs) - self.show_statistics() - return res - - def _post_prepare(self): - self.label_to_id = {label: i + 1 for i, label in enumerate(self.labels)} - self.label_to_id[self.none_label] = 0 - self.id_to_label = {v: k for k, v in self.label_to_id.items()} - - self.marker_factory = self.get_marker_factory() - self.argument_markers = self.marker_factory.get_all_markers( - append_markers=self.append_markers, - add_type_to_marker=self.add_type_to_marker, - entity_labels=self.entity_labels, - ) - self.tokenizer.add_tokens(self.argument_markers, special_tokens=True) - - self.argument_markers_to_id = { - marker: self.tokenizer.vocab[marker] for marker in self.argument_markers - } - - self.argument_role2idx = { - role: i for i, role in enumerate(sorted(self.marker_factory.all_roles)) - } - - def _add_reversed_relations( - self, - arguments2relation: Dict[Tuple[Tuple[str, Annotation], ...], Annotation], - doc_id: Optional[str] = None, - ) -> None: - if self.add_reversed_relations: - for arguments, rel in list(arguments2relation.items()): - arg_roles, arg_spans = zip(*arguments) - if isinstance(rel, BinaryRelation): - label = rel.label - if label in self.symmetric_relations and not self.reverse_symmetric_relations: - continue - if label.endswith(self.reversed_relation_label_suffix): - raise ValueError( - f"doc.id={doc_id}: The relation has the label '{label}' which already ends with the " - f"reversed_relation_label_suffix='{self.reversed_relation_label_suffix}'. " - f"It looks like the relation is already reversed, which is not allowed." - ) - if rel.label not in self.symmetric_relations: - label += self.reversed_relation_label_suffix - - reversed_rel = BinaryRelation( - head=rel.tail, - tail=rel.head, - label=label, - score=rel.score, - ) - reversed_arguments = get_relation_argument_spans_and_roles(reversed_rel) - if reversed_arguments in arguments2relation: - prev_rel = arguments2relation[reversed_arguments] - prev_label = prev_rel.label - logger.warning( - f"doc.id={doc_id}: there is already a relation with reversed " - f"arguments={reversed_arguments} and label={prev_label}, so we do not add the reversed " - f"relation (with label {prev_label}) for these arguments" - ) - if self.collect_statistics: - self.collect_relation("skipped_reversed_same_arguments", reversed_rel) - continue - elif rel.label in self.symmetric_relations: - # warn if the original relation arguments were not sorted by their start and end positions - # in the case of symmetric relations - if not all(isinstance(arg_span, Span) for arg_span in arg_spans): - raise NotImplementedError( - f"doc.id={doc_id}: the taskmodule does not yet support adding reversed relations " - f"for symmetric relations with arguments that are no Spans: {arguments}" - ) - args_sorted = sorted( - [rel.head, rel.tail], key=lambda span: (span.start, span.end) - ) - if args_sorted != [rel.head, rel.tail]: - logger.warning( - f"doc.id={doc_id}: The symmetric relation with label '{label}' has arguments " - f"{arguments} which are not sorted by their start and end positions. " - f"This may lead to problems during evaluation because we assume that the " - f"arguments of symmetric relations were sorted in the beginning and, thus, interpret " - f"relations where this is not the case as reversed. All reversed relations will get " - f"their arguments swapped during inference in the case of add_reversed_relations=True " - f"to remove duplicates. You may consider adding reversed versions of the *symmetric* " - f"relations on your own and then setting *reverse_symmetric_relations* to False." - ) - if self.collect_statistics: - self.collect_relation( - "used_not_sorted_reversed_arguments", reversed_rel - ) - - arguments2relation[reversed_arguments] = reversed_rel - else: - raise NotImplementedError( - f"doc.id={doc_id}: the taskmodule does not yet support adding reversed relations for type: " - f"{type(rel)}" - ) - - def _filter_relations_by_argument_and_relation_type_whitelist( - self, - arguments2relation: Dict[Tuple[Tuple[str, Annotation], ...], Annotation], - doc_id: Optional[str] = None, - ) -> None: - if self.argument_and_relation_type_whitelist is not None: - for arguments, relation in list(arguments2relation.items()): - argument_labels = tuple(getattr(ann, "label") for role, ann in arguments) - relation_label = getattr(relation, "label") - if ( - relation_label not in self.argument_and_relation_type_whitelist - or argument_labels - not in self.argument_and_relation_type_whitelist[relation_label] - ): - rel = arguments2relation.pop(arguments) - self.collect_relation("skipped_argument_and_relation_type_whitelist", rel) - - def _filter_relations_by_argument_type_whitelist( - self, - arguments2relation: Dict[Tuple[Tuple[str, Annotation], ...], Annotation], - doc_id: Optional[str] = None, - ) -> None: - if self.argument_type_whitelist is not None: - for arguments, rel in list(arguments2relation.items()): - argument_labels = tuple(getattr(arg, "label") for _, arg in arguments) - if argument_labels not in self.argument_type_whitelist: - rel = arguments2relation.pop(arguments) - self.collect_relation("skipped_argument_type_whitelist", rel) - - def _add_candidate_relations( - self, - arguments2relation: Dict[Tuple[Tuple[str, Annotation], ...], Annotation], - entities: Iterable[Span], - arguments_blacklist: Optional[Set[Tuple[Tuple[str, Annotation], ...]]] = None, - doc_id: Optional[str] = None, - ) -> None: - if self.add_candidate_relations: - if self.marker_factory.all_roles == {HEAD, TAIL}: - # flatten argument_and_relation_type_whitelist values - arg_rel_whitelist_vals_set = ( - None - if self.argument_and_relation_type_whitelist is None - else {i for j in self.argument_and_relation_type_whitelist.values() for i in j} - ) - # iterate over all possible argument candidates - for head in entities: - for tail in entities: - if head == tail: - continue - - # Create a relation candidate with the none label. Otherwise, we use the existing relation. - new_relation = BinaryRelation( - head=head, tail=tail, label=self.none_label, score=1.0 - ) - new_relation_args = get_relation_argument_spans_and_roles(new_relation) - arg_roles, arg_spans = zip(*new_relation_args) - arg_labels = tuple(getattr(ann, "label") for ann in arg_spans) - - # Skip if argument_type_whitelist and/or argument_and_relation_type_whitelist - # are defined and current candidates do not fit. - if ( - self.argument_type_whitelist is not None - and arg_labels not in self.argument_type_whitelist - ) or ( - arg_rel_whitelist_vals_set is not None - and arg_labels not in arg_rel_whitelist_vals_set - ): - continue - - # check blacklist - if ( - arguments_blacklist is not None - and new_relation_args in arguments_blacklist - ): - continue - - # we use the new relation only if there is no existing relation with the same arguments - if new_relation_args not in arguments2relation: - arguments2relation[new_relation_args] = new_relation - else: - raise NotImplementedError( - f"doc.id={doc_id}: the taskmodule does not yet support adding relation candidates " - f"with argument roles other than 'head' and 'tail': {sorted(self.marker_factory.all_roles)}" - ) - - def _filter_relations_by_argument_distance( - self, - arguments2relation: Dict[Tuple[Tuple[str, Annotation], ...], Annotation], - doc_id: Optional[str] = None, - ) -> None: - if self.max_argument_distance is not None: - for arguments, rel in list(arguments2relation.items()): - if isinstance(rel, BinaryRelation): - if isinstance(rel.head, Span) and isinstance(rel.tail, Span): - dist = span_distance( - (rel.head.start, rel.head.end), - (rel.tail.start, rel.tail.end), - self.max_argument_distance_type, - ) - if dist > self.max_argument_distance: - arguments2relation.pop(arguments) - self.collect_relation("skipped_argument_distance", rel) - else: - raise NotImplementedError( - f"doc.id={doc_id}: the taskmodule does not yet support filtering relation candidates " - f"with arguments of type: {type(rel.head)} and {type(rel.tail)}" - ) - else: - raise NotImplementedError( - f"doc.id={doc_id}: the taskmodule does not yet support filtering relation candidates for " - f"type: {type(rel)}" - ) - - def encode_input( - self, - document: DocumentType, - is_training: bool = False, - ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]: - all_relations: Sequence[Annotation] = self.get_relation_layer(document) - all_entities: Sequence[Span] = self.get_entity_layer(document) - self.collect_all_relations("available", all_relations) - - partitions: Sequence[Span] - if self.partition_annotation is not None: - partitions = document[self.partition_annotation] - if len(partitions) == 0: - logger.warning( - f"the document {document.id} has no '{self.partition_annotation}' partition entries, " - f"no inputs will be created!" - ) - else: - # use single dummy partition - partitions = [Span(start=0, end=len(document.text))] - - task_encodings: List[TaskEncodingType] = [] - for partition in partitions: - # get all entities that are contained in the current partition - entities: List[Span] = [ - entity - for entity in all_entities - if is_contained_in((entity.start, entity.end), (partition.start, partition.end)) - ] - - # Create a mapping from relation arguments to the respective relation objects. - # Note that the data can contain multiple relations with the same arguments. - entities_set = set(entities) - arguments2relations: Dict[Tuple[Tuple[str, Annotation], ...], List[Annotation]] = ( - defaultdict(list) - ) - for rel in all_relations: - # Skip relations with unknown labels. Use label_to_id because that contains the none_label - if rel.label not in self.label_to_id: - self.collect_relation("skipped_unknown_label", rel) - continue - - arguments = get_relation_argument_spans_and_roles(rel) - arg_roles, arg_spans = zip(*arguments) - - # filter out all relations that are completely outside the current partition - if all(arg_span not in entities_set for arg_span in arg_spans): - continue - - # filter relations that are only partially contained in the current partition, - # i.e. some arguments are in the partition and some are not - if any(arg_span not in entities_set for arg_span in arg_spans): - logger.warning( - f"doc.id={document.id}: there is a relation with label '{rel.label}' and arguments " - f"{arguments} that is only partially contained in the current partition. " - f"We skip this relation." - ) - self.collect_relation("skipped_partially_contained", rel) - continue - arguments2relations[arguments].append(rel) - - # resolve duplicates for same arguments - arguments2relation: Dict[Tuple[Tuple[str, Annotation], ...], Annotation] = {} - # we will never create an encoding for the relation candidates in arguments_blacklist - arguments_blacklist: Set[Tuple[Tuple[str, Annotation], ...]] = set() - for arguments, relations in arguments2relations.items(): - relations_set = set(relations) - # more than one unique relation with the same arguments - if len(relations_set) > 1: - arguments_resolved = tuple(map(lambda x: (x[0], x[1].resolve()), arguments)) - labels = [rel.label for rel in relations] - if self.handle_relations_with_same_arguments == "keep_first": - # keep only the first relation - arguments2relation[arguments] = relations[0] - for discard_rel in set(relations) - { - relations[0] - }: # remove all other relations - self.collect_relation("skipped_same_arguments", discard_rel) - if not self.collect_statistics: - # We show this warning only if statistics are disabled. - # We want to be informed if such skip occurs, but having it in statistics and - # getting lots of warnings in the same time seemed overwhelming. - logger.warning( - f"doc.id={document.id}: there are multiple relations with the same arguments " - f"{arguments_resolved}, but different labels: {labels}. We only keep the first " - f"occurring relation which has the label='{relations[0].label}'." - ) - elif self.handle_relations_with_same_arguments == "keep_none": - # add these arguments to the blacklist to not add them as 'no-relation's back again - arguments_blacklist.add(arguments) - # remove all relations with the same arguments - for discard_rel in relations_set: - self.collect_relation("skipped_same_arguments", discard_rel) - if not self.collect_statistics: - logger.warning( - f"doc.id={document.id}: there are multiple relations with the same arguments " - f"{arguments_resolved}, but different labels: {labels}. All relations will be removed." - ) - else: - raise ValueError( - f"'handle_relations_with_same_arguments' must be 'keep_first' or 'keep_none', " - f"but got `{self.handle_relations_with_same_arguments}`." - ) - else: - arguments2relation[arguments] = relations[0] - # more than one duplicate relation (with the same arguments) - if len(relations) > 1: - # if 'collect_statistics=true' such duplicates won't be collected and are not counted in - # statistics if 'collect_statistics=true' either as 'available' or as 'skipped_same_arguments' - logger.warning( - f"doc.id={document.id}: Relation annotation `{rel.resolve()}` is duplicated. " - f"We keep only one of them. Duplicate won't appear in statistics either as 'available' " - f"or as skipped." - ) - - # We use this filter before adding reversed relations because we also don't want them to be reversed - self._filter_relations_by_argument_and_relation_type_whitelist( - arguments2relation=arguments2relation, doc_id=document.id - ) - self._add_reversed_relations(arguments2relation=arguments2relation, doc_id=document.id) - self._filter_relations_by_argument_type_whitelist( - arguments2relation=arguments2relation, doc_id=document.id - ) - self._add_candidate_relations( - arguments2relation=arguments2relation, - arguments_blacklist=arguments_blacklist, - entities=entities, - doc_id=document.id, - ) - - self._filter_relations_by_argument_distance( - arguments2relation=arguments2relation, doc_id=document.id - ) - - without_special_tokens = self.max_window is not None - text = document.text[partition.start : partition.end] - encoding = self.tokenizer( - text, - padding=False, - truncation=self.truncation if self.max_window is None else False, - max_length=self.max_length, - is_split_into_words=False, - return_offsets_mapping=False, - add_special_tokens=not without_special_tokens, - ) - - for arguments, rel in arguments2relation.items(): - arg_roles, arg_spans = zip(*arguments) - if not all(isinstance(arg, LabeledSpan) for arg in arg_spans): - # TODO: add test case for this - raise ValueError( - f"the taskmodule expects the relation arguments to be of type LabeledSpan, " - f"but got {[type(arg) for arg in arg_spans]}" - ) - - arg_spans_partition = [ - shift_span(span, offset=-partition.start) for span in arg_spans - ] - # map character spans to token spans - try: - arg_token_spans = [ - get_aligned_token_span( - encoding=encoding, - char_span=arg, - ) - for arg in arg_spans_partition - ] - # Check if the mapping was successful. It may fail (and is None) if any argument start or end does not - # match a token start or end, respectively. - except SpanNotAlignedWithTokenException as e: - span_original = shift_span(e.span, offset=partition.start) - # the span is not attached because we shifted it above, so we can not use str(e.span) - span_text = document.text[span_original.start : span_original.end] - logger.warning( - f"doc.id={document.id}: Skipping invalid example, cannot get argument token slice for " - f'{span_original}: "{span_text}"' - ) - self.collect_relation("skipped_args_not_aligned", rel) - continue - - # create the argument objects - args = [ - RelationArgument( - entity=span, - role=role, - token_span=token_span, - add_type_to_marker=self.add_type_to_marker, - marker_factory=self.marker_factory, - ) - for span, role, token_span in zip(arg_spans, arg_roles, arg_token_spans) - ] - - if self.max_argument_distance_tokens is not None: - token_distances = [] - for idx1 in range(len(args) - 1): - for idx in range(idx1 + 1, len(args)): - arg1 = args[idx1] - arg2 = args[idx] - dist = span_distance( - (arg1.token_span.start, arg1.token_span.end), - (arg2.token_span.start, arg2.token_span.end), - self.max_argument_distance_type_tokens, - ) - token_distances.append(dist) - if len(token_distances) > 0: - if self.max_argument_distance_type_tokens == "outer": - max_dist = max(token_distances) - elif self.max_argument_distance_type_tokens == "inner": - if len(args) > 2: - raise NotImplementedError( - f"max_argument_distance_type_tokens={self.max_argument_distance_type_tokens} " - f"is not supported for relations with more than 2 arguments" - ) - max_dist = max(token_distances) - else: - raise NotImplementedError( - f"max_argument_distance_type_tokens={self.max_argument_distance_type_tokens} " - f"is not supported" - ) - if max_dist > self.max_argument_distance_tokens: - self.collect_relation("skipped_argument_distance_tokens", rel) - continue - - input_ids = encoding["input_ids"] - - entity_tags = None - if self.add_entity_tags_to_input: - entity_spans_partition = [ - shift_span(span, offset=-partition.start) for span in entities - ] - entity_token_spans = [] - for span in entity_spans_partition: - try: - entity_token_spans.append( - get_aligned_token_span( - encoding=encoding, - char_span=span, - ) - ) - except SpanNotAlignedWithTokenException as e: - span_original = shift_span(e.span, offset=partition.start) - span_text = document.text[span_original.start : span_original.end] - logger.warning( - f"doc.id={document.id}: Skipping invalid example, cannot get entity token slice for " - f'{span_original}: "{span_text}"' - ) - self.collect_relation("skipped_entity_not_aligned", rel) - continue - - entity_tags = bio_encode_spans( - spans=[ - (span.start, span.end, getattr(span, "label", "ENTITY")) - for span in entity_token_spans - ], - total_length=len(input_ids), - label2idx={ - label: idx for idx, label in enumerate(self.entity_labels or []) - }, - ) - - # windowing: we restrict the input to a window of a maximal size (max_window) with the arguments - # of the candidate relation in the center (as much as possible) - if self.max_window is not None: - # The actual number of tokens needs to be lower than max_window because we add two - # marker tokens (before / after) each argument and the default special tokens - # (e.g. CLS and SEP). - max_tokens = self.max_window - self.tokenizer.num_special_tokens_to_add() - if self.insert_markers: - max_tokens -= len(args) * 2 - # if we add the markers also to the end, this decreases the available window again by - # two tokens (marker + sep) per argument - if self.append_markers: - # TODO: add test case for this - max_tokens -= len(args) * 2 - - if self.allow_discontinuous_text: - if entity_tags is not None: - raise NotImplementedError( - "allow_discontinuous_text=True is not yet supported with add_entity_tags_to_input=True" - ) - - max_tokens_per_argument = max_tokens // len(args) - max_tokens_per_argument -= len(self.glue_token_ids) - if any( - arg.token_span.end - arg.token_span.start > max_tokens_per_argument - for arg in args - ): - self.collect_relation("skipped_too_long_argument", rel) - continue - - mask = np.zeros_like(input_ids) - for arg in args: - # if the input is already fully covered by one argument frame, we keep everything - if len(input_ids) <= max_tokens_per_argument: - mask[:] = 1 - break - arg_center = (arg.token_span.end + arg.token_span.start) // 2 - arg_frame_start = arg_center - max_tokens_per_argument // 2 - # shift the frame to the right if it is out of bounds - if arg_frame_start < 0: - arg_frame_start = 0 - arg_frame_end = arg_frame_start + max_tokens_per_argument - # shift the frame to the left if it is out of bounds - # Note that this can not cause to have arg_frame_start < 0 because we already - # checked that the frame is not larger than the input. - if arg_frame_end > len(input_ids): - arg_frame_end = len(input_ids) - arg_frame_start = arg_frame_end - max_tokens_per_argument - # still, a sanity check - if arg_frame_start < 0: - raise ValueError( - f"arg_frame_start={arg_frame_start} < 0 after adjusting arg_frame_end={arg_frame_end}" - ) - mask[arg_frame_start:arg_frame_end] = 1 - offsets = np.cumsum(mask != 1) - arg_cluster_offset_values = set() - # sort by start indices - args_sorted = sorted(args, key=lambda x: x.token_span.start) - for arg in args_sorted: - offset = offsets[arg.token_span.start] - arg_cluster_offset_values.add(offset) - arg.shift_token_span(-offset) - # shift back according to inserted glue patterns - num_glues = len(arg_cluster_offset_values) - 1 - arg.shift_token_span(num_glues * len(self.glue_token_ids)) - - new_input_ids: List[int] = [] - for arg_cluster_offset_value in sorted(arg_cluster_offset_values): - if len(new_input_ids) > 0: - new_input_ids.extend(self.glue_token_ids) - segment_mask = offsets == arg_cluster_offset_value - segment_input_ids = [ - input_id - for input_id, keep in zip(input_ids, mask & segment_mask) - if keep - ] - new_input_ids.extend(segment_input_ids) - - input_ids = new_input_ids - else: - # the slice from the beginning of the first entity to the end of the second is required - slice_required = ( - min(arg.token_span.start for arg in args), - max(arg.token_span.end for arg in args), - ) - window_slice = get_window_around_slice( - slice=slice_required, - max_window_size=max_tokens, - available_input_length=len(input_ids), - ) - # this happens if slice_required (all arguments) does not fit into max_tokens (the available window) - if window_slice is None: - self.collect_relation("skipped_too_long", rel) - continue - - window_start, window_end = window_slice - input_ids = input_ids[window_start:window_end] - - if entity_tags is not None: - entity_tags = entity_tags[window_start:window_end] - - for arg in args: - arg.shift_token_span(-window_start) - - # collect all markers with their target positions, the source argument, and - marker_ids_with_positions = [] - for arg in args: - marker_ids_with_positions.append( - ( - self.argument_markers_to_id[arg.as_start_marker], - arg.token_span.start, - arg, - START, - ) - ) - marker_ids_with_positions.append( - ( - self.argument_markers_to_id[arg.as_end_marker], - arg.token_span.end, - arg, - END, - ) - ) - - # create new input ids with the markers inserted and collect new mention offsets - input_ids_with_markers = list(input_ids) - offset = 0 - arg_start_indices = [-1] * len(self.argument_role2idx) - arg_end_indices = [-1] * len(self.argument_role2idx) - marker_ids_with_positions_sorted = sorted( - marker_ids_with_positions, key=lambda id_pos: id_pos[1] - ) - for ( - marker_id, - token_position, - arg, - marker_type, - ) in marker_ids_with_positions_sorted: - if self.insert_markers: - input_ids_with_markers = ( - input_ids_with_markers[: token_position + offset] - + [marker_id] - + input_ids_with_markers[token_position + offset :] - ) - if entity_tags is not None: - entity_tags = ( - entity_tags[: token_position + offset] - + [0] - + entity_tags[token_position + offset :] - ) - offset += 1 - if self.add_argument_indices_to_input or self.add_argument_tags_to_input: - idx = self.argument_role2idx[arg.role] - if marker_type == START: - if arg_start_indices[idx] != -1: - # TODO: add test case for this - raise ValueError( - f"Trying to overwrite arg_start_indices[{idx}]={arg_start_indices[idx]} with " - f"{token_position + offset} for document {document.id}" - ) - arg_start_indices[idx] = token_position + offset - elif marker_type == END: - if arg_end_indices[idx] != -1: - # TODO: add test case for this - raise ValueError( - f"Trying to overwrite arg_start_indices[{idx}]={arg_end_indices[idx]} with " - f"{token_position + offset} for document {document.id}" - ) - # -1 to undo the additional offset for the end marker which does not - # affect the mention offset - arg_end_indices[idx] = ( - token_position + offset - (1 if self.insert_markers else 0) - ) - - if self.append_markers: - if self.tokenizer.sep_token is None: - # TODO: add test case for this - raise ValueError("append_markers is True, but tokenizer has no sep_token") - sep_token_id = self.tokenizer.vocab[self.tokenizer.sep_token] - for arg in args: - if without_special_tokens: - # TODO: add test case for this - input_ids_with_markers.append(sep_token_id) - input_ids_with_markers.append( - self.argument_markers_to_id[arg.as_append_marker] - ) - else: - input_ids_with_markers.append( - self.argument_markers_to_id[arg.as_append_marker] - ) - input_ids_with_markers.append(sep_token_id) - if entity_tags is not None: - entity_tags.append(0) - entity_tags.append(0) - - # when windowing is used, we have to add the special tokens manually - if without_special_tokens: - original_input_ids_with_markers = input_ids_with_markers - input_ids_with_markers = self.tokenizer.build_inputs_with_special_tokens( - token_ids_0=input_ids_with_markers - ) - if self.add_argument_indices_to_input or self.add_argument_tags_to_input: - # get the number of prefix tokens - index_offset = find_sublist( - sub=original_input_ids_with_markers, bigger=input_ids_with_markers - ) - if index_offset == -1: - raise ValueError( - f"Could not find the original tokens in the prefixed tokens for document {document.id}" - ) - arg_start_indices = [ - idx + index_offset if idx != -1 else -1 for idx in arg_start_indices - ] - arg_end_indices = [ - idx + index_offset if idx != -1 else -1 for idx in arg_end_indices - ] - if entity_tags is not None: - special_tokens_mask = self.tokenizer.get_special_tokens_mask( - token_ids_0=input_ids_with_markers, already_has_special_tokens=True - ) - entity_tags_with_special = self.tokenizer.build_inputs_with_special_tokens( - token_ids_0=entity_tags - ) - entity_tags = [ - tag if not is_special else 0 - for tag, is_special in zip( - entity_tags_with_special, special_tokens_mask - ) - ] - - inputs = {"input_ids": input_ids_with_markers} - if self.add_argument_indices_to_input: - inputs["pooler_start_indices"] = arg_start_indices - inputs["pooler_end_indices"] = arg_end_indices - if self.add_argument_tags_to_input: - # create bio-encoded tags for the arguments - # using arg_start_indices, arg_end_indices, and marker_ids_with_positions_sorted - argument_spans = [ - ( - arg_start_indices[self.argument_role2idx[arg.role]], - arg_end_indices[self.argument_role2idx[arg.role]], - arg.role, - ) - for marker_id, token_position, arg, marker_type in marker_ids_with_positions_sorted - ] - argument_tag_ids = bio_encode_spans( - spans=argument_spans, - total_length=len(input_ids_with_markers), - label2idx=self.argument_role2idx, - ) - inputs["argument_tags"] = argument_tag_ids - - if entity_tags is not None: - inputs["entity_tags"] = entity_tags - - task_encodings.append( - TaskEncoding( - document=document, - inputs=inputs, - metadata=({"candidate_annotation": rel}), - ) - ) - - self.collect_relation("used", rel) - - return task_encodings - - def _maybe_log_example( - self, - task_encoding: TaskEncodingType, - target: TargetEncodingType, - ): - """Maybe log the example.""" - - # log the first n examples - if self._logged_examples_counter < self.log_first_n_examples: - input_ids = task_encoding.inputs["input_ids"] - tokens = self.tokenizer.convert_ids_to_tokens(input_ids) - target_labels = [self.id_to_label[label_id] for label_id in target] - logger.info("*** Example ***") - logger.info("doc id: %s", task_encoding.document.id) - logger.info("tokens: %s", " ".join([str(x) for x in tokens])) - logger.info("input_ids: %s", " ".join([str(x) for x in input_ids])) - logger.info("Expected label: %s (ids = %s)", target_labels, target) - - self._logged_examples_counter += 1 - - def encode_target( - self, - task_encoding: TaskEncodingType, - ) -> TargetEncodingType: - candidate_annotation = task_encoding.metadata["candidate_annotation"] - if isinstance(candidate_annotation, (BinaryRelation, NaryRelation)): - labels = [candidate_annotation.label] - else: - raise NotImplementedError( - f"encoding the target with a candidate_annotation of another type than BinaryRelation or" - f"NaryRelation is not yet supported. candidate_annotation has the type: " - f"{type(candidate_annotation)}" - ) - target = [self.label_to_id[label] for label in labels] - - self._maybe_log_example(task_encoding=task_encoding, target=target) - - return target - - def unbatch_output(self, model_output: ModelTargetType) -> Sequence[TaskOutputType]: - unbatched_output = [] - if self.multi_label: - raise NotImplementedError - else: - label_ids = model_output["labels"].detach().cpu().tolist() - probabilities = model_output["probabilities"].detach().cpu().tolist() - for batch_idx in range(len(label_ids)): - label_id = label_ids[batch_idx] - result: TaskOutputType = { - "labels": [self.id_to_label[label_id]], - "probabilities": [probabilities[batch_idx][label_id]], - } - unbatched_output.append(result) - - return unbatched_output - - def create_annotations_from_output( - self, - task_encoding: TaskEncodingType, - task_output: TaskOutputType, - ) -> Iterator[Tuple[str, Union[BinaryRelation, MultiLabeledBinaryRelation, NaryRelation]]]: - candidate_annotation = task_encoding.metadata["candidate_annotation"] - new_annotation: Union[BinaryRelation, MultiLabeledBinaryRelation, NaryRelation] - if self.multi_label: - raise NotImplementedError - else: - label = task_output["labels"][0] - probability = ( - task_output["probabilities"][0] if "probabilities" in task_output else 1.0 - ) - if isinstance(candidate_annotation, BinaryRelation): - head = candidate_annotation.head - tail = candidate_annotation.tail - # Reverse predicted reversed relations back. Serialization will remove any duplicated relations. - if self.add_reversed_relations: - # TODO: add test case for this - if label.endswith(self.reversed_relation_label_suffix): - label = label[: -len(self.reversed_relation_label_suffix)] - head, tail = tail, head - # If the predicted label is symmetric, we sort the arguments by its center. - elif label in self.symmetric_relations and self.reverse_symmetric_relations: - if not (isinstance(head, Span) and isinstance(tail, Span)): - raise ValueError( - f"the taskmodule expects the relation arguments of the candidate_annotation" - f"to be of type Span, but got head of type: {type(head)} and tail of type: " - f"{type(tail)}" - ) - # use a unique order for the arguments: sort by start and end positions - head, tail = sorted([head, tail], key=lambda span: (span.start, span.end)) - new_annotation = BinaryRelation( - head=head, tail=tail, label=label, score=probability - ) - elif isinstance(candidate_annotation, NaryRelation): - # TODO: add test case for this - if self.add_reversed_relations: - raise ValueError("can not reverse a NaryRelation") - new_annotation = NaryRelation( - arguments=candidate_annotation.arguments, - roles=candidate_annotation.roles, - label=label, - score=probability, - ) - else: - raise NotImplementedError( - f"creating a new annotation from a candidate_annotation of another type than BinaryRelation is " - f"not yet supported. candidate_annotation has the type: {type(candidate_annotation)}" - ) - - new_annotation_args = get_relation_argument_spans_and_roles(new_annotation) - arg_roles, arg_spans = zip(*new_annotation_args) - arg_labels = tuple(getattr(ann, "label") for ann in arg_spans) - - # Create annotation only if 1. and 2. are fulfilled: - if ( - # 1. the label is not the no-relation-label, - label != self.none_label - # or we did not create candidate relations, - or not self.add_candidate_relations - ) and ( - # 2. the argument_and_relation_type_whitelist is not set, - self.argument_and_relation_type_whitelist is None - # or the label and argument types are in the whitelist - or arg_labels in self.argument_and_relation_type_whitelist.get(label, {}) - ): - yield self.relation_annotation, new_annotation - - def _get_global_attention(self, input_ids: torch.LongTensor) -> torch.LongTensor: - # we want to have global attention on all marker tokens and the cls token - positive_token_ids = list(self.argument_markers_to_id.values()) + [ - self.tokenizer.cls_token_id - ] - global_attention_mask = construct_mask( - input_ids=input_ids, positive_ids=positive_token_ids - ) - return global_attention_mask - - def collate( - self, task_encodings: Sequence[TaskEncodingType] - ) -> Tuple[ModelInputType, Optional[ModelTargetType]]: - input_features = [ - {"input_ids": task_encoding.inputs["input_ids"]} for task_encoding in task_encodings - ] - - inputs: Dict[str, torch.LongTensor] = self.tokenizer.pad( - input_features, - padding=self.padding, - max_length=self.max_length, - pad_to_multiple_of=self.pad_to_multiple_of, - return_tensors="pt", - ) - if self.add_argument_tags_to_input: - argument_tags = [ - {"input_ids": task_encoding.inputs["argument_tags"]} - for task_encoding in task_encodings - ] - argument_tags_padded = self.tokenizer.pad( - argument_tags, - padding=self.padding, - max_length=self.max_length, - pad_to_multiple_of=self.pad_to_multiple_of, - return_tensors="pt", - ) - # increase all values by 1 because 0 is used for padding - inputs["argument_tags"] = argument_tags_padded["input_ids"] + 1 - # overwrite padding with 0 - inputs["argument_tags"][argument_tags_padded["attention_mask"] == 0] = 0 - - if self.add_entity_tags_to_input: - entity_tags = [ - {"input_ids": task_encoding.inputs["entity_tags"]} - for task_encoding in task_encodings - ] - entity_tags_padded = self.tokenizer.pad( - entity_tags, - padding=self.padding, - max_length=self.max_length, - pad_to_multiple_of=self.pad_to_multiple_of, - return_tensors="pt", - ) - # increase all values by 1 because 0 is used for padding - inputs["entity_tags"] = entity_tags_padded["input_ids"] + 1 - # overwrite padding with 0 - inputs["entity_tags"][entity_tags_padded["attention_mask"] == 0] = 0 - - if self.add_argument_indices_to_input: - inputs["pooler_start_indices"] = torch.tensor( - [task_encoding.inputs["pooler_start_indices"] for task_encoding in task_encodings] - ).to(torch.long) - inputs["pooler_end_indices"] = torch.tensor( - [task_encoding.inputs["pooler_end_indices"] for task_encoding in task_encodings] - ).to(torch.long) - - if self.add_global_attention_mask_to_input: - inputs["global_attention_mask"] = self._get_global_attention( - input_ids=inputs["input_ids"] - ) - - if not task_encodings[0].has_targets: - return inputs, None - - target_list: List[TargetEncodingType] = [ - task_encoding.targets for task_encoding in task_encodings - ] - targets = torch.tensor(target_list, dtype=torch.int64) - - if not self.multi_label: - targets = targets.flatten() - - return inputs, {"labels": targets} - - def configure_model_metric(self, stage: str) -> MetricCollection: - if self.label_to_id is None: - raise ValueError( - "The taskmodule has not been prepared yet, so label_to_id is not known. " - "Please call taskmodule.prepare(documents) before configuring the model metric " - "or pass the labels to the taskmodule constructor an call taskmodule.post_prepare()." - ) - # we use the length of label_to_id because that contains the none_label (in contrast to labels) - labels = [self.id_to_label[i] for i in range(len(self.label_to_id))] - common_metric_kwargs = { - "num_classes": len(labels), - "task": "multilabel" if self.multi_label else "multiclass", - } - return MetricCollection( - { - "with_tn": WrappedMetricWithPrepareFunction( - metric=MetricCollection( - { - "micro/f1": F1Score(average="micro", **common_metric_kwargs), - "macro/f1": F1Score(average="macro", **common_metric_kwargs), - "f1_per_label": ClasswiseWrapper( - F1Score(average=None, **common_metric_kwargs), - labels=labels, - postfix="/f1", - ), - } - ), - prepare_function=_get_labels, - ), - # We can not easily calculate the macro f1 here, because - # F1Score with average="macro" would still include the none_label. - "micro/f1_without_tn": WrappedMetricWithPrepareFunction( - metric=F1Score(average="micro", **common_metric_kwargs), - prepare_together_function=partial( - _get_labels_together_remove_none_label, - none_idx=self.label_to_id[self.none_label], - ), - ), - } - ) diff --git a/src/pie_modules/taskmodules/text_to_text.py b/src/pie_modules/taskmodules/text_to_text.py deleted file mode 100644 index 1a90a7854..000000000 --- a/src/pie_modules/taskmodules/text_to_text.py +++ /dev/null @@ -1,458 +0,0 @@ -import dataclasses -import logging -from functools import partial -from typing import ( - Any, - Dict, - Iterator, - List, - Optional, - Sequence, - Set, - Tuple, - Type, - Union, -) - -import torch -from pie_core import ( - Annotation, - AnnotationLayer, - Document, - TaskEncoding, - TaskModule, -) -from pie_core.taskmodule import ( - InputEncoding, - ModelBatchOutput, - TargetEncoding, - TaskBatchEncoding, -) -from pie_core.utils.hydra import resolve_type -from torchmetrics import Metric -from transformers import AutoTokenizer, PreTrainedTokenizer -from typing_extensions import TypeAlias - -from pie_modules.annotations import AnnotationWithText -from pie_modules.document.processing import ( - token_based_document_to_text_based, - tokenize_document, -) -from pie_modules.documents import TextBasedDocument, TokenBasedDocument - -from .common import BatchableMixin, get_first_occurrence_index -from .metrics import WrappedMetricWithPrepareFunction - -logger = logging.getLogger(__name__) - - -DocumentType: TypeAlias = TextBasedDocument - - -@dataclasses.dataclass -class InputEncodingType(BatchableMixin): - input_ids: List[int] - attention_mask: List[int] - - -@dataclasses.dataclass -class TargetEncodingType(BatchableMixin): - labels: List[int] - # this is optional because we use the same type for TaskOutputType, which does not have this field - decoder_attention_mask: Optional[List[int]] = None - - -TaskEncodingType: TypeAlias = TaskEncoding[ - DocumentType, - InputEncodingType, - TargetEncodingType, -] -TaskOutputType: TypeAlias = TargetEncodingType - - -# we use a custom un-batch function for metrics, because the text metrics such as ROUGEScore metric expects -# strings for input and target -def unbatch_and_untokenize( - batch: ModelBatchOutput, taskmodule: "TextToTextTaskModule" -) -> Sequence[str]: - unbatched = taskmodule.unbatch_output(batch) - texts = [ - taskmodule.tokenizer.decode(encoding.labels, skip_special_tokens=True) - for encoding in unbatched - ] - return texts - - -@TaskModule.register() -class TextToTextTaskModule( - TaskModule[ - DocumentType, - InputEncoding, - TargetEncoding, - TaskBatchEncoding, - ModelBatchOutput, - TaskOutputType, - ], -): - """A PIE task module for text-to-text tasks. It works with simple text annotations, e.g. - abstractive summaries, as target annotations. - - It can also be used with additional guidance annotations, e.g. questions for generative question answering, in - which case the text of the guidance annotation is prepended to the input text. - - Args: - tokenizer_name_or_path: The name (Huggingface Hub model identifier) or local path of the tokenizer to use. - document_type: The type of the input document. Must be a string that resolves to a subclass of - TextBasedDocument, e.g. "pie_modules.documents.TextDocumentWithAbstractiveSummary" for abstractive - summarization. - tokenized_document_type: The type of the tokenized document. Must be a string that resolves to a - subclass of TokenBasedDocument, e.g. "pie_modules.documents.TokenDocumentWithAbstractiveSummary" for - abstractive summarization. - target_layer: The name of the annotation layer that contains the target annotations, e.g. "abstractive_summary" - for abstractive summarization. - target_annotation_type: The type of the target annotations. Must be a string that resolves to a subclass - of AnnotationWithText, e.g. "pie_modules.annotations.AbstractiveSummary" for abstractive summarization. - guidance_layer: The name of the annotation layer that contains the guidance annotations. If set, the text of - the guidance annotation is prepended to the input text. - guidance_annotation_field: The name of the field in the target annotations that contains the guidance - annotation. Required if guidance_layer is defined to attach the guidance annotation to the newly created - target annotation. - text_metric_type: The type of the text metric to use for evaluation. Must be a string that resolves to a - subclass of Metric, e.g. "torchmetrics.text.ROUGEScore" for ROUGE score. - tokenizer_init_kwargs: Additional keyword arguments that are passed to the tokenizer constructor. - tokenizer_kwargs: Additional keyword arguments that are passed when calling the tokenizer. - partition_layer_name: The name of the annotation layer that contains the partitions. If set, the partitions - will be used to split the input text into multiple parts which are then tokenized separately. This can be - used to split long documents into multiple parts to avoid exceeding the maximum input length of the - tokenizer / model. - annotation_field_mapping: A mapping from input document annotation layer names to layer names defined in the - document_type / tokenized_document_type. This can be used if the actual input documents have different - annotation layer names than the provided document_type / tokenized_document_type. - log_first_n_examples: The number of examples to log. If set to a positive integer n, the first n examples will - be logged. This can be used to check if the input and target encodings are as expected. - """ - - def __init__( - self, - tokenizer_name_or_path: str, - document_type: str, - tokenized_document_type: str, - target_layer: str, - target_annotation_type: str, - guidance_layer: Optional[str] = None, - guidance_annotation_field: Optional[str] = None, - text_metric_type: Optional[str] = None, - tokenizer_init_kwargs: Optional[Dict[str, Any]] = None, - tokenizer_kwargs: Optional[Dict[str, Any]] = None, - partition_layer_name: Optional[str] = None, - annotation_field_mapping: Optional[Dict[str, str]] = None, - log_first_n_examples: Optional[int] = None, - **kwargs, - ): - super().__init__(**kwargs) - self.save_hyperparameters() - - self.target_layer = target_layer - self.guidance_layer = guidance_layer - self.target_annotation_type: Type[AnnotationWithText] = resolve_type( - target_annotation_type, expected_super_type=AnnotationWithText - ) - self.guidance_annotation_field = guidance_annotation_field - self.text_metric_type: Optional[Metric] = None - if text_metric_type is not None: - self.text_metric_type = resolve_type(text_metric_type, expected_super_type=Metric) - - # tokenization - self._document_type: Type[TextBasedDocument] = resolve_type( - document_type, expected_super_type=TextBasedDocument - ) - self._tokenized_document_type: Type[TokenBasedDocument] = resolve_type( - tokenized_document_type, expected_super_type=TokenBasedDocument - ) - self.tokenizer_name_or_path = tokenizer_name_or_path - self.tokenizer_kwargs = tokenizer_kwargs or {} - self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( - tokenizer_name_or_path, - **(tokenizer_init_kwargs or {}), - ) - self.annotation_field_mapping = annotation_field_mapping or dict() - self.partition_layer_name = partition_layer_name - - # target encoding - self.pad_values = { - "input_ids": self.tokenizer.pad_token_id, - "attention_mask": 0, - "labels": self.tokenizer.pad_token_id, - "decoder_attention_mask": 0, - } - self.dtypes = { - "input_ids": torch.int64, - "attention_mask": torch.int64, - "labels": torch.int64, - "decoder_attention_mask": torch.int64, - } - - # logging - self.log_first_n_examples = log_first_n_examples - - @property - def document_type(self) -> Type[TextBasedDocument]: - return self._document_type - - @property - def tokenized_document_type(self) -> Type[TokenBasedDocument]: - return self._tokenized_document_type - - @property - def layer_names(self) -> List[str]: - return [self.target_layer] - - def get_mapped_layer(self, document: Document, layer_name: str) -> AnnotationLayer: - if layer_name in self.annotation_field_mapping: - layer_name = self.annotation_field_mapping[layer_name] - return document[layer_name] - - @property - def generation_config(self) -> Dict[str, Any]: - return {} - - def maybe_log_example( - self, - task_encoding: TaskEncodingType, - targets: Optional[TargetEncodingType] = None, - ) -> None: - if self.log_first_n_examples is not None and self.log_first_n_examples > 0: - inputs = task_encoding.inputs - - logger.info(f"input_ids: {inputs.input_ids}") - logger.info(f"attention_mask: {inputs.attention_mask}") - if targets is not None or task_encoding.has_targets: - targets = targets or task_encoding.targets - logger.info(f"labels: {targets.labels}") - self.log_first_n_examples -= 1 - - def warn_only_once(self, message: str) -> None: - if not hasattr(self, "_warned"): - self._warned: Set[str] = set() - if message not in self._warned: - logger.warning(f"{message} (This warning will only be shown once)") - self._warned.add(message) - - def encode_annotations( - self, - layers: Dict[str, AnnotationLayer], - metadata: Optional[Dict[str, Any]] = None, - ) -> TargetEncodingType: - target_annotations = [] - guidance_annotation = ( - metadata.get("guidance_annotation", None) if metadata is not None else None - ) - if guidance_annotation is not None: - if self.guidance_annotation_field is None: - raise ValueError( - "guidance_annotation is available, but guidance_annotation_field is not set" - ) - # filter annotations that belong to the guidance_annotation - for target_annotation in layers[self.target_layer]: - current_guidance_annotation = getattr( - target_annotation, self.guidance_annotation_field - ) - if current_guidance_annotation == guidance_annotation: - target_annotations.append(target_annotation) - else: - target_annotations = layers[self.target_layer] - - if len(target_annotations) == 0: - raise ValueError(f"target_annotations {self.target_layer} contains no annotation") - elif len(target_annotations) > 1: - self.warn_only_once( - f"target_annotations {self.target_layer} contains more than one annotation, " - f"but only the first one will be used" - ) - annotation = target_annotations[0] - if isinstance(annotation, self.target_annotation_type): - text = target_annotations[0].text - else: - raise ValueError( - f"target_annotations {self.target_layer} contains an annotation of type {type(annotation)}, " - f"but expected {self.target_annotation_type}" - ) - encoding = self.tokenizer(text) - return TargetEncodingType( - labels=encoding["input_ids"], decoder_attention_mask=encoding["attention_mask"] - ) - - def decode_annotations( - self, encoding: TaskOutputType, metadata: Optional[Dict[str, Any]] = None - ) -> Tuple[Dict[str, List[Annotation]], Any]: - text = self.tokenizer.decode(encoding.labels, skip_special_tokens=True) - annotation_kwargs = {} - if self.guidance_annotation_field is not None: - if metadata is None: - raise ValueError( - "metadata is required to decode annotations with guidance_annotation_field" - ) - guidance_annotation = metadata.get("guidance_annotation", None) - if guidance_annotation is not None: - if self.guidance_annotation_field is None: - raise ValueError( - "guidance_annotation is available, but guidance_annotation_field is not set" - ) - annotation_kwargs[self.guidance_annotation_field] = guidance_annotation - - decoded_layers = { - self.target_layer: [self.target_annotation_type(text=text, **annotation_kwargs)] - } - # no error collection yet - errors: Dict[str, Any] = {} - return decoded_layers, errors - - def tokenize_document( - self, document: DocumentType, source_text: Optional[str] = None - ) -> List[TokenBasedDocument]: - field_mapping = dict(self.annotation_field_mapping) - if self.partition_layer_name is not None: - field_mapping[self.partition_layer_name] = "labeled_partitions" - partition_layer = "labeled_partitions" - else: - partition_layer = None - casted_document = document.as_type(self.document_type, field_mapping=field_mapping) - - tokenizer_kwargs = dict(self.tokenizer_kwargs) - if source_text is not None: - tokenizer_kwargs["text"] = source_text - tokenized_docs = tokenize_document( - casted_document, - tokenizer=self.tokenizer, - result_document_type=self.tokenized_document_type, - partition_layer=partition_layer, - **tokenizer_kwargs, - ) - for idx, tokenized_doc in enumerate(tokenized_docs): - tokenized_doc.id = f"{document.id}-tokenized-{idx+1}-of-{len(tokenized_docs)}" - - return tokenized_docs - - def encode_input( - self, document: DocumentType, is_training: bool = False - ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]: - task_encodings: List[TaskEncodingType] = [] - if self.guidance_layer is None: - guidance_annotations = [None] - else: - guidance_annotations = document[self.guidance_layer] - for guidance_annotation in guidance_annotations: - source_text = None - if guidance_annotation is not None: - # Here could also more sophisticated logic be implemented - source_text = guidance_annotation.text - tokenized_docs = self.tokenize_document(document, source_text=source_text) - for tokenized_doc in tokenized_docs: - tokenizer_encoding = tokenized_doc.metadata["tokenizer_encoding"] - task_encodings.append( - TaskEncoding( - document=document, - inputs=InputEncodingType( - input_ids=tokenizer_encoding.ids, - attention_mask=tokenizer_encoding.attention_mask, - ), - metadata={ - "tokenized_document": tokenized_doc, - "guidance_annotation": guidance_annotation, - }, - ) - ) - - return task_encodings - - def encode_target(self, task_encoding: TaskEncodingType) -> Optional[TargetEncodingType]: - document = task_encoding.metadata["tokenized_document"] - guidance_annotation = task_encoding.metadata["guidance_annotation"] - - layers = { - layer_name: self.get_mapped_layer(document, layer_name=layer_name) - for layer_name in self.layer_names - } - result = self.encode_annotations( - layers=layers, - metadata={**task_encoding.metadata, "guidance_annotation": guidance_annotation}, - ) - - self.maybe_log_example(task_encoding=task_encoding, targets=result) - return result - - def collate(self, task_encodings: Sequence[TaskEncodingType]) -> TaskBatchEncoding: - if len(task_encodings) == 0: - raise ValueError("no task_encodings available") - inputs = InputEncodingType.batch( - values=[x.inputs for x in task_encodings], - dtypes=self.dtypes, - pad_values=self.pad_values, - ) - - targets = None - if task_encodings[0].has_targets: - targets = TargetEncodingType.batch( - values=[x.targets for x in task_encodings], - dtypes=self.dtypes, - pad_values=self.pad_values, - ) - - return inputs, targets - - def unbatch_output(self, model_output: ModelBatchOutput) -> Sequence[TaskOutputType]: - labels = model_output["labels"] - batch_size = labels.size(0) - - # We use the position after the first eos token as the seq_len. - # Note that, if eos_id is not in model_output for a given batch item, the result will be - # model_output.size(1) + 1 (i.e. seq_len + 1) for that batch item. This is fine, because we use the - # seq_lengths just to truncate the output and want to keep everything if eos_id is not present. - seq_lengths = get_first_occurrence_index(labels, self.tokenizer.eos_token_id) + 1 - - result = [ - TaskOutputType(labels[i, : seq_lengths[i]].to(device="cpu").tolist()) - for i in range(batch_size) - ] - return result - - def create_annotations_from_output( - self, - task_encoding: TaskEncodingType, - task_output: TaskOutputType, - ) -> Iterator[Tuple[str, Annotation]]: - layers, errors = self.decode_annotations( - encoding=task_output, metadata=task_encoding.metadata - ) - tokenized_document = task_encoding.metadata["tokenized_document"] - - # Note: token_based_document_to_text_based() does not yet consider predictions, so we need to clear - # the main annotations and attach the predictions to that - for layer_name, annotations in layers.items(): - layer = self.get_mapped_layer(tokenized_document, layer_name=layer_name) - layer.clear() - layer.extend(annotations) - - untokenized_document = token_based_document_to_text_based( - tokenized_document, result_document_type=self.document_type - ) - - for layer_name in layers: - annotations = self.get_mapped_layer(untokenized_document, layer_name=layer_name) - for annotation in annotations: - yield layer_name, annotation.copy() - - def configure_model_generation(self) -> Optional[Dict[str, Any]]: - # we do not set any overrides here, because we want to use the default generation config as - # it is derived from the Huggingface base model config.json - return {} - - def configure_model_metric(self, stage: str) -> Optional[Metric]: - if self.text_metric_type is None: - return None - - return WrappedMetricWithPrepareFunction( - metric=self.text_metric_type(), - prepare_function=partial(unbatch_and_untokenize, taskmodule=self), - prepare_does_unbatch=True, - ) diff --git a/src/pie_modules/utils/__init__.py b/src/pie_modules/utils/__init__.py index c8017711c..e69de29bb 100644 --- a/src/pie_modules/utils/__init__.py +++ b/src/pie_modules/utils/__init__.py @@ -1,3 +0,0 @@ -# backwards compatibility -from .dictionary import flatten_dict, list_of_dicts2dict_of_lists -from .hydra import resolve_type diff --git a/src/pie_modules/utils/dictionary.py b/src/pie_modules/utils/dictionary.py deleted file mode 100644 index 34cb20be0..000000000 --- a/src/pie_modules/utils/dictionary.py +++ /dev/null @@ -1,3 +0,0 @@ -# backwards compatibility -from pie_core.utils.dictionary import flatten_dict_s as flatten_dict -from pie_core.utils.dictionary import list_of_dicts2dict_of_lists diff --git a/src/pie_modules/utils/hydra.py b/src/pie_modules/utils/hydra.py deleted file mode 100644 index 21e55eeb9..000000000 --- a/src/pie_modules/utils/hydra.py +++ /dev/null @@ -1,2 +0,0 @@ -# backwards compatibility -from pie_core.utils.hydra import resolve_target, resolve_type diff --git a/src/pie_modules/utils/tokenization.py b/src/pie_modules/utils/tokenization.py deleted file mode 100644 index 85e89d448..000000000 --- a/src/pie_modules/utils/tokenization.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import TypeVar - -from transformers import BatchEncoding - -from pie_modules.annotations import Span - -S = TypeVar("S", bound=Span) - - -class SpanNotAlignedWithTokenException(Exception): - def __init__(self, span): - self.span = span - - -def get_aligned_token_span(encoding: BatchEncoding, char_span: S) -> S: - # find the start - token_start = None - token_end_before = None - char_start = None - for idx in range(char_span.start, char_span.end): - token_start = encoding.char_to_token(idx) - if token_start is not None: - char_start = idx - break - - if char_start is None: - raise SpanNotAlignedWithTokenException(span=char_span) - for idx in range(char_span.end - 1, char_start - 1, -1): - token_end_before = encoding.char_to_token(idx) - if token_end_before is not None: - break - - if token_start is None or token_end_before is None: - raise SpanNotAlignedWithTokenException(span=char_span) - - return char_span.copy(start=token_start, end=token_end_before + 1) diff --git a/tests/document/processing/test_relation_argument_sorter.py b/tests/document/processing/test_relation_argument_sorter.py index fec8e1200..2068384ea 100644 --- a/tests/document/processing/test_relation_argument_sorter.py +++ b/tests/document/processing/test_relation_argument_sorter.py @@ -104,7 +104,7 @@ def test_get_args_wrong_type(document_with_nary_relation): == "relation NaryRelation(arguments=(LabeledSpan(start=0, end=8, label='PER', score=1.0), " "LabeledSpan(start=18, end=19, label='ORG', score=1.0), LabeledSpan(start=33, end=34, " "label='ORG', score=1.0)), roles=('person', 'worksAt', 'founded'), label='event', score=1.0) " - "has unknown type [], cannot get arguments from it" + "has unknown type [], cannot get arguments from it" ) @@ -122,7 +122,7 @@ def test_construct_relation_with_new_args_wrong_type(document_with_nary_relation == "original relation NaryRelation(arguments=(LabeledSpan(start=0, end=8, label='PER', score=1.0), " "LabeledSpan(start=18, end=19, label='ORG', score=1.0), LabeledSpan(start=33, end=34, label='ORG', " "score=1.0)), roles=('person', 'worksAt', 'founded'), label='event', score=1.0) has unknown type " - "[], cannot reconstruct it with new arguments" + "[], cannot reconstruct it with new arguments" ) diff --git a/tests/document/processing/test_tokenization.py b/tests/document/processing/test_tokenization.py index c940be177..c6fdacd17 100644 --- a/tests/document/processing/test_tokenization.py +++ b/tests/document/processing/test_tokenization.py @@ -4,7 +4,6 @@ import pytest from pie_core import Annotation, AnnotationLayer, Document, annotation_field -from transformers import AutoTokenizer, PreTrainedTokenizer from pie_modules.annotations import ( BinaryRelation, @@ -16,7 +15,6 @@ from pie_modules.document.processing import ( text_based_document_to_token_based, token_based_document_to_text_based, - tokenize_document, ) from pie_modules.document.processing.tokenization import find_token_offset_mapping from pie_modules.documents import TextBasedDocument, TokenBasedDocument @@ -229,11 +227,6 @@ def _test_token_document_with_multi_spans(doc): ] -@pytest.fixture(scope="module") -def tokenizer() -> PreTrainedTokenizer: - return AutoTokenizer.from_pretrained("bert-base-cased") - - def test_find_token_offset_mapping(text_document, token_document): token_offset_mapping = find_token_offset_mapping( text=text_document.text, tokens=list(token_document.tokens) @@ -559,7 +552,7 @@ class WrongAnnotationType(TextBasedDocument): assert ( str(excinfo.value) == "can not convert layers that target the text but contain non-span annotations, " - "but found " + "but found " ) @@ -650,370 +643,5 @@ class WrongAnnotationType(TokenBasedDocument): assert ( str(excinfo.value) == "can not convert layers that target the tokens but contain non-span annotations, " - "but found " - ) - - -def test_tokenize_document(text_document, tokenizer): - added_annotations = [] - tokenized_docs = tokenize_document( - text_document, - tokenizer=tokenizer, - result_document_type=TokenizedTestDocument, - added_annotations=added_annotations, - ) - assert len(tokenized_docs) == 1 - tokenized_doc = tokenized_docs[0] - - # check (de-)serialization - tokenized_doc.copy() - - assert ( - tokenized_doc.metadata["text"] - == text_document.text - == "First sentence. Entity M works at N. And it founded O." - ) - assert tokenized_doc.tokens == ( - "[CLS]", - "First", - "sentence", - ".", - "En", - "##ti", - "##ty", - "M", - "works", - "at", - "N", - ".", - "And", - "it", - "founded", - "O", - ".", - "[SEP]", - ) - assert len(tokenized_doc.sentences) == len(text_document.sentences) == 3 - sentences = [str(sentence) for sentence in tokenized_doc.sentences] - assert sentences == [ - "('First', 'sentence', '.')", - "('En', '##ti', '##ty', 'M', 'works', 'at', 'N', '.')", - "('And', 'it', 'founded', 'O', '.')", - ] - assert len(tokenized_doc.entities) == len(text_document.entities) == 4 - entities = [str(entity) for entity in tokenized_doc.entities] - assert entities == ["('En', '##ti', '##ty', 'M')", "('N',)", "('it',)", "('O',)"] - assert len(tokenized_doc.relations) == len(text_document.relations) == 2 - relation_tuples = [ - (str(rel.head), rel.label, str(rel.tail)) for rel in tokenized_doc.relations - ] - assert relation_tuples == [ - ("('En', '##ti', '##ty', 'M')", "per:employee_of", "('N',)"), - ("('it',)", "per:founder", "('O',)"), - ] - - assert len(added_annotations) == 1 - first_added_annotations = added_annotations[0] - _assert_added_annotations(text_document, tokenized_doc, first_added_annotations) - - -def test_tokenize_document_max_length(text_document, tokenizer, caplog): - added_annotations = [] - caplog.clear() - with caplog.at_level("WARNING"): - tokenized_docs = tokenize_document( - text_document, - tokenizer=tokenizer, - result_document_type=TokenizedTestDocument, - # max_length is set to 10, so the document is split into two parts - strict_span_conversion=False, - max_length=10, - return_overflowing_tokens=True, - added_annotations=added_annotations, - ) - assert len(caplog.records) == 1 - assert ( - caplog.records[0].message - == "could not convert all annotations from document with id=None to token based documents, missed annotations " - "(disable this message with verbose=False):\n" - "{\n" - ' "relations": "{BinaryRelation(head=LabeledSpan(start=16, end=24, label=\'per\', score=1.0), ' - "tail=LabeledSpan(start=34, end=35, label='org', score=1.0), label='per:employee_of', score=1.0)}\",\n" - ' "sentences": "{Span(start=16, end=36)}"\n' - "}" - ) - assert len(tokenized_docs) == 2 - assert len(added_annotations) == 2 - tokenized_doc = tokenized_docs[0] - - # check (de-)serialization - tokenized_doc.copy() - - assert ( - tokenized_doc.metadata["text"] - == text_document.text - == "First sentence. Entity M works at N. And it founded O." - ) - assert tokenized_doc.tokens == ( - "[CLS]", - "First", - "sentence", - ".", - "En", - "##ti", - "##ty", - "M", - "works", - "[SEP]", - ) - assert len(tokenized_doc.sentences) == 1 - sentences = [str(sentence) for sentence in tokenized_doc.sentences] - assert sentences == ["('First', 'sentence', '.')"] - assert len(tokenized_doc.entities) == 1 - entities = [str(entity) for entity in tokenized_doc.entities] - assert entities == ["('En', '##ti', '##ty', 'M')"] - assert len(tokenized_doc.relations) == 0 - # check annotation mapping - current_added_annotations = added_annotations[0] - # no relations are added in the first tokenized document - assert set(current_added_annotations) == {"sentences", "entities"} - # check sentences - sentence_mapping = { - k.resolve(): v.resolve() for k, v in current_added_annotations["sentences"].items() - } - assert sentence_mapping == {"First sentence.": ("First", "sentence", ".")} - # check entities - entity_mapping = { - k.resolve(): v.resolve() for k, v in current_added_annotations["entities"].items() - } - assert entity_mapping == {("per", "Entity M"): ("per", ("En", "##ti", "##ty", "M"))} - - tokenized_doc = tokenized_docs[1] - - # check (de-)serialization - tokenized_doc.copy() - - assert ( - tokenized_doc.metadata["text"] - == text_document.text - == "First sentence. Entity M works at N. And it founded O." - ) - assert tokenized_doc.tokens == ( - "[CLS]", - "at", - "N", - ".", - "And", - "it", - "founded", - "O", - ".", - "[SEP]", - ) - assert len(tokenized_doc.sentences) == 1 - sentences = [str(sentence) for sentence in tokenized_doc.sentences] - assert sentences == ["('And', 'it', 'founded', 'O', '.')"] - assert len(tokenized_doc.entities) == 3 - entities = [str(entity) for entity in tokenized_doc.entities] - assert entities == ["('N',)", "('it',)", "('O',)"] - assert len(tokenized_doc.relations) == 1 - relation_tuples = [ - (str(rel.head), rel.label, str(rel.tail)) for rel in tokenized_doc.relations - ] - assert relation_tuples == [("('it',)", "per:founder", "('O',)")] - # check annotation mapping - current_added_annotations = added_annotations[1] - assert set(current_added_annotations) == {"sentences", "entities", "relations"} - # check sentences - sentence_mapping = { - k.resolve(): v.resolve() for k, v in current_added_annotations["sentences"].items() - } - assert sentence_mapping == {"And it founded O.": ("And", "it", "founded", "O", ".")} - # check entities - entity_mapping = { - k.resolve(): v.resolve() for k, v in current_added_annotations["entities"].items() - } - assert entity_mapping == { - ("org", "N"): ("org", ("N",)), - ("per", "it"): ("per", ("it",)), - ("org", "O"): ("org", ("O",)), - } - # check relations - relation_mapping = { - k.resolve(): v.resolve() for k, v in current_added_annotations["relations"].items() - } - assert relation_mapping == { - ("per:founder", (("per", "it"), ("org", "O"))): ( - "per:founder", - (("per", ("it",)), ("org", ("O",))), - ) - } - - -def test_tokenize_document_max_length_strict(text_document, tokenizer): - with pytest.raises(ValueError) as excinfo: - tokenize_document( - text_document, - tokenizer=tokenizer, - result_document_type=TokenizedTestDocument, - # max_length is set to 10, so the document is split into two parts - strict_span_conversion=True, - max_length=10, - return_overflowing_tokens=True, - ) - assert ( - str(excinfo.value) - == "could not convert all annotations from document with id=None to token based documents, " - "but strict_span_conversion is True, so raise an error, missed annotations:\n" - "{\n" - ' "relations": "{BinaryRelation(head=LabeledSpan(start=16, end=24, label=\'per\', score=1.0), ' - "tail=LabeledSpan(start=34, end=35, label='org', score=1.0), label='per:employee_of', score=1.0)}\",\n" - ' "sentences": "{Span(start=16, end=36)}"\n' - "}" - ) - - -def test_tokenize_document_partition(text_document, tokenizer): - added_annotations = [] - tokenized_docs = tokenize_document( - text_document, - tokenizer=tokenizer, - result_document_type=TokenizedTestDocument, - partition_layer="sentences", - added_annotations=added_annotations, - ) - assert len(tokenized_docs) == 3 - assert len(added_annotations) == 3 - tokenized_doc = tokenized_docs[0] - - # check (de-)serialization - tokenized_doc.copy() - - assert ( - tokenized_doc.metadata["text"] - == text_document.text - == "First sentence. Entity M works at N. And it founded O." + "but found " ) - assert tokenized_doc.tokens == ("[CLS]", "First", "sentence", ".", "[SEP]") - assert len(tokenized_doc.sentences) == 1 - assert len(tokenized_doc.entities) == 0 - assert len(tokenized_doc.relations) == 0 - - # check annotation mapping - current_added_annotations = added_annotations[0] - assert set(current_added_annotations) == {"sentences"} - # check sentences - sentence_mapping = { - k.resolve(): v.resolve() for k, v in current_added_annotations["sentences"].items() - } - assert sentence_mapping == {"First sentence.": ("First", "sentence", ".")} - - tokenized_doc = tokenized_docs[1] - - # check (de-)serialization - tokenized_doc.copy() - - assert ( - tokenized_doc.metadata["text"] - == text_document.text - == "First sentence. Entity M works at N. And it founded O." - ) - assert tokenized_doc.tokens == ( - "[CLS]", - "En", - "##ti", - "##ty", - "M", - "works", - "at", - "N", - ".", - "[SEP]", - ) - assert len(tokenized_doc.sentences) == 1 - sentences = [str(sentence) for sentence in tokenized_doc.sentences] - assert sentences == ["('En', '##ti', '##ty', 'M', 'works', 'at', 'N', '.')"] - assert len(tokenized_doc.entities) == 2 - entities = [str(entity) for entity in tokenized_doc.entities] - assert entities == ["('En', '##ti', '##ty', 'M')", "('N',)"] - assert len(tokenized_doc.relations) == 1 - relation_tuples = [ - (str(rel.head), rel.label, str(rel.tail)) for rel in tokenized_doc.relations - ] - assert relation_tuples == [("('En', '##ti', '##ty', 'M')", "per:employee_of", "('N',)")] - - # check annotation mapping - current_added_annotations = added_annotations[1] - assert set(current_added_annotations) == {"sentences", "entities", "relations"} - # check sentences - sentence_mapping = { - k.resolve(): v.resolve() for k, v in current_added_annotations["sentences"].items() - } - assert sentence_mapping == { - "Entity M works at N.": ("En", "##ti", "##ty", "M", "works", "at", "N", ".") - } - # check entities - entity_mapping = { - k.resolve(): v.resolve() for k, v in current_added_annotations["entities"].items() - } - assert entity_mapping == { - ("per", "Entity M"): ("per", ("En", "##ti", "##ty", "M")), - ("org", "N"): ("org", ("N",)), - } - # check relations - relation_mapping = { - k.resolve(): v.resolve() for k, v in current_added_annotations["relations"].items() - } - assert relation_mapping == { - ("per:employee_of", (("per", "Entity M"), ("org", "N"))): ( - "per:employee_of", - (("per", ("En", "##ti", "##ty", "M")), ("org", ("N",))), - ) - } - - tokenized_doc = tokenized_docs[2] - - # check (de-)serialization - tokenized_doc.copy() - - assert ( - tokenized_doc.metadata["text"] - == text_document.text - == "First sentence. Entity M works at N. And it founded O." - ) - assert tokenized_doc.tokens == ("[CLS]", "And", "it", "founded", "O", ".", "[SEP]") - assert len(tokenized_doc.sentences) == 1 - sentences = [str(sentence) for sentence in tokenized_doc.sentences] - assert sentences == ["('And', 'it', 'founded', 'O', '.')"] - assert len(tokenized_doc.entities) == 2 - entities = [str(entity) for entity in tokenized_doc.entities] - assert entities == ["('it',)", "('O',)"] - assert len(tokenized_doc.relations) == 1 - relation_tuples = [ - (str(rel.head), rel.label, str(rel.tail)) for rel in tokenized_doc.relations - ] - assert relation_tuples == [("('it',)", "per:founder", "('O',)")] - - # check annotation mapping - current_added_annotations = added_annotations[2] - assert set(current_added_annotations) == {"sentences", "entities", "relations"} - # check sentences - sentence_mapping = { - k.resolve(): v.resolve() for k, v in current_added_annotations["sentences"].items() - } - assert sentence_mapping == {"And it founded O.": ("And", "it", "founded", "O", ".")} - # check entities - entity_mapping = { - k.resolve(): v.resolve() for k, v in current_added_annotations["entities"].items() - } - assert entity_mapping == {("per", "it"): ("per", ("it",)), ("org", "O"): ("org", ("O",))} - # check relations - relation_mapping = { - k.resolve(): v.resolve() for k, v in current_added_annotations["relations"].items() - } - assert relation_mapping == { - ("per:founder", (("per", "it"), ("org", "O"))): ( - "per:founder", - (("per", ("it",)), ("org", ("O",))), - ) - } diff --git a/tests/metrics/test_relation_argument_distance_collector.py b/tests/metrics/test_relation_argument_distance_collector.py index f51d6e819..2dc05ce90 100644 --- a/tests/metrics/test_relation_argument_distance_collector.py +++ b/tests/metrics/test_relation_argument_distance_collector.py @@ -82,107 +82,7 @@ def test_relation_argument_distance_collector_with_n_ary_relation(): } -def test_relation_argument_distance_collector_with_tokenize(): - doc = TestDocument( - text="This is the first entity. This is the second entity. And, this is the third entity." - ) - - doc.entities.append(LabeledSpan(start=0, end=25, label="entity")) - doc.entities.append(LabeledSpan(start=26, end=52, label="entity")) - doc.entities.append(LabeledSpan(start=53, end=83, label="entity")) - doc.relations.append( - BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="relation_label_1") - ) - doc.relations.append( - BinaryRelation(head=doc.entities[1], tail=doc.entities[2], label="relation_label_2") - ) - - @dataclasses.dataclass - class TokenizedTestDocument(TokenBasedDocument): - entities: AnnotationLayer[LabeledSpan] = annotation_field(target="tokens") - relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities") - - statistic = RelationArgumentDistanceCollector( - layer="relations", - tokenize=True, - tokenizer="bert-base-uncased", - tokenized_document_type=TokenizedTestDocument, - ) - values = statistic(doc) - assert values == { - "ALL": {"len": 4, "mean": 13.0, "std": 1.0, "min": 12.0, "max": 14.0}, - "relation_label_1": {"len": 2, "mean": 12.0, "std": 0.0, "min": 12.0, "max": 12.0}, - "relation_label_2": {"len": 2, "mean": 14.0, "std": 0.0, "min": 14.0, "max": 14.0}, - } - - -def test_relation_argument_distance_collector_with_tokenize_missing_tokenizer(): - with pytest.raises(ValueError) as excinfo: - RelationArgumentDistanceCollector( - layer="relations", - tokenize=True, - tokenized_document_type=TokenBasedDocument, - ) - assert ( - str(excinfo.value) == "tokenizer must be provided to calculate distance in means of tokens" - ) - - -def test_relation_argument_distance_collector_with_tokenize_missing_tokenized_document_type(): - with pytest.raises(ValueError) as excinfo: - RelationArgumentDistanceCollector( - layer="relations", - tokenize=True, - tokenizer="bert-base-uncased", - ) - assert ( - str(excinfo.value) - == "tokenized_document_type must be provided to calculate distance in means of tokens" - ) - - -def test_relation_argument_distance_collector_with_tokenize_wrong_document_type(): - @dataclasses.dataclass - class TestDocument(Document): - data: str - entities: AnnotationLayer[LabeledSpan] = annotation_field(target="data") - relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities") - - doc = TestDocument( - data="This is the first entity. This is the second entity. This is the third entity." - ) - - doc.entities.append(LabeledSpan(start=0, end=25, label="entity")) - doc.entities.append(LabeledSpan(start=26, end=52, label="entity")) - doc.entities.append(LabeledSpan(start=53, end=78, label="entity")) - doc.relations.append( - BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="relation_label_1") - ) - doc.relations.append( - BinaryRelation(head=doc.entities[1], tail=doc.entities[2], label="relation_label_2") - ) - - @dataclasses.dataclass - class TokenizedTestDocument(TokenBasedDocument): - entities: AnnotationLayer[LabeledSpan] = annotation_field(target="tokens") - relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities") - - statistic = RelationArgumentDistanceCollector( - layer="relations", - tokenize=True, - tokenizer="bert-base-uncased", - tokenized_document_type=TokenizedTestDocument, - ) - - with pytest.raises(ValueError) as excinfo: - statistic(doc) - assert ( - str(excinfo.value) - == "doc must be a TextBasedDocument to calculate distance in means of tokens" - ) - - -def test_relation_argument_distance_collector_with_tokenize_wrong_span_annotation_type(): +def test_relation_argument_distance_collector_with_wrong_span_annotation_type(): @dataclasses.dataclass(eq=True, frozen=True) class UnknownSpan(Annotation): start: int @@ -216,7 +116,7 @@ class TestDocument(TextBasedDocument): ) -def test_relation_argument_distance_collector_with_tokenize_wrong_relation_annotation_type(): +def test_relation_argument_distance_collector_with_wrong_relation_annotation_type(): @dataclasses.dataclass(eq=True, frozen=True) class UnknownRelation(Annotation): head: Annotation diff --git a/tests/metrics/test_span_coverage_collector.py b/tests/metrics/test_span_coverage_collector.py index 026ec28b7..109352df3 100644 --- a/tests/metrics/test_span_coverage_collector.py +++ b/tests/metrics/test_span_coverage_collector.py @@ -1,10 +1,10 @@ import dataclasses import pytest -from pie_core import Annotation, AnnotationLayer, Document, annotation_field +from pie_core import Annotation, AnnotationLayer, annotation_field from pie_modules.annotations import LabeledMultiSpan, LabeledSpan -from pie_modules.documents import TextBasedDocument, TokenBasedDocument +from pie_modules.documents import TextBasedDocument from pie_modules.metrics import SpanCoverageCollector @@ -53,85 +53,7 @@ def test_span_coverage_collector_with_labels(): assert values == {"len": 1, "max": 0.125, "mean": 0.125, "min": 0.125, "std": 0.0} -def test_span_coverage_collector_with_tokenize(): - doc = TestDocument(text="A and O.") - doc.entities.append(LabeledSpan(start=0, end=1, label="entity")) - doc.entities.append(LabeledSpan(start=6, end=7, label="entity")) - - @dataclasses.dataclass - class TokenizedTestDocument(TokenBasedDocument): - entities: AnnotationLayer[LabeledSpan] = annotation_field(target="tokens") - - statistic = SpanCoverageCollector( - layer="entities", - tokenize=True, - tokenizer="bert-base-uncased", - tokenized_document_type=TokenizedTestDocument, - ) - values = statistic(doc) - assert values == { - "len": 1, - "max": 0.3333333333333333, - "mean": 0.3333333333333333, - "min": 0.3333333333333333, - "std": 0.0, - } - - -def test_span_coverage_collector_with_tokenize_missing_tokenizer(): - with pytest.raises(ValueError) as excinfo: - SpanCoverageCollector( - layer="entities", - tokenize=True, - tokenized_document_type=TokenBasedDocument, - ) - assert ( - str(excinfo.value) - == "tokenizer must be provided to calculate the span coverage in means of tokens" - ) - - -def test_span_coverage_collector_with_tokenize_missing_tokenized_document_type(): - with pytest.raises(ValueError) as excinfo: - SpanCoverageCollector( - layer="entities", - tokenize=True, - tokenizer="bert-base-uncased", - ) - assert ( - str(excinfo.value) - == "tokenized_document_type must be provided to calculate the span coverage in means of tokens" - ) - - -def test_span_coverage_collector_with_tokenize_wrong_document_type(): - @dataclasses.dataclass - class TestDocument(Document): - data: str - entities: AnnotationLayer[LabeledSpan] = annotation_field(target="data") - - doc = TestDocument(data="A and O") - - @dataclasses.dataclass - class TokenizedTestDocument(TokenBasedDocument): - entities: AnnotationLayer[LabeledSpan] = annotation_field(target="tokens") - - statistic = SpanCoverageCollector( - layer="entities", - tokenize=True, - tokenizer="bert-base-uncased", - tokenized_document_type=TokenizedTestDocument, - ) - - with pytest.raises(ValueError) as excinfo: - statistic(doc) - assert ( - str(excinfo.value) - == "doc must be a TextBasedDocument to calculate the span coverage in means of tokens" - ) - - -def test_span_coverage_collector_with_tokenize_wrong_annotation_type(): +def test_span_coverage_collector_with_wrong_annotation_type(): @dataclasses.dataclass(eq=True, frozen=True) class UnknownSpan(Annotation): start: int diff --git a/tests/metrics/test_span_length_collector.py b/tests/metrics/test_span_length_collector.py index dcd69af01..f482b1066 100644 --- a/tests/metrics/test_span_length_collector.py +++ b/tests/metrics/test_span_length_collector.py @@ -78,83 +78,7 @@ def test_span_length_collector_wrong_label_value(): assert str(excinfo.value) == "labels must be a list of strings or 'INFERRED'" -def test_span_length_collector_with_tokenize(documents): - @dataclasses.dataclass - class TokenizedTestDocument(TokenBasedDocument): - entities: AnnotationLayer[LabeledSpan] = annotation_field(target="tokens") - - statistic = SpanLengthCollector( - layer="entities", - tokenize=True, - tokenizer="bert-base-uncased", - tokenized_document_type=TokenizedTestDocument, - ) - values = statistic(documents) - assert values == { - "len": 7, - "max": 3, - "mean": 1.8571428571428572, - "min": 1, - "std": 0.8329931278350429, - } - - -def test_span_length_collector_with_tokenize_missing_tokenizer(): - with pytest.raises(ValueError) as excinfo: - SpanLengthCollector( - layer="entities", - tokenize=True, - tokenized_document_type=TokenBasedDocument, - ) - assert ( - str(excinfo.value) - == "tokenizer must be provided to calculate the span length in means of tokens" - ) - - -def test_span_length_collector_with_tokenize_missing_tokenized_document_type(): - with pytest.raises(ValueError) as excinfo: - SpanLengthCollector( - layer="entities", - tokenize=True, - tokenizer="bert-base-uncased", - ) - assert ( - str(excinfo.value) - == "tokenized_document_type must be provided to calculate the span length in means of tokens" - ) - - -def test_span_length_collector_with_tokenize_wrong_document_type(): - @dataclasses.dataclass - class TestDocument(Document): - data: str - entities: AnnotationLayer[LabeledSpan] = annotation_field(target="data") - - doc = TestDocument(data="First sentence. Entity M works at N. And it founded O.") - doc.entities.append(LabeledSpan(start=16, end=24, label="per")) - assert str(doc.entities[0]) == "Entity M" - - @dataclasses.dataclass - class TokenizedTestDocument(TokenBasedDocument): - entities: AnnotationLayer[LabeledSpan] = annotation_field(target="tokens") - - statistic = SpanLengthCollector( - layer="entities", - tokenize=True, - tokenizer="bert-base-uncased", - tokenized_document_type=TokenizedTestDocument, - ) - - with pytest.raises(ValueError) as excinfo: - statistic(doc) - assert ( - str(excinfo.value) - == "doc must be a TextBasedDocument to calculate the span length in means of tokens" - ) - - -def test_span_length_collector_with_tokenize_wrong_annotation_type(): +def test_span_length_collector_with_wrong_annotation_type(): @dataclasses.dataclass class TestDocument(TextBasedDocument): label: AnnotationLayer[Label] = annotation_field() @@ -168,5 +92,5 @@ class TestDocument(TextBasedDocument): statistic(doc) assert ( str(excinfo.value) - == "span length calculation is not yet supported for " + == "span length calculation is not yet supported for " ) diff --git a/tests/metrics/test_statistics.py b/tests/metrics/test_statistics.py index 4d1325795..951dbd57c 100644 --- a/tests/metrics/test_statistics.py +++ b/tests/metrics/test_statistics.py @@ -3,7 +3,6 @@ FieldLengthCollector, LabelCountCollector, SubFieldLengthCollector, - TokenCountCollector, ) @@ -71,17 +70,3 @@ def test_statistics(document_dataset): "val": {"mean": 3.0, "std": 0.0, "min": 3, "max": 3}, "train": {"mean": 3.0, "std": 0.0, "min": 3, "max": 3}, } - - -def test_statistics_with_tokenize(document_dataset): - statistic = TokenCountCollector( - text_field="text", - tokenizer="bert-base-uncased", - tokenizer_kwargs=dict(add_special_tokens=False), - ) - values = statistic(document_dataset) - assert values == { - "test": {"max": 13, "mean": 8.5, "min": 4, "std": 4.5}, - "train": {"max": 14, "mean": 8.285714285714286, "min": 4, "std": 3.5742845723419436}, - "val": {"max": 13, "mean": 8.5, "min": 4, "std": 4.5}, - } diff --git a/tests/models/__init__.py b/tests/models/__init__.py deleted file mode 100644 index 1e3e58002..000000000 --- a/tests/models/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -def trunc_number(x: float, n: int) -> float: - return int(x * 10**n) / 10**n diff --git a/tests/models/base_models/__init__.py b/tests/models/base_models/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/models/base_models/test_bart_as_pointer_network.py b/tests/models/base_models/test_bart_as_pointer_network.py deleted file mode 100644 index 554e52108..000000000 --- a/tests/models/base_models/test_bart_as_pointer_network.py +++ /dev/null @@ -1,983 +0,0 @@ -import pytest -import torch -from transformers import ( - BartModel, - BeamSearchScorer, - LogitsProcessorList, - MinLengthLogitsProcessor, -) -from transformers.generation import BeamSearchEncoderDecoderOutput - -from pie_modules.models.base_models import ( - BartAsPointerNetwork, - BartModelWithDecoderPositionIds, -) -from tests import _config_to_str -from tests.models import trunc_number - -# this is a small model that can be used for testing -MODEL_NAME_OR_PATH = "sshleifer/bart-tiny-random" -DECODER_POSITION_ID_PATTERN = [0, 0, 1, 0, 0, 1, 1] -CONFIGS = [{}, {"decoder_position_id_mode": "pattern"}] -CONFIG_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS} - - -@pytest.fixture(scope="module", params=CONFIG_DICT.keys()) -def config_str(request): - return request.param - - -@pytest.fixture(scope="module") -def config(config_str): - return CONFIG_DICT[config_str] - - -@pytest.fixture(scope="module") -def document(): - from pie_modules.annotations import BinaryRelation, LabeledSpan - from pie_modules.documents import ( - TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, - ) - - doc = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions( - text="This is a dummy text about nothing. Trust me." - ) - span1 = LabeledSpan(start=10, end=20, label="content") - span2 = LabeledSpan(start=27, end=34, label="topic") - span3 = LabeledSpan(start=42, end=44, label="person") - doc.labeled_spans.extend([span1, span2, span3]) - assert str(span1) == "dummy text" - assert str(span2) == "nothing" - assert str(span3) == "me" - rel = BinaryRelation(head=span1, tail=span2, label="is_about") - doc.binary_relations.append(rel) - assert str(rel.label) == "is_about" - assert str(rel.head) == "dummy text" - assert str(rel.tail) == "nothing" - - sent1 = LabeledSpan(start=0, end=35, label="1") - sent2 = LabeledSpan(start=36, end=45, label="2") - doc.labeled_partitions.extend([sent1, sent2]) - assert str(sent1) == "This is a dummy text about nothing." - assert str(sent2) == "Trust me." - return doc - - -@pytest.fixture(scope="module") -def taskmodule(document): - from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE - - taskmodule = PointerNetworkTaskModuleForEnd2EndRE( - tokenizer_name_or_path=MODEL_NAME_OR_PATH, - partition_layer_name="labeled_partitions", - create_constraints=True, - ) - - taskmodule.prepare(documents=[document]) - - return taskmodule - - -@pytest.fixture(scope="module") -def model(config) -> BartAsPointerNetwork: - model_name_or_path = MODEL_NAME_OR_PATH - - torch.random.manual_seed(42) - model = BartAsPointerNetwork.from_pretrained( - model_name_or_path, - # label id space - bos_token_id=0, # taskmodule.bos_id, - eos_token_id=1, # taskmodule.eos_id, - pad_token_id=1, # taskmodule.eos_id, - # target token id space - target_token_ids=[0, 2, 50266, 50269, 50268, 50265, 50267], # taskmodule.target_token_ids, - # mapping to better initialize the label embedding weights - # taken from taskmodule.label_embedding_weight_mapping - embedding_weight_mapping={ - 50266: [39763], - 50269: [10166], - 50268: [5970], - 50265: [45260], - 50267: [354, 1215, 9006], - }, - decoder_position_id_pattern=DECODER_POSITION_ID_PATTERN, - **config, - ) - - return model - - -def test_model(model, config): - assert model is not None - named_parameters = dict(model.named_parameters()) - parameter_means = {k: trunc_number(v.mean().item(), 7) for k, v in named_parameters.items()} - parameter_means_expected = { - "model.shared.weight": -1.41e-05, - "model.encoder.embed_positions.weight": -0.0001324, - "model.encoder.layers.0.self_attn.k_proj.weight": -0.0004574, - "model.encoder.layers.0.self_attn.k_proj.bias": 0.0, - "model.encoder.layers.0.self_attn.v_proj.weight": -0.0005457, - "model.encoder.layers.0.self_attn.v_proj.bias": 0.0, - "model.encoder.layers.0.self_attn.q_proj.weight": -0.0009775, - "model.encoder.layers.0.self_attn.q_proj.bias": 0.0, - "model.encoder.layers.0.self_attn.out_proj.weight": -0.0001075, - "model.encoder.layers.0.self_attn.out_proj.bias": 0.0, - "model.encoder.layers.0.self_attn_layer_norm.weight": 1.0, - "model.encoder.layers.0.self_attn_layer_norm.bias": 0.0, - "model.encoder.layers.0.fc1.weight": -0.0008655, - "model.encoder.layers.0.fc1.bias": 0.0, - "model.encoder.layers.0.fc2.weight": 0.0015535, - "model.encoder.layers.0.fc2.bias": 0.0, - "model.encoder.layers.0.final_layer_norm.weight": 1.0, - "model.encoder.layers.0.final_layer_norm.bias": 0.0, - "model.encoder.layers.1.self_attn.k_proj.weight": -0.0007831, - "model.encoder.layers.1.self_attn.k_proj.bias": 0.0, - "model.encoder.layers.1.self_attn.v_proj.weight": 0.0001186, - "model.encoder.layers.1.self_attn.v_proj.bias": 0.0, - "model.encoder.layers.1.self_attn.q_proj.weight": 0.0006847, - "model.encoder.layers.1.self_attn.q_proj.bias": 0.0, - "model.encoder.layers.1.self_attn.out_proj.weight": 0.0011724, - "model.encoder.layers.1.self_attn.out_proj.bias": 0.0, - "model.encoder.layers.1.self_attn_layer_norm.weight": 1.0, - "model.encoder.layers.1.self_attn_layer_norm.bias": 0.0, - "model.encoder.layers.1.fc1.weight": 0.0007757, - "model.encoder.layers.1.fc1.bias": 0.0, - "model.encoder.layers.1.fc2.weight": -0.0002014, - "model.encoder.layers.1.fc2.bias": 0.0, - "model.encoder.layers.1.final_layer_norm.weight": 1.0, - "model.encoder.layers.1.final_layer_norm.bias": 0.0, - "model.encoder.layernorm_embedding.weight": 1.0, - "model.encoder.layernorm_embedding.bias": 0.0, - "model.decoder.embed_positions.weight": -0.0001275, - "model.decoder.layers.0.self_attn.k_proj.weight": -0.0010682, - "model.decoder.layers.0.self_attn.k_proj.bias": 0.0, - "model.decoder.layers.0.self_attn.v_proj.weight": 0.0005057, - "model.decoder.layers.0.self_attn.v_proj.bias": 0.0, - "model.decoder.layers.0.self_attn.q_proj.weight": 0.0003248, - "model.decoder.layers.0.self_attn.q_proj.bias": 0.0, - "model.decoder.layers.0.self_attn.out_proj.weight": -0.0002014, - "model.decoder.layers.0.self_attn.out_proj.bias": 0.0, - "model.decoder.layers.0.self_attn_layer_norm.weight": 1.0, - "model.decoder.layers.0.self_attn_layer_norm.bias": 0.0, - "model.decoder.layers.0.encoder_attn.k_proj.weight": -0.0004254, - "model.decoder.layers.0.encoder_attn.k_proj.bias": 0.0, - "model.decoder.layers.0.encoder_attn.v_proj.weight": -0.0004049, - "model.decoder.layers.0.encoder_attn.v_proj.bias": 0.0, - "model.decoder.layers.0.encoder_attn.q_proj.weight": -0.0003516, - "model.decoder.layers.0.encoder_attn.q_proj.bias": 0.0, - "model.decoder.layers.0.encoder_attn.out_proj.weight": 0.0009908, - "model.decoder.layers.0.encoder_attn.out_proj.bias": 0.0, - "model.decoder.layers.0.encoder_attn_layer_norm.weight": 1.0, - "model.decoder.layers.0.encoder_attn_layer_norm.bias": 0.0, - "model.decoder.layers.0.fc1.weight": 0.0008378, - "model.decoder.layers.0.fc1.bias": 0.0, - "model.decoder.layers.0.fc2.weight": -2e-05, - "model.decoder.layers.0.fc2.bias": 0.0, - "model.decoder.layers.0.final_layer_norm.weight": 1.0, - "model.decoder.layers.0.final_layer_norm.bias": 0.0, - "model.decoder.layers.1.self_attn.k_proj.weight": -0.0007669, - "model.decoder.layers.1.self_attn.k_proj.bias": 0.0, - "model.decoder.layers.1.self_attn.v_proj.weight": -0.0007123, - "model.decoder.layers.1.self_attn.v_proj.bias": 0.0, - "model.decoder.layers.1.self_attn.q_proj.weight": 0.0012958, - "model.decoder.layers.1.self_attn.q_proj.bias": 0.0, - "model.decoder.layers.1.self_attn.out_proj.weight": -0.0006818, - "model.decoder.layers.1.self_attn.out_proj.bias": 0.0, - "model.decoder.layers.1.self_attn_layer_norm.weight": 1.0, - "model.decoder.layers.1.self_attn_layer_norm.bias": 0.0, - "model.decoder.layers.1.encoder_attn.k_proj.weight": -0.0006906, - "model.decoder.layers.1.encoder_attn.k_proj.bias": 0.0, - "model.decoder.layers.1.encoder_attn.v_proj.weight": -0.0009213, - "model.decoder.layers.1.encoder_attn.v_proj.bias": 0.0, - "model.decoder.layers.1.encoder_attn.q_proj.weight": -0.000842, - "model.decoder.layers.1.encoder_attn.q_proj.bias": 0.0, - "model.decoder.layers.1.encoder_attn.out_proj.weight": 0.0008073, - "model.decoder.layers.1.encoder_attn.out_proj.bias": 0.0, - "model.decoder.layers.1.encoder_attn_layer_norm.weight": 1.0, - "model.decoder.layers.1.encoder_attn_layer_norm.bias": 0.0, - "model.decoder.layers.1.fc1.weight": 0.0015493, - "model.decoder.layers.1.fc1.bias": 0.0, - "model.decoder.layers.1.fc2.weight": -0.0009827, - "model.decoder.layers.1.fc2.bias": 0.0, - "model.decoder.layers.1.final_layer_norm.weight": 1.0, - "model.decoder.layers.1.final_layer_norm.bias": 0.0, - "model.decoder.layernorm_embedding.weight": 1.0, - "model.decoder.layernorm_embedding.bias": 0.0, - "pointer_head.encoder_mlp.0.weight": 0.0004805, - "pointer_head.encoder_mlp.0.bias": 0.0, - "pointer_head.encoder_mlp.3.weight": 0.0001837, - "pointer_head.encoder_mlp.3.bias": 0.0, - } - assert parameter_means == parameter_means_expected - assert isinstance(model, BartAsPointerNetwork) - if config == {}: - assert isinstance(model.model, BartModel) - elif config == {"decoder_position_id_mode": "pattern"}: - assert isinstance(model.model, BartModelWithDecoderPositionIds) - else: - raise ValueError(f"Unknown config: {config}") - - -@pytest.fixture(scope="module") -def batch(): - inputs = { - "input_ids": torch.tensor( - [ - [0, 713, 16, 10, 34759, 2788, 59, 1085, 4, 2], - [0, 18823, 162, 4, 2, 1, 1, 1, 1, 1], - ] - ), - "attention_mask": torch.tensor( - [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]] - ), - } - targets = { - "labels": torch.tensor([[14, 14, 5, 11, 12, 3, 6, 1], [9, 9, 4, 2, 2, 2, 2, 1]]), - "decoder_attention_mask": torch.tensor( - [[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]] - ), - } - return inputs, targets - - -@pytest.fixture(scope="module") -def batch_with_constraints(batch): - constraints = torch.tensor( - [ - [ - [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0], - [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ], - [ - [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, -1, -1, -1, -1, -1], - [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1], - [0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, -1, -1, -1, -1, -1], - [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1], - [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1], - [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1], - [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1], - ], - ] - ) - targets_with_constraints = {**batch[1], "constraints": constraints} - return batch[0], targets_with_constraints - - -@pytest.mark.skip(reason="This is just to show how to create the batch.") -def test_batch_with_constraints(batch_with_constraints, taskmodule, document): - inputs, targets = batch_with_constraints - task_encodings = taskmodule.encode([document], encode_target=True) - batch_from_documents = taskmodule.collate(task_encodings) - inputs_from_documents, targets_from_documents = batch_from_documents - for key in inputs: - torch.testing.assert_close(inputs[key], inputs_from_documents[key]) - - for key in targets: - torch.testing.assert_close(targets[key], targets_from_documents[key]) - - -@pytest.fixture(scope="module") -def decoder_input_ids(model): - # taken from batch[1]["labels"] - labels = torch.tensor([[14, 14, 5, 11, 12, 3, 6, 1], [9, 9, 4, 2, 2, 2, 2, 1]]) - decoder_input_ids = model.prepare_decoder_input_ids_from_labels(labels=labels) - return decoder_input_ids - - -def test_prepare_decoder_input_ids_from_labels(decoder_input_ids): - assert decoder_input_ids.shape == (2, 8) - torch.testing.assert_close( - decoder_input_ids, - torch.tensor([[0, 14, 14, 5, 11, 12, 3, 6], [0, 9, 9, 4, 2, 2, 2, 2]]), - ) - - -def test_forward(model, batch, decoder_input_ids, config): - inputs, targets = batch - torch.manual_seed(42) - outputs = model(**inputs, decoder_input_ids=decoder_input_ids) - assert outputs.loss is None - assert outputs.logits is not None - # shape: (batch_size, output_seq_len, target_size=num_target_ids+num_offsets) - assert outputs.logits.shape == (2, 8, 17) - # check exact values only for the first sequence output - torch.testing.assert_close( - outputs.logits[:, 0, :], - torch.tensor( - [ - [ - -1.0000000138484279e24, - -0.23238050937652588, - 0.2958170175552368, - 0.05529244244098663, - 0.04253090173006058, - 0.10081345587968826, - -0.07145103067159653, - 0.12317530065774918, - -0.06861806660890579, - 0.07819556444883347, - 0.006490768864750862, - -0.040455855429172516, - 0.03176971897482872, - 0.05362509936094284, - 0.04528001323342323, - -0.0684177577495575, - -1.0000000331813535e32, - ], - [ - -1.0000000138484279e24, - -0.23274855315685272, - 0.2960396707057953, - 0.05556505173444748, - 0.04273710399866104, - 0.10071954131126404, - -0.071356862783432, - 0.12314081937074661, - 0.06498698145151138, - 0.07938676327466965, - -0.07943986356258392, - -1.0000000331813535e32, - -1.0000000331813535e32, - -1.0000000331813535e32, - -1.0000000331813535e32, - -1.0000000331813535e32, - -1.0000000331813535e32, - ], - ] - ), - ) - # check the sum of all logits - if config == {}: - torch.testing.assert_close( - outputs.logits.sum(0).sum(0), - torch.tensor( - [ - -1.6000000221574846e25, - -0.9064984321594238, - 1.189674735069275, - 0.9796359539031982, - 0.1837124526500702, - 1.3070943355560303, - -0.1210818886756897, - 0.5316579937934875, - -0.12306825071573257, - 0.6218758225440979, - -0.4374474287033081, - -8.000000265450828e32, - -8.000000265450828e32, - -8.000000265450828e32, - -8.000000265450828e32, - -8.000000265450828e32, - -1.6000000530901656e33, - ] - ), - ) - elif config == {"decoder_position_id_mode": "pattern"}: - torch.testing.assert_close( - outputs.logits.sum(0).sum(0), - torch.tensor( - [ - -1.6000000221574846e25, - -0.5539568662643433, - 0.7004716396331787, - 1.5720455646514893, - -0.3760950267314911, - 0.7738710641860962, - -0.1090446263551712, - 0.287150502204895, - -0.04344810172915459, - 0.3674442768096924, - -0.6838937997817993, - -8.000000265450828e32, - -8.000000265450828e32, - -8.000000265450828e32, - -8.000000265450828e32, - -8.000000265450828e32, - -1.6000000530901656e33, - ] - ), - ) - else: - raise ValueError(f"Unknown config: {config}") - - -def test_forward_with_labels(model, batch, config): - inputs, targets = batch - targets_without_constraints = { - key: value for key, value in targets.items() if key != "constraints" - } - assert set(inputs) == {"input_ids", "attention_mask"} - assert set(targets_without_constraints) == {"labels", "decoder_attention_mask"} - torch.manual_seed(42) - outputs = model(**inputs, **targets_without_constraints) - loss = outputs.loss - if config == {}: - torch.testing.assert_close(loss, torch.tensor(2.4516539573669434)) - elif config == {"decoder_position_id_mode": "pattern"}: - torch.testing.assert_close(loss, torch.tensor(2.4184868335723877)) - else: - raise ValueError(f"Unknown config: {config}") - - -def test_forward_with_labels_and_constraints(model, batch_with_constraints, config): - inputs, targets = batch_with_constraints - assert set(inputs) == {"input_ids", "attention_mask"} - assert set(targets) == {"labels", "decoder_attention_mask", "constraints"} - torch.manual_seed(42) - outputs = model(**inputs, **targets) - loss = outputs.loss - if config == {}: - torch.testing.assert_close(loss, torch.tensor(4.776531219482422)) - elif config == {"decoder_position_id_mode": "pattern"}: - torch.testing.assert_close(loss, torch.tensor(4.742183685302734)) - else: - raise ValueError(f"Unknown model type {type(model.model)}") - - -@pytest.fixture(scope="module") -def empty_decoder_input_ids(batch, model): - inputs, targets = batch - batch_size, seq_len = inputs["input_ids"].shape - decoder_input_ids = torch.ones((batch_size, 1), dtype=torch.long) * model.config.bos_token_id - torch.testing.assert_close( - decoder_input_ids, - torch.tensor([[0], [0]]), - ) - return decoder_input_ids - - -@pytest.fixture(scope="module") -def encoder_outputs(model, batch): - inputs, targets = batch - torch.manual_seed(42) - encoder_outputs = model.get_encoder()( - input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"] - ) - return encoder_outputs - - -@pytest.fixture(scope="module") -def prepared_encoder_decoder_kwargs_for_generation( - model, batch, empty_decoder_input_ids, encoder_outputs -): - model_kwargs = { - "attention_mask": batch[0]["attention_mask"], - "output_attentions": False, - "output_hidden_states": False, - "use_cache": True, - } - torch.manual_seed(42) - prepared_kwargs = model._prepare_encoder_decoder_kwargs_for_generation( - inputs_tensor=batch[0]["input_ids"], - model_kwargs=model_kwargs, - model_input_name="input_ids", - ) - return prepared_kwargs - - -def test_prepare_encoder_decoder_kwargs_for_generation( - prepared_encoder_decoder_kwargs_for_generation, batch, encoder_outputs -): - model_kwargs = { - "attention_mask": batch[0]["attention_mask"], - "output_attentions": False, - "output_hidden_states": False, - "use_cache": True, - } - - assert set(prepared_encoder_decoder_kwargs_for_generation) == set(model_kwargs) | { - "encoder_input_ids", - "encoder_attention_mask", - "encoder_outputs", - } - torch.testing.assert_close( - prepared_encoder_decoder_kwargs_for_generation["encoder_input_ids"], - batch[0]["input_ids"], - ) - torch.testing.assert_close( - prepared_encoder_decoder_kwargs_for_generation["encoder_attention_mask"], - batch[0]["attention_mask"], - ) - torch.testing.assert_close( - prepared_encoder_decoder_kwargs_for_generation["encoder_outputs"].last_hidden_state, - encoder_outputs.last_hidden_state, - ) - - -def test_prepare_inputs_for_generation( - model, - prepared_encoder_decoder_kwargs_for_generation, - empty_decoder_input_ids, - batch, - encoder_outputs, - config, -): - result = model.prepare_inputs_for_generation( - decoder_input_ids=empty_decoder_input_ids, **prepared_encoder_decoder_kwargs_for_generation - ) - result_keys = { - "input_ids", - "attention_mask", - "encoder_outputs", - "decoder_input_ids", - "decoder_attention_mask", - "past_key_values", - "use_cache", - "head_mask", - "decoder_head_mask", - "cross_attn_head_mask", - } - if model.pointer_head.use_prepared_position_ids: - result_keys.add("decoder_position_ids") - assert set(result) == result_keys - torch.testing.assert_close( - result["input_ids"], - batch[0]["input_ids"], - ) - torch.testing.assert_close( - result["attention_mask"], - batch[0]["attention_mask"], - ) - torch.testing.assert_close( - result["encoder_outputs"].last_hidden_state, - encoder_outputs.last_hidden_state, - ) - torch.testing.assert_close( - result["decoder_input_ids"], - empty_decoder_input_ids, - ) - assert result["decoder_attention_mask"] is None - assert result["past_key_values"] is None - assert result["use_cache"] is True - assert result["head_mask"] is None - assert result["decoder_head_mask"] is None - assert result["cross_attn_head_mask"] is None - if config == {}: - assert "decoder_position_ids" not in result - elif config == {"decoder_position_id_mode": "pattern"}: - torch.testing.assert_close(result["decoder_position_ids"], torch.tensor([[0], [0]])) - else: - raise ValueError(f"Unknown config: {config}") - - -def test_prepare_inputs_for_generation_with_past_key_values( - model, - prepared_encoder_decoder_kwargs_for_generation, - batch, - encoder_outputs, - config, -): - # shallow copy to avoid changing the original dict - kwargs = dict(prepared_encoder_decoder_kwargs_for_generation) - kwargs["decoder_input_ids"] = torch.tensor( - [ - [0, 8, 9], - [0, 8, 10], - [0, 8, 15], - [0, 8, 8], - [0, 9, 10], - [0, 8, 12], - [0, 8, 9], - [0, 8, 10], - [0, 9, 10], - [0, 8, 8], - [0, 9, 9], - [0, 8, 6], - ] - ) - # 12 is batch_size (2) * num_beams (6), - # 16 is number of encoder / decoder attention heads, - # 2 is the length of already generated tokens / 10 is the length of the encoder input, - # 64 seems to be the size of the hidden states - dummy_past_key_values = ( - torch.zeros((12, 16, 2, 64)), - torch.zeros((12, 16, 2, 64)), - torch.zeros((12, 16, 10, 64)), - torch.zeros((12, 16, 10, 64)), - ) - - result = model.prepare_inputs_for_generation(past_key_values=dummy_past_key_values, **kwargs) - if config == {}: - assert len(result) == 10 - elif config == {"decoder_position_id_mode": "pattern"}: - assert len(result) == 11 - else: - raise ValueError(f"Unknown config: {config}") - torch.testing.assert_close( - result["input_ids"], - batch[0]["input_ids"], - ) - torch.testing.assert_close( - result["attention_mask"], - batch[0]["attention_mask"], - ) - torch.testing.assert_close( - result["encoder_outputs"].last_hidden_state, - encoder_outputs.last_hidden_state, - ) - torch.testing.assert_close( - result["decoder_input_ids"], - # just the last id for each entry - torch.tensor([[9], [10], [15], [8], [10], [12], [9], [10], [10], [8], [9], [6]]), - ) - assert result["decoder_attention_mask"] is None - assert result["past_key_values"] is dummy_past_key_values - assert result["use_cache"] is True - assert result["head_mask"] is None - assert result["decoder_head_mask"] is None - assert result["cross_attn_head_mask"] is None - if "decoder_position_ids" in result: - torch.testing.assert_close( - result["decoder_position_ids"], - # originally this was 0 from the pattern, but got shifted for the position-bos and position-pad indices - torch.tensor([[2], [2], [2], [2], [2], [2], [2], [2], [2], [2], [2], [2]]), - ) - - -def test_generate(model, batch, empty_decoder_input_ids, config): - inputs, targets = batch - batch_size, seq_len = inputs["input_ids"].shape - torch.manual_seed(42) - outputs = model.generate(**inputs) - if config == {}: - assert outputs.shape == (batch_size, 20) # note that 20 is the model.config.max_length - torch.testing.assert_close( - outputs, - torch.tensor( - [ - [0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], - [0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], - ] - ), - ) - elif config == {"decoder_position_id_mode": "pattern"}: - assert outputs.shape == (batch_size, 20) # note that 20 is the model.config.max_length - torch.testing.assert_close( - outputs, - torch.tensor( - [ - [0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], - [0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], - ] - ), - ) - else: - raise ValueError(f"Unknown config: {config}") - - -def test_head_named_params(model): - parameter_shapes = {name: tuple(param.shape) for name, param in model.head_named_params()} - assert parameter_shapes == { - "pointer_head.encoder_mlp.0.bias": (24,), - "pointer_head.encoder_mlp.0.weight": (24, 24), - "pointer_head.encoder_mlp.3.bias": (24,), - "pointer_head.encoder_mlp.3.weight": (24, 24), - } - - -def test_encoder_only_named_params(model): - parameter_shapes = { - name: tuple(param.shape) for name, param in model.encoder_only_named_params() - } - assert len(parameter_shapes) == 35 - assert parameter_shapes == { - "model.encoder.embed_positions.weight": (1026, 24), - "model.encoder.layernorm_embedding.bias": (24,), - "model.encoder.layernorm_embedding.weight": (24,), - "model.encoder.layers.0.fc1.bias": (16,), - "model.encoder.layers.0.fc1.weight": (16, 24), - "model.encoder.layers.0.fc2.bias": (24,), - "model.encoder.layers.0.fc2.weight": (24, 16), - "model.encoder.layers.0.final_layer_norm.bias": (24,), - "model.encoder.layers.0.final_layer_norm.weight": (24,), - "model.encoder.layers.0.self_attn.k_proj.bias": (24,), - "model.encoder.layers.0.self_attn.k_proj.weight": (24, 24), - "model.encoder.layers.0.self_attn.out_proj.bias": (24,), - "model.encoder.layers.0.self_attn.out_proj.weight": (24, 24), - "model.encoder.layers.0.self_attn.q_proj.bias": (24,), - "model.encoder.layers.0.self_attn.q_proj.weight": (24, 24), - "model.encoder.layers.0.self_attn.v_proj.bias": (24,), - "model.encoder.layers.0.self_attn.v_proj.weight": (24, 24), - "model.encoder.layers.0.self_attn_layer_norm.bias": (24,), - "model.encoder.layers.0.self_attn_layer_norm.weight": (24,), - "model.encoder.layers.1.fc1.bias": (16,), - "model.encoder.layers.1.fc1.weight": (16, 24), - "model.encoder.layers.1.fc2.bias": (24,), - "model.encoder.layers.1.fc2.weight": (24, 16), - "model.encoder.layers.1.final_layer_norm.bias": (24,), - "model.encoder.layers.1.final_layer_norm.weight": (24,), - "model.encoder.layers.1.self_attn.k_proj.bias": (24,), - "model.encoder.layers.1.self_attn.k_proj.weight": (24, 24), - "model.encoder.layers.1.self_attn.out_proj.bias": (24,), - "model.encoder.layers.1.self_attn.out_proj.weight": (24, 24), - "model.encoder.layers.1.self_attn.q_proj.bias": (24,), - "model.encoder.layers.1.self_attn.q_proj.weight": (24, 24), - "model.encoder.layers.1.self_attn.v_proj.bias": (24,), - "model.encoder.layers.1.self_attn.v_proj.weight": (24, 24), - "model.encoder.layers.1.self_attn_layer_norm.bias": (24,), - "model.encoder.layers.1.self_attn_layer_norm.weight": (24,), - } - - -def test_decoder_only_named_params(model): - parameter_shapes = { - name: tuple(param.shape) for name, param in model.decoder_only_named_params() - } - assert len(parameter_shapes) == 55 - assert parameter_shapes == { - "model.decoder.embed_positions.weight": (1026, 24), - "model.decoder.layernorm_embedding.bias": (24,), - "model.decoder.layernorm_embedding.weight": (24,), - "model.decoder.layers.0.encoder_attn.k_proj.bias": (24,), - "model.decoder.layers.0.encoder_attn.k_proj.weight": (24, 24), - "model.decoder.layers.0.encoder_attn.out_proj.bias": (24,), - "model.decoder.layers.0.encoder_attn.out_proj.weight": (24, 24), - "model.decoder.layers.0.encoder_attn.q_proj.bias": (24,), - "model.decoder.layers.0.encoder_attn.q_proj.weight": (24, 24), - "model.decoder.layers.0.encoder_attn.v_proj.bias": (24,), - "model.decoder.layers.0.encoder_attn.v_proj.weight": (24, 24), - "model.decoder.layers.0.encoder_attn_layer_norm.bias": (24,), - "model.decoder.layers.0.encoder_attn_layer_norm.weight": (24,), - "model.decoder.layers.0.fc1.bias": (16,), - "model.decoder.layers.0.fc1.weight": (16, 24), - "model.decoder.layers.0.fc2.bias": (24,), - "model.decoder.layers.0.fc2.weight": (24, 16), - "model.decoder.layers.0.final_layer_norm.bias": (24,), - "model.decoder.layers.0.final_layer_norm.weight": (24,), - "model.decoder.layers.0.self_attn.k_proj.bias": (24,), - "model.decoder.layers.0.self_attn.k_proj.weight": (24, 24), - "model.decoder.layers.0.self_attn.out_proj.bias": (24,), - "model.decoder.layers.0.self_attn.out_proj.weight": (24, 24), - "model.decoder.layers.0.self_attn.q_proj.bias": (24,), - "model.decoder.layers.0.self_attn.q_proj.weight": (24, 24), - "model.decoder.layers.0.self_attn.v_proj.bias": (24,), - "model.decoder.layers.0.self_attn.v_proj.weight": (24, 24), - "model.decoder.layers.0.self_attn_layer_norm.bias": (24,), - "model.decoder.layers.0.self_attn_layer_norm.weight": (24,), - "model.decoder.layers.1.encoder_attn.k_proj.bias": (24,), - "model.decoder.layers.1.encoder_attn.k_proj.weight": (24, 24), - "model.decoder.layers.1.encoder_attn.out_proj.bias": (24,), - "model.decoder.layers.1.encoder_attn.out_proj.weight": (24, 24), - "model.decoder.layers.1.encoder_attn.q_proj.bias": (24,), - "model.decoder.layers.1.encoder_attn.q_proj.weight": (24, 24), - "model.decoder.layers.1.encoder_attn.v_proj.bias": (24,), - "model.decoder.layers.1.encoder_attn.v_proj.weight": (24, 24), - "model.decoder.layers.1.encoder_attn_layer_norm.bias": (24,), - "model.decoder.layers.1.encoder_attn_layer_norm.weight": (24,), - "model.decoder.layers.1.fc1.bias": (16,), - "model.decoder.layers.1.fc1.weight": (16, 24), - "model.decoder.layers.1.fc2.bias": (24,), - "model.decoder.layers.1.fc2.weight": (24, 16), - "model.decoder.layers.1.final_layer_norm.bias": (24,), - "model.decoder.layers.1.final_layer_norm.weight": (24,), - "model.decoder.layers.1.self_attn.k_proj.bias": (24,), - "model.decoder.layers.1.self_attn.k_proj.weight": (24, 24), - "model.decoder.layers.1.self_attn.out_proj.bias": (24,), - "model.decoder.layers.1.self_attn.out_proj.weight": (24, 24), - "model.decoder.layers.1.self_attn.q_proj.bias": (24,), - "model.decoder.layers.1.self_attn.q_proj.weight": (24, 24), - "model.decoder.layers.1.self_attn.v_proj.bias": (24,), - "model.decoder.layers.1.self_attn.v_proj.weight": (24, 24), - "model.decoder.layers.1.self_attn_layer_norm.bias": (24,), - "model.decoder.layers.1.self_attn_layer_norm.weight": (24,), - } - - -def test_encoder_decoder_shared_named_params(model): - parameter_shapes = { - name: tuple(param.shape) for name, param in model.encoder_decoder_shared_named_params() - } - assert len(parameter_shapes) == 1 - assert parameter_shapes == {"model.shared.weight": (50270, 24)} - - -def test_base_model_named_params(model): - parameter_shapes = { - name: tuple(param.shape) for name, param in model.base_model_named_params() - } - assert len(parameter_shapes) == 91 - encoder_only_parameter_shapes = { - name: tuple(param.shape) for name, param in model.encoder_only_named_params() - } - decoder_only_parameter_shapes = { - name: tuple(param.shape) for name, param in model.decoder_only_named_params() - } - shared_parameter_shapes = { - name: tuple(param.shape) for name, param in model.encoder_decoder_shared_named_params() - } - expected_parameter_shapes = { - **encoder_only_parameter_shapes, - **decoder_only_parameter_shapes, - **shared_parameter_shapes, - } - - assert parameter_shapes == expected_parameter_shapes - - -def test_configure_optimizer(model): - optimizer = model.configure_optimizer() - assert isinstance(optimizer, torch.optim.AdamW) - assert optimizer.defaults["lr"] == 0.001 - assert optimizer.defaults["weight_decay"] == model.config.weight_decay == 0.01 - assert len(optimizer.param_groups) == 6 - assert all(param_group["lr"] == model.config.lr for param_group in optimizer.param_groups) - - # head parameters - assert optimizer.param_groups[0]["weight_decay"] == model.config.weight_decay == 0.01 - # decoder only layer norm parameters - assert optimizer.param_groups[1]["weight_decay"] == model.config.weight_decay == 0.01 - # decoder only other parameters - assert optimizer.param_groups[2]["weight_decay"] == model.config.weight_decay == 0.01 - # encoder only layer norm parameters - assert ( - optimizer.param_groups[3]["weight_decay"] == model.config.encoder_layer_norm_decay == 0.001 - ) - # encoder only other parameters - assert optimizer.param_groups[4]["weight_decay"] == model.config.weight_decay == 0.01 - # encoder-decoder shared parameters - assert optimizer.param_groups[5]["weight_decay"] == model.config.weight_decay == 0.01 - - all_optimized_parameters = set() - for param_group in optimizer.param_groups: - all_optimized_parameters.update(set(param_group["params"])) - assert len(all_optimized_parameters) > 0 - # check that all model parameters are covered - all_model_parameters = {param for name, param in model.named_parameters()} - assert all_optimized_parameters == all_model_parameters - - -# note that this is only used for the tests below which are marked as slow -# and are primarily meant to show how beam search works -@pytest.fixture(scope="module") -def pretrained_model() -> BartAsPointerNetwork: - torch.random.manual_seed(42) - model = BartAsPointerNetwork.from_pretrained( - "sshleifer/distilbart-xsum-12-1", - # label id space - bos_token_id=0, # taskmodule.bos_id, - eos_token_id=1, # taskmodule.eos_id, - pad_token_id=1, # taskmodule.eos_id, - # target token id space - target_token_ids=[0, 2, 50266, 50269, 50268, 50265, 50267], # taskmodule.target_token_ids, - # mapping to better initialize the label embedding weights - # taken from taskmodule.label_embedding_weight_mapping - embedding_weight_mapping={ - 50266: [39763], - 50269: [10166], - 50268: [5970], - 50265: [45260], - 50267: [354, 1215, 9006], - }, - decoder_position_id_mode="pattern", - decoder_position_id_pattern=[0, 0, 1, 0, 0, 1, 1], - ) - - return model - - -ARTICLE_TO_SUMMARIZE = ( - "PG&E stated it scheduled the blackouts in response to forecasts for high winds " - "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " - "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." -) - - -@pytest.mark.slow -def test_bart_pointer_network_beam_search(pretrained_model, taskmodule): - model = pretrained_model - encoder_input_str = ARTICLE_TO_SUMMARIZE # "translate English to German: How old are you?" - encoder_input_tokenized = taskmodule.tokenizer(encoder_input_str, return_tensors="pt") - - # lets run beam search using 3 beams - num_beams = 3 - # define decoder start token ids - decoder_input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) - decoder_input_ids = decoder_input_ids * model.config.decoder_start_token_id - - # add encoder_outputs to model keyword arguments - encoder = model.get_encoder() - encoder_input_ids = encoder_input_tokenized.input_ids.repeat_interleave(num_beams, dim=0) - encoder_attention_mask = encoder_input_tokenized.attention_mask.repeat_interleave( - num_beams, dim=0 - ) - torch.manual_seed(42) - encoder_outputs = encoder(encoder_input_ids, return_dict=True) - model_kwargs = { - "encoder_outputs": encoder_outputs, - "encoder_input_ids": encoder_input_ids, - "encoder_attention_mask": encoder_attention_mask, - } - - # instantiate beam scorer - beam_scorer = BeamSearchScorer( - batch_size=1, - num_beams=num_beams, - device=model.device, - ) - - # instantiate logits processors - logits_processor = LogitsProcessorList( - [ - MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), - ] - ) - - torch.manual_seed(42) - outputs = model.beam_search( - decoder_input_ids, - beam_scorer, - logits_processor=logits_processor, - pad_token_id=model.config.pad_token_id, - eos_token_id=model.config.eos_token_id, - max_length=20, - **model_kwargs, - ) - - torch.testing.assert_close( - outputs, - torch.tensor( - [[0, 10, 30, 53, 54, 45, 15, 16, 17, 33, 33, 33, 35, 33, 58, 39, 41, 35, 33, 35]] - ), - ) - - # result = tokenizer.batch_decode(outputs, skip_special_tokens=True) - # assert result == [ - # " power lines in California have been shut down after a power provider said it was due to high winds." - # ] - - -@pytest.mark.slow -def test_bart_pointer_network_generate_with_scores(pretrained_model, taskmodule): - model = pretrained_model - encoder_input_str = ARTICLE_TO_SUMMARIZE # "translate English to German: How old are you?" - inputs = taskmodule.tokenizer(encoder_input_str, max_length=1024, return_tensors="pt") - - torch.manual_seed(42) - outputs = model.generate( - inputs["input_ids"], - num_beams=3, - min_length=5, - max_length=20, - return_dict_in_generate=True, - output_scores=True, - ) - assert isinstance(outputs, BeamSearchEncoderDecoderOutput) - torch.testing.assert_close(outputs.sequences_scores, torch.tensor([-8.088160514831543])) - torch.testing.assert_close( - outputs.sequences, - torch.tensor( - [[0, 10, 30, 53, 54, 45, 15, 16, 17, 33, 33, 33, 35, 33, 58, 39, 41, 35, 33, 35]] - ), - ) - - # result = tokenizer.batch_decode( - # summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - # ) - # assert result == [" power lines in California have been shut down on Friday."] diff --git a/tests/models/base_models/test_bart_with_decoder_position_ids.py b/tests/models/base_models/test_bart_with_decoder_position_ids.py deleted file mode 100644 index 0b814c64b..000000000 --- a/tests/models/base_models/test_bart_with_decoder_position_ids.py +++ /dev/null @@ -1,301 +0,0 @@ -import pytest -import torch -from torch.nn import Embedding -from transformers import BartConfig -from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions -from transformers.models.bart.modeling_bart import BartEncoder - -from pie_modules.models.base_models import BartModelWithDecoderPositionIds -from pie_modules.models.base_models.bart_with_decoder_position_ids import ( - BartDecoderWithPositionIds, - BartLearnedPositionalEmbeddingWithPositionIds, -) - - -def test_bart_learned_positional_embedding_with_position_ids(): - # Arrange - torch.manual_seed(42) - model = BartLearnedPositionalEmbeddingWithPositionIds(10, 6) - input_ids = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]) - position_ids_original = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]) - position_ids_different = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 2]]) - - # Act - original = model(input_ids=input_ids) - replaced_original = model(input_ids=input_ids, position_ids=position_ids_original) - replaced_different = model(input_ids=input_ids, position_ids=position_ids_different) - - # Assert - assert original.shape == (1, 10, 6) - assert replaced_original.shape == (1, 10, 6) - torch.testing.assert_close(original, replaced_original) - assert replaced_different.shape == (1, 10, 6) - assert not torch.allclose(original, replaced_different) - - -@pytest.fixture(scope="module") -def bart_config(): - return BartConfig( - vocab_size=30, - d_model=10, - encoder_layers=1, - decoder_layers=1, - encoder_attention_heads=2, - decoder_attention_heads=2, - encoder_ffn_dim=20, - decoder_ffn_dim=20, - max_position_embeddings=10, - ) - - -@pytest.fixture(scope="module") -def bart_decoder_with_position_ids(bart_config): - return BartDecoderWithPositionIds(config=bart_config) - - -def test_bart_decoder_with_position_ids(bart_decoder_with_position_ids): - assert bart_decoder_with_position_ids is not None - - -def test_bart_decoder_with_position_ids_get_input_embeddings(bart_decoder_with_position_ids): - input_embeddings = bart_decoder_with_position_ids.get_input_embeddings() - assert input_embeddings is not None - assert isinstance(input_embeddings, Embedding) - assert input_embeddings.embedding_dim == 10 - assert input_embeddings.num_embeddings == 30 - - -def test_bart_decoder_with_position_ids_set_input_embeddings(bart_decoder_with_position_ids): - original_input_embeddings = bart_decoder_with_position_ids.get_input_embeddings() - torch.manual_seed(42) - new_input_embeddings = Embedding( - original_input_embeddings.num_embeddings, original_input_embeddings.embedding_dim - ) - bart_decoder_with_position_ids.set_input_embeddings(new_input_embeddings) - input_embeddings = bart_decoder_with_position_ids.get_input_embeddings() - assert input_embeddings == new_input_embeddings - assert input_embeddings is not original_input_embeddings - # recover original input embeddings - bart_decoder_with_position_ids.set_input_embeddings(original_input_embeddings) - - -def test_bart_decoder_with_position_ids_forward(bart_decoder_with_position_ids): - # Arrange - model = bart_decoder_with_position_ids - input_ids = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]) - position_ids_original = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]) - position_ids_different = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2]]) - - # Act - torch.manual_seed(42) - original = model(input_ids=input_ids) - torch.manual_seed(42) - replaced_original = model(input_ids=input_ids, position_ids=position_ids_original) - torch.manual_seed(42) - replaced_different = model(input_ids=input_ids, position_ids=position_ids_different) - - # Assert - assert isinstance(original, BaseModelOutputWithPastAndCrossAttentions) - assert original.last_hidden_state.shape == (1, 8, 10) - assert isinstance(replaced_original, BaseModelOutputWithPastAndCrossAttentions) - torch.testing.assert_close(original.last_hidden_state, replaced_original.last_hidden_state) - - assert isinstance(replaced_different, BaseModelOutputWithPastAndCrossAttentions) - assert replaced_different.last_hidden_state.shape == (1, 8, 10) - assert not torch.allclose(original.last_hidden_state, replaced_different.last_hidden_state) - - -def test_bart_decoder_with_position_ids_forward_with_inputs_embeds(bart_decoder_with_position_ids): - # Arrange - model = bart_decoder_with_position_ids - inputs_embeds = torch.randn(1, 8, 10) - position_ids_original = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]) - position_ids_different = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2]]) - - # Act - torch.manual_seed(42) - original = model(inputs_embeds=inputs_embeds) - torch.manual_seed(42) - replaced_original = model(inputs_embeds=inputs_embeds, position_ids=position_ids_original) - torch.manual_seed(42) - replaced_different = model(inputs_embeds=inputs_embeds, position_ids=position_ids_different) - - # Assert - assert isinstance(original, BaseModelOutputWithPastAndCrossAttentions) - assert original.last_hidden_state.shape == (1, 8, 10) - assert isinstance(replaced_original, BaseModelOutputWithPastAndCrossAttentions) - torch.testing.assert_close(original.last_hidden_state, replaced_original.last_hidden_state) - - assert isinstance(replaced_different, BaseModelOutputWithPastAndCrossAttentions) - assert replaced_different.last_hidden_state.shape == (1, 8, 10) - assert not torch.allclose(original.last_hidden_state, replaced_different.last_hidden_state) - - -def test_bart_decoder_with_position_ids_forward_wrong_position_ids_shape( - bart_decoder_with_position_ids, -): - # Arrange - model = bart_decoder_with_position_ids - input_ids = torch.tensor([[0, 1, 2, 3]]) - position_ids_wrong_shape = torch.tensor([[0, 1, 2]]) - - # Act - torch.manual_seed(42) - with pytest.raises(ValueError) as excinfo: - model(input_ids=input_ids, position_ids=position_ids_wrong_shape) - assert ( - str(excinfo.value) - == "Position IDs shape torch.Size([1, 3]) does not match input ids shape torch.Size([1, 4])." - ) - - -@pytest.fixture(scope="module") -def bart_model_with_decoder_position_ids(bart_config): - torch.manual_seed(42) - model = BartModelWithDecoderPositionIds(config=bart_config) - model.train() - return model - - -def test_bart_model_with_decoder_position_ids(bart_model_with_decoder_position_ids): - assert bart_model_with_decoder_position_ids is not None - - -def test_bart_model_with_decoder_position_ids_get_input_embeddings( - bart_model_with_decoder_position_ids, -): - input_embeddings = bart_model_with_decoder_position_ids.get_input_embeddings() - assert input_embeddings is not None - assert isinstance(input_embeddings, Embedding) - assert input_embeddings.embedding_dim == 10 - assert input_embeddings.num_embeddings == 30 - - -def test_bart_model_with_decoder_position_ids_set_input_embeddings( - bart_model_with_decoder_position_ids, -): - original_input_embeddings = bart_model_with_decoder_position_ids.get_input_embeddings() - torch.manual_seed(42) - new_input_embeddings = Embedding( - original_input_embeddings.num_embeddings, original_input_embeddings.embedding_dim - ) - bart_model_with_decoder_position_ids.set_input_embeddings(new_input_embeddings) - input_embeddings = bart_model_with_decoder_position_ids.get_input_embeddings() - assert input_embeddings == new_input_embeddings - assert input_embeddings is not original_input_embeddings - # recover original input embeddings - bart_model_with_decoder_position_ids.set_input_embeddings(original_input_embeddings) - - -def test_bart_model_with_decoder_position_ids_get_encoder(bart_model_with_decoder_position_ids): - encoder = bart_model_with_decoder_position_ids.get_encoder() - assert encoder is not None - assert isinstance(encoder, BartEncoder) - - -def test_bart_model_with_decoder_position_ids_get_decoder(bart_model_with_decoder_position_ids): - decoder = bart_model_with_decoder_position_ids.get_decoder() - assert decoder is not None - assert isinstance(decoder, BartDecoderWithPositionIds) - - -@pytest.mark.parametrize( - "return_dict, prepare_encoder_outputs, output_everything", - [(True, True, True), (False, False, False)], -) -def test_bart_model_with_decoder_position_forward( - bart_model_with_decoder_position_ids, return_dict, prepare_encoder_outputs, output_everything -): - model = bart_model_with_decoder_position_ids - - # Arrange - model.eval() - input_ids = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]) - position_ids_original = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]) - position_ids_different = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2]]) - common_kwargs = {"input_ids": input_ids, "return_dict": return_dict} - if prepare_encoder_outputs: - common_kwargs["encoder_outputs"] = bart_model_with_decoder_position_ids.get_encoder()( - input_ids=input_ids, return_dict=False - ) - else: - common_kwargs["encoder_outputs"] = None - if output_everything: - common_kwargs["output_attentions"] = True - common_kwargs["output_hidden_states"] = True - - # Act - original = model(**common_kwargs)[0] - replaced_original = model( - decoder_position_ids=position_ids_original, - **common_kwargs, - )[0] - replaced_different = model(decoder_position_ids=position_ids_different, **common_kwargs)[0] - - # Assert - assert isinstance(original, torch.FloatTensor) - assert original.shape == (1, 8, 10) - torch.testing.assert_close( - original[0, :5, :3], - torch.tensor( - [ - [0.7589594721794128, 1.0452316999435425, 0.7063764333724976], - [-0.12192550301551819, -0.9932114481925964, -0.722382664680481], - [0.24711951613426208, -0.291597843170166, -1.0466505289077759], - [1.1228691339492798, -0.0873560905456543, 1.534016728401184], - [-1.1132177114486694, 0.2277398556470871, 1.6456809043884277], - ] - ), - ) - torch.testing.assert_close( - original.sum(dim=-1), - torch.tensor( - [ - [ - 0.0, - -1.1920928955078125e-07, - -1.1920928955078125e-07, - -2.682209014892578e-07, - 5.960464477539063e-08, - 5.960464477539063e-08, - 2.384185791015625e-07, - -5.960464477539063e-08, - ] - ] - ), - ) - assert isinstance(replaced_original, torch.FloatTensor) - torch.testing.assert_close(original, replaced_original) - - assert isinstance(replaced_different, torch.FloatTensor) - assert replaced_different.shape == (1, 8, 10) - torch.testing.assert_close( - replaced_different[0, :5, :3], - torch.tensor( - [ - [0.7589594721794128, 1.0452316999435425, 0.7063764333724976], - [-0.0127173513174057, -0.8127143383026123, -1.256797194480896], - [1.0517312288284302, 0.037927787750959396, -0.28661563992500305], - [0.5884698629379272, 0.9930593371391296, 1.3842554092407227], - [0.6132885813713074, -1.0105736255645752, 2.361264228820801], - ] - ), - ) - torch.testing.assert_close( - replaced_different.sum(dim=-1), - torch.tensor( - [ - [ - 0.0, - -2.384185791015625e-07, - -1.7881393432617188e-07, - 2.5331974029541016e-07, - 1.4901161193847656e-07, - 1.1920928955078125e-07, - -1.1920928955078125e-07, - -1.7881393432617188e-07, - ] - ] - ), - ) - assert not torch.allclose(replaced_different, original) diff --git a/tests/models/components/__init__.py b/tests/models/components/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/models/components/test_pointer_head.py b/tests/models/components/test_pointer_head.py deleted file mode 100644 index eb2650f68..000000000 --- a/tests/models/components/test_pointer_head.py +++ /dev/null @@ -1,713 +0,0 @@ -import pytest -import torch -from torch import nn - -from pie_modules.models.components.pointer_head import PointerHead - - -def get_pointer_head(num_embeddings=120, embedding_dim=3, eos_id=1, pad_id=2, **kwargs): - torch.manual_seed(42) - return PointerHead( - embeddings=nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim), - # bos, eos, pad, 3 x label ids - target_token_ids=[100, 101, 102, 110, 111, 112], - bos_id=0, # -> 100 - eos_id=eos_id, # 1 (default) -> 101 - pad_id=pad_id, # 2 (default) -> 102 - embedding_weight_mapping={ - "110": [20, 21], - "111": [30], - }, - use_encoder_mlp=True, - use_constraints_encoder_mlp=True, - **kwargs, - ) - - -def test_get_pointer_head(): - pointer_head = get_pointer_head() - assert pointer_head is not None - assert not pointer_head.use_prepared_position_ids - - -def test_set_embeddings(): - pointer_head = get_pointer_head() - original_embeddings = pointer_head.embeddings - new_embeddings = nn.Embedding( - original_embeddings.num_embeddings, original_embeddings.embedding_dim - ) - pointer_head.set_embeddings(new_embeddings) - assert pointer_head.embeddings is not None - assert pointer_head.embeddings != original_embeddings - assert pointer_head.embeddings == new_embeddings - - -def test_overwrite_embeddings_with_mapping(): - pointer_head = get_pointer_head() - original_embeddings_weight = pointer_head.embeddings.weight.clone() - pointer_head.overwrite_embeddings_with_mapping() - assert pointer_head.embeddings is not None - assert not torch.equal(pointer_head.embeddings.weight, original_embeddings_weight) - torch.testing.assert_close( - pointer_head.embeddings.weight[110], original_embeddings_weight[[20, 21]].mean(dim=0) - ) - torch.testing.assert_close( - pointer_head.embeddings.weight[111], original_embeddings_weight[[30]].mean(dim=0) - ) - - -@pytest.mark.parametrize( - "use_attention_mask", - [True, False], -) -def test_prepare_decoder_input_ids(use_attention_mask): - pointer_head = get_pointer_head() - encoder_input_ids = torch.tensor( - [ - [10, 11, 12, 13, 14, 15], - [20, 21, 22, 23, 24, 0], - ] - ).to(torch.long) - # we have 3 special tokens (bos, eos, pad) and 3 labels, so the offset is 6 - input_ids = torch.tensor( - [ - # bos, offset (0=6-6), offset (1=7-6), label (3), label (4), offset (2=8-6) - [0, 6, 7, 3, 4, 8], - # bos, label (3), offset (3=9-6), eos, pad, pad - [0, 3, 9, 1, 2, 2], - ] - ).to(torch.long) - # this is the attention mask for the (decoder) input_ids, not the encoder_input_ids - attention_mask = ( - torch.tensor( - [ - [1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 0, 0], - ] - ).to(torch.long) - if use_attention_mask - else None - ) - - prepared_decoder_input_ids = pointer_head.prepare_decoder_input_ids( - input_ids=input_ids, - encoder_input_ids=encoder_input_ids, - ) - assert prepared_decoder_input_ids is not None - assert prepared_decoder_input_ids.shape == input_ids.shape - # to recap, the target2token_id mapping is (bos, eos, pad, 3 x label ids) - torch.testing.assert_close( - pointer_head.target2token_id, torch.tensor([100, 101, 102, 110, 111, 112]) - ) - # 3 labels + bos / pad - assert pointer_head.pointer_offset == 6 - assert prepared_decoder_input_ids.tolist() == [ - # bos (0), offset (0=6-6), offset (1=7-6), label (3), label (4), offset (2=8-6) - [100, 10, 11, 110, 111, 12], - # bos (0), label (3), offset (3=9-6), eos (1), pad (2), pad (2) - [100, 110, 23, 101, 102, 102], - ] - - -def test_prepare_decoder_input_ids_out_of_bounds(): - pointer_head = get_pointer_head() - # 3 labels + bos / pad - assert pointer_head.pointer_offset == 6 - encoder_input_ids = torch.tensor( - [ - [100, 101, 102], - ] - ).to(torch.long) - input_ids = torch.tensor( - [ - # 9 is out of bounds: > pointer_head.pointer_offset + len(encoder_input_ids) - [0, 9], - ] - ).to(torch.long) - - with pytest.raises(ValueError) as excinfo: - pointer_head.prepare_decoder_input_ids( - input_ids=input_ids, encoder_input_ids=encoder_input_ids - ) - assert str(excinfo.value) == ( - "encoder_input_ids_index.max() [3] must be smaller than encoder_input_length [3]!" - ) - - -@pytest.mark.parametrize( - "decoder_position_id_mode", - ["pattern", "pattern_with_increment", "mapping"], -) -def test_prepare_decoder_position_ids(decoder_position_id_mode): - pointer_head = get_pointer_head( - decoder_position_id_mode=decoder_position_id_mode, - decoder_position_id_pattern=[0, 1, 1, 2], - decoder_position_id_mapping={"default": 3, "vocab": 2, "bos": 0, "eos": 0, "pad": 1}, - ) - input_ids = torch.tensor( - [ - # bos, offset (0=6-6), offset (1=7-6), label (3), label (4), offset (2=8-6) - [0, 6, 7, 3, 4, 8], - # bos, label (3), offset (3=9-6), eos, pad, pad - [0, 3, 9, 1, 2, 2], - ] - ).to(torch.long) - - prepared_decoder_position_ids = pointer_head.prepare_decoder_position_ids(input_ids=input_ids) - assert prepared_decoder_position_ids is not None - assert prepared_decoder_position_ids.shape == input_ids.shape - if decoder_position_id_mode == "pattern": - assert prepared_decoder_position_ids.tolist() == [ - [0, 2, 3, 3, 4, 2], - [0, 2, 3, 3, 1, 1], - ] - elif decoder_position_id_mode == "pattern_with_increment": - # the position ids (except for position-bos=0 and position-pad=1) get increased by 3 per record - # (which has length 4) - assert prepared_decoder_position_ids.tolist() == [ - [0, 2, 3, 3, 4, 5], - [0, 2, 3, 3, 1, 1], - ] - elif decoder_position_id_mode == "mapping": - assert prepared_decoder_position_ids.tolist() == [ - [0, 3, 3, 2, 2, 3], - [0, 2, 3, 0, 1, 1], - ] - else: - raise ValueError(f"unknown decoder_position_id_mode={decoder_position_id_mode}") - - -def test_prepare_decoder_position_ids_unknown_mode(): - with pytest.raises(ValueError) as excinfo: - get_pointer_head(decoder_position_id_mode="unknown") - assert str(excinfo.value) == ( - 'decoder_position_id_mode="unknown" is not supported, use one of "pattern", ' - '"pattern_with_increment", or "mapping"!' - ) - - -@pytest.mark.parametrize( - "decoder_position_id_mode", - ["pattern", "pattern_with_increment", "mapping"], -) -def test_prepare_decoder_position_ids_missing_parameter(decoder_position_id_mode): - with pytest.raises(ValueError) as excinfo: - get_pointer_head(decoder_position_id_mode=decoder_position_id_mode) - if decoder_position_id_mode in ["pattern", "pattern_with_increment"]: - assert ( - str(excinfo.value) == "decoder_position_id_pattern must be provided when using " - 'decoder_position_id_mode="pattern" or "pattern_with_increment"!' - ) - elif decoder_position_id_mode == "mapping": - assert ( - str(excinfo.value) - == 'decoder_position_id_mode="mapping" requires decoder_position_id_mapping to be provided!' - ) - else: - raise ValueError(f"unknown decoder_position_id_mode={decoder_position_id_mode}") - - -def test_prepare_decoder_position_ids_with_wrong_mapping(): - input_ids = torch.tensor( - [ - # bos, offset (0=6-6), offset (1=7-6), label (3), label (4), offset (2=8-6) - [0, 6, 7, 3, 4, 8], - # bos, label (3), offset (3=9-6), eos, pad, pad - [0, 3, 9, 1, 2, 2], - ] - ).to(torch.long) - - # missing default - pointer_head = get_pointer_head( - decoder_position_id_mode="mapping", - decoder_position_id_mapping={"vocab": 2, "bos": 0, "eos": 0, "pad": 1}, - ) - with pytest.raises(ValueError) as excinfo: - pointer_head.prepare_decoder_position_ids(input_ids=input_ids) - assert ( - str(excinfo.value) - == "mapping must contain a default entry, but only contains ['vocab', 'bos', 'eos', 'pad']!" - ) - - # unknown key - pointer_head = get_pointer_head( - decoder_position_id_mode="mapping", - decoder_position_id_mapping={ - "default": 3, - "vocab": 2, - "bos": 0, - "eos": 0, - "pad": 1, - "unknown": 4, - }, - ) - with pytest.raises(ValueError) as excinfo: - pointer_head.prepare_decoder_position_ids(input_ids=input_ids) - assert ( - str(excinfo.value) == "Mapping contains unknown key 'unknown' " - "(mapping: {'default': 3, 'vocab': 2, 'bos': 0, 'eos': 0, 'pad': 1, 'unknown': 4})." - ) - - # multiple values for same input id - pointer_head = get_pointer_head( - # same id for eos and pad - eos_id=1, - pad_id=1, - decoder_position_id_mode="mapping", - decoder_position_id_mapping={ - "default": 3, - "vocab": 2, - "bos": 0, - # different position ids for eos and pad, this is not allowed when eos and pad have the same id - "eos": 0, - "pad": 1, - }, - ) - with pytest.raises(ValueError) as excinfo: - pointer_head.prepare_decoder_position_ids(input_ids=input_ids) - assert ( - str(excinfo.value) - == "Can not set the position ids for 'pad' to 1 because it was already set to 0 by key 'eos'. " - "Note that both, 'pad' and 'eos', have the same id (1), so their position_ids need to be " - "also the same (position id mapping: {'default': 3, 'vocab': 2, 'bos': 0, 'eos': 0, 'pad': 1})." - ) - - -def test_prepare_decoder_inputs(): - pointer_head = get_pointer_head( - decoder_position_id_mode="pattern", decoder_position_id_pattern=[0, 1, 1, 2] - ) - encoder_input_ids = torch.tensor( - [ - [10, 11, 12, 13, 14, 15], - [20, 21, 22, 23, 24, 0], - ] - ).to(torch.long) - input_ids = torch.tensor( - [ - # bos, offset (0=6-6), offset (1=7-6), label (3), label (4), offset (2=8-6) - [0, 6, 7, 3, 4, 8], - # bos, label (3), offset (3=9-6), eos, pad, pad - [0, 3, 9, 1, 2, 2], - ] - ).to(torch.long) - - decoder_inputs = pointer_head.prepare_decoder_inputs( - input_ids=input_ids, - encoder_input_ids=encoder_input_ids, - ) - assert set(decoder_inputs.keys()) == {"input_ids", "position_ids"} - assert decoder_inputs["input_ids"].shape == input_ids.shape - assert decoder_inputs["position_ids"].shape == input_ids.shape - # to recap, the target2token_id mapping is (bos, eos, pad, 3 x label ids) - torch.testing.assert_close( - pointer_head.target2token_id, torch.tensor([100, 101, 102, 110, 111, 112]) - ) - # 3 labels + bos / pad - assert pointer_head.pointer_offset == 6 - assert decoder_inputs["input_ids"].tolist() == [ - # bos (0), offset (0=6-6), offset (1=7-6), label (3), label (4), offset (2=8-6) - [100, 10, 11, 110, 111, 12], - # bos (0), label (3), offset (3=9-6), eos (1), pad (2), pad (2) - [100, 110, 23, 101, 102, 102], - ] - assert decoder_inputs["position_ids"].tolist() == [ - [0, 2, 3, 3, 4, 2], - [0, 2, 3, 3, 1, 1], - ] - - -def test_forward(): - pointer_head = get_pointer_head() - # shape: (batch_size=2, input_sequence_length=5) - encoder_input_ids = torch.tensor( - [ - [10, 11, 12, 13, 14], - [20, 21, 22, 23, 0], - ] - ).to(torch.long) - encoder_attention_mask = torch.tensor( - [ - [1, 1, 1, 1, 1], - [1, 1, 1, 1, 0], - ] - ).to(torch.long) - # shape: (batch_size=2, input_sequence_length=5, hidden_size=3) - encoder_last_hidden_state = pointer_head.embeddings(encoder_input_ids) - # shape: (batch_size=2, target_sequence_length=4) - prepared_input_ids = torch.tensor( - [ - # bos (0), offset (0=6-6), offset (1=7-6), label (3) - [100, 10, 11, 110], - # bos (0), label (3), offset (3=9-6), eos (1) - [100, 110, 23, 101], - ] - ).to(torch.long) - # shape: (batch_size=2, target_sequence_length=4) - last_hidden_state = pointer_head.embeddings(prepared_input_ids) - - torch.manual_seed(42) - logits, loss = pointer_head( - encoder_input_ids=encoder_input_ids, - encoder_attention_mask=encoder_attention_mask, - encoder_last_hidden_state=encoder_last_hidden_state, - last_hidden_state=last_hidden_state, - ) - assert loss is None - assert logits is not None - # shape: (batch_size=2, target_sequence_length=4, num_targets+num_offsets=6+5==11) - assert logits.shape == (2, 4, 11) - torch.testing.assert_close( - logits, - torch.tensor( - [ - [ - [ - -1.0000000138484279e24, - -0.9407045245170593, - -1.0000000138484279e24, - 0.5535521507263184, - 0.04295700043439865, - 1.0467679500579834, - -1.110795497894287, - 1.1652655601501465, - 0.09444020688533783, - 0.43052661418914795, - -1.0437036752700806, - ], - [ - -1.0000000138484279e24, - 1.1563994884490967, - -1.0000000138484279e24, - -0.8941665887832642, - -0.6862093806266785, - -1.154745101928711, - 1.6984729766845703, - -1.3889904022216797, - -0.4076152741909027, - -1.0112841129302979, - 0.9846026301383972, - ], - [ - -1.0000000138484279e24, - -1.9377808570861816, - -1.0000000138484279e24, - 2.437451124191284, - 0.041493892669677734, - 0.5383729338645935, - -1.5238577127456665, - 1.6700562238693237, - -0.07231226563453674, - 1.0911093950271606, - -0.9189060926437378, - ], - [ - -1.0000000138484279e24, - -1.880744218826294, - -1.0000000138484279e24, - 3.8719429969787598, - 0.07287894189357758, - -1.3378281593322754, - -0.653921365737915, - 0.783344566822052, - -0.3344290256500244, - 1.3571363687515259, - 0.5505899786949158, - ], - ], - [ - [ - -1.0000000138484279e24, - -0.9407045245170593, - -1.0000000138484279e24, - 0.5535521507263184, - 0.04295700043439865, - 1.0467679500579834, - -1.0019789934158325, - 0.6891120672225952, - -0.002076566219329834, - 0.7561025619506836, - -1.0000000331813535e32, - ], - [ - -1.0000000138484279e24, - -1.880744218826294, - -1.0000000138484279e24, - 3.8719429969787598, - 0.07287894189357758, - -1.3378281593322754, - -1.3875324726104736, - -2.124865770339966, - -2.559859275817871, - 0.5425653457641602, - -1.0000000331813535e32, - ], - [ - -1.0000000138484279e24, - -1.479057788848877, - -1.0000000138484279e24, - 1.7857770919799805, - 0.6723557114601135, - 0.6378745436668396, - -2.262815475463867, - -0.1536862850189209, - -0.5338708758354187, - 1.3628911972045898, - -1.0000000331813535e32, - ], - [ - -1.0000000138484279e24, - 1.1815755367279053, - -1.0000000138484279e24, - -1.880744218826294, - -0.10646091401576996, - 0.1437276005744934, - 1.0795626640319824, - 0.6434042453765869, - 1.0681594610214233, - -0.5814396142959595, - -1.0000000331813535e32, - ], - ], - ] - ), - ) - - -@pytest.mark.parametrize( - "with_constraints", - [True, False], -) -def test_forward_with_labels(with_constraints): - pointer_head = get_pointer_head(num_embeddings=300, embedding_dim=3) - - # shape: (batch_size=2, input_sequence_length=5) - encoder_input_ids = torch.tensor( - [ - [10, 11, 12, 13, 14], - [20, 21, 22, 0, 0], - ] - ).to(torch.long) - encoder_attention_mask = torch.tensor( - [ - [1, 1, 1, 1, 1], - [1, 1, 1, 0, 0], - ] - ).to(torch.long) - # shape: (batch_size=2, input_sequence_length=5, hidden_size=3) - # encoder_last_hidden_state = pointer_head.embeddings(encoder_input_ids) - # shape: (batch_size=2, target_sequence_length=4) - prepared_input_ids = torch.tensor( - [ - # bos (0), offset (0=6-6), offset (1=7-6), label (3) - [100, 10, 11, 110], - # bos (0), label (3), offset (3=9-6), eos (1) - [100, 110, 23, 101], - ] - ).to(torch.long) - # shape: (batch_size=2, target_sequence_length=4) - # last_hidden_state = pointer_head.embeddings(prepared_input_ids) - labels = torch.tensor( - [ - # offset (0=6-6), offset (1=7-6), label (3), label (4) - [6, 7, 3, 4], - # label (3), offset (3=9-6), eos, pad, pad - [3, 9, 1, 2], - ] - ).to(torch.long) - decoder_attention_mask = torch.tensor( - [ - [1, 1, 1, 1], - [1, 1, 1, 0], - ] - ).to(torch.long) - - # shape: (batch_size=2, target_sequence_length=4, num_targets+num_offsets=6+5==11) - constraints = ( - # recap: the target2token_id mapping is (bos, eos, pad, 3 x label ids) - torch.tensor( - [ - [ - # allow all labels - [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0], - # allow all offsets different from previous label id (3) - [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0], - # allow all offsets different from previous label ids (3, 4) - [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], - # allow all offsets - [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1], - ], - [ - # allow all labels - [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0], - # allow all offsets - [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1], - # allow all offsets equal or bigger than previous one (9) or eos - [0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1], - # allow only pad - [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], - ], - ] - ).to(torch.long) - ) - - torch.manual_seed(42) - # shape: (batch_size=2, input_sequence_length=6, hidden_size=3) - encoder_last_hidden_state = pointer_head.embeddings(encoder_input_ids) - last_hidden_state = pointer_head.embeddings(prepared_input_ids) - _, loss = pointer_head( - encoder_input_ids=encoder_input_ids, - encoder_attention_mask=encoder_attention_mask, - encoder_last_hidden_state=encoder_last_hidden_state, - last_hidden_state=last_hidden_state, - labels=labels, - decoder_attention_mask=decoder_attention_mask, - constraints=constraints if with_constraints else None, - ) - assert loss is not None - maybe_gradients = torch.autograd.grad(loss, pointer_head.parameters(), allow_unused=True) - gradients = [g for g in maybe_gradients if g is not None] - if not with_constraints: - # embeddings.weight, 2 x (encoder_mlp.weight, encoder_mlp.bias) - assert len(gradients) == 5 - # embeddings.weight (just check entries for special tokens and labels) - torch.testing.assert_close( - gradients[0][100:113], - torch.tensor( - [ - [0.29642319679260254, 0.012336060404777527, 0.14099650084972382], - [0.015981415286660194, 0.17855659127235413, -0.21089009940624237], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [-0.8812153935432434, -0.43322375416755676, 0.07359108328819275], - [0.22255337238311768, 0.09604272246360779, 0.017692387104034424], - [-0.021408570930361748, -0.01747075282037258, 0.15882402658462524], - ] - ), - ) - # first encoder_mlp.weight - torch.testing.assert_close( - gradients[1], - torch.tensor( - [ - [6.044770998414606e-05, -0.001140016596764326, 0.0007320810691453516], - [0.014351745136082172, 0.01521987747400999, -0.028653975576162338], - [0.011420723050832748, 0.0070406426675617695, -0.030101824551820755], - ] - ), - ) - # first encoder_mlp.bias - torch.testing.assert_close( - gradients[2], - torch.tensor([-0.0006180311902426183, -0.023118967190384865, -0.024205176159739494]), - ) - # second encoder_mlp.weight - torch.testing.assert_close( - gradients[3], - torch.tensor( - [ - [-0.0005463349516503513, -0.016356423497200012, 0.01958528161048889], - [-0.0005303063080646098, -0.029644077643752098, -0.1391362100839615], - [0.0028533015865832567, 0.08096987009048462, 0.28279614448547363], - ] - ), - ) - # second encoder_mlp.bias - torch.testing.assert_close( - gradients[4], - torch.tensor([-0.030467912554740906, -0.045307278633117676, 0.06145985424518585]), - ) - else: - # embeddings.weight, 2 x (encoder_mlp.weight, encoder_mlp.bias), 2 x (constraints_encoder_mlp.weight, constraints_encoder_mlp.bias) - assert len(gradients) == 9 - # embeddings.weight (just check entries for special tokens and labels) - torch.testing.assert_close( - gradients[0][100:113], - torch.tensor( - [ - [0.2915953993797302, 0.009700030088424683, 0.1484404355287552], - [0.02216985821723938, 0.15251068770885468, -0.21624334156513214], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [-0.8804605007171631, -0.4300656318664551, 0.0664108395576477], - [0.21543428301811218, 0.093157559633255, 0.013825103640556335], - [-0.021408570930361748, -0.01747075282037258, 0.15882402658462524], - ] - ), - ) - # first encoder_mlp.weight - torch.testing.assert_close( - gradients[1], - torch.tensor( - [ - [-0.0003244421095587313, 0.006118832156062126, -0.003929311875253916], - [0.013681752607226372, 0.013532182201743126, -0.027564184740185738], - [0.012365758419036865, 0.00791379064321518, -0.02969365194439888], - ] - ), - ) - # first encoder_mlp.bias - torch.testing.assert_close( - gradients[2], - torch.tensor([0.003317170077934861, -0.021803036332130432, -0.023893579840660095]), - ) - # second encoder_mlp.weight - torch.testing.assert_close( - gradients[3], - torch.tensor( - [ - [-0.004014550242573023, -0.018573174253106117, 0.019694898277521133], - [-0.0019358742283657193, -0.030542463064193726, -0.13909178972244263], - [0.0009692738531157374, 0.0797656774520874, 0.28285568952560425], - ] - ), - ) - # second encoder_mlp.bias - torch.testing.assert_close( - gradients[4], - torch.tensor([-0.046919066458940506, -0.05197446048259735, 0.05252313241362572]), - ) - # first constraints_encoder_mlp.weight - torch.testing.assert_close( - gradients[5], - torch.tensor( - [ - [0.010755524039268494, -0.009512078016996384, -0.007983260788023472], - [0.004236628767102957, 0.002073169220238924, -0.0010695274686440825], - [-0.008700753562152386, -0.00425766222178936, 0.002196485875174403], - ] - ), - ) - # first constraints_encoder_mlp.bias - torch.testing.assert_close( - gradients[6], - torch.tensor([0.05254765599966049, -0.0024578727316111326, 0.005047726910561323]), - ) - # second constraints_encoder_mlp.weight - torch.testing.assert_close( - gradients[7], - torch.tensor( - [ - [0.004190368112176657, -0.01078515499830246, -0.015312351286411285], - [0.001505501102656126, -0.006679146084934473, -0.009482797235250473], - [0.02189277485013008, -0.010388202033936977, -0.014748772606253624], - ] - ), - ) - # second constraints_encoder_mlp.bias - torch.testing.assert_close( - gradients[8], - torch.tensor([0.016296036541461945, -0.00018996000289916992, 0.05888192355632782]), - ) diff --git a/tests/models/components/test_pooler.py b/tests/models/components/test_pooler.py index a63b36a68..e69de29bb 100644 --- a/tests/models/components/test_pooler.py +++ b/tests/models/components/test_pooler.py @@ -1,218 +0,0 @@ -import pytest -import torch - -from pie_modules.models.components.pooler import ( - CLS_TOKEN, - MENTION_POOLING, - START_TOKENS, - ArgumentWrappedPooler, - AtIndexPooler, - SpanMaxPooler, - SpanMeanPooler, - get_pooler_and_output_size, - pool_cls, -) - - -@pytest.mark.parametrize( - "pooler_type", - [ - CLS_TOKEN, - START_TOKENS, - MENTION_POOLING, - ], -) -def test_get_pooler_and_output_size(pooler_type): - pooler, output_size = get_pooler_and_output_size(config={"type": pooler_type}, input_dim=20) - assert pooler is not None - if pooler_type == CLS_TOKEN: - assert output_size == 20 - elif pooler_type in (START_TOKENS, MENTION_POOLING): - # pre default, num_indices is 2 - assert output_size == 20 * 2 - else: - raise ValueError(f"Unknown pooler type {pooler_type}") - - -@pytest.mark.parametrize("aggregate", ["max", "mean"]) -def test_get_pooler_and_output_size_mention(aggregate): - pooler, output_size = get_pooler_and_output_size( - config={"type": MENTION_POOLING, "aggregate": aggregate}, input_dim=20 - ) - assert pooler is not None - assert output_size == 20 * 2 - if aggregate == "max": - assert isinstance(pooler, SpanMaxPooler) - elif aggregate == "mean": - assert isinstance(pooler, SpanMeanPooler) - else: - raise ValueError(f"Unknown aggregate type {aggregate}") - - -def test_get_pooler_and_output_size_mention_unknown_aggregate(): - with pytest.raises(ValueError) as excinfo: - get_pooler_and_output_size( - config={"type": MENTION_POOLING, "aggregate": "unknown"}, input_dim=20 - ) - assert str(excinfo.value) == 'Unknown aggregation method for mention pooling: "unknown"' - - -def test_get_pooler_and_output_size_wrong_type(): - with pytest.raises(ValueError) as excinfo: - get_pooler_and_output_size(config={"type": "wrong_type"}, input_dim=20) - assert str(excinfo.value) == 'Unknown pooler type "wrong_type"' - - -@pytest.fixture(scope="session") -def hidde_state(): - result = torch.tensor( - [ - [[0.00, 0.01], [0.10, 0.11], [0.20, 0.21], [0.30, 0.31]], - [[1.00, 1.01], [1.10, 1.11], [1.20, 1.21], [1.30, 1.31]], - ] - ) - # batch_size x sequence_length x hidden_size - assert result.shape == (2, 4, 2) - return result - - -def test_pool_cls(hidde_state): - pooler = pool_cls - output = pooler(hidden_state=hidde_state) - assert output is not None - batch_size = hidde_state.shape[0] - hidden_size = hidde_state.shape[-1] - assert output.shape == (batch_size, hidden_size) - torch.testing.assert_close(output, hidde_state[:, 0, :]) - torch.testing.assert_close(output, torch.tensor([[0.00, 0.01], [1.00, 1.01]])) - - -def test_at_index_pooler(hidde_state): - pooler = AtIndexPooler(input_dim=hidde_state.shape[-1], num_indices=2) - indices = torch.tensor([[2, 0], [1, 0]]) - output = pooler(hidden_state=hidde_state, indices=indices) - assert output is not None - batch_size = hidde_state.shape[0] - hidden_size = hidde_state.shape[-1] - # times num_indices (=2) due to concat - assert output.shape == (batch_size, hidden_size * 2) - torch.testing.assert_close( - output, torch.tensor([[0.20, 0.21, 0.00, 0.01], [1.10, 1.11, 1.00, 1.01]]) - ) - - -def test_at_index_pooler_with_offset(hidde_state): - # set the seed to make sure that we get the same missing embeddings - torch.manual_seed(42) - pooler = AtIndexPooler(input_dim=hidde_state.shape[-1], num_indices=2, offset=-1) - indices = torch.tensor([[2, 1], [0, -10]]) - output = pooler(hidden_state=hidde_state, indices=indices) - assert output is not None - batch_size = hidde_state.shape[0] - hidden_size = hidde_state.shape[-1] - # times num_indices (=2) due to concat - assert output.shape == (batch_size, hidden_size * 2) - # the second batch element has out of bounds indices, so we expect the missing embeddings - # it needs to be flattened, because the output is concatenated - torch.testing.assert_close(output[1], pooler.missing_embeddings.view(-1)) - torch.testing.assert_close( - output, - torch.tensor( - [ - [0.10, 0.11, 0.00, 0.01], - [ - 0.33669036626815796, - 0.12880940735340118, - 0.23446236550807953, - 0.23033303022384644, - ], - ] - ), - ) - - -def test_at_index_pooler_wrong_indices_shapes(hidde_state): - pooler = AtIndexPooler(input_dim=hidde_state.shape[-1], num_indices=2) - indices = torch.tensor([[2, 0, 1], [1, 0, 0]]) - with pytest.raises(ValueError) as excinfo: - pooler(hidden_state=hidde_state, indices=indices) - assert str(excinfo.value) == "number of indices [3] has to be the same as num_types [2]" - - -def test_argument_wrapped_pooler(hidde_state): - def dummy_pooler(hidden_state, y): - return hidden_state[:, 0, :] - - pooler = ArgumentWrappedPooler(pooler=dummy_pooler, argument_mapping={"x": "y"}) - output = pooler(hidden_state=hidde_state, x="dummy") - assert output is not None - batch_size = hidde_state.shape[0] - hidden_size = hidde_state.shape[-1] - assert output.shape == (batch_size, hidden_size) - torch.testing.assert_close(output, hidde_state[:, 0, :]) - - -def test_span_max_pooler(hidde_state): - pooler = SpanMaxPooler(input_dim=hidde_state.shape[-1], num_indices=2) - start_indices = torch.tensor([[2, 0], [0, 1]]) - end_indices = torch.tensor([[3, 3], [1, 2]]) - output = pooler(hidden_state=hidde_state, start_indices=start_indices, end_indices=end_indices) - assert output is not None - batch_size = hidde_state.shape[0] - hidden_size = hidde_state.shape[-1] - # times num_indices (=2) due to concat - assert output.shape == (batch_size, hidden_size * 2) - torch.testing.assert_close( - output, torch.tensor([[0.20, 0.21, 0.20, 0.21], [1.00, 1.01, 1.10, 1.11]]) - ) - - -def test_span_max_pooler_wrong_start_indices_shape(hidde_state): - pooler = SpanMaxPooler(input_dim=hidde_state.shape[-1], num_indices=2) - start_indices = torch.tensor([[2, 0, 1], [0, 1, 0]]) - end_indices = torch.tensor([[3, 3], [1, 2]]) - with pytest.raises(ValueError) as excinfo: - pooler(hidden_state=hidde_state, start_indices=start_indices, end_indices=end_indices) - assert str(excinfo.value) == ( - "number of start indices [3] has to be the same as num_types [2]" - ) - - -def test_span_max_pooler_wrong_end_indices_shape(hidde_state): - pooler = SpanMaxPooler(input_dim=hidde_state.shape[-1], num_indices=2) - start_indices = torch.tensor([[2, 0], [0, 1]]) - end_indices = torch.tensor([[3, 3, 3], [1, 2, 1]]) - with pytest.raises(ValueError) as excinfo: - pooler(hidden_state=hidde_state, start_indices=start_indices, end_indices=end_indices) - assert str(excinfo.value) == ("number of end indices [3] has to be the same as num_types [2]") - - -def test_span_max_pooler_start_indices_bigger_than_end_indices(hidde_state): - pooler = SpanMaxPooler(input_dim=hidde_state.shape[-1], num_indices=2) - start_indices = torch.tensor([[2, 0], [0, 1]]) - end_indices = torch.tensor([[1, 3], [1, 2]]) - with pytest.raises(ValueError) as excinfo: - pooler(hidden_state=hidde_state, start_indices=start_indices, end_indices=end_indices) - assert str(excinfo.value) == ( - "values in start_indices have to be smaller than respective values in end_indices, but start_indices=\n" - "tensor([[2, 0],\n" - " [0, 1]])\n " - "and end_indices=\n" - "tensor([[1, 3],\n" - " [1, 2]])" - ) - - -def test_span_mean_pooler(hidde_state): - pooler = SpanMeanPooler(input_dim=hidde_state.shape[-1], num_indices=2) - start_indices = torch.tensor([[2, 0], [0, 1]]) - end_indices = torch.tensor([[3, 3], [1, 2]]) - output = pooler(hidden_state=hidde_state, start_indices=start_indices, end_indices=end_indices) - assert output is not None - batch_size = hidde_state.shape[0] - hidden_size = hidde_state.shape[-1] - # times num_indices (=2) due to concat - assert output.shape == (batch_size, hidden_size * 2) - torch.testing.assert_close( - output, torch.tensor([[0.20, 0.21, 0.10, 0.11], [1.00, 1.01, 1.10, 1.11]]) - ) diff --git a/tests/models/components/test_seq2seq_encoder.py b/tests/models/components/test_seq2seq_encoder.py deleted file mode 100644 index 58e2fc0c5..000000000 --- a/tests/models/components/test_seq2seq_encoder.py +++ /dev/null @@ -1,93 +0,0 @@ -import pytest -import torch - -from pie_modules.models.components.seq2seq_encoder import ( - ACTIVATION_TYPE2CLASS, - RNN_TYPE2CLASS, - build_seq2seq_encoder, -) - - -def test_no_encoder(): - seq2seq_dict = {} - input_size = 10 - encoder, output_size = build_seq2seq_encoder(seq2seq_dict, input_size) - assert encoder is None - assert output_size == input_size - - seq2seq_dict = { - "type": "sequential", - "rnn_layer": { - "type": "none", - }, - } - input_size = 10 - encoder, output_size = build_seq2seq_encoder(seq2seq_dict, input_size) - assert len(encoder) == 0 - assert output_size == input_size - - -@pytest.mark.parametrize("seq2seq_enc_type", list(RNN_TYPE2CLASS)) -@pytest.mark.parametrize("bidirectional", [True, False]) -def test_rnn_encoder(seq2seq_enc_type, bidirectional): - hidden_size = 99 - seq2seq_dict = { - "type": seq2seq_enc_type, - "hidden_size": hidden_size, - "bidirectional": bidirectional, - } - input_size = 10 - encoder, output_size = build_seq2seq_encoder(seq2seq_dict, input_size) - assert encoder is not None - assert isinstance(encoder.rnn, RNN_TYPE2CLASS[seq2seq_enc_type]) - - expected_output_size = hidden_size * 2 if bidirectional else hidden_size - assert output_size is not None - assert output_size == expected_output_size - - -@pytest.mark.parametrize("activation_type", list(ACTIVATION_TYPE2CLASS)) -def test_activations(activation_type): - seq2seq_dict = { - "type": activation_type, - } - input_size = 10 - encoder, output_size = build_seq2seq_encoder(seq2seq_dict, input_size) - assert encoder is not None - assert isinstance(encoder, ACTIVATION_TYPE2CLASS[activation_type]) - assert output_size == input_size - - -def test_dropout(): - seq2seq_dict = { - "type": "dropout", - "p": 0.5, - } - input_size = 10 - encoder, output_size = build_seq2seq_encoder(seq2seq_dict, input_size) - assert encoder is not None - assert isinstance(encoder, torch.nn.Dropout) - assert output_size == input_size - - -def test_linear(): - out_features = 99 - seq2seq_dict = { - "type": "linear", - "out_features": out_features, - } - - input_size = 10 - encoder, output_size = build_seq2seq_encoder(seq2seq_dict, input_size) - assert encoder is not None - assert isinstance(encoder, torch.nn.Linear) - assert output_size == out_features - - -def test_unknown_rnn_type(): - seq2seq_dict = { - "type": "unknown", - } - with pytest.raises(ValueError) as exc_info: - build_seq2seq_encoder(seq2seq_dict, 10) - assert str(exc_info.value) == "Unknown seq2seq_encoder_type: unknown" diff --git a/tests/models/test_extractive_question_answering.py b/tests/models/test_extractive_question_answering.py deleted file mode 100644 index 9d8b6c1cb..000000000 --- a/tests/models/test_extractive_question_answering.py +++ /dev/null @@ -1,245 +0,0 @@ -import json - -import pytest -import torch -import transformers -from pytorch_lightning import Trainer - -from pie_modules.annotations import ExtractiveAnswer, Question -from pie_modules.documents import TextDocumentWithQuestionsAndExtractiveAnswers -from pie_modules.models.simple_extractive_question_answering import ( - SimpleExtractiveQuestionAnsweringModel, -) -from pie_modules.taskmodules.extractive_question_answering import ( - ExtractiveQuestionAnsweringTaskModule, -) -from tests import DUMP_FIXTURE_DATA, FIXTURES_ROOT - -FIXTURES_TASKMODULE_DATA_PATH = FIXTURES_ROOT / "taskmodules" / "extractive_question_answering" - - -@pytest.fixture -def documents(): - document0 = TextDocumentWithQuestionsAndExtractiveAnswers( - text="This is a test document", id="doc0" - ) - document0.questions.append(Question(text="What is the first word?")) - document0.answers.append(ExtractiveAnswer(question=document0.questions[0], start=0, end=3)) - - document1 = TextDocumentWithQuestionsAndExtractiveAnswers( - text="Oranges are orange in color.", id="doc1" - ) - document1.questions.append(Question(text="What color are oranges?")) - document1.answers.append(ExtractiveAnswer(question=document1.questions[0], start=23, end=27)) - - document2 = TextDocumentWithQuestionsAndExtractiveAnswers( - text="This is a test document that has two questions attached to it.", id="doc2" - ) - document2.questions.append(Question(text="What type of document is this?")) - document2.questions.append(Question(text="How many questions are attached to this document?")) - document2.answers.append(ExtractiveAnswer(question=document2.questions[0], start=11, end=14)) - document2.answers.append(ExtractiveAnswer(question=document2.questions[1], start=34, end=36)) - - documents = [document0, document1, document2] - return documents - - -@pytest.mark.skipif( - condition=not DUMP_FIXTURE_DATA, - reason="Only need to dump the data if taskmodule has changed", -) -def test_dump_fixtures(documents): - tokenizer_name_or_path = "bert-base-cased" - taskmodule = ExtractiveQuestionAnsweringTaskModule( - tokenizer_name_or_path=tokenizer_name_or_path, - max_length=512, - ) - - task_encodings = taskmodule.encode(documents, encode_target=True) - batch_encoding = taskmodule.collate(task_encodings) - - FIXTURES_TASKMODULE_DATA_PATH.mkdir(parents=True, exist_ok=True) - filepath = FIXTURES_TASKMODULE_DATA_PATH / "batch_encoding_inputs.json" - - inputs = {key: tensor.tolist() for key, tensor in batch_encoding[0].items()} - targets = {key: tensor.tolist() for key, tensor in batch_encoding[1].items()} - converted_batch_encoding = { - "inputs": inputs, - "targets": targets, - } - - with open(filepath, "w") as f: - json.dump(converted_batch_encoding, f) - return converted_batch_encoding - - -@pytest.fixture -def batch(): - filepath = FIXTURES_TASKMODULE_DATA_PATH / "batch_encoding_inputs.json" - with open(filepath) as f: - batch_encoding = json.load(f) - - inputs = {key: torch.LongTensor(tensor) for key, tensor in batch_encoding["inputs"].items()} - targets = {key: torch.LongTensor(tensor) for key, tensor in batch_encoding["targets"].items()} - return inputs, targets - - -def get_model( - monkeypatch, - model_type, - batch_size, - seq_len, - add_dummy_linear=False, - **model_kwargs, -): - class MockConfig: - def __init__( - self, - hidden_size: int = 10, - model_type=model_type, - ) -> None: - self.hidden_size = hidden_size - self.model_type = model_type - - class MockModel(torch.nn.Module): - def __init__(self, batch_size, seq_len, hidden_size, add_dummy_linear) -> None: - super().__init__() - self.batch_size = batch_size - self.seq_len = seq_len - self.hidden_size = hidden_size - if add_dummy_linear: - self.dummy_linear = torch.nn.Linear(self.hidden_size, 99) - - def __call__(self, *args, **kwargs): - torch.manual_seed(42) - start_logits = torch.FloatTensor(torch.rand(self.batch_size, self.seq_len)) - end_logits = torch.FloatTensor(torch.rand(self.batch_size, self.seq_len)) - loss = torch.FloatTensor(torch.rand(1)) - return transformers.modeling_outputs.QuestionAnsweringModelOutput( - start_logits=start_logits, - end_logits=end_logits, - loss=loss, - ) - - hidden_size = 10 - - monkeypatch.setattr( - transformers.AutoConfig, - "from_pretrained", - lambda model_name_or_path: MockConfig(hidden_size=hidden_size, model_type=model_type), - ) - monkeypatch.setattr( - transformers.AutoModelForQuestionAnswering, - "from_pretrained", - lambda model_name_or_path, config: MockModel( - batch_size=batch_size, - seq_len=seq_len, - hidden_size=hidden_size, - add_dummy_linear=add_dummy_linear, - ), - ) - - # set seed to make the classifier deterministic - torch.manual_seed(42) - result = SimpleExtractiveQuestionAnsweringModel( - model_name_or_path=model_type, - max_input_length=seq_len, - **model_kwargs, - ) - assert not result.is_from_pretrained - - return result - - -@pytest.fixture -def model(monkeypatch, batch): - inputs, targets = batch - model = get_model( - monkeypatch=monkeypatch, - model_type="bert", - batch_size=inputs["input_ids"].shape[0], - seq_len=inputs["input_ids"].shape[1], - add_dummy_linear=True, - ) - return model - - -def test_get_model(monkeypatch, model): - assert model is not None - assert isinstance(model, SimpleExtractiveQuestionAnsweringModel) - - -def test_forward(batch, model): - inputs, targets = batch - batch_size, seq_len = inputs["input_ids"].shape - - # set seed to make sure the output is deterministic - torch.manual_seed(42) - output = model.forward(inputs) - assert set(output) == {"start_logits", "end_logits", "loss"} - start_logits = output["start_logits"] - end_logits = output["end_logits"] - loss = output["loss"] - assert start_logits.shape == (batch_size, seq_len) - assert end_logits.shape == (batch_size, seq_len) - assert loss.shape == (1,) - expected_loss = torch.FloatTensor([0.04587]) - torch.testing.assert_close(output["loss"], expected_loss) - - -def test_step(batch, model): - torch.manual_seed(42) - loss = model.step("train", batch) - assert loss is not None - expected_loss = torch.FloatTensor([0.04587]) - torch.testing.assert_close(loss, expected_loss) - - -def test_training_step(batch, model): - loss = model.training_step(batch, batch_idx=0) - assert loss is not None - expected_loss = torch.FloatTensor([0.04587]) - torch.testing.assert_close(loss, expected_loss) - - -def test_validation_step(batch, model): - loss = model.validation_step(batch, batch_idx=0) - assert loss is not None - expected_loss = torch.FloatTensor([0.04587]) - torch.testing.assert_close(loss, expected_loss) - - -def test_test_step(batch, model): - loss = model.test_step(batch, batch_idx=0) - assert loss is not None - expected_loss = torch.FloatTensor([0.04587]) - torch.testing.assert_close(loss, expected_loss) - - -def test_optim(model): - optimizer = model.configure_optimizers() - assert optimizer is not None - assert isinstance(optimizer, torch.optim.Adam) - assert optimizer.defaults["lr"] == 1e-05 - - -def test_optim_with_warmup_proportion(monkeypatch, batch): - inputs, targets = batch - model = get_model( - monkeypatch=monkeypatch, - model_type="bert", - batch_size=inputs["input_ids"].shape[0], - seq_len=inputs["input_ids"].shape[1], - add_dummy_linear=True, - warmup_proportion=0.1, - ) - model.trainer = Trainer(max_epochs=10) - optimizers_and_schedulars = model.configure_optimizers() - assert optimizers_and_schedulars is not None - assert isinstance(optimizers_and_schedulars, tuple) and len(optimizers_and_schedulars) == 2 - - optimizers, schedulers = optimizers_and_schedulars - assert isinstance(optimizers[0], torch.optim.Optimizer) - assert set(schedulers[0]) == {"scheduler", "interval"} - schedular = schedulers[0]["scheduler"] - assert isinstance(schedular, torch.optim.lr_scheduler.LRScheduler) diff --git a/tests/models/test_sequence_classification_with_pooler.py b/tests/models/test_sequence_classification_with_pooler.py deleted file mode 100644 index 07f9e10cd..000000000 --- a/tests/models/test_sequence_classification_with_pooler.py +++ /dev/null @@ -1,578 +0,0 @@ -from typing import Dict - -import pytest -import torch -from pytorch_lightning import Trainer -from torch import LongTensor, tensor -from torch.optim.lr_scheduler import LambdaLR -from transformers.modeling_outputs import SequenceClassifierOutput - -from pie_modules.models import SequenceClassificationModelWithPooler -from pie_modules.models.sequence_classification_with_pooler import OutputType -from tests.models import trunc_number - -NUM_CLASSES = 4 -POOLER = "start_tokens" - - -@pytest.fixture -def inputs() -> Dict[str, LongTensor]: - result_dict = { - "input_ids": torch.tensor( - [ - [ - 101, - 28998, - 13832, - 3121, - 2340, - 138, - 28996, - 1759, - 1120, - 28999, - 139, - 28997, - 119, - 102, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - ], - [ - 101, - 1752, - 5650, - 119, - 28998, - 13832, - 3121, - 2340, - 144, - 28996, - 1759, - 1120, - 28999, - 145, - 28997, - 119, - 1262, - 1771, - 146, - 119, - 102, - 0, - ], - [ - 101, - 1752, - 5650, - 119, - 28998, - 13832, - 3121, - 2340, - 144, - 28996, - 1759, - 1120, - 145, - 119, - 1262, - 1771, - 28999, - 146, - 28997, - 119, - 102, - 0, - ], - [ - 101, - 1752, - 5650, - 119, - 13832, - 3121, - 2340, - 144, - 1759, - 1120, - 28999, - 145, - 28997, - 119, - 1262, - 1771, - 28998, - 146, - 28996, - 119, - 102, - 0, - ], - [ - 101, - 1752, - 5650, - 119, - 28998, - 13832, - 3121, - 2340, - 150, - 28996, - 1759, - 1120, - 28999, - 151, - 28997, - 119, - 1262, - 1122, - 1771, - 152, - 119, - 102, - ], - [ - 101, - 1752, - 5650, - 119, - 13832, - 3121, - 2340, - 150, - 1759, - 1120, - 151, - 119, - 1262, - 28998, - 1122, - 28996, - 1771, - 28999, - 152, - 28997, - 119, - 102, - ], - [ - 101, - 1752, - 5650, - 119, - 13832, - 3121, - 2340, - 150, - 1759, - 1120, - 151, - 119, - 1262, - 28999, - 1122, - 28997, - 1771, - 28998, - 152, - 28996, - 119, - 102, - ], - ] - ).to(torch.long), - "attention_mask": torch.tensor( - [ - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - ] - ).to(torch.long), - "pooler_start_indices": torch.tensor( - [[2, 10], [5, 13], [5, 17], [17, 11], [5, 13], [14, 18], [18, 14]] - ).to(torch.long), - "pooler_end_indices": torch.tensor( - [[6, 11], [9, 14], [9, 18], [18, 12], [9, 14], [15, 19], [19, 15]] - ).to(torch.long), - } - - return result_dict - - -@pytest.fixture -def targets() -> Dict[str, LongTensor]: - return {"labels": torch.tensor([0, 1, 2, 3, 1, 2, 3]).to(torch.long)} - - -@pytest.fixture -def model() -> SequenceClassificationModelWithPooler: - torch.manual_seed(42) - result = SequenceClassificationModelWithPooler( - model_name_or_path="prajjwal1/bert-tiny", - num_classes=NUM_CLASSES, - pooler=POOLER, - ) - return result - - -def test_model(model): - assert model is not None - named_parameters = dict(model.named_parameters()) - parameter_means = {k: trunc_number(v.mean().item(), 7) for k, v in named_parameters.items()} - parameter_means_expected = { - "classifier.bias": -0.0253964, - "classifier.weight": -0.000511, - "model.embeddings.LayerNorm.bias": -0.0294608, - "model.embeddings.LayerNorm.weight": 1.312345, - "model.embeddings.position_embeddings.weight": 5.5e-05, - "model.embeddings.token_type_embeddings.weight": -0.0015419, - "model.embeddings.word_embeddings.weight": 0.0031152, - "model.encoder.layer.0.attention.output.LayerNorm.bias": 0.0608714, - "model.encoder.layer.0.attention.output.LayerNorm.weight": 1.199831, - "model.encoder.layer.0.attention.output.dense.bias": 0.0007209, - "model.encoder.layer.0.attention.output.dense.weight": 3.01e-05, - "model.encoder.layer.0.attention.self.key.bias": 0.0020557, - "model.encoder.layer.0.attention.self.key.weight": 0.0003863, - "model.encoder.layer.0.attention.self.query.bias": 0.0185744, - "model.encoder.layer.0.attention.self.query.weight": -0.0003949, - "model.encoder.layer.0.attention.self.value.bias": 0.0065417, - "model.encoder.layer.0.attention.self.value.weight": 4.22e-05, - "model.encoder.layer.0.intermediate.dense.bias": -0.1219958, - "model.encoder.layer.0.intermediate.dense.weight": -0.0011731, - "model.encoder.layer.0.output.LayerNorm.bias": 0.005295, - "model.encoder.layer.0.output.LayerNorm.weight": 1.2419648, - "model.encoder.layer.0.output.dense.bias": -0.0013031, - "model.encoder.layer.0.output.dense.weight": -0.0002212, - "model.encoder.layer.1.attention.output.LayerNorm.bias": 0.0443237, - "model.encoder.layer.1.attention.output.LayerNorm.weight": 1.0377343, - "model.encoder.layer.1.attention.output.dense.bias": 0.0041446, - "model.encoder.layer.1.attention.output.dense.weight": -2.43e-05, - "model.encoder.layer.1.attention.self.key.bias": 0.0045062, - "model.encoder.layer.1.attention.self.key.weight": 0.0001333, - "model.encoder.layer.1.attention.self.query.bias": -0.0358397, - "model.encoder.layer.1.attention.self.query.weight": -0.0007321, - "model.encoder.layer.1.attention.self.value.bias": -0.0007094, - "model.encoder.layer.1.attention.self.value.weight": 0.0001012, - "model.encoder.layer.1.intermediate.dense.bias": -0.1247257, - "model.encoder.layer.1.intermediate.dense.weight": -0.001344, - "model.encoder.layer.1.output.LayerNorm.bias": -0.0474442, - "model.encoder.layer.1.output.LayerNorm.weight": 1.017162, - "model.encoder.layer.1.output.dense.bias": 0.000677, - "model.encoder.layer.1.output.dense.weight": -5.32e-05, - "model.pooler.dense.bias": -0.0052078, - "model.pooler.dense.weight": 0.0001295, - "pooler.pooler.missing_embeddings": 0.0630417, - } - assert parameter_means == parameter_means_expected - - -def test_model_pickleable(model): - import pickle - - pickle.dumps(model) - - -@pytest.fixture -def model_output(model, inputs) -> OutputType: - # set seed to make sure the output is deterministic - torch.manual_seed(42) - return model(inputs) - - -def test_forward_logits(model_output, inputs): - batch_size, seq_len = inputs["input_ids"].shape - - assert isinstance(model_output, SequenceClassifierOutput) - - logits = model_output.logits - - assert logits.shape == (batch_size, NUM_CLASSES) - - torch.testing.assert_close( - logits, - torch.tensor( - [ - [ - -0.29492446780204773, - -0.804599940776825, - -0.19659805297851562, - -1.0868580341339111, - ], - [ - -0.3601434826850891, - -0.7454482316970825, - 0.4882846474647522, - -1.0253472328186035, - ], - [ - -0.40172430872917175, - -1.2119399309158325, - 0.5856620669364929, - -1.0999149084091187, - ], - [ - -0.09260234981775284, - -1.0742112398147583, - 0.3299948275089264, - -0.5182554125785828, - ], - [ - -0.40149545669555664, - -0.7731614708900452, - 0.4616103768348694, - -1.0583568811416626, - ], - [ - -0.14356234669685364, - -1.2634986639022827, - 0.31660008430480957, - -0.7487461566925049, - ], - [ - -0.11717557162046432, - -0.971996009349823, - 0.3477852940559387, - -0.5993944406509399, - ], - ] - ), - ) - - -def test_decode(model, model_output, inputs): - decoded = model.decode(inputs=inputs, outputs=model_output) - assert isinstance(decoded, dict) - assert set(decoded) == {"labels", "probabilities"} - labels = decoded["labels"] - assert labels.shape == (inputs["input_ids"].shape[0],) - torch.testing.assert_close( - labels, - torch.tensor([2, 2, 2, 2, 2, 2, 2]), - ) - probabilities = decoded["probabilities"] - assert probabilities.shape == (inputs["input_ids"].shape[0], NUM_CLASSES) - torch.testing.assert_close( - probabilities.round(decimals=4), - torch.tensor( - [ - [0.3168, 0.1903, 0.3495, 0.1435], - [0.2207, 0.1502, 0.5156, 0.1135], - [0.2161, 0.0961, 0.5802, 0.1075], - [0.2814, 0.1054, 0.4294, 0.1838], - [0.2184, 0.1506, 0.5177, 0.1132], - [0.2893, 0.0944, 0.4583, 0.1580], - [0.2751, 0.1170, 0.4380, 0.1699], - ] - ), - ) - - -def test_decode_with_multi_label(model_output, inputs): - torch.manual_seed(42) - model = SequenceClassificationModelWithPooler( - model_name_or_path="prajjwal1/bert-tiny", - num_classes=NUM_CLASSES, - pooler=POOLER, - multi_label=True, - ) - decoded = model.decode(inputs=inputs, outputs=model_output) - assert isinstance(decoded, dict) - assert set(decoded) == {"labels", "probabilities"} - labels = decoded["labels"] - assert labels.shape == (inputs["input_ids"].shape[0], NUM_CLASSES) - torch.testing.assert_close( - labels, - torch.tensor( - [ - [0, 0, 0, 0], - [0, 0, 1, 0], - [0, 0, 1, 0], - [0, 0, 1, 0], - [0, 0, 1, 0], - [0, 0, 1, 0], - [0, 0, 1, 0], - ] - ), - ) - probabilities = decoded["probabilities"] - assert probabilities.shape == (inputs["input_ids"].shape[0], NUM_CLASSES) - torch.testing.assert_close( - probabilities.round(decimals=4), - torch.tensor( - [ - [0.4268, 0.3090, 0.4510, 0.2522], - [0.4109, 0.3218, 0.6197, 0.2640], - [0.4009, 0.2294, 0.6424, 0.2498], - [0.4769, 0.2546, 0.5818, 0.3733], - [0.4010, 0.3158, 0.6134, 0.2576], - [0.4642, 0.2204, 0.5785, 0.3211], - [0.4707, 0.2745, 0.5861, 0.3545], - ] - ), - ) - - -@pytest.fixture -def batch(inputs, targets): - return inputs, targets - - -def test_training_step(batch, model): - # set the seed to make sure the loss is deterministic - torch.manual_seed(42) - loss = model.training_step(batch, batch_idx=0) - assert loss is not None - torch.testing.assert_close(loss, torch.tensor(1.3899686336517334)) - - -def test_validation_step(batch, model): - # set the seed to make sure the loss is deterministic - torch.manual_seed(42) - loss = model.validation_step(batch, batch_idx=0) - assert loss is not None - torch.testing.assert_close(loss, torch.tensor(1.3899686336517334)) - - -def test_test_step(batch, model): - # set the seed to make sure the loss is deterministic - torch.manual_seed(42) - loss = model.test_step(batch, batch_idx=0) - assert loss is not None - torch.testing.assert_close(loss, torch.tensor(1.3899686336517334)) - - -def test_base_model_named_parameters(model): - base_model_named_parameters = dict(model.base_model_named_parameters()) - assert set(base_model_named_parameters) == { - "model.pooler.dense.bias", - "model.encoder.layer.0.intermediate.dense.weight", - "model.encoder.layer.0.intermediate.dense.bias", - "model.encoder.layer.1.attention.output.dense.weight", - "model.encoder.layer.1.attention.output.LayerNorm.weight", - "model.encoder.layer.1.attention.self.query.weight", - "model.encoder.layer.1.output.dense.weight", - "model.encoder.layer.0.output.dense.bias", - "model.encoder.layer.1.intermediate.dense.bias", - "model.encoder.layer.1.attention.self.value.bias", - "model.encoder.layer.0.attention.output.dense.weight", - "model.encoder.layer.0.attention.self.query.bias", - "model.encoder.layer.0.attention.self.value.bias", - "model.encoder.layer.1.output.dense.bias", - "model.encoder.layer.1.attention.self.query.bias", - "model.encoder.layer.1.attention.output.LayerNorm.bias", - "model.encoder.layer.0.attention.self.query.weight", - "model.encoder.layer.0.attention.output.LayerNorm.bias", - "model.encoder.layer.0.attention.self.key.bias", - "model.encoder.layer.1.intermediate.dense.weight", - "model.encoder.layer.1.output.LayerNorm.bias", - "model.encoder.layer.1.output.LayerNorm.weight", - "model.encoder.layer.0.attention.self.key.weight", - "model.encoder.layer.1.attention.output.dense.bias", - "model.encoder.layer.0.attention.output.dense.bias", - "model.embeddings.LayerNorm.bias", - "model.encoder.layer.0.attention.self.value.weight", - "model.encoder.layer.0.attention.output.LayerNorm.weight", - "model.embeddings.token_type_embeddings.weight", - "model.encoder.layer.0.output.LayerNorm.weight", - "model.embeddings.position_embeddings.weight", - "model.encoder.layer.1.attention.self.key.bias", - "model.embeddings.LayerNorm.weight", - "model.encoder.layer.0.output.LayerNorm.bias", - "model.encoder.layer.1.attention.self.key.weight", - "model.pooler.dense.weight", - "model.encoder.layer.0.output.dense.weight", - "model.embeddings.word_embeddings.weight", - "model.encoder.layer.1.attention.self.value.weight", - } - - -def test_task_named_parameters(model): - task_named_parameters = dict(model.task_named_parameters()) - assert set(task_named_parameters) == { - "classifier.weight", - "pooler.pooler.missing_embeddings", - "classifier.bias", - } - - -def test_configure_optimizers_with_warmup(): - model = SequenceClassificationModelWithPooler( - model_name_or_path="prajjwal1/bert-tiny", - num_classes=NUM_CLASSES, - ) - model.trainer = Trainer(max_epochs=10) - optimizers_and_schedulers = model.configure_optimizers() - assert len(optimizers_and_schedulers) == 2 - optimizers, schedulers = optimizers_and_schedulers - assert len(optimizers) == 1 - assert len(schedulers) == 1 - optimizer = optimizers[0] - assert optimizer is not None - assert isinstance(optimizer, torch.optim.AdamW) - assert optimizer.defaults["lr"] == 1e-05 - assert optimizer.defaults["weight_decay"] == 0.01 - assert optimizer.defaults["eps"] == 1e-08 - - scheduler = schedulers[0] - assert isinstance(scheduler, dict) - assert set(scheduler) == {"scheduler", "interval"} - assert isinstance(scheduler["scheduler"], LambdaLR) - - -def test_configure_optimizers_with_task_learning_rate(monkeypatch): - model = SequenceClassificationModelWithPooler( - model_name_or_path="prajjwal1/bert-tiny", - num_classes=NUM_CLASSES, - learning_rate=1e-5, - task_learning_rate=1e-3, - # disable warmup to make sure the scheduler is not added which would set the learning rate - # to 0 - warmup_proportion=0.0, - ) - optimizer = model.configure_optimizers() - assert optimizer is not None - assert isinstance(optimizer, torch.optim.AdamW) - assert len(optimizer.param_groups) == 2 - # base model parameters - param_group = optimizer.param_groups[0] - assert len(param_group["params"]) == 39 - assert param_group["lr"] == 1e-5 - # classifier head parameters - param_group = optimizer.param_groups[1] - assert len(param_group["params"]) == 2 - assert param_group["lr"] == 1e-3 - # ensure that all parameters are covered - assert set(optimizer.param_groups[0]["params"] + optimizer.param_groups[1]["params"]) == set( - model.parameters() - ) - - -def test_freeze_base_model(monkeypatch, inputs, targets): - model = SequenceClassificationModelWithPooler( - model_name_or_path="prajjwal1/bert-tiny", - num_classes=NUM_CLASSES, - freeze_base_model=True, - # disable warmup to make sure the scheduler is not added which would set the learning rate - # to 0 - warmup_proportion=0.0, - ) - base_model_params = [param for name, param in model.base_model_named_parameters()] - task_params = [param for name, param in model.task_named_parameters()] - assert len(base_model_params) + len(task_params) == len(list(model.parameters())) - for param in base_model_params: - assert not param.requires_grad - for param in task_params: - assert param.requires_grad diff --git a/tests/models/test_sequence_pair_similarity_model_with_pooler.py b/tests/models/test_sequence_pair_similarity_model_with_pooler.py deleted file mode 100644 index 3e6871a11..000000000 --- a/tests/models/test_sequence_pair_similarity_model_with_pooler.py +++ /dev/null @@ -1,326 +0,0 @@ -from typing import Dict - -import pytest -import torch -from pytorch_lightning import Trainer -from torch import LongTensor, tensor -from torch.optim.lr_scheduler import LambdaLR -from transformers.modeling_outputs import SequenceClassifierOutput - -from pie_modules.models import SequencePairSimilarityModelWithPooler -from pie_modules.models.sequence_classification_with_pooler import OutputType -from tests.models import trunc_number - - -@pytest.fixture -def inputs() -> Dict[str, LongTensor]: - result_dict = { - "encoding": { - "input_ids": tensor( - [ - [101, 1262, 1131, 1771, 140, 119, 102], - [101, 1262, 1131, 1771, 140, 119, 102], - [101, 1262, 1131, 1771, 140, 119, 102], - [101, 1262, 1131, 1771, 140, 119, 102], - ] - ), - "token_type_ids": tensor( - [ - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - ] - ), - "attention_mask": tensor( - [ - [1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1], - ] - ), - }, - "encoding_pair": { - "input_ids": tensor( - [ - [101, 3162, 7871, 1117, 5855, 119, 102], - [101, 3162, 7871, 1117, 5855, 119, 102], - [101, 3162, 7871, 1117, 5855, 119, 102], - [101, 3162, 7871, 1117, 5855, 119, 102], - ] - ), - "token_type_ids": tensor( - [ - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - ] - ), - "attention_mask": tensor( - [ - [1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1], - ] - ), - }, - "pooler_start_indices": tensor([[2], [2], [4], [4]]), - "pooler_end_indices": tensor([[3], [3], [5], [5]]), - "pooler_pair_start_indices": tensor([[1], [3], [1], [3]]), - "pooler_pair_end_indices": tensor([[2], [5], [2], [5]]), - } - - return result_dict - - -@pytest.fixture -def targets() -> Dict[str, LongTensor]: - return {"scores": tensor([0.0, 0.0, 0.0, 0.0])} - - -@pytest.fixture -def model() -> SequencePairSimilarityModelWithPooler: - torch.manual_seed(42) - result = SequencePairSimilarityModelWithPooler( - model_name_or_path="prajjwal1/bert-tiny", - ) - return result - - -def test_model(model): - assert model is not None - named_parameters = dict(model.named_parameters()) - parameter_means = {k: trunc_number(v.mean().item(), 7) for k, v in named_parameters.items()} - parameter_means_expected = { - "model.embeddings.word_embeddings.weight": 0.0031152, - "model.embeddings.position_embeddings.weight": 5.5e-05, - "model.embeddings.token_type_embeddings.weight": -0.0015419, - "model.embeddings.LayerNorm.weight": 1.312345, - "model.embeddings.LayerNorm.bias": -0.0294608, - "model.encoder.layer.0.attention.self.query.weight": -0.0003949, - "model.encoder.layer.0.attention.self.query.bias": 0.0185744, - "model.encoder.layer.0.attention.self.key.weight": 0.0003863, - "model.encoder.layer.0.attention.self.key.bias": 0.0020557, - "model.encoder.layer.0.attention.self.value.weight": 4.22e-05, - "model.encoder.layer.0.attention.self.value.bias": 0.0065417, - "model.encoder.layer.0.attention.output.dense.weight": 3.01e-05, - "model.encoder.layer.0.attention.output.dense.bias": 0.0007209, - "model.encoder.layer.0.attention.output.LayerNorm.weight": 1.199831, - "model.encoder.layer.0.attention.output.LayerNorm.bias": 0.0608714, - "model.encoder.layer.0.intermediate.dense.weight": -0.0011731, - "model.encoder.layer.0.intermediate.dense.bias": -0.1219958, - "model.encoder.layer.0.output.dense.weight": -0.0002212, - "model.encoder.layer.0.output.dense.bias": -0.0013031, - "model.encoder.layer.0.output.LayerNorm.weight": 1.2419648, - "model.encoder.layer.0.output.LayerNorm.bias": 0.005295, - "model.encoder.layer.1.attention.self.query.weight": -0.0007321, - "model.encoder.layer.1.attention.self.query.bias": -0.0358397, - "model.encoder.layer.1.attention.self.key.weight": 0.0001333, - "model.encoder.layer.1.attention.self.key.bias": 0.0045062, - "model.encoder.layer.1.attention.self.value.weight": 0.0001012, - "model.encoder.layer.1.attention.self.value.bias": -0.0007094, - "model.encoder.layer.1.attention.output.dense.weight": -2.43e-05, - "model.encoder.layer.1.attention.output.dense.bias": 0.0041446, - "model.encoder.layer.1.attention.output.LayerNorm.weight": 1.0377343, - "model.encoder.layer.1.attention.output.LayerNorm.bias": 0.0443237, - "model.encoder.layer.1.intermediate.dense.weight": -0.001344, - "model.encoder.layer.1.intermediate.dense.bias": -0.1247257, - "model.encoder.layer.1.output.dense.weight": -5.32e-05, - "model.encoder.layer.1.output.dense.bias": 0.000677, - "model.encoder.layer.1.output.LayerNorm.weight": 1.017162, - "model.encoder.layer.1.output.LayerNorm.bias": -0.0474442, - "model.pooler.dense.weight": 0.0001295, - "model.pooler.dense.bias": -0.0052078, - "pooler.missing_embeddings": 0.0812017, - } - assert parameter_means == parameter_means_expected - - -def test_model_pickleable(model): - import pickle - - pickle.dumps(model) - - -@pytest.fixture -def model_output(model, inputs) -> OutputType: - # set seed to make sure the output is deterministic - torch.manual_seed(42) - return model(inputs) - - -def test_forward_logits(model_output, inputs): - assert isinstance(model_output, SequenceClassifierOutput) - - logits = model_output.logits - - torch.testing.assert_close( - logits, - torch.tensor( - [0.5338148474693298, 0.5866107940673828, 0.5076886415481567, 0.5946245789527893] - ), - ) - - -def test_decode(model, model_output, inputs): - decoded = model.decode(inputs=inputs, outputs=model_output) - assert isinstance(decoded, dict) - assert set(decoded) == {"scores"} - scores = decoded["scores"] - torch.testing.assert_close( - scores, - torch.tensor( - [0.5338148474693298, 0.5866107940673828, 0.5076886415481567, 0.5946245789527893] - ), - ) - - -@pytest.fixture -def batch(inputs, targets): - return inputs, targets - - -def test_training_step(batch, model): - # set the seed to make sure the loss is deterministic - torch.manual_seed(42) - loss = model.training_step(batch, batch_idx=0) - assert loss is not None - torch.testing.assert_close(loss, torch.tensor(0.8145309686660767)) - - -def test_validation_step(batch, model): - # set the seed to make sure the loss is deterministic - torch.manual_seed(42) - loss = model.validation_step(batch, batch_idx=0) - assert loss is not None - torch.testing.assert_close(loss, torch.tensor(0.8145309686660767)) - - -def test_test_step(batch, model): - # set the seed to make sure the loss is deterministic - torch.manual_seed(42) - loss = model.test_step(batch, batch_idx=0) - assert loss is not None - torch.testing.assert_close(loss, torch.tensor(0.8145309686660767)) - - -def test_base_model_named_parameters(model): - base_model_named_parameters = dict(model.base_model_named_parameters()) - assert set(base_model_named_parameters) == { - "model.pooler.dense.bias", - "model.encoder.layer.0.intermediate.dense.weight", - "model.encoder.layer.0.intermediate.dense.bias", - "model.encoder.layer.1.attention.output.dense.weight", - "model.encoder.layer.1.attention.output.LayerNorm.weight", - "model.encoder.layer.1.attention.self.query.weight", - "model.encoder.layer.1.output.dense.weight", - "model.encoder.layer.0.output.dense.bias", - "model.encoder.layer.1.intermediate.dense.bias", - "model.encoder.layer.1.attention.self.value.bias", - "model.encoder.layer.0.attention.output.dense.weight", - "model.encoder.layer.0.attention.self.query.bias", - "model.encoder.layer.0.attention.self.value.bias", - "model.encoder.layer.1.output.dense.bias", - "model.encoder.layer.1.attention.self.query.bias", - "model.encoder.layer.1.attention.output.LayerNorm.bias", - "model.encoder.layer.0.attention.self.query.weight", - "model.encoder.layer.0.attention.output.LayerNorm.bias", - "model.encoder.layer.0.attention.self.key.bias", - "model.encoder.layer.1.intermediate.dense.weight", - "model.encoder.layer.1.output.LayerNorm.bias", - "model.encoder.layer.1.output.LayerNorm.weight", - "model.encoder.layer.0.attention.self.key.weight", - "model.encoder.layer.1.attention.output.dense.bias", - "model.encoder.layer.0.attention.output.dense.bias", - "model.embeddings.LayerNorm.bias", - "model.encoder.layer.0.attention.self.value.weight", - "model.encoder.layer.0.attention.output.LayerNorm.weight", - "model.embeddings.token_type_embeddings.weight", - "model.encoder.layer.0.output.LayerNorm.weight", - "model.embeddings.position_embeddings.weight", - "model.encoder.layer.1.attention.self.key.bias", - "model.embeddings.LayerNorm.weight", - "model.encoder.layer.0.output.LayerNorm.bias", - "model.encoder.layer.1.attention.self.key.weight", - "model.pooler.dense.weight", - "model.encoder.layer.0.output.dense.weight", - "model.embeddings.word_embeddings.weight", - "model.encoder.layer.1.attention.self.value.weight", - } - - -def test_task_named_parameters(model): - task_named_parameters = dict(model.task_named_parameters()) - assert set(task_named_parameters) == { - "pooler.missing_embeddings", - } - - -def test_configure_optimizers_with_warmup(): - model = SequencePairSimilarityModelWithPooler( - model_name_or_path="prajjwal1/bert-tiny", - ) - model.trainer = Trainer(max_epochs=10) - optimizers_and_schedulers = model.configure_optimizers() - assert len(optimizers_and_schedulers) == 2 - optimizers, schedulers = optimizers_and_schedulers - assert len(optimizers) == 1 - assert len(schedulers) == 1 - optimizer = optimizers[0] - assert optimizer is not None - assert isinstance(optimizer, torch.optim.AdamW) - assert optimizer.defaults["lr"] == 1e-05 - assert optimizer.defaults["weight_decay"] == 0.01 - assert optimizer.defaults["eps"] == 1e-08 - - scheduler = schedulers[0] - assert isinstance(scheduler, dict) - assert set(scheduler) == {"scheduler", "interval"} - assert isinstance(scheduler["scheduler"], LambdaLR) - - -def test_configure_optimizers_with_task_learning_rate(monkeypatch): - model = SequencePairSimilarityModelWithPooler( - model_name_or_path="prajjwal1/bert-tiny", - learning_rate=1e-5, - task_learning_rate=1e-3, - # disable warmup to make sure the scheduler is not added which would set the learning rate - # to 0 - warmup_proportion=0.0, - ) - optimizer = model.configure_optimizers() - assert optimizer is not None - assert isinstance(optimizer, torch.optim.AdamW) - assert len(optimizer.param_groups) == 2 - # base model parameters - param_group = optimizer.param_groups[0] - assert len(param_group["params"]) == 39 - assert param_group["lr"] == 1e-5 - # classifier head parameters - there is just the default embedding (which is not used) - param_group = optimizer.param_groups[1] - assert len(param_group["params"]) == 1 - assert param_group["lr"] == 1e-3 - # ensure that all parameters are covered - assert set(optimizer.param_groups[0]["params"] + optimizer.param_groups[1]["params"]) == set( - model.parameters() - ) - - -def test_freeze_base_model(monkeypatch, inputs, targets): - model = SequencePairSimilarityModelWithPooler( - model_name_or_path="prajjwal1/bert-tiny", - freeze_base_model=True, - # disable warmup to make sure the scheduler is not added which would set the learning rate - # to 0 - warmup_proportion=0.0, - ) - base_model_params = [param for name, param in model.base_model_named_parameters()] - task_params = [param for name, param in model.task_named_parameters()] - assert len(base_model_params) + len(task_params) == len(list(model.parameters())) - for param in base_model_params: - assert not param.requires_grad - for param in task_params: - assert param.requires_grad diff --git a/tests/models/test_simple_generative.py b/tests/models/test_simple_generative.py deleted file mode 100644 index 147ce1f67..000000000 --- a/tests/models/test_simple_generative.py +++ /dev/null @@ -1,439 +0,0 @@ -import math -from typing import List, Optional - -import pytest -import torch -from pytorch_lightning import Trainer -from torch.optim import Optimizer - -from pie_modules.models import SimpleGenerativeModel -from pie_modules.models.common import TESTING, VALIDATION -from pie_modules.taskmodules import TextToTextTaskModule -from tests.models import trunc_number - -MODEL_ID = "google/t5-efficient-tiny-nl2" - - -@pytest.fixture(scope="module") -def taskmodule(): - return TextToTextTaskModule( - tokenizer_name_or_path=MODEL_ID, - document_type="pie_modules.documents.TextDocumentWithAbstractiveSummary", - target_layer="abstractive_summary", - target_annotation_type="pie_modules.annotations.AbstractiveSummary", - tokenized_document_type="pie_modules.documents.TokenDocumentWithAbstractiveSummary", - text_metric_type="torchmetrics.text.ROUGEScore", - ) - - -@pytest.fixture(scope="module") -def model(taskmodule) -> SimpleGenerativeModel: - return SimpleGenerativeModel( - base_model={ - "_type_": "transformers.AutoModelForSeq2SeqLM", - "pretrained_model_name_or_path": MODEL_ID, - }, - # only use predictions for metrics in test stage to cover all cases (default is all stages) - metric_call_predict=[TESTING], - taskmodule_config=taskmodule.config, - # use a strange learning rate to make sure it is passed through - learning_rate=13e-3, - optimizer_type="torch.optim.Adam", - ) - - -def test_model(model): - assert model is not None - assert model.model is not None - assert model.taskmodule is not None - named_parameters = dict(model.named_parameters()) - parameter_means = {k: trunc_number(v.mean().item(), 7) for k, v in named_parameters.items()} - parameter_means_expected = { - "model.shared.weight": -0.3906954, - "model.encoder.block.0.layer.0.SelfAttention.q.weight": 2.15e-05, - "model.encoder.block.0.layer.0.SelfAttention.k.weight": -0.0015166, - "model.encoder.block.0.layer.0.SelfAttention.v.weight": -0.0018635, - "model.encoder.block.0.layer.0.SelfAttention.o.weight": 0.000866, - "model.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight": -2.8229351, - "model.encoder.block.0.layer.0.layer_norm.weight": 0.226491, - "model.encoder.block.0.layer.1.DenseReluDense.wi.weight": 0.0034651, - "model.encoder.block.0.layer.1.DenseReluDense.wo.weight": 0.00017, - "model.encoder.block.0.layer.1.layer_norm.weight": 1.2047424, - "model.encoder.block.1.layer.0.SelfAttention.q.weight": -7.88e-05, - "model.encoder.block.1.layer.0.SelfAttention.k.weight": -0.0017292, - "model.encoder.block.1.layer.0.SelfAttention.v.weight": -0.0025692, - "model.encoder.block.1.layer.0.SelfAttention.o.weight": 0.000484, - "model.encoder.block.1.layer.0.layer_norm.weight": 0.4024209, - "model.encoder.block.1.layer.1.DenseReluDense.wi.weight": 0.0012148, - "model.encoder.block.1.layer.1.DenseReluDense.wo.weight": -0.000555, - "model.encoder.block.1.layer.1.layer_norm.weight": 1.9719848, - "model.encoder.final_layer_norm.weight": 1.3045949, - "model.decoder.block.0.layer.0.SelfAttention.q.weight": 4.21e-05, - "model.decoder.block.0.layer.0.SelfAttention.k.weight": 0.0006944, - "model.decoder.block.0.layer.0.SelfAttention.v.weight": -0.0001296, - "model.decoder.block.0.layer.0.SelfAttention.o.weight": 0.0020978, - "model.decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight": -0.5869011, - "model.decoder.block.0.layer.0.layer_norm.weight": 0.1958751, - "model.decoder.block.0.layer.1.EncDecAttention.q.weight": 7.8e-06, - "model.decoder.block.0.layer.1.EncDecAttention.k.weight": -0.0001409, - "model.decoder.block.0.layer.1.EncDecAttention.v.weight": -0.0010971, - "model.decoder.block.0.layer.1.EncDecAttention.o.weight": 0.0026751, - "model.decoder.block.0.layer.1.layer_norm.weight": 0.0658893, - "model.decoder.block.0.layer.2.DenseReluDense.wi.weight": 0.0012591, - "model.decoder.block.0.layer.2.DenseReluDense.wo.weight": 0.0033682, - "model.decoder.block.0.layer.2.layer_norm.weight": 2.9871673, - "model.decoder.block.1.layer.0.SelfAttention.q.weight": 6.16e-05, - "model.decoder.block.1.layer.0.SelfAttention.k.weight": 0.0004128, - "model.decoder.block.1.layer.0.SelfAttention.v.weight": -0.0003878, - "model.decoder.block.1.layer.0.SelfAttention.o.weight": -0.0040457, - "model.decoder.block.1.layer.0.layer_norm.weight": 1.1167399, - "model.decoder.block.1.layer.1.EncDecAttention.q.weight": -0.0001246, - "model.decoder.block.1.layer.1.EncDecAttention.k.weight": 0.0013352, - "model.decoder.block.1.layer.1.EncDecAttention.v.weight": -0.0024415, - "model.decoder.block.1.layer.1.EncDecAttention.o.weight": -9.83e-05, - "model.decoder.block.1.layer.1.layer_norm.weight": 0.0755381, - "model.decoder.block.1.layer.2.DenseReluDense.wi.weight": -0.0045786, - "model.decoder.block.1.layer.2.DenseReluDense.wo.weight": 0.0101685, - "model.decoder.block.1.layer.2.layer_norm.weight": 7.3835659, - "model.decoder.final_layer_norm.weight": 0.8366433, - } - assert parameter_means == parameter_means_expected - - -def test_model_pickleable(model): - import pickle - - pickle.dumps(model) - - -def test_model_without_taskmodule(caplog): - with caplog.at_level("WARNING"): - model = SimpleGenerativeModel( - base_model={ - "_type_": "transformers.AutoModelForSeq2SeqLM", - "pretrained_model_name_or_path": MODEL_ID, - }, - ) - assert model is not None - assert caplog.messages == [ - "No taskmodule is available, so no metrics are set up. Please provide a taskmodule_config " - "to enable metrics for stages ['train', 'val', 'test'].", - "No taskmodule is available, so no generation config will be created. Consider setting " - "taskmodule_config to a valid taskmodule config to use specific setup for generation.", - ] - - -def test_missing_base_model_and_type(): - with pytest.raises(ValueError) as excinfo: - SimpleGenerativeModel() - assert ( - str(excinfo.value) - == "Either base_model or base_model_type must be provided. If base_model is " - "not provided, base_model_type must be a valid model type, " - "e.g. 'transformers.AutoModelForSeq2SeqLM'." - ) - - -def test_model_with_deprecated_base_model_setup(caplog, taskmodule): - with caplog.at_level("WARNING"): - model = SimpleGenerativeModel( - base_model_type="transformers.AutoModelForSeq2SeqLM", - base_model_config=dict(pretrained_model_name_or_path=MODEL_ID), - taskmodule_config=taskmodule.config, - ) - assert model is not None - assert caplog.messages == [ - "The base_model_type and base_model_config arguments are deprecated. Please use base_model. " - "You can use the following code to create the base_model argument: " - "base_model = {'_type_': base_model_type, **base_model_config}", - ] - - -@pytest.fixture(scope="module") -def batch(model): - inputs = { - "input_ids": torch.tensor( - [ - [100, 19, 3, 9, 794, 1708, 1, 0, 0, 0, 0, 0], - [100, 19, 430, 794, 1708, 84, 19, 3, 9, 720, 1200, 1], - ] - ), - "attention_mask": torch.tensor( - [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]] - ), - } - - targets = { - "labels": torch.tensor([[3, 9, 1708, 1, 0], [3, 9, 1200, 1708, 1]]), - "decoder_attention_mask": torch.tensor([[1, 1, 1, 1, 0], [1, 1, 1, 1, 1]]), - } - - return inputs, targets - - -def test_batch(batch, taskmodule): - inputs, targets = batch - input_ids_tokens = [ - taskmodule.tokenizer.convert_ids_to_tokens(input_ids) for input_ids in inputs["input_ids"] - ] - assert input_ids_tokens == [ - [ - "▁This", - "▁is", - "▁", - "a", - "▁test", - "▁document", - "", - "", - "", - "", - "", - "", - ], - [ - "▁This", - "▁is", - "▁another", - "▁test", - "▁document", - "▁which", - "▁is", - "▁", - "a", - "▁bit", - "▁longer", - "", - ], - ] - - labels_tokens = [ - taskmodule.tokenizer.convert_ids_to_tokens(labels) for labels in targets["labels"] - ] - assert labels_tokens == [ - ["▁", "a", "▁document", "", ""], - ["▁", "a", "▁longer", "▁document", ""], - ] - - -def test_training_step(batch, model): - model.train() - torch.manual_seed(42) - metric = model._get_metric(VALIDATION, batch_idx=0) - metric.reset() - loss = model.training_step(batch, batch_idx=0) - assert loss is not None - torch.testing.assert_close(loss, torch.tensor(8.98222827911377)) - - metric_values = metric.compute() - metric_values_float = {key: value.item() for key, value in metric_values.items()} - - # we do not collect metrics during training, so all entries should be NaN - assert len(metric_values_float) > 0 - assert all([math.isnan(value) for value in metric_values_float.values()]) - - model.on_train_epoch_end() - - -def test_validation_step(batch, model): - model.eval() - torch.manual_seed(42) - metric = model._get_metric(VALIDATION, batch_idx=0) - metric.reset() - loss = model.validation_step(batch, batch_idx=0) - assert loss is not None - torch.testing.assert_close(loss, torch.tensor(10.146586418151855)) - - metric_values = metric.compute() - metric_values_float = {key: value.item() for key, value in metric_values.items()} - assert metric_values_float == { - "rouge1_fmeasure": 0.0, - "rouge1_precision": 0.0, - "rouge1_recall": 0.0, - "rouge2_fmeasure": 0.0, - "rouge2_precision": 0.0, - "rouge2_recall": 0.0, - "rougeL_fmeasure": 0.0, - "rougeL_precision": 0.0, - "rougeL_recall": 0.0, - "rougeLsum_fmeasure": 0.0, - "rougeLsum_precision": 0.0, - "rougeLsum_recall": 0.0, - } - - model.on_validation_epoch_end() - - -def test_test_step(batch, model): - model.eval() - torch.manual_seed(42) - metric = model._get_metric(TESTING, batch_idx=0) - metric.reset() - loss = model.test_step(batch, batch_idx=0) - assert loss is not None - torch.testing.assert_close(loss, torch.tensor(10.146586418151855)) - - metric_values = metric.compute() - metric_values_float = {key: value.item() for key, value in metric_values.items()} - assert metric_values_float == { - "rouge1_fmeasure": 0.1111111119389534, - "rouge1_precision": 0.06666667014360428, - "rouge1_recall": 0.3333333432674408, - "rouge2_fmeasure": 0.0, - "rouge2_precision": 0.0, - "rouge2_recall": 0.0, - "rougeL_fmeasure": 0.1111111119389534, - "rougeL_precision": 0.06666667014360428, - "rougeL_recall": 0.3333333432674408, - "rougeLsum_fmeasure": 0.0555555559694767, - "rougeLsum_precision": 0.03333333507180214, - "rougeLsum_recall": 0.1666666716337204, - } - - model.on_test_epoch_end() - - -def test_predict_step(batch, model): - model.eval() - torch.manual_seed(42) - predictions = model.predict_step(batch, batch_idx=0) - labels = predictions["labels"] - assert labels is not None - torch.testing.assert_close( - labels, - torch.tensor( - [ - [32099, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [ - 32099, - 19, - 3, - 9, - 248, - 194, - 12, - 129, - 25, - 708, - 5, - 37, - 166, - 794, - 1708, - 19, - 3, - 9, - 794, - ], - ] - ), - ) - - predicted_tokens = [ - model.taskmodule.tokenizer.convert_ids_to_tokens(label) for label in labels - ] - assert predicted_tokens == [ - [ - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - ], - [ - "", - "▁is", - "▁", - "a", - "▁great", - "▁way", - "▁to", - "▁get", - "▁you", - "▁started", - ".", - "▁The", - "▁first", - "▁test", - "▁document", - "▁is", - "▁", - "a", - "▁test", - ], - ] - - -@pytest.fixture(scope="module") -def optimizer(model): - return model.configure_optimizers() - - -def test_optimizer(optimizer): - assert optimizer is not None - assert isinstance(optimizer, torch.optim.Adam) - assert optimizer.defaults["lr"] == 13e-3 - assert len(optimizer.param_groups) == 1 - param_group = optimizer.param_groups[0] - assert len(param_group["params"]) == 47 - - -def _assert_optimizer( - actual: Optimizer, - expected: Optimizer, - allow_mismatching_param_group_keys: Optional[List[str]] = None, -): - allow_mismatching_param_group_key_set = set(allow_mismatching_param_group_keys or []) - assert actual is not None - assert isinstance(actual, type(expected)) - assert actual.defaults == expected.defaults - assert len(actual.param_groups) == len(expected.param_groups) - for actual_param_group, expected_param_group in zip( - actual.param_groups, expected.param_groups - ): - actual_keys = set(actual_param_group) - allow_mismatching_param_group_key_set - expected_keys = set(expected_param_group) - allow_mismatching_param_group_key_set - assert actual_keys == expected_keys - for key in actual_keys: - # also include the key in the comparison to have it in the assertion error message - assert (key, actual_param_group[key]) == (key, expected_param_group[key]) - - -def test_configure_optimizers_with_warmup(model, optimizer): - warmup_proportion_backup_value = model.warmup_proportion - scheduler_name_backup_value = model.scheduler_name - model.warmup_proportion = 0.1 - model.scheduler_name = "linear" - model.trainer = Trainer(max_epochs=10) - optimizer_and_schedular = model.configure_optimizers() - assert optimizer_and_schedular is not None - assert isinstance(optimizer_and_schedular, tuple) - assert len(optimizer_and_schedular) == 2 - optimizers, schedulers = optimizer_and_schedular - assert len(optimizers) == 1 - _assert_optimizer( - optimizers[0], optimizer, allow_mismatching_param_group_keys=["initial_lr", "lr"] - ) - assert len(schedulers) == 1 - assert set(schedulers[0]) == {"scheduler", "interval"} - scheduler = schedulers[0]["scheduler"] - assert isinstance(scheduler, torch.optim.lr_scheduler.LambdaLR) - assert scheduler.optimizer is optimizers[0] - assert scheduler.base_lrs == [13e-3] - - model.warmup_proportion = warmup_proportion_backup_value - model.scheduler_name = scheduler_name_backup_value diff --git a/tests/models/test_simple_sequence_classification.py b/tests/models/test_simple_sequence_classification.py deleted file mode 100644 index 1fe740cd3..000000000 --- a/tests/models/test_simple_sequence_classification.py +++ /dev/null @@ -1,550 +0,0 @@ -from typing import Dict - -import pytest -import torch -from pytorch_lightning import Trainer -from torch.optim.lr_scheduler import LambdaLR -from transformers.modeling_outputs import SequenceClassifierOutput - -from pie_modules.models import SimpleSequenceClassificationModel -from pie_modules.models.simple_sequence_classification import OutputType -from tests.models import trunc_number - -NUM_CLASSES = 4 - - -@pytest.fixture -def inputs() -> Dict[str, torch.LongTensor]: - result_dict = { - "input_ids": torch.tensor( - [ - [ - 101, - 28998, - 13832, - 3121, - 2340, - 138, - 28996, - 1759, - 1120, - 28999, - 139, - 28997, - 119, - 102, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - ], - [ - 101, - 1752, - 5650, - 119, - 28998, - 13832, - 3121, - 2340, - 144, - 28996, - 1759, - 1120, - 28999, - 145, - 28997, - 119, - 1262, - 1771, - 146, - 119, - 102, - 0, - ], - [ - 101, - 1752, - 5650, - 119, - 28998, - 13832, - 3121, - 2340, - 144, - 28996, - 1759, - 1120, - 145, - 119, - 1262, - 1771, - 28999, - 146, - 28997, - 119, - 102, - 0, - ], - [ - 101, - 1752, - 5650, - 119, - 13832, - 3121, - 2340, - 144, - 1759, - 1120, - 28999, - 145, - 28997, - 119, - 1262, - 1771, - 28998, - 146, - 28996, - 119, - 102, - 0, - ], - [ - 101, - 1752, - 5650, - 119, - 28998, - 13832, - 3121, - 2340, - 150, - 28996, - 1759, - 1120, - 28999, - 151, - 28997, - 119, - 1262, - 1122, - 1771, - 152, - 119, - 102, - ], - [ - 101, - 1752, - 5650, - 119, - 13832, - 3121, - 2340, - 150, - 1759, - 1120, - 151, - 119, - 1262, - 28998, - 1122, - 28996, - 1771, - 28999, - 152, - 28997, - 119, - 102, - ], - [ - 101, - 1752, - 5650, - 119, - 13832, - 3121, - 2340, - 150, - 1759, - 1120, - 151, - 119, - 1262, - 28999, - 1122, - 28997, - 1771, - 28998, - 152, - 28996, - 119, - 102, - ], - ] - ).to(torch.long), - "attention_mask": torch.tensor( - [ - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - ] - ).to(torch.long), - } - - return result_dict - - -@pytest.fixture -def targets() -> Dict[str, torch.LongTensor]: - return {"labels": torch.tensor([0, 1, 2, 3, 1, 2, 3]).to(torch.long)} - - -@pytest.fixture -def model() -> SimpleSequenceClassificationModel: - torch.manual_seed(42) - result = SimpleSequenceClassificationModel( - model_name_or_path="prajjwal1/bert-tiny", - num_classes=NUM_CLASSES, - ) - return result - - -def test_model(model): - assert model is not None - named_parameters = dict(model.named_parameters()) - parameter_means = {k: trunc_number(v.mean().item(), 7) for k, v in named_parameters.items()} - parameter_means_expected = { - "model.bert.embeddings.word_embeddings.weight": 0.0031152, - "model.bert.embeddings.position_embeddings.weight": 5.5e-05, - "model.bert.embeddings.token_type_embeddings.weight": -0.0015419, - "model.bert.embeddings.LayerNorm.weight": 1.312345, - "model.bert.embeddings.LayerNorm.bias": -0.0294608, - "model.bert.encoder.layer.0.attention.self.query.weight": -0.0003949, - "model.bert.encoder.layer.0.attention.self.query.bias": 0.0185744, - "model.bert.encoder.layer.0.attention.self.key.weight": 0.0003863, - "model.bert.encoder.layer.0.attention.self.key.bias": 0.0020557, - "model.bert.encoder.layer.0.attention.self.value.weight": 4.22e-05, - "model.bert.encoder.layer.0.attention.self.value.bias": 0.0065417, - "model.bert.encoder.layer.0.attention.output.dense.weight": 3.01e-05, - "model.bert.encoder.layer.0.attention.output.dense.bias": 0.0007209, - "model.bert.encoder.layer.0.attention.output.LayerNorm.weight": 1.199831, - "model.bert.encoder.layer.0.attention.output.LayerNorm.bias": 0.0608714, - "model.bert.encoder.layer.0.intermediate.dense.weight": -0.0011731, - "model.bert.encoder.layer.0.intermediate.dense.bias": -0.1219958, - "model.bert.encoder.layer.0.output.dense.weight": -0.0002212, - "model.bert.encoder.layer.0.output.dense.bias": -0.0013031, - "model.bert.encoder.layer.0.output.LayerNorm.weight": 1.2419648, - "model.bert.encoder.layer.0.output.LayerNorm.bias": 0.005295, - "model.bert.encoder.layer.1.attention.self.query.weight": -0.0007321, - "model.bert.encoder.layer.1.attention.self.query.bias": -0.0358397, - "model.bert.encoder.layer.1.attention.self.key.weight": 0.0001333, - "model.bert.encoder.layer.1.attention.self.key.bias": 0.0045062, - "model.bert.encoder.layer.1.attention.self.value.weight": 0.0001012, - "model.bert.encoder.layer.1.attention.self.value.bias": -0.0007094, - "model.bert.encoder.layer.1.attention.output.dense.weight": -2.43e-05, - "model.bert.encoder.layer.1.attention.output.dense.bias": 0.0041446, - "model.bert.encoder.layer.1.attention.output.LayerNorm.weight": 1.0377343, - "model.bert.encoder.layer.1.attention.output.LayerNorm.bias": 0.0443237, - "model.bert.encoder.layer.1.intermediate.dense.weight": -0.001344, - "model.bert.encoder.layer.1.intermediate.dense.bias": -0.1247257, - "model.bert.encoder.layer.1.output.dense.weight": -5.32e-05, - "model.bert.encoder.layer.1.output.dense.bias": 0.000677, - "model.bert.encoder.layer.1.output.LayerNorm.weight": 1.017162, - "model.bert.encoder.layer.1.output.LayerNorm.bias": -0.0474442, - "model.bert.pooler.dense.weight": 0.0001295, - "model.bert.pooler.dense.bias": -0.0052078, - "model.classifier.weight": 0.0005538, - "model.classifier.bias": 0.0, - } - assert parameter_means == parameter_means_expected - - -def test_model_pickleable(model): - import pickle - - pickle.dumps(model) - - -@pytest.fixture -def model_output(model, inputs) -> OutputType: - # set seed to make sure the output is deterministic - torch.manual_seed(42) - return model(inputs) - - -def test_forward(model_output, inputs): - batch_size = inputs["input_ids"].shape[0] - assert isinstance(model_output, SequenceClassifierOutput) - assert set(model_output) == {"logits"} - logits = model_output["logits"] - - assert logits.shape == (batch_size, NUM_CLASSES) - - torch.testing.assert_close( - logits, - torch.tensor( - [ - [ - 0.16545572876930237, - 0.17544983327388763, - -0.011048287153244019, - 0.05337674915790558, - ], - [ - 0.14748695492744446, - 0.16249355673789978, - -0.058017998933792114, - 0.025398850440979004, - ], - [ - 0.14271709322929382, - 0.16188383102416992, - -0.061113521456718445, - 0.026494741439819336, - ], - [ - 0.15641027688980103, - 0.17225395143032074, - -0.05567866563796997, - 0.022433891892433167, - ], - [ - 0.15785054862499237, - 0.16935551166534424, - -0.054724469780921936, - 0.012338697910308838, - ], - [ - 0.16152460873126984, - 0.17789196968078613, - -0.053754448890686035, - 0.008724510669708252, - ], - [ - 0.16836002469062805, - 0.17842254042625427, - -0.052499815821647644, - 0.006823211908340454, - ], - ] - ), - ) - - -def test_decode(model, model_output, inputs): - decoded = model.decode(inputs=inputs, outputs=model_output) - assert isinstance(decoded, dict) - assert set(decoded) == {"labels", "probabilities"} - labels = decoded["labels"] - assert labels.shape == (inputs["input_ids"].shape[0],) - torch.testing.assert_close( - labels, - torch.tensor([1, 1, 1, 1, 1, 1, 1]), - ) - probabilities = decoded["probabilities"] - assert probabilities.shape == (inputs["input_ids"].shape[0], NUM_CLASSES) - torch.testing.assert_close( - probabilities, - torch.tensor( - [ - [ - 0.2672215402126312, - 0.26990556716918945, - 0.22398385405540466, - 0.23888900876045227, - ], - [ - 0.26922059059143066, - 0.27329114079475403, - 0.21920911967754364, - 0.23827917873859406, - ], - [0.2684398889541626, 0.2736346125602722, 0.21893969178199768, 0.23898591101169586], - [0.2703087329864502, 0.2746255099773407, 0.21865077316761017, 0.23641489446163177], - [0.2713961601257324, 0.2745365798473358, 0.21942369639873505, 0.2346435934305191], - [ - 0.27165648341178894, - 0.27613937854766846, - 0.21904107928276062, - 0.23316311836242676, - ], - [0.2730168402194977, 0.2757779359817505, 0.21891282498836517, 0.23229233920574188], - ] - ), - ) - - -@pytest.fixture -def batch(inputs, targets): - return inputs, targets - - -def test_training_step(batch, model): - # set the seed to make sure the loss is deterministic - torch.manual_seed(42) - loss = model.training_step(batch, batch_idx=0) - assert loss is not None - torch.testing.assert_close(loss, torch.tensor(1.4069921970367432)) - - -def test_validation_step(batch, model): - # set the seed to make sure the loss is deterministic - torch.manual_seed(42) - loss = model.validation_step(batch, batch_idx=0) - assert loss is not None - torch.testing.assert_close(loss, torch.tensor(1.4069921970367432)) - - -def test_test_step(batch, model): - # set the seed to make sure the loss is deterministic - torch.manual_seed(42) - loss = model.test_step(batch, batch_idx=0) - assert loss is not None - torch.testing.assert_close(loss, torch.tensor(1.4069921970367432)) - - -def test_base_model_named_parameters(model): - base_model_named_parameters = dict(model.base_model_named_parameters()) - assert set(base_model_named_parameters) == { - "model.bert.pooler.dense.bias", - "model.bert.encoder.layer.0.intermediate.dense.weight", - "model.bert.encoder.layer.0.intermediate.dense.bias", - "model.bert.encoder.layer.1.attention.output.dense.weight", - "model.bert.encoder.layer.1.attention.output.LayerNorm.weight", - "model.bert.encoder.layer.1.attention.self.query.weight", - "model.bert.encoder.layer.1.output.dense.weight", - "model.bert.encoder.layer.0.output.dense.bias", - "model.bert.encoder.layer.1.intermediate.dense.bias", - "model.bert.encoder.layer.1.attention.self.value.bias", - "model.bert.encoder.layer.0.attention.output.dense.weight", - "model.bert.encoder.layer.0.attention.self.query.bias", - "model.bert.encoder.layer.0.attention.self.value.bias", - "model.bert.encoder.layer.1.output.dense.bias", - "model.bert.encoder.layer.1.attention.self.query.bias", - "model.bert.encoder.layer.1.attention.output.LayerNorm.bias", - "model.bert.encoder.layer.0.attention.self.query.weight", - "model.bert.encoder.layer.0.attention.output.LayerNorm.bias", - "model.bert.encoder.layer.0.attention.self.key.bias", - "model.bert.encoder.layer.1.intermediate.dense.weight", - "model.bert.encoder.layer.1.output.LayerNorm.bias", - "model.bert.encoder.layer.1.output.LayerNorm.weight", - "model.bert.encoder.layer.0.attention.self.key.weight", - "model.bert.encoder.layer.1.attention.output.dense.bias", - "model.bert.encoder.layer.0.attention.output.dense.bias", - "model.bert.embeddings.LayerNorm.bias", - "model.bert.encoder.layer.0.attention.self.value.weight", - "model.bert.encoder.layer.0.attention.output.LayerNorm.weight", - "model.bert.embeddings.token_type_embeddings.weight", - "model.bert.encoder.layer.0.output.LayerNorm.weight", - "model.bert.embeddings.position_embeddings.weight", - "model.bert.encoder.layer.1.attention.self.key.bias", - "model.bert.embeddings.LayerNorm.weight", - "model.bert.encoder.layer.0.output.LayerNorm.bias", - "model.bert.encoder.layer.1.attention.self.key.weight", - "model.bert.pooler.dense.weight", - "model.bert.encoder.layer.0.output.dense.weight", - "model.bert.embeddings.word_embeddings.weight", - "model.bert.encoder.layer.1.attention.self.value.weight", - } - - -def test_task_named_parameters(model): - task_named_parameters = dict(model.task_named_parameters()) - assert set(task_named_parameters) == { - "model.classifier.weight", - "model.classifier.bias", - } - - -def test_configure_optimizers_with_warmup(): - model = SimpleSequenceClassificationModel( - model_name_or_path="prajjwal1/bert-tiny", - num_classes=NUM_CLASSES, - ) - model.trainer = Trainer(max_epochs=10) - optimizers_and_schedulers = model.configure_optimizers() - assert len(optimizers_and_schedulers) == 2 - optimizers, schedulers = optimizers_and_schedulers - assert len(optimizers) == 1 - assert len(schedulers) == 1 - optimizer = optimizers[0] - assert optimizer is not None - assert isinstance(optimizer, torch.optim.AdamW) - assert optimizer.defaults["lr"] == 1e-05 - assert optimizer.defaults["weight_decay"] == 0.01 - assert optimizer.defaults["eps"] == 1e-08 - - scheduler = schedulers[0] - assert isinstance(scheduler, dict) - assert set(scheduler) == {"scheduler", "interval"} - assert isinstance(scheduler["scheduler"], LambdaLR) - - -def test_configure_optimizers_with_task_learning_rate(monkeypatch): - model = SimpleSequenceClassificationModel( - model_name_or_path="prajjwal1/bert-tiny", - num_classes=NUM_CLASSES, - learning_rate=1e-5, - task_learning_rate=1e-3, - # disable warmup to make sure the scheduler is not added which would set the learning rate - # to 0 - warmup_proportion=0.0, - ) - optimizer = model.configure_optimizers() - assert optimizer is not None - assert isinstance(optimizer, torch.optim.AdamW) - assert len(optimizer.param_groups) == 2 - # base model parameters - param_group = optimizer.param_groups[0] - assert len(param_group["params"]) == 39 - assert param_group["lr"] == 1e-5 - # classifier head parameters - param_group = optimizer.param_groups[1] - assert len(param_group["params"]) == 2 - assert param_group["lr"] == 1e-3 - # ensure that all parameters are covered - assert set(optimizer.param_groups[0]["params"] + optimizer.param_groups[1]["params"]) == set( - model.parameters() - ) - - -def test_freeze_base_model(monkeypatch, inputs, targets): - model = SimpleSequenceClassificationModel( - model_name_or_path="prajjwal1/bert-tiny", - num_classes=NUM_CLASSES, - freeze_base_model=True, - # disable warmup to make sure the scheduler is not added which would set the learning rate - # to 0 - warmup_proportion=0.0, - ) - base_model_params = [param for name, param in model.base_model_named_parameters()] - task_params = [param for name, param in model.task_named_parameters()] - assert len(base_model_params) + len(task_params) == len(list(model.parameters())) - for param in base_model_params: - assert not param.requires_grad - for param in task_params: - assert param.requires_grad - - -def test_base_model_named_parameters_wrong_prefix(monkeypatch): - model = SimpleSequenceClassificationModel( - model_name_or_path="prajjwal1/bert-tiny", - num_classes=NUM_CLASSES, - base_model_prefix="wrong_prefix", - ) - with pytest.raises(ValueError) as excinfo: - model.base_model_named_parameters() - assert ( - str(excinfo.value) - == "Base model with prefix 'wrong_prefix' not found in BertForSequenceClassification" - ) diff --git a/tests/models/test_simple_token_classification.py b/tests/models/test_simple_token_classification.py deleted file mode 100644 index f6b190f6a..000000000 --- a/tests/models/test_simple_token_classification.py +++ /dev/null @@ -1,453 +0,0 @@ -import pytest -import torch - -from pie_modules.models import SimpleTokenClassificationModel -from pie_modules.models.common import TESTING, TRAINING, VALIDATION -from pie_modules.taskmodules import LabeledSpanExtractionByTokenClassificationTaskModule -from tests import _config_to_str -from tests.models import trunc_number - -CONFIGS = [{}] -CONFIG_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS} - - -@pytest.fixture(scope="module", params=CONFIG_DICT.keys()) -def config_str(request): - return request.param - - -@pytest.fixture(scope="module") -def config(config_str): - return CONFIG_DICT[config_str] - - -@pytest.fixture -def taskmodule_config(): - return { - "taskmodule_type": "LabeledSpanExtractionByTokenClassificationTaskModule", - "tokenizer_name_or_path": "bert-base-cased", - "span_annotation": "entities", - "partition_annotation": None, - "label_pad_id": -100, - "labels": ["ORG", "PER"], - "include_ill_formed_predictions": True, - "tokenize_kwargs": None, - "pad_kwargs": None, - "combine_token_scores_method": "mean", - "log_precision_recall_metrics": True, - } - - -def test_taskmodule_config(documents, taskmodule_config): - tokenizer_name_or_path = "bert-base-cased" - taskmodule = LabeledSpanExtractionByTokenClassificationTaskModule( - span_annotation="entities", - tokenizer_name_or_path=tokenizer_name_or_path, - ) - taskmodule.prepare(documents) - assert taskmodule.config == taskmodule_config - - -def test_batch(documents, batch, taskmodule_config): - taskmodule = LabeledSpanExtractionByTokenClassificationTaskModule.from_config( - taskmodule_config - ) - encodings = taskmodule.encode(documents, encode_target=True) - # just take the first 4 encodings - batch_from_documents = taskmodule.collate(encodings[:4]) - - inputs, targets = batch - inputs_from_documents, targets_from_documents = batch_from_documents - assert set(inputs) == set(inputs_from_documents) - torch.testing.assert_close(inputs["input_ids"], inputs_from_documents["input_ids"]) - torch.testing.assert_close(inputs["attention_mask"], inputs_from_documents["attention_mask"]) - torch.testing.assert_close(targets, targets_from_documents) - - -@pytest.fixture -def batch(): - inputs = { - "input_ids": torch.tensor( - [ - [101, 138, 1423, 5650, 119, 102, 0, 0, 0, 0, 0, 0], - [101, 13832, 3121, 2340, 138, 1759, 1120, 139, 119, 102, 0, 0], - [101, 13832, 3121, 2340, 140, 1105, 141, 119, 102, 0, 0, 0], - [101, 1752, 5650, 119, 13832, 3121, 2340, 142, 1105, 143, 119, 102], - ] - ).to(torch.long), - "attention_mask": torch.tensor( - [ - [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - ] - ), - "special_tokens_mask": torch.tensor( - [ - [1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], - [1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1], - [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1], - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], - ] - ), - } - targets = { - "labels": torch.tensor( - [ - [-100, 0, 0, 0, 0, -100, -100, -100, -100, -100, -100, -100], - [-100, 3, 4, 4, 4, 0, 0, 1, 0, -100, -100, -100], - [-100, 3, 4, 4, 4, 0, 1, 0, -100, -100, -100, -100], - [-100, 0, 0, 0, 3, 4, 4, 4, 0, 1, 0, -100], - ] - ) - } - return inputs, targets - - -@pytest.fixture -def model(monkeypatch, batch, config, taskmodule_config) -> SimpleTokenClassificationModel: - torch.manual_seed(42) - model = SimpleTokenClassificationModel( - model_name_or_path="prajjwal1/bert-tiny", - num_classes=5, - taskmodule_config=taskmodule_config, - metric_stages=["val", "test"], - ) - return model - - -def test_model(model): - assert model is not None - named_parameters = dict(model.named_parameters()) - parameter_means = {k: trunc_number(v.mean().item(), 7) for k, v in named_parameters.items()} - parameter_means_expected = { - "model.bert.embeddings.word_embeddings.weight": 0.0031152, - "model.bert.embeddings.position_embeddings.weight": 5.5e-05, - "model.bert.embeddings.token_type_embeddings.weight": -0.0015419, - "model.bert.embeddings.LayerNorm.weight": 1.312345, - "model.bert.embeddings.LayerNorm.bias": -0.0294608, - "model.bert.encoder.layer.0.attention.self.query.weight": -0.0003949, - "model.bert.encoder.layer.0.attention.self.query.bias": 0.0185744, - "model.bert.encoder.layer.0.attention.self.key.weight": 0.0003863, - "model.bert.encoder.layer.0.attention.self.key.bias": 0.0020557, - "model.bert.encoder.layer.0.attention.self.value.weight": 4.22e-05, - "model.bert.encoder.layer.0.attention.self.value.bias": 0.0065417, - "model.bert.encoder.layer.0.attention.output.dense.weight": 3.01e-05, - "model.bert.encoder.layer.0.attention.output.dense.bias": 0.0007209, - "model.bert.encoder.layer.0.attention.output.LayerNorm.weight": 1.199831, - "model.bert.encoder.layer.0.attention.output.LayerNorm.bias": 0.0608714, - "model.bert.encoder.layer.0.intermediate.dense.weight": -0.0011731, - "model.bert.encoder.layer.0.intermediate.dense.bias": -0.1219958, - "model.bert.encoder.layer.0.output.dense.weight": -0.0002212, - "model.bert.encoder.layer.0.output.dense.bias": -0.0013031, - "model.bert.encoder.layer.0.output.LayerNorm.weight": 1.2419648, - "model.bert.encoder.layer.0.output.LayerNorm.bias": 0.005295, - "model.bert.encoder.layer.1.attention.self.query.weight": -0.0007321, - "model.bert.encoder.layer.1.attention.self.query.bias": -0.0358397, - "model.bert.encoder.layer.1.attention.self.key.weight": 0.0001333, - "model.bert.encoder.layer.1.attention.self.key.bias": 0.0045062, - "model.bert.encoder.layer.1.attention.self.value.weight": 0.0001012, - "model.bert.encoder.layer.1.attention.self.value.bias": -0.0007094, - "model.bert.encoder.layer.1.attention.output.dense.weight": -2.43e-05, - "model.bert.encoder.layer.1.attention.output.dense.bias": 0.0041446, - "model.bert.encoder.layer.1.attention.output.LayerNorm.weight": 1.0377343, - "model.bert.encoder.layer.1.attention.output.LayerNorm.bias": 0.0443237, - "model.bert.encoder.layer.1.intermediate.dense.weight": -0.001344, - "model.bert.encoder.layer.1.intermediate.dense.bias": -0.1247257, - "model.bert.encoder.layer.1.output.dense.weight": -5.32e-05, - "model.bert.encoder.layer.1.output.dense.bias": 0.000677, - "model.bert.encoder.layer.1.output.LayerNorm.weight": 1.017162, - "model.bert.encoder.layer.1.output.LayerNorm.bias": -0.0474442, - "model.classifier.weight": 0.0005138, - "model.classifier.bias": 0.0, - } - assert parameter_means == parameter_means_expected - - -def test_model_pickleable(model): - import pickle - - pickle.dumps(model) - - -def test_forward(batch, model): - inputs, targets = batch - batch_size, seq_len = inputs["input_ids"].shape - num_classes = model.config["num_classes"] - - # set seed to make sure the output is deterministic - torch.manual_seed(42) - output = model.forward(inputs) - assert set(output) == {"logits"} - logits = output["logits"] - assert logits.shape == (batch_size, seq_len, num_classes) - # check the first batch entry - torch.testing.assert_close( - logits[0], - torch.tensor( - [ - [ - -0.13442197442054749, - -0.06983129680156708, - 0.17513807117938995, - -0.24002864956855774, - 0.08871676027774811, - ], - [ - -0.032687313854694366, - -0.2071131318807602, - 0.10695032775402069, - -0.05829116329550743, - -0.21174949407577515, - ], - [ - -0.17153336107730865, - -0.2230629324913025, - -0.11457862704992294, - 0.03658870607614517, - -0.242639422416687, - ], - [ - -0.07552017271518707, - -0.20950022339820862, - 0.041016221046447754, - -0.13453879952430725, - -0.09942213445901871, - ], - [ - -0.19299760460853577, - -0.2081824392080307, - 0.20880958437919617, - -0.028745755553245544, - -0.14375154674053192, - ], - [ - -0.20548884570598602, - -0.17012161016464233, - 0.0647551566362381, - -0.090476393699646, - -0.1362220048904419, - ], - [ - -0.09553629904985428, - -0.1303575187921524, - 0.2995688021183014, - -0.04689876735210419, - -0.17737819254398346, - ], - [ - -0.030023209750652313, - -0.12308696657419205, - 0.2582213580608368, - -0.04085375368595123, - -0.16487300395965576, - ], - [ - -0.04765648394823074, - -0.18347612023353577, - 0.24941012263298035, - 0.022468380630016327, - -0.19706891477108002, - ], - [ - -0.09828818589448929, - -0.18449409306049347, - 0.2711920738220215, - 0.044708192348480225, - -0.15743865072727203, - ], - [ - -0.13639293611049652, - -0.16482298076152802, - 0.3018418848514557, - 0.0815257728099823, - -0.15574774146080017, - ], - [ - -0.14846578240394592, - -0.17294010519981384, - 0.31513816118240356, - 0.10425455123186111, - -0.16388092935085297, - ], - ] - ), - ) - - # check the sums per sequence - torch.testing.assert_close( - logits.sum(1), - torch.tensor( - [ - [ - -1.3690122365951538, - -2.0469894409179688, - 2.1774630546569824, - -0.35028770565986633, - -1.7614551782608032, - ], - [ - -0.892522394657135, - -1.3144632577896118, - 2.683281898498535, - -1.4629074335098267, - -3.3516180515289307, - ], - [ - -1.3936796188354492, - 0.21844607591629028, - 4.501010417938232, - -0.15485064685344696, - -2.651848316192627, - ], - [ - -1.7388781309127808, - -0.7211084365844727, - 3.463726043701172, - -0.2992384433746338, - -2.65508770942688, - ], - ] - ), - ) - - -def test_training_step_and_on_epoch_end(batch, model, config): - assert model._get_metric(TRAINING) is None - loss = model.training_step(batch, batch_idx=0) - assert loss is not None - torch.testing.assert_close(loss, torch.tensor(1.730902075767517)) - - model.on_train_epoch_end() - - -def test_validation_step_and_on_epoch_end(batch, model, config): - metric = model._get_metric(VALIDATION) - metric.reset() - loss = model.validation_step(batch, batch_idx=0) - assert loss is not None - torch.testing.assert_close(loss, torch.tensor(1.730902075767517)) - metric_values = {k: v.item() for k, v in metric.compute().items()} - assert metric_values == { - "span/ORG/f1": 0.0, - "span/ORG/precision": 0.0, - "span/ORG/recall": 0.0, - "span/PER/f1": 0.0, - "span/PER/precision": 0.0, - "span/PER/recall": 0.0, - "span/macro/f1": 0.0, - "span/macro/precision": 0.0, - "span/macro/recall": 0.0, - "span/micro/f1": 0.0, - "span/micro/precision": 0.0, - "span/micro/recall": 0.0, - "token/macro/f1": 0.0, - "token/micro/f1": 0.0, - "token/macro/precision": 0.0, - "token/macro/recall": 0.0, - "token/micro/precision": 0.0, - "token/micro/recall": 0.0, - } - - model.on_validation_epoch_end() - - -def test_test_step_and_on_epoch_end(batch, model, config): - metric = model._get_metric(TESTING) - metric.reset() - loss = model.test_step(batch, batch_idx=0) - assert loss is not None - torch.testing.assert_close(loss, torch.tensor(1.730902075767517)) - metric_values = {k: v.item() for k, v in metric.compute().items()} - assert metric_values == { - "span/ORG/f1": 0.0, - "span/ORG/precision": 0.0, - "span/ORG/recall": 0.0, - "span/PER/f1": 0.0, - "span/PER/precision": 0.0, - "span/PER/recall": 0.0, - "span/macro/f1": 0.0, - "span/macro/precision": 0.0, - "span/macro/recall": 0.0, - "span/micro/f1": 0.0, - "span/micro/precision": 0.0, - "span/micro/recall": 0.0, - "token/macro/f1": 0.0, - "token/micro/f1": 0.0, - "token/macro/precision": 0.0, - "token/macro/recall": 0.0, - "token/micro/precision": 0.0, - "token/micro/recall": 0.0, - } - - model.on_test_epoch_end() - - -@pytest.mark.parametrize("test_step", [False, True]) -def test_predict_and_predict_step(model, batch, config, test_step): - torch.manual_seed(42) - if test_step: - predictions = model.predict_step(batch, batch_idx=0, dataloader_idx=0) - else: - predictions = model.predict(batch[0]) - assert set(predictions) == {"labels", "probabilities"} - - assert predictions["labels"].shape == batch[1]["labels"].shape - torch.testing.assert_close( - predictions["labels"], - torch.tensor( - [ - [-100, 2, 3, 2, 2, -100, -100, -100, -100, -100, -100, -100], - [-100, 2, 2, 2, 2, 2, 2, 2, 2, -100, -100, -100], - [-100, 2, 2, 2, 2, 2, 2, 2, -100, -100, -100, -100], - [-100, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, -100], - ] - ), - ) - torch.testing.assert_close( - # just check the first two batch entries - predictions["probabilities"][:2].round(decimals=4), - torch.tensor( - [ - [ - [0.1792, 0.1912, 0.2443, 0.1613, 0.2240], - [0.2083, 0.1750, 0.2395, 0.2030, 0.1742], - [0.1934, 0.1837, 0.2047, 0.2381, 0.1801], - [0.2034, 0.1779, 0.2285, 0.1917, 0.1986], - [0.1752, 0.1725, 0.2618, 0.2065, 0.1840], - [0.1805, 0.1870, 0.2365, 0.2025, 0.1935], - [0.1844, 0.1781, 0.2738, 0.1936, 0.1700], - [0.1958, 0.1784, 0.2612, 0.1937, 0.1711], - [0.1941, 0.1694, 0.2612, 0.2082, 0.1671], - [0.1831, 0.1680, 0.2650, 0.2113, 0.1726], - [0.1740, 0.1691, 0.2697, 0.2164, 0.1707], - [0.1713, 0.1672, 0.2723, 0.2205, 0.1687], - ], - [ - [0.1654, 0.1989, 0.2729, 0.1542, 0.2086], - [0.1787, 0.1511, 0.3093, 0.1968, 0.1641], - [0.1888, 0.1966, 0.2365, 0.2081, 0.1700], - [0.2092, 0.1935, 0.2428, 0.2034, 0.1511], - [0.2275, 0.1784, 0.2546, 0.1856, 0.1539], - [0.2254, 0.1959, 0.2377, 0.1873, 0.1536], - [0.2177, 0.1879, 0.2485, 0.1975, 0.1484], - [0.2227, 0.1906, 0.2541, 0.1906, 0.1420], - [0.2080, 0.2098, 0.2667, 0.1764, 0.1391], - [0.1815, 0.2015, 0.2600, 0.1852, 0.1718], - [0.1672, 0.1883, 0.3065, 0.1773, 0.1607], - [0.1750, 0.1846, 0.2911, 0.1862, 0.1630], - ], - ] - ), - ) - - -def test_configure_optimizers(model): - optimizer = model.configure_optimizers() - assert optimizer is not None - assert isinstance(optimizer, torch.optim.Adam) - assert optimizer.defaults["lr"] == 1e-05 - assert len(optimizer.param_groups) == 1 - assert len(optimizer.param_groups[0]["params"]) > 0 - assert set(optimizer.param_groups[0]["params"]) == set(model.parameters()) diff --git a/tests/models/test_span_tuple_classification.py b/tests/models/test_span_tuple_classification.py deleted file mode 100644 index f984fdbdb..000000000 --- a/tests/models/test_span_tuple_classification.py +++ /dev/null @@ -1,549 +0,0 @@ -import pytest -import torch -from pytorch_lightning import Trainer -from torch import tensor - -from pie_modules.models import SpanTupleClassificationModel -from pie_modules.models.common import TESTING, TRAINING, VALIDATION -from pie_modules.taskmodules import RESpanPairClassificationTaskModule -from tests import _config_to_str - -CONFIGS = [{}] -CONFIG_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS} -NUM_CLASSES = 4 - - -@pytest.fixture(scope="module", params=CONFIG_DICT.keys()) -def config_str(request): - return request.param - - -@pytest.fixture(scope="module") -def config(config_str): - return CONFIG_DICT[config_str] - - -@pytest.fixture -def taskmodule_config(): - return { - "taskmodule_type": "RESpanPairClassificationTaskModule", - "tokenizer_name_or_path": "bert-base-cased", - "relation_annotation": "relations", - "no_relation_label": "no_relation", - "partition_annotation": None, - "tokenize_kwargs": None, - "create_candidate_relations": False, - "create_candidate_relations_kwargs": None, - "labels": ["org:founded_by", "per:employee_of", "per:founder"], - "entity_labels": ["ORG", "PER"], - "add_type_to_marker": True, - "log_first_n_examples": 0, - "collect_statistics": False, - } - - -def test_taskmodule_config(documents, taskmodule_config): - tokenizer_name_or_path = "bert-base-cased" - taskmodule = RESpanPairClassificationTaskModule( - relation_annotation="relations", - tokenizer_name_or_path=tokenizer_name_or_path, - ) - taskmodule.prepare(documents) - assert taskmodule.config == taskmodule_config - assert len(taskmodule.id_to_label) == NUM_CLASSES - - -def test_batch(documents, batch, taskmodule_config): - taskmodule = RESpanPairClassificationTaskModule.from_config(taskmodule_config) - encodings = taskmodule.encode(documents, encode_target=True, as_dataset=True) - batch_from_documents = taskmodule.collate(encodings[:4]) - - inputs, targets = batch - inputs_from_documents, targets_from_documents = batch_from_documents - assert set(inputs) == set(inputs_from_documents) - for key in inputs: - torch.testing.assert_close(inputs[key], inputs_from_documents[key]) - assert set(targets) == set(targets_from_documents) - for key in targets: - torch.testing.assert_close(targets[key], targets_from_documents[key]) - - -@pytest.fixture -def batch(): - inputs = { - "input_ids": tensor( - [ - [ - 101, - 28996, - 13832, - 3121, - 2340, - 138, - 28998, - 1759, - 1120, - 28999, - 139, - 28997, - 119, - 102, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - ], - [ - 101, - 1752, - 5650, - 119, - 28996, - 13832, - 3121, - 2340, - 144, - 28998, - 1759, - 1120, - 28999, - 145, - 28997, - 119, - 1262, - 1771, - 28999, - 146, - 28997, - 119, - 102, - 0, - 0, - 0, - ], - [ - 101, - 1752, - 5650, - 119, - 28996, - 13832, - 3121, - 2340, - 150, - 28998, - 1759, - 1120, - 28999, - 151, - 28997, - 119, - 1262, - 28996, - 1122, - 28998, - 1771, - 28999, - 152, - 28997, - 119, - 102, - ], - ] - ), - "attention_mask": tensor( - [ - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - ] - ), - "span_start_indices": tensor([[1, 9, 0, 0], [4, 12, 18, 0], [4, 12, 17, 21]]), - "span_end_indices": tensor([[7, 12, 0, 0], [10, 15, 21, 0], [10, 15, 20, 24]]), - "tuple_indices": tensor( - [[[0, 1], [-1, -1], [-1, -1]], [[0, 1], [0, 2], [2, 1]], [[0, 1], [2, 3], [3, 2]]] - ), - "tuple_indices_mask": tensor( - [[True, False, False], [True, True, True], [True, True, True]] - ), - } - targets = {"labels": tensor([[2, -100, -100], [2, 3, 1], [2, 3, 1]])} - return inputs, targets - - -@pytest.fixture -def model(batch, config, taskmodule_config) -> SpanTupleClassificationModel: - torch.manual_seed(42) - model = SpanTupleClassificationModel( - model_name_or_path="prajjwal1/bert-tiny", - num_classes=NUM_CLASSES, - taskmodule_config=taskmodule_config, - metric_stages=["val", "test"], - **config, - ) - return model - - -def test_model_pickleable(model): - import pickle - - pickle.dumps(model) - - -def test_freeze_base_model(): - model = SpanTupleClassificationModel( - model_name_or_path="prajjwal1/bert-tiny", - num_classes=NUM_CLASSES, - freeze_base_model=True, - ) - - base_model_params = dict(model.model.named_parameters(prefix="model")) - assert len(base_model_params) > 0 - for param in base_model_params.values(): - assert not param.requires_grad - task_params = { - name: param for name, param in model.named_parameters() if name not in base_model_params - } - assert len(task_params) > 0 - for param in task_params.values(): - assert param.requires_grad - - -def test_tune_base_model(): - model = SpanTupleClassificationModel( - model_name_or_path="prajjwal1/bert-tiny", - num_classes=5, - ) - base_model_params = dict(model.model.named_parameters(prefix="model")) - assert len(base_model_params) > 0 - for param in base_model_params.values(): - assert param.requires_grad - task_params = { - name: param for name, param in model.named_parameters() if name not in base_model_params - } - assert len(task_params) > 0 - for param in task_params.values(): - assert param.requires_grad - - -@pytest.mark.parametrize( - "span_embedding_mode", ["start_and_end_token", "start_token", "end_token"] -) -@pytest.mark.parametrize( - "tuple_embedding_mode", ["concat", "multiply2_and_concat", "index_0", "index_1"] -) -def test_forward_embeddings(batch, taskmodule_config, span_embedding_mode, tuple_embedding_mode): - torch.manual_seed(42) - simple_model = SpanTupleClassificationModel( - model_name_or_path="prajjwal1/bert-tiny", - num_classes=NUM_CLASSES, - # disable the tuple mlp to allow for checking the intermediate embeddings via the indices - tuple_entry_hidden_dim=None, - taskmodule_config=taskmodule_config, - span_embedding_mode=span_embedding_mode, - tuple_embedding_mode=tuple_embedding_mode, - ) - - inputs, targets = batch - batch_size, seq_len = inputs["input_ids"].shape - - # set seed to make sure the output is deterministic - torch.manual_seed(42) - # return embeddings to check the logits - output = simple_model.forward(inputs, return_embeddings=True) - assert set(output) == {"logits", "last_hidden_state", "span_embeddings", "tuple_embeddings"} - logits_flat = output["logits"] - assert len(logits_flat.shape) == 2 - assert logits_flat.shape[-1] == NUM_CLASSES - - # check span_embeddings: they should be the entries of last_hidden_state at the - # span_start_indices and span_end_indices - for batch_idx in range(batch_size): - for j, (start, end) in enumerate( - zip(inputs["span_start_indices"][batch_idx], inputs["span_end_indices"][batch_idx]) - ): - current_expected_span_embedding_list = [] - if simple_model.span_embedding_mode == "start_and_end_token": - current_expected_span_embedding_list.append( - output["last_hidden_state"][batch_idx, start] - ) - current_expected_span_embedding_list.append( - output["last_hidden_state"][batch_idx, end] - ) - elif simple_model.span_embedding_mode == "start_token": - current_expected_span_embedding_list.append( - output["last_hidden_state"][batch_idx, start] - ) - elif simple_model.span_embedding_mode == "end_token": - current_expected_span_embedding_list.append( - output["last_hidden_state"][batch_idx, end] - ) - else: - raise ValueError( - f"Unknown span_embedding_mode: {simple_model.span_embedding_mode}" - ) - expected_current_span_embedding = torch.concat( - current_expected_span_embedding_list, dim=-1 - ) - current_span_embeddings = output["span_embeddings"][batch_idx, j] - torch.testing.assert_close(current_span_embeddings, expected_current_span_embedding) - - # check tuple_embeddings: they should be the entries of span_embeddings at the tuple_indices - tuple_idx = 0 - for batch_idx in range(batch_size): - for indices, is_valid in zip( - inputs["tuple_indices"][batch_idx], inputs["tuple_indices_mask"][batch_idx] - ): - if is_valid: - current_expected_tuple_embedding_list = [ - output["span_embeddings"][batch_idx, idx] for idx in indices - ] - if simple_model.tuple_embedding_mode == "concat": - expected_current_tuple_embedding = torch.concat( - current_expected_tuple_embedding_list, dim=-1 - ) - elif simple_model.tuple_embedding_mode == "multiply2_and_concat": - expected_current_tuple_embedding = torch.cat( - [ - current_expected_tuple_embedding_list[0] - * current_expected_tuple_embedding_list[1], - current_expected_tuple_embedding_list[0], - current_expected_tuple_embedding_list[1], - ], - dim=-1, - ) - elif simple_model.tuple_embedding_mode.startswith("index_"): - idx = int(simple_model.tuple_embedding_mode.split("_")[1]) - expected_current_tuple_embedding = current_expected_tuple_embedding_list[idx] - else: - raise ValueError( - f"Unknown tuple_embedding_mode: {simple_model.tuple_embedding_mode}" - ) - current_tuple_embedding = output["tuple_embeddings"][tuple_idx] - torch.testing.assert_close( - current_tuple_embedding, expected_current_tuple_embedding - ) - tuple_idx += 1 - - -def test_forward_logits(batch, model): - inputs, targets = batch - - # set seed to make sure the output is deterministic - torch.manual_seed(42) - # return embeddings to check the logits - output = model.forward(inputs) - assert set(output) == {"logits"} - logits_flat = output["logits"] - assert len(logits_flat.shape) == 2 - assert logits_flat.shape[-1] == NUM_CLASSES - # check the actual logits - torch.testing.assert_close( - logits_flat, - tensor( - [ - [ - -0.23075447976589203, - 0.08129829168319702, - -0.26441076397895813, - 0.3208361268043518, - ], - [ - -0.2247302085161209, - 0.21453489363193512, - -0.20609508454799652, - 0.2984844148159027, - ], - [ - -0.0552724152803421, - 0.18319237232208252, - -0.14115819334983826, - 0.23137536644935608, - ], - [ - -0.2897184491157532, - 0.17462071776390076, - -0.12004873156547546, - 0.1817789375782013, - ], - [ - -0.3101494312286377, - 0.18245069682598114, - -0.13525372743606567, - 0.28625163435935974, - ], - [ - -0.33728304505348206, - 0.22038179636001587, - -0.0482308566570282, - 0.25237396359443665, - ], - [ - -0.3835912048816681, - 0.20549766719341278, - 0.15333643555641174, - 0.23370930552482605, - ], - ] - ), - ) - - -def test_step(batch, model, config): - torch.manual_seed(42) - loss = model._step("train", batch) - assert loss is not None - if config == {}: - torch.testing.assert_close(loss, torch.tensor(1.3911350965499878)) - else: - raise ValueError(f"Unknown config: {config}") - - -def test_training_step_and_on_epoch_end(batch, model, config): - metric = model._get_metric(TRAINING, batch_idx=0) - assert metric is None - loss = model.training_step(batch, batch_idx=0) - assert loss is not None - if config == {}: - torch.testing.assert_close(loss, torch.tensor(1.3911350965499878)) - else: - raise ValueError(f"Unknown config: {config}") - - model.on_train_epoch_end() - - -def test_validation_step_and_on_epoch_end(batch, model, config): - metric = model._get_metric(VALIDATION, batch_idx=0) - metric.reset() - loss = model.validation_step(batch, batch_idx=0) - assert loss is not None - metric_values = {k: v.item() for k, v in metric.compute().items()} - if config == {}: - torch.testing.assert_close(loss, torch.tensor(1.3911350965499878)) - assert metric_values == { - "macro/f1": 0.14814814925193787, - "micro/f1": 0.2857142984867096, - "no_relation/f1": 0.0, - "org:founded_by/f1": 0.0, - "per:employee_of/f1": 0.0, - "per:founder/f1": 0.4444444477558136, - } - else: - raise ValueError(f"Unknown config: {config}") - - model.on_validation_epoch_end() - - -def test_test_step_and_on_epoch_end(batch, model, config): - metric = model._get_metric(TESTING, batch_idx=0) - metric.reset() - loss = model.test_step(batch, batch_idx=0) - assert loss is not None - metric_values = {k: v.item() for k, v in metric.compute().items()} - if config == {}: - torch.testing.assert_close(loss, torch.tensor(1.3911350965499878)) - assert metric_values == { - "macro/f1": 0.14814814925193787, - "micro/f1": 0.2857142984867096, - "no_relation/f1": 0.0, - "org:founded_by/f1": 0.0, - "per:employee_of/f1": 0.0, - "per:founder/f1": 0.4444444477558136, - } - else: - raise ValueError(f"Unknown config: {config}") - - model.on_test_epoch_end() - - -@pytest.mark.parametrize("test_step", [False, True]) -def test_predict_and_predict_step(model, batch, config, test_step): - torch.manual_seed(42) - if test_step: - predictions = model.predict_step(batch, batch_idx=0, dataloader_idx=0) - else: - predictions = model.predict(batch[0]) - - assert set(predictions) == {"labels", "probabilities"} - labels = predictions["labels"] - assert labels.shape == batch[1]["labels"].shape - probabilities = predictions["probabilities"] - if config == {}: - torch.testing.assert_close(labels, tensor([[3, -100, -100], [3, 3, 3], [3, 3, 3]])) - torch.testing.assert_close( - probabilities.round(decimals=4), - tensor( - [ - [ - [0.1973, 0.2695, 0.1907, 0.3425], - [-1.0000, -1.0000, -1.0000, -1.0000], - [-1.0000, -1.0000, -1.0000, -1.0000], - ], - [ - [0.1902, 0.2951, 0.1938, 0.3209], - [0.2213, 0.2809, 0.2031, 0.2947], - [0.1859, 0.2958, 0.2203, 0.2979], - ], - [ - [0.1772, 0.2900, 0.2111, 0.3217], - [0.1699, 0.2968, 0.2269, 0.3064], - [0.1571, 0.2831, 0.2687, 0.2912], - ], - ], - ), - ) - else: - raise ValueError(f"Unknown config: {config}") - - -def test_configure_optimizers(model): - model.trainer = Trainer(max_epochs=10) - optimizer_and_schedular = model.configure_optimizers() - assert optimizer_and_schedular is not None - optimizers, schedulers = optimizer_and_schedular - - assert len(optimizers) == 1 - optimizer = optimizers[0] - assert isinstance(optimizer, torch.optim.AdamW) - assert optimizer.defaults["lr"] == 1e-05 - assert optimizer.defaults["weight_decay"] == 0.01 - assert optimizer.defaults["eps"] == 1e-08 - - assert len(schedulers) == 1 - scheduler = schedulers[0] - assert isinstance(scheduler["scheduler"], torch.optim.lr_scheduler.LambdaLR) - - -def test_configure_optimizers_with_task_learning_rate(): - model = SpanTupleClassificationModel( - model_name_or_path="prajjwal1/bert-tiny", - num_classes=NUM_CLASSES, - warmup_proportion=0.0, - task_learning_rate=1e-4, - ) - optimizer = model.configure_optimizers() - assert optimizer is not None - assert isinstance(optimizer, torch.optim.AdamW) - assert len(optimizer.param_groups) == 2 - # check that all parameters are in the optimizer - assert set(optimizer.param_groups[0]["params"]) | set( - optimizer.param_groups[1]["params"] - ) == set(model.parameters()) - - # base model parameters - param_group = optimizer.param_groups[0] - assert param_group["lr"] == 1e-05 - assert len(param_group["params"]) == 39 - - # task parameters - param_group = optimizer.param_groups[1] - assert param_group["lr"] == 1e-04 - assert len(param_group["params"]) == 6 diff --git a/tests/models/test_token_classification_with_seq2seq_encoder_and_crf.py b/tests/models/test_token_classification_with_seq2seq_encoder_and_crf.py deleted file mode 100644 index 5b31b2f13..000000000 --- a/tests/models/test_token_classification_with_seq2seq_encoder_and_crf.py +++ /dev/null @@ -1,639 +0,0 @@ -import pytest -import torch -from pytorch_lightning import Trainer - -from pie_modules.models import TokenClassificationModelWithSeq2SeqEncoderAndCrf -from pie_modules.models.common import TESTING, TRAINING, VALIDATION -from pie_modules.taskmodules import LabeledSpanExtractionByTokenClassificationTaskModule -from tests import _config_to_str -from tests.models import trunc_number - -CONFIGS = [{}, {"use_crf": False}] -CONFIG_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS} - - -@pytest.fixture(scope="module", params=CONFIG_DICT.keys()) -def config_str(request): - return request.param - - -@pytest.fixture(scope="module") -def config(config_str): - return CONFIG_DICT[config_str] - - -@pytest.fixture -def taskmodule_config(): - return { - "taskmodule_type": "LabeledSpanExtractionByTokenClassificationTaskModule", - "tokenizer_name_or_path": "bert-base-cased", - "span_annotation": "entities", - "partition_annotation": None, - "label_pad_id": -100, - "labels": ["ORG", "PER"], - "include_ill_formed_predictions": True, - "combine_token_scores_method": "mean", - "tokenize_kwargs": None, - "pad_kwargs": None, - "log_precision_recall_metrics": True, - } - - -def test_taskmodule_config(documents, taskmodule_config): - tokenizer_name_or_path = "bert-base-cased" - taskmodule = LabeledSpanExtractionByTokenClassificationTaskModule( - span_annotation="entities", - tokenizer_name_or_path=tokenizer_name_or_path, - ) - taskmodule.prepare(documents) - assert taskmodule.config == taskmodule_config - - -def test_batch(documents, batch, taskmodule_config): - taskmodule = LabeledSpanExtractionByTokenClassificationTaskModule.from_config( - taskmodule_config - ) - encodings = taskmodule.encode(documents, encode_target=True, as_dataset=True) - batch_from_documents = taskmodule.collate(encodings[:4]) - - inputs, targets = batch - inputs_from_documents, targets_from_documents = batch_from_documents - torch.testing.assert_close(inputs["input_ids"], inputs_from_documents["input_ids"]) - torch.testing.assert_close(inputs["attention_mask"], inputs_from_documents["attention_mask"]) - torch.testing.assert_close(targets, targets_from_documents) - - -@pytest.fixture -def batch(): - inputs = { - "input_ids": torch.tensor( - [ - [101, 138, 1423, 5650, 119, 102, 0, 0, 0, 0, 0, 0], - [101, 13832, 3121, 2340, 138, 1759, 1120, 139, 119, 102, 0, 0], - [101, 13832, 3121, 2340, 140, 1105, 141, 119, 102, 0, 0, 0], - [101, 1752, 5650, 119, 13832, 3121, 2340, 142, 1105, 143, 119, 102], - ] - ).to(torch.long), - "attention_mask": torch.tensor( - [ - [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - ] - ), - "special_tokens_mask": torch.tensor( - [ - [1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], - [1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1], - [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1], - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], - ] - ), - } - targets = { - "labels": torch.tensor( - [ - [-100, 0, 0, 0, 0, -100, -100, -100, -100, -100, -100, -100], - [-100, 3, 4, 4, 4, 0, 0, 1, 0, -100, -100, -100], - [-100, 3, 4, 4, 4, 0, 1, 0, -100, -100, -100, -100], - [-100, 0, 0, 0, 3, 4, 4, 4, 0, 1, 0, -100], - ] - ) - } - return inputs, targets - - -@pytest.fixture -def model(batch, config, taskmodule_config) -> TokenClassificationModelWithSeq2SeqEncoderAndCrf: - seq2seq_dict = { - "type": "linear", - "out_features": 10, - } - torch.manual_seed(42) - model = TokenClassificationModelWithSeq2SeqEncoderAndCrf( - model_name_or_path="prajjwal1/bert-tiny", - num_classes=5, - seq2seq_encoder=seq2seq_dict, - taskmodule_config=taskmodule_config, - metric_stages=["val", "test"], - **config, - ) - return model - - -def test_model(model, config): - assert model is not None - named_parameters = dict(model.named_parameters()) - parameter_means = {k: trunc_number(v.mean().item(), 7) for k, v in named_parameters.items()} - parameter_means_expected = { - "model.embeddings.word_embeddings.weight": 0.0031152, - "model.embeddings.position_embeddings.weight": 5.5e-05, - "model.embeddings.token_type_embeddings.weight": -0.0015419, - "model.embeddings.LayerNorm.weight": 1.312345, - "model.embeddings.LayerNorm.bias": -0.0294608, - "model.encoder.layer.0.attention.self.query.weight": -0.0003949, - "model.encoder.layer.0.attention.self.query.bias": 0.0185744, - "model.encoder.layer.0.attention.self.key.weight": 0.0003863, - "model.encoder.layer.0.attention.self.key.bias": 0.0020557, - "model.encoder.layer.0.attention.self.value.weight": 4.22e-05, - "model.encoder.layer.0.attention.self.value.bias": 0.0065417, - "model.encoder.layer.0.attention.output.dense.weight": 3.01e-05, - "model.encoder.layer.0.attention.output.dense.bias": 0.0007209, - "model.encoder.layer.0.attention.output.LayerNorm.weight": 1.199831, - "model.encoder.layer.0.attention.output.LayerNorm.bias": 0.0608714, - "model.encoder.layer.0.intermediate.dense.weight": -0.0011731, - "model.encoder.layer.0.intermediate.dense.bias": -0.1219958, - "model.encoder.layer.0.output.dense.weight": -0.0002212, - "model.encoder.layer.0.output.dense.bias": -0.0013031, - "model.encoder.layer.0.output.LayerNorm.weight": 1.2419648, - "model.encoder.layer.0.output.LayerNorm.bias": 0.005295, - "model.encoder.layer.1.attention.self.query.weight": -0.0007321, - "model.encoder.layer.1.attention.self.query.bias": -0.0358397, - "model.encoder.layer.1.attention.self.key.weight": 0.0001333, - "model.encoder.layer.1.attention.self.key.bias": 0.0045062, - "model.encoder.layer.1.attention.self.value.weight": 0.0001012, - "model.encoder.layer.1.attention.self.value.bias": -0.0007094, - "model.encoder.layer.1.attention.output.dense.weight": -2.43e-05, - "model.encoder.layer.1.attention.output.dense.bias": 0.0041446, - "model.encoder.layer.1.attention.output.LayerNorm.weight": 1.0377343, - "model.encoder.layer.1.attention.output.LayerNorm.bias": 0.0443237, - "model.encoder.layer.1.intermediate.dense.weight": -0.001344, - "model.encoder.layer.1.intermediate.dense.bias": -0.1247257, - "model.encoder.layer.1.output.dense.weight": -5.32e-05, - "model.encoder.layer.1.output.dense.bias": 0.000677, - "model.encoder.layer.1.output.LayerNorm.weight": 1.017162, - "model.encoder.layer.1.output.LayerNorm.bias": -0.0474442, - "model.pooler.dense.weight": 0.0001295, - "model.pooler.dense.bias": -0.0052078, - "seq2seq_encoder.weight": -0.0015382, - "seq2seq_encoder.bias": -0.0105704, - "classifier.weight": 0.0261459, - "classifier.bias": -0.0157966, - } - if config.get("use_crf", True): - parameter_means_expected.update( - { - "crf.start_transitions": -0.0341042, - "crf.end_transitions": 0.0140624, - "crf.transitions": 0.0056733, - } - ) - assert parameter_means == parameter_means_expected - - -def test_model_pickleable(model): - import pickle - - pickle.dumps(model) - - -def test_freeze_base_model(): - model = TokenClassificationModelWithSeq2SeqEncoderAndCrf( - model_name_or_path="prajjwal1/bert-tiny", - num_classes=5, - freeze_base_model=True, - ) - - base_model_params = dict(model.model.named_parameters(prefix="model")) - assert len(base_model_params) > 0 - for param in base_model_params.values(): - assert not param.requires_grad - task_params = { - name: param for name, param in model.named_parameters() if name not in base_model_params - } - assert len(task_params) > 0 - for param in task_params.values(): - assert param.requires_grad - - -def test_tune_base_model(): - model = TokenClassificationModelWithSeq2SeqEncoderAndCrf( - model_name_or_path="prajjwal1/bert-tiny", - num_classes=5, - ) - base_model_params = dict(model.model.named_parameters(prefix="model")) - assert len(base_model_params) > 0 - for param in base_model_params.values(): - assert param.requires_grad - task_params = { - name: param for name, param in model.named_parameters() if name not in base_model_params - } - assert len(task_params) > 0 - for param in task_params.values(): - assert param.requires_grad - - -def test_forward(batch, model): - inputs, targets = batch - batch_size, seq_len = inputs["input_ids"].shape - num_classes = int(torch.max(targets["labels"]) + 1) - - # set seed to make sure the output is deterministic - torch.manual_seed(42) - output = model.forward(inputs) - assert set(output) == {"logits"} - logits = output["logits"] - assert logits.shape == (batch_size, seq_len, num_classes) - # check the first batch entry - torch.testing.assert_close( - logits[0], - torch.tensor( - [ - [ - -1.065280795097351, - 0.22260898351669312, - -0.013371739536523819, - 1.0213487148284912, - -0.08737741410732269, - ], - [ - -1.092915415763855, - 0.07986105978488922, - 0.011286348104476929, - 0.7147902250289917, - -0.014343257993459702, - ], - [ - -1.0107779502868652, - 0.2041827142238617, - -0.06531291455030441, - 0.6551182270050049, - 0.04944971576333046, - ], - [ - -0.3324984312057495, - 0.27757787704467773, - 0.13295423984527588, - 0.26407280564308167, - -0.007371138781309128, - ], - [ - -0.6176304817199707, - 0.12915551662445068, - 0.268213152885437, - 0.43618908524513245, - -0.13303528726100922, - ], - [ - -0.5220450758934021, - 0.37291139364242554, - 0.2522115111351013, - 0.7383102178573608, - 0.1278681606054306, - ], - [ - -1.0737248659133911, - 0.0029090046882629395, - 0.06924695521593094, - 0.6680881977081299, - -0.15523286163806915, - ], - [ - -0.5176048278808594, - -0.01018303632736206, - 0.14543311297893524, - 0.5191693305969238, - -0.3461107611656189, - ], - [ - -0.9277648329734802, - 0.3154565095901489, - -0.07648143172264099, - 0.4210910201072693, - 0.2663896083831787, - ], - [ - -0.8864655494689941, - 0.2862459421157837, - -0.04168111830949783, - 0.4992614984512329, - 0.28455498814582825, - ], - [ - -0.9500657916069031, - 0.1869449019432068, - -0.005329027771949768, - 0.5908203721046448, - 0.06730394065380096, - ], - [ - -0.5336291193962097, - -0.053214408457279205, - 0.22038350999355316, - 0.48135989904403687, - -0.4338146448135376, - ], - ] - ), - ) - - # check the sums per sequence - torch.testing.assert_close( - logits.sum(1), - torch.tensor( - [ - [ - -9.530403137207031, - 2.0144565105438232, - 0.8975526690483093, - 7.009620189666748, - -0.3817189633846283, - ], - [ - -4.351415634155273, - 0.3694552183151245, - -0.8337129354476929, - 3.612205743789673, - 0.15454095602035522, - ], - [ - -6.173098564147949, - -2.6261491775512695, - 0.47521746158599854, - 3.344158172607422, - -5.086399078369141, - ], - [ - -9.28173542022705, - -1.6196215152740479, - 0.18393829464912415, - 5.492751121520996, - -4.148656845092773, - ], - ] - ), - ) - - -def test_step(batch, model, config): - torch.manual_seed(42) - loss = model._step("train", batch) - assert loss is not None - if config == {}: - torch.testing.assert_close(loss, torch.tensor(75.52511596679688)) - elif config == {"use_crf": False}: - torch.testing.assert_close(loss, torch.tensor(1.9434731006622314)) - else: - raise ValueError(f"Unknown config: {config}") - - -def test_training_step_and_on_epoch_end(batch, model, config): - metric = model._get_metric(TRAINING, batch_idx=0) - assert metric is None - loss = model.training_step(batch, batch_idx=0) - assert loss is not None - if config == {}: - torch.testing.assert_close(loss, torch.tensor(77.59623718261719)) - elif config == {"use_crf": False}: - torch.testing.assert_close(loss, torch.tensor(1.9865683317184448)) - else: - raise ValueError(f"Unknown config: {config}") - - model.on_train_epoch_end() - - -def test_training_step_without_attention_mask(batch, model, config): - inputs, targets = batch - inputs_without_attention_mask = {k: v for k, v in inputs.items() if k != "attention_mask"} - loss = model.training_step(batch=(inputs_without_attention_mask, targets), batch_idx=0) - assert loss is not None - if config == {}: - torch.testing.assert_close(loss, torch.tensor(103.0061264038086)) - elif config == {"use_crf": False}: - torch.testing.assert_close(loss, torch.tensor(1.9988830089569092)) - else: - raise ValueError(f"Unknown config: {config}") - - -def test_validation_step_and_on_epoch_end(batch, model, config): - metric = model._get_metric(VALIDATION, batch_idx=0) - metric.reset() - loss = model.validation_step(batch, batch_idx=0) - assert loss is not None - metric_values = {k: v.item() for k, v in metric.compute().items()} - if config == {}: - torch.testing.assert_close(loss, torch.tensor(77.59623718261719)) - assert metric_values == { - "token/macro/f1": 0.20666667819023132, - "token/micro/f1": 0.2068965584039688, - "token/macro/precision": 0.29019609093666077, - "token/macro/recall": 0.2666666805744171, - "token/micro/precision": 0.2068965584039688, - "token/micro/recall": 0.2068965584039688, - "span/ORG/f1": 0.3636363744735718, - "span/ORG/recall": 0.25, - "span/ORG/precision": 0.6666666865348816, - "span/PER/f1": 0.0, - "span/PER/recall": 0.0, - "span/PER/precision": 0.0, - "span/micro/f1": 0.12121212482452393, - "span/micro/recall": 0.07407407462596893, - "span/micro/precision": 0.3333333432674408, - "span/macro/f1": 0.1818181872367859, - "span/macro/recall": 0.125, - "span/macro/precision": 0.3333333432674408, - } - elif config == {"use_crf": False}: - torch.testing.assert_close(loss, torch.tensor(1.9865683317184448)) - assert metric_values == { - "token/macro/f1": 0.11717171967029572, - "token/micro/f1": 0.17241379618644714, - "token/macro/precision": 0.22500000894069672, - "token/macro/recall": 0.24444444477558136, - "token/micro/precision": 0.17241379618644714, - "token/micro/recall": 0.17241379618644714, - "span/ORG/f1": 0.0, - "span/ORG/recall": 0.0, - "span/ORG/precision": 0.0, - "span/PER/f1": 0.0, - "span/PER/recall": 0.0, - "span/PER/precision": 0.0, - "span/micro/f1": 0.0, - "span/micro/recall": 0.0, - "span/micro/precision": 0.0, - "span/macro/f1": 0.0, - "span/macro/recall": 0.0, - "span/macro/precision": 0.0, - } - else: - raise ValueError(f"Unknown config: {config}") - - model.on_validation_epoch_end() - - -def test_test_step_and_on_epoch_end(batch, model, config): - metric = model._get_metric(TESTING, batch_idx=0) - metric.reset() - loss = model.test_step(batch, batch_idx=0) - assert loss is not None - metric_values = {k: v.item() for k, v in metric.compute().items()} - if config == {}: - torch.testing.assert_close(loss, torch.tensor(77.59623718261719)) - assert metric_values == { - "token/macro/f1": 0.20666667819023132, - "token/micro/f1": 0.2068965584039688, - "token/macro/precision": 0.29019609093666077, - "token/macro/recall": 0.2666666805744171, - "token/micro/precision": 0.2068965584039688, - "token/micro/recall": 0.2068965584039688, - "span/ORG/f1": 0.3636363744735718, - "span/ORG/recall": 0.25, - "span/ORG/precision": 0.6666666865348816, - "span/PER/f1": 0.0, - "span/PER/recall": 0.0, - "span/PER/precision": 0.0, - "span/micro/f1": 0.12121212482452393, - "span/micro/recall": 0.07407407462596893, - "span/micro/precision": 0.3333333432674408, - "span/macro/f1": 0.1818181872367859, - "span/macro/recall": 0.125, - "span/macro/precision": 0.3333333432674408, - } - elif config == {"use_crf": False}: - torch.testing.assert_close(loss, torch.tensor(1.9865683317184448)) - assert metric_values == { - "token/macro/f1": 0.11717171967029572, - "token/micro/f1": 0.17241379618644714, - "token/macro/precision": 0.22500000894069672, - "token/macro/recall": 0.24444444477558136, - "token/micro/precision": 0.17241379618644714, - "token/micro/recall": 0.17241379618644714, - "span/ORG/f1": 0.0, - "span/ORG/recall": 0.0, - "span/ORG/precision": 0.0, - "span/PER/f1": 0.0, - "span/PER/recall": 0.0, - "span/PER/precision": 0.0, - "span/micro/f1": 0.0, - "span/micro/recall": 0.0, - "span/micro/precision": 0.0, - "span/macro/f1": 0.0, - "span/macro/recall": 0.0, - "span/macro/precision": 0.0, - } - else: - raise ValueError(f"Unknown config: {config}") - - model.on_test_epoch_end() - - -@pytest.mark.parametrize("test_step", [False, True]) -def test_predict_and_predict_step(model, batch, config, test_step): - torch.manual_seed(42) - if test_step: - predictions = model.predict_step(batch, batch_idx=0, dataloader_idx=0) - else: - predictions = model.predict(batch[0]) - - assert set(predictions) == {"labels", "probabilities"} - labels = predictions["labels"] - probabilities = predictions["probabilities"] - if config == {}: - torch.testing.assert_close( - labels, - torch.tensor( - [ - [-100, 3, 3, 1, 3, -100, -100, -100, -100, -100, -100, -100], - [-100, 3, 1, 4, 4, 3, 3, 3, 2, -100, -100, -100], - [-100, 3, 2, 2, 3, 3, 3, 2, -100, -100, -100, -100], - [-100, 3, 3, 3, 2, 3, 1, 4, 3, 3, 2, -100], - ] - ), - ) - elif config == {"use_crf": False}: - torch.testing.assert_close( - labels, - torch.tensor( - [ - [-100, 3, 3, 1, 3, -100, -100, -100, -100, -100, -100, -100], - [-100, 3, 3, 4, 4, 3, 3, 3, 3, -100, -100, -100], - [-100, 3, 2, 2, 3, 3, 3, 2, -100, -100, -100, -100], - [-100, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, -100], - ] - ), - ) - else: - raise ValueError(f"Unknown config: {config}") - - assert labels.shape == batch[1]["labels"].shape - torch.testing.assert_close( - probabilities[:2].round(decimals=4), - torch.tensor( - [ - [ - [0.0549, 0.1991, 0.1573, 0.4426, 0.1461], - [0.0614, 0.1984, 0.1853, 0.3744, 0.1806], - [0.0661, 0.2229, 0.1702, 0.3499, 0.1909], - [0.1310, 0.2411, 0.2087, 0.2379, 0.1813], - [0.0997, 0.2104, 0.2418, 0.2861, 0.1619], - [0.0904, 0.2213, 0.1961, 0.3189, 0.1732], - [0.0654, 0.1920, 0.2052, 0.3734, 0.1639], - [0.1162, 0.1929, 0.2254, 0.3276, 0.1379], - [0.0716, 0.2483, 0.1678, 0.2759, 0.2364], - [0.0726, 0.2344, 0.1689, 0.2901, 0.2340], - [0.0708, 0.2207, 0.1821, 0.3305, 0.1958], - [0.1162, 0.1879, 0.2470, 0.3206, 0.1284], - ], - [ - [0.1242, 0.1911, 0.1516, 0.3256, 0.2075], - [0.1291, 0.2089, 0.2046, 0.2890, 0.1684], - [0.2033, 0.2016, 0.1920, 0.2260, 0.1771], - [0.1793, 0.2191, 0.1800, 0.1889, 0.2328], - [0.1854, 0.2150, 0.1638, 0.1898, 0.2460], - [0.1363, 0.2007, 0.1738, 0.2887, 0.2005], - [0.1254, 0.2014, 0.1826, 0.2890, 0.2016], - [0.1305, 0.2056, 0.2056, 0.2590, 0.1993], - [0.1400, 0.2022, 0.2252, 0.2544, 0.1783], - [0.1299, 0.2051, 0.1933, 0.2751, 0.1966], - [0.1088, 0.2086, 0.1599, 0.2861, 0.2367], - [0.0910, 0.1794, 0.1840, 0.3793, 0.1663], - ], - ] - ), - ) - - -def test_configure_optimizers(model): - model.trainer = Trainer(max_epochs=10) - optimizer_and_schedular = model.configure_optimizers() - assert optimizer_and_schedular is not None - optimizers, schedulers = optimizer_and_schedular - - assert len(optimizers) == 1 - optimizer = optimizers[0] - assert isinstance(optimizer, torch.optim.AdamW) - assert optimizer.defaults["lr"] == 1e-05 - assert optimizer.defaults["weight_decay"] == 0.01 - assert optimizer.defaults["eps"] == 1e-08 - - assert len(schedulers) == 1 - scheduler = schedulers[0] - assert isinstance(scheduler["scheduler"], torch.optim.lr_scheduler.LambdaLR) - - -def test_configure_optimizers_with_task_learning_rate(): - model = TokenClassificationModelWithSeq2SeqEncoderAndCrf( - model_name_or_path="prajjwal1/bert-tiny", - num_classes=5, - warmup_proportion=0.0, - task_learning_rate=1e-4, - ) - optimizer = model.configure_optimizers() - assert optimizer is not None - assert isinstance(optimizer, torch.optim.AdamW) - assert len(optimizer.param_groups) == 2 - # check that all parameters are in the optimizer - assert set(optimizer.param_groups[0]["params"]) | set( - optimizer.param_groups[1]["params"] - ) == set(model.parameters()) - - # base model parameters - param_group = optimizer.param_groups[0] - assert param_group["lr"] == 1e-05 - assert len(param_group["params"]) == 39 - - # task parameters - param_group = optimizer.param_groups[1] - assert param_group["lr"] == 1e-04 - assert len(param_group["params"]) == 5 diff --git a/tests/taskmodules/__init__.py b/tests/taskmodules/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/taskmodules/common/__init__.py b/tests/taskmodules/common/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/taskmodules/common/test_interfaces.py b/tests/taskmodules/common/test_interfaces.py deleted file mode 100644 index 5c65c797a..000000000 --- a/tests/taskmodules/common/test_interfaces.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Any, Dict, List, Set, Tuple - -from pie_modules.annotations import Span -from pie_modules.taskmodules.common import AnnotationEncoderDecoder - - -def test_annotation_encoder_decoder(): - """Test the AnnotationEncoderDecoder class.""" - - class SpanAnnotationEncoderDecoder(AnnotationEncoderDecoder[Span, Tuple[int, int]]): - """A class that uses the AnnotationEncoderDecoder class.""" - - def encode(self, annotation: Span, **kwargs) -> Tuple[int, int]: - return annotation.start, annotation.end - - def decode(self, encoding: Tuple[int, int], **kwargs) -> Span: - return Span(start=encoding[0], end=encoding[1]) - - def validate_encoding(self, encoding: Tuple[int, int]) -> Set[str]: - return {"order"} if encoding[0] > encoding[1] else set() - - encoder_decoder = SpanAnnotationEncoderDecoder() - - assert encoder_decoder.encode(Span(start=1, end=2)) == (1, 2) - assert encoder_decoder.decode((1, 2)) == Span(start=1, end=2) - assert encoder_decoder.validate_encoding((1, 2)) == set() - assert encoder_decoder.validate_encoding((2, 1)) == {"order"} diff --git a/tests/taskmodules/common/test_mixins.py b/tests/taskmodules/common/test_mixins.py deleted file mode 100644 index b6b921bb2..000000000 --- a/tests/taskmodules/common/test_mixins.py +++ /dev/null @@ -1,166 +0,0 @@ -import dataclasses -import logging -from typing import List - -import torch -from pie_core import Annotation - -from pie_modules.taskmodules.common import BatchableMixin -from pie_modules.taskmodules.common.mixins import RelationStatisticsMixin - - -def test_batchable_mixin(): - """Test the BatchableMixin class.""" - - @dataclasses.dataclass - class Foo(BatchableMixin): - """A class that uses the BatchableMixin class.""" - - a: List[int] - - @property - def len_a(self): - """Return the length of the list a.""" - return len(self.a) - - x = Foo(a=[1, 2, 3]) - y = Foo(a=[4, 5]) - - batch = Foo.batch( - values=[x, y], dtypes={"a": torch.int64, "len_a": torch.int64}, pad_values={"a": 0} - ) - torch.testing.assert_close(batch["a"], torch.tensor([[1, 2, 3], [4, 5, 0]])) - torch.testing.assert_close(batch["len_a"], torch.tensor([3, 2])) - - -def test_relation_statistics_mixin_show_statistics(caplog): - """Test the RelationStatisticsMixin class.""" - - class Foo(RelationStatisticsMixin): - """A class that uses the RelationStatisticsMixin class.""" - - pass - - @dataclasses.dataclass(eq=True, frozen=True) - class TestAnnotation(Annotation): - label: str - score: float = dataclasses.field(default=1.0, compare=False) - - x = Foo(collect_statistics=True) - - relations = [ - TestAnnotation(label="A", score=1.0), - TestAnnotation(label="B", score=0.5), - TestAnnotation(label="C", score=0.0), - TestAnnotation(label="D", score=0.3), - ] - # all available relations - x.collect_all_relations(kind="available", relations=relations) - # relations skipped for a reason ("test") - x.collect_relation(kind="skipped_test", relation=relations[1]) - # mark two relations as used, one of them is skipped for another (unknown) reason - x.collect_all_relations(kind="used", relations=[relations[0], relations[2]]) - - statistics = x.get_statistics() - - assert statistics == { - ("available", "A"): 1, - ("available", "B"): 1, - ("available", "D"): 1, - ("available", "no_relation"): 1, - ("skipped_other", "D"): 1, - ("skipped_test", "B"): 1, - ("used", "A"): 1, - ("used", "no_relation"): 1, - } - - with caplog.at_level(logging.INFO): - x.show_statistics() - assert caplog.messages[0] == ( - "Foo does not have a `none_label` attribute. " - "Using default value 'no_relation'. " - "`none_label` is used as the label for relations with score 0 in statistics and " - "all relations with label different from `none_label` will be summarized to 'all_relations'. " - "Set the `none_label` attribute before using statistics or " - "overwrite `get_none_label_for_statistics()` function to get rid of this message." - ) - assert caplog.messages[1] == ( - "statistics:\n" - "| | available | skipped_other | skipped_test | used | used % |\n" - "|:--------------|------------:|----------------:|---------------:|-------:|---------:|\n" - "| A | 1 | 0 | 0 | 1 | 100 |\n" - "| B | 1 | 0 | 1 | 0 | 0 |\n" - "| D | 1 | 1 | 0 | 0 | 0 |\n" - "| no_relation | 1 | 0 | 0 | 1 | 100 |\n" - "| all_relations | 3 | 1 | 1 | 1 | 33 |" - ) - - -def test_relation_statistics_mixin_show_statistics_no_relations(caplog): - """Test the RelationStatisticsMixin class with no predictions.""" - - class Foo(RelationStatisticsMixin): - """A class that uses the RelationStatisticsMixin class.""" - - pass - - x = Foo(collect_statistics=True) - - # Test with no relations collected - x.collect_all_relations(kind="available", relations=[]) - x.collect_all_relations(kind="used", relations=[]) - - statistics = x.get_statistics() - - assert statistics == {} - - with caplog.at_level(logging.INFO): - x.show_statistics() - assert caplog.messages[0] == "statistics:\n" "|--:|\n" "| 0 |" - - -def test_relation_statistics_mixin_show_statistics_custom_none_label(caplog): - """Test the RelationStatisticsMixin class with custom none_label.""" - - class Foo(RelationStatisticsMixin): - """A class that uses the RelationStatisticsMixin class. - - It also sets the `none_label` attribute which will be used by statistics. - """ - - def __init__(self, none_label: str = "no_relation", **kwargs): - super().__init__(**kwargs) - self.none_label = none_label - - @dataclasses.dataclass(eq=True, frozen=True) - class TestAnnotation(Annotation): - label: str - score: float = dataclasses.field(default=1.0, compare=False) - - x = Foo(collect_statistics=True, none_label="None_Label") - - relations = [ - TestAnnotation(label="A", score=1.0), - TestAnnotation(label="B", score=0.5), - TestAnnotation(label="C", score=0.0), - TestAnnotation(label="D", score=0.3), - ] - # all available relations - x.collect_all_relations(kind="available", relations=relations) - # relations skipped for a reason ("test") - x.collect_relation(kind="skipped_test", relation=relations[1]) - # mark two relations as used, one of them is skipped for another (unknown) reason - x.collect_all_relations(kind="used", relations=[relations[0], relations[2]]) - - with caplog.at_level(logging.INFO): - x.show_statistics() - assert caplog.messages[0] == ( - "statistics:\n" - "| | available | skipped_other | skipped_test | used | used % |\n" - "|:--------------|------------:|----------------:|---------------:|-------:|---------:|\n" - "| A | 1 | 0 | 0 | 1 | 100 |\n" - "| B | 1 | 0 | 1 | 0 | 0 |\n" - "| D | 1 | 1 | 0 | 0 | 0 |\n" - "| None_Label | 1 | 0 | 0 | 1 | 100 |\n" - "| all_relations | 3 | 1 | 1 | 1 | 33 |" - ) diff --git a/tests/taskmodules/common/test_taskmodule_with_document_converter.py b/tests/taskmodules/common/test_taskmodule_with_document_converter.py deleted file mode 100644 index 8540b5a3f..000000000 --- a/tests/taskmodules/common/test_taskmodule_with_document_converter.py +++ /dev/null @@ -1,163 +0,0 @@ -from typing import Optional, Type - -import pytest -from pie_core import Document -from typing_extensions import TypeAlias - -from pie_modules.documents import TextDocumentWithLabeledSpansAndBinaryRelations -from pie_modules.taskmodules import RETextClassificationWithIndicesTaskModule -from pie_modules.taskmodules.common import TaskModuleWithDocumentConverter -from tests.conftest import TestDocument - -DocumentType: TypeAlias = TestDocument -ConvertedDocumentType: TypeAlias = TextDocumentWithLabeledSpansAndBinaryRelations - - -class MyRETaskModuleWithDocConverter( - TaskModuleWithDocumentConverter, RETextClassificationWithIndicesTaskModule -): - @property - def document_type(self) -> Optional[Type[Document]]: - return TestDocument - - def _convert_document(self, document: DocumentType) -> ConvertedDocumentType: - result = document.as_type( - TextDocumentWithLabeledSpansAndBinaryRelations, - field_mapping={"entities": "labeled_spans", "relations": "binary_relations"}, - ) - new2old_span = { - new_s: old_s for old_s, new_s in zip(document.entities, result.labeled_spans) - } - result.metadata["new2old_span"] = new2old_span - return result - - def _integrate_predictions_from_converted_document( - self, document: DocumentType, converted_document: ConvertedDocumentType - ) -> None: - new2old_span = converted_document.metadata["new2old_span"] - for rel in converted_document.binary_relations.predictions: - new_rel = rel.copy(head=new2old_span[rel.head], tail=new2old_span[rel.tail]) - document.relations.predictions.append(new_rel) - - -@pytest.fixture(scope="module") -def taskmodule(documents): - result = MyRETaskModuleWithDocConverter(tokenizer_name_or_path="bert-base-cased") - result.prepare(documents) - return result - - -def test_taskmodule(taskmodule): - assert taskmodule is not None - assert taskmodule.document_type == TestDocument - - -@pytest.fixture(scope="module") -def task_encodings(taskmodule, documents): - return taskmodule.encode(documents, encode_target=True) - - -def test_task_encodings(task_encodings): - assert len(task_encodings) == 7 - - -def test_decode(taskmodule, task_encodings): - label_indices = [0, 1, 3, 0, 0, 2, 0] - probabilities = [0.1738, 0.6643, 0.2101, 0.0801, 0.0319, 0.81, 0.3079] - task_outputs = [ - {"labels": [taskmodule.id_to_label[label_idx]], "probabilities": [prob]} - for label_idx, prob in zip(label_indices, probabilities) - ] - docs_with_predictions = taskmodule.decode( - task_encodings=task_encodings, task_outputs=task_outputs - ) - assert all(isinstance(doc, TestDocument) for doc in docs_with_predictions) - - all_gold_relations = [doc.relations.resolve() for doc in docs_with_predictions] - assert all_gold_relations == [ - [("per:employee_of", (("PER", "Entity A"), ("ORG", "B")))], - [ - ("per:employee_of", (("PER", "Entity G"), ("ORG", "H"))), - ("per:founder", (("PER", "Entity G"), ("ORG", "I"))), - ("org:founded_by", (("ORG", "I"), ("ORG", "H"))), - ], - [ - ("per:employee_of", (("PER", "Entity M"), ("ORG", "N"))), - ("per:founder", (("PER", "it"), ("ORG", "O"))), - ("org:founded_by", (("ORG", "O"), ("PER", "it"))), - ], - ] - - all_predicted_relations = [ - doc.relations.predictions.resolve() for doc in docs_with_predictions - ] - assert all_predicted_relations == [ - [("no_relation", (("PER", "Entity A"), ("ORG", "B")))], - [ - ("org:founded_by", (("PER", "Entity G"), ("ORG", "H"))), - ("per:founder", (("PER", "Entity G"), ("ORG", "I"))), - ("no_relation", (("ORG", "I"), ("ORG", "H"))), - ], - [ - ("no_relation", (("PER", "Entity M"), ("ORG", "N"))), - ("per:employee_of", (("PER", "it"), ("ORG", "O"))), - ("no_relation", (("ORG", "O"), ("PER", "it"))), - ], - ] - - -class MyRETaskModuleWithDocConverterWithoutDocType( - TaskModuleWithDocumentConverter, RETextClassificationWithIndicesTaskModule -): - def _convert_document(self, document: DocumentType) -> ConvertedDocumentType: - pass - - def _integrate_predictions_from_converted_document( - self, document: DocumentType, converted_document: ConvertedDocumentType - ) -> None: - pass - - -def test_missing_document_type_overwrite(): - taskmodule = MyRETaskModuleWithDocConverterWithoutDocType( - tokenizer_name_or_path="bert-base-cased" - ) - - with pytest.raises(NotImplementedError) as e: - taskmodule.document_type - assert ( - str(e.value) - == "please overwrite document_type for MyRETaskModuleWithDocConverterWithoutDocType" - ) - - -class MyRETaskModuleWithWrongDocConverter( - TaskModuleWithDocumentConverter, RETextClassificationWithIndicesTaskModule -): - @property - def document_type(self) -> Optional[Type[Document]]: - return TestDocument - - def _convert_document(self, document: DocumentType) -> ConvertedDocumentType: - result = TextDocumentWithLabeledSpansAndBinaryRelations(text="dummy") - result.metadata["original_document"] = None - return result - - def _integrate_predictions_from_converted_document( - self, document: DocumentType, converted_document: ConvertedDocumentType - ) -> None: - pass - - -def test_wrong_doc_converter(documents): - taskmodule = MyRETaskModuleWithWrongDocConverter(tokenizer_name_or_path="bert-base-cased") - taskmodule.prepare(documents) - with pytest.raises(ValueError) as e: - taskmodule.encode(documents, encode_target=True) - assert ( - str(e.value) - == "metadata of converted_document has already and entry 'original_document', " - "this is not allowed. Please adjust " - "'MyRETaskModuleWithWrongDocConverter._convert_document()' to produce " - "documents without that key in metadata." - ) diff --git a/tests/taskmodules/common/test_utils.py b/tests/taskmodules/common/test_utils.py deleted file mode 100644 index b9a426402..000000000 --- a/tests/taskmodules/common/test_utils.py +++ /dev/null @@ -1,17 +0,0 @@ -import torch - -from pie_modules.taskmodules.common.utils import get_first_occurrence_index - - -def test_get_first_occurrence_index(): - tensor: torch.LongTensor = torch.tensor( - [ - [0, 1, 1, 1, 1, 1], # 1 - [0, 0, 1, 1, 1, 1], # 2 - [0, 1, 1, 0, 0, 1], # 1 - [1, 1, 1, 1, 1, 1], # 0 - [0, 0, 0, 0, 0, 0], # 6 (=size of input) because no 1s at all - ] - ).to(torch.long) - indices = get_first_occurrence_index(tensor, 1) - torch.testing.assert_close(indices, torch.tensor([1, 2, 1, 0, 6])) diff --git a/tests/taskmodules/metrics/__init__.py b/tests/taskmodules/metrics/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/taskmodules/metrics/test_precision_recall_and_f1_for_labeled_annotations.py b/tests/taskmodules/metrics/test_precision_recall_and_f1_for_labeled_annotations.py deleted file mode 100644 index 90cc93e59..000000000 --- a/tests/taskmodules/metrics/test_precision_recall_and_f1_for_labeled_annotations.py +++ /dev/null @@ -1,151 +0,0 @@ -import pytest -from torch import tensor - -from pie_modules.annotations import LabeledSpan -from pie_modules.taskmodules.metrics import PrecisionRecallAndF1ForLabeledAnnotations - - -def test_precision_recall_and_f1_for_labeled_annotations(): - metric = PrecisionRecallAndF1ForLabeledAnnotations() - assert metric.metric_state == {} - - metric.update( - gold=[LabeledSpan(start=0, end=1, label="a")], - predicted=[LabeledSpan(start=0, end=1, label="a")], - ) - metric_state = {k: v.tolist() for k, v in metric.metric_state.items()} - assert metric_state == {"counts_a": [1, 1, 1], "counts_micro": [1, 1, 1]} - value = metric.compute() - assert value == { - "a": {"recall": 1.0, "precision": 1.0, "f1": 1.0}, - "macro": {"recall": 1.0, "precision": 1.0, "f1": 1.0}, - "micro": {"recall": 1.0, "precision": 1.0, "f1": 1.0}, - } - - metric.reset() - metric.update( - gold=[LabeledSpan(start=0, end=1, label="a"), LabeledSpan(start=0, end=1, label="b")], - predicted=[LabeledSpan(start=0, end=1, label="b"), LabeledSpan(start=0, end=1, label="c")], - ) - metric_state = {k: v.tolist() for k, v in metric.metric_state.items()} - assert metric_state == { - "counts_a": [1, 0, 0], - "counts_b": [1, 1, 1], - "counts_c": [0, 1, 0], - "counts_micro": [2, 2, 1], - } - assert metric.compute() == { - "b": {"recall": 1.0, "precision": 1.0, "f1": 1.0}, - "a": {"recall": 0.0, "precision": 0.0, "f1": 0.0}, - "c": {"recall": 0.0, "precision": 0.0, "f1": 0.0}, - "macro": { - "f1": tensor(0.3333333432674408), - "precision": tensor(0.3333333432674408), - "recall": tensor(0.3333333432674408), - }, - "micro": {"recall": 0.5, "precision": 0.5, "f1": 0.5}, - } - - # check deduplication in same update - metric.reset() - metric.update( - gold=[ - LabeledSpan(start=0, end=1, label="a"), - LabeledSpan(start=0, end=1, label="a"), - LabeledSpan(start=0, end=1, label="b"), - ], - predicted=[ - LabeledSpan(start=0, end=1, label="b"), - LabeledSpan(start=0, end=1, label="b"), - LabeledSpan(start=0, end=1, label="c"), - ], - ) - metric_state = {k: v.tolist() for k, v in metric.metric_state.items()} - assert metric_state == { - "counts_a": [1, 0, 0], - "counts_b": [1, 1, 1], - "counts_c": [0, 1, 0], - "counts_micro": [2, 2, 1], - } - - # assert no deduplication over multiple updates - metric.reset() - metric.update( - gold=[LabeledSpan(start=0, end=1, label="a")], - predicted=[LabeledSpan(start=0, end=1, label="b")], - ) - metric.update( - gold=[LabeledSpan(start=0, end=1, label="b")], - predicted=[LabeledSpan(start=0, end=1, label="a")], - ) - metric_state = {k: v.tolist() for k, v in metric.metric_state.items()} - assert metric_state == { - "counts_a": [1, 1, 0], - "counts_b": [1, 1, 0], - "counts_c": [0, 0, 0], - "counts_micro": [2, 2, 0], - } - assert metric.compute() == { - "a": {"f1": 0.0, "precision": 0.0, "recall": 0.0}, - "b": {"f1": 0.0, "precision": 0.0, "recall": 0.0}, - "c": {"f1": 0.0, "precision": 0.0, "recall": 0.0}, - "macro": {"f1": 0.0, "precision": 0.0, "recall": 0.0}, - "micro": {"f1": 0.0, "precision": 0.0, "recall": 0.0}, - } - - -def test_precision_recall_and_f1_for_labeled_annotations_in_percent(): - metric = PrecisionRecallAndF1ForLabeledAnnotations( - in_percent=True, flatten_result_with_sep="/" - ) - - metric.update( - gold=[LabeledSpan(start=0, end=1, label="a")], - predicted=[LabeledSpan(start=0, end=1, label="a"), LabeledSpan(start=0, end=1, label="b")], - ) - values = {k: v.item() for k, v in metric.compute().items()} - assert values == { - "a/f1": 100.0, - "a/precision": 100.0, - "a/recall": 100.0, - "b/f1": 0.0, - "b/precision": 0.0, - "b/recall": 0.0, - "macro/f1": 50.0, - "macro/precision": 50.0, - "macro/recall": 50.0, - "micro/f1": 66.66667175292969, - "micro/precision": 50.0, - "micro/recall": 100.0, - } - - -def test_precision_recall_and_f1_for_labeled_annotations_with_label_mapping(): - metric = PrecisionRecallAndF1ForLabeledAnnotations( - label_mapping={"a": "label_a", "b": "label_b"} - ) - - metric.update( - gold=[LabeledSpan(start=0, end=1, label="a")], - predicted=[LabeledSpan(start=0, end=1, label="a"), LabeledSpan(start=0, end=1, label="b")], - ) - assert metric.compute() == { - "label_a": {"f1": 1.0, "precision": 1.0, "recall": 1.0}, - "label_b": {"f1": 0.0, "precision": 0.0, "recall": 0.0}, - "macro": {"f1": 0.5, "precision": 0.5, "recall": 0.5}, - "micro": {"f1": 0.6666666666666666, "precision": 0.5, "recall": 1.0}, - } - - -def test_precision_recall_and_f1_for_labeled_annotations_key_micro_error(): - metric = PrecisionRecallAndF1ForLabeledAnnotations() - with pytest.raises(ValueError) as excinfo: - metric.update( - gold=[LabeledSpan(start=0, end=1, label="micro")], - predicted=[], - ) - assert ( - str(excinfo.value) - == "The key 'micro' was used as an annotation label, but it is reserved for the micro average. " - "You can change which key is used for that with the 'key_micro' argument." - ) diff --git a/tests/taskmodules/metrics/test_wrapped_layer_metrics_with_unbatch_and_decode_with_errors_function.py b/tests/taskmodules/metrics/test_wrapped_layer_metrics_with_unbatch_and_decode_with_errors_function.py deleted file mode 100644 index b4e34ceb7..000000000 --- a/tests/taskmodules/metrics/test_wrapped_layer_metrics_with_unbatch_and_decode_with_errors_function.py +++ /dev/null @@ -1,144 +0,0 @@ -import json -from typing import Any, Dict, Tuple - -import pytest -from torch import tensor -from torchmetrics import Metric - -from pie_modules.taskmodules.metrics import ( - WrappedLayerMetricsWithUnbatchAndDecodeWithErrorsFunction, -) - - -class TestMetric(Metric): - """A simple metric that computes the exact match ratio between predictions and targets.""" - - def __init__(self): - super().__init__() - self.add_state("matching", default=[]) - - def update(self, prediction: str, target: str): - self.matching.append(prediction == target) - - def compute(self): - # Note: returning NaN in the case of an empty list would be more correct, but - # returning 0.0 is more convenient for testing. - return sum(self.matching) / len(self.matching) if self.matching else 0.0 - - -@pytest.fixture(scope="module") -def wrapped_layer_metrics_with_unbatch_and_decode_with_errors_function(): - def decode_with_errors_function(x: str) -> Tuple[Dict[str, Any], Dict[str, int]]: - if x == "error": - return {"entities": [], "relations": []}, {"dummy": 1} - else: - return json.loads(x), {"dummy": 0} - - layer_metrics = { - "entities": TestMetric(), - "relations": TestMetric(), - } - metric = WrappedLayerMetricsWithUnbatchAndDecodeWithErrorsFunction( - layer_metrics=layer_metrics, - unbatch_function=lambda x: x.split("\n"), - decode_layers_with_errors_function=decode_with_errors_function, - ) - return metric - - -def test_wrapped_layer_metrics_with_unbatch_and_decode_with_errors_function( - wrapped_layer_metrics_with_unbatch_and_decode_with_errors_function, -): - metric = wrapped_layer_metrics_with_unbatch_and_decode_with_errors_function - assert metric is not None - assert metric.unbatch_function is not None - assert metric.decode_layers_with_errors_function is not None - assert metric.layer_metrics is not None - assert metric.metric_state == { - "total": tensor(0), - "exact_encoding_matches": tensor(0), - } - - values = metric.compute() - assert metric.metric_state - assert values == { - "decoding_errors": {"all": 0.0}, - "entities": 0.0, - "exact_encoding_matches": 0.0, - "relations": 0.0, - } - - metric.reset() - # Prediction and expected are the same. - metric.update( - prediction=json.dumps({"entities": ["E1"], "relations": ["R1"]}), - expected=json.dumps({"entities": ["E1"], "relations": ["R1"]}), - ) - assert metric.metric_state == { - "total": tensor(1), - "exact_encoding_matches": tensor(1), - "errors_dummy": tensor(0), - } - values = metric.compute() - assert values == { - "decoding_errors": {"all": 0.0, "dummy": 0.0}, - "entities": 1.0, - "exact_encoding_matches": 1.0, - "relations": 1.0, - } - - metric.reset() - # Prediction and expected are different and there are multiple entries. - # The first entry is an exact match, the second entry is not. - metric.update( - prediction=json.dumps({"entities": ["E1"], "relations": ["R1"]}) - + "\n" - + json.dumps({"entities": ["E1"], "relations": ["R1"]}), - expected=json.dumps({"entities": ["E1"], "relations": ["R1"]}) - + "\n" - + json.dumps({"entities": ["E1"], "relations": ["R2"]}), - ) - assert metric.metric_state == { - "total": tensor(2), - "exact_encoding_matches": tensor(1), - "errors_dummy": tensor(0), - } - values = metric.compute() - assert values == { - "decoding_errors": {"all": 0.0, "dummy": 0.0}, - "entities": 1.0, - "exact_encoding_matches": 0.5, - "relations": 0.5, - } - - metric.reset() - # Encoding error - metric.update( - prediction="error", - expected=json.dumps({"entities": ["E1"], "relations": []}), - ) - assert metric.metric_state == { - "total": tensor(1), - "exact_encoding_matches": tensor(0), - "errors_dummy": tensor(1), - } - values = metric.compute() - # In the case on an error, the decoding function returns adict with empty lists for entities and relations. - # Thus, we get a perfect match for entities and a 0.0 match for relations. - assert values == { - "decoding_errors": {"all": 1.0, "dummy": 1.0}, - "entities": 0.0, - "exact_encoding_matches": 0.0, - "relations": 1.0, - } - - # test mismatched number of predictions and targets - metric.reset() - with pytest.raises(ValueError) as excinfo: - metric.update( - prediction=json.dumps({"entities": ["E1"], "relations": ["R1"]}), - expected=json.dumps({"entities": ["E1"], "relations": ["R1"]}) - + "\n" - + json.dumps({"entities": ["E1"], "relations": ["R1"]}), - ) - assert str(excinfo.value) == "Number of predictions (1) and targets (2) do not match." diff --git a/tests/taskmodules/metrics/test_wrapped_metric_with_prepare_function.py b/tests/taskmodules/metrics/test_wrapped_metric_with_prepare_function.py deleted file mode 100644 index e06130b09..000000000 --- a/tests/taskmodules/metrics/test_wrapped_metric_with_prepare_function.py +++ /dev/null @@ -1,137 +0,0 @@ -from functools import partial - -import pytest -from torchmetrics import Metric - -from pie_modules.taskmodules.metrics import WrappedMetricWithPrepareFunction - - -class TestMetric(Metric): - """A simple metric that computes the exact match ratio between predictions and targets.""" - - def __init__(self): - super().__init__() - self.add_state("matching", default=[]) - - def update(self, prediction: str, target: str): - self.matching.append(prediction == target) - - def compute(self): - # Note: returning NaN in the case of an empty list would be more correct, but - # returning 0.0 is more convenient for testing. - return sum(self.matching) / len(self.matching) if self.matching else 0.0 - - -def test_metric(): - metric = WrappedMetricWithPrepareFunction( - metric=TestMetric(), prepare_function=lambda x: x.split()[0] - ) - - assert metric is not None - assert metric.prepare_function is not None - - assert metric.compute() == 0.0 - - metric.reset() - metric(prediction="abc", target="abc") - assert metric.compute() == 1.0 - - metric.reset() - metric(prediction="abc", target="def") - assert metric.compute() == 0.0 - - metric.reset() - metric(prediction="abc def", target="abc xyz") - # we consider just the first word, so this is still 1.0 - assert metric.compute() == 1.0 - - metric.reset() - metric(prediction="abc def", target="xyz def") - assert metric.compute() == 0.0 - - -def split_both_and_remove_where_both_match( - preds: str, targets: str, match: str -) -> tuple[list[str], list[str]]: - preds = preds.split() - targets = targets.split() - not_both_none_indices = [ - i for i, (p, t) in enumerate(zip(preds, targets)) if p != match or t != match - ] - preds = [preds[i] for i in not_both_none_indices] - targets = [targets[i] for i in not_both_none_indices] - return preds, targets - - -def test_wrapped_metric_with_prepare_both_function(): - metric = WrappedMetricWithPrepareFunction( - metric=TestMetric(), - prepare_together_function=partial(split_both_and_remove_where_both_match, match="none"), - prepare_does_unbatch=True, - ) - - assert metric is not None - assert metric.prepare_both_function is not None - - assert metric.compute() == 0.0 - - # none is removed from both, remaining is the same - metric.reset() - metric(prediction="abc none", target="abc none") - assert metric.compute() == 1.0 - - # none is removed from both, remaining is different - metric.reset() - metric(prediction="abc none", target="def none") - assert metric.compute() == 0.0 - - # none is not removed from both, remaining is partially the same - metric.reset() - metric(prediction="abc def", target="abc none") - assert metric.compute() == 0.5 - - # none is not removed from both, remaining is different - metric.reset() - metric(prediction="abc def", target="def none") - assert metric.compute() == 0.0 - - -@pytest.fixture(scope="module") -def wrapped_metric_with_unbatch_function(): - # just split the strings to unbatch the inputs - return WrappedMetricWithPrepareFunction( - metric=TestMetric(), prepare_function=lambda x: x.split(), prepare_does_unbatch=True - ) - - -def test_wrapped_metric_with_unbatch_function(wrapped_metric_with_unbatch_function): - metric = wrapped_metric_with_unbatch_function - assert metric is not None - - assert metric.compute() == 0.0 - - metric.reset() - metric(prediction="abc", target="abc") - assert metric.compute() == 1.0 - - metric.reset() - metric(prediction="abc", target="def") - assert metric.compute() == 0.0 - - metric.reset() - metric(prediction="abc def", target="abc def") - assert metric.compute() == 1.0 - - metric.reset() - metric(prediction="abc def", target="def abc") - assert metric.compute() == 0.0 - - metric.reset() - metric(prediction="abc xyz", target="def xyz") - assert metric.compute() == 0.5 - - -def test_wrapped_metric_with_unbatch_function_size_mismatch(wrapped_metric_with_unbatch_function): - with pytest.raises(ValueError) as excinfo: - wrapped_metric_with_unbatch_function(prediction="abc", target="abc def") - assert str(excinfo.value) == "Number of prepared predictions (1) and targets (2) do not match." diff --git a/tests/taskmodules/pointer_network/__init__.py b/tests/taskmodules/pointer_network/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/taskmodules/pointer_network/test_annotation_encoder_decoder.py b/tests/taskmodules/pointer_network/test_annotation_encoder_decoder.py deleted file mode 100644 index 4dbeaab8a..000000000 --- a/tests/taskmodules/pointer_network/test_annotation_encoder_decoder.py +++ /dev/null @@ -1,441 +0,0 @@ -import pytest - -from pie_modules.annotations import BinaryRelation, LabeledSpan, Span -from pie_modules.taskmodules.pointer_network.annotation_encoder_decoder import ( - BinaryRelationEncoderDecoder, - DecodingLabelException, - DecodingLengthException, - DecodingNegativeIndexException, - DecodingOrderException, - LabeledSpanEncoderDecoder, - SpanEncoderDecoder, - SpanEncoderDecoderWithOffset, -) - - -@pytest.mark.parametrize("exclusive_end", [True, False]) -def test_span_encoder_decoder(exclusive_end): - """Test the SimpleSpanEncoderDecoder class.""" - - encoder_decoder = SpanEncoderDecoder(exclusive_end) - if exclusive_end: - assert encoder_decoder.encode(Span(start=1, end=2)) == [1, 2] - assert encoder_decoder.decode([1, 2]) == Span(start=1, end=2) - else: - assert encoder_decoder.encode(Span(start=1, end=2)) == [1, 1] - assert encoder_decoder.decode([1, 1]) == Span(start=1, end=2) - - -def test_span_encoder_decoder_wrong_length(): - """Test the SimpleSpanEncoderDecoder class.""" - - encoder_decoder = SpanEncoderDecoder() - with pytest.raises(DecodingLengthException) as excinfo: - encoder_decoder.decode([1]) - assert ( - str(excinfo.value) - == "two values are required to decode as Span, but encoding has length 1" - ) - assert excinfo.value.identifier == "len" - - with pytest.raises(DecodingLengthException) as excinfo: - encoder_decoder.decode([1, 2, 3]) - assert ( - str(excinfo.value) - == "two values are required to decode as Span, but encoding has length 3" - ) - assert excinfo.value.identifier == "len" - - -def test_span_encoder_decoder_wrong_order(): - """Test the SimpleSpanEncoderDecoder class.""" - - encoder_decoder = SpanEncoderDecoder() - - with pytest.raises(DecodingOrderException) as excinfo: - encoder_decoder.decode([3, 2]) - assert ( - str(excinfo.value) - == "end index can not be smaller than start index, but got: start=3, end=2" - ) - assert excinfo.value.identifier == "order" - - # zero-length span - span = encoder_decoder.decode([1, 1]) - assert span is not None - - -def test_span_encoder_decoder_wrong_offset(): - """Test the SimpleSpanEncoderDecoder class.""" - - encoder_decoder = SpanEncoderDecoder() - - with pytest.raises(DecodingNegativeIndexException) as excinfo: - encoder_decoder.decode([-1, 2]) - assert str(excinfo.value) == "indices must be positive, but got: [-1, 2]" - assert excinfo.value.identifier == "index" - - -def test_span_encoder_decoder_with_offset(): - """Test the SpanEncoderDecoderWithOffset class.""" - - encoder_decoder = SpanEncoderDecoderWithOffset(offset=1) - - assert encoder_decoder.encode(Span(start=1, end=2)) == [2, 3] - assert encoder_decoder.decode([2, 3]) == Span(start=1, end=2) - - -@pytest.mark.parametrize("mode", ["indices_label", "label_indices"]) -def test_labeled_span_encoder_decoder(mode): - """Test the LabeledSpanEncoderDecoder class.""" - - label2id = {"A": 0, "B": 1} - encoder_decoder = LabeledSpanEncoderDecoder( - span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)), - label2id=label2id, - mode=mode, - ) - - if mode == "indices_label": - assert encoder_decoder.encode(LabeledSpan(start=1, end=2, label="A")) == [3, 4, 0] - assert encoder_decoder.decode([3, 4, 0]) == LabeledSpan(start=1, end=2, label="A") - elif mode == "label_indices": - assert encoder_decoder.encode(LabeledSpan(start=1, end=2, label="A")) == [0, 3, 4] - assert encoder_decoder.decode([0, 3, 4]) == LabeledSpan(start=1, end=2, label="A") - else: - raise ValueError(f"unknown mode: {mode}") - - -@pytest.mark.parametrize("mode", ["indices_label", "label_indices"]) -def test_labeled_span_encoder_decoder_wrong_label_encoding(mode): - """Test the LabeledSpanEncoderDecoder class.""" - - label2id = {"A": 0, "B": 1} - encoder_decoder = LabeledSpanEncoderDecoder( - span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)), - label2id=label2id, - mode=mode, - ) - - if mode == "indices_label": - with pytest.raises(DecodingLabelException) as excinfo: - encoder_decoder.decode([2, 3, 4]) - elif mode == "label_indices": - with pytest.raises(DecodingLabelException) as excinfo: - encoder_decoder.decode([4, 2, 3]) - assert str(excinfo.value) == "unknown label id: 4 (label2id: {'A': 0, 'B': 1})" - assert excinfo.value.identifier == "label" - - -def test_labeled_span_encoder_decoder_unknown_mode(): - """Test the LabeledSpanEncoderDecoder class.""" - - label2id = {"A": 0, "B": 1} - encoder_decoder = LabeledSpanEncoderDecoder( - span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)), - label2id=label2id, - mode="unknown", - ) - with pytest.raises(ValueError) as excinfo: - encoder_decoder.encode(LabeledSpan(start=1, end=2, label="A")) - assert str(excinfo.value) == "unknown mode: unknown" - - with pytest.raises(ValueError) as excinfo: - encoder_decoder.decode([0, 3, 4]) - assert str(excinfo.value) == "unknown mode: unknown" - - -@pytest.mark.parametrize( - "mode", ["head_tail_label", "tail_head_label", "label_head_tail", "label_tail_head"] -) -def test_binary_relation_encoder_decoder(mode): - """Test the BinaryRelationEncoderDecoder class.""" - - label2id = {"A": 0, "B": 1, "C": 2} - labeled_span_encoder_decoder = LabeledSpanEncoderDecoder( - span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)), - label2id=label2id, - mode="indices_label", - ) - encoder_decoder = BinaryRelationEncoderDecoder( - head_encoder_decoder=labeled_span_encoder_decoder, - tail_encoder_decoder=labeled_span_encoder_decoder, - label2id=label2id, - mode=mode, - ) - - if mode == "head_tail_label": - assert encoder_decoder.encode( - BinaryRelation( - head=LabeledSpan(start=1, end=2, label="A"), - tail=LabeledSpan(start=3, end=4, label="B"), - label="C", - ) - ) == [4, 5, 0, 6, 7, 1, 2] - assert encoder_decoder.decode([4, 5, 0, 6, 7, 1, 2]) == BinaryRelation( - head=LabeledSpan(start=1, end=2, label="A"), - tail=LabeledSpan(start=3, end=4, label="B"), - label="C", - ) - elif mode == "tail_head_label": - assert encoder_decoder.encode( - BinaryRelation( - head=LabeledSpan(start=1, end=2, label="A"), - tail=LabeledSpan(start=3, end=4, label="B"), - label="C", - ) - ) == [6, 7, 1, 4, 5, 0, 2] - assert encoder_decoder.decode([6, 7, 1, 4, 5, 0, 2]) == BinaryRelation( - head=LabeledSpan(start=1, end=2, label="A"), - tail=LabeledSpan(start=3, end=4, label="B"), - label="C", - ) - elif mode == "label_head_tail": - assert encoder_decoder.encode( - BinaryRelation( - head=LabeledSpan(start=1, end=2, label="A"), - tail=LabeledSpan(start=3, end=4, label="B"), - label="C", - ) - ) == [2, 4, 5, 0, 6, 7, 1] - assert encoder_decoder.decode([2, 4, 5, 0, 6, 7, 1]) == BinaryRelation( - head=LabeledSpan(start=1, end=2, label="A"), - tail=LabeledSpan(start=3, end=4, label="B"), - label="C", - ) - elif mode == "label_tail_head": - assert encoder_decoder.encode( - BinaryRelation( - head=LabeledSpan(start=1, end=2, label="A"), - tail=LabeledSpan(start=3, end=4, label="B"), - label="C", - ) - ) == [2, 6, 7, 1, 4, 5, 0] - assert encoder_decoder.decode([2, 6, 7, 1, 4, 5, 0]) == BinaryRelation( - head=LabeledSpan(start=1, end=2, label="A"), - tail=LabeledSpan(start=3, end=4, label="B"), - label="C", - ) - - -@pytest.mark.parametrize( - "mode", ["head_tail_label", "tail_head_label", "label_head_tail", "label_tail_head"] -) -def test_binary_relation_encoder_decoder_loop_relation(mode): - """Test the BinaryRelationEncoderDecoder class.""" - - # we use different label2id for head and tail to test the case where the head and tail - # have different label sets - head_encoder_decoder = LabeledSpanEncoderDecoder( - span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=3), - label2id={"A": 1, "B": 2}, - mode="indices_label", - ) - tail_encoder_decoder = LabeledSpanEncoderDecoder( - span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=3), - label2id={"A": -1, "B": -2}, - mode="indices_label", - ) - encoder_decoder = BinaryRelationEncoderDecoder( - head_encoder_decoder=head_encoder_decoder, - tail_encoder_decoder=tail_encoder_decoder, - label2id={"N": 3}, - mode=mode, - loop_dummy_relation_name="L", - none_label="N", - ) - - if mode == "head_tail_label": - assert encoder_decoder.encode( - BinaryRelation( - head=LabeledSpan(start=1, end=2, label="A"), - tail=LabeledSpan(start=1, end=2, label="A"), - label="L", - ) - ) == [4, 5, 1, 3, 3, 3, 3] - assert encoder_decoder.decode([4, 5, 1, 3, 3, 3, 3]) == BinaryRelation( - head=LabeledSpan(start=1, end=2, label="A"), - tail=LabeledSpan(start=1, end=2, label="A"), - label="L", - ) - elif mode == "tail_head_label": - assert encoder_decoder.encode( - BinaryRelation( - head=LabeledSpan(start=1, end=2, label="A"), - tail=LabeledSpan(start=1, end=2, label="A"), - label="L", - ) - ) == [4, 5, -1, 3, 3, 3, 3] - assert encoder_decoder.decode([4, 5, -1, 3, 3, 3, 3]) == BinaryRelation( - head=LabeledSpan(start=1, end=2, label="A"), - tail=LabeledSpan(start=1, end=2, label="A"), - label="L", - ) - elif mode == "label_head_tail": - assert encoder_decoder.encode( - BinaryRelation( - head=LabeledSpan(start=1, end=2, label="A"), - tail=LabeledSpan(start=1, end=2, label="A"), - label="L", - ) - ) == [3, 4, 5, 1, 3, 3, 3] - assert encoder_decoder.decode([3, 4, 5, 1, 3, 3, 3]) == BinaryRelation( - head=LabeledSpan(start=1, end=2, label="A"), - tail=LabeledSpan(start=1, end=2, label="A"), - label="L", - ) - elif mode == "label_tail_head": - assert encoder_decoder.encode( - BinaryRelation( - head=LabeledSpan(start=1, end=2, label="A"), - tail=LabeledSpan(start=1, end=2, label="A"), - label="L", - ) - ) == [3, 4, 5, -1, 3, 3, 3] - assert encoder_decoder.decode([3, 4, 5, -1, 3, 3, 3]) == BinaryRelation( - head=LabeledSpan(start=1, end=2, label="A"), - tail=LabeledSpan(start=1, end=2, label="A"), - label="L", - ) - else: - raise ValueError(f"unknown mode: {mode}") - - -@pytest.mark.parametrize( - "loop_dummy_relation_name,none_label", - [("L", None), (None, "N")], -) -def test_binary_relation_encoder_decoder_only_loop_or_none_label_provided( - loop_dummy_relation_name, none_label -): - """Test the BinaryRelationEncoderDecoder class.""" - - label2id = {"A": 0, "B": 1, "N": 2} - labeled_span_encoder_decoder = LabeledSpanEncoderDecoder( - span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)), - label2id=label2id, - mode="indices_label", - ) - encoder_decoder = BinaryRelationEncoderDecoder( - head_encoder_decoder=labeled_span_encoder_decoder, - tail_encoder_decoder=labeled_span_encoder_decoder, - label2id=label2id, - mode="head_tail_label", - loop_dummy_relation_name=loop_dummy_relation_name, - none_label=none_label, - ) - - if loop_dummy_relation_name is not None: - with pytest.raises(ValueError) as excinfo: - encoder_decoder.encode( - BinaryRelation( - head=LabeledSpan(start=1, end=2, label="A"), - tail=LabeledSpan(start=1, end=2, label="A"), - label=loop_dummy_relation_name, - ) - ) - - assert ( - str(excinfo.value) - == "loop_dummy_relation_name is set, but none_label is not set: None" - ) - elif none_label is not None: - none_id = label2id[none_label] - with pytest.raises(ValueError) as excinfo: - encoder_decoder.decode([4, 5, 1, none_id, none_id, none_id, none_id]) - assert ( - str(excinfo.value) - == "loop_dummy_relation_name is not set, but none_label=N was found in decoded encoding: " - "[4, 5, 1, 2, 2, 2, 2] (label2id: {'A': 0, 'B': 1, 'N': 2}))" - ) - else: - raise ValueError("unknown setting") - - -@pytest.mark.parametrize( - "loop_dummy_relation_name,none_label", - [(None, None), ("L", "N")], -) -def test_binary_relation_encoder_decoder_unknown_mode(loop_dummy_relation_name, none_label): - """Test the BinaryRelationEncoderDecoder class.""" - - label2id = {"A": 0, "B": 1, "N": 2, "L": 3} - labeled_span_encoder_decoder = LabeledSpanEncoderDecoder( - span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)), - label2id=label2id, - mode="indices_label", - ) - encoder_decoder = BinaryRelationEncoderDecoder( - head_encoder_decoder=labeled_span_encoder_decoder, - tail_encoder_decoder=labeled_span_encoder_decoder, - label2id=label2id, - mode="unknown", - loop_dummy_relation_name=loop_dummy_relation_name, - none_label=none_label, - ) - with pytest.raises(ValueError) as excinfo: - encoder_decoder.encode( - BinaryRelation( - head=LabeledSpan(start=1, end=2, label="A"), - tail=LabeledSpan(start=1, end=2, label="A"), - label="L", - ) - ) - assert str(excinfo.value) == "unknown mode: unknown" - - with pytest.raises(ValueError) as excinfo: - encoder_decoder.decode([2, 2, 2, 2, 2, 2, 2]) - assert str(excinfo.value) == "unknown mode: unknown" - - -def test_binary_relation_encoder_decoder_wrong_encoding_size(): - """Test the BinaryRelationEncoderDecoder class.""" - - label2id = {"A": 0, "B": 1, "C": 2} - labeled_span_encoder_decoder = LabeledSpanEncoderDecoder( - span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)), - label2id=label2id, - mode="indices_label", - ) - encoder_decoder = BinaryRelationEncoderDecoder( - head_encoder_decoder=labeled_span_encoder_decoder, - tail_encoder_decoder=labeled_span_encoder_decoder, - label2id=label2id, - mode="head_tail_label", - ) - with pytest.raises(DecodingLengthException) as excinfo: - encoder_decoder.decode([1, 2, 3, 4, 5, 6]) - assert ( - str(excinfo.value) - == "seven values are required to decode as BinaryRelation, but the encoding has length 6" - ) - assert excinfo.value.identifier == "len" - - with pytest.raises(DecodingLengthException) as excinfo: - encoder_decoder.decode([1, 2, 3, 4, 5, 6, 7, 8]) - assert ( - str(excinfo.value) - == "seven values are required to decode as BinaryRelation, but the encoding has length 8" - ) - assert excinfo.value.identifier == "len" - - -def test_binary_relation_encoder_decoder_wrong_label_index(): - """Test the BinaryRelationEncoderDecoder class.""" - - label2id = {"A": 0, "B": 1, "C": 2} - labeled_span_encoder_decoder = LabeledSpanEncoderDecoder( - span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)), - label2id=label2id, - mode="indices_label", - ) - encoder_decoder = BinaryRelationEncoderDecoder( - head_encoder_decoder=labeled_span_encoder_decoder, - tail_encoder_decoder=labeled_span_encoder_decoder, - label2id=label2id, - mode="head_tail_label", - ) - with pytest.raises(DecodingLabelException) as excinfo: - encoder_decoder.decode([1, 2, 3, 4, 5, 6, 7]) - assert str(excinfo.value) == "unknown label id: 7 (label2id: {'A': 0, 'B': 1, 'C': 2})" - assert excinfo.value.identifier == "label" diff --git a/tests/taskmodules/pointer_network/test_logits_processor.py b/tests/taskmodules/pointer_network/test_logits_processor.py deleted file mode 100644 index 145fa1e4c..000000000 --- a/tests/taskmodules/pointer_network/test_logits_processor.py +++ /dev/null @@ -1,80 +0,0 @@ -import pytest -import torch - -from pie_modules.taskmodules.pointer_network.logits_processor import ( - FinitizeLogitsProcessor, - PrefixConstrainedLogitsProcessorWithMaximum, -) - - -def test_prefix_constrained_logits_processor_with_maximum(): - def allow_last_three(batch_id, sent, max_index): - return list(range(max_index - 3, max_index)) - - logits_processor = PrefixConstrainedLogitsProcessorWithMaximum( - prefix_allowed_tokens_fn=allow_last_three, num_beams=1 - ) - - input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7]]).to(dtype=torch.long) - scores = torch.tensor([[0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0]]).to(dtype=torch.float) - new_scores = logits_processor(input_ids, scores) - assert new_scores.shape == scores.shape - torch.testing.assert_close( - new_scores, - torch.tensor( - [[-float("inf"), -float("inf"), -float("inf"), -float("inf"), 0.9, 0.9, 0.0]] - ), - ) - - -def test_prefix_constrained_logits_processor_with_maximum_with_inf_scores(): - def allow_last_three(batch_id, sent, max_index): - return list(range(max_index - 3, max_index)) - - logits_processor = PrefixConstrainedLogitsProcessorWithMaximum( - prefix_allowed_tokens_fn=allow_last_three, num_beams=1 - ) - input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7]]).to(dtype=torch.long) - scores_with_pos_inf = torch.tensor([[0.9, 0.9, float("inf"), 0.9, 0.9, 0.9, 0.0]]).to( - dtype=torch.float - ) - scores_with_neg_inf = torch.tensor([[0.9, 0.9, -float("inf"), 0.9, 0.9, 0.9, 0.0]]).to( - dtype=torch.float - ) - - with pytest.raises(ValueError, match="scores contains ±inf or NaN"): - logits_processor(input_ids, scores_with_pos_inf) - - with pytest.raises(ValueError, match="scores contains ±inf or NaN"): - logits_processor(input_ids, scores_with_neg_inf) - - -def test_prefix_constrained_logits_processor_with_maximum_without_allowed_tokens(): - def allow_no_tokens(batch_id, sent, max_index): - return [] - - logits_processor = PrefixConstrainedLogitsProcessorWithMaximum( - prefix_allowed_tokens_fn=allow_no_tokens, num_beams=1 - ) - - input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7]]).to(dtype=torch.long) - scores = torch.tensor([[0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0]]).to(dtype=torch.float) - - with pytest.raises(ValueError, match="No allowed token ids for batch_id"): - logits_processor(input_ids, scores) - - -def test_finitize_logits_processor(): - logits_processor = FinitizeLogitsProcessor() - - input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7]]).to(dtype=torch.long) - scores = torch.tensor([[0.9, 0.9, float("inf"), 0.9, 0.9, -float("inf"), 0.0]]).to( - dtype=torch.float - ) - new_scores = logits_processor(input_ids, scores) - - assert new_scores.shape == scores.shape - torch.testing.assert_close( - new_scores, - torch.tensor([[0.9, 0.9, 3.4028235e38, 0.9, 0.9, -3.4028235e38, 0.0]]), - ) diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py deleted file mode 100644 index 065f9b519..000000000 --- a/tests/taskmodules/test_cross_text_binary_coref.py +++ /dev/null @@ -1,395 +0,0 @@ -import json -from typing import Any, Dict, Union - -import pytest -import torch.testing -from pie_core.utils.dictionary import flatten_dict_s, list_of_dicts2dict_of_lists -from torch import tensor -from torchmetrics import Metric, MetricCollection - -from pie_modules.annotations import LabeledSpan -from pie_modules.document.processing.text_pair import add_negative_coref_relations -from pie_modules.documents import ( - BinaryCorefRelation, - TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, -) -from pie_modules.taskmodules import CrossTextBinaryCorefTaskModule -from tests import FIXTURES_ROOT, _config_to_str - -TOKENIZER_NAME_OR_PATH = "bert-base-cased" -DOC_IDX_WITH_TASK_ENCODINGS = 2 - -CONFIGS = [ - {}, -] -CONFIGS_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS} - - -@pytest.fixture(scope="module", params=CONFIGS_DICT.keys()) -def config(request): - return CONFIGS_DICT[request.param] - - -@pytest.fixture(scope="module") -def positive_documents(): - doc1 = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( - id="0", text="Entity A works at B.", text_pair="And she founded C." - ) - doc1.labeled_spans.append(LabeledSpan(start=0, end=8, label="PERSON")) - doc1.labeled_spans.append(LabeledSpan(start=18, end=19, label="COMPANY")) - doc1.labeled_spans_pair.append(LabeledSpan(start=4, end=7, label="PERSON")) - doc1.labeled_spans_pair.append(LabeledSpan(start=16, end=17, label="COMPANY")) - doc1.binary_coref_relations.append( - BinaryCorefRelation(head=doc1.labeled_spans[0], tail=doc1.labeled_spans_pair[0]) - ) - - doc2 = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( - id="0", text="Bob loves his cat.", text_pair="She sleeps a lot." - ) - doc2.labeled_spans.append(LabeledSpan(start=0, end=3, label="PERSON")) - doc2.labeled_spans.append(LabeledSpan(start=10, end=17, label="ANIMAL")) - doc2.labeled_spans_pair.append(LabeledSpan(start=0, end=3, label="ANIMAL")) - doc2.binary_coref_relations.append( - BinaryCorefRelation(head=doc2.labeled_spans[1], tail=doc2.labeled_spans_pair[0]) - ) - - return [doc1, doc2] - - -def test_positive_documents(positive_documents): - assert len(positive_documents) == 2 - doc1, doc2 = positive_documents - assert doc1.labeled_spans.resolve() == [("PERSON", "Entity A"), ("COMPANY", "B")] - assert doc1.labeled_spans_pair.resolve() == [("PERSON", "she"), ("COMPANY", "C")] - assert doc1.binary_coref_relations.resolve() == [ - ("coref", (("PERSON", "Entity A"), ("PERSON", "she"))) - ] - - assert doc2.labeled_spans.resolve() == [("PERSON", "Bob"), ("ANIMAL", "his cat")] - assert doc2.labeled_spans_pair.resolve() == [("ANIMAL", "She")] - assert doc2.binary_coref_relations.resolve() == [ - ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She"))) - ] - - -@pytest.fixture(scope="module") -def unprepared_taskmodule(config): - taskmodule = CrossTextBinaryCorefTaskModule( - tokenizer_name_or_path=TOKENIZER_NAME_OR_PATH, **config - ) - assert not taskmodule.is_from_pretrained - - return taskmodule - - -@pytest.fixture(scope="module") -def taskmodule(unprepared_taskmodule, positive_documents): - unprepared_taskmodule.prepare(positive_documents) - return unprepared_taskmodule - - -@pytest.fixture(scope="module") -def documents_with_negatives(taskmodule, positive_documents): - file_name = ( - FIXTURES_ROOT / "taskmodules" / "cross_text_binary_coref" / "documents_with_negatives.json" - ) - - # result = list(add_negative_relations(positive_documents)) - # result_json = [doc.asdict() for doc in result] - # with open(file_name, "w") as f: - # json.dump(result_json, f, indent=2) - - with open(file_name) as f: - result_json = json.load(f) - result = [ - TextPairDocumentWithLabeledSpansAndBinaryCorefRelations.fromdict(doc_json) - for doc_json in result_json - ] - - return result - - -@pytest.fixture(scope="module") -def task_encodings_without_target(taskmodule, documents_with_negatives): - task_encodings = taskmodule.encode_input(documents_with_negatives[DOC_IDX_WITH_TASK_ENCODINGS]) - return task_encodings - - -def test_encode_input(task_encodings_without_target, taskmodule): - task_encodings = task_encodings_without_target - convert_ids_to_tokens = taskmodule.tokenizer.convert_ids_to_tokens - - inputs_dict = list_of_dicts2dict_of_lists( - [task_encoding.inputs for task_encoding in task_encodings] - ) - tokens = [convert_ids_to_tokens(encoding["input_ids"]) for encoding in inputs_dict["encoding"]] - tokens_pair = [ - convert_ids_to_tokens(encoding["input_ids"]) for encoding in inputs_dict["encoding_pair"] - ] - assert tokens == [ - ["[CLS]", "And", "she", "founded", "C", ".", "[SEP]"], - ["[CLS]", "And", "she", "founded", "C", ".", "[SEP]"], - ] - assert tokens_pair == [ - ["[CLS]", "En", "##ti", "##ty", "A", "works", "at", "B", ".", "[SEP]"], - ["[CLS]", "En", "##ti", "##ty", "A", "works", "at", "B", ".", "[SEP]"], - ] - span_tokens = [ - toks[start:end] - for toks, start, end in zip( - tokens, inputs_dict["pooler_start_indices"], inputs_dict["pooler_end_indices"] - ) - ] - span_tokens_pair = [ - toks[start:end] - for toks, start, end in zip( - tokens_pair, - inputs_dict["pooler_pair_start_indices"], - inputs_dict["pooler_pair_end_indices"], - ) - ] - assert span_tokens == [["she"], ["C"]] - assert span_tokens_pair == [["En", "##ti", "##ty", "A"], ["B"]] - - -def test_encode_target(task_encodings_without_target, taskmodule): - targets = [ - taskmodule.encode_target(task_encoding) for task_encoding in task_encodings_without_target - ] - assert targets == [1.0, 0.0] - - -def test_encode_with_collect_statistics(taskmodule, positive_documents): - documents_with_negatives = add_negative_coref_relations(positive_documents) - original_values = taskmodule.collect_statistics - taskmodule.collect_statistics = True - taskmodule.encode(documents_with_negatives, encode_target=True) - statistics = taskmodule.get_statistics() - taskmodule.collect_statistics = original_values - - assert statistics == { - ("available", "coref"): 4, - ("available", "no_relation"): 6, - ("used", "coref"): 4, - ("used", "no_relation"): 6, - } - - -def test_encode_with_windowing(documents_with_negatives): - tokenizer_name_or_path = "bert-base-cased" - taskmodule = CrossTextBinaryCorefTaskModule( - tokenizer_name_or_path=tokenizer_name_or_path, - max_window=4, - collect_statistics=True, - ) - assert not taskmodule.is_from_pretrained - taskmodule.prepare(documents_with_negatives) - - assert len(documents_with_negatives) == 16 - - task_encodings = taskmodule.encode(documents_with_negatives) - statistics = taskmodule.get_statistics() - - assert statistics == { - ("available", "coref"): 4, - ("available", "no_relation"): 6, - ("skipped_span_does_not_fit_into_window", "coref"): 2, - ("skipped_span_does_not_fit_into_window", "no_relation"): 2, - ("used", "coref"): 2, - ("used", "no_relation"): 4, - } - - assert len(task_encodings) == 6 - for task_encoding in task_encodings: - for k, v in task_encoding.inputs["encoding"].items(): - assert len(v) <= taskmodule.max_window - for k, v in task_encoding.inputs["encoding_pair"].items(): - assert len(v) <= taskmodule.max_window - - -@pytest.fixture(scope="module") -def task_encodings(taskmodule, documents_with_negatives): - return taskmodule.encode( - documents_with_negatives[DOC_IDX_WITH_TASK_ENCODINGS], encode_target=True - ) - - -@pytest.fixture(scope="module") -def batch(taskmodule, task_encodings): - result = taskmodule.collate(task_encodings) - return result - - -def test_collate(batch, taskmodule): - assert batch is not None - inputs, targets = batch - assert inputs is not None - assert set(inputs) == { - "pooler_end_indices", - "encoding_pair", - "pooler_pair_end_indices", - "pooler_start_indices", - "encoding", - "pooler_pair_start_indices", - } - torch.testing.assert_close( - inputs["encoding"]["input_ids"], - torch.tensor( - [[101, 1262, 1131, 1771, 140, 119, 102], [101, 1262, 1131, 1771, 140, 119, 102]] - ), - ) - torch.testing.assert_close( - inputs["encoding"]["token_type_ids"], torch.zeros_like(inputs["encoding"]["input_ids"]) - ) - torch.testing.assert_close( - inputs["encoding"]["attention_mask"], torch.ones_like(inputs["encoding"]["input_ids"]) - ) - - torch.testing.assert_close( - inputs["encoding_pair"]["input_ids"], - torch.tensor( - [ - [101, 13832, 3121, 2340, 138, 1759, 1120, 139, 119, 102], - [101, 13832, 3121, 2340, 138, 1759, 1120, 139, 119, 102], - ] - ), - ) - torch.testing.assert_close( - inputs["encoding_pair"]["token_type_ids"], - torch.zeros_like(inputs["encoding_pair"]["input_ids"]), - ) - torch.testing.assert_close( - inputs["encoding_pair"]["attention_mask"], - torch.ones_like(inputs["encoding_pair"]["input_ids"]), - ) - - torch.testing.assert_close(inputs["pooler_start_indices"], torch.tensor([[2], [4]])) - torch.testing.assert_close(inputs["pooler_end_indices"], torch.tensor([[3], [5]])) - torch.testing.assert_close(inputs["pooler_pair_start_indices"], torch.tensor([[1], [7]])) - torch.testing.assert_close(inputs["pooler_pair_end_indices"], torch.tensor([[5], [8]])) - - torch.testing.assert_close(targets, {"scores": torch.tensor([1.0, 0.0])}) - - -@pytest.fixture(scope="module") -def unbatched_output(taskmodule): - model_output = { - "scores": torch.tensor([0.5338148474693298, 0.9866107940673828]), - } - return taskmodule.unbatch_output(model_output=model_output) - - -def test_unbatch_output(unbatched_output, taskmodule): - assert len(unbatched_output) == 2 - assert unbatched_output == [ - {"is_similar": False, "score": 0.5338148474693298}, - {"is_similar": True, "score": 0.9866107702255249}, - ] - - -def test_create_annotation_from_output(taskmodule, task_encodings, unbatched_output): - all_new_annotations = [] - for task_encoding, task_output in zip(task_encodings, unbatched_output): - for new_annotation in taskmodule.create_annotations_from_output( - task_encoding=task_encoding, task_output=task_output - ): - all_new_annotations.append(new_annotation) - assert all(layer_name == "binary_coref_relations" for layer_name, ann in all_new_annotations) - resolve_annotations_with_scores = [ - (round(ann.score, 4), ann.resolve()) for layer_name, ann in all_new_annotations - ] - assert resolve_annotations_with_scores == [ - (0.9866, ("coref", (("COMPANY", "C"), ("COMPANY", "B")))), - ] - - -def get_metric_state(metric_or_collection: Union[Metric, MetricCollection]) -> Dict[str, Any]: - if isinstance(metric_or_collection, Metric): - return flatten_dict_s(metric_or_collection.metric_state) - elif isinstance(metric_or_collection, MetricCollection): - return flatten_dict_s({k: get_metric_state(v) for k, v in metric_or_collection.items()}) - else: - raise ValueError(f"unsupported type: {type(metric_or_collection)}") - - -def test_configure_metric(taskmodule, batch): - metric = taskmodule.configure_model_metric(stage="train") - - assert isinstance(metric, (Metric, MetricCollection)) - state = get_metric_state(metric) - torch.testing.assert_close( - state, - { - "continuous/auroc/preds": [], - "continuous/auroc/target": [], - "continuous/avg-P/preds": [], - "continuous/avg-P/target": [], - "continuous/f1/fn": tensor([0]), - "continuous/f1/fp": tensor([0]), - "continuous/f1/tn": tensor([0]), - "continuous/f1/tp": tensor([0]), - }, - ) - - # targets = batch[1] - targets = { - "scores": torch.tensor([0.0, 1.0, 0.0, 0.0]), - } - metric.update(targets, targets) - - state = get_metric_state(metric) - torch.testing.assert_close( - state, - { - "continuous/auroc/preds": [tensor([0.0, 1.0, 0.0, 0.0])], - "continuous/auroc/target": [tensor([0.0, 1.0, 0.0, 0.0])], - "continuous/avg-P/preds": [tensor([0.0, 1.0, 0.0, 0.0])], - "continuous/avg-P/target": [tensor([0.0, 1.0, 0.0, 0.0])], - "continuous/f1/tp": tensor([1]), - "continuous/f1/fp": tensor([0]), - "continuous/f1/tn": tensor([3]), - "continuous/f1/fn": tensor([0]), - }, - ) - - torch.testing.assert_close( - metric.compute(), - {"auroc": tensor(1.0), "avg-P": tensor(1.0), "f1": tensor(1.0)}, - ) - - # torch.rand_like(targets) - random_targets = { - "scores": torch.tensor([0.2703, 0.6812, 0.2582, 0.9030]), - } - metric.update(random_targets, targets) - state = get_metric_state(metric) - torch.testing.assert_close( - state, - { - "continuous/auroc/preds": [ - tensor([0.0, 1.0, 0.0, 0.0]), - tensor([0.2703, 0.6812, 0.2582, 0.9030]), - ], - "continuous/auroc/target": [ - tensor([0.0, 1.0, 0.0, 0.0]), - tensor([0.0, 1.0, 0.0, 0.0]), - ], - "continuous/avg-P/preds": [ - tensor([0.0, 1.0, 0.0, 0.0]), - tensor([0.2703, 0.6812, 0.2582, 0.9030]), - ], - "continuous/avg-P/target": [ - tensor([0.0, 1.0, 0.0, 0.0]), - tensor([0.0, 1.0, 0.0, 0.0]), - ], - "continuous/f1/tp": tensor([1]), - "continuous/f1/fp": tensor([1]), - "continuous/f1/tn": tensor([5]), - "continuous/f1/fn": tensor([1]), - }, - ) - - torch.testing.assert_close( - metric.compute(), - {"auroc": tensor(0.91666663), "avg-P": tensor(0.83333337), "f1": tensor(0.50000000)}, - ) diff --git a/tests/taskmodules/test_extractive_question_answering.py b/tests/taskmodules/test_extractive_question_answering.py deleted file mode 100644 index d38e532d6..000000000 --- a/tests/taskmodules/test_extractive_question_answering.py +++ /dev/null @@ -1,277 +0,0 @@ -import pytest -import torch -import transformers -from pie_core import AnnotationLayer - -from pie_modules.annotations import ExtractiveAnswer, Question -from pie_modules.documents import TextDocumentWithQuestionsAndExtractiveAnswers -from pie_modules.taskmodules.extractive_question_answering import ( - ExtractiveQuestionAnsweringTaskModule, -) - - -@pytest.fixture() -def document(): - document = TextDocumentWithQuestionsAndExtractiveAnswers( - text="This is a test document", id="doc0" - ) - document.questions.append(Question(text="What is the first word?")) - document.answers.append(ExtractiveAnswer(question=document.questions[0], start=0, end=4)) - assert str(document.answers[0]) == "This" - return document - - -@pytest.fixture() -def document1(): - document1 = TextDocumentWithQuestionsAndExtractiveAnswers( - text="This is the second document", id="doc1" - ) - document1.questions.append(Question(text="Which document is this?")) - document1.answers.append(ExtractiveAnswer(question=document1.questions[0], start=13, end=18)) - assert str(document1.answers[0]) == "second" - return document1 - - -@pytest.fixture() -def document_with_no_answer(): - document = TextDocumentWithQuestionsAndExtractiveAnswers( - text="This is a test document", id="document_with_no_answer" - ) - document.questions.append(Question(text="What is the first word?")) - return document - - -@pytest.fixture() -def document_with_multiple_answers(): - document = TextDocumentWithQuestionsAndExtractiveAnswers( - text="This is a test document", id="document_with_multiple_answers" - ) - document.questions.append(Question(text="What is the first word?")) - document.answers.append(ExtractiveAnswer(question=document.questions[0], start=0, end=4)) - assert str(document.answers[0]) == "This" - document.answers.append(ExtractiveAnswer(question=document.questions[0], start=0, end=7)) - assert str(document.answers[1]) == "This is" - return document - - -@pytest.fixture() -def taskmodule(): - return ExtractiveQuestionAnsweringTaskModule( - tokenizer_name_or_path="bert-base-uncased", max_length=128 - ) - - -def test_encode_input( - taskmodule, document, document_with_no_answer, document_with_multiple_answers -): - inputs = taskmodule.encode_input(document) - assert inputs is not None - assert len(inputs) == 1 - expected_inputs = [ - 101, - 2054, - 2003, - 1996, - 2034, - 2773, - 1029, - 102, - 2023, - 2003, - 1037, - 3231, - 6254, - 102, - ] - assert inputs[0].inputs == expected_inputs - - inputs = taskmodule.encode_input(document_with_no_answer) - assert inputs is not None - assert len(inputs) == 1 - assert inputs[0].inputs == expected_inputs - - inputs = taskmodule.encode_input(document_with_multiple_answers) - assert inputs is not None - assert len(inputs) == 1 - assert inputs[0].inputs == expected_inputs - - -def test_encode_target(taskmodule, document, document_with_no_answer): - inputs = taskmodule.encode_input(document) - targets = taskmodule.encode_target(inputs[0]) - assert targets is not None - assert targets.start_position == 8 - assert targets.end_position == 8 - - inputs = taskmodule.encode_input(document_with_no_answer) - targets = taskmodule.encode_target(inputs[0]) - assert targets is not None - assert targets.start_position == 0 - assert targets.end_position == 0 - - -def test_get_question_layer(taskmodule, document, document_with_no_answer): - question_layer = taskmodule.get_question_layer(document) - assert question_layer is not None - assert len(question_layer) == 1 - assert type(question_layer) is AnnotationLayer - assert type(question_layer[0]) is Question - assert question_layer[0].text == "What is the first word?" - - question_layer = taskmodule.get_question_layer(document_with_no_answer) - assert question_layer is not None - assert len(question_layer) == 1 - assert type(question_layer) is AnnotationLayer - assert type(question_layer[0]) is Question - assert question_layer[0].text == "What is the first word?" - - -def test_get_answer_layer(taskmodule, document, document_with_no_answer): - answer_layer = taskmodule.get_answer_layer(document) - assert answer_layer is not None - assert len(answer_layer) == 1 - assert type(answer_layer) is AnnotationLayer - assert type(answer_layer[0]) is ExtractiveAnswer - assert answer_layer[0].question.text == "What is the first word?" - assert answer_layer[0].start == 0 - assert answer_layer[0].end == 4 - - answer_layer = taskmodule.get_answer_layer(document_with_no_answer) - assert answer_layer is not None - assert len(answer_layer) == 0 - assert type(answer_layer) is AnnotationLayer - - -def test_get_context(taskmodule, document, document_with_no_answer): - context = taskmodule.get_context(document) - assert context is not None - assert context == "This is a test document" - - context = taskmodule.get_context(document_with_no_answer) - assert context is not None - assert context == "This is a test document" - - -@pytest.fixture() -def documents(document, document_with_no_answer): - return [document, document_with_no_answer] - - -@pytest.fixture() -def batch_without_targets(taskmodule, documents): - task_encodings = taskmodule.encode(documents) - batch_encoding = taskmodule.collate(task_encodings) - return batch_encoding - - -def test_collate_without_targets(batch_without_targets): - assert batch_without_targets is not None - assert len(batch_without_targets) == 2 - inputs, targets = batch_without_targets - assert inputs is not None - assert targets is None - - -@pytest.fixture() -def task_encodings(taskmodule, documents): - task_encodings = taskmodule.encode(documents, encode_target=True) - return task_encodings - - -@pytest.fixture() -def batch(taskmodule, task_encodings): - batch_encoding = taskmodule.collate(task_encodings) - return batch_encoding - - -def test_collate_with_targets(batch): - assert batch is not None - assert len(batch) == 2 - inputs, targets = batch - assert inputs is not None - assert set(inputs.data) == {"input_ids", "token_type_ids", "attention_mask"} - assert inputs.data["input_ids"].shape == (2, 14) - assert inputs.data["token_type_ids"].shape == (2, 14) - assert inputs.data["attention_mask"].shape == (2, 14) - assert targets is not None - assert set(targets) == {"start_positions", "end_positions"} - assert targets["start_positions"].shape == (2,) - assert targets["end_positions"].shape == (2,) - - expected_inputs_ids = [ - [101, 2054, 2003, 1996, 2034, 2773, 1029, 102, 2023, 2003, 1037, 3231, 6254, 102], - [101, 2054, 2003, 1996, 2034, 2773, 1029, 102, 2023, 2003, 1037, 3231, 6254, 102], - ] - expected_token_type_ids = [ - [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], - ] - expected_attention_mask = [ - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - ] - assert inputs.data["input_ids"].tolist() == expected_inputs_ids - assert inputs.data["token_type_ids"].tolist() == expected_token_type_ids - assert inputs.data["attention_mask"].tolist() == expected_attention_mask - - expected_start_positions = [8, 0] - expected_end_positions = [8, 0] - assert targets["start_positions"].tolist() == expected_start_positions - assert targets["end_positions"].tolist() == expected_end_positions - - -@pytest.fixture() -def model_outputs(batch): - # create probabilities that "perfectly" model the batch targets - inputs, targets = batch - start_probs = torch.zeros_like(inputs.input_ids, dtype=torch.float) + 0.05 - end_probs = torch.zeros_like(inputs.input_ids, dtype=torch.float) + 0.05 - # set target positions to 0.95 as a dummy value - for idx, (start_position, end_position) in enumerate( - zip(targets["start_positions"], targets["end_positions"]) - ): - start_probs[idx, start_position] = 0.95 - end_probs[idx, end_position] = 0.95 - - # convert probs to logits - start_logits = torch.log(start_probs / (1 - start_probs)) - end_logits = torch.log(end_probs / (1 - end_probs)) - - model_outputs = transformers.modeling_outputs.QuestionAnsweringModelOutput( - start_logits=start_logits, - end_logits=end_logits, - ) - return model_outputs - - -@pytest.fixture() -def unbatched_output(taskmodule, model_outputs): - return taskmodule.unbatch_output(model_outputs) - - -def test_unbatch_output(unbatched_output): - assert unbatched_output is not None - assert len(unbatched_output) == 2 - # check first result - assert unbatched_output[0].start == 8 - assert unbatched_output[0].end == 8 - assert unbatched_output[0].start_probability == pytest.approx(0.9652407) - assert unbatched_output[0].end_probability == pytest.approx(0.9652407) - # check second result - assert unbatched_output[1].start == 0 - assert unbatched_output[1].end == 0 - assert unbatched_output[1].start_probability == pytest.approx(0.9652407) - assert unbatched_output[1].end_probability == pytest.approx(0.9652407) - - -def test_create_annotations_from_output(taskmodule, task_encodings, unbatched_output, documents): - taskmodule.combine_outputs(task_encodings, unbatched_output) - assert len(documents) > 0 - for doc in documents: - gold_annotations = doc.answers - predicted_annotations = doc.answers.predictions - assert len(predicted_annotations) == len(gold_annotations) - for predicted_annotation, gold_annotation in zip(predicted_annotations, gold_annotations): - # we did construct the predicted annotations from the gold annotations, so they should be equal - assert predicted_annotation == gold_annotation - assert predicted_annotation.score == pytest.approx(0.9316896200180054) diff --git a/tests/taskmodules/test_labeled_span_extraction_by_token_classification.py b/tests/taskmodules/test_labeled_span_extraction_by_token_classification.py deleted file mode 100644 index b05374a9b..000000000 --- a/tests/taskmodules/test_labeled_span_extraction_by_token_classification.py +++ /dev/null @@ -1,883 +0,0 @@ -import logging -import pickle -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, Dict, List - -import pytest -import torch -from pie_core import AnnotationLayer, annotation_field -from torch import tensor -from transformers import BatchEncoding - -from pie_modules.annotations import LabeledSpan -from pie_modules.documents import ( - TextBasedDocument, - TextDocumentWithLabeledSpans, - TextDocumentWithLabeledSpansAndLabeledPartitions, -) -from pie_modules.taskmodules import LabeledSpanExtractionByTokenClassificationTaskModule -from pie_modules.taskmodules.labeled_span_extraction_by_token_classification import ( - ModelOutputType, -) - - -def _config_to_str(cfg: Dict[str, Any]) -> str: - # Converts a configuration dictionary to a string representation - result = "-".join([f"{k}={cfg[k]}" for k in sorted(cfg)]) - return result - - -CONFIG_DEFAULT = {} -CONFIG_MAX_WINDOW = { - "tokenize_kwargs": {"max_length": 8, "truncation": True, "return_overflowing_tokens": True} -} -CONFIG_MAX_WINDOW_WITH_STRIDE = { - "tokenize_kwargs": { - "max_length": 8, - "stride": 2, - "truncation": True, - "return_overflowing_tokens": True, - } -} -CONFIG_PARTITIONS = {"partition_annotation": "sentences"} - -CONFIGS: List[Dict[str, Any]] = [ - CONFIG_DEFAULT, - CONFIG_MAX_WINDOW, - CONFIG_MAX_WINDOW_WITH_STRIDE, - CONFIG_PARTITIONS, -] - -CONFIGS_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS} - - -@pytest.fixture(scope="module", params=CONFIGS_DICT.keys()) -def config(request): - """ - - Provides clean and readable test configurations. - - Yields config dictionaries from the CONFIGS list to produce clean test case identifiers. - - """ - return CONFIGS_DICT[request.param] - - -@pytest.fixture(scope="module") -def config_str(config): - # Fixture returning a string representation of the config - return _config_to_str(config) - - -@pytest.fixture(scope="module") -def unprepared_taskmodule(config): - """ - - Prepares a task module with the specified tokenizer and configuration. - - Sets up the task module with a unprepared state for testing purposes. - - """ - return LabeledSpanExtractionByTokenClassificationTaskModule( - tokenizer_name_or_path="bert-base-uncased", span_annotation="entities", **config - ) - - -@dataclass -class ExampleDocument(TextBasedDocument): - entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") - sentences: AnnotationLayer[LabeledSpan] = annotation_field(target="text") - - -@pytest.fixture(scope="module") -def documents(): - """ - - Creates example documents with predefined texts. - - Assigns labels to the documents for testing purposes. - - """ - doc1 = ExampleDocument(text="Mount Everest is the highest peak in the world.", id="doc1") - doc1.entities.append(LabeledSpan(start=0, end=13, label="LOC")) - assert str(doc1.entities[0]) == "Mount Everest" - - doc2 = ExampleDocument(text="Alice loves reading books. Bob enjoys playing soccer.", id="doc2") - doc2.entities.append(LabeledSpan(start=0, end=5, label="PER")) - assert str(doc2.entities[0]) == "Alice" - doc2.entities.append(LabeledSpan(start=27, end=30, label="PER")) - assert str(doc2.entities[1]) == "Bob" - # we add just one sentence to doc2 that covers only Bob - doc2.sentences.append(LabeledSpan(start=27, end=53, label="sentence")) - assert str(doc2.sentences[0]) == "Bob enjoys playing soccer." - - return [doc1, doc2] - - -def test_taskmodule(unprepared_taskmodule): - assert unprepared_taskmodule is not None - - -@pytest.fixture(scope="module") -def taskmodule(unprepared_taskmodule, documents): - """ - - Prepares the task module with the given documents, i.e. collect available label values. - - Calls the necessary methods to prepare the task module with the documents. - - Calls _prepare(documents) and then _post_prepare() - - """ - unprepared_taskmodule.prepare(documents) - return unprepared_taskmodule - - -def test_prepare(taskmodule): - assert taskmodule is not None - assert taskmodule.is_prepared - assert taskmodule.label_to_id == {"B-LOC": 1, "B-PER": 3, "I-LOC": 2, "I-PER": 4, "O": 0} - assert taskmodule.id_to_label == {0: "O", 1: "B-LOC", 2: "I-LOC", 3: "B-PER", 4: "I-PER"} - - -def test_config(taskmodule): - config = taskmodule._config() - assert config["taskmodule_type"] == "LabeledSpanExtractionByTokenClassificationTaskModule" - assert "labels" in config - assert config["labels"] == ["LOC", "PER"] - - -@pytest.fixture(scope="module") -def task_encodings_without_targets(taskmodule, documents): - """ - - Generates task encodings for all the documents, but without associated targets. - """ - return taskmodule.encode(documents, encode_target=False) - - -def test_task_encodings_without_targets(task_encodings_without_targets, taskmodule, config): - tokens = [ - taskmodule.tokenizer.convert_ids_to_tokens(task_encoding.inputs.ids) - for task_encoding in task_encodings_without_targets - ] - - # If config is empty - if config == CONFIG_DEFAULT: - assert tokens == [ - [ - "[CLS]", - "mount", - "everest", - "is", - "the", - "highest", - "peak", - "in", - "the", - "world", - ".", - "[SEP]", - ], - [ - "[CLS]", - "alice", - "loves", - "reading", - "books", - ".", - "bob", - "enjoys", - "playing", - "soccer", - ".", - "[SEP]", - ], - ] - - # If config has the specified values (max_window=8, window_overlap=2) - elif config == CONFIG_MAX_WINDOW_WITH_STRIDE: - for t in tokens: - assert len(t) <= 8 - - assert tokens == [ - ["[CLS]", "mount", "everest", "is", "the", "highest", "peak", "[SEP]"], - ["[CLS]", "highest", "peak", "in", "the", "world", ".", "[SEP]"], - ["[CLS]", "alice", "loves", "reading", "books", ".", "bob", "[SEP]"], - ["[CLS]", ".", "bob", "enjoys", "playing", "soccer", ".", "[SEP]"], - ] - - # If config has the specified value (max_window=8) - elif config == CONFIG_MAX_WINDOW: - for t in tokens: - assert len(t) <= 8 - - assert tokens == [ - ["[CLS]", "mount", "everest", "is", "the", "highest", "peak", "[SEP]"], - ["[CLS]", "in", "the", "world", ".", "[SEP]"], - ["[CLS]", "alice", "loves", "reading", "books", ".", "bob", "[SEP]"], - ["[CLS]", "enjoys", "playing", "soccer", ".", "[SEP]"], - ] - - # If config has the specified value (partition_annotation=sentences) - elif config == CONFIG_PARTITIONS: - assert tokens - - else: - raise ValueError(f"unknown config: {config}") - - -@pytest.fixture(scope="module") -def task_encodings(taskmodule, documents): - return taskmodule.encode(documents, encode_target=True) - - -def test_task_encodings(task_encodings, taskmodule, config): - tokens = [ - taskmodule.tokenizer.convert_ids_to_tokens(task_encoding.inputs.ids) - for task_encoding in task_encodings - ] - labels_tokens = [ - [taskmodule.id_to_label[x] if x != -100 else "" for x in task_encoding.targets] - for task_encoding in task_encodings - ] - assert len(labels_tokens) == len(tokens) - - tokens_with_labels = list(zip(tokens, labels_tokens)) - - for tokens, labels in tokens_with_labels: - assert len(tokens) == len(labels) - - # If config is empty - if config == CONFIG_DEFAULT: - assert tokens_with_labels == [ - ( - [ - "[CLS]", - "mount", - "everest", - "is", - "the", - "highest", - "peak", - "in", - "the", - "world", - ".", - "[SEP]", - ], - ["", "B-LOC", "I-LOC", "O", "O", "O", "O", "O", "O", "O", "O", ""], - ), - ( - [ - "[CLS]", - "alice", - "loves", - "reading", - "books", - ".", - "bob", - "enjoys", - "playing", - "soccer", - ".", - "[SEP]", - ], - ["", "B-PER", "O", "O", "O", "O", "B-PER", "O", "O", "O", "O", ""], - ), - ] - - # If config has the specified values (max_window=8, window_overlap=2) - elif config == CONFIG_MAX_WINDOW_WITH_STRIDE: - for tokens, labels in tokens_with_labels: - assert len(tokens) <= 8 - - assert tokens_with_labels == [ - ( - ["[CLS]", "mount", "everest", "is", "the", "highest", "peak", "[SEP]"], - ["", "B-LOC", "I-LOC", "O", "O", "O", "O", ""], - ), - ( - ["[CLS]", "highest", "peak", "in", "the", "world", ".", "[SEP]"], - ["", "O", "O", "O", "O", "O", "O", ""], - ), - ( - ["[CLS]", "alice", "loves", "reading", "books", ".", "bob", "[SEP]"], - ["", "B-PER", "O", "O", "O", "O", "B-PER", ""], - ), - ( - ["[CLS]", ".", "bob", "enjoys", "playing", "soccer", ".", "[SEP]"], - ["", "O", "B-PER", "O", "O", "O", "O", ""], - ), - ] - - # If config has the specified value (max_window=8) - elif config == CONFIG_MAX_WINDOW: - for tokens, labels in tokens_with_labels: - assert len(tokens) <= 8 - - assert tokens_with_labels == [ - ( - ["[CLS]", "mount", "everest", "is", "the", "highest", "peak", "[SEP]"], - ["", "B-LOC", "I-LOC", "O", "O", "O", "O", ""], - ), - ( - ["[CLS]", "in", "the", "world", ".", "[SEP]"], - ["", "O", "O", "O", "O", ""], - ), - ( - ["[CLS]", "alice", "loves", "reading", "books", ".", "bob", "[SEP]"], - ["", "B-PER", "O", "O", "O", "O", "B-PER", ""], - ), - ( - ["[CLS]", "enjoys", "playing", "soccer", ".", "[SEP]"], - ["", "O", "O", "O", "O", ""], - ), - ] - - # If config has the specified value (partition_annotation=sentences) - elif config == CONFIG_PARTITIONS: - assert tokens_with_labels == [ - ( - ["[CLS]", "bob", "enjoys", "playing", "soccer", ".", "[SEP]"], - ["", "B-PER", "O", "O", "O", "O", ""], - ) - ] - - else: - raise ValueError(f"unknown config: {config}") - - -def test_encode_targets_with_overlap(caplog): - # setup taskmodule - taskmodule = LabeledSpanExtractionByTokenClassificationTaskModule( - tokenizer_name_or_path="bert-base-uncased", labels=["LOC", "PER"] - ) - taskmodule.post_prepare() - - # create a document with overlapping entities - doc = TextDocumentWithLabeledSpans( - text="Alice loves reading books. Bob enjoys playing soccer." - ) - doc.labeled_spans.append(LabeledSpan(start=0, end=5, label="PER")) - doc.labeled_spans.append(LabeledSpan(start=27, end=30, label="PER")) - doc.labeled_spans.append(LabeledSpan(start=27, end=37, label="PER")) - assert str(doc.labeled_spans[0]) == "Alice" - assert str(doc.labeled_spans[1]) == "Bob" - assert str(doc.labeled_spans[2]) == "Bob enjoys" - - # encode the document - with caplog.at_level(logging.WARNING): - task_encodings = taskmodule.encode([doc], encode_target=True) - assert len(caplog.records) == 1 - assert ( - caplog.messages[0] - == "tag already assigned (current span has an overlap: ('bob', 'enjoys'))." - ) - assert len(task_encodings) == 1 - assert task_encodings[0].targets == [-100, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0, -100] - - -@pytest.fixture(scope="module") -def task_encodings_for_batch(task_encodings, config): - # just take everything we have - return task_encodings - - -@pytest.fixture(scope="module") -def batch(taskmodule, task_encodings_for_batch, config) -> BatchEncoding: - return taskmodule.collate(task_encodings_for_batch) - - -def test_collate(batch, config): - assert batch is not None - assert len(batch) == 2 - inputs, targets = batch - - assert set(inputs.data) == {"input_ids", "attention_mask", "special_tokens_mask"} - input_ids_list = inputs.input_ids.tolist() - attention_mask_list = inputs.attention_mask.tolist() - special_tokens_mask_list = inputs.special_tokens_mask.tolist() - assert set(targets) == {"labels"} - labels_list = targets["labels"].tolist() - - # If config is empty - if config == CONFIG_DEFAULT: - assert input_ids_list == [ - [101, 4057, 23914, 2003, 1996, 3284, 4672, 1999, 1996, 2088, 1012, 102], - [101, 5650, 7459, 3752, 2808, 1012, 3960, 15646, 2652, 4715, 1012, 102], - ] - assert attention_mask_list == [ - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - ] - assert labels_list == [ - [-100, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, -100], - [-100, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0, -100], - ] - assert special_tokens_mask_list == [ - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], - ] - - # If config has the specified values (max_window=8, window_overlap=2) - elif config == CONFIG_MAX_WINDOW_WITH_STRIDE: - assert input_ids_list == [ - [101, 4057, 23914, 2003, 1996, 3284, 4672, 102], - [101, 3284, 4672, 1999, 1996, 2088, 1012, 102], - [101, 5650, 7459, 3752, 2808, 1012, 3960, 102], - [101, 1012, 3960, 15646, 2652, 4715, 1012, 102], - ] - assert attention_mask_list == [ - [1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1, 1], - ] - assert labels_list == [ - [-100, 1, 2, 0, 0, 0, 0, -100], - [-100, 0, 0, 0, 0, 0, 0, -100], - [-100, 3, 0, 0, 0, 0, 3, -100], - [-100, 0, 3, 0, 0, 0, 0, -100], - ] - - # If config has the specified values (max_window=8) - elif config == CONFIG_MAX_WINDOW: - assert input_ids_list == [ - [101, 4057, 23914, 2003, 1996, 3284, 4672, 102], - [101, 1999, 1996, 2088, 1012, 102, 0, 0], - [101, 5650, 7459, 3752, 2808, 1012, 3960, 102], - [101, 15646, 2652, 4715, 1012, 102, 0, 0], - ] - assert attention_mask_list == [ - [1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 0, 0], - [1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 0, 0], - ] - assert labels_list == [ - [-100, 1, 2, 0, 0, 0, 0, -100], - [-100, 0, 0, 0, 0, -100, -100, -100], - [-100, 3, 0, 0, 0, 0, 3, -100], - [-100, 0, 0, 0, 0, -100, -100, -100], - ] - assert special_tokens_mask_list == [ - [1, 0, 0, 0, 0, 0, 0, 1], - [1, 0, 0, 0, 0, 1, 1, 1], - [1, 0, 0, 0, 0, 0, 0, 1], - [1, 0, 0, 0, 0, 1, 1, 1], - ] - - # If config has the specified value (partition_annotation=sentences) - elif config == CONFIG_PARTITIONS: - assert input_ids_list == [[101, 3960, 15646, 2652, 4715, 1012, 102]] - assert attention_mask_list == [[1, 1, 1, 1, 1, 1, 1]] - assert labels_list == [[-100, 3, 0, 0, 0, 0, -100]] - assert special_tokens_mask_list == [[1, 0, 0, 0, 0, 0, 1]] - - else: - raise ValueError(f"unknown config: {config}") - - inputs_expected = BatchEncoding( - data={ - "input_ids": torch.tensor(input_ids_list, dtype=torch.int64), - "attention_mask": torch.tensor(attention_mask_list, dtype=torch.int64), - "special_tokens_mask": torch.tensor(special_tokens_mask_list, dtype=torch.int64), - } - ) - assert set(inputs.data) == set(inputs_expected.data) - labels_expected = torch.tensor(labels_list, dtype=torch.int64) - assert torch.equal(targets["labels"], labels_expected) - - -# This is not used, but can be used to create a batch of task encodings with targets for the unbatched_outputs fixture. -@pytest.fixture(scope="module") -def real_model_output(batch, taskmodule): - from pytorch_ie.models import TransformerTokenClassificationModel - - model = TransformerTokenClassificationModel( - model_name_or_path="prajjwal1/bert-tiny", - num_classes=len(taskmodule.label_to_id), - ) - inputs, targets = batch - result = model(inputs) - return result - - -@pytest.fixture(scope="module") -def model_output(config, batch, taskmodule) -> ModelOutputType: - # create "perfect" output from targets - labels = batch[1]["labels"] - num_classes = len(taskmodule.label_to_id) - # create one-hot encoding from labels - labels_valid = labels.clone() - labels_valid[labels_valid == taskmodule.label_pad_id] = taskmodule.label_to_id["O"] - # create one-hot encoding from labels, but with 0.9 for the correct labels - probabilities = ( - torch.nn.functional.one_hot(labels_valid, num_classes=num_classes).to(torch.float32) * 0.9 - ) - return {"labels": labels, "probabilities": probabilities} - - -@pytest.fixture(scope="module") -def unbatched_outputs(taskmodule, model_output): - return taskmodule.unbatch_output(model_output) - - -@pytest.mark.parametrize("combine_token_scores_method", ["mean", "max", "product", "UNKNOWN"]) -def test_combine_token_scores_method(documents, combine_token_scores_method): - taskmodule = LabeledSpanExtractionByTokenClassificationTaskModule( - tokenizer_name_or_path="bert-base-uncased", - span_annotation="entities", - combine_token_scores_method=combine_token_scores_method, - ) - taskmodule.prepare(documents) - - task_encodings = taskmodule.encode(documents, encode_target=True) - batch = taskmodule.collate(task_encodings) - - # create "perfect" output from targets - labels = batch[1]["labels"] - num_classes = len(taskmodule.label_to_id) - # create one-hot encoding from labels - labels_valid = labels.clone() - labels_valid[labels_valid == taskmodule.label_pad_id] = taskmodule.label_to_id["O"] - # create one-hot encoding from labels, but with 0.9 for the correct labels - probabilities = ( - torch.nn.functional.one_hot(labels_valid, num_classes=num_classes).to(torch.float32) * 0.9 - ) - # stepwise decrease the "winning" probabilities per token to test the different combine_token_scores_methods - diff = 0.0 - for i in range(probabilities.size(1)): - probabilities[:, i] -= diff - diff += 0.01 - probabilities[probabilities < 0] = 0.0 - - model_output = {"labels": labels, "probabilities": probabilities} - - unbatched_outputs = taskmodule.unbatch_output(model_output) - - if combine_token_scores_method == "UNKNOWN": - with pytest.raises(ValueError) as excinfo: - taskmodule.decode_annotations(unbatched_outputs[0]) - assert str(excinfo.value) == "combine_token_scores_method=UNKNOWN is not supported." - else: - annotations = [] - scores = [] - for unbatched_output in unbatched_outputs: - decoded_annotations = taskmodule.decode_annotations(unbatched_output) - assert set(decoded_annotations.keys()) == {"labeled_spans"} - # Sort the annotations in each document by start and end position and label - sorted_annotations = sorted(decoded_annotations["labeled_spans"]) - annotations.append(sorted_annotations) - scores.append([round(ann.score, 5) for ann in sorted_annotations]) - - # input values are (before combination): [[0.89, 0.88], [[0.89], [0.84]]] - if combine_token_scores_method == "mean": - assert scores == [[(0.89 + 0.88) / 2], [0.89, 0.84]] - elif combine_token_scores_method == "max": - assert scores == [[0.89], [0.89, 0.84]] - elif combine_token_scores_method == "min": - assert scores == [[0.88], [0.89, 0.84]] - elif combine_token_scores_method == "product": - assert scores == [[0.89 * 0.88], [0.89, 0.84]] - else: - raise ValueError(f"unknown combine_token_scores_method: {combine_token_scores_method}") - - -def test_unbatched_output(unbatched_outputs, config): - assert unbatched_outputs is not None - - if config == CONFIG_DEFAULT: - assert len(unbatched_outputs) == 2 - torch.testing.assert_close( - unbatched_outputs[0]["labels"], - torch.tensor([-100, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, -100]), - ) - torch.testing.assert_close( - unbatched_outputs[1]["labels"], - torch.tensor([-100, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0, -100]), - ) - elif config == CONFIG_MAX_WINDOW_WITH_STRIDE: - assert len(unbatched_outputs) == 4 - torch.testing.assert_close( - unbatched_outputs[0]["labels"], torch.tensor([-100, 1, 2, 0, 0, 0, 0, -100]) - ) - torch.testing.assert_close( - unbatched_outputs[1]["labels"], torch.tensor([-100, 0, 0, 0, 0, 0, 0, -100]) - ) - torch.testing.assert_close( - unbatched_outputs[2]["labels"], torch.tensor([-100, 3, 0, 0, 0, 0, 3, -100]) - ) - torch.testing.assert_close( - unbatched_outputs[3]["labels"], torch.tensor([-100, 0, 3, 0, 0, 0, 0, -100]) - ) - elif config == CONFIG_MAX_WINDOW: - assert len(unbatched_outputs) == 4 - torch.testing.assert_close( - unbatched_outputs[0]["labels"], torch.tensor([-100, 1, 2, 0, 0, 0, 0, -100]) - ) - torch.testing.assert_close( - unbatched_outputs[1]["labels"], torch.tensor([-100, 0, 0, 0, 0, -100, -100, -100]) - ) - torch.testing.assert_close( - unbatched_outputs[2]["labels"], torch.tensor([-100, 3, 0, 0, 0, 0, 3, -100]) - ) - torch.testing.assert_close( - unbatched_outputs[3]["labels"], torch.tensor([-100, 0, 0, 0, 0, -100, -100, -100]) - ) - elif config == CONFIG_PARTITIONS: - assert len(unbatched_outputs) == 1 - torch.testing.assert_close( - unbatched_outputs[0]["labels"], torch.tensor([-100, 3, 0, 0, 0, 0, -100]) - ) - else: - raise ValueError(f"unknown config: {config}") - - -def test_decode_annotations(taskmodule, unbatched_outputs, config): - annotations = [] - for unbatched_output in unbatched_outputs: - decoded_annotations = taskmodule.decode_annotations(unbatched_output) - assert set(decoded_annotations.keys()) == {"labeled_spans"} - # Sort the annotations in each document by start and end position and label - annotations.append( - sorted( - decoded_annotations["labeled_spans"], - key=lambda labeled_span: ( - labeled_span.start, - labeled_span.end, - labeled_span.label, - ), - ) - ) - - # Check based on the config - if config == CONFIG_DEFAULT: - assert annotations == [ - [LabeledSpan(start=1, end=3, label="LOC")], - [ - LabeledSpan(start=1, end=2, label="PER"), - LabeledSpan(start=6, end=7, label="PER"), - ], - ] - - elif config == CONFIG_MAX_WINDOW_WITH_STRIDE: - # We get two annotations for Bob because the window overlaps with the previous one. - # This is not a problem because annotations get de-duplicated during serialization. - assert annotations == [ - [LabeledSpan(start=1, end=3, label="LOC")], - [], - [ - LabeledSpan(start=1, end=2, label="PER"), - LabeledSpan(start=6, end=7, label="PER"), - ], - [LabeledSpan(start=2, end=3, label="PER")], - ] - - elif config == CONFIG_MAX_WINDOW: - assert annotations == [ - [LabeledSpan(start=1, end=3, label="LOC")], - [], - [ - LabeledSpan(start=1, end=2, label="PER"), - LabeledSpan(start=6, end=7, label="PER"), - ], - [], - ] - - elif config == CONFIG_PARTITIONS: - assert annotations == [[LabeledSpan(start=1, end=2, label="PER", score=1.0)]] - - else: - raise ValueError(f"unknown config: {config}") - - # assert that all scores are 0.9 - for doc_annotations in annotations: - for annotation in doc_annotations: - assert round(annotation.score, 4) == 0.9 - - -@pytest.fixture(scope="module") -def annotations_from_output(taskmodule, task_encodings_for_batch, unbatched_outputs, config): - named_annotations_per_document = defaultdict(list) - for task_encoding, task_output in zip(task_encodings_for_batch, unbatched_outputs): - annotations = taskmodule.create_annotations_from_output(task_encoding, task_output) - named_annotations_per_document[task_encoding.document.id].extend(list(annotations)) - return named_annotations_per_document - - -def test_annotations_from_output(annotations_from_output, config, documents): - assert annotations_from_output is not None - # Sort the annotations in each document by start and end positions - annotations_from_output = { - doc_id: sorted(annotations, key=lambda x: (x[0], x[1].start, x[1].end)) - for doc_id, annotations in annotations_from_output.items() - } - documents_by_id = {doc.id: doc for doc in documents} - documents_with_annotations = [] - resolved_annotations = defaultdict(list) - # Check that the number of annotations is correct - for doc_id, layer_names_and_annotations in annotations_from_output.items(): - new_doc = documents_by_id[doc_id].copy() - for layer_name, annotation in layer_names_and_annotations: - assert layer_name == "entities" - assert isinstance(annotation, LabeledSpan) - new_doc.entities.predictions.append(annotation) - resolved_annotations[doc_id].append(str(annotation)) - documents_with_annotations.append(new_doc) - - resolved_annotations = dict(resolved_annotations) - # Check based on the config - if config == CONFIG_DEFAULT: - assert resolved_annotations == {"doc1": ["Mount Everest"], "doc2": ["Alice", "Bob"]} - - elif config == CONFIG_MAX_WINDOW_WITH_STRIDE: - # We get two annotations for Bob because the window overlaps with the previous one. - # This is not a problem because annotations get de-duplicated during serialization. - assert resolved_annotations == {"doc1": ["Mount Everest"], "doc2": ["Alice", "Bob", "Bob"]} - - elif config == CONFIG_MAX_WINDOW: - assert resolved_annotations == {"doc1": ["Mount Everest"], "doc2": ["Alice", "Bob"]} - - elif config == CONFIG_PARTITIONS: - assert resolved_annotations == {"doc2": ["Bob"]} - - else: - raise ValueError(f"unknown config: {config}") - - -def test_document_type(): - taskmodule = LabeledSpanExtractionByTokenClassificationTaskModule( - tokenizer_name_or_path="bert-base-uncased" - ) - assert taskmodule.document_type == TextDocumentWithLabeledSpans - - -def test_document_type_with_partitions(): - taskmodule = LabeledSpanExtractionByTokenClassificationTaskModule( - tokenizer_name_or_path="bert-base-uncased", partition_annotation="labeled_partitions" - ) - assert taskmodule.document_type == TextDocumentWithLabeledSpansAndLabeledPartitions - - -def test_document_type_with_non_default_span_annotation(caplog): - with caplog.at_level(logging.WARNING): - taskmodule = LabeledSpanExtractionByTokenClassificationTaskModule( - tokenizer_name_or_path="bert-base-uncased", span_annotation="entities" - ) - assert taskmodule.document_type is None - assert len(caplog.records) == 1 - assert caplog.records[0].levelname == "WARNING" - assert ( - caplog.records[0].message - == "span_annotation=entities is not the default value ('labeled_spans'), so the taskmodule " - "LabeledSpanExtractionByTokenClassificationTaskModule can not request the usual document type " - "(TextDocumentWithLabeledSpans) for auto-conversion because this has the bespoken default value " - "as layer name(s) instead of the provided one(s)." - ) - - -def test_document_type_with_non_default_partition_annotation(caplog): - with caplog.at_level(logging.WARNING): - taskmodule = LabeledSpanExtractionByTokenClassificationTaskModule( - tokenizer_name_or_path="bert-base-uncased", partition_annotation="sentences" - ) - assert taskmodule.document_type is None - assert len(caplog.records) == 1 - assert caplog.records[0].levelname == "WARNING" - assert ( - caplog.records[0].message - == "partition_annotation=sentences is not the default value ('labeled_partitions'), " - "so the taskmodule LabeledSpanExtractionByTokenClassificationTaskModule can not request the usual document type " - "(TextDocumentWithLabeledSpansAndLabeledPartitions) for auto-conversion because this has " - "the bespoken default value as layer name(s) instead of the provided one(s)." - ) - - -def test_document_type_with_non_default_span_and_partition_annotation(caplog): - with caplog.at_level(logging.WARNING): - taskmodule = LabeledSpanExtractionByTokenClassificationTaskModule( - tokenizer_name_or_path="bert-base-uncased", - span_annotation="entities", - partition_annotation="sentences", - ) - assert taskmodule.document_type is None - assert len(caplog.records) == 1 - assert caplog.records[0].levelname == "WARNING" - assert ( - caplog.records[0].message - == "span_annotation=entities is not the default value ('labeled_spans') and " - "partition_annotation=sentences is not the default value ('labeled_partitions'), " - "so the taskmodule LabeledSpanExtractionByTokenClassificationTaskModule can not request the usual document " - "type (TextDocumentWithLabeledSpansAndLabeledPartitions) for auto-conversion because " - "this has the bespoken default value as layer name(s) instead of the provided one(s)." - ) - - -def test_configure_model_metric(documents): - taskmodule = LabeledSpanExtractionByTokenClassificationTaskModule( - tokenizer_name_or_path="bert-base-uncased", - span_annotation="entities", - labels=["LOC", "PER"], - ) - taskmodule.post_prepare() - - metric = taskmodule.configure_model_metric(stage="test") - values = metric.compute() - assert values == { - "token/macro/f1": tensor(0.0), - "token/micro/f1": tensor(0.0), - "token/macro/precision": tensor(0.0), - "token/macro/recall": tensor(0.0), - "token/micro/precision": tensor(0.0), - "token/micro/recall": tensor(0.0), - } - - batch = taskmodule.collate(taskmodule.encode(documents, encode_target=True)) - targets = batch[1] - metric.update(targets, targets) - values = metric.compute() - assert values == { - "span/LOC/f1": tensor(1.0), - "span/LOC/precision": tensor(1.0), - "span/LOC/recall": tensor(1.0), - "span/PER/f1": tensor(1.0), - "span/PER/precision": tensor(1.0), - "span/PER/recall": tensor(1.0), - "span/macro/f1": tensor(1.0), - "span/macro/precision": tensor(1.0), - "span/macro/recall": tensor(1.0), - "span/micro/f1": tensor(1.0), - "span/micro/precision": tensor(1.0), - "span/micro/recall": tensor(1.0), - "token/macro/f1": tensor(1.0), - "token/micro/f1": tensor(1.0), - "token/macro/precision": tensor(1.0), - "token/macro/recall": tensor(1.0), - "token/micro/precision": tensor(1.0), - "token/micro/recall": tensor(1.0), - } - - target_labels = targets["labels"] - predicted_labels = torch.ones_like(target_labels) - # we need to set the same padding as in the targets - predicted_labels[target_labels == taskmodule.label_pad_id] = taskmodule.label_pad_id - prediction = {"labels": predicted_labels} - metric.update(prediction, targets) - values = metric.compute() - values_converted = {k: v.item() for k, v in values.items()} - assert values_converted == { - "token/macro/f1": 0.5434783101081848, - "token/micro/f1": 0.5249999761581421, - "token/macro/precision": 0.773809552192688, - "token/macro/recall": 0.625, - "token/micro/precision": 0.5249999761581421, - "token/micro/recall": 0.5249999761581421, - "span/LOC/recall": 0.0476190485060215, - "span/LOC/precision": 0.5, - "span/LOC/f1": 0.08695652335882187, - "span/macro/f1": 0.37681159377098083, - "span/macro/precision": 0.5, - "span/macro/recall": 0.523809552192688, - "span/micro/recall": 0.1304347813129425, - "span/micro/precision": 0.5, - "span/micro/f1": 0.2068965584039688, - "span/PER/recall": 1.0, - "span/PER/precision": 0.5, - "span/PER/f1": 0.6666666865348816, - } - - # ensure that the metric can be pickled - pickle.dumps(metric) diff --git a/tests/taskmodules/test_pointer_network_for_end2end_re.py b/tests/taskmodules/test_pointer_network_for_end2end_re.py deleted file mode 100644 index 5251b8ba8..000000000 --- a/tests/taskmodules/test_pointer_network_for_end2end_re.py +++ /dev/null @@ -1,1313 +0,0 @@ -import logging -import pickle -from dataclasses import asdict, dataclass -from typing import Dict, List, Set - -import pytest -import torch -from pie_core import AnnotationLayer, Document, annotation_field -from transformers import LogitsProcessorList - -from pie_modules.annotations import BinaryRelation, LabeledSpan -from pie_modules.documents import TextBasedDocument -from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE -from pie_modules.taskmodules.pointer_network.logits_processor import ( - FinitizeLogitsProcessor, - PrefixConstrainedLogitsProcessorWithMaximum, -) -from pie_modules.taskmodules.pointer_network_for_end2end_re import ( - LabelsAndOptionalConstraints, -) - -logger = logging.getLogger(__name__) - -DUMP_FIXTURE_DATA = False - - -def _config_to_str(cfg: Dict[str, str]) -> str: - result = "-".join([f"{k}={cfg[k]}" for k in sorted(cfg)]) - return result - - -CONFIGS = [{}, {"partition_layer_name": "sentences"}] -CONFIG_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS} - - -@pytest.fixture(scope="module", params=CONFIG_DICT.keys()) -def config_str(request): - return request.param - - -@pytest.fixture(scope="module") -def config(config_str): - return CONFIG_DICT[config_str] - - -@dataclass -class ExampleDocument(TextBasedDocument): - entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") - relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities") - sentences: AnnotationLayer[LabeledSpan] = annotation_field(target="text") - - -@pytest.fixture(scope="module") -def document(): - doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.") - span1 = LabeledSpan(start=10, end=20, label="content") - span2 = LabeledSpan(start=27, end=34, label="topic") - span3 = LabeledSpan(start=42, end=44, label="person") - doc.entities.extend([span1, span2, span3]) - assert str(span1) == "dummy text" - assert str(span2) == "nothing" - assert str(span3) == "me" - rel = BinaryRelation(head=span1, tail=span2, label="is_about") - doc.relations.append(rel) - assert str(rel.label) == "is_about" - assert str(rel.head) == "dummy text" - assert str(rel.tail) == "nothing" - - no_rel = BinaryRelation(head=span1, tail=span3, label="no_relation") - doc.relations.append(no_rel) - assert str(no_rel.label) == "no_relation" - assert str(no_rel.head) == "dummy text" - assert str(no_rel.tail) == "me" - - sent1 = LabeledSpan(start=0, end=35, label="1") - sent2 = LabeledSpan(start=36, end=45, label="2") - doc.sentences.extend([sent1, sent2]) - assert str(sent1) == "This is a dummy text about nothing." - assert str(sent2) == "Trust me." - return doc - - -def test_document(document): - spans = document.entities - assert len(spans) == 3 - assert (str(spans[0]), spans[0].label) == ("dummy text", "content") - assert (str(spans[1]), spans[1].label) == ("nothing", "topic") - assert (str(spans[2]), spans[2].label) == ("me", "person") - relations = document.relations - assert len(relations) == 2 - assert (str(relations[0].head), relations[0].label, str(relations[0].tail)) == ( - "dummy text", - "is_about", - "nothing", - ) - assert (str(relations[1].head), relations[1].label, str(relations[1].tail)) == ( - "dummy text", - "no_relation", - "me", - ) - sentences = document.sentences - assert len(sentences) == 2 - assert str(sentences[0]) == "This is a dummy text about nothing." - assert str(sentences[1]) == "Trust me." - - -@pytest.fixture(scope="module") -def taskmodule(document, config): - taskmodule = PointerNetworkTaskModuleForEnd2EndRE( - tokenizer_name_or_path="facebook/bart-base", - relation_layer_name="relations", - exclude_labels_per_layer={"relations": ["no_relation"]}, - annotation_field_mapping={ - "entities": "labeled_spans", - "relations": "binary_relations", - }, - create_constraints=True, - tokenizer_kwargs={"strict_span_conversion": False}, - **config, - ) - - taskmodule.prepare(documents=[document]) - return taskmodule - - -def test_taskmodule(taskmodule): - assert taskmodule.is_prepared - assert taskmodule.prepared_attributes == { - "labels_per_layer": { - "entities": ["content", "person", "topic"], - "relations": ["is_about"], - }, - } - assert taskmodule.layer_names == ["entities", "relations"] - assert taskmodule.special_targets == ["", ""] - assert taskmodule.labels == ["none", "content", "person", "topic", "is_about"] - assert taskmodule.targets == [ - "", - "", - "none", - "content", - "person", - "topic", - "is_about", - ] - assert taskmodule.bos_id == 0 - assert taskmodule.eos_id == 1 - assert taskmodule.none_id == 2 - assert taskmodule.span_ids == [3, 4, 5] - assert taskmodule.relation_ids == [6] - assert taskmodule.label2id == { - "content": 3, - "is_about": 6, - "none": 2, - "person": 4, - "topic": 5, - } - assert taskmodule.label_embedding_weight_mapping == { - 50265: [45260], - 50266: [39763], - 50267: [354, 1215, 9006], - 50268: [5970], - 50269: [10166], - } - assert taskmodule.target_tokens == [ - "", - "", - "<>", - "<>", - "<>", - "<>", - "<>", - ] - assert taskmodule.target_token_ids == [0, 2, 50266, 50269, 50268, 50265, 50267] - - -def test_taskmodule_with_wrong_annotation_field_mapping(): - with pytest.raises(ValueError) as exc_info: - PointerNetworkTaskModuleForEnd2EndRE( - tokenizer_name_or_path="facebook/bart-base", - relation_layer_name="relations", - annotation_field_mapping={ - "entities": "labeled_spans", - "sentences": "labeled_spans", - }, - ) - assert str(exc_info.value) == ( - "inverted annotation_field_mapping is not unique. annotation_field_mapping: " - "{'entities': 'labeled_spans', 'sentences': 'labeled_spans'}" - ) - - -def test_prepared_config(taskmodule, config): - if config == {}: - assert taskmodule._config() == { - "taskmodule_type": "PointerNetworkTaskModuleForEnd2EndRE", - "relation_layer_name": "relations", - "symmetric_relations": None, - "none_label": "none", - "loop_dummy_relation_name": "loop", - "labels_per_layer": { - "entities": ["content", "person", "topic"], - "relations": ["is_about"], - }, - "exclude_labels_per_layer": {"relations": ["no_relation"]}, - "create_constraints": True, - "document_type": "pytorch_ie.documents.TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions", - "tokenized_document_type": "pie_modules.documents.TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions", - "tokenizer_name_or_path": "facebook/bart-base", - "tokenizer_init_kwargs": None, - "tokenizer_kwargs": {"strict_span_conversion": False}, - "partition_layer_name": None, - "add_reversed_relations": False, - "annotation_field_mapping": { - "entities": "labeled_spans", - "relations": "binary_relations", - }, - "constrained_generation": False, - "label_tokens": None, - "label_representations": None, - "log_first_n_examples": None, - } - elif config == {"partition_layer_name": "sentences"}: - assert taskmodule._config() == { - "taskmodule_type": "PointerNetworkTaskModuleForEnd2EndRE", - "relation_layer_name": "relations", - "symmetric_relations": None, - "none_label": "none", - "loop_dummy_relation_name": "loop", - "labels_per_layer": { - "entities": ["content", "person", "topic"], - "relations": ["is_about"], - }, - "exclude_labels_per_layer": {"relations": ["no_relation"]}, - "create_constraints": True, - "document_type": "pytorch_ie.documents.TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions", - "tokenized_document_type": "pie_modules.documents.TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions", - "tokenizer_name_or_path": "facebook/bart-base", - "tokenizer_init_kwargs": None, - "tokenizer_kwargs": {"strict_span_conversion": False}, - "partition_layer_name": "sentences", - "add_reversed_relations": False, - "annotation_field_mapping": { - "entities": "labeled_spans", - "relations": "binary_relations", - }, - "constrained_generation": False, - "label_tokens": None, - "label_representations": None, - "log_first_n_examples": None, - } - else: - raise Exception(f"unknown config: {config}") - - -@pytest.fixture() -def task_encoding_without_target(taskmodule, document): - return taskmodule.encode_input(document)[0] - - -def test_add_reversed_relation_labels(): - taskmodule = PointerNetworkTaskModuleForEnd2EndRE( - tokenizer_name_or_path="facebook/bart-base", - symmetric_relations=["symmetric_relation"], - ) - - labels = ["is_about", "symmetric_relation"] - labels_with_reversed = taskmodule.add_reversed_relation_labels(labels) - assert labels_with_reversed == {"is_about", "is_about_reversed", "symmetric_relation"} - - -def test_reverse_relation(): - taskmodule = PointerNetworkTaskModuleForEnd2EndRE( - tokenizer_name_or_path="facebook/bart-base", - symmetric_relations=["symmetric_relation"], - ) - - rel = BinaryRelation( - head=LabeledSpan(start=10, end=20, label="content"), - tail=LabeledSpan(start=27, end=34, label="topic"), - label="is_about", - ) - reversed_relation = taskmodule.reverse_relation(relation=rel) - assert reversed_relation == BinaryRelation( - head=LabeledSpan(start=27, end=34, label="topic", score=1.0), - tail=LabeledSpan(start=10, end=20, label="content", score=1.0), - label="is_about_reversed", - score=1.0, - ) - - sym_rel = BinaryRelation( - head=LabeledSpan(start=10, end=20, label="content"), - tail=LabeledSpan(start=27, end=34, label="topic"), - label="symmetric_relation", - ) - reversed_sym_rel = taskmodule.reverse_relation(relation=sym_rel) - assert reversed_sym_rel == BinaryRelation( - head=LabeledSpan(start=27, end=34, label="topic", score=1.0), - tail=LabeledSpan(start=10, end=20, label="content", score=1.0), - label="symmetric_relation", - score=1.0, - ) - - -def test_unreverse_relation(): - taskmodule = PointerNetworkTaskModuleForEnd2EndRE( - tokenizer_name_or_path="facebook/bart-base", - symmetric_relations=["symmetric_relation"], - ) - - # nothing should change because the relation is not reversed - rel = BinaryRelation( - head=LabeledSpan(start=10, end=20, label="content"), - tail=LabeledSpan(start=27, end=34, label="topic"), - label="is_about", - ) - same_rel = taskmodule.unreverse_relation(relation=rel) - assert same_rel == rel - - # the relation is reversed, so it should be un-reversed - reversed_rel = BinaryRelation( - head=LabeledSpan(start=10, end=20, label="content"), - tail=LabeledSpan(start=27, end=34, label="topic"), - label="is_about_reversed", - ) - unreversed_relation = taskmodule.unreverse_relation(relation=reversed_rel) - assert unreversed_relation == BinaryRelation( - head=LabeledSpan(start=27, end=34, label="topic", score=1.0), - tail=LabeledSpan(start=10, end=20, label="content", score=1.0), - label="is_about", - score=1.0, - ) - - # nothing should change because the relation is symmetric and already ordered (head < tail) - ordered_sym_rel = BinaryRelation( - head=LabeledSpan(start=10, end=20, label="content"), - tail=LabeledSpan(start=27, end=34, label="topic"), - label="symmetric_relation", - ) - unreversed_ordered_sym_rel = taskmodule.unreverse_relation(relation=ordered_sym_rel) - assert ordered_sym_rel == unreversed_ordered_sym_rel - - # the relation is symmetric and unordered (head > tail), so it should be un-reversed - unordered_sym_rel = BinaryRelation( - head=LabeledSpan(start=27, end=34, label="topic"), - tail=LabeledSpan(start=10, end=20, label="content"), - label="symmetric_relation", - ) - unreversed_unordered_sym_rel = taskmodule.unreverse_relation(relation=unordered_sym_rel) - assert unreversed_unordered_sym_rel == BinaryRelation( - head=LabeledSpan(start=10, end=20, label="content", score=1.0), - tail=LabeledSpan(start=27, end=34, label="topic", score=1.0), - label="symmetric_relation", - score=1.0, - ) - - -@pytest.fixture(params=[False, True]) -def taskmodule_with_reversed_relations(document, request) -> PointerNetworkTaskModuleForEnd2EndRE: - is_about_is_symmetric = request.param - taskmodule = PointerNetworkTaskModuleForEnd2EndRE( - tokenizer_name_or_path="facebook/bart-base", - relation_layer_name="relations", - exclude_labels_per_layer={"relations": ["no_relation"]}, - annotation_field_mapping={ - "entities": "labeled_spans", - "relations": "binary_relations", - }, - create_constraints=True, - tokenizer_kwargs={"strict_span_conversion": False}, - add_reversed_relations=True, - symmetric_relations=["is_about"] if is_about_is_symmetric else None, - ) - - taskmodule.prepare(documents=[document]) - assert taskmodule.is_prepared - if is_about_is_symmetric: - assert taskmodule.prepared_attributes == { - "labels_per_layer": { - "entities": ["content", "person", "topic"], - "relations": ["is_about"], - } - } - else: - assert taskmodule.prepared_attributes == { - "labels_per_layer": { - "entities": ["content", "person", "topic"], - "relations": ["is_about", "is_about_reversed"], - } - } - - return taskmodule - - -def test_encode_with_add_reversed_relations(taskmodule_with_reversed_relations, document): - task_encodings = taskmodule_with_reversed_relations.encode(document, encode_target=True) - assert len(task_encodings) == 1 - task_encoding = task_encodings[0] - assert task_encoding is not None - assert asdict(task_encoding.inputs) == { - "input_ids": [0, 713, 16, 10, 34759, 2788, 59, 1085, 4, 3101, 162, 4, 2], - "attention_mask": [1] * 13, - } - tokens = taskmodule_with_reversed_relations.tokenizer.convert_ids_to_tokens( - task_encoding.inputs.input_ids - ) - assert tokens == [ - "", - "This", - "Ġis", - "Ġa", - "Ġdummy", - "Ġtext", - "Ġabout", - "Ġnothing", - ".", - "ĠTrust", - "Ġme", - ".", - "", - ] - if "is_about" in taskmodule_with_reversed_relations.symmetric_relations: - decoded_annotations, statistics = taskmodule_with_reversed_relations.decode_annotations( - task_encoding.targets - ) - assert decoded_annotations == { - "entities": [ - LabeledSpan(start=4, end=6, label="content", score=1.0), - LabeledSpan(start=7, end=8, label="topic", score=1.0), - LabeledSpan(start=10, end=11, label="person", score=1.0), - ], - "relations": [ - BinaryRelation( - head=LabeledSpan(start=4, end=6, label="content", score=1.0), - tail=LabeledSpan(start=7, end=8, label="topic", score=1.0), - label="is_about", - score=1.0, - ), - BinaryRelation( - head=LabeledSpan(start=7, end=8, label="topic", score=1.0), - tail=LabeledSpan(start=4, end=6, label="content", score=1.0), - label="is_about", - score=1.0, - ), - ], - } - else: - decoded_annotations, statistics = taskmodule_with_reversed_relations.decode_annotations( - task_encoding.targets - ) - assert decoded_annotations == { - "entities": [ - LabeledSpan(start=4, end=6, label="content", score=1.0), - LabeledSpan(start=7, end=8, label="topic", score=1.0), - LabeledSpan(start=10, end=11, label="person", score=1.0), - ], - "relations": [ - BinaryRelation( - head=LabeledSpan(start=4, end=6, label="content", score=1.0), - tail=LabeledSpan(start=7, end=8, label="topic", score=1.0), - label="is_about", - score=1.0, - ), - BinaryRelation( - head=LabeledSpan(start=7, end=8, label="topic", score=1.0), - tail=LabeledSpan(start=4, end=6, label="content", score=1.0), - label="is_about_reversed", - score=1.0, - ), - ], - } - - -def test_encode_with_add_reversed_relations_already_exists(caplog): - doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.") - doc.entities.append(LabeledSpan(start=10, end=20, label="content")) - doc.entities.append(LabeledSpan(start=27, end=34, label="topic")) - doc.relations.append( - BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about") - ) - doc.relations.append( - BinaryRelation(head=doc.entities[1], tail=doc.entities[0], label="is_about") - ) - - taskmodule = PointerNetworkTaskModuleForEnd2EndRE( - tokenizer_name_or_path="facebook/bart-base", - relation_layer_name="relations", - annotation_field_mapping={ - "entities": "labeled_spans", - "relations": "binary_relations", - }, - add_reversed_relations=True, - symmetric_relations=["is_about"], - ) - taskmodule.prepare(documents=[doc]) - - with caplog.at_level(logging.WARNING): - task_encodings = taskmodule.encode(doc, encode_target=True) - assert len(caplog.messages) == 0 - assert len(task_encodings) == 1 - task_encoding = task_encodings[0] - - decoded_annotations, statistics = taskmodule.decode_annotations(task_encoding.targets) - assert decoded_annotations == { - "entities": [ - LabeledSpan(start=4, end=6, label="content", score=1.0), - LabeledSpan(start=7, end=8, label="topic", score=1.0), - ], - "relations": [ - BinaryRelation( - head=LabeledSpan(start=4, end=6, label="content", score=1.0), - tail=LabeledSpan(start=7, end=8, label="topic", score=1.0), - label="is_about", - score=1.0, - ), - BinaryRelation( - head=LabeledSpan(start=7, end=8, label="topic", score=1.0), - tail=LabeledSpan(start=4, end=6, label="content", score=1.0), - label="is_about", - score=1.0, - ), - ], - } - - -def test_decode_with_add_reversed_relations(): - doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.") - doc.entities.append(LabeledSpan(start=10, end=20, label="content")) - doc.entities.append(LabeledSpan(start=27, end=34, label="topic")) - doc.relations.append( - BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about") - ) - - taskmodule = PointerNetworkTaskModuleForEnd2EndRE( - tokenizer_name_or_path="facebook/bart-base", - relation_layer_name="relations", - annotation_field_mapping={ - "entities": "labeled_spans", - "relations": "binary_relations", - }, - add_reversed_relations=True, - ) - taskmodule.prepare(documents=[doc]) - - task_encodings = taskmodule.encode(doc, encode_target=True) - assert len(task_encodings) == 1 - decoded_annotations, statistics = taskmodule.decode_annotations(task_encodings[0].targets) - assert decoded_annotations == { - "entities": [ - LabeledSpan(start=4, end=6, label="content", score=1.0), - LabeledSpan(start=7, end=8, label="topic", score=1.0), - ], - "relations": [ - BinaryRelation( - head=LabeledSpan(start=4, end=6, label="content", score=1.0), - tail=LabeledSpan(start=7, end=8, label="topic", score=1.0), - label="is_about", - score=1.0, - ), - BinaryRelation( - head=LabeledSpan(start=7, end=8, label="topic", score=1.0), - tail=LabeledSpan(start=4, end=6, label="content", score=1.0), - label="is_about_reversed", - score=1.0, - ), - ], - } - - task_outputs = [task_encoding.targets for task_encoding in task_encodings] - docs_with_predictions = taskmodule.decode(task_encodings, task_outputs) - assert len(docs_with_predictions) == 1 - doc_with_predictions: ExampleDocument = docs_with_predictions[0] - assert set(doc_with_predictions.entities.predictions) == set(doc_with_predictions.entities) - assert set(doc_with_predictions.relations.predictions) == set(doc_with_predictions.relations) - - -@pytest.fixture() -def target_encoding(taskmodule, task_encoding_without_target): - return taskmodule.encode_target(task_encoding_without_target) - - -def test_target_encoding(target_encoding, taskmodule): - assert target_encoding is not None - if taskmodule.partition_layer_name is None: - assert asdict(target_encoding) == { - "labels": [14, 14, 5, 11, 12, 3, 6, 17, 17, 4, 2, 2, 2, 2, 1], - "constraints": [ - [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], - [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1], - [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ], - } - elif taskmodule.partition_layer_name == "sentences": - assert asdict(target_encoding) == { - "labels": [14, 14, 5, 11, 12, 3, 6, 1], - "constraints": [ - [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0], - [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ], - } - else: - raise Exception(f"unknown partition_layer_name: {taskmodule.partition_layer_name}") - - -def test_task_encoding_with_deduplicated_relations(caplog): - doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.") - doc.entities.append(LabeledSpan(start=10, end=20, label="content")) - doc.entities.append(LabeledSpan(start=27, end=34, label="topic")) - doc.entities.append(LabeledSpan(start=42, end=44, label="person")) - assert doc.entities.resolve() == [ - ("content", "dummy text"), - ("topic", "nothing"), - ("person", "me"), - ] - # add the same relation twice (just use a different score, but that should not matter) - doc.relations.append( - BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about") - ) - doc.relations.append( - BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about", score=0.9) - ) - assert doc.relations.resolve() == [ - ("is_about", (("content", "dummy text"), ("topic", "nothing"))), - ("is_about", (("content", "dummy text"), ("topic", "nothing"))), - ] - taskmodule = PointerNetworkTaskModuleForEnd2EndRE( - tokenizer_name_or_path="facebook/bart-base", - relation_layer_name="relations", - annotation_field_mapping={ - "entities": "labeled_spans", - "relations": "binary_relations", - }, - ) - taskmodule.prepare(documents=[doc]) - caplog.clear() - with caplog.at_level(logging.WARNING): - task_encodings = taskmodule.encode(doc, encode_target=True) - messages = list(caplog.messages) - - assert len(task_encodings) == 1 - decoded_annotations, statistics = taskmodule.decode_annotations(task_encodings[0].targets) - assert decoded_annotations == { - "entities": [ - LabeledSpan(start=4, end=6, label="content", score=1.0), - LabeledSpan(start=7, end=8, label="topic", score=1.0), - LabeledSpan(start=10, end=11, label="person", score=1.0), - ], - "relations": [ - BinaryRelation( - head=LabeledSpan(start=4, end=6, label="content", score=1.0), - tail=LabeledSpan(start=7, end=8, label="topic", score=1.0), - label="is_about", - score=1.0, - ) - ], - } - - assert messages == [ - ( - "encoding errors: {'correct': 2}, skipped annotations:\n" - "{\n" - ' "relations": [\n' - ' "BinaryRelation(' - "head=LabeledSpan(start=4, end=6, label='content', score=1.0), " - "tail=LabeledSpan(start=7, end=8, label='topic', score=1.0), " - "label='is_about', score=0.9" - ')"\n' - " ]\n" - "}" - ) - ] - - -def test_task_encoding_with_conflicting_relations(caplog): - doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.") - doc.entities.append(LabeledSpan(start=10, end=20, label="content")) - doc.entities.append(LabeledSpan(start=27, end=34, label="topic")) - doc.entities.append(LabeledSpan(start=42, end=44, label="person")) - assert doc.entities.resolve() == [ - ("content", "dummy text"), - ("topic", "nothing"), - ("person", "me"), - ] - # add two relations with the same head and tail, but different labels - doc.relations.append( - BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about") - ) - doc.relations.append( - BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="wrong_relation") - ) - assert doc.relations.resolve() == [ - ("is_about", (("content", "dummy text"), ("topic", "nothing"))), - ("wrong_relation", (("content", "dummy text"), ("topic", "nothing"))), - ] - taskmodule = PointerNetworkTaskModuleForEnd2EndRE( - tokenizer_name_or_path="facebook/bart-base", - relation_layer_name="relations", - annotation_field_mapping={ - "entities": "labeled_spans", - "relations": "binary_relations", - }, - ) - taskmodule.prepare(documents=[doc]) - caplog.clear() - with caplog.at_level(logging.ERROR): - task_encodings = taskmodule.encode(doc, encode_target=True) - messages = list(caplog.messages) - - assert len(task_encodings) == 0 - - assert messages == [ - "failed to encode target, it will be skipped: " - "relation ('Ġdummy', 'Ġtext') -> ('Ġnothing',) already exists, but has " - "another label: is_about (current label: wrong_relation)." - ] - - -@pytest.fixture() -def task_encoding(task_encoding_without_target, target_encoding): - task_encoding_without_target.targets = target_encoding - return task_encoding_without_target - - -def _separate_constraint(constraint, taskmodule): - special_ids = sorted(taskmodule.special_target2id.values()) - none_ids = [taskmodule.none_id] - span_ids = taskmodule.span_ids - rel_ids = taskmodule.relation_ids - result = [[constraint[id] for id in ids] for ids in [special_ids, none_ids, span_ids, rel_ids]] - result += [constraint[taskmodule.pointer_offset :]] - assert sum(len(con_part) for con_part in result) == len(constraint) - return result - - -def test_build_constraint(taskmodule): - target_ids = [14, 14, 5, 11, 12, 3, 6, 17, 17, 4, 2, 2, 2, 2, 1] - input_len = 13 - - # empty previous_ids - constraint = taskmodule._build_constraint(previous_ids=[], input_len=input_len) - # [bos, eos], [none], [content, person, topic], [is_about] [13 offsets (all remaining)] - constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule) - # allow eos and all offsets - assert constraint_formatted == [ - [0, 1], - [0], - [0, 0, 0], - [0], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - ] - - # just first span start - constraint = taskmodule._build_constraint(previous_ids=[14], input_len=input_len) - # [bos, eos], [none], [content, person, topic], [is_about] [13 offsets (all remaining)] - constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule) - # allow all offsets after first span start - assert constraint_formatted == [ - [0, 0], - [0], - [0, 0, 0], - [0], - [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], - ] - - # first span start and end - constraint = taskmodule._build_constraint(previous_ids=[14, 14], input_len=input_len) - # [bos, eos], [none], [content, person, topic], [is_about] [13 offsets (all remaining)] - constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule) - # allow all span ids - assert constraint_formatted == [ - [0, 0], - [0], - [1, 1, 1], - [0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] - - # first span start, end, and label - constraint = taskmodule._build_constraint(previous_ids=[14, 14, 5], input_len=input_len) - # [bos, eos], [none], [content, person, topic], [is_about] [13 offsets (all remaining)] - constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule) - # allow none and all offsets except offsets covered by first span - assert constraint_formatted == [ - [0, 0], - [1], - [0, 0, 0], - [0], - [1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1], - ] - - # first span, and second span start - constraint = taskmodule._build_constraint(previous_ids=[14, 14, 5, 11], input_len=input_len) - # [bos, eos], [none], [content, person, topic], [is_about] [13 offsets (all remaining)] - constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule) - # allow all offsets after second span start, but not after first span start - assert constraint_formatted == [ - [0, 0], - [0], - [0, 0, 0], - [0], - [0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0], - ] - - # first span, and second span start and end - constraint = taskmodule._build_constraint( - previous_ids=[14, 14, 5, 11, 12], input_len=input_len - ) - # [bos, eos], [none], [content, person, topic], [is_about] [13 offsets (all remaining)] - constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule) - # allow all span ids - assert constraint_formatted == [ - [0, 0], - [0], - [1, 1, 1], - [0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] - - # first span, and second span - constraint = taskmodule._build_constraint( - previous_ids=[14, 14, 5, 11, 12, 3], input_len=input_len - ) - # [bos, eos], [none], [content, person, topic], [is_about] [13 offsets (all remaining)] - constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule) - # allow all relation ids - assert constraint_formatted == [ - [0, 0], - [0], - [0, 0, 0], - [1], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] - - # fist span, and (1 to 3)-times none - for i in range(1, 3): - none_ids = [2] * i - constraint = taskmodule._build_constraint( - previous_ids=[14, 14, 5] + none_ids, input_len=input_len - ) - # [bos, eos], [none], [content, person, topic], [is_about] [13 offsets (all remaining)] - constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule) - # allow only none - assert constraint_formatted == [ - [0, 0], - [1], - [0, 0, 0], - [0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] - - # contains eos - constraint = taskmodule._build_constraint( - previous_ids=[14, 14, 5, 11, 12, 3, 6, 1], input_len=input_len - ) - # [bos, eos/pad], [none], [content, person, topic], [is_about] [13 offsets (all remaining)] - constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule) - # allow only pad (same as eos) - assert constraint_formatted == [ - [0, 1], - [0], - [0, 0, 0], - [0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] - - -def test_maybe_log_example(taskmodule, task_encoding, caplog, config): - original_log_first_n_examples = taskmodule.log_first_n_examples - taskmodule.log_first_n_examples = 1 - caplog.clear() - with caplog.at_level(logging.INFO): - taskmodule.maybe_log_example(task_encoding) - if config == {}: - assert caplog.messages == [ - "*** Example ***", - "doc.id: None-tokenized-1-of-1", - "input_ids: 0 713 16 10 34759 2788 59 1085 4 3101 162 4 2", - "input_tokens: This Ġis Ġa Ġdummy Ġtext Ġabout Ġnothing . ĠTrust Ġme . " "", - "label_ids: 14 14 5 11 12 3 6 17 17 4 2 2 2 2 1", - "label_tokens: 14 {Ġnothing} 14 {Ġnothing} topic 11 {Ġdummy} 12 {Ġtext} content is_about 17 {Ġme} 17 " - "{Ġme} person none none none none ", - "constraints: torch.Size([15, 20]) (content is omitted)", - ] - elif config == {"partition_layer_name": "sentences"}: - assert caplog.messages == [ - "*** Example ***", - "doc.id: None-tokenized-1-of-2", - "input_ids: 0 713 16 10 34759 2788 59 1085 4 2", - "input_tokens: This Ġis Ġa Ġdummy Ġtext Ġabout Ġnothing . ", - "label_ids: 14 14 5 11 12 3 6 1", - "label_tokens: 14 {Ġnothing} 14 {Ġnothing} topic 11 {Ġdummy} 12 {Ġtext} content is_about ", - "constraints: torch.Size([8, 17]) (content is omitted)", - ] - else: - raise Exception(f"unknown config: {config}") - - # restore original value - taskmodule.log_first_n_examples = original_log_first_n_examples - - -def test_maybe_log_example_disabled(taskmodule, task_encoding, caplog): - original_log_first_n_examples = taskmodule.log_first_n_examples - taskmodule.log_first_n_examples = None - caplog.clear() - with caplog.at_level(logging.INFO): - taskmodule.maybe_log_example(task_encoding) - assert caplog.record_tuples == [] - - # restore original value - taskmodule.log_first_n_examples = original_log_first_n_examples - - -@pytest.fixture() -def task_encodings(taskmodule, document): - return taskmodule.encode(documents=[document], encode_target=True) - - -@pytest.fixture() -def batch(taskmodule, task_encodings): - return taskmodule.collate(task_encodings) - - -def test_collate(batch, taskmodule): - inputs, targets = batch - for tensor in inputs.values(): - assert isinstance(tensor, torch.Tensor) - assert tensor.dtype == torch.int64 - for tensor in targets.values(): - assert isinstance(tensor, torch.Tensor) - assert tensor.dtype == torch.int64 - inputs_lists = {k: inputs[k].tolist() for k in sorted(inputs)} - targets_lists = {k: targets[k].tolist() for k in sorted(targets)} - if taskmodule.partition_layer_name is None: - assert inputs_lists == { - "input_ids": [[0, 713, 16, 10, 34759, 2788, 59, 1085, 4, 3101, 162, 4, 2]], - "attention_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], - } - assert targets_lists == { - "constraints": [ - [ - [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], - [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1], - [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] - ], - "labels": [[14, 14, 5, 11, 12, 3, 6, 17, 17, 4, 2, 2, 2, 2, 1]], - "decoder_attention_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], - } - elif taskmodule.partition_layer_name == "sentences": - assert inputs_lists == { - "input_ids": [ - [0, 713, 16, 10, 34759, 2788, 59, 1085, 4, 2], - [0, 18823, 162, 4, 2, 1, 1, 1, 1, 1], - ], - "attention_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]], - } - assert targets_lists == { - "constraints": [ - [ - [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1], - [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0], - [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ], - [ - [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, -1, -1, -1, -1, -1], - [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1], - [0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, -1, -1, -1, -1, -1], - [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1], - [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1], - [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1], - [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1], - ], - ], - "labels": [[14, 14, 5, 11, 12, 3, 6, 1], [9, 9, 4, 2, 2, 2, 2, 1]], - "decoder_attention_mask": [ - [1, 1, 1, 1, 1, 1, 1, 1], - [1, 1, 1, 1, 1, 1, 1, 1], - ], - } - else: - raise Exception(f"unknown partition_layer_name: {taskmodule.partition_layer_name}") - - -@pytest.fixture() -def unbatched_output(taskmodule, batch): - inputs, targets = batch - # because the model is trained to reproduce the target tokens, we can just use them as model prediction - return taskmodule.unbatch_output(targets) - - -@pytest.fixture() -def task_outputs(unbatched_output): - return unbatched_output - - -@pytest.fixture() -def task_output(task_outputs) -> LabelsAndOptionalConstraints: - return task_outputs[0] - - -def test_task_output(task_output, taskmodule): - output_list = task_output.labels - if taskmodule.partition_layer_name is None: - assert output_list == [14, 14, 5, 11, 12, 3, 6, 17, 17, 4, 2, 2, 2, 2, 1] - elif taskmodule.partition_layer_name == "sentences": - assert output_list == [14, 14, 5, 11, 12, 3, 6, 1] - else: - raise Exception(f"unknown partition_layer_name: {taskmodule.partition_layer_name}") - - -def _test_annotations_from_output(task_encodings, task_outputs, taskmodule, layer_names_expected): - assert len(task_outputs) == len(task_encodings) - - # this needs to be outside the below loop because documents can contain duplicates - # which would break the comparison when clearing predictions that were already added - for task_encoding in task_encodings: - for layer_name in layer_names_expected: - task_encoding.document[layer_name].predictions.clear() - - layer_names: Set[str] = set() - # Note: this list may contain duplicates! - documents: List[Document] = [] - for i in range(len(task_outputs)): - task_encoding = task_encodings[i] - task_output = task_outputs[i] - documents.append(task_encoding.document) - - for layer_name, annotation in taskmodule.create_annotations_from_output( - task_encoding=task_encoding, task_output=task_output - ): - task_encoding.document[layer_name].predictions.append(annotation) - layer_names.add(layer_name) - - assert layer_names == layer_names_expected - - for document in documents: - for layer_name in layer_names: - layer = { - str(ann) - for ann in document[layer_name].predictions - if ann.label in taskmodule.labels_per_layer[layer_name] - } - layer_expected = { - str(ann) - for ann in document[layer_name] - if ann.label in taskmodule.labels_per_layer[layer_name] - } - assert layer == layer_expected - - # this needs to be outside the above loop because documents can contain duplicates - # which would break the comparison when clearing predictions too early - for document in documents: - for layer_name in layer_names: - document[layer_name].predictions.clear() - - -def test_annotations_from_output(task_encodings, task_outputs, taskmodule): - _test_annotations_from_output( - taskmodule=taskmodule, - task_encodings=task_encodings, - task_outputs=task_outputs, - layer_names_expected={"entities", "relations"}, - ) - - -def get_default_taskmodule(**kwargs): - taskmodule = PointerNetworkTaskModuleForEnd2EndRE( - tokenizer_name_or_path="facebook/bart-base", - labels_per_layer={ - "labeled_spans": ["content", "person", "topic"], - "binary_relations": ["is_about"], - }, - **kwargs, - ) - taskmodule.post_prepare() - return taskmodule - - -def test_configure_model_metric(): - taskmodule = get_default_taskmodule() - metric = taskmodule.configure_model_metric() - assert metric is not None - values = metric.compute() - assert values == { - "binary_relations": {}, - "decoding_errors": {"all": 0.0}, - "exact_encoding_matches": 0.0, - "labeled_spans": {}, - } - - model_output = {"labels": torch.tensor([[14, 14, 5, 11, 12, 3, 6, 17, 17, 4, 2, 2, 2, 2, 1]])} - # test with expected == prediction - metric.update(model_output, model_output) - values = metric.compute() - assert values == { - "exact_encoding_matches": 1.0, - "decoding_errors": {"correct": 1.0, "all": 0.0}, - "labeled_spans": { - "content": {"recall": 1.0, "precision": 1.0, "f1": 1.0}, - "person": {"recall": 1.0, "precision": 1.0, "f1": 1.0}, - "topic": {"recall": 1.0, "precision": 1.0, "f1": 1.0}, - "macro": {"recall": 1.0, "precision": 1.0, "f1": 1.0}, - "micro": {"recall": 1.0, "precision": 1.0, "f1": 1.0}, - }, - "binary_relations": { - "is_about": {"recall": 1.0, "precision": 1.0, "f1": 1.0}, - "macro": {"recall": 1.0, "precision": 1.0, "f1": 1.0}, - "micro": {"recall": 1.0, "precision": 1.0, "f1": 1.0}, - }, - } - torch.random.manual_seed(42) - # random_labels = torch.randint(0, 20, (1, 30)) - # split into random_labels1 and random_labels2 just for better code formatting - random_labels1 = [0, 14, 4, 19, 2, 6, 18, 3, 0, 8, 8, 14, 2, 1] - random_labels2 = [14, 6, 7, 8, 4, 1, 17, 9, 14, 7, 13, 15, 5, 12, 18, 13] - labels_random = torch.tensor([random_labels1 + random_labels2]) - metric.reset() - # test the case where we have mixed results (correct and wrong) - metric.update(model_output, model_output) - metric.update(prediction={"labels": labels_random}, expected=model_output) - values = metric.compute() - assert values == { - "exact_encoding_matches": 0.5, - "decoding_errors": {"correct": 0.5, "len": 0.25, "order": 0.25, "all": 0.5}, - "labeled_spans": { - "person": {"recall": 0.5, "precision": 1.0, "f1": 0.6666666865348816}, - "topic": {"recall": 0.5, "precision": 1.0, "f1": 0.6666666865348816}, - "content": {"recall": 0.5, "precision": 1.0, "f1": 0.6666666865348816}, - "macro": {"recall": 0.5, "precision": 1.0, "f1": 0.6666666865348816}, - "micro": {"recall": 0.5, "precision": 1.0, "f1": 0.6666666865348816}, - }, - "binary_relations": { - "is_about": {"recall": 0.5, "precision": 1.0, "f1": 0.6666666865348816}, - "macro": {"recall": 0.5, "precision": 1.0, "f1": 0.6666666865348816}, - "micro": {"recall": 0.5, "precision": 1.0, "f1": 0.6666666865348816}, - }, - } - - # ensure that the metric can be pickled - pickle.dumps(metric) - - -def test_configure_model_generation(): - taskmodule = get_default_taskmodule() - assert taskmodule.configure_model_generation() == { - "no_repeat_ngram_size": 7, - } - - -def test_configure_model_generation_with_constrained_generation(): - taskmodule = get_default_taskmodule(constrained_generation=True) - generation_config = taskmodule.configure_model_generation() - assert set(generation_config) == {"no_repeat_ngram_size", "logits_processor"} - assert generation_config["no_repeat_ngram_size"] == 7 - logits_processor = generation_config["logits_processor"] - assert isinstance(logits_processor, LogitsProcessorList) - assert len(logits_processor) == 2 - assert isinstance(logits_processor[0], FinitizeLogitsProcessor) - assert isinstance(logits_processor[1], PrefixConstrainedLogitsProcessorWithMaximum) - - -def test_prefix_allowed_tokens_fn_with_maximum(): - taskmodule = get_default_taskmodule() - # not that this includes the leading bos token - add_previous_input_ids = torch.tensor([0, 14, 14, 5, 11, 12, 3, 6, 17, 17, 4, 2, 2, 2, 2, 1]) - - # empty input (first entry) - allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum( - batch_id=0, input_ids=add_previous_input_ids[:1], maximum=20 - ) - # allow the eos id [1] and all offset ids [7..19] - assert allowed_ids == [1, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] - - # first span start - allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum( - batch_id=0, input_ids=add_previous_input_ids[:2], maximum=20 - ) - # allow all offset ids from first span start [14..19] - assert allowed_ids == [14, 15, 16, 17, 18, 19] - - # first span start and end - allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum( - batch_id=0, input_ids=add_previous_input_ids[:3], maximum=20 - ) - # allow all span ids - assert allowed_ids == [3, 4, 5] - - # first span start, end, and label - allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum( - batch_id=0, input_ids=add_previous_input_ids[:4], maximum=20 - ) - # allow none [2] and all offsets except offsets covered by first span [14] - assert allowed_ids == [2, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18, 19] - - # first span, and second span start - allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum( - batch_id=0, input_ids=add_previous_input_ids[:5], maximum=20 - ) - # allow all offsets from second span start [11], but before first span start [14] because it would be an overlap - assert allowed_ids == [11, 12, 13] - - # first span, and second span start and end - allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum( - batch_id=0, input_ids=add_previous_input_ids[:6], maximum=20 - ) - # allow all span ids - assert allowed_ids == [3, 4, 5] - - # first span, and second span - allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum( - batch_id=0, input_ids=add_previous_input_ids[:7], maximum=20 - ) - # allow all relation ids - assert allowed_ids == [6] - - # entry begins (second entry) - allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum( - batch_id=0, input_ids=add_previous_input_ids[:8], maximum=20 - ) - # allow eos [1] and all offsets [7..19] - assert allowed_ids == [1, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] - - # first span start - allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum( - batch_id=0, input_ids=add_previous_input_ids[:9], maximum=20 - ) - # allow all offsets from first span start [17..19] - assert allowed_ids == [17, 18, 19] - - # first span start and end - allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum( - batch_id=0, input_ids=add_previous_input_ids[:10], maximum=20 - ) - # allow all span ids - assert allowed_ids == [3, 4, 5] - - # first span start, end, and span label - allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum( - batch_id=0, input_ids=add_previous_input_ids[:11], maximum=20 - ) - # allow none [2] and all offsets except offsets covered by first span [17] - assert allowed_ids == [2, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19] - - # first span, and none - allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum( - batch_id=0, input_ids=add_previous_input_ids[:12], maximum=20 - ) - # allow only none [2] because when the entry contains already a none id, it cannot be followed by anything else - assert allowed_ids == [2] - - # first span, and none, and none - allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum( - batch_id=0, input_ids=add_previous_input_ids[:13], maximum=20 - ) - # allow only none [2] because when the entry contains already a none id, it cannot be followed by anything else - assert allowed_ids == [2] - - # first span, and none, and none, and none - allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum( - batch_id=0, input_ids=add_previous_input_ids[:14], maximum=20 - ) - # allow only none [2] because when the entry contains already a none id, it cannot be followed by anything else - assert allowed_ids == [2] - - # first span, and none, and none, and none, and none (second entry is complete) - allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum( - batch_id=0, input_ids=add_previous_input_ids[:15], maximum=20 - ) - # allow eos [1] and all offsets [7..19] - assert allowed_ids == [1, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] - - # got an eos, so the sequence is complete - allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum( - batch_id=0, input_ids=add_previous_input_ids[:16], maximum=20 - ) - # allow only pad [1] (same as eos) because the sequence is complete - assert allowed_ids == [1] diff --git a/tests/taskmodules/test_re_span_pair_classification.py b/tests/taskmodules/test_re_span_pair_classification.py deleted file mode 100644 index f82554c1e..000000000 --- a/tests/taskmodules/test_re_span_pair_classification.py +++ /dev/null @@ -1,614 +0,0 @@ -import dataclasses -import logging -from typing import Any, Dict, Union - -import pytest -import torch -from pie_core import AnnotationLayer, annotation_field -from pie_core.utils.dictionary import flatten_dict_s -from torch import tensor -from torchmetrics import Metric, MetricCollection - -from pie_modules.annotations import BinaryRelation, LabeledSpan -from pie_modules.documents import TextBasedDocument -from pie_modules.taskmodules import RESpanPairClassificationTaskModule -from pie_modules.utils.span import distance -from tests import _config_to_str - -TOKENIZER_NAME_OR_PATH = "bert-base-cased" - -CONFIGS = [{}, {"partition_annotation": "sentences"}] -CONFIGS_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS} - - -@pytest.fixture(scope="module", params=CONFIGS_DICT.keys()) -def cfg(request): - return CONFIGS_DICT[request.param] - - -@pytest.fixture(scope="module") -def unprepared_taskmodule(cfg): - taskmodule = RESpanPairClassificationTaskModule( - relation_annotation="relations", - tokenizer_name_or_path=TOKENIZER_NAME_OR_PATH, - log_first_n_examples=10, - collect_statistics=True, - **cfg, - ) - assert not taskmodule.is_from_pretrained - - return taskmodule - - -@dataclasses.dataclass -class FixedTestDocument(TextBasedDocument): - sentences: AnnotationLayer[LabeledSpan] = annotation_field(target="text") - entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") - relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities") - - -@pytest.fixture(scope="module") -def fixed_documents(documents): - result = [] - for document in documents: - fixed_doc = document.copy(with_annotations=False).as_type(FixedTestDocument) - for sentence in document.sentences: - fixed_doc.sentences.append( - LabeledSpan(start=sentence.start, end=sentence.end, label="sentence") - ) - entity_mapping = {} - for entity in document.entities: - new_entity = entity.copy() - fixed_doc.entities.append(new_entity) - entity_mapping[entity] = new_entity - for relation in document.relations: - new_relation = relation.copy( - head=entity_mapping[relation.head], tail=entity_mapping[relation.tail] - ) - fixed_doc.relations.append(new_relation) - result.append(fixed_doc) - return result - - -@pytest.fixture(scope="module") -def taskmodule(unprepared_taskmodule, fixed_documents) -> RESpanPairClassificationTaskModule: - unprepared_taskmodule.prepare(fixed_documents) - return unprepared_taskmodule - - -def test_taskmodule(taskmodule: RESpanPairClassificationTaskModule): - assert taskmodule.is_prepared - - assert taskmodule.relation_annotation == "relations" - assert taskmodule.labels == ["org:founded_by", "per:employee_of", "per:founder"] - assert taskmodule.entity_labels == ["ORG", "PER"] - assert taskmodule.label_to_id == { - "org:founded_by": 1, - "per:employee_of": 2, - "per:founder": 3, - "no_relation": 0, - } - assert taskmodule.argument_markers == [ - "[/SPAN:ORG]", - "[/SPAN:PER]", - "[SPAN:ORG]", - "[SPAN:PER]", - ] - assert taskmodule.tokenizer.additional_special_tokens == [ - "[SPAN:PER]", - "[/SPAN:ORG]", - "[/SPAN:PER]", - "[SPAN:ORG]", - ] - assert taskmodule.tokenizer.additional_special_tokens_ids == [28996, 28997, 28998, 28999] - - # because this is not the standard value for relation_annotation, we can not determine the document type - assert taskmodule.document_type is None - - -@pytest.fixture(scope="module") -def document(fixed_documents): - result = fixed_documents[4] - assert ( - result.metadata["description"] - == "sentences with multiple relation annotations and cross-sentence relation" - ) - return result - - -def test_create_candidate_relations(taskmodule, document): - # _create_candidate_relations requires normalized documents - normalized_document = taskmodule.normalize_document(document) - candidate_relations = taskmodule._create_candidate_relations(normalized_document) - resolved_relations = [ann.resolve() for ann in candidate_relations] - assert resolved_relations == [ - ("no_relation", (("PER", "Entity G"), ("ORG", "H"))), - ("no_relation", (("PER", "Entity G"), ("ORG", "I"))), - ("no_relation", (("ORG", "H"), ("PER", "Entity G"))), - ("no_relation", (("ORG", "H"), ("ORG", "I"))), - ("no_relation", (("ORG", "I"), ("PER", "Entity G"))), - ("no_relation", (("ORG", "I"), ("ORG", "H"))), - ] - - -def test_create_candidate_relations_with_max_distance(taskmodule, document): - # _create_candidate_relations requires normalized documents - normalized_document = taskmodule.normalize_document(document) - candidate_relations = taskmodule._create_candidate_relations( - normalized_document, max_argument_distance=10 - ) - resolved_relations = [ann.resolve() for ann in candidate_relations] - assert resolved_relations == [ - ("no_relation", (("PER", "Entity G"), ("ORG", "H"))), - ("no_relation", (("ORG", "H"), ("PER", "Entity G"))), - ] - distances = [ - distance( - start_end=(rel.head.start, rel.head.end), - other_start_end=(rel.tail.start, rel.tail.end), - distance_type="inner", - ) - for rel in candidate_relations - ] - assert distances == [10.0, 10.0] - - -@pytest.fixture(scope="module") -def task_encodings(taskmodule, document): - result = taskmodule.encode(document, encode_target=True) - return result - - -def test_encode_input(task_encodings, document, taskmodule, cfg): - assert task_encodings is not None - if cfg == {}: - assert len(task_encodings) == 1 - inputs = task_encodings[0].inputs - assert set(inputs) == { - "input_ids", - "attention_mask", - "span_start_indices", - "span_end_indices", - "tuple_indices", - "tuple_indices_mask", - } - tokens = taskmodule.tokenizer.convert_ids_to_tokens(inputs["input_ids"]) - assert tokens == [ - "[CLS]", - "First", - "sentence", - ".", - "[SPAN:PER]", - "En", - "##ti", - "##ty", - "G", - "[/SPAN:PER]", - "works", - "at", - "[SPAN:ORG]", - "H", - "[/SPAN:ORG]", - ".", - "And", - "founded", - "[SPAN:ORG]", - "I", - "[/SPAN:ORG]", - ".", - "[SEP]", - ] - span_tokens = [ - tokens[start:end] - for start, end in zip(inputs["span_start_indices"], inputs["span_end_indices"]) - ] - assert span_tokens == [ - ["[SPAN:PER]", "En", "##ti", "##ty", "G", "[/SPAN:PER]"], - ["[SPAN:ORG]", "H", "[/SPAN:ORG]"], - ["[SPAN:ORG]", "I", "[/SPAN:ORG]"], - ] - tuple_spans = [ - [span_tokens[idx] for idx in indices] for indices in inputs["tuple_indices"] - ] - assert tuple_spans == [ - [ - ["[SPAN:PER]", "En", "##ti", "##ty", "G", "[/SPAN:PER]"], - ["[SPAN:ORG]", "H", "[/SPAN:ORG]"], - ], - [ - ["[SPAN:PER]", "En", "##ti", "##ty", "G", "[/SPAN:PER]"], - ["[SPAN:ORG]", "I", "[/SPAN:ORG]"], - ], - [["[SPAN:ORG]", "I", "[/SPAN:ORG]"], ["[SPAN:ORG]", "H", "[/SPAN:ORG]"]], - ] - assert inputs["tuple_indices_mask"].tolist() == [True, True, True] - elif cfg == {"partition_annotation": "sentences"}: - assert len(task_encodings) == 1 - for idx, encoding in enumerate(task_encodings): - inputs = encoding.inputs - assert set(inputs) == { - "input_ids", - "attention_mask", - "span_start_indices", - "span_end_indices", - "tuple_indices", - "tuple_indices_mask", - } - tokens = taskmodule.tokenizer.convert_ids_to_tokens(inputs["input_ids"]) - span_tokens = [ - tokens[start:end] - for start, end in zip(inputs["span_start_indices"], inputs["span_end_indices"]) - ] - tuple_spans = [ - [span_tokens[idx] for idx in indices] for indices in inputs["tuple_indices"] - ] - if idx == 0: - assert tokens == [ - "[CLS]", - "En", - "##ti", - "##ty", - "G", - "[/SPAN:PER]", - "works", - "at", - "[SPAN:ORG]", - "H", - "[/SPAN:ORG]", - ".", - "[SEP]", - ] - assert span_tokens == [ - ["[CLS]", "En", "##ti", "##ty", "G", "[/SPAN:PER]"], - ["[SPAN:ORG]", "H", "[/SPAN:ORG]"], - ] - assert tuple_spans == [ - [ - ["[CLS]", "En", "##ti", "##ty", "G", "[/SPAN:PER]"], - ["[SPAN:ORG]", "H", "[/SPAN:ORG]"], - ] - ] - assert inputs["tuple_indices_mask"].tolist() == [True] - else: - raise ValueError(f"unexpected idx: {idx}") - else: - raise ValueError(f"unexpected config: {cfg}") - - -def test_encode_target(taskmodule, task_encodings, cfg): - if cfg == {}: - assert len(task_encodings) == 1 - targets = task_encodings[0].targets - labels = [taskmodule.id_to_label[label] for label in targets["labels"].tolist()] - assert labels == ["per:employee_of", "per:founder", "org:founded_by"] - elif cfg == {"partition_annotation": "sentences"}: - assert len(task_encodings) == 1 - for idx, encoding in enumerate(task_encodings): - targets = encoding.targets - labels = [taskmodule.id_to_label[label] for label in targets["labels"].tolist()] - if idx == 0: - assert labels == ["per:employee_of"] - else: - raise ValueError(f"unexpected idx: {idx}") - else: - raise ValueError(f"unexpected config: {cfg}") - - -def test_encode_with_no_gold_relation(document): - # create a new taskmodule that does create candidate relations - taskmodule = RESpanPairClassificationTaskModule( - relation_annotation="relations", - tokenizer_name_or_path=TOKENIZER_NAME_OR_PATH, - create_candidate_relations=True, - labels=["org:founded_by", "per:employee_of", "per:founder"], - entity_labels=["ORG", "PER"], - ) - taskmodule.post_prepare() - # create a new document that has no relations - document = document.copy() - document.relations.clear() - - encodings = taskmodule.encode(document, encode_target=True) - - assert len(encodings) == 1 - encoding = encodings[0] - # same number of candidate relations as there are labels - assert len(encoding.metadata["candidate_relations"]) == encoding.targets["labels"].numel() - assert all(rel.label == "no_relation" for rel in encoding.metadata["candidate_relations"]) - assert encoding.targets["labels"].tolist() == [0, 0, 0, 0, 0, 0] - - -def test_encode_with_multiple_gold_relations_with_same_arguments(document, caplog): - # create a new taskmodule that does create candidate relations - taskmodule = RESpanPairClassificationTaskModule( - relation_annotation="relations", - tokenizer_name_or_path=TOKENIZER_NAME_OR_PATH, - labels=["org:founded_by", "per:employee_of", "per:founder"], - entity_labels=["ORG", "PER"], - ) - taskmodule.post_prepare() - # create a new document that has multiple relations with the same arguments - document = document.copy() - document.relations.clear() - head = document.entities[0] - tail = document.entities[1] - document.relations.extend( - [ - BinaryRelation(head=head, tail=tail, label="org:founded_by"), - BinaryRelation(head=head, tail=tail, label="per:employee_of"), - ] - ) - - caplog.clear() - with caplog.at_level(logging.WARNING): - encodings = taskmodule.encode(document, encode_target=True) - assert len(caplog.messages) == 2 - assert ( - caplog.messages[0] - == "skip the candidate relation because there are more than one gold relation for " - "its args and roles: [BinaryRelation(head=LabeledSpan(start=5, end=10, label='PER', score=1.0), " - "tail=LabeledSpan(start=13, end=15, label='ORG', score=1.0), label='org:founded_by', score=1.0), " - "BinaryRelation(head=LabeledSpan(start=5, end=10, label='PER', score=1.0), " - "tail=LabeledSpan(start=13, end=15, label='ORG', score=1.0), label='per:employee_of', score=1.0)]" - ) - assert ( - caplog.messages[1] - == "skip the candidate relation because there are more than one gold relation for " - "its args and roles: [BinaryRelation(head=LabeledSpan(start=5, end=10, label='PER', score=1.0), " - "tail=LabeledSpan(start=13, end=15, label='ORG', score=1.0), label='org:founded_by', score=1.0), " - "BinaryRelation(head=LabeledSpan(start=5, end=10, label='PER', score=1.0), " - "tail=LabeledSpan(start=13, end=15, label='ORG', score=1.0), label='per:employee_of', score=1.0)]" - ) - - assert len(encodings) == 1 - encoding = encodings[0] - candidate_relations = encoding.metadata["candidate_relations"] - # same number of candidate relations as there are labels - assert len(candidate_relations) == encoding.targets["labels"].numel() - assert candidate_relations[0].label == "org:founded_by" - assert candidate_relations[1].label == "per:employee_of" - assert encoding.targets["labels"].tolist() == [-100, -100] - - -def test_maybe_log_example(taskmodule, task_encodings, caplog, cfg): - caplog.clear() - if cfg == {}: - with caplog.at_level(logging.INFO): - taskmodule._maybe_log_example(task_encodings[0], target=task_encodings[0].targets) - assert caplog.messages == [ - "*** Example ***", - "doc id: train_doc5", - "tokens: [CLS] First sentence . [SPAN:PER] En ##ti ##ty G [/SPAN:PER] works at [SPAN:ORG] H [/SPAN:ORG] . And founded [SPAN:ORG] I [/SPAN:ORG] . [SEP]", - "input_ids: 101 1752 5650 119 28996 13832 3121 2340 144 28998 1759 1120 28999 145 28997 119 1262 1771 28999 146 28997 119 102", - "relation 0: per:employee_of", - "\targ 0: [SPAN:PER] En ##ti ##ty G [/SPAN:PER]", - "\targ 1: [SPAN:ORG] H [/SPAN:ORG]", - "relation 1: per:founder", - "\targ 0: [SPAN:PER] En ##ti ##ty G [/SPAN:PER]", - "\targ 1: [SPAN:ORG] I [/SPAN:ORG]", - "relation 2: org:founded_by", - "\targ 0: [SPAN:ORG] I [/SPAN:ORG]", - "\targ 1: [SPAN:ORG] H [/SPAN:ORG]", - ] - elif cfg == {"partition_annotation": "sentences"}: - with caplog.at_level(logging.INFO): - taskmodule._maybe_log_example(task_encodings[0], target=task_encodings[0].targets) - assert caplog.messages == [ - "*** Example ***", - "doc id: train_doc5", - "tokens: [CLS] En ##ti ##ty G [/SPAN:PER] works at [SPAN:ORG] H [/SPAN:ORG] . [SEP]", - "input_ids: 101 13832 3121 2340 144 28998 1759 1120 28999 145 28997 119 102", - "relation 0: per:employee_of", - "\targ 0: [CLS] En ##ti ##ty G [/SPAN:PER]", - "\targ 1: [SPAN:ORG] H [/SPAN:ORG]", - ] - else: - raise ValueError(f"unexpected config: {cfg}") - - -def test_encode_with_statistics(taskmodule, fixed_documents, cfg, caplog): - caplog.clear() - with caplog.at_level(logging.INFO): - taskmodule.encode(fixed_documents, encode_target=True) - assert len(caplog.messages) > 0 - statistics = caplog.messages[-1] - if cfg == {}: - assert ( - statistics - == """statistics: -| | org:founded_by | per:employee_of | per:founder | -|:--------------------|-----------------:|------------------:|--------------:| -| available | 2 | 3 | 2 | -| available_tokenized | 2 | 3 | 2 | -| used | 2 | 3 | 2 |""" - ) - elif cfg == {"partition_annotation": "sentences"}: - assert ( - statistics - == """statistics: -| | org:founded_by | per:employee_of | per:founder | -|:--------------------|-----------------:|------------------:|--------------:| -| available | 2 | 3 | 2 | -| available_tokenized | 1 | 3 | 1 | -| used | 1 | 3 | 1 |""" - ) - else: - raise ValueError(f"unexpected config: {cfg}") - - -def test_collate(taskmodule, task_encodings, cfg): - result = taskmodule.collate(task_encodings) - assert result is not None - inputs, targets = result - assert set(inputs) == { - "input_ids", - "attention_mask", - "span_start_indices", - "span_end_indices", - "tuple_indices", - "tuple_indices_mask", - } - if cfg == {}: - torch.testing.assert_close( - inputs["input_ids"], - tensor( - [ - [ - 101, - 1752, - 5650, - 119, - 28996, - 13832, - 3121, - 2340, - 144, - 28998, - 1759, - 1120, - 28999, - 145, - 28997, - 119, - 1262, - 1771, - 28999, - 146, - 28997, - 119, - 102, - ] - ] - ), - ) - torch.testing.assert_close(inputs["attention_mask"], torch.ones_like(inputs["input_ids"])) - torch.testing.assert_close(inputs["span_start_indices"], tensor([[4, 12, 18]])) - torch.testing.assert_close(inputs["span_end_indices"], tensor([[10, 15, 21]])) - torch.testing.assert_close(inputs["tuple_indices"], tensor([[[0, 1], [0, 2], [2, 1]]])) - torch.testing.assert_close(inputs["tuple_indices_mask"], tensor([[True, True, True]])) - assert set(targets) == {"labels"} - torch.testing.assert_close(targets["labels"], tensor([[2, 3, 1]])) - elif cfg == {"partition_annotation": "sentences"}: - torch.testing.assert_close( - inputs["input_ids"], - tensor( - [[101, 13832, 3121, 2340, 144, 28998, 1759, 1120, 28999, 145, 28997, 119, 102]] - ), - ) - torch.testing.assert_close( - inputs["attention_mask"], tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) - ) - torch.testing.assert_close(inputs["span_start_indices"], tensor([[0, 8]])) - torch.testing.assert_close(inputs["span_end_indices"], tensor([[6, 11]])) - torch.testing.assert_close(inputs["tuple_indices"], tensor([[[0, 1]]])) - torch.testing.assert_close(inputs["tuple_indices_mask"], tensor([[True]])) - assert set(targets) == {"labels"} - torch.testing.assert_close(targets["labels"], tensor([[2]])) - else: - raise ValueError(f"unexpected config: {cfg}") - - -@pytest.fixture -def model_output(): - return { - "labels": torch.tensor([[2, 3, 1]]), - "probabilities": torch.tensor( - [ - [ - # no_relation, org:founded_by, per:employee_of, per:founder - [0.1, 0.2, 0.6, 0.1], - [0.1, 0.2, 0.2, 0.5], - [0.2, 0.5, 0.2, 0.1], - ] - ] - ), - } - - -@pytest.fixture -def unbatched_model_outputs(taskmodule, model_output): - return taskmodule.unbatch_output(model_output) - - -def test_unbatch_outputs(taskmodule, unbatched_model_outputs): - assert len(unbatched_model_outputs) == 1 - result = unbatched_model_outputs[0] - assert set(result) == {"labels", "probabilities"} - assert result["labels"] == ["per:employee_of", "per:founder", "org:founded_by"] - assert result["probabilities"] == [0.6000000238418579, 0.5, 0.5] - - -def test_create_annotations_from_output( - taskmodule, unbatched_model_outputs, task_encodings, document -): - result = list( - taskmodule.create_annotations_from_output( - task_encoding=task_encodings[0], task_output=unbatched_model_outputs[0] - ) - ) - scores = [0.6000000238418579, 0.5, 0.5] - for i, ((layer_name, predicted_relation), original_relation) in enumerate( - zip(result, document.relations) - ): - assert layer_name == taskmodule.relation_annotation - assert predicted_relation == original_relation.copy() - assert predicted_relation.score == scores[i] - - -def get_metric_state(metric_or_collection: Union[Metric, MetricCollection]) -> Dict[str, Any]: - if isinstance(metric_or_collection, Metric): - return { - k: v.tolist() for k, v in flatten_dict_s(metric_or_collection.metric_state).items() - } - elif isinstance(metric_or_collection, MetricCollection): - return flatten_dict_s({k: get_metric_state(v) for k, v in metric_or_collection.items()}) - else: - raise ValueError(f"unsupported type: {type(metric_or_collection)}") - - -def test_configure_model_metrics(taskmodule, model_output): - metrics = taskmodule.configure_model_metric(stage="train") - assert metrics is not None - assert isinstance(metrics, (Metric, MetricCollection)) - state = get_metric_state(metrics) - assert state == { - "f1_per_label/tp": [0, 0, 0, 0], - "f1_per_label/fp": [0, 0, 0, 0], - "f1_per_label/tn": [0, 0, 0, 0], - "f1_per_label/fn": [0, 0, 0, 0], - "macro/f1/tp": [0, 0, 0, 0], - "macro/f1/fp": [0, 0, 0, 0], - "macro/f1/tn": [0, 0, 0, 0], - "macro/f1/fn": [0, 0, 0, 0], - "micro/f1/tp": [0], - "micro/f1/fp": [0], - "micro/f1/tn": [0], - "micro/f1/fn": [0], - } - - metric_values = metrics(model_output, model_output) - state = get_metric_state(metrics) - assert state == { - "f1_per_label/tp": [0, 1, 1, 1], - "f1_per_label/fp": [0, 0, 0, 0], - "f1_per_label/tn": [3, 2, 2, 2], - "f1_per_label/fn": [0, 0, 0, 0], - "macro/f1/tp": [0, 1, 1, 1], - "macro/f1/fp": [0, 0, 0, 0], - "macro/f1/tn": [3, 2, 2, 2], - "macro/f1/fn": [0, 0, 0, 0], - "micro/f1/tp": [3], - "micro/f1/fp": [0], - "micro/f1/tn": [9], - "micro/f1/fn": [0], - } - - metric_values_converted = {key: value.item() for key, value in metric_values.items()} - assert metric_values_converted == { - "macro/f1": 1.0, - "micro/f1": 1.0, - "no_relation/f1": 0.0, - "org:founded_by/f1": 1.0, - "per:employee_of/f1": 1.0, - "per:founder/f1": 1.0, - } diff --git a/tests/taskmodules/test_re_text_classification_with_indices.py b/tests/taskmodules/test_re_text_classification_with_indices.py deleted file mode 100644 index 0574e10e5..000000000 --- a/tests/taskmodules/test_re_text_classification_with_indices.py +++ /dev/null @@ -1,3159 +0,0 @@ -import dataclasses -import logging -import pickle -import re -from dataclasses import dataclass -from typing import Any, Dict, List, Union - -import pytest -import torch -from pie_core import ( - Annotation, - AnnotationLayer, - Document, - TaskEncoding, - annotation_field, -) -from pie_core.utils.dictionary import flatten_dict_s -from torch import tensor -from torchmetrics import Metric, MetricCollection - -from pie_modules.annotations import BinaryRelation, LabeledSpan, NaryRelation -from pie_modules.documents import ( - TextBasedDocument, - TextDocumentWithLabeledSpansAndBinaryRelations, -) -from pie_modules.taskmodules import RETextClassificationWithIndicesTaskModule -from pie_modules.taskmodules.re_text_classification_with_indices import ( - HEAD, - TAIL, - find_sublist, - get_relation_argument_spans_and_roles, - span_distance, -) -from pie_modules.utils.span import distance_inner -from tests import _config_to_str -from tests.conftest import _TABULATE_AVAILABLE, TestDocument - -CONFIGS = [ - {"add_type_to_marker": False, "append_markers": False}, - {"add_type_to_marker": True, "append_markers": False}, - {"add_type_to_marker": False, "append_markers": True}, - {"add_type_to_marker": True, "append_markers": True}, -] -CONFIGS_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS} - - -@pytest.fixture(scope="module", params=CONFIGS_DICT.keys()) -def cfg(request): - return CONFIGS_DICT[request.param] - - -def test_taskmodule_with_deprecated_parameters(caplog): - with caplog.at_level(logging.WARNING): - tokenizer_name_or_path = "bert-base-cased" - taskmodule = RETextClassificationWithIndicesTaskModule( - tokenizer_name_or_path=tokenizer_name_or_path, label_to_id={"a": 0, "b": 1} - ) - assert taskmodule.labels == ["a", "b"] - # check the warning message - assert len(caplog.records) == 1 - assert ( - caplog.records[0].message - == "The parameter label_to_id is deprecated and will be removed in a future version. Please use labels instead." - ) - - -@pytest.fixture(scope="module") -def unprepared_taskmodule(cfg): - tokenizer_name_or_path = "bert-base-cased" - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", tokenizer_name_or_path=tokenizer_name_or_path, **cfg - ) - assert not taskmodule.is_from_pretrained - - return taskmodule - - -@pytest.fixture(scope="module") -def taskmodule(unprepared_taskmodule, documents): - unprepared_taskmodule.prepare(documents) - return unprepared_taskmodule - - -@pytest.fixture -def model_output(): - return { - "labels": torch.tensor([1, 0, 2, 3, 1, 0, 0, 0]), - "probabilities": torch.tensor( - [ - # O, org:founded_by, per:employee_of, per:founder - [0.1, 0.6, 0.1, 0.2], - [0.5, 0.2, 0.2, 0.1], - [0.1, 0.2, 0.6, 0.1], - [0.1, 0.2, 0.2, 0.5], - [0.2, 0.4, 0.3, 0.1], - [0.5, 0.2, 0.2, 0.1], - [0.6, 0.1, 0.2, 0.1], - [0.5, 0.2, 0.2, 0.1], - ] - ), - } - - -def test_prepared_taskmodule(taskmodule, documents): - assert taskmodule.is_prepared - - assert taskmodule.entity_labels == ["ORG", "PER"] - - if taskmodule.append_markers: - if taskmodule.add_type_to_marker: - assert taskmodule.argument_markers == [ - "[/H:ORG]", - "[/H:PER]", - "[/H]", - "[/T:ORG]", - "[/T:PER]", - "[/T]", - "[H:ORG]", - "[H:PER]", - "[H=ORG]", - "[H=PER]", - "[H]", - "[T:ORG]", - "[T:PER]", - "[T=ORG]", - "[T=PER]", - "[T]", - ] - assert taskmodule.argument_markers_to_id == { - "[/H:ORG]": 28996, - "[/H:PER]": 28997, - "[/H]": 28998, - "[/T:ORG]": 28999, - "[/T:PER]": 29000, - "[/T]": 29001, - "[H:ORG]": 29002, - "[H:PER]": 29003, - "[H=ORG]": 29004, - "[H=PER]": 29005, - "[H]": 29006, - "[T:ORG]": 29007, - "[T:PER]": 29008, - "[T=ORG]": 29009, - "[T=PER]": 29010, - "[T]": 29011, - } - - else: - assert taskmodule.argument_markers == [ - "[/H]", - "[/T]", - "[H=ORG]", - "[H=PER]", - "[H]", - "[T=ORG]", - "[T=PER]", - "[T]", - ] - assert taskmodule.argument_markers_to_id == { - "[/H]": 28996, - "[/T]": 28997, - "[H=ORG]": 28998, - "[H=PER]": 28999, - "[H]": 29000, - "[T=ORG]": 29001, - "[T=PER]": 29002, - "[T]": 29003, - } - else: - if taskmodule.add_type_to_marker: - assert taskmodule.argument_markers == [ - "[/H:ORG]", - "[/H:PER]", - "[/H]", - "[/T:ORG]", - "[/T:PER]", - "[/T]", - "[H:ORG]", - "[H:PER]", - "[H]", - "[T:ORG]", - "[T:PER]", - "[T]", - ] - assert taskmodule.argument_markers_to_id == { - "[/H:ORG]": 28996, - "[/H:PER]": 28997, - "[/H]": 28998, - "[/T:ORG]": 28999, - "[/T:PER]": 29000, - "[/T]": 29001, - "[H:ORG]": 29002, - "[H:PER]": 29003, - "[H]": 29004, - "[T:ORG]": 29005, - "[T:PER]": 29006, - "[T]": 29007, - } - else: - assert taskmodule.argument_markers == ["[/H]", "[/T]", "[H]", "[T]"] - assert taskmodule.argument_markers_to_id == { - "[/H]": 28996, - "[/T]": 28997, - "[H]": 28998, - "[T]": 28999, - } - - assert taskmodule.label_to_id == { - "org:founded_by": 1, - "per:employee_of": 2, - "per:founder": 3, - "no_relation": 0, - } - assert taskmodule.id_to_label == { - 1: "org:founded_by", - 2: "per:employee_of", - 3: "per:founder", - 0: "no_relation", - } - - -def test_config(taskmodule): - config = taskmodule._config() - assert config["taskmodule_type"] == "RETextClassificationWithIndicesTaskModule" - assert taskmodule.PREPARED_ATTRIBUTES == ["labels", "entity_labels"] - assert all(attribute in config for attribute in taskmodule.PREPARED_ATTRIBUTES) - assert config["labels"] == ["org:founded_by", "per:employee_of", "per:founder"] - assert config["entity_labels"] == ["ORG", "PER"] - - -@pytest.mark.parametrize("encode_target", [False, True]) -def test_encode(taskmodule, documents, encode_target): - task_encodings = taskmodule.encode(documents, encode_target=encode_target) - - assert len(task_encodings) == 7 - - encoding = task_encodings[0] - - tokens = taskmodule.tokenizer.convert_ids_to_tokens(encoding.inputs["input_ids"]) - assert len(tokens) == len(encoding.inputs["input_ids"]) - - if taskmodule.add_type_to_marker: - assert tokens[:14] == [ - "[CLS]", - "[H:PER]", - "En", - "##ti", - "##ty", - "A", - "[/H:PER]", - "works", - "at", - "[T:ORG]", - "B", - "[/T:ORG]", - ".", - "[SEP]", - ] - else: - assert tokens[:14] == [ - "[CLS]", - "[H]", - "En", - "##ti", - "##ty", - "A", - "[/H]", - "works", - "at", - "[T]", - "B", - "[/T]", - ".", - "[SEP]", - ] - if taskmodule.append_markers: - assert len(tokens) == 14 + 4 - assert tokens[-4:] == ["[H=PER]", "[SEP]", "[T=ORG]", "[SEP]"] - else: - assert len(tokens) == 14 - - if encode_target: - assert encoding.targets == [2] - else: - assert not encoding.has_targets - - with pytest.raises(ValueError, match=re.escape("task encoding has no target")): - encoding.targets - - -@pytest.fixture(scope="module") -def batch(taskmodule, documents): - documents = [documents[i] for i in [0, 1, 4]] - task_encodings = taskmodule.encode(documents, encode_target=True) - return taskmodule.collate(task_encodings[:2]) - - -def test_collate(taskmodule, batch): - inputs, targets = batch - - assert "input_ids" in inputs - assert "attention_mask" in inputs - assert inputs["input_ids"].shape == inputs["attention_mask"].shape - - if taskmodule.append_markers: - assert inputs["input_ids"].shape == (2, 25) - if taskmodule.add_type_to_marker: - torch.testing.assert_close( - inputs.input_ids, - torch.tensor( - [ - [ - 101, - 29003, - 13832, - 3121, - 2340, - 138, - 28997, - 1759, - 1120, - 29007, - 139, - 28999, - 119, - 102, - 29005, - 102, - 29009, - 102, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - ], - [ - 101, - 1752, - 5650, - 119, - 29003, - 13832, - 3121, - 2340, - 144, - 28997, - 1759, - 1120, - 29007, - 145, - 28999, - 119, - 1262, - 1771, - 146, - 119, - 102, - 29005, - 102, - 29009, - 102, - ], - ] - ), - ) - else: - torch.testing.assert_close( - inputs.input_ids, - torch.tensor( - [ - [ - 101, - 29000, - 13832, - 3121, - 2340, - 138, - 28996, - 1759, - 1120, - 29003, - 139, - 28997, - 119, - 102, - 28999, - 102, - 29001, - 102, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - ], - [ - 101, - 1752, - 5650, - 119, - 29000, - 13832, - 3121, - 2340, - 144, - 28996, - 1759, - 1120, - 29003, - 145, - 28997, - 119, - 1262, - 1771, - 146, - 119, - 102, - 28999, - 102, - 29001, - 102, - ], - ] - ), - ) - torch.testing.assert_close( - inputs.attention_mask, - torch.tensor( - [ - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - ] - ), - ) - - else: - assert inputs["input_ids"].shape == (2, 21) - - if taskmodule.add_type_to_marker: - torch.testing.assert_close( - inputs.input_ids, - torch.tensor( - [ - [ - 101, - 29003, - 13832, - 3121, - 2340, - 138, - 28997, - 1759, - 1120, - 29005, - 139, - 28999, - 119, - 102, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - ], - [ - 101, - 1752, - 5650, - 119, - 29003, - 13832, - 3121, - 2340, - 144, - 28997, - 1759, - 1120, - 29005, - 145, - 28999, - 119, - 1262, - 1771, - 146, - 119, - 102, - ], - ] - ), - ) - else: - torch.testing.assert_close( - inputs.input_ids, - torch.tensor( - [ - [ - 101, - 28998, - 13832, - 3121, - 2340, - 138, - 28996, - 1759, - 1120, - 28999, - 139, - 28997, - 119, - 102, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - ], - [ - 101, - 1752, - 5650, - 119, - 28998, - 13832, - 3121, - 2340, - 144, - 28996, - 1759, - 1120, - 28999, - 145, - 28997, - 119, - 1262, - 1771, - 146, - 119, - 102, - ], - ] - ), - ) - torch.testing.assert_close( - inputs.attention_mask, - torch.tensor( - [ - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - ] - ), - ) - - assert set(targets) == {"labels"} - torch.testing.assert_close(targets["labels"], torch.tensor([2, 2])) - - -def test_unbatch_output(taskmodule, model_output): - unbatched_outputs = taskmodule.unbatch_output(model_output) - - assert len(unbatched_outputs) == 8 - - labels = [ - "org:founded_by", - "no_relation", - "per:employee_of", - "per:founder", - "org:founded_by", - "no_relation", - "no_relation", - "no_relation", - ] - probabilities = [0.6, 0.5, 0.6, 0.5, 0.4, 0.5, 0.6, 0.5] - - for output, label, probability in zip(unbatched_outputs, labels, probabilities): - assert set(output.keys()) == {"labels", "probabilities"} - assert output["labels"] == [label] - assert output["probabilities"] == pytest.approx([probability]) - - -@pytest.mark.parametrize("inplace", [False, True]) -def test_decode(taskmodule, documents, model_output, inplace): - # copy the documents, because the taskmodule may modify them - documents = [documents[i].copy() for i in [0, 1, 4]] - - encodings = taskmodule.encode(documents, encode_target=False) - unbatched_outputs = taskmodule.unbatch_output(model_output) - decoded_documents = taskmodule.decode( - task_encodings=encodings, - task_outputs=unbatched_outputs, - inplace=inplace, - ) - - assert len(decoded_documents) == len(documents) - - if inplace: - assert {id(doc) for doc in decoded_documents} == {id(doc) for doc in documents} - else: - assert {id(doc) for doc in decoded_documents}.isdisjoint({id(doc) for doc in documents}) - - expected_scores = [0.6, 0.5, 0.6, 0.5, 0.4, 0.5, 0.6, 0.5] - i = 0 - for document in decoded_documents: - for relation_expected, relation_decoded in zip( - document["entities"], document["entities"].predictions - ): - assert relation_expected.start == relation_decoded.start - assert relation_expected.end == relation_decoded.end - assert relation_expected.label == relation_decoded.label - assert expected_scores[i] == pytest.approx(relation_decoded.score) - i += 1 - - if not inplace: - for document in documents: - assert not document["relations"].predictions - - -def test_encode_with_partition(documents): - tokenizer_name_or_path = "bert-base-cased" - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path=tokenizer_name_or_path, - partition_annotation="sentences", - ) - assert not taskmodule.is_from_pretrained - taskmodule.prepare(documents) - - assert len(documents) == 7 - encodings = taskmodule.encode(documents) - tokens = [ - taskmodule.tokenizer.convert_ids_to_tokens(encoding.inputs["input_ids"]) - for encoding in encodings - ] - assert len(encodings) == 5 - assert encodings[0].document != encodings[1].document - assert encodings[1].document != encodings[2].document - # the last document contains 3 valid relations - assert encodings[2].document == encodings[3].document - assert encodings[3].document == encodings[4].document - assert tokens[0] == [ - "[CLS]", - "[H]", - "En", - "##ti", - "##ty", - "A", - "[/H]", - "works", - "at", - "[T]", - "B", - "[/T]", - ".", - "[SEP]", - ] - assert tokens[1] == [ - "[CLS]", - "[H]", - "En", - "##ti", - "##ty", - "G", - "[/H]", - "works", - "at", - "[T]", - "H", - "[/T]", - ".", - "[SEP]", - ] - assert tokens[2] == [ - "[CLS]", - "[H]", - "En", - "##ti", - "##ty", - "M", - "[/H]", - "works", - "at", - "[T]", - "N", - "[/T]", - ".", - "[SEP]", - ] - assert tokens[3] == [ - "[CLS]", - "And", - "[H]", - "it", - "[/H]", - "founded", - "[T]", - "O", - "[/T]", - "[SEP]", - ] - assert tokens[4] == [ - "[CLS]", - "And", - "[T]", - "it", - "[/T]", - "founded", - "[H]", - "O", - "[/H]", - "[SEP]", - ] - - -def test_encode_with_windowing(documents): - tokenizer_name_or_path = "bert-base-cased" - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path=tokenizer_name_or_path, - max_window=12, - ) - assert not taskmodule.is_from_pretrained - taskmodule.prepare(documents) - - assert len(documents) == 7 - encodings = taskmodule.encode(documents) - assert len(encodings) == 3 - for encoding in encodings: - assert len(encoding.inputs["input_ids"]) <= taskmodule.max_window - tokens = [ - taskmodule.tokenizer.convert_ids_to_tokens(encoding.inputs["input_ids"]) - for encoding in encodings - ] - assert tokens[0] == [ - "[CLS]", - "at", - "[T]", - "H", - "[/T]", - ".", - "And", - "founded", - "[H]", - "I", - "[/H]", - "[SEP]", - ] - assert tokens[1] == [ - "[CLS]", - ".", - "And", - "[H]", - "it", - "[/H]", - "founded", - "[T]", - "O", - "[/T]", - ".", - "[SEP]", - ] - assert tokens[2] == [ - "[CLS]", - ".", - "And", - "[T]", - "it", - "[/T]", - "founded", - "[H]", - "O", - "[/H]", - ".", - "[SEP]", - ] - - -def test_encode_with_allow_discontinuous_text(documents): - tokenizer_name_or_path = "bert-base-cased" - # tokenizer_name_or_path = "allenai/longformer-scico" - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path=tokenizer_name_or_path, - max_window=12, - allow_discontinuous_text=True, - ) - - assert not taskmodule.is_from_pretrained - taskmodule.prepare(documents) - - assert len(documents) == 7 - encodings = taskmodule.encode(documents) - assert len(encodings) == 3 - - for encoding in encodings: - assert len(encoding.inputs["input_ids"]) <= taskmodule.max_window - tokens = [ - taskmodule.tokenizer.convert_ids_to_tokens(encoding.inputs["input_ids"]) - for encoding in encodings - ] - assert tokens == [ - ["[CLS]", "at", "[T]", "H", "[/T]", "[SEP]", "founded", "[H]", "I", "[/H]", "[SEP]"], - ["[CLS]", "And", "[H]", "it", "[/H]", "founded", "[T]", "O", "[/T]", "[SEP]"], - ["[CLS]", "And", "[T]", "it", "[/T]", "founded", "[H]", "O", "[/H]", "[SEP]"], - ] - - -def test_encode_with_allow_discontinuous_text_and_binary_relations(): - """This checks whether relation arguments at the very beginning or end of the document are - encoded correctly. - - Also, it checks whether the encoding of the consecutive spans that fit within the frame - specified by max_window is correct. - """ - tokenizer_name_or_path = "bert-base-cased" - taskmodule = RETextClassificationWithIndicesTaskModule( - tokenizer_name_or_path=tokenizer_name_or_path, - max_window=128, - allow_discontinuous_text=True, - ) - texts = [ - "Loren ipsun dolor sit anet, consectetur adipisci elit, sed eiusnod tenpor incidunt ut labore et dolore nagna aliqua.", - "Ut enin ad ninin venian, quis nostrun exercitationen ullan corporis suscipit laboriosan, nisi ut aliquid ex ea connodi consequatur.", - "Quis aute iure reprehenderit in voluptate velit esse cillun dolore eu fugiat nulla pariatur.", - "Excepteur sint obcaecat cupiditat non proident, sunt in culpa qui officia deserunt nollit anin id est laborun.", - ] - text_lengths = [len(text) for text in texts] - sep = " " - - doc = TextDocumentWithLabeledSpansAndBinaryRelations( - text=sep.join(texts), - id="123", - ) - - labeled_spans = [] - offset = 0 - for i, text in enumerate(texts): - labeled_spans.append( - LabeledSpan(start=0 + offset, end=text_lengths[i] + offset, label="sentence") - ) - offset += text_lengths[i] + len(sep) - - for span in labeled_spans: - doc.labeled_spans.append(span) - assert doc.labeled_spans.resolve() == [ - ( - "sentence", - "Loren ipsun dolor sit anet, consectetur adipisci elit, sed eiusnod tenpor incidunt ut " - "labore et dolore nagna aliqua.", - ), - ( - "sentence", - "Ut enin ad ninin venian, quis nostrun exercitationen ullan corporis suscipit laboriosan, " - "nisi ut aliquid ex ea connodi consequatur.", - ), - ( - "sentence", - "Quis aute iure reprehenderit in voluptate velit esse cillun dolore eu fugiat nulla pariatur.", - ), - ( - "sentence", - "Excepteur sint obcaecat cupiditat non proident, sunt in culpa qui officia deserunt nollit " - "anin id est laborun.", - ), - ] - - rel_start = BinaryRelation( - head=doc.labeled_spans[0], tail=doc.labeled_spans[2], label="relation", score=1.0 - ) - doc.binary_relations.append(rel_start) - rel_end = BinaryRelation( - head=doc.labeled_spans[-1], tail=doc.labeled_spans[0], label="relation", score=1.0 - ) - doc.binary_relations.append(rel_end) - rel_consecutive = BinaryRelation( - head=doc.labeled_spans[2], tail=doc.labeled_spans[3], label="relation", score=1.0 - ) - doc.binary_relations.append(rel_consecutive) - - # test document where everything is already included in one argument frame - doc2 = TextDocumentWithLabeledSpansAndBinaryRelations("A founded B.", id="123") - doc2.labeled_spans.append(LabeledSpan(start=0, end=1, label="PER")) - doc2.labeled_spans.append(LabeledSpan(start=10, end=11, label="PER")) - assert doc2.labeled_spans.resolve() == [("PER", "A"), ("PER", "B")] - rel = BinaryRelation(head=doc2.labeled_spans[0], tail=doc2.labeled_spans[1], label="relation") - doc2.binary_relations.append(rel) - - taskmodule.prepare([doc, doc2]) - encoded = taskmodule.encode_input(doc) - - decoded_arg_start = taskmodule.tokenizer.decode(encoded[0].inputs["input_ids"]) - decoded_arg_end = taskmodule.tokenizer.decode(encoded[1].inputs["input_ids"]) - decoded_arg_consecutive = taskmodule.tokenizer.decode(encoded[2].inputs["input_ids"]) - - assert ( - decoded_arg_start - == "[CLS] [H] Loren ipsun dolor sit anet, consectetur adipisci elit, sed eiusnod tenpor incidunt ut labore et dolore nagna aliqua. [/H] Ut enin ad ninin venian, quis no [SEP] ex ea connodi consequatur. [T] Quis aute iure reprehenderit in voluptate velit esse cillun dolore eu fugiat nulla pariatur. [/T] Excepteur sint obcaecat cupid [SEP]" - ) - - assert ( - decoded_arg_end - == "[CLS] [T] Loren ipsun dolor sit anet, consectetur adipisci elit, sed eiusnod tenpor incidunt ut labore et dolore nagna aliqua. [/T] Ut enin ad ninin venian, quis no [SEP]se cillun dolore eu fugiat nulla pariatur. [H] Excepteur sint obcaecat cupiditat non proident, sunt in culpa qui officia deserunt nollit anin id est laborun. [/H] [SEP]" - ) - - assert ( - decoded_arg_consecutive - == "[CLS] ex ea connodi consequatur. [H] Quis aute iure reprehenderit in voluptate velit esse cillun dolore eu fugiat nulla pariatur. [/H] [T] Excepteur sint obcaecat cupiditat non proident, sunt in culpa qui officia deserunt nollit anin id est laborun. [/T] [SEP]" - ) - - encoded2 = taskmodule.encode_input(doc2) - assert len(encoded2) == 1 - decoded2 = taskmodule.tokenizer.decode(encoded2[0].inputs["input_ids"]) - assert decoded2 == "[CLS] [H] A [/H] founded [T] B [/T]. [SEP]" - - -def get_arg_token_span( - tokens: List[str], - start_indices: List[int], - end_indices: List[int], - argument_role2idx: Dict[str, int], -) -> Dict[str, List[str]]: - return { - role: tokens[start_indices[argument_role2idx[role]] : end_indices[argument_role2idx[role]]] - for role, idx in argument_role2idx.items() - } - - -def test_encode_with_add_argument_indices(documents): - tokenizer_name_or_path = "bert-base-cased" - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path=tokenizer_name_or_path, - add_argument_indices_to_input=True, - ) - - assert not taskmodule.is_from_pretrained - taskmodule.prepare(documents) - - encodings = taskmodule.encode(documents, encode_target=True) - assert len(encodings) == 7 - batch = taskmodule.collate(encodings) - inputs, targets = batch - tokens = [ - taskmodule.tokenizer.convert_ids_to_tokens(input_ids) for input_ids in inputs["input_ids"] - ] - - arg_spans = [ - get_arg_token_span( - current_tokens, - current_start_indices, - current_end_indices, - taskmodule.argument_role2idx, - ) - for current_tokens, current_start_indices, current_end_indices in zip( - tokens, inputs["pooler_start_indices"].tolist(), inputs["pooler_end_indices"].tolist() - ) - ] - - assert arg_spans == [ - {"head": ["En", "##ti", "##ty", "A"], "tail": ["B"]}, - {"head": ["En", "##ti", "##ty", "G"], "tail": ["H"]}, - {"head": ["En", "##ti", "##ty", "G"], "tail": ["I"]}, - {"head": ["I"], "tail": ["H"]}, - {"head": ["En", "##ti", "##ty", "M"], "tail": ["N"]}, - {"head": ["it"], "tail": ["O"]}, - {"head": ["O"], "tail": ["it"]}, - ] - - -def test_encode_with_add_argument_indices_and_without_insert_markers(documents): - tokenizer_name_or_path = "bert-base-cased" - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path=tokenizer_name_or_path, - add_argument_indices_to_input=True, - insert_markers=False, - ) - - assert not taskmodule.is_from_pretrained - taskmodule.prepare(documents) - - encodings = taskmodule.encode(documents, encode_target=True) - assert len(encodings) == 7 - batch = taskmodule.collate(encodings) - inputs, targets = batch - tokens = [ - taskmodule.tokenizer.convert_ids_to_tokens(input_ids) for input_ids in inputs["input_ids"] - ] - - arg_spans = [ - get_arg_token_span( - current_tokens, - current_start_indices, - current_end_indices, - taskmodule.argument_role2idx, - ) - for current_tokens, current_start_indices, current_end_indices in zip( - tokens, inputs["pooler_start_indices"].tolist(), inputs["pooler_end_indices"].tolist() - ) - ] - - assert arg_spans == [ - {"head": ["En", "##ti", "##ty", "A"], "tail": ["B"]}, - {"head": ["En", "##ti", "##ty", "G"], "tail": ["H"]}, - {"head": ["En", "##ti", "##ty", "G"], "tail": ["I"]}, - {"head": ["I"], "tail": ["H"]}, - {"head": ["En", "##ti", "##ty", "M"], "tail": ["N"]}, - {"head": ["it"], "tail": ["O"]}, - {"head": ["O"], "tail": ["it"]}, - ] - - -def test_find_sublist(): - # default case - assert find_sublist(sub=[2, 3], bigger=[1, 2, 3, 4]) == 1 - # no sublist - assert find_sublist(sub=[2, 3], bigger=[1, 3, 2, 4]) == -1 - # empty sublist: occurs on every position, but first is returned - assert find_sublist(sub=[], bigger=[1, 3, 2, 4]) == 0 - # empty bigger - assert find_sublist(sub=[2, 3], bigger=[]) == -1 - - -def test_encode_with_add_argument_indices_and_windowing(documents): - tokenizer_name_or_path = "bert-base-cased" - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path=tokenizer_name_or_path, - add_argument_indices_to_input=True, - max_window=12, - ) - - assert not taskmodule.is_from_pretrained - taskmodule.prepare(documents) - - encodings = taskmodule.encode(documents, encode_target=True) - assert len(encodings) == 3 - batch = taskmodule.collate(encodings) - inputs, targets = batch - tokens = [ - taskmodule.tokenizer.convert_ids_to_tokens(input_ids) for input_ids in inputs["input_ids"] - ] - - arg_spans = [ - get_arg_token_span( - current_tokens, - current_start_indices, - current_end_indices, - taskmodule.argument_role2idx, - ) - for current_tokens, current_start_indices, current_end_indices in zip( - tokens, inputs["pooler_start_indices"].tolist(), inputs["pooler_end_indices"].tolist() - ) - ] - - assert arg_spans == [ - {"head": ["I"], "tail": ["H"]}, - {"head": ["it"], "tail": ["O"]}, - {"head": ["O"], "tail": ["it"]}, - ] - - -def test_encode_with_add_argument_indices_windowing_and_without_insert_markers(documents): - tokenizer_name_or_path = "bert-base-cased" - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path=tokenizer_name_or_path, - add_argument_indices_to_input=True, - max_window=8, - insert_markers=False, - ) - - assert not taskmodule.is_from_pretrained - taskmodule.prepare(documents) - - encodings = taskmodule.encode(documents, encode_target=True) - assert len(encodings) == 3 - batch = taskmodule.collate(encodings) - inputs, targets = batch - tokens = [ - taskmodule.tokenizer.convert_ids_to_tokens(input_ids) for input_ids in inputs["input_ids"] - ] - - arg_spans = [ - get_arg_token_span( - current_tokens, - current_start_indices, - current_end_indices, - taskmodule.argument_role2idx, - ) - for current_tokens, current_start_indices, current_end_indices in zip( - tokens, inputs["pooler_start_indices"].tolist(), inputs["pooler_end_indices"].tolist() - ) - ] - - assert arg_spans == [ - {"head": ["I"], "tail": ["H"]}, - {"head": ["it"], "tail": ["O"]}, - {"head": ["O"], "tail": ["it"]}, - ] - - -@pytest.mark.parametrize("handle_relations_with_same_arguments", ["keep_first", "keep_none"]) -@pytest.mark.parametrize("add_candidate_relations", [False, True]) -@pytest.mark.parametrize("collect_statistics", [False, True]) -def test_encode_input_multiple_relations_for_same_arguments( - caplog, handle_relations_with_same_arguments, add_candidate_relations, collect_statistics -): - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path="bert-base-cased", - handle_relations_with_same_arguments=handle_relations_with_same_arguments, - collect_statistics=collect_statistics, - add_candidate_relations=add_candidate_relations, - ) - document = TestDocument(text="A founded B.", id="test_doc") - document.entities.append(LabeledSpan(start=0, end=1, label="PER")) - document.entities.append(LabeledSpan(start=10, end=11, label="PER")) - entities = document.entities - assert str(entities[0]) == "A" - assert str(entities[1]) == "B" - document.relations.extend( - [ - BinaryRelation(head=entities[0], tail=entities[1], label="per:founded_by"), - BinaryRelation(head=entities[0], tail=entities[1], label="per:founder"), - BinaryRelation(head=entities[0], tail=entities[1], label="per:founded_by"), - ] - ) - taskmodule.prepare([document]) - - with caplog.at_level(logging.WARNING): - encodings = taskmodule.encode_input(document) - - statistics = taskmodule.get_statistics() - candidate_relation = [enc.metadata["candidate_annotation"] for enc in encodings] - candidate_relation_tuples = [ - (rel.head.resolve(), rel.label, rel.tail.resolve()) for rel in candidate_relation - ] - - if handle_relations_with_same_arguments == "keep_first": - # Note: Warnings are shown only if statistics are disabled. For details see comment at - # src/pie_modules/taskmodules/re_text_classification_with_indices.py:811-818 - expected_warning = ( - "doc.id=test_doc: there are multiple relations with the same arguments " - "(('head', ('PER', 'A')), ('tail', ('PER', 'B'))), but different labels: " - "['per:founded_by', 'per:founder', 'per:founded_by']. We only keep the first " - "occurring relation which has the label='per:founded_by'." - ) - if not add_candidate_relations: - # with 'keep_first', only first relation occurred is kept ('per:founded_by'). - # full duplicate of 'per:founded_by' is removed and appears neither as available, - # nor as skipped in statistics. - assert candidate_relation_tuples == [(("PER", "A"), "per:founded_by", ("PER", "B"))] - if collect_statistics: - assert statistics == { - ("available", "per:founded_by"): 1, - ("available", "per:founder"): 1, - ("skipped_same_arguments", "per:founder"): 1, - ("used", "per:founded_by"): 1, - } - assert caplog.messages == [] - else: - assert statistics == {} - assert caplog.messages == [expected_warning] - - else: - # as above, but with candidate (negative) relations added - assert candidate_relation_tuples == [ - (("PER", "A"), "per:founded_by", ("PER", "B")), - (("PER", "B"), "no_relation", ("PER", "A")), - ] - if collect_statistics: - assert statistics == { - ("available", "per:founded_by"): 1, - ("available", "per:founder"): 1, - ("used", "no_relation"): 1, - ("used", "per:founded_by"): 1, - ("skipped_same_arguments", "per:founder"): 1, - } - assert caplog.messages == [] - else: - assert statistics == {} - assert caplog.messages == [expected_warning] - - elif handle_relations_with_same_arguments == "keep_none": - # Note: Warnings are shown only if statistics are disabled. For details see comment at - # src/pie_modules/taskmodules/re_text_classification_with_indices.py:811-818 - expected_warning = ( - "doc.id=test_doc: there are multiple relations with the same arguments " - "(('head', ('PER', 'A')), ('tail', ('PER', 'B'))), but different labels: " - "['per:founded_by', 'per:founder', 'per:founded_by']. All relations will be removed." - ) - if not add_candidate_relations: - # with 'keep_none' both relations sharing same arguments are removed - # full duplicate of 'per:founded_by' is removed and appears neither as available, - # nor as skipped in statistics. - assert candidate_relation_tuples == [] - if collect_statistics: - assert statistics == { - ("available", "per:founded_by"): 1, - ("available", "per:founder"): 1, - ("skipped_same_arguments", "per:founder"): 1, - ("skipped_same_arguments", "per:founded_by"): 1, - } - assert caplog.messages == [] - else: - assert statistics == {} - assert caplog.messages == [expected_warning] - else: - # all conflicting relations go into the same direction, so we can create a candidate (negative) - # relation for the other direction. - assert candidate_relation_tuples == [(("PER", "B"), "no_relation", ("PER", "A"))] - if collect_statistics: - assert statistics == { - ("available", "per:founded_by"): 1, - ("available", "per:founder"): 1, - ("skipped_same_arguments", "per:founded_by"): 1, - ("skipped_same_arguments", "per:founder"): 1, - ("used", "no_relation"): 1, - } - assert caplog.messages == [] - else: - assert statistics == {} - assert caplog.messages == [expected_warning] - - -def test_encode_input_handle_relations_with_same_arguments_unknown_value(): - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path="bert-base-cased", - handle_relations_with_same_arguments="unknown_value", - ) - document = TestDocument(text="A founded B.", id="test_doc") - document.entities.append(LabeledSpan(start=0, end=1, label="PER")) - document.entities.append(LabeledSpan(start=10, end=11, label="PER")) - document.relations.append( - BinaryRelation( - head=document.entities[0], tail=document.entities[1], label="per:founded_by" - ) - ) - document.relations.append( - BinaryRelation(head=document.entities[0], tail=document.entities[1], label="per:founder") - ) - assert document.relations.resolve() == [ - ("per:founded_by", (("PER", "A"), ("PER", "B"))), - ("per:founder", (("PER", "A"), ("PER", "B"))), - ] - taskmodule.prepare([document]) - - with pytest.raises(ValueError) as excinfo: - taskmodule.encode_input(document) - assert str(excinfo.value) == ( - "'handle_relations_with_same_arguments' must be 'keep_first' or 'keep_none', but got `unknown_value`." - ) - - -@pytest.mark.parametrize("handle_relations_with_same_arguments", ["keep_first", "keep_none"]) -@pytest.mark.parametrize("add_candidate_relations", [False, True]) -@pytest.mark.parametrize("collect_statistics", [False, True]) -def test_encode_input_duplicated_relations( - caplog, handle_relations_with_same_arguments, add_candidate_relations, collect_statistics -): - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path="bert-base-cased", - handle_relations_with_same_arguments=handle_relations_with_same_arguments, - add_candidate_relations=add_candidate_relations, - collect_statistics=collect_statistics, - ) - document = TestDocument(text="A founded B.", id="test_doc") - document.entities.append(LabeledSpan(start=0, end=1, label="PER")) - document.entities.append(LabeledSpan(start=10, end=11, label="PER")) - entities = document.entities - assert str(entities[0]) == "A" - assert str(entities[1]) == "B" - document.relations.extend( - [ - BinaryRelation(head=entities[0], tail=entities[1], label="per:founded_by"), - BinaryRelation(head=entities[0], tail=entities[1], label="per:founded_by"), - ] - ) - taskmodule.prepare([document]) - with caplog.at_level(logging.WARNING): - encodings = taskmodule.encode_input(document) - - statistics = taskmodule.get_statistics() - - assert len(caplog.messages) == 1 - assert ( - caplog.messages[0] == "doc.id=test_doc: Relation annotation " - "`('per:founded_by', (('PER', 'A'), ('PER', 'B')))` is duplicated. We keep " - "only one of them. Duplicate won't appear in statistics either as 'available' or as skipped." - ) - candidate_relation = [enc.metadata["candidate_annotation"] for enc in encodings] - candidate_relation_tuples = [ - (rel.head.resolve(), rel.label, rel.tail.resolve()) for rel in candidate_relation - ] - # equally for 'keep_first' and 'keep_last', full duplicates are not affected and do not appear in statistics, but still - # generate a warning. - if add_candidate_relations: - assert candidate_relation_tuples == [ - (("PER", "A"), "per:founded_by", ("PER", "B")), - (("PER", "B"), "no_relation", ("PER", "A")), - ] - if collect_statistics: - assert statistics == { - ("available", "per:founded_by"): 1, - ("used", "no_relation"): 1, - ("used", "per:founded_by"): 1, - } - else: - assert statistics == {} - else: - assert candidate_relation_tuples == [(("PER", "A"), "per:founded_by", ("PER", "B"))] - if collect_statistics: - assert statistics == { - ("available", "per:founded_by"): 1, - ("used", "per:founded_by"): 1, - } - else: - assert statistics == {} - - -def test_encode_input_argument_role_unknown(documents): - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path="bert-base-cased", - # the tail argument is not in the role_to_marker - argument_role_to_marker={HEAD: "H"}, - ) - taskmodule.prepare(documents) - with pytest.raises(ValueError) as excinfo: - taskmodule.encode_input(documents[1]) - assert ( - str(excinfo.value) == "role='tail' not in known roles=['head'] (did you initialise the " - "taskmodule with the correct argument_role_to_marker dictionary?)" - ) - - -def test_encode_input_with_add_candidate_relations(documents): - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path="bert-base-cased", - add_candidate_relations=True, - ) - taskmodule.prepare(documents) - documents_without_relations = [] - encodings = [] - # just take the first three documents - for doc in documents[:3]: - doc_without_relations = doc.copy() - relations = list(doc_without_relations.relations) - doc_without_relations.relations.clear() - # re-add one relation to test if it is kept - if len(relations) > 0: - doc_without_relations.relations.append(relations[0]) - documents_without_relations.append(doc_without_relations) - encodings.extend(taskmodule.encode(doc_without_relations)) - - assert len(encodings) == 4 - relations = [encoding.metadata["candidate_annotation"] for encoding in encodings] - texts = [encoding.document.text for encoding in encodings] - relation_tuples = [(str(rel.head), rel.label, str(rel.tail)) for rel in relations] - - # There are no entities in the first document, so there are no created relation candidates - - # this relation was kept - assert texts[0] == "Entity A works at B." - assert relation_tuples[0] == ("Entity A", "per:employee_of", "B") - - # the following relations were added - assert texts[1] == "Entity A works at B." - assert relation_tuples[1] == ("B", "no_relation", "Entity A") - assert texts[2] == "Entity C and D." - assert relation_tuples[2] == ("Entity C", "no_relation", "D") - assert texts[3] == "Entity C and D." - assert relation_tuples[3] == ("D", "no_relation", "Entity C") - - -@pytest.fixture -def document_with_nary_relations(): - @dataclasses.dataclass - class TestDocumentWithNaryRelations(TextBasedDocument): - entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") - relations: AnnotationLayer[NaryRelation] = annotation_field(target="entities") - - document = TestDocumentWithNaryRelations( - text="Entity A works at B.", id="doc_with_nary_relations" - ) - document.entities.append(LabeledSpan(start=0, end=8, label="PER")) - document.entities.append(LabeledSpan(start=18, end=19, label="PER")) - document.relations.append( - NaryRelation( - arguments=tuple(document.entities), - roles=tuple(["head", "tail"]), - label="per:employee_of", - ) - ) - return document - - -def test_encode_input_with_add_candidate_relations_with_wrong_relation_type( - document_with_nary_relations, -): - doc = document_with_nary_relations - - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path="bert-base-cased", - add_candidate_relations=True, - argument_role_to_marker={HEAD: "H", "arg2": "T"}, - ) - taskmodule.prepare([doc]) - with pytest.raises(NotImplementedError) as excinfo: - taskmodule.encode_input(doc) - assert ( - str(excinfo.value) - == "doc.id=doc_with_nary_relations: the taskmodule does not yet support adding relation candidates " - "with argument roles other than 'head' and 'tail': ['arg2', 'head']" - ) - - -def test_filter_relations_by_argument_type_whitelist(documents): - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path="bert-base-cased", - add_candidate_relations=True, - argument_type_whitelist=[["PER", "ORG"], ["ORG", "PER"]], - ) - doc = documents[4] - taskmodule.prepare(documents) - - assert doc.entities.resolve() == [("PER", "Entity G"), ("ORG", "H"), ("ORG", "I")] - assert doc.relations.resolve() == [ - ("per:employee_of", (("PER", "Entity G"), ("ORG", "H"))), - ("per:founder", (("PER", "Entity G"), ("ORG", "I"))), - ("org:founded_by", (("ORG", "I"), ("ORG", "H"))), - ] - arguments2relation = {} - for rel in doc.relations: - arguments2relation[get_relation_argument_spans_and_roles(rel)] = rel - assert len(arguments2relation) == 3 - - taskmodule._filter_relations_by_argument_type_whitelist(arguments2relation=arguments2relation) - assert len(arguments2relation) == 2 - - relation_tuples = [rel.resolve() for rel in arguments2relation.values()] - assert relation_tuples[0] == ("per:employee_of", (("PER", "Entity G"), ("ORG", "H"))) - assert relation_tuples[1] == ("per:founder", (("PER", "Entity G"), ("ORG", "I"))) - - assert ("org:founded_by", (("ORG", "I"), ("ORG", "H"))) not in relation_tuples - - -def test_add_candidate_relations_with_argument_type_whitelist(documents): - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path="bert-base-cased", - add_candidate_relations=True, - argument_type_whitelist=[["PER", "ORG"], ["ORG", "PER"]], - ) - doc = documents[4] - taskmodule.prepare(documents) - - assert doc.entities.resolve() == [("PER", "Entity G"), ("ORG", "H"), ("ORG", "I")] - assert doc.relations.resolve() == [ - ("per:employee_of", (("PER", "Entity G"), ("ORG", "H"))), - ("per:founder", (("PER", "Entity G"), ("ORG", "I"))), - ("org:founded_by", (("ORG", "I"), ("ORG", "H"))), - ] - arguments2relation = {} - for rel in doc.relations: - arguments2relation[get_relation_argument_spans_and_roles(rel)] = rel - assert len(arguments2relation) == 3 - - taskmodule._add_candidate_relations( - arguments2relation=arguments2relation, entities=doc.entities - ) - assert len(arguments2relation) == 5 - - relation_tuples = [rel.resolve() for rel in arguments2relation.values()] - - # Original relations from document (aren't affected by whitelist) - assert relation_tuples[0] == ("per:employee_of", (("PER", "Entity G"), ("ORG", "H"))) - assert relation_tuples[1] == ("per:founder", (("PER", "Entity G"), ("ORG", "I"))) - assert relation_tuples[2] == ("org:founded_by", (("ORG", "I"), ("ORG", "H"))) - - # Relation candidate added by _add_candidate_relations() - assert relation_tuples[3] == ("no_relation", (("ORG", "H"), ("PER", "Entity G"))) - assert relation_tuples[4] == ("no_relation", (("ORG", "I"), ("PER", "Entity G"))) - - # Relations not created due to whitelist - assert ("no_relation", (("ORG", "H"), ("ORG", "I"))) not in relation_tuples - - -def test_filter_relations_by_argument_and_relation_type_whitelist(documents): - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path="bert-base-cased", - add_candidate_relations=True, - argument_and_relation_type_whitelist={ - "per:employee_of": [["PER", "ORG"]], - "per:founder": [["PER", "ORG"]], - "org:founded_by": [["ORG", "PER"]], - }, - ) - doc = documents[4] - taskmodule.prepare(documents) - - assert doc.entities.resolve() == [("PER", "Entity G"), ("ORG", "H"), ("ORG", "I")] - assert doc.relations.resolve() == [ - ("per:employee_of", (("PER", "Entity G"), ("ORG", "H"))), - ("per:founder", (("PER", "Entity G"), ("ORG", "I"))), - ("org:founded_by", (("ORG", "I"), ("ORG", "H"))), - ] - arguments2relation = {} - for rel in doc.relations: - arguments2relation[get_relation_argument_spans_and_roles(rel)] = rel - assert len(arguments2relation) == 3 - - taskmodule._filter_relations_by_argument_and_relation_type_whitelist( - arguments2relation=arguments2relation - ) - assert len(arguments2relation) == 2 - - relation_tuples = [rel.resolve() for rel in arguments2relation.values()] - assert relation_tuples[0] == ("per:employee_of", (("PER", "Entity G"), ("ORG", "H"))) - assert relation_tuples[1] == ("per:founder", (("PER", "Entity G"), ("ORG", "I"))) - - assert ("org:founded_by", (("ORG", "I"), ("ORG", "H"))) not in relation_tuples - - -def test_add_candidate_relations_with_argument_and_relation_type_whitelist(documents): - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path="bert-base-cased", - add_candidate_relations=True, - argument_and_relation_type_whitelist={ - "per:employee_of": [["PER", "ORG"]], - "per:founder": [["PER", "ORG"]], - "org:founded_by": [["ORG", "PER"]], - }, - ) - doc = documents[4] - taskmodule.prepare(documents) - - assert doc.entities.resolve() == [("PER", "Entity G"), ("ORG", "H"), ("ORG", "I")] - assert doc.relations.resolve() == [ - ("per:employee_of", (("PER", "Entity G"), ("ORG", "H"))), - ("per:founder", (("PER", "Entity G"), ("ORG", "I"))), - ("org:founded_by", (("ORG", "I"), ("ORG", "H"))), - ] - arguments2relation = {} - for rel in doc.relations: - arguments2relation[get_relation_argument_spans_and_roles(rel)] = rel - assert len(arguments2relation) == 3 - - taskmodule._add_candidate_relations( - arguments2relation=arguments2relation, entities=doc.entities - ) - assert len(arguments2relation) == 5 - - relation_tuples = [rel.resolve() for rel in arguments2relation.values()] - - # Original relations from document (aren't affected by whitelist) - assert relation_tuples[0] == ("per:employee_of", (("PER", "Entity G"), ("ORG", "H"))) - assert relation_tuples[1] == ("per:founder", (("PER", "Entity G"), ("ORG", "I"))) - assert relation_tuples[2] == ("org:founded_by", (("ORG", "I"), ("ORG", "H"))) - - # Relation candidate added by _add_candidate_relations() - assert relation_tuples[3] == ("no_relation", (("ORG", "H"), ("PER", "Entity G"))) - assert relation_tuples[4] == ("no_relation", (("ORG", "I"), ("PER", "Entity G"))) - - # Relations not created due to whitelist - assert ("no_relation", (("ORG", "H"), ("ORG", "I"))) not in relation_tuples - - -def test_encode_input_with_add_reversed_relations(documents): - tokenizer_name_or_path = "bert-base-cased" - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path=tokenizer_name_or_path, - add_reversed_relations=True, - ) - taskmodule.prepare(documents) - encodings = [] - # just take the first three documents - for doc in documents[:3]: - encodings.extend(taskmodule.encode_input(doc)) - - assert len(encodings) == 2 - texts = [encoding.document.text for encoding in encodings] - relations = [encoding.metadata["candidate_annotation"] for encoding in encodings] - relation_tuples = [(str(rel.head), rel.label, str(rel.tail)) for rel in relations] - - # There are no relations in the first and last document, so there are also no new reversed relations - - # this is the original relation - assert texts[0] == "Entity A works at B." - assert relation_tuples[0] == ("Entity A", "per:employee_of", "B") - - # this is the reversed relation - assert texts[1] == "Entity A works at B." - assert relation_tuples[1] == ("B", "per:employee_of_reversed", "Entity A") - - # test that an already reversed relation is not reversed again - document = TestDocument( - text="Entity A works at B.", id="doc_with_relation_with_reversed_suffix" - ) - document.entities.extend( - [LabeledSpan(start=0, end=8, label="PER"), LabeledSpan(start=18, end=19, label="PER")] - ) - document.relations.append( - BinaryRelation( - head=document.entities[1], - tail=document.entities[0], - label=f"per:employee_of{taskmodule.reversed_relation_label_suffix}", - ) - ) - with pytest.raises(ValueError) as excinfo: - taskmodule.encode_input(document) - assert str(excinfo.value) == ( - "doc.id=doc_with_relation_with_reversed_suffix: The relation has the label 'per:employee_of_reversed' " - "which already ends with the reversed_relation_label_suffix='_reversed'. It looks like the relation is " - "already reversed, which is not allowed." - ) - - -def test_prepare_with_add_reversed_relations_with_label_has_suffix(): - tokenizer_name_or_path = "bert-base-cased" - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path=tokenizer_name_or_path, - add_reversed_relations=True, - ) - document = TestDocument( - text="Entity A works at B.", id="doc_with_relation_with_reversed_suffix" - ) - document.entities.extend( - [LabeledSpan(start=0, end=8, label="PER"), LabeledSpan(start=18, end=19, label="PER")] - ) - document.relations.append( - BinaryRelation( - head=document.entities[0], - tail=document.entities[1], - label=f"per:employee_of{taskmodule.reversed_relation_label_suffix}", - ) - ) - - with pytest.raises(ValueError) as excinfo: - taskmodule.prepare([document]) - assert ( - str(excinfo.value) - == "doc.id=doc_with_relation_with_reversed_suffix: the relation label 'per:employee_of_reversed' " - "already ends with the reversed_relation_label_suffix '_reversed', this is not allowed because " - "we would not know if we should strip the suffix and revert the arguments during inference or not" - ) - - -@pytest.mark.parametrize("reverse_symmetric_relations", [False, True]) -def test_encode_input_with_add_reversed_relations_with_symmetric_relations( - reverse_symmetric_relations, caplog -): - document = TestDocument( - text="Entity A is married with B, but likes C, who is married with D.", - id="doc_with_symmetric_relation", - ) - document.entities.extend( - [ - LabeledSpan(start=0, end=8, label="PER"), - LabeledSpan(start=25, end=26, label="PER"), - LabeledSpan(start=38, end=39, label="PER"), - LabeledSpan(start=61, end=62, label="PER"), - ] - ) - assert str(document.entities[0]) == "Entity A" - assert str(document.entities[1]) == "B" - assert str(document.entities[2]) == "C" - assert str(document.entities[3]) == "D" - document.relations.extend( - [ - BinaryRelation( - head=document.entities[0], tail=document.entities[1], label="per:is_married_with" - ), - BinaryRelation( - head=document.entities[0], tail=document.entities[2], label="per:likes" - ), - BinaryRelation( - head=document.entities[2], tail=document.entities[3], label="per:is_married_with" - ), - BinaryRelation( - head=document.entities[3], tail=document.entities[2], label="per:is_married_with" - ), - ] - ) - - tokenizer_name_or_path = "bert-base-cased" - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path=tokenizer_name_or_path, - add_reversed_relations=True, - symmetric_relations=["per:is_married_with"], - reverse_symmetric_relations=reverse_symmetric_relations, - ) - taskmodule.prepare([document]) - encodings = taskmodule.encode_input(document) - relations = [encoding.metadata["candidate_annotation"] for encoding in encodings] - relation_tuples = [ - (str(relation.head), relation.label, str(relation.tail)) for relation in relations - ] - if reverse_symmetric_relations: - assert relation_tuples == [ - ("Entity A", "per:is_married_with", "B"), - ("Entity A", "per:likes", "C"), - ("C", "per:is_married_with", "D"), - ("D", "per:is_married_with", "C"), - ("B", "per:is_married_with", "Entity A"), - ("C", "per:likes_reversed", "Entity A"), - ] - assert len(caplog.messages) == 2 - assert ( - caplog.messages[0] - == "doc.id=doc_with_symmetric_relation: there is already a relation with reversed " - "arguments=(('head', LabeledSpan(start=61, end=62, label='PER', score=1.0)), " - "('tail', LabeledSpan(start=38, end=39, label='PER', score=1.0))) and label=per:is_married_with, " - "so we do not add the reversed relation (with label per:is_married_with) for these arguments" - ) - assert ( - caplog.messages[1] - == "doc.id=doc_with_symmetric_relation: there is already a relation with reversed " - "arguments=(('head', LabeledSpan(start=38, end=39, label='PER', score=1.0)), " - "('tail', LabeledSpan(start=61, end=62, label='PER', score=1.0))) and label=per:is_married_with, " - "so we do not add the reversed relation (with label per:is_married_with) for these arguments" - ) - else: - assert relation_tuples == [ - ("Entity A", "per:is_married_with", "B"), - ("Entity A", "per:likes", "C"), - ("C", "per:is_married_with", "D"), - ("D", "per:is_married_with", "C"), - ("C", "per:likes_reversed", "Entity A"), - ] - assert len(caplog.messages) == 0 - - caplog.clear() - document = TestDocument( - text="Entity A is married with B.", - id="doc_with_reversed_symmetric_relation", - ) - document.entities.append(LabeledSpan(start=0, end=8, label="PER")) - document.entities.append(LabeledSpan(start=25, end=26, label="PER")) - document.relations.append( - BinaryRelation( - head=document.entities[1], tail=document.entities[0], label="per:is_married_with" - ) - ) - encodings = taskmodule.encode_input(document) - relations = [encoding.metadata["candidate_annotation"] for encoding in encodings] - relation_tuples = [ - (str(relation.head), relation.label, str(relation.tail)) for relation in relations - ] - if reverse_symmetric_relations: - assert len(relation_tuples) == 2 - assert relation_tuples[0] == ("B", "per:is_married_with", "Entity A") - assert relation_tuples[1] == ("Entity A", "per:is_married_with", "B") - assert len(caplog.messages) == 1 - assert ( - caplog.messages[0] - == "doc.id=doc_with_reversed_symmetric_relation: The symmetric relation with label 'per:is_married_with' " - "has arguments (('head', LabeledSpan(start=25, end=26, label='PER', score=1.0)), " - "('tail', LabeledSpan(start=0, end=8, label='PER', score=1.0))) which are not sorted by their start " - "and end positions. This may lead to problems during evaluation because we assume that the arguments " - "of symmetric relations were sorted in the beginning and, thus, interpret relations where this is not " - "the case as reversed. All reversed relations will get their arguments swapped during inference in " - "the case of add_reversed_relations=True to remove duplicates. You may consider adding reversed " - "versions of the *symmetric* relations on your own and then setting *reverse_symmetric_relations* " - "to False." - ) - else: - assert len(relation_tuples) == 1 - assert relation_tuples[0] == ("B", "per:is_married_with", "Entity A") - assert len(caplog.messages) == 0 - - -def test_encode_input_with_add_reversed_relations_with_wrong_relation_type( - document_with_nary_relations, -): - doc = document_with_nary_relations - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path="bert-base-cased", - add_reversed_relations=True, - symmetric_relations=["per:employee_of"], - ) - taskmodule.prepare([doc]) - with pytest.raises(NotImplementedError) as excinfo: - taskmodule.encode_input(doc) - assert ( - str(excinfo.value) - == "doc.id=doc_with_nary_relations: the taskmodule does not yet support adding " - "reversed relations for type: " - ) - - -def test_inner_span_distance_overlap(): - dist = distance_inner((0, 2), (1, 3)) - assert dist == -1 - - -def test_span_distance_unknown_type(): - with pytest.raises(ValueError) as excinfo: - span_distance((0, 1), (2, 3), "unknown") - assert str(excinfo.value) == "unknown distance_type=unknown. use one of: center, inner, outer" - - -def test_encode_input_with_max_argument_distance(): - document = TestDocument( - text="Entity A works at B and C.", id="doc_with_three_entities_and_two_relations" - ) - e0 = LabeledSpan(start=0, end=8, label="PER") - e1 = LabeledSpan(start=18, end=19, label="PER") - e2 = LabeledSpan(start=24, end=25, label="PER") - document.entities.extend([e0, e1, e2]) - assert str(document.entities[0]) == "Entity A" - assert str(document.entities[1]) == "B" - assert str(document.entities[2]) == "C" - document.relations.append( - BinaryRelation( - head=document.entities[0], tail=document.entities[1], label="per:employee_of" - ) - ) - document.relations.append( - BinaryRelation( - head=document.entities[0], tail=document.entities[2], label="per:employee_of" - ) - ) - dist_01 = span_distance((e0.start, e0.end), (e1.start, e1.end), "inner") - dist_02 = span_distance((e0.start, e0.end), (e2.start, e2.end), "inner") - assert dist_01 == 10 - assert dist_02 == 16 - - tokenizer_name_or_path = "bert-base-cased" - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path=tokenizer_name_or_path, - max_argument_distance=10, - ) - taskmodule.prepare([document]) - encodings = taskmodule.encode_input(document) - - # there are two relations, but only one is within the max_argument_distance - assert len(encodings) == 1 - relation = encodings[0].metadata["candidate_annotation"] - assert str(relation.head) == "Entity A" - assert str(relation.tail) == "B" - assert relation.label == "per:employee_of" - - -def test_encode_input_with_max_argument_distance_with_wrong_relation_type( - document_with_nary_relations, -): - doc = document_with_nary_relations - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path="bert-base-cased", - max_argument_distance=10, - ) - taskmodule.prepare([doc]) - with pytest.raises(NotImplementedError) as excinfo: - encodings = taskmodule.encode_input(doc) - assert ( - str(excinfo.value) - == "doc.id=doc_with_nary_relations: the taskmodule does not yet support filtering " - "relation candidates for type: " - ) - - -@pytest.mark.parametrize("distance_type", ["inner", "outer", "unknown"]) -def test_encode_input_with_max_argument_distance_tokens(distance_type): - document = TestDocument( - text="Entity A works at B and C.", id="doc_with_three_entities_and_two_relations" - ) - e0 = LabeledSpan(start=0, end=8, label="PER") - e1 = LabeledSpan(start=18, end=19, label="PER") - e2 = LabeledSpan(start=24, end=25, label="PER") - document.entities.extend([e0, e1, e2]) - assert str(document.entities[0]) == "Entity A" - assert str(document.entities[1]) == "B" - assert str(document.entities[2]) == "C" - document.relations.append( - BinaryRelation( - head=document.entities[0], tail=document.entities[1], label="per:employee_of" - ) - ) - document.relations.append( - BinaryRelation( - head=document.entities[0], tail=document.entities[2], label="per:employee_of" - ) - ) - dist_01 = span_distance((e0.start, e0.end), (e1.start, e1.end), "inner") - dist_02 = span_distance((e0.start, e0.end), (e2.start, e2.end), "inner") - assert dist_01 == 10 - assert dist_02 == 16 - - tokenizer_name_or_path = "bert-base-cased" - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path=tokenizer_name_or_path, - max_argument_distance_tokens=3 if distance_type == "inner" else 8, - max_argument_distance_type_tokens=distance_type, - ) - taskmodule.prepare([document]) - if distance_type == "unknown": - with pytest.raises(ValueError) as excinfo: - taskmodule.encode_input(document) - assert ( - str(excinfo.value) == "unknown distance_type=unknown. use one of: center, inner, outer" - ) - return - - encodings = taskmodule.encode_input(document) - - # there are two relations, but only one is within the max_argument_distance - assert len(encodings) == 1 - encoding = encodings[0] - tokens = taskmodule.tokenizer.convert_ids_to_tokens(encoding.inputs["input_ids"]) - assert tokens == [ - "[CLS]", - "[H]", - "En", - "##ti", - "##ty", - "A", - "[/H]", - "works", - "at", - "[T]", - "B", - "[/T]", - "and", - "C", - ".", - "[SEP]", - ] - head_start = tokens.index("[H]") + 1 - head_end = tokens.index("[/H]") - tail_start = tokens.index("[T]") + 1 - tail_end = tokens.index("[/T]") - assert (head_start, head_end, tail_start, tail_end) == (2, 6, 10, 11) - # subtract 2 for the special marker tokens [/H] and [T] - inner_dist = tail_start - head_end - 2 - assert inner_dist == 2 - # subtract 2 for the special marker tokens [H] and [/T] - outer_dist = tail_end - head_start - 2 - assert outer_dist == 7 - - relation = encodings[0].metadata["candidate_annotation"] - assert str(relation.head) == "Entity A" - assert str(relation.tail) == "B" - assert relation.label == "per:employee_of" - - -def test_encode_input_with_unknown_label(): - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path="bert-base-cased", - labels=["rel"], - entity_labels=["a", "b"], - collect_statistics=True, - ) - taskmodule.post_prepare() - - doc = TestDocument(text="hello world", id="doc_with_unknown_label") - doc.entities.append(LabeledSpan(start=0, end=5, label="a")) - doc.entities.append(LabeledSpan(start=6, end=11, label="b")) - doc.relations.append( - BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="unknown") - ) - - task_encodings = taskmodule.encode_input(doc) - assert len(task_encodings) == 0 - - statistics = taskmodule.get_statistics() - assert statistics == {("available", "unknown"): 1, ("skipped_unknown_label", "unknown"): 1} - - -def test_encode_with_empty_partition_layer(documents): - tokenizer_name_or_path = "bert-base-cased" - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path=tokenizer_name_or_path, - partition_annotation="sentences", - ) - taskmodule.prepare(documents) - documents_without_sentences = [] - # just take the first three documents - for doc in documents[:3]: - doc_without_sentences = doc.copy() - doc_without_sentences.sentences.clear() - documents_without_sentences.append(doc_without_sentences) - - encodings = taskmodule.encode(documents_without_sentences) - # since there are no sentences, but we use partition_annotation="sentences", - # there are no encodings - assert len(encodings) == 0 - - -def test_encode_nary_relatio(): - tokenizer_name_or_path = "bert-base-cased" - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path=tokenizer_name_or_path, - argument_role_to_marker={"r1": "R1", "r2": "R2", "r3": "R3"}, - # setting labels and entity_labels makes the taskmodule prepared - labels=["rel"], - entity_labels=["a", "b", "c"], - ) - taskmodule._post_prepare() - - @dataclass - class DocWithNaryRelation(TextBasedDocument): - entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") - relations: AnnotationLayer[NaryRelation] = annotation_field(target="entities") - - doc = DocWithNaryRelation(text="hello my world") - entity1 = LabeledSpan(start=0, end=5, label="a") - entity2 = LabeledSpan(start=6, end=8, label="b") - entity3 = LabeledSpan(start=9, end=14, label="c") - doc.entities.extend([entity1, entity2, entity3]) - doc.relations.append( - NaryRelation( - arguments=tuple([entity1, entity2, entity3]), - roles=tuple(["r1", "r2", "r3"]), - label="rel", - ) - ) - - task_encodings = taskmodule.encode([doc]) - assert len(task_encodings) == 1 - encoding = task_encodings[0] - assert encoding.document == doc - assert encoding.document.text == "hello my world" - rel = encoding.metadata["candidate_annotation"] - assert str(rel.arguments[0]) == "hello" - assert str(rel.arguments[1]) == "my" - assert str(rel.arguments[2]) == "world" - assert rel.label == "rel" - - -def test_encode_unknown_relation_type(): - tokenizer_name_or_path = "bert-base-cased" - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path=tokenizer_name_or_path, - # setting labels and entity_labels makes the taskmodule prepared - labels=["has_wrong_type"], - entity_labels=["a"], - ) - taskmodule._post_prepare() - - @dataclass(frozen=True) - class UnknownRelation(Annotation): - arg: LabeledSpan - label: str - - @dataclass - class DocWithUnknownRelationType(TextBasedDocument): - entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") - relations: AnnotationLayer[UnknownRelation] = annotation_field(target="entities") - - doc = DocWithUnknownRelationType(text="hello world") - entity = LabeledSpan(start=0, end=1, label="a") - doc.entities.append(entity) - doc.relations.append(UnknownRelation(arg=entity, label="has_wrong_type")) - - with pytest.raises(NotImplementedError) as excinfo: - taskmodule.encode([doc]) - assert str(excinfo.value).startswith( - "the taskmodule does not yet support getting relation arguments for type: " - ) and str(excinfo.value).endswith(".UnknownRelation'>") - - -def test_encode_with_unaligned_span(caplog): - tokenizer_name_or_path = "bert-base-cased" - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path=tokenizer_name_or_path, - # setting v and entity_labels makes the taskmodule prepared - labels=["rel"], - entity_labels=["a"], - ) - taskmodule._post_prepare() - - @dataclass - class MyDocument(TextBasedDocument): - entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") - relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities") - - doc = MyDocument(text="hello space", id="doc1") - entity1 = LabeledSpan(start=0, end=5, label="a") - entity2 = LabeledSpan(start=7, end=13, label="a") - entity3 = LabeledSpan(start=6, end=8, label="a") - doc.entities.extend([entity1, entity2, entity3]) - # the start of entity2 is not aligned with a token, but this will get fixed - assert str(entity2) == " space" - doc.relations.append(BinaryRelation(head=entity1, tail=entity2, label="rel")) - # entity3 can not get fixed because it contains only space - assert str(entity3) == " " - doc.relations.append(BinaryRelation(head=entity1, tail=entity3, label="rel")) - - task_encodings = taskmodule.encode([doc]) - # the second relation is skipped because we can not get an aligned token span for it - assert len(task_encodings) == 1 - task_encoding = task_encodings[0] - tokens = taskmodule.tokenizer.convert_ids_to_tokens(task_encoding.inputs["input_ids"]) - assert tokens == ["[CLS]", "[H]", "hello", "[/H]", "[T]", "space", "[/T]", "[SEP]"] - - assert len(caplog.records) == 1 - assert caplog.records[0].levelname == "WARNING" - assert ( - caplog.messages[0] - == "doc.id=doc1: Skipping invalid example, cannot get argument token slice for LabeledSpan(start=6, end=8, label='a', score=1.0): \" \"" - ) - - -def test_encode_with_log_first_n_examples(caplog): - @dataclass - class DocumentWithLabeledEntitiesAndRelations(TextBasedDocument): - entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") - relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities") - - doc = DocumentWithLabeledEntitiesAndRelations(text="hello world", id="doc1") - entity1 = LabeledSpan(start=0, end=5, label="a") - entity2 = LabeledSpan(start=6, end=11, label="a") - doc.entities.extend([entity1, entity2]) - doc.relations.append(BinaryRelation(head=entity1, tail=entity2, label="rel")) - - tokenizer_name_or_path = "bert-base-cased" - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path=tokenizer_name_or_path, - log_first_n_examples=1, - ) - taskmodule.prepare([doc]) - - # we need to set the log level to INFO, otherwise the log messages are not captured - with caplog.at_level(logging.INFO): - task_encodings = taskmodule.encode([doc, doc], encode_target=True) - - # the second example is skipped because log_first_n_examples=1 - assert len(task_encodings) == 2 - assert len(caplog.records) == 5 - assert all([record.levelname == "INFO" for record in caplog.records]) - assert caplog.records[0].message == "*** Example ***" - assert caplog.records[1].message == "doc id: doc1" - assert caplog.records[2].message == "tokens: [CLS] [H] hello [/H] [T] world [/T] [SEP]" - assert caplog.records[3].message == "input_ids: 101 28998 19082 28996 28999 1362 28997 102" - assert caplog.records[4].message == "Expected label: ['rel'] (ids = [1])" - - -@pytest.mark.skipif(condition=not _TABULATE_AVAILABLE, reason="requires the 'tabulate' package") -def test_encode_with_collect_statistics(documents): - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path="bert-base-cased", - collect_statistics=True, - ) - taskmodule.prepare(documents) - task_encodings = taskmodule.encode(documents) - statistics = taskmodule.get_statistics() - assert len(task_encodings) == 7 - - assert statistics == { - ("available", "org:founded_by"): 2, - ("available", "per:employee_of"): 3, - ("available", "per:founder"): 2, - ("used", "org:founded_by"): 2, - ("used", "per:employee_of"): 3, - ("used", "per:founder"): 2, - } - - -def test_get_global_attention(taskmodule, batch, cfg): - global_attention_mask = taskmodule._get_global_attention(input_ids=batch[0]["input_ids"]) - tokens = [ - taskmodule.tokenizer.convert_ids_to_tokens(token_ids) - for token_ids in batch[0]["input_ids"].tolist() - ] - global_attention_tokens = [ - [tok for tok, m in zip(tkns, glob_attn_mask) if m] - for tkns, glob_attn_mask in zip(tokens, global_attention_mask) - ] - pad_tok = taskmodule.tokenizer.pad_token - not_global_attention_tokens = [ - [tok for tok, m in zip(tkns, glob_attn_mask) if not (m or tok == pad_tok)] - for tkns, glob_attn_mask in zip(tokens, global_attention_mask) - ] - if not cfg.get("append_markers", False): - torch.testing.assert_close( - global_attention_mask, - torch.tensor( - [ - [1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0], - ] - ), - ) - assert not_global_attention_tokens == [ - ["En", "##ti", "##ty", "A", "works", "at", "B", ".", "[SEP]"], - [ - "First", - "sentence", - ".", - "En", - "##ti", - "##ty", - "G", - "works", - "at", - "H", - ".", - "And", - "founded", - "I", - ".", - "[SEP]", - ], - ] - else: - torch.testing.assert_close( - global_attention_mask, - torch.tensor( - [ - [1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0], - ] - ), - ) - assert not_global_attention_tokens == [ - ["En", "##ti", "##ty", "A", "works", "at", "B", ".", "[SEP]", "[SEP]", "[SEP]"], - [ - "First", - "sentence", - ".", - "En", - "##ti", - "##ty", - "G", - "works", - "at", - "H", - ".", - "And", - "founded", - "I", - ".", - "[SEP]", - "[SEP]", - "[SEP]", - ], - ] - - if cfg == {"add_type_to_marker": False, "append_markers": False}: - assert global_attention_tokens == [ - ["[CLS]", "[H]", "[/H]", "[T]", "[/T]"], - ["[CLS]", "[H]", "[/H]", "[T]", "[/T]"], - ] - elif cfg == {"add_type_to_marker": True, "append_markers": False}: - assert global_attention_tokens == [ - ["[CLS]", "[H:PER]", "[/H:PER]", "[T:ORG]", "[/T:ORG]"], - ["[CLS]", "[H:PER]", "[/H:PER]", "[T:ORG]", "[/T:ORG]"], - ] - elif cfg == {"add_type_to_marker": False, "append_markers": True}: - assert global_attention_tokens == [ - ["[CLS]", "[H]", "[/H]", "[T]", "[/T]", "[H=PER]", "[T=ORG]"], - ["[CLS]", "[H]", "[/H]", "[T]", "[/T]", "[H=PER]", "[T=ORG]"], - ] - elif cfg == {"add_type_to_marker": True, "append_markers": True}: - assert global_attention_tokens == [ - ["[CLS]", "[H:PER]", "[/H:PER]", "[T:ORG]", "[/T:ORG]", "[H=PER]", "[T=ORG]"], - ["[CLS]", "[H:PER]", "[/H:PER]", "[T:ORG]", "[/T:ORG]", "[H=PER]", "[T=ORG]"], - ] - else: - raise ValueError(f"unexpected config: {cfg}") - - -def get_metric_state(metric_or_collection: Union[Metric, MetricCollection]) -> Dict[str, Any]: - if isinstance(metric_or_collection, Metric): - return { - k: v.tolist() for k, v in flatten_dict_s(metric_or_collection.metric_state).items() - } - elif isinstance(metric_or_collection, MetricCollection): - return flatten_dict_s({k: get_metric_state(v) for k, v in metric_or_collection.items()}) - else: - raise ValueError(f"unsupported type: {type(metric_or_collection)}") - - -def test_configure_model_metric(documents, taskmodule): - task_encodings = taskmodule.encode(documents, encode_target=True) - batch = taskmodule.collate(task_encodings) - - metric = taskmodule.configure_model_metric(stage="train") - assert isinstance(metric, (Metric, MetricCollection)) - state = get_metric_state(metric) - assert state == { - "micro/f1_without_tn/tp": [0], - "micro/f1_without_tn/fp": [0], - "micro/f1_without_tn/tn": [0], - "micro/f1_without_tn/fn": [0], - "with_tn/f1_per_label/tp": [0, 0, 0, 0], - "with_tn/f1_per_label/fp": [0, 0, 0, 0], - "with_tn/f1_per_label/tn": [0, 0, 0, 0], - "with_tn/f1_per_label/fn": [0, 0, 0, 0], - "with_tn/macro/f1/tp": [0, 0, 0, 0], - "with_tn/macro/f1/fp": [0, 0, 0, 0], - "with_tn/macro/f1/tn": [0, 0, 0, 0], - "with_tn/macro/f1/fn": [0, 0, 0, 0], - "with_tn/micro/f1/tp": [0], - "with_tn/micro/f1/fp": [0], - "with_tn/micro/f1/tn": [0], - "with_tn/micro/f1/fn": [0], - } - assert metric.compute() == { - "no_relation/f1": tensor(0.0), - "org:founded_by/f1": tensor(0.0), - "per:employee_of/f1": tensor(0.0), - "per:founder/f1": tensor(0.0), - "macro/f1": tensor(0.0), - "micro/f1": tensor(0.0), - "micro/f1_without_tn": tensor(0.0), - } - - targets = batch[1] - metric.update(targets, targets) - state = get_metric_state(metric) - assert state == { - "micro/f1_without_tn/tp": [7], - "micro/f1_without_tn/fp": [0], - "micro/f1_without_tn/tn": [21], - "micro/f1_without_tn/fn": [0], - "with_tn/f1_per_label/tp": [0, 2, 3, 2], - "with_tn/f1_per_label/fp": [0, 0, 0, 0], - "with_tn/f1_per_label/tn": [7, 5, 4, 5], - "with_tn/f1_per_label/fn": [0, 0, 0, 0], - "with_tn/macro/f1/tp": [0, 2, 3, 2], - "with_tn/macro/f1/fp": [0, 0, 0, 0], - "with_tn/macro/f1/tn": [7, 5, 4, 5], - "with_tn/macro/f1/fn": [0, 0, 0, 0], - "with_tn/micro/f1/tp": [7], - "with_tn/micro/f1/fp": [0], - "with_tn/micro/f1/tn": [21], - "with_tn/micro/f1/fn": [0], - } - assert metric.compute() == { - "no_relation/f1": tensor(0.0), - "org:founded_by/f1": tensor(1.0), - "per:employee_of/f1": tensor(1.0), - "per:founder/f1": tensor(1.0), - "macro/f1": tensor(1.0), - "micro/f1": tensor(1.0), - "micro/f1_without_tn": tensor(1.0), - } - - metric.reset() - modified_targets = {"labels": torch.tensor([2, 2, 3, 1, 2, 0, 1])} - # three positive matches and one true negative - random_predictions = {"labels": torch.tensor([1, 1, 3, 1, 2, 0, 0])} - metric.update(random_predictions, modified_targets) - state = get_metric_state(metric) - assert state == { - "micro/f1_without_tn/tp": [3], - "micro/f1_without_tn/fp": [3], - "micro/f1_without_tn/tn": [15], - "micro/f1_without_tn/fn": [3], - "with_tn/f1_per_label/tp": [1, 1, 1, 1], - "with_tn/f1_per_label/fp": [1, 2, 0, 0], - "with_tn/f1_per_label/tn": [5, 3, 4, 6], - "with_tn/f1_per_label/fn": [0, 1, 2, 0], - "with_tn/macro/f1/tp": [1, 1, 1, 1], - "with_tn/macro/f1/fp": [1, 2, 0, 0], - "with_tn/macro/f1/tn": [5, 3, 4, 6], - "with_tn/macro/f1/fn": [0, 1, 2, 0], - "with_tn/micro/f1/tp": [4], - "with_tn/micro/f1/fp": [3], - "with_tn/micro/f1/tn": [18], - "with_tn/micro/f1/fn": [3], - } - # created with torch.set_printoptions(precision=6) - torch.testing.assert_close( - metric.compute(), - { - "no_relation/f1": tensor(0.666667), - "org:founded_by/f1": tensor(0.400000), - "per:employee_of/f1": tensor(0.500000), - "per:founder/f1": tensor(1.0), - "macro/f1": tensor(0.641667), - "micro/f1": tensor(0.571429), - "micro/f1_without_tn": tensor(0.500000), - }, - ) - - # no targets and no predictions - metric.reset() - no_targets = {"labels": torch.tensor([0, 0, 0])} - no_predictions = {"labels": torch.tensor([0, 0, 0])} - metric.update(no_targets, no_predictions) - state = get_metric_state(metric) - - assert state == { - "micro/f1_without_tn/tp": [0], - "micro/f1_without_tn/fp": [0], - "micro/f1_without_tn/tn": [0], - "micro/f1_without_tn/fn": [0], - "with_tn/f1_per_label/tp": [3, 0, 0, 0], - "with_tn/f1_per_label/fp": [0, 0, 0, 0], - "with_tn/f1_per_label/tn": [0, 3, 3, 3], - "with_tn/f1_per_label/fn": [0, 0, 0, 0], - "with_tn/macro/f1/tp": [3, 0, 0, 0], - "with_tn/macro/f1/fp": [0, 0, 0, 0], - "with_tn/macro/f1/tn": [0, 3, 3, 3], - "with_tn/macro/f1/fn": [0, 0, 0, 0], - "with_tn/micro/f1/tp": [3], - "with_tn/micro/f1/fp": [0], - "with_tn/micro/f1/tn": [9], - "with_tn/micro/f1/fn": [0], - } - torch.testing.assert_close( - metric.compute(), - { - "micro/f1_without_tn": tensor(0.0), - "no_relation/f1": tensor(1.0), - "org:founded_by/f1": tensor(0.0), - "per:employee_of/f1": tensor(0.0), - "per:founder/f1": tensor(0.0), - "macro/f1": tensor(1.0), - "micro/f1": tensor(1.0), - }, - ) - - # ensure that the metric can be pickled - pickle.dumps(metric) - - -def get_bio_tag(tag_id: int, idx2label: Dict[int, str]) -> str: - if tag_id == 0: - return "O" - tag_id -= 1 - label = idx2label[tag_id // 2] - if tag_id % 2 == 0: - return f"B-{label}" - else: - return f"I-{label}" - - -def test_encode_without_insert_marker_but_argument_tags(documents): - tokenizer_name_or_path = "bert-base-cased" - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path=tokenizer_name_or_path, - insert_markers=False, - add_argument_tags_to_input=True, - ) - assert not taskmodule.is_from_pretrained - taskmodule.prepare(documents) - - assert len(documents) == 7 - encodings = taskmodule.encode(documents) - batch = taskmodule.collate(encodings) - inputs, targets = batch - tokens = [ - taskmodule.tokenizer.convert_ids_to_tokens(input_ids) for input_ids in inputs["input_ids"] - ] - - idx2role = {v: k for k, v in taskmodule.argument_role2idx.items()} - argument_tag_ids = [ - [get_bio_tag(tag_id, idx2role) for tag_id in (argument_tags - 1).tolist() if tag_id >= 0] - for argument_tags in inputs["argument_tags"] - ] - tokens_with_tags = [ - [(tok, tag) for tok, tag in zip(tkns, tags)] - for tkns, tags in zip(tokens, argument_tag_ids) - ] - assert tokens_with_tags == [ - [ - ("[CLS]", "O"), - ("En", "B-head"), - ("##ti", "I-head"), - ("##ty", "I-head"), - ("A", "I-head"), - ("works", "O"), - ("at", "O"), - ("B", "B-tail"), - (".", "O"), - ("[SEP]", "O"), - ], - [ - ("[CLS]", "O"), - ("First", "O"), - ("sentence", "O"), - (".", "O"), - ("En", "B-head"), - ("##ti", "I-head"), - ("##ty", "I-head"), - ("G", "I-head"), - ("works", "O"), - ("at", "O"), - ("H", "B-tail"), - (".", "O"), - ("And", "O"), - ("founded", "O"), - ("I", "O"), - (".", "O"), - ("[SEP]", "O"), - ], - [ - ("[CLS]", "O"), - ("First", "O"), - ("sentence", "O"), - (".", "O"), - ("En", "B-head"), - ("##ti", "I-head"), - ("##ty", "I-head"), - ("G", "I-head"), - ("works", "O"), - ("at", "O"), - ("H", "O"), - (".", "O"), - ("And", "O"), - ("founded", "O"), - ("I", "B-tail"), - (".", "O"), - ("[SEP]", "O"), - ], - [ - ("[CLS]", "O"), - ("First", "O"), - ("sentence", "O"), - (".", "O"), - ("En", "O"), - ("##ti", "O"), - ("##ty", "O"), - ("G", "O"), - ("works", "O"), - ("at", "O"), - ("H", "B-tail"), - (".", "O"), - ("And", "O"), - ("founded", "O"), - ("I", "B-head"), - (".", "O"), - ("[SEP]", "O"), - ], - [ - ("[CLS]", "O"), - ("First", "O"), - ("sentence", "O"), - (".", "O"), - ("En", "B-head"), - ("##ti", "I-head"), - ("##ty", "I-head"), - ("M", "I-head"), - ("works", "O"), - ("at", "O"), - ("N", "B-tail"), - (".", "O"), - ("And", "O"), - ("it", "O"), - ("founded", "O"), - ("O", "O"), - (".", "O"), - ("[SEP]", "O"), - ], - [ - ("[CLS]", "O"), - ("First", "O"), - ("sentence", "O"), - (".", "O"), - ("En", "O"), - ("##ti", "O"), - ("##ty", "O"), - ("M", "O"), - ("works", "O"), - ("at", "O"), - ("N", "O"), - (".", "O"), - ("And", "O"), - ("it", "B-head"), - ("founded", "O"), - ("O", "B-tail"), - (".", "O"), - ("[SEP]", "O"), - ], - [ - ("[CLS]", "O"), - ("First", "O"), - ("sentence", "O"), - (".", "O"), - ("En", "O"), - ("##ti", "O"), - ("##ty", "O"), - ("M", "O"), - ("works", "O"), - ("at", "O"), - ("N", "O"), - (".", "O"), - ("And", "O"), - ("it", "B-tail"), - ("founded", "O"), - ("O", "B-head"), - (".", "O"), - ("[SEP]", "O"), - ], - ] - - -@pytest.mark.parametrize("add_argument_indices_to_input", [True, False]) -def test_encode_without_insert_marker_but_argument_tags_and_windowing( - documents, add_argument_indices_to_input -): - tokenizer_name_or_path = "bert-base-cased" - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path=tokenizer_name_or_path, - add_argument_indices_to_input=add_argument_indices_to_input, - add_argument_tags_to_input=True, - max_window=8, - insert_markers=False, - ) - assert not taskmodule.is_from_pretrained - taskmodule.prepare(documents) - - encodings = taskmodule.encode(documents, encode_target=True) - assert len(encodings) == 3 - batch = taskmodule.collate(encodings) - inputs, targets = batch - tokens = [ - taskmodule.tokenizer.convert_ids_to_tokens(input_ids) for input_ids in inputs["input_ids"] - ] - - if add_argument_indices_to_input: - arg_spans = [ - get_arg_token_span( - current_tokens, - current_start_indices, - current_end_indices, - taskmodule.argument_role2idx, - ) - for current_tokens, current_start_indices, current_end_indices in zip( - tokens, - inputs["pooler_start_indices"].tolist(), - inputs["pooler_end_indices"].tolist(), - ) - ] - - assert arg_spans == [ - {"head": ["I"], "tail": ["H"]}, - {"head": ["it"], "tail": ["O"]}, - {"head": ["O"], "tail": ["it"]}, - ] - - idx2role = {v: k for k, v in taskmodule.argument_role2idx.items()} - argument_tag_ids = [ - [get_bio_tag(tag_id, idx2role) for tag_id in (argument_tags - 1).tolist() if tag_id >= 0] - for argument_tags in inputs["argument_tags"] - ] - tokens_with_tags = [ - [(tok, tag) for tok, tag in zip(tkns, tags)] - for tkns, tags in zip(tokens, argument_tag_ids) - ] - assert tokens_with_tags == [ - [ - ("[CLS]", "O"), - ("at", "O"), - ("H", "B-tail"), - (".", "O"), - ("And", "O"), - ("founded", "O"), - ("I", "B-head"), - ("[SEP]", "O"), - ], - [ - ("[CLS]", "O"), - (".", "O"), - ("And", "O"), - ("it", "B-head"), - ("founded", "O"), - ("O", "B-tail"), - (".", "O"), - ("[SEP]", "O"), - ], - [ - ("[CLS]", "O"), - (".", "O"), - ("And", "O"), - ("it", "B-tail"), - ("founded", "O"), - ("O", "B-head"), - (".", "O"), - ("[SEP]", "O"), - ], - ] - - -@pytest.mark.parametrize("insert_markers", [True, False]) -def test_encode_with_add_entity_tags_to_input(documents, insert_markers): - tokenizer_name_or_path = "bert-base-cased" - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path=tokenizer_name_or_path, - add_entity_tags_to_input=True, - insert_markers=insert_markers, - ) - assert not taskmodule.is_from_pretrained - taskmodule.prepare(documents) - - encodings = taskmodule.encode(documents) - assert len(encodings) == 7 - batch = taskmodule.collate(encodings) - inputs, targets = batch - tokens = [ - taskmodule.tokenizer.convert_ids_to_tokens(input_ids) for input_ids in inputs["input_ids"] - ] - - idx2label = {k: v for k, v in enumerate(taskmodule.entity_labels)} - entity_tag_ids = [ - [get_bio_tag(tag_id, idx2label) for tag_id in (argument_tags - 1).tolist() if tag_id >= 0] - for argument_tags in inputs["entity_tags"] - ] - tokens_with_tags = [ - [(tok, tag) for tok, tag in zip(tkns, tags)] for tkns, tags in zip(tokens, entity_tag_ids) - ] - if insert_markers: - assert tokens_with_tags[:3] == [ - [ - ("[CLS]", "O"), - ("[H]", "O"), - ("En", "B-PER"), - ("##ti", "I-PER"), - ("##ty", "I-PER"), - ("A", "I-PER"), - ("[/H]", "O"), - ("works", "O"), - ("at", "O"), - ("[T]", "O"), - ("B", "B-ORG"), - ("[/T]", "O"), - (".", "O"), - ("[SEP]", "O"), - ], - [ - ("[CLS]", "O"), - ("First", "O"), - ("sentence", "O"), - (".", "O"), - ("[H]", "O"), - ("En", "B-PER"), - ("##ti", "I-PER"), - ("##ty", "I-PER"), - ("G", "I-PER"), - ("[/H]", "O"), - ("works", "O"), - ("at", "O"), - ("[T]", "O"), - ("H", "B-ORG"), - ("[/T]", "O"), - (".", "O"), - ("And", "O"), - ("founded", "O"), - ("I", "B-ORG"), - (".", "O"), - ("[SEP]", "O"), - ], - [ - ("[CLS]", "O"), - ("First", "O"), - ("sentence", "O"), - (".", "O"), - ("[H]", "O"), - ("En", "B-PER"), - ("##ti", "I-PER"), - ("##ty", "I-PER"), - ("G", "I-PER"), - ("[/H]", "O"), - ("works", "O"), - ("at", "O"), - ("H", "B-ORG"), - (".", "O"), - ("And", "O"), - ("founded", "O"), - ("[T]", "O"), - ("I", "B-ORG"), - ("[/T]", "O"), - (".", "O"), - ("[SEP]", "O"), - ], - ] - else: - assert tokens_with_tags[:3] == [ - [ - ("[CLS]", "O"), - ("En", "B-PER"), - ("##ti", "I-PER"), - ("##ty", "I-PER"), - ("A", "I-PER"), - ("works", "O"), - ("at", "O"), - ("B", "B-ORG"), - (".", "O"), - ("[SEP]", "O"), - ], - [ - ("[CLS]", "O"), - ("First", "O"), - ("sentence", "O"), - (".", "O"), - ("En", "B-PER"), - ("##ti", "I-PER"), - ("##ty", "I-PER"), - ("G", "I-PER"), - ("works", "O"), - ("at", "O"), - ("H", "B-ORG"), - (".", "O"), - ("And", "O"), - ("founded", "O"), - ("I", "B-ORG"), - (".", "O"), - ("[SEP]", "O"), - ], - [ - ("[CLS]", "O"), - ("First", "O"), - ("sentence", "O"), - (".", "O"), - ("En", "B-PER"), - ("##ti", "I-PER"), - ("##ty", "I-PER"), - ("G", "I-PER"), - ("works", "O"), - ("at", "O"), - ("H", "B-ORG"), - (".", "O"), - ("And", "O"), - ("founded", "O"), - ("I", "B-ORG"), - (".", "O"), - ("[SEP]", "O"), - ], - ] - - -@pytest.mark.parametrize("insert_markers", [True, False]) -def test_encode_with_add_entity_tags_to_input_windowing(documents, insert_markers): - tokenizer_name_or_path = "bert-base-cased" - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path=tokenizer_name_or_path, - add_entity_tags_to_input=True, - insert_markers=insert_markers, - max_window=12 if insert_markers else 8, - ) - assert not taskmodule.is_from_pretrained - taskmodule.prepare(documents) - - encodings = taskmodule.encode(documents, encode_target=True) - assert len(encodings) == 3 - batch = taskmodule.collate(encodings) - inputs, targets = batch - tokens = [ - taskmodule.tokenizer.convert_ids_to_tokens(input_ids) for input_ids in inputs["input_ids"] - ] - - idx2label = {k: v for k, v in enumerate(taskmodule.entity_labels)} - entity_tag_ids = [ - [get_bio_tag(tag_id, idx2label) for tag_id in (argument_tags - 1).tolist() if tag_id >= 0] - for argument_tags in inputs["entity_tags"] - ] - tokens_with_tags = [ - [(tok, tag) for tok, tag in zip(tkns, tags)] for tkns, tags in zip(tokens, entity_tag_ids) - ] - - if insert_markers: - assert tokens_with_tags == [ - [ - ("[CLS]", "O"), - ("at", "O"), - ("[T]", "O"), - ("H", "B-ORG"), - ("[/T]", "O"), - (".", "O"), - ("And", "O"), - ("founded", "O"), - ("[H]", "O"), - ("I", "B-ORG"), - ("[/H]", "O"), - ("[SEP]", "O"), - ], - [ - ("[CLS]", "O"), - (".", "O"), - ("And", "O"), - ("[H]", "O"), - ("it", "B-PER"), - ("[/H]", "O"), - ("founded", "O"), - ("[T]", "O"), - ("O", "B-ORG"), - ("[/T]", "O"), - (".", "O"), - ("[SEP]", "O"), - ], - [ - ("[CLS]", "O"), - (".", "O"), - ("And", "O"), - ("[T]", "O"), - ("it", "B-PER"), - ("[/T]", "O"), - ("founded", "O"), - ("[H]", "O"), - ("O", "B-ORG"), - ("[/H]", "O"), - (".", "O"), - ("[SEP]", "O"), - ], - ] - else: - assert tokens_with_tags == [ - [ - ("[CLS]", "O"), - ("at", "O"), - ("H", "B-ORG"), - (".", "O"), - ("And", "O"), - ("founded", "O"), - ("I", "B-ORG"), - ("[SEP]", "O"), - ], - [ - ("[CLS]", "O"), - (".", "O"), - ("And", "O"), - ("it", "B-PER"), - ("founded", "O"), - ("O", "B-ORG"), - (".", "O"), - ("[SEP]", "O"), - ], - [ - ("[CLS]", "O"), - (".", "O"), - ("And", "O"), - ("it", "B-PER"), - ("founded", "O"), - ("O", "B-ORG"), - (".", "O"), - ("[SEP]", "O"), - ], - ] - - -@pytest.mark.parametrize("add_candidate_relations", [False, True]) -def test_create_annotations_from_output(add_candidate_relations): - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path="bert-base-cased", - # pass in the labels and entity_labels to avoid calling prepare - # (which would required documents to collect the labels from) - labels=["org:founded_by", "per:employee_of", "per:founder"], - entity_labels=["PER", "ORG"], - # we want to test the effect of creating candidate relations - add_candidate_relations=add_candidate_relations, - ) - # just call post_prepare to set up the taskmodule since labels - # and entity_labels are already set - taskmodule.post_prepare() - - entities = [ - LabeledSpan(start=16, end=24, label="PER"), - LabeledSpan(start=34, end=35, label="ORG"), - LabeledSpan(start=49, end=50, label="ORG"), - ] - - assert taskmodule.none_label == "no_relation" - candidate_relations = [ - BinaryRelation(head=entities[0], tail=entities[1], label="no_relation"), - BinaryRelation(head=entities[0], tail=entities[2], label="no_relation"), - BinaryRelation(head=entities[2], tail=entities[1], label="no_relation"), - ] - - # Just create the task encodings with dummy inputs and a dummy document since - # we do not want to pass them into the model, but add correct metadata - # (which is used to create the annotations). - task_encodings = [ - TaskEncoding(inputs={}, metadata={"candidate_annotation": rel}, document=Document()) - for rel in candidate_relations - ] - unbatched_model_outputs = [ - {"labels": ["per:employee_of"], "probabilities": [0.6000000238418579]}, - {"labels": ["per:founder"], "probabilities": [0.5]}, - {"labels": ["no_relation"], "probabilities": [0.6000000238418579]}, - ] - - result_flat = [] - for i in range(len(unbatched_model_outputs)): - result_flat.extend( - list( - taskmodule.create_annotations_from_output( - task_encoding=task_encodings[i], task_output=unbatched_model_outputs[i] - ) - ) - ) - - # The entities need to be added to a document. This is only required to resolve - # the relations later on for better readability! - document = TestDocument(text="First sentence. Entity G works at H. And founded I.") - document.entities.extend(entities) - - # this would be the "model input" - assert [rel.resolve() for rel in candidate_relations] == [ - ("no_relation", (("PER", "Entity G"), ("ORG", "H"))), - ("no_relation", (("PER", "Entity G"), ("ORG", "I"))), - ("no_relation", (("ORG", "I"), ("ORG", "H"))), - ] - - # this is the final "output" - relations_resolved_with_score = [ - (rel.resolve(), round(rel.score, 4)) for _, rel in result_flat - ] - if add_candidate_relations: - # if candidate relations were added, the no-relation is removed - assert relations_resolved_with_score == [ - (("per:employee_of", (("PER", "Entity G"), ("ORG", "H"))), 0.6), - (("per:founder", (("PER", "Entity G"), ("ORG", "I"))), 0.5), - ] - else: - # if no candidate relations were added, the no-relation is kept - assert relations_resolved_with_score == [ - (("per:employee_of", (("PER", "Entity G"), ("ORG", "H"))), 0.6), - (("per:founder", (("PER", "Entity G"), ("ORG", "I"))), 0.5), - (("no_relation", (("ORG", "I"), ("ORG", "H"))), 0.6), - ] - - -@pytest.mark.parametrize("as_list", [False, True]) -@pytest.mark.parametrize("add_candidate_relations", [False, True]) -def test_create_annotations_from_output_with_argument_and_relation_type_whitelist( - add_candidate_relations, as_list -): - if as_list: - argument_and_relation_type_whitelist = [ - ["per:employee_of", "PER", "ORG"], - ["per:founder", "PER", "ORG"], - ["org:founded_by", "ORG", "PER"], - ["no_relation", "PER", "ORG"], - ["no_relation", "ORG", "PER"], - ] - else: - argument_and_relation_type_whitelist = { - "per:employee_of": [["PER", "ORG"]], - "per:founder": [["PER", "ORG"]], - "org:founded_by": [["ORG", "PER"]], - "no_relation": [["PER", "ORG"], ["ORG", "PER"]], - } - taskmodule = RETextClassificationWithIndicesTaskModule( - relation_annotation="relations", - tokenizer_name_or_path="bert-base-cased", - # pass in the labels and entity_labels to avoid calling prepare - # (which would required documents to collect the labels from) - labels=["org:founded_by", "per:employee_of", "per:founder"], - entity_labels=["PER", "ORG"], - # we want to test the effect of creating candidate relations - add_candidate_relations=add_candidate_relations, - argument_and_relation_type_whitelist=argument_and_relation_type_whitelist, - ) - # just call post_prepare to set up the taskmodule since labels - # and entity_labels are already set - taskmodule.post_prepare() - - entities = [ - LabeledSpan(start=16, end=24, label="PER"), - LabeledSpan(start=34, end=35, label="ORG"), - LabeledSpan(start=49, end=50, label="ORG"), - ] - - assert taskmodule.none_label == "no_relation" - candidate_relations = [ - BinaryRelation(head=entities[0], tail=entities[1], label="no_relation"), - BinaryRelation(head=entities[0], tail=entities[2], label="no_relation"), - BinaryRelation(head=entities[2], tail=entities[0], label="no_relation"), - BinaryRelation(head=entities[2], tail=entities[1], label="no_relation"), - BinaryRelation(head=entities[1], tail=entities[2], label="no_relation"), - ] - - # Just create the task encodings with dummy inputs and a dummy document since - # we do not want to pass them into the model, but add correct metadata - # (which is used to create the annotations). - task_encodings = [ - TaskEncoding(inputs={}, metadata={"candidate_annotation": rel}, document=Document()) - for rel in candidate_relations - ] - unbatched_model_outputs = [ - {"labels": ["per:employee_of"], "probabilities": [0.6000000238418579]}, - {"labels": ["per:founder"], "probabilities": [0.5]}, - {"labels": ["no_relation"], "probabilities": [0.6000000238418579]}, - {"labels": ["org:founded_by"], "probabilities": [0.6000000238418579]}, - {"labels": ["no_relation"], "probabilities": [0.6000000238418579]}, - ] - - result_flat = [] - for i in range(len(unbatched_model_outputs)): - result_flat.extend( - list( - taskmodule.create_annotations_from_output( - task_encoding=task_encodings[i], task_output=unbatched_model_outputs[i] - ) - ) - ) - - # The entities need to be added to a document. This is only required to resolve - # the relations later on for better readability! - document = TestDocument(text="First sentence. Entity G works at H. And founded I.") - document.entities.extend(entities) - - # this would be the "model input" - assert [rel.resolve() for rel in candidate_relations] == [ - ("no_relation", (("PER", "Entity G"), ("ORG", "H"))), - ("no_relation", (("PER", "Entity G"), ("ORG", "I"))), - ("no_relation", (("ORG", "I"), ("PER", "Entity G"))), - ("no_relation", (("ORG", "I"), ("ORG", "H"))), - ("no_relation", (("ORG", "H"), ("ORG", "I"))), - ] - - # this is the final "output" - relations_resolved_with_score = [ - (rel.resolve(), round(rel.score, 4)) for _, rel in result_flat - ] - if add_candidate_relations: - # if candidate relations were added, no-relations are removed - # relations with wrong entity types are also removed. - assert relations_resolved_with_score == [ - (("per:employee_of", (("PER", "Entity G"), ("ORG", "H"))), 0.6), - (("per:founder", (("PER", "Entity G"), ("ORG", "I"))), 0.5), - ] - else: - # if no candidate relations were added, only relations not fitting the filter - # are removed. We explicitly need to add "no_relation" with possible argument types - # to whitelist if we don't want them to be filtered. - assert relations_resolved_with_score == [ - (("per:employee_of", (("PER", "Entity G"), ("ORG", "H"))), 0.6), - (("per:founder", (("PER", "Entity G"), ("ORG", "I"))), 0.5), - (("no_relation", (("ORG", "I"), ("PER", "Entity G"))), 0.6), - ] diff --git a/tests/taskmodules/test_text2text.py b/tests/taskmodules/test_text2text.py deleted file mode 100644 index 7dc779acb..000000000 --- a/tests/taskmodules/test_text2text.py +++ /dev/null @@ -1,275 +0,0 @@ -import pickle -from typing import Any, Dict, List, Sequence, Tuple - -import pytest -import torch -from pie_core import Annotation, TaskEncoding - -from pie_modules.annotations import AbstractiveSummary -from pie_modules.documents import ( - TextDocumentWithAbstractiveSummary, - TokenDocumentWithAbstractiveSummary, -) -from pie_modules.models.common import VALIDATION -from pie_modules.taskmodules import TextToTextTaskModule -from pie_modules.taskmodules.text_to_text import ( - InputEncodingType, - TargetEncodingType, - TaskEncodingType, - TaskOutputType, -) - - -@pytest.fixture(scope="module") -def documents(): - result = [] - - doc = TextDocumentWithAbstractiveSummary(text="This is a test document") - summary = AbstractiveSummary(text="a document") - doc.abstractive_summary.append(summary) - result.append(doc) - - doc = TextDocumentWithAbstractiveSummary( - text="This is another test document which is a bit longer" - ) - summary = AbstractiveSummary(text="a longer document") - doc.abstractive_summary.append(summary) - result.append(doc) - - return result - - -@pytest.fixture(scope="module") -def taskmodule(): - return TextToTextTaskModule( - tokenizer_name_or_path="google/t5-efficient-tiny-nl2", - document_type="pie_modules.documents.TextDocumentWithAbstractiveSummary", - target_layer="abstractive_summary", - target_annotation_type="pie_modules.annotations.AbstractiveSummary", - tokenized_document_type="pie_modules.documents.TokenDocumentWithAbstractiveSummary", - text_metric_type="torchmetrics.text.ROUGEScore", - ) - - -def test_taskmodule(taskmodule): - assert taskmodule is not None - assert taskmodule.document_type == TextDocumentWithAbstractiveSummary - assert taskmodule.tokenized_document_type == TokenDocumentWithAbstractiveSummary - assert taskmodule.target_annotation_type == AbstractiveSummary - assert taskmodule.layer_names == ["abstractive_summary"] - assert taskmodule.generation_config == {} - - -@pytest.fixture(scope="module") -def task_encodings(taskmodule, documents) -> Sequence[TaskEncodingType]: - encodings = taskmodule.encode(documents, encode_target=True) - assert all(isinstance(encoding, TaskEncoding) for encoding in encodings) - assert len(encodings) == 2 == len(documents) - assert encodings[0].document == documents[0] - assert encodings[1].document == documents[1] - return encodings - - -def test_maybe_log_example(taskmodule, task_encodings, caplog): - counter_backup = taskmodule.log_first_n_examples - - taskmodule.log_first_n_examples = 1 - with caplog.at_level("INFO"): - taskmodule.maybe_log_example(task_encodings[0]) - - assert len(caplog.messages) == 3 - assert caplog.messages[0] == "input_ids: [100, 19, 3, 9, 794, 1708, 1]" - assert caplog.messages[1] == "attention_mask: [1, 1, 1, 1, 1, 1, 1]" - assert caplog.messages[2] == "labels: [3, 9, 1708, 1]" - - taskmodule.log_first_n_examples = counter_backup - - -@pytest.fixture(scope="module") -def input_encoding(taskmodule, task_encodings) -> InputEncodingType: - assert len(task_encodings) > 0 - return task_encodings[0].inputs - - -def test_input_encoding(taskmodule, input_encoding): - assert isinstance(input_encoding, InputEncodingType) - assert input_encoding.input_ids == [100, 19, 3, 9, 794, 1708, 1] - assert input_encoding.attention_mask == [1, 1, 1, 1, 1, 1, 1] - - tokens = taskmodule.tokenizer.convert_ids_to_tokens(input_encoding.input_ids) - assert tokens == ["▁This", "▁is", "▁", "a", "▁test", "▁document", ""] - - -@pytest.fixture(scope="module") -def metadata(taskmodule, task_encodings) -> Dict[str, Any]: - assert len(task_encodings) > 0 - return task_encodings[0].metadata - - -def test_metadata(taskmodule, metadata): - assert set(metadata) == {"tokenized_document", "guidance_annotation"} - - tokenized_document = metadata["tokenized_document"] - assert isinstance(tokenized_document, TokenDocumentWithAbstractiveSummary) - assert tokenized_document.tokens == ("▁This", "▁is", "▁", "a", "▁test", "▁document", "") - assert len(tokenized_document.abstractive_summary) == 1 - assert tokenized_document.abstractive_summary[0].text == "a document" - - -@pytest.fixture(scope="module") -def target_encoding(taskmodule, task_encodings) -> TargetEncodingType: - assert len(task_encodings) > 0 - return task_encodings[0].targets - - -def test_target_encoding(taskmodule, target_encoding): - assert isinstance(target_encoding, TargetEncodingType) - assert target_encoding.labels == [3, 9, 1708, 1] - assert target_encoding.decoder_attention_mask == [1, 1, 1, 1] - - -@pytest.fixture(scope="module") -def batch(taskmodule, task_encodings) -> List[TaskEncodingType]: - result = taskmodule.collate(task_encodings) - return result - - -def test_batch(taskmodule, batch): - assert len(batch) == 2 - inputs, targets = batch - - assert set(inputs) == {"input_ids", "attention_mask"} - torch.testing.assert_close( - inputs["input_ids"], - torch.tensor( - [ - [100, 19, 3, 9, 794, 1708, 1, 0, 0, 0, 0, 0], - [100, 19, 430, 794, 1708, 84, 19, 3, 9, 720, 1200, 1], - ] - ), - ) - torch.testing.assert_close( - inputs["attention_mask"], - torch.tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), - ) - - assert set(targets) == {"labels", "decoder_attention_mask"} - torch.testing.assert_close( - targets["labels"], torch.tensor([[3, 9, 1708, 1, 0], [3, 9, 1200, 1708, 1]]) - ) - torch.testing.assert_close( - targets["decoder_attention_mask"], torch.tensor([[1, 1, 1, 1, 0], [1, 1, 1, 1, 1]]) - ) - - -@pytest.fixture(scope="module") -def unbatched_output(taskmodule, batch) -> Sequence[TaskOutputType]: - inputs, targets = batch - return taskmodule.unbatch_output(targets) - - -def test_unbatched_output(taskmodule, unbatched_output): - assert all(isinstance(output, TargetEncodingType) for output in unbatched_output) - assert len(unbatched_output) == 2 - - assert unbatched_output[0].labels == [3, 9, 1708, 1] - assert unbatched_output[0].decoder_attention_mask is None - - assert unbatched_output[1].labels == [3, 9, 1200, 1708, 1] - assert unbatched_output[1].decoder_attention_mask is None - - -@pytest.fixture(scope="module") -def decoded_annotations( - taskmodule, task_encodings, unbatched_output -) -> List[Tuple[str, Annotation]]: - result = [] - for encoding, output in zip(task_encodings, unbatched_output): - result.extend( - taskmodule.create_annotations_from_output(task_encoding=encoding, task_output=output) - ) - return result - - -def test_decoded_annotations(taskmodule, decoded_annotations): - names, annotations = zip(*decoded_annotations) - assert all(layer_name == taskmodule.target_layer for layer_name in names) - assert all( - isinstance(annotation, taskmodule.target_annotation_type) for annotation in annotations - ) - - assert len(annotations) == 2 - assert annotations[0].text == "a document" - assert annotations[0].score is None - assert annotations[1].text == "a longer document" - assert annotations[1].score is None - - -def test_configure_model_metrics(taskmodule): - metric = taskmodule.configure_model_metric(stage=VALIDATION) - assert metric is not None - values = metric.compute() - keys = { - "rouge2_fmeasure", - "rougeL_recall", - "rouge1_precision", - "rouge1_recall", - "rouge2_recall", - "rougeL_precision", - "rouge1_fmeasure", - "rougeLsum_recall", - "rougeLsum_precision", - "rougeL_fmeasure", - "rouge2_precision", - "rougeLsum_fmeasure", - } - assert set(values) == keys - assert all(torch.isnan(value) for value in values.values()) - - labels = torch.tensor([[3, 9, 1708, 1, 0], [3, 9, 1200, 1708, 1]]) - metric.update(prediction={"labels": labels}, target={"labels": labels}) - assert set(metric.metric_state) == keys - assert all( - value == [torch.tensor(1.0), torch.tensor(1.0)] for value in metric.metric_state.values() - ) - values = metric.compute() - assert set(values) == keys - assert all(value == torch.tensor(1.0) for value in values.values()) - - random_labels = torch.tensor([[875, 885, 112, 289, 769], [270, 583, 970, 114, 71]]) - metric.update(prediction={"labels": random_labels}, target={"labels": labels}) - values = metric.compute() - assert {k: v.item() for k, v in values.items()} == { - "rouge1_fmeasure": 0.5625, - "rouge1_precision": 0.550000011920929, - "rouge1_recall": 0.5833333134651184, - "rouge2_fmeasure": 0.5, - "rouge2_precision": 0.5, - "rouge2_recall": 0.5, - "rougeL_fmeasure": 0.5625, - "rougeL_precision": 0.550000011920929, - "rougeL_recall": 0.5833333134651184, - "rougeLsum_fmeasure": 0.5625, - "rougeLsum_precision": 0.550000011920929, - "rougeLsum_recall": 0.5833333134651184, - } - - # ensure that the metric can be pickled - pickle.dumps(metric) - - -def test_configure_model_generation(taskmodule): - generation_config = taskmodule.configure_model_generation() - assert generation_config is not None - assert generation_config == {} - - -def test_warn_once(taskmodule, caplog): - with caplog.at_level("WARNING"): - taskmodule.warn_only_once("test") - taskmodule.warn_only_once("test") - taskmodule.warn_only_once("test2") - - assert len(caplog.messages) == 2 - assert caplog.messages[0] == "test (This warning will only be shown once)" - assert caplog.messages[1] == "test2 (This warning will only be shown once)" diff --git a/tests/taskmodules/test_text2text_with_guidance.py b/tests/taskmodules/test_text2text_with_guidance.py deleted file mode 100644 index c66711610..000000000 --- a/tests/taskmodules/test_text2text_with_guidance.py +++ /dev/null @@ -1,240 +0,0 @@ -from typing import Any, Dict, List, Sequence, Tuple - -import pytest -import torch -from pie_core import Annotation, TaskEncoding - -from pie_modules.annotations import GenerativeAnswer, Question -from pie_modules.documents import ( - TextDocumentWithQuestionsAndGenerativeAnswers, - TokenDocumentWithQuestionsAndGenerativeAnswers, -) -from pie_modules.taskmodules import TextToTextTaskModule -from pie_modules.taskmodules.text_to_text import ( - InputEncodingType, - TargetEncodingType, - TaskEncodingType, - TaskOutputType, -) - - -@pytest.fixture(scope="module") -def documents(): - result = [] - - doc = TextDocumentWithQuestionsAndGenerativeAnswers(text="This is a test document") - question = Question(text="What is this?") - doc.questions.append(question) - answer = GenerativeAnswer(text="a document", question=question) - doc.generative_answers.append(answer) - result.append(doc) - - doc = TextDocumentWithQuestionsAndGenerativeAnswers( - text="This is another test document which is a bit longer." - ) - question = Question(text="And what is this?") - doc.questions.append(question) - answer = GenerativeAnswer(text="a longer document", question=question) - doc.generative_answers.append(answer) - result.append(doc) - - return result - - -@pytest.fixture(scope="module") -def taskmodule(): - return TextToTextTaskModule( - tokenizer_name_or_path="google/t5-efficient-tiny-nl2", - document_type="pie_modules.documents.TextDocumentWithQuestionsAndGenerativeAnswers", - target_layer="generative_answers", - target_annotation_type="pie_modules.annotations.GenerativeAnswer", - tokenized_document_type="pie_modules.documents.TokenDocumentWithQuestionsAndGenerativeAnswers", - guidance_layer="questions", - guidance_annotation_field="question", - text_metric_type="torchmetrics.text.ROUGEScore", - ) - - -def test_taskmodule(taskmodule): - assert taskmodule is not None - assert taskmodule.document_type == TextDocumentWithQuestionsAndGenerativeAnswers - assert taskmodule.tokenized_document_type == TokenDocumentWithQuestionsAndGenerativeAnswers - assert taskmodule.target_annotation_type == GenerativeAnswer - assert taskmodule.layer_names == ["generative_answers"] - assert taskmodule.generation_config == {} - - -@pytest.fixture(scope="module") -def task_encodings(taskmodule, documents) -> Sequence[TaskEncodingType]: - encodings = taskmodule.encode(documents, encode_target=True) - assert all(isinstance(encoding, TaskEncoding) for encoding in encodings) - assert len(encodings) == 2 == len(documents) - assert encodings[0].document == documents[0] - assert encodings[1].document == documents[1] - return encodings - - -def test_maybe_log_example(taskmodule, task_encodings, caplog): - counter_backup = taskmodule.log_first_n_examples - - taskmodule.log_first_n_examples = 1 - with caplog.at_level("INFO"): - taskmodule.maybe_log_example(task_encodings[0]) - - assert len(caplog.messages) == 3 - assert caplog.messages[0] == "input_ids: [363, 19, 48, 58, 1, 100, 19, 3, 9, 794, 1708, 1]" - assert caplog.messages[1] == "attention_mask: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]" - assert caplog.messages[2] == "labels: [3, 9, 1708, 1]" - - taskmodule.log_first_n_examples = counter_backup - - -@pytest.fixture(scope="module") -def input_encoding(taskmodule, task_encodings) -> InputEncodingType: - assert len(task_encodings) > 0 - return task_encodings[0].inputs - - -def test_input_encoding(taskmodule, input_encoding): - assert isinstance(input_encoding, InputEncodingType) - assert input_encoding.input_ids == [363, 19, 48, 58, 1, 100, 19, 3, 9, 794, 1708, 1] - assert input_encoding.attention_mask == [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] - - tokens = taskmodule.tokenizer.convert_ids_to_tokens(input_encoding.input_ids) - assert tokens == [ - "▁What", - "▁is", - "▁this", - "?", - "", - "▁This", - "▁is", - "▁", - "a", - "▁test", - "▁document", - "", - ] - - -@pytest.fixture(scope="module") -def metadata(taskmodule, task_encodings) -> Dict[str, Any]: - assert len(task_encodings) > 0 - return task_encodings[0].metadata - - -def test_metadata(taskmodule, metadata): - assert set(metadata) == {"tokenized_document", "guidance_annotation"} - - tokenized_document = metadata["tokenized_document"] - assert isinstance(tokenized_document, TokenDocumentWithQuestionsAndGenerativeAnswers) - assert tokenized_document.tokens == ( - "▁What", - "▁is", - "▁this", - "?", - "", - "▁This", - "▁is", - "▁", - "a", - "▁test", - "▁document", - "", - ) - assert len(tokenized_document.questions) == 1 - assert tokenized_document.questions[0].text == "What is this?" - - -@pytest.fixture(scope="module") -def target_encoding(taskmodule, task_encodings) -> TargetEncodingType: - assert len(task_encodings) > 0 - return task_encodings[0].targets - - -def test_target_encoding(taskmodule, target_encoding): - assert isinstance(target_encoding, TargetEncodingType) - assert target_encoding.labels == [3, 9, 1708, 1] - assert target_encoding.decoder_attention_mask == [1, 1, 1, 1] - - -@pytest.fixture(scope="module") -def batch(taskmodule, task_encodings) -> List[TaskEncodingType]: - result = taskmodule.collate(task_encodings) - return result - - -def test_batch(taskmodule, batch): - assert len(batch) == 2 - inputs, targets = batch - - assert set(inputs) == {"input_ids", "attention_mask"} - torch.testing.assert_close( - inputs["input_ids"], - torch.tensor( - [ - [363, 19, 48, 58, 1, 100, 19, 3, 9, 794, 1708, 1, 0, 0, 0, 0, 0, 0, 0], - [275, 125, 19, 48, 58, 1, 100, 19, 430, 794, 1708, 84, 19, 3, 9, 720, 1200, 5, 1], - ] - ), - ) - torch.testing.assert_close( - inputs["attention_mask"], - torch.tensor( - [ - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - ] - ), - ) - - assert set(targets) == {"labels", "decoder_attention_mask"} - torch.testing.assert_close( - targets["labels"], torch.tensor([[3, 9, 1708, 1, 0], [3, 9, 1200, 1708, 1]]) - ) - torch.testing.assert_close( - targets["decoder_attention_mask"], torch.tensor([[1, 1, 1, 1, 0], [1, 1, 1, 1, 1]]) - ) - - -@pytest.fixture(scope="module") -def unbatched_output(taskmodule, batch) -> Sequence[TaskOutputType]: - inputs, targets = batch - return taskmodule.unbatch_output(targets) - - -def test_unbatched_output(taskmodule, unbatched_output): - assert all(isinstance(output, TargetEncodingType) for output in unbatched_output) - assert len(unbatched_output) == 2 - - assert unbatched_output[0].labels == [3, 9, 1708, 1] - assert unbatched_output[0].decoder_attention_mask is None - - assert unbatched_output[1].labels == [3, 9, 1200, 1708, 1] - assert unbatched_output[1].decoder_attention_mask is None - - -@pytest.fixture(scope="module") -def decoded_annotations( - taskmodule, task_encodings, unbatched_output -) -> List[Tuple[str, Annotation]]: - result = [] - for encoding, output in zip(task_encodings, unbatched_output): - result.extend( - taskmodule.create_annotations_from_output(task_encoding=encoding, task_output=output) - ) - return result - - -def test_decoded_annotations(taskmodule, decoded_annotations): - names, annotations = zip(*decoded_annotations) - assert all(layer_name == taskmodule.target_layer for layer_name in names) - assert all( - isinstance(annotation, taskmodule.target_annotation_type) for annotation in annotations - ) - - assert len(annotations) == 2 - assert annotations[0].text == "a document" - assert annotations[0].score is None - assert annotations[1].text == "a longer document" - assert annotations[1].score is None diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 5ea2a0b85..b41d4bc5d 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -1,10 +1,23 @@ +import dataclasses import json +import re from typing import Dict, Optional import pytest -from pie_core import Annotation +from pie_core import Annotation, AnnotationLayer, annotation_field -from pie_modules.annotations import LabeledMultiSpan +from pie_modules.annotations import ( + BinaryRelation, + Label, + LabeledMultiSpan, + LabeledSpan, + MultiLabel, + MultiLabeledBinaryRelation, + MultiLabeledSpan, + NaryRelation, + Span, +) +from pie_modules.documents import TextBasedDocument def _test_annotation_reconstruction( @@ -17,6 +30,163 @@ def _test_annotation_reconstruction( assert annotation_reconstructed == annotation +def test_label(): + label1 = Label(label="label1") + assert label1.label == "label1" + assert label1.score == pytest.approx(1.0) + assert label1.resolve() == "label1" + + label2 = Label(label="label2", score=0.5) + assert label2.label == "label2" + assert label2.score == pytest.approx(0.5) + + assert label2.asdict() == { + "_id": label2._id, + "label": "label2", + "score": 0.5, + } + + _test_annotation_reconstruction(label2) + + +def test_multilabel(): + multilabel1 = MultiLabel(label=("label1", "label2")) + assert multilabel1.label == ("label1", "label2") + assert multilabel1.score == pytest.approx((1.0, 1.0)) + assert multilabel1.resolve() == ("label1", "label2") + + multilabel2 = MultiLabel(label=("label3", "label4"), score=(0.4, 0.5)) + assert multilabel2.label == ("label3", "label4") + assert multilabel2.score == pytest.approx((0.4, 0.5)) + + assert multilabel2.asdict() == { + "_id": multilabel2._id, + "label": ("label3", "label4"), + "score": (0.4, 0.5), + } + + _test_annotation_reconstruction(multilabel2) + + with pytest.raises( + ValueError, match=re.escape("Number of labels (2) and scores (3) must be equal.") + ): + MultiLabel(label=("label5", "label6"), score=(0.1, 0.2, 0.3)) + + +def test_span(): + span = Span(start=1, end=2) + assert span.start == 1 + assert span.end == 2 + + assert span.asdict() == { + "_id": span._id, + "start": 1, + "end": 2, + } + + _test_annotation_reconstruction(span) + + with pytest.raises(ValueError) as excinfo: + span.resolve() + assert str(excinfo.value) == "Span(start=1, end=2) is not attached to a target." + + @dataclasses.dataclass + class TestDocument(TextBasedDocument): + spans: AnnotationLayer[Span] = annotation_field(target="text") + + doc = TestDocument(text="Hello, world!") + span = Span(start=7, end=12) + doc.spans.append(span) + assert span.resolve() == "world" + + +def test_labeled_span(): + labeled_span1 = LabeledSpan(start=1, end=2, label="label1") + assert labeled_span1.start == 1 + assert labeled_span1.end == 2 + assert labeled_span1.label == "label1" + assert labeled_span1.score == pytest.approx(1.0) + + labeled_span2 = LabeledSpan(start=3, end=4, label="label2", score=0.5) + assert labeled_span2.start == 3 + assert labeled_span2.end == 4 + assert labeled_span2.label == "label2" + assert labeled_span2.score == pytest.approx(0.5) + + assert labeled_span2.asdict() == { + "_id": labeled_span2._id, + "start": 3, + "end": 4, + "label": "label2", + "score": 0.5, + } + + _test_annotation_reconstruction(labeled_span2) + + with pytest.raises(ValueError) as excinfo: + labeled_span1.resolve() + assert ( + str(excinfo.value) + == "LabeledSpan(start=1, end=2, label='label1', score=1.0) is not attached to a target." + ) + + @dataclasses.dataclass + class TestDocument(TextBasedDocument): + spans: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + + doc = TestDocument(text="Hello, world!") + labeled_span = LabeledSpan(start=7, end=12, label="LOC") + doc.spans.append(labeled_span) + assert labeled_span.resolve() == ("LOC", "world") + + +def test_multilabeled_span(): + multilabeled_span1 = MultiLabeledSpan(start=1, end=2, label=("label1", "label2")) + assert multilabeled_span1.start == 1 + assert multilabeled_span1.end == 2 + assert multilabeled_span1.label == ("label1", "label2") + assert multilabeled_span1.score == pytest.approx((1.0, 1.0)) + + multilabeled_span2 = MultiLabeledSpan( + start=3, end=4, label=("label3", "label4"), score=(0.4, 0.5) + ) + assert multilabeled_span2.start == 3 + assert multilabeled_span2.end == 4 + assert multilabeled_span2.label == ("label3", "label4") + assert multilabeled_span2.score == pytest.approx((0.4, 0.5)) + + assert multilabeled_span2.asdict() == { + "_id": multilabeled_span2._id, + "start": 3, + "end": 4, + "label": ("label3", "label4"), + "score": (0.4, 0.5), + } + + _test_annotation_reconstruction(multilabeled_span2) + + with pytest.raises( + ValueError, match=re.escape("Number of labels (2) and scores (3) must be equal.") + ): + MultiLabeledSpan(start=5, end=6, label=("label5", "label6"), score=(0.1, 0.2, 0.3)) + + with pytest.raises(ValueError) as excinfo: + multilabeled_span1.resolve() + assert ( + str(excinfo.value) + == "MultiLabeledSpan(start=1, end=2, label=('label1', 'label2'), score=(1.0, 1.0)) is not attached to a target." + ) + + @dataclasses.dataclass + class TestDocument(TextBasedDocument): + spans: AnnotationLayer[MultiLabeledSpan] = annotation_field(target="text") + + doc = TestDocument(text="Hello, world!") + multilabeled_span = MultiLabeledSpan(start=7, end=12, label=("LOC", "ORG")) + doc.spans.append(multilabeled_span) + assert multilabeled_span.resolve() == (("LOC", "ORG"), "world") + + def test_labeled_multi_span(): labeled_multi_span1 = LabeledMultiSpan(slices=((1, 2), (3, 4)), label="label1") assert labeled_multi_span1.slices == ((1, 2), (3, 4)) @@ -40,3 +210,179 @@ def test_labeled_multi_span(): } _test_annotation_reconstruction(labeled_multi_span2) + + +def test_binary_relation(): + head = Span(start=1, end=2) + tail = Span(start=3, end=4) + + binary_relation1 = BinaryRelation(head=head, tail=tail, label="label1") + assert binary_relation1.head == head + assert binary_relation1.tail == tail + assert binary_relation1.label == "label1" + assert binary_relation1.score == pytest.approx(1.0) + + binary_relation2 = BinaryRelation(head=head, tail=tail, label="label2", score=0.5) + assert binary_relation2.head == head + assert binary_relation2.tail == tail + assert binary_relation2.label == "label2" + assert binary_relation2.score == pytest.approx(0.5) + + assert binary_relation2.asdict() == { + "_id": binary_relation2._id, + "head": head._id, + "tail": tail._id, + "label": "label2", + "score": 0.5, + } + + annotation_store = { + head._id: head, + tail._id: tail, + } + _test_annotation_reconstruction(binary_relation2, annotation_store=annotation_store) + + with pytest.raises( + ValueError, + match=re.escape("Unable to resolve the annotation id without annotation_store."), + ): + BinaryRelation.fromdict(binary_relation2.asdict()) + + with pytest.raises(ValueError) as excinfo: + binary_relation1.resolve() + assert str(excinfo.value) == "Span(start=1, end=2) is not attached to a target." + + @dataclasses.dataclass + class TestDocument(TextBasedDocument): + spans: AnnotationLayer[Span] = annotation_field(target="text") + relations: AnnotationLayer[BinaryRelation] = annotation_field(target="spans") + + doc = TestDocument(text="Hello, world!") + head = Span(start=0, end=5) + tail = Span(start=7, end=12) + doc.spans.extend([head, tail]) + relation = BinaryRelation(head=head, tail=tail, label="LABEL") + doc.relations.append(relation) + assert relation.resolve() == ("LABEL", ("Hello", "world")) + + +def test_multilabeled_binary_relation(): + head = Span(start=1, end=2) + tail = Span(start=3, end=4) + + binary_relation1 = MultiLabeledBinaryRelation(head=head, tail=tail, label=("label1", "label2")) + assert binary_relation1.head == head + assert binary_relation1.tail == tail + assert binary_relation1.label == ("label1", "label2") + assert binary_relation1.score == pytest.approx((1.0, 1.0)) + + binary_relation2 = MultiLabeledBinaryRelation( + head=head, tail=tail, label=("label3", "label4"), score=(0.4, 0.5) + ) + assert binary_relation2.head == head + assert binary_relation2.tail == tail + assert binary_relation2.label == ("label3", "label4") + assert binary_relation2.score == pytest.approx((0.4, 0.5)) + + assert binary_relation2.asdict() == { + "_id": binary_relation2._id, + "head": head._id, + "tail": tail._id, + "label": ("label3", "label4"), + "score": (0.4, 0.5), + } + + annotation_store = { + head._id: head, + tail._id: tail, + } + _test_annotation_reconstruction(binary_relation2, annotation_store=annotation_store) + + with pytest.raises( + ValueError, + match=re.escape("Unable to resolve the annotation id without annotation_store."), + ): + MultiLabeledBinaryRelation.fromdict(binary_relation2.asdict()) + + with pytest.raises( + ValueError, match=re.escape("Number of labels (2) and scores (3) must be equal.") + ): + MultiLabeledBinaryRelation( + head=head, tail=tail, label=("label5", "label6"), score=(0.1, 0.2, 0.3) + ) + + with pytest.raises(ValueError) as excinfo: + binary_relation1.resolve() + assert str(excinfo.value) == "Span(start=1, end=2) is not attached to a target." + + @dataclasses.dataclass + class TestDocument(TextBasedDocument): + spans: AnnotationLayer[Span] = annotation_field(target="text") + relations: AnnotationLayer[MultiLabeledBinaryRelation] = annotation_field(target="spans") + + doc = TestDocument(text="Hello, world!") + head = Span(start=0, end=5) + tail = Span(start=7, end=12) + doc.spans.extend([head, tail]) + relation = MultiLabeledBinaryRelation(head=head, tail=tail, label=("LABEL1", "LABEL2")) + doc.relations.append(relation) + assert relation.resolve() == (("LABEL1", "LABEL2"), ("Hello", "world")) + + +def test_nary_relation(): + arg1 = Span(start=1, end=2) + arg2 = Span(start=3, end=4) + arg3 = Span(start=5, end=6) + + nary_relation1 = NaryRelation( + arguments=(arg1, arg2, arg3), roles=("role1", "role2", "role3"), label="label1" + ) + + assert nary_relation1.arguments == (arg1, arg2, arg3) + assert nary_relation1.roles == ("role1", "role2", "role3") + assert nary_relation1.label == "label1" + assert nary_relation1.score == pytest.approx(1.0) + + assert nary_relation1.asdict() == { + "_id": nary_relation1._id, + "arguments": [arg1._id, arg2._id, arg3._id], + "roles": ("role1", "role2", "role3"), + "label": "label1", + "score": 1.0, + } + + annotation_store = { + arg1._id: arg1, + arg2._id: arg2, + arg3._id: arg3, + } + _test_annotation_reconstruction(nary_relation1, annotation_store=annotation_store) + + with pytest.raises( + ValueError, + match=re.escape("Unable to resolve the annotation id without annotation_store."), + ): + NaryRelation.fromdict(nary_relation1.asdict()) + + with pytest.raises(ValueError) as excinfo: + nary_relation1.resolve() + assert str(excinfo.value) == "Span(start=1, end=2) is not attached to a target." + + @dataclasses.dataclass + class TestDocument(TextBasedDocument): + spans: AnnotationLayer[Span] = annotation_field(target="text") + relations: AnnotationLayer[NaryRelation] = annotation_field(target="spans") + + doc = TestDocument(text="Hello, world A and B!") + arg1 = Span(start=0, end=5) + arg2 = Span(start=7, end=14) + arg3 = Span(start=19, end=20) + doc.spans.extend([arg1, arg2, arg3]) + relation = NaryRelation( + arguments=(arg1, arg2, arg3), roles=("ARG1", "ARG2", "ARG3"), label="LABEL" + ) + doc.relations.append(relation) + assert relation.resolve() == ( + "LABEL", + (("ARG1", "Hello"), ("ARG2", "world A"), ("ARG3", "B")), + ) diff --git a/tests/utils/test_hydra.py b/tests/utils/test_hydra.py deleted file mode 100644 index e0f6ddf82..000000000 --- a/tests/utils/test_hydra.py +++ /dev/null @@ -1,43 +0,0 @@ -import dataclasses - -import pytest -from pie_core import AnnotationLayer, annotation_field -from pie_core.utils.hydra import resolve_type - -from pie_modules.annotations import LabeledSpan, Span -from pie_modules.documents import TextBasedDocument - - -@dataclasses.dataclass -class TestDocumentWithEntities(TextBasedDocument): - entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") - - -@dataclasses.dataclass -class TestDocumentWithSentences(TextBasedDocument): - sentences: AnnotationLayer[Span] = annotation_field(target="text") - - -def test_resolve_document_type(): - assert resolve_type(TestDocumentWithEntities) == TestDocumentWithEntities - assert ( - resolve_type("tests.utils.test_hydra.TestDocumentWithEntities") == TestDocumentWithEntities - ) - with pytest.raises(TypeError) as exc_info: - resolve_type("tests.utils.test_hydra.test_resolve_document_type") - assert str(exc_info.value).startswith( - "type must be a subclass of None or a string that resolves to that, but got " - "