From 956a53599402f72a78bbbc0cd29bc573a3685fee Mon Sep 17 00:00:00 2001 From: Markus Reichel Date: Wed, 20 Apr 2022 15:22:58 +0200 Subject: [PATCH 01/12] Add streamlit app to learn how to edit penman notation with ud graphs --- frontend/graphedit.py | 120 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 frontend/graphedit.py diff --git a/frontend/graphedit.py b/frontend/graphedit.py new file mode 100644 index 0000000..80ebe9d --- /dev/null +++ b/frontend/graphedit.py @@ -0,0 +1,120 @@ +import argparse +import copy +import os +import time +import json +import streamlit as st +import pandas as pd +import penman + +from graphviz import Source +from xpotato.dataset.utils import default_pn_to_graph + +from tuw_nlp.graph.utils import graph_to_pn, pn_to_graph + +from utils import ( + train_df, + add_rule_manually, + annotate_df, + extract_data_from_dataframe, + get_df_from_rules, + graph_viewer, + init_evaluator, + init_extractor, + init_session_states, + rank_and_suggest, + read_df, + rerun, + rule_chooser, + save_ruleset, + read_ruleset, + save_after_modify, + save_dataframe, + match_texts, + show_ml_feature, + st_stdout, + to_dot, +) + +def main(): + st.set_page_config(layout="wide") + hide_streamlit_style = """ + + + """ + st.markdown(hide_streamlit_style, unsafe_allow_html=True) + st.markdown( + "

GraphEdit

" + "

Edit ud PENMAN graphs!

", + unsafe_allow_html=True, + ) + init_session_states() + text_input = st.text_area("Provide the text here you want to convert to Penman notation!" + ) + topenman = st.button("To Penman!") + output = st.empty() + penman_input = st.text_area("Provide the penman notation here you want to render!" + ) + global generatedpenman + render = st.button("Render!") + extractor = init_extractor("en", "ud") + if topenman: + if text_input: + texts = text_input.split("\n") + graphs = list(extractor.parse_iterable([text for text in texts], "ud")) + dot_current_graph = to_dot( + graphs[0], + ) + penmanstring = value=(penman.encode(penman.decode(graph_to_pn(graphs[0])), indent=10)) + #penman_input = placeholder.text_area("Penman notation provided.", penmanstring) + output.text(penmanstring) + generatedpenman = penmanstring + if render: + #penman_input = placeholder.text_area("Provide the penman notation here you want to render!", penmanstring) + if penman_input: + + graph, ind = default_pn_to_graph(penman_input) + dot_current_graph = to_dot( + graph, + ) + + if st.session_state.download: + graph_pipe = Source(dot_current_graph).pipe(format="svg") + st.download_button( + label="Download graph as SVG", + data=graph_pipe, + file_name="graph.svg", + mime="mage/svg+xml", + ) + + with st.expander("Graph dot source", expanded=False): + st.write(dot_current_graph) + + st.graphviz_chart( + dot_current_graph, + use_container_width=True, + ) + + st.write("Penman format:") + st.text(penman.encode(penman.decode(graph_to_pn(graph)), indent=10)) + st.write("In one line format:") + st.write(graph_to_pn(graph)) + +if __name__ == "__main__": + main() From f02bb96d0cb1ef5169e3458d5f32c64f4d778d9e Mon Sep 17 00:00:00 2001 From: Markus Reichel Date: Wed, 1 Jun 2022 15:50:40 +0200 Subject: [PATCH 02/12] Add eraser to the hatexplain evaluation script Still missing: multi rule matching, labels without and only rationales --- scripts/call_eraser.py | 87 ++ scripts/eraserbenchmark/.gitignore | 111 +++ scripts/eraserbenchmark/LICENSE | 201 +++++ scripts/eraserbenchmark/README.md | 47 ++ scripts/eraserbenchmark/REPRODUCTION.txt | 32 + .../eraserbenchmark/data_exploration.ipynb | 309 +++++++ scripts/eraserbenchmark/params/boolq.json | 26 + .../eraserbenchmark/params/boolq_baas.json | 26 + .../eraserbenchmark/params/boolq_bert.json | 32 + .../eraserbenchmark/params/boolq_soft.json | 21 + scripts/eraserbenchmark/params/cose_bert.json | 30 + .../eraserbenchmark/params/esnli_bert.json | 28 + .../params/evidence_inference.json | 26 + .../params/evidence_inference_bert.json | 33 + .../params/evidence_inference_soft.json | 22 + scripts/eraserbenchmark/params/fever.json | 26 + .../eraserbenchmark/params/fever_baas.json | 25 + .../eraserbenchmark/params/fever_bert.json | 32 + .../eraserbenchmark/params/fever_soft.json | 21 + scripts/eraserbenchmark/params/movies.json | 26 + .../eraserbenchmark/params/movies_baas.json | 26 + .../eraserbenchmark/params/movies_bert.json | 32 + .../eraserbenchmark/params/movies_soft.json | 21 + scripts/eraserbenchmark/params/multirc.json | 26 + .../eraserbenchmark/params/multirc_baas.json | 26 + .../eraserbenchmark/params/multirc_bert.json | 32 + .../eraserbenchmark/params/multirc_soft.json | 21 + .../rationale_benchmark/__init__.py | 0 .../rationale_benchmark/metrics.py | 760 +++++++++++++++++ .../rationale_benchmark/models/__init__.py | 0 .../models/encode_attend.py | 520 ++++++++++++ .../rationale_benchmark/models/mlp.py | 282 +++++++ .../rationale_benchmark/models/model_utils.py | 155 ++++ .../models/pipeline/__init__.py | 0 .../models/pipeline/bert_pipeline.py | 330 ++++++++ .../models/pipeline/evidence_classifier.py | 193 +++++ .../models/pipeline/evidence_identifier.py | 261 ++++++ .../pipeline/evidence_token_identifier.py | 246 ++++++ .../models/pipeline/pipeline_train.py | 167 ++++ .../models/pipeline/pipeline_utils.py | 799 ++++++++++++++++++ .../models/sequence_taggers.py | 62 ++ .../rationale_benchmark/utils.py | 226 +++++ scripts/eraserbenchmark/requirements.txt | 12 + scripts/evaluate_hatexplain.py | 16 +- scripts/hatexplain_to_eraser.py | 247 ++++++ scripts/read_hatexplain.py | 4 +- 46 files changed, 5619 insertions(+), 6 deletions(-) create mode 100644 scripts/call_eraser.py create mode 100644 scripts/eraserbenchmark/.gitignore create mode 100644 scripts/eraserbenchmark/LICENSE create mode 100644 scripts/eraserbenchmark/README.md create mode 100644 scripts/eraserbenchmark/REPRODUCTION.txt create mode 100644 scripts/eraserbenchmark/data_exploration.ipynb create mode 100644 scripts/eraserbenchmark/params/boolq.json create mode 100644 scripts/eraserbenchmark/params/boolq_baas.json create mode 100644 scripts/eraserbenchmark/params/boolq_bert.json create mode 100644 scripts/eraserbenchmark/params/boolq_soft.json create mode 100644 scripts/eraserbenchmark/params/cose_bert.json create mode 100644 scripts/eraserbenchmark/params/esnli_bert.json create mode 100644 scripts/eraserbenchmark/params/evidence_inference.json create mode 100644 scripts/eraserbenchmark/params/evidence_inference_bert.json create mode 100644 scripts/eraserbenchmark/params/evidence_inference_soft.json create mode 100644 scripts/eraserbenchmark/params/fever.json create mode 100644 scripts/eraserbenchmark/params/fever_baas.json create mode 100644 scripts/eraserbenchmark/params/fever_bert.json create mode 100644 scripts/eraserbenchmark/params/fever_soft.json create mode 100644 scripts/eraserbenchmark/params/movies.json create mode 100644 scripts/eraserbenchmark/params/movies_baas.json create mode 100644 scripts/eraserbenchmark/params/movies_bert.json create mode 100644 scripts/eraserbenchmark/params/movies_soft.json create mode 100644 scripts/eraserbenchmark/params/multirc.json create mode 100644 scripts/eraserbenchmark/params/multirc_baas.json create mode 100644 scripts/eraserbenchmark/params/multirc_bert.json create mode 100644 scripts/eraserbenchmark/params/multirc_soft.json create mode 100644 scripts/eraserbenchmark/rationale_benchmark/__init__.py create mode 100644 scripts/eraserbenchmark/rationale_benchmark/metrics.py create mode 100644 scripts/eraserbenchmark/rationale_benchmark/models/__init__.py create mode 100644 scripts/eraserbenchmark/rationale_benchmark/models/encode_attend.py create mode 100644 scripts/eraserbenchmark/rationale_benchmark/models/mlp.py create mode 100644 scripts/eraserbenchmark/rationale_benchmark/models/model_utils.py create mode 100644 scripts/eraserbenchmark/rationale_benchmark/models/pipeline/__init__.py create mode 100644 scripts/eraserbenchmark/rationale_benchmark/models/pipeline/bert_pipeline.py create mode 100644 scripts/eraserbenchmark/rationale_benchmark/models/pipeline/evidence_classifier.py create mode 100644 scripts/eraserbenchmark/rationale_benchmark/models/pipeline/evidence_identifier.py create mode 100644 scripts/eraserbenchmark/rationale_benchmark/models/pipeline/evidence_token_identifier.py create mode 100644 scripts/eraserbenchmark/rationale_benchmark/models/pipeline/pipeline_train.py create mode 100644 scripts/eraserbenchmark/rationale_benchmark/models/pipeline/pipeline_utils.py create mode 100644 scripts/eraserbenchmark/rationale_benchmark/models/sequence_taggers.py create mode 100644 scripts/eraserbenchmark/rationale_benchmark/utils.py create mode 100644 scripts/eraserbenchmark/requirements.txt create mode 100644 scripts/hatexplain_to_eraser.py diff --git a/scripts/call_eraser.py b/scripts/call_eraser.py new file mode 100644 index 0000000..8ab40c0 --- /dev/null +++ b/scripts/call_eraser.py @@ -0,0 +1,87 @@ +import os, sys +file_path = 'eraserbenchmark/' +sys.path.append(os.path.dirname(file_path)) + +import eraserbenchmark.rationale_benchmark.metrics as eb + +import contextlib, logging + +class DiscardEraserBenchMarkStdOut(object): + def write(self, x): pass + +@contextlib.contextmanager +def nostdout(): + save_stdout = sys.stdout + sys.stdout = DiscardEraserBenchMarkStdOut() + yield + sys.stdout = save_stdout + +#--data_dir : Location of the folder which contains the dataset in eraser format +#--results : The location of the model output file in eraser format +#--score_file : The file name and location to write the output + +def call_eraser(datadir, testtrainorval, pathtopredictions, silent=False): + import sys, os + pkgpath = os.getcwd()+"\\eraserbenchmark" + print(pkgpath) + sys.path.append(pkgpath) + import rationale_benchmark.metrics as eraser + if silent: + logger = logging.getLogger() + logger.disabled = True + #dir(rationale_benchmark.metrics) + eraser.runEvaluation(data_dir=datadir, #data dir + split=testtrainorval, # split + results=pathtopredictions, # results + score_file="eraser_output.json", # score + strict=False) # strict + #iou_thresholds=[0.5], # iou + #aopc_thresholds=[0.01, 0.05, 0.1, 0.2, 0.5]) # aopc + if silent: + logger.disabled = False + print_eraser_results() + +""" +def call_eraser(datadir, testtrainorval, pathtopredictions): + #args = ['--split', 'test', '--strict', '--data_dir', 'movies', '--results', './movies/movies_majority_human_perf.jsonl'] + args = ['--split', testtrainorval, '--data_dir', datadir, '--results', pathtopredictions, '--score_file', 'eraser_output.json'] + for arg in args: + sys.argv.append(arg) + + # suppress text + SUPPRESS_TEXT = False + + if SUPPRESS_TEXT: + logger = logging.getLogger() + logger.disabled = True + with nostdout(): + eb.main() + logger.disabled = False + else: + eb.main() + print_eraser_results() +""" + +def print_eraser_results(): + # print the required results + import json + with open('./eraser_output.json') as fp: + output_data = json.load(fp) + + print('\nPlausibility') + if 'iou_scores' in output_data: + print('IOU F1 :', output_data['iou_scores'][0]['macro']['f1']) + print('Token F1 :', output_data['token_prf']['instance_macro']['f1']) + + if 'token_soft_metrics' in output_data: + print('AUPRC :', output_data['token_soft_metrics']['auprc']) + + print('\nFaithfulness') + if 'classification_scores' in output_data: + print('Comprehensiveness :', output_data['classification_scores']['comprehensiveness']) + print('Sufficiency :', output_data['classification_scores']['sufficiency']) + else: + print('--') + +if __name__ == "__main__": + call_eraser("./hatexplain", "train", "./hatexplain/train_prediction.jsonl") \ No newline at end of file diff --git a/scripts/eraserbenchmark/.gitignore b/scripts/eraserbenchmark/.gitignore new file mode 100644 index 0000000..e807186 --- /dev/null +++ b/scripts/eraserbenchmark/.gitignore @@ -0,0 +1,111 @@ +model_components +pipeline_outputs +*.swp +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +rationale_benchmark/data/esnli_previous +data/esnli_previous +esnli_union/ diff --git a/scripts/eraserbenchmark/LICENSE b/scripts/eraserbenchmark/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/scripts/eraserbenchmark/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/scripts/eraserbenchmark/README.md b/scripts/eraserbenchmark/README.md new file mode 100644 index 0000000..24f7ace --- /dev/null +++ b/scripts/eraserbenchmark/README.md @@ -0,0 +1,47 @@ +# eraserbenchmark +A benchmark for understanding and evaluating rationales: http://www.eraserbenchmark.com/ + +## Core Files + +The core files are [utils](rationale_benchmark/utils.py) and [metrics](rationale_benchmark/metrics.py). +These two files comprise everything you need to work with our released datasets. + +[utils](rationale_benchmark/utils.py) documents everything you need to know about our input formats. Output +formats and validation code are covered in [metrics](rationale_benchmark/metrics.py). + +## Models + +At the moment we offer two forms of pipeline models: +* (Lehman, et al., 2019) - sentence level rationale identification, followed by taking the best resulting sentence and classifying it. + * Both portions of this pipeline function via encoding the input sentence (via a GRU), attending (conditioned on a query vector), and making a classification. +* BERT-To-BERT - the same as above, but using a BERT model. + +### (Lehman, et al., 2019) Pipeline + +To run this model, we need to first: +* create a `model_components`, `data`, and `output`, directory +* download GloVe vectors from http://nlp.stanford.edu/data/glove.6B.zip and extract the 200 dimensional vector to `model_components` +* download http://evexdb.org/pmresources/vec-space-models/PubMed-w2v.bin to `model_components` +* set up a virtual env meeting requirements.txt +* download data from the primary website to `data` and extract each dataset to its respective directory +* ensure that we have at least an 11G GPU. Reducing batch sizes may enable running on a smaller GPU. + +Then we can run (as an example): +``` +PYTHONPATH=./:$PYTHONPATH python rationale_benchmark/models/pipeline/pipeline_train.py --data_dir data/movies --output_dir output/movies --model_params params/movies.json +PYTHONPATH=./:$PYTHONPATH python rationale_benchmark/metrics.py --split test --data_dir data/movies --results output/movies/test_decoded.jsonl --score_file output/movies/test_scores.json +``` + +### BERT-To-BERT Pipeline + +To run this model, instructions are effectively the same as the simple pipeline above, except we also require a GPU with approximately 16G of memory (e.g. Tesla V100). The same caveats about batch sizes apply here as well. + +Then we can run (as an example): +``` +PYTHONPATH=./:$PYTHONPATH python rationale_benchmark/models/pipeline/bert_pipeline.py --data_dir data/movies --output_dir output_bert/movies --model_params param/movies.json +PYTHONPATH=./:$PYTHONPATH python rationale_benchmark/metrics.py --split test --data_dir data/movies --results output_bert/movies/test_decoded.jsonl --score_file output_bert/movies/test_scores.json +``` + +For more examples, see the [BERT-to-BERT reproduction](REPRODUCTION.txt). + +More models including Lei et al can be found at : https://github.com/successar/Eraser-Benchmark-Baseline-Models diff --git a/scripts/eraserbenchmark/REPRODUCTION.txt b/scripts/eraserbenchmark/REPRODUCTION.txt new file mode 100644 index 0000000..5b7512d --- /dev/null +++ b/scripts/eraserbenchmark/REPRODUCTION.txt @@ -0,0 +1,32 @@ +# commands to reproduce results +# Note that we control the GPU device directly via CUDA_VISIBLE_DEVICES +# Note that we add the current directory to the PYTHONPATH directly. +# Evidence Inference +CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python rationale_benchmark/models/pipeline/bert_pipeline.py --data_dir data/evidence_inference/ --output_dir bert_models/evidence_inference --model_params params/evidence_inference_bert.json +PYTHONPATH=./:$PYTHONPATH python rationale_benchmark/metrics.py --data_dir data/evidence_inference/ --split test --strict --results bert_models/evidence_inference/test_decoded.jsonl --iou_thresholds 0.5 --score_file bert_models/evidence_inference_test_scores.json + +# FEVER, takes a very long time to run +CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python rationale_benchmark/models/pipeline/bert_pipeline.py --data_dir data/fever/ --output_dir bert_models/fever --model_params params/fever_bert.json +PYTHONPATH=./:$PYTHONPATH python rationale_benchmark/metrics.py --data_dir data/fever/ --split test --strict --results bert_models/fever/test_decoded.jsonl --iou_thresholds 0.5 --score_file bert_models/fever_test_scores.json + +# BoolQ, takes a very long time to run +CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python rationale_benchmark/models/pipeline/bert_pipeline.py --data_dir data/boolq/ --output_dir bert_models/boolq --model_params params/boolq_bert.json +PYTHONPATH=./:$PYTHONPATH python rationale_benchmark/metrics.py --data_dir data/boolq/ --split test --strict --results bert_models/boolq/test_decoded.jsonl --iou_thresholds 0.5 --score_file bert_models/boolq_test_scores.json + +# MultiRC +CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python rationale_benchmark/models/pipeline/bert_pipeline.py --data_dir data/multirc/ --output_dir bert_models/multirc --model_params params/multirc_bert.json +PYTHONPATH=./:$PYTHONPATH python rationale_benchmark/metrics.py --data_dir data/multirc/ --split test --strict --results bert_models/multirc/test_decoded.jsonl --iou_thresholds 0.5 --score_file bert_models/multirc_test_scores.json + +# Movies +CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python rationale_benchmark/models/pipeline/bert_pipeline.py --data_dir data/movies/ --output_dir bert_models/movies --model_params params/movies_bert.json +PYTHONPATH=./:$PYTHONPATH python rationale_benchmark/metrics.py --data_dir data/movies/ --split test --strict --results bert_models/movies/test_decoded.jsonl --iou_thresholds 0.5 --score_file bert_models/movies_test_scores.json + +# CoS-E +CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python rationale_benchmark/models/pipeline/bert_pipeline.py --data_dir data/cose_simplified/ --output_dir bert_models/cose --model_params params/cose_bert.json +# Note that the training files and the evaluation files are different here. This is because we have done terrible things in order to get COS-E to run in this situation, as it is different from all of the other datasets. +PYTHONPATH=./:$PYTHONPATH python rationale_benchmark/metrics.py --data_dir data/cose/ --split test --strict --results bert_models/cose/test_decoded.jsonl --iou_thresholds 0.5 --score_file bert_models/cose_test_scores.json + +# e-SNLI +# This is only an example; scoring should be done against esnli +CUDA_VISIBLE_DEVICES=0 PYTHONPATH=./:$PYTHONPATH python rationale_benchmark/models/pipeline/bert_pipeline.py --data_dir data/esnli_flat/ --output_dir bert_models/esnli_flat --model_params params/esnli_bert.json +PYTHONPATH=./:$PYTHONPATH python rationale_benchmark/metrics.py --data_dir data/esnli_flat --split test --strict --results bert_models/esnli_flat/test_decoded.jsonl --iou_thresholds 0.5 --score_file bert_models/esnli_test_scores.json diff --git a/scripts/eraserbenchmark/data_exploration.ipynb b/scripts/eraserbenchmark/data_exploration.ipynb new file mode 100644 index 0000000..4ee63bc --- /dev/null +++ b/scripts/eraserbenchmark/data_exploration.ipynb @@ -0,0 +1,309 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Read and explore data\n", + "\n", + "We'll start off with reading and doing some basic data processing. We'll assume that:\n", + "* you've downloaded the data from http://www.eraserbenchmark.com/ and have unpacked it to a directory called `data`\n", + "* you're running the kernel in the root of the `eraserbenchmark` repo\n", + "\n", + "We're going to work with the movies dataset as it's the smallest and easiest to get started with. All the data is stored in either plain text, or jsonl, and should be pre-tokenized and ready to go!" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from rationale_benchmark.utils import load_documents, load_datasets, annotations_from_jsonl, Annotation\n", + "\n", + "data_root = os.path.join('data', 'movies')\n", + "documents = load_documents(data_root)\n", + "val = annotations_from_jsonl(os.path.join(data_root, 'val.jsonl'))\n", + "## Or load everything:\n", + "train, val, test = load_datasets(data_root)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "What is the sentiment of this review?\n", + "NEG\n", + "16\n" + ] + } + ], + "source": [ + "ann = train[0]\n", + "evidences = ann.all_evidences()\n", + "print(type(ann))\n", + "print(ann.query)\n", + "print(ann.classification)\n", + "print(len(evidences))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "So we have a review with a negative sentiment, and 16 evidence statements. Let's take a look." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "the sad part is\n", + "what 's the deal ?\n", + "not really\n", + "just did n't snag this one correctly\n", + "it does n't entertain , it 's confusing , it rarely excites\n", + "have no idea what 's going on\n", + "do we really need to see it over and over again ?\n", + "skip it !\n", + "pretty redundant\n", + "the film does n't stick\n", + "it 's simply too jumbled\n", + "i get kind of fed up after a while\n", + "downshifts into this \" fantasy \" world\n", + "executed it terribly\n", + "a very bad package\n", + "mind - fuck movie\n" + ] + } + ], + "source": [ + "for ev in evidences:\n", + " print(ev.text)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's get all the documents, and a take a look at the contents" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "44\n", + "plot : two teen couples go to a church party , drink and then drive .\n", + "they get into an accident .\n", + "one of the guys dies , but his girlfriend continues to see him in her life , and has nightmares .\n", + "what 's the deal ?\n", + "watch the movie and \" sorta \" find out . . .\n", + "critique : a mind - fuck movie for the teen generation that touches on a very cool idea , but presents it in a very bad package .\n", + "which is what makes this review an even harder one to write , since i generally applaud films which attempt to break the mold , mess with your head and such ( lost highway & memento ) , but there are good and bad ways of making all types of films , and these folks just did n't snag this one correctly\n", + ".\n", + "they seem to have taken this pretty neat concept , but executed it terribly .\n", + "so what are the problems with the movie ?\n", + "well , its main problem is that it 's simply too jumbled\n", + ".\n", + "it starts off \" normal \" but then downshifts into this \" fantasy \" world in which you , as an audience member ,\n", + "have no idea what 's going on\n", + ".\n", + "there are dreams , there are characters coming back from the dead , there are others who look like the dead , there are strange apparitions , there are disappearances , there are a looooot of chase scenes , there are tons of weird things that happen , and most of it is simply not explained .\n", + "now i personally do n't mind trying to unravel a film every now and then , but when all it does is give me the same clue over and over again , i get kind of fed up after a while , which is this film 's biggest problem .\n", + "it 's obviously got this big secret to hide , but it seems to want to hide it completely until its final five minutes .\n", + "and do they make things entertaining , thrilling or even engaging , in the meantime ?\n", + "not really .\n", + "the sad part is that the arrow and i both dig on flicks like this , so we actually figured most of it out by the half - way point , so all of the strangeness after that did start to make a little bit of sense , but it still did n't the make the film all that more entertaining .\n", + "i guess the bottom line with movies like this is that you should always make sure that the audience is \" into it \" even before they are given the secret password to enter your world of understanding .\n", + "i mean , showing melissa sagemiller running away from visions for about 20 minutes throughout the movie is just plain lazy ! !\n", + "okay , we get it .\n", + ". .\n", + "there are people chasing her and we do n't know who they are .\n", + "do we really need to see it over and over again ?\n", + "how about giving us different scenes offering further insight into all of the strangeness going down in the movie ?\n", + "apparently , the studio took this film away from its director and chopped it up themselves , and it shows .\n", + "there might 've been a pretty decent teen mind - fuck movie in here somewhere , but i guess \" the suits \" decided that turning it into a music video with little edge , would make more sense .\n", + "the actors are pretty good for the most part , although wes bentley just seemed to be playing the exact same character that he did in american beauty , only in a new neighborhood .\n", + "but my biggest kudos go out to sagemiller , who holds her own throughout the entire film , and actually has you feeling her character 's unraveling .\n", + "overall , the film does n't stick\n", + "because it does n't entertain , it 's confusing , it rarely excites and\n", + "it feels pretty redundant for most of its runtime , despite a pretty cool ending and explanation to all of the craziness that came before it .\n", + "oh ,\n", + "and by the way , this is not a horror or teen slasher flick . . .\n", + "it 's just packaged to look that way because someone is apparently assuming that the genre is still hot with the kids .\n", + "it also wrapped production two years ago and has been sitting on the shelves ever since .\n", + "whatever . .\n", + ". skip it !\n", + "where 's joblo coming from ?\n", + "a nightmare of elm street 3 ( 7/10 ) - blair witch 2 ( 7/10 ) - the crow ( 9/10 ) - the crow : salvation ( 4/10 )\n", + "- lost highway ( 10/10 ) - memento ( 10/10 ) - the others ( 9/10 ) - stir of echoes ( 8/10 )\n" + ] + } + ], + "source": [ + "(docid,) = set(ev.docid for ev in evidences)\n", + "doc = documents[docid]\n", + "print(len(doc))\n", + "for sent in doc:\n", + " print(' '.join(sent))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now let's take a look at where in the document these start appearing:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "the sad part is\n", + "the sad part is\n", + "what 's the deal ?\n", + "what 's the deal ?\n", + "not really\n", + "not really\n", + "just did n't snag this one correctly\n", + "just did n't snag this one correctly\n", + "it does n't entertain , it 's confusing , it rarely excites\n", + "it does n't entertain , it 's confusing , it rarely excites\n", + "have no idea what 's going on\n", + "have no idea what 's going on\n", + "do we really need to see it over and over again ?\n", + "do we really need to see it over and over again ?\n", + "skip it !\n", + "skip it !\n", + "pretty redundant\n", + "pretty redundant\n", + "the film does n't stick\n", + "the film does n't stick\n", + "it 's simply too jumbled\n", + "it 's simply too jumbled\n", + "i get kind of fed up after a while\n", + "i get kind of fed up after a while\n", + "downshifts into this \" fantasy \" world\n", + "downshifts into this \" fantasy \" world\n", + "executed it terribly\n", + "executed it terribly\n", + "a very bad package\n", + "a very bad package\n", + "mind - fuck movie\n", + "mind - fuck movie\n" + ] + } + ], + "source": [ + "import itertools\n", + "flattened_doc = list(itertools.chain.from_iterable(doc))\n", + "for ev in evidences:\n", + " # saved text\n", + " print(ev.text)\n", + " # offset text\n", + " print(' '.join(flattened_doc[ev.start_token:ev.end_token]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Count rationale tokens, tokens, sentences" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "evidences 8.679174484052533\n", + "document_sentences 36.78924327704816\n", + "document_tokens 773.5622263914947\n", + "rationale_tokens 66.83989993746091\n", + "rationale_token_fraction 0.09350348753236702\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "\n", + "def process_annotation(ann: Annotation, docs: dict) -> dict:\n", + " evidences = ann.all_evidences()\n", + " if len(evidences) == 0:\n", + " return {}\n", + " (docid,) = set(ev.docid for ev in evidences)\n", + " doc = docs[docid]\n", + " sentences = len(doc)\n", + " tokens = sum(len(s) for s in doc)\n", + " # this accumulation will take care of any potentially overlapping evidence statements.\n", + " # there should be none in the data, but getting familiar with the idea of how to do this is potentially useful\n", + " rationale_tokens = len(set(itertools.chain.from_iterable(range(ev.start_token, ev.end_token) for ev in evidences)))\n", + " return {\n", + " 'class': ann.classification,\n", + " 'evidences': len(evidences),\n", + " 'document_sentences': sentences,\n", + " 'document_tokens': tokens,\n", + " 'rationale_tokens': rationale_tokens,\n", + " 'rationale_token_fraction': rationale_tokens / tokens\n", + " }\n", + "\n", + "def average(counts, key) -> float:\n", + " ns = [c[key] for c in counts]\n", + " return np.mean(ns)\n", + "\n", + "# this filter skips an empty document \n", + "annotation_counts = list(filter(lambda x: len(x) > 0, (process_annotation(ann, documents) for ann in train)))\n", + "for key in ['evidences', 'document_sentences', 'document_tokens', 'rationale_tokens', 'rationale_token_fraction']:\n", + " print(key, average(annotation_counts, key))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/scripts/eraserbenchmark/params/boolq.json b/scripts/eraserbenchmark/params/boolq.json new file mode 100644 index 0000000..c9317b6 --- /dev/null +++ b/scripts/eraserbenchmark/params/boolq.json @@ -0,0 +1,26 @@ +{ + "embeddings": { + "embedding_file": "model_components/glove.6B.200d.txt", + "dropout": 0.05 + }, + "evidence_identifier": { + "mlp_size": 128, + "dropout": 0.2, + "batch_size": 768, + "epochs": 50, + "patience": 10, + "lr": 1e-3, + "sampling_method": "random", + "sampling_ratio": 1.0 + }, + "evidence_classifier": { + "classes": [ "False", "True" ], + "mlp_size": 128, + "dropout": 0.2, + "batch_size": 768, + "epochs": 50, + "patience": 10, + "lr": 1e-3, + "sampling_method": "everything" + } +} diff --git a/scripts/eraserbenchmark/params/boolq_baas.json b/scripts/eraserbenchmark/params/boolq_baas.json new file mode 100644 index 0000000..9ea4928 --- /dev/null +++ b/scripts/eraserbenchmark/params/boolq_baas.json @@ -0,0 +1,26 @@ +{ + "start_server": 0, + "bert_dir": "model_components/uncased_L-12_H-768_A-12/", + "max_length": 512, + "pooling_strategy": "CLS_TOKEN", + "evidence_identifier": { + "batch_size": 64, + "epochs": 3, + "patience": 10, + "lr": 1e-3, + "max_grad_norm": 1.0, + "sampling_method": "random", + "sampling_ratio": 1.0 + }, + "evidence_classifier": { + "classes": [ "False", "True" ], + "batch_size": 64, + "epochs": 3, + "patience": 10, + "lr": 1e-3, + "max_grad_norm": 1.0, + "sampling_method": "everything" + } +} + + diff --git a/scripts/eraserbenchmark/params/boolq_bert.json b/scripts/eraserbenchmark/params/boolq_bert.json new file mode 100644 index 0000000..e454fb1 --- /dev/null +++ b/scripts/eraserbenchmark/params/boolq_bert.json @@ -0,0 +1,32 @@ +{ + "max_length": 512, + "bert_vocab": "bert-base-uncased", + "bert_dir": "bert-base-uncased", + "use_evidence_sentence_identifier": 1, + "use_evidence_token_identifier": 0, + "evidence_identifier": { + "batch_size": 10, + "epochs": 10, + "patience": 10, + "warmup_steps": 50, + "lr": 1e-05, + "max_grad_norm": 1, + "sampling_method": "random", + "sampling_ratio": 1, + "use_half_precision": 0 + }, + "evidence_classifier": { + "classes": [ + "False", + "True" + ], + "batch_size": 10, + "warmup_steps": 50, + "epochs": 10, + "patience": 10, + "lr": 1e-05, + "max_grad_norm": 1, + "sampling_method": "everything", + "use_half_precision": 0 + } +} diff --git a/scripts/eraserbenchmark/params/boolq_soft.json b/scripts/eraserbenchmark/params/boolq_soft.json new file mode 100644 index 0000000..721697d --- /dev/null +++ b/scripts/eraserbenchmark/params/boolq_soft.json @@ -0,0 +1,21 @@ +{ + "embeddings": { + "embedding_file": "model_components/glove.6B.200d.txt", + "dropout": 0.2 + }, + "classifier": { + "classes": [ "False", "True" ], + "has_query": 1, + "hidden_size": 32, + "mlp_size": 128, + "dropout": 0.2, + "batch_size": 16, + "epochs": 50, + "attention_epochs": 50, + "patience": 10, + "lr": 1e-3, + "dropout": 0.2, + "k_fraction": 0.07, + "threshold": 0.1 + } +} diff --git a/scripts/eraserbenchmark/params/cose_bert.json b/scripts/eraserbenchmark/params/cose_bert.json new file mode 100644 index 0000000..f32cadd --- /dev/null +++ b/scripts/eraserbenchmark/params/cose_bert.json @@ -0,0 +1,30 @@ +{ + "max_length": 512, + "bert_vocab": "bert-base-uncased", + "bert_dir": "bert-base-uncased", + "use_evidence_sentence_identifier": 0, + "use_evidence_token_identifier": 1, + "evidence_token_identifier": { + "batch_size": 32, + "epochs": 10, + "patience": 10, + "warmup_steps": 10, + "lr": 1e-05, + "max_grad_norm": 0.5, + "sampling_method": "everything", + "use_half_precision": 0, + "cose_data_hack": 1 + }, + "evidence_classifier": { + "classes": [ "false", "true"], + "batch_size": 32, + "warmup_steps": 10, + "epochs": 10, + "patience": 10, + "lr": 1e-05, + "max_grad_norm": 0.5, + "sampling_method": "everything", + "use_half_precision": 0, + "cose_data_hack": 1 + } +} diff --git a/scripts/eraserbenchmark/params/esnli_bert.json b/scripts/eraserbenchmark/params/esnli_bert.json new file mode 100644 index 0000000..7feb838 --- /dev/null +++ b/scripts/eraserbenchmark/params/esnli_bert.json @@ -0,0 +1,28 @@ +{ + "max_length": 512, + "bert_vocab": "bert-base-uncased", + "bert_dir": "bert-base-uncased", + "use_evidence_sentence_identifier": 0, + "use_evidence_token_identifier": 1, + "evidence_token_identifier": { + "batch_size": 32, + "epochs": 10, + "patience": 10, + "warmup_steps": 10, + "lr": 1e-05, + "max_grad_norm": 1, + "sampling_method": "everything", + "use_half_precision": 0 + }, + "evidence_classifier": { + "classes": [ "contradiction", "neutral", "entailment" ], + "batch_size": 32, + "warmup_steps": 10, + "epochs": 10, + "patience": 10, + "lr": 1e-05, + "max_grad_norm": 1, + "sampling_method": "everything", + "use_half_precision": 0 + } +} diff --git a/scripts/eraserbenchmark/params/evidence_inference.json b/scripts/eraserbenchmark/params/evidence_inference.json new file mode 100644 index 0000000..910dcff --- /dev/null +++ b/scripts/eraserbenchmark/params/evidence_inference.json @@ -0,0 +1,26 @@ +{ + "embeddings": { + "embedding_file": "model_components/PubMed-w2v.bin", + "dropout": 0.05 + }, + "evidence_identifier": { + "mlp_size": 128, + "dropout": 0.05, + "batch_size": 768, + "epochs": 50, + "patience": 10, + "lr": 1e-3, + "sampling_method": "random", + "sampling_ratio": 1.0 + }, + "evidence_classifier": { + "classes": [ "significantly decreased", "no significant difference", "significantly increased" ], + "mlp_size": 128, + "dropout": 0.05, + "batch_size": 768, + "epochs": 50, + "patience": 10, + "lr": 1e-3, + "sampling_method": "everything" + } +} diff --git a/scripts/eraserbenchmark/params/evidence_inference_bert.json b/scripts/eraserbenchmark/params/evidence_inference_bert.json new file mode 100644 index 0000000..b595a64 --- /dev/null +++ b/scripts/eraserbenchmark/params/evidence_inference_bert.json @@ -0,0 +1,33 @@ +{ + "max_length": 512, + "bert_vocab": "allenai/scibert_scivocab_uncased", + "bert_dir": "allenai/scibert_scivocab_uncased", + "use_evidence_sentence_identifier": 1, + "use_evidence_token_identifier": 0, + "evidence_identifier": { + "batch_size": 10, + "epochs": 10, + "patience": 10, + "warmup_steps": 10, + "lr": 1e-05, + "max_grad_norm": 1, + "sampling_method": "random", + "use_half_precision": 0, + "sampling_ratio": 1 + }, + "evidence_classifier": { + "classes": [ + "significantly decreased", + "no significant difference", + "significantly increased" + ], + "batch_size": 10, + "warmup_steps": 10, + "epochs": 10, + "patience": 10, + "lr": 1e-05, + "max_grad_norm": 1, + "sampling_method": "everything", + "use_half_precision": 0 + } +} diff --git a/scripts/eraserbenchmark/params/evidence_inference_soft.json b/scripts/eraserbenchmark/params/evidence_inference_soft.json new file mode 100644 index 0000000..e416f0b --- /dev/null +++ b/scripts/eraserbenchmark/params/evidence_inference_soft.json @@ -0,0 +1,22 @@ +{ + "embeddings": { + "embedding_file": "model_components/PubMed-w2v.bin", + "dropout": 0.2 + }, + "classifier": { + "classes": [ "significantly decreased", "no significant difference", "significantly increased" ], + "use_token_selection": 1, + "has_query": 1, + "hidden_size": 32, + "mlp_size": 128, + "dropout": 0.2, + "batch_size": 16, + "epochs": 50, + "attention_epochs": 0, + "patience": 10, + "lr": 1e-3, + "dropout": 0.2, + "k_fraction": 0.013, + "threshold": 0.1 + } +} diff --git a/scripts/eraserbenchmark/params/fever.json b/scripts/eraserbenchmark/params/fever.json new file mode 100644 index 0000000..3933882 --- /dev/null +++ b/scripts/eraserbenchmark/params/fever.json @@ -0,0 +1,26 @@ +{ + "embeddings": { + "embedding_file": "model_components/glove.6B.200d.txt", + "dropout": 0.05 + }, + "evidence_identifier": { + "mlp_size": 128, + "dropout": 0.05, + "batch_size": 768, + "epochs": 50, + "patience": 10, + "lr": 1e-3, + "sampling_method": "random", + "sampling_ratio": 1.0 + }, + "evidence_classifier": { + "classes": [ "SUPPORTS", "REFUTES" ], + "mlp_size": 128, + "dropout": 0.05, + "batch_size": 768, + "epochs": 50, + "patience": 10, + "lr": 1e-3, + "sampling_method": "everything" + } +} diff --git a/scripts/eraserbenchmark/params/fever_baas.json b/scripts/eraserbenchmark/params/fever_baas.json new file mode 100644 index 0000000..10c4c9f --- /dev/null +++ b/scripts/eraserbenchmark/params/fever_baas.json @@ -0,0 +1,25 @@ +{ + "start_server": 0, + "bert_dir": "model_components/uncased_L-12_H-768_A-12/", + "max_length": 512, + "pooling_strategy": "CLS_TOKEN", + "evidence_identifier": { + "batch_size": 64, + "epochs": 3, + "patience": 10, + "lr": 1e-3, + "max_grad_norm": 1.0, + "sampling_method": "random", + "sampling_ratio": 1.0 + }, + "evidence_classifier": { + "classes": [ "SUPPORTS", "REFUTES" ], + "batch_size": 64, + "epochs": 3, + "patience": 10, + "lr": 1e-3, + "max_grad_norm": 1.0, + "sampling_method": "everything" + } +} + diff --git a/scripts/eraserbenchmark/params/fever_bert.json b/scripts/eraserbenchmark/params/fever_bert.json new file mode 100644 index 0000000..64fcf80 --- /dev/null +++ b/scripts/eraserbenchmark/params/fever_bert.json @@ -0,0 +1,32 @@ +{ + "max_length": 512, + "bert_vocab": "bert-base-uncased", + "bert_dir": "bert-base-uncased", + "use_evidence_sentence_identifier": 1, + "use_evidence_token_identifier": 0, + "evidence_identifier": { + "batch_size": 16, + "epochs": 10, + "patience": 10, + "warmup_steps": 10, + "lr": 1e-05, + "max_grad_norm": 1.0, + "sampling_method": "random", + "sampling_ratio": 1.0, + "use_half_precision": 0 + }, + "evidence_classifier": { + "classes": [ + "SUPPORTS", + "REFUTES" + ], + "batch_size": 16, + "warmup_steps": 10, + "epochs": 10, + "patience": 10, + "lr": 1e-05, + "max_grad_norm": 1.0, + "sampling_method": "everything", + "use_half_precision": 0 + } +} diff --git a/scripts/eraserbenchmark/params/fever_soft.json b/scripts/eraserbenchmark/params/fever_soft.json new file mode 100644 index 0000000..acd23ef --- /dev/null +++ b/scripts/eraserbenchmark/params/fever_soft.json @@ -0,0 +1,21 @@ +{ + "embeddings": { + "embedding_file": "model_components/glove.6B.200d.txt", + "dropout": 0.2 + }, + "classifier": { + "classes": [ "SUPPORTS", "REFUTES" ], + "has_query": 1, + "hidden_size": 32, + "mlp_size": 128, + "dropout": 0.2, + "batch_size": 128, + "epochs": 50, + "attention_epochs": 50, + "patience": 10, + "lr": 1e-3, + "dropout": 0.2, + "k_fraction": 0.07, + "threshold": 0.1 + } +} diff --git a/scripts/eraserbenchmark/params/movies.json b/scripts/eraserbenchmark/params/movies.json new file mode 100644 index 0000000..546f21c --- /dev/null +++ b/scripts/eraserbenchmark/params/movies.json @@ -0,0 +1,26 @@ +{ + "embeddings": { + "embedding_file": "model_components/glove.6B.200d.txt", + "dropout": 0.05 + }, + "evidence_identifier": { + "mlp_size": 128, + "dropout": 0.05, + "batch_size": 768, + "epochs": 50, + "patience": 10, + "lr": 1e-4, + "sampling_method": "random", + "sampling_ratio": 1.0 + }, + "evidence_classifier": { + "classes": [ "NEG", "POS" ], + "mlp_size": 128, + "dropout": 0.05, + "batch_size": 768, + "epochs": 50, + "patience": 10, + "lr": 1e-3, + "sampling_method": "everything" + } +} diff --git a/scripts/eraserbenchmark/params/movies_baas.json b/scripts/eraserbenchmark/params/movies_baas.json new file mode 100644 index 0000000..846b020 --- /dev/null +++ b/scripts/eraserbenchmark/params/movies_baas.json @@ -0,0 +1,26 @@ +{ + "start_server": 0, + "bert_dir": "model_components/uncased_L-12_H-768_A-12/", + "max_length": 512, + "pooling_strategy": "CLS_TOKEN", + "evidence_identifier": { + "batch_size": 64, + "epochs": 3, + "patience": 10, + "lr": 1e-3, + "max_grad_norm": 1.0, + "sampling_method": "random", + "sampling_ratio": 1.0 + }, + "evidence_classifier": { + "classes": [ "NEG", "POS" ], + "batch_size": 64, + "epochs": 3, + "patience": 10, + "lr": 1e-3, + "max_grad_norm": 1.0, + "sampling_method": "everything" + } +} + + diff --git a/scripts/eraserbenchmark/params/movies_bert.json b/scripts/eraserbenchmark/params/movies_bert.json new file mode 100644 index 0000000..fa09aef --- /dev/null +++ b/scripts/eraserbenchmark/params/movies_bert.json @@ -0,0 +1,32 @@ +{ + "max_length": 512, + "bert_vocab": "bert-base-uncased", + "bert_dir": "bert-base-uncased", + "use_evidence_sentence_identifier": 1, + "use_evidence_token_identifier": 0, + "evidence_identifier": { + "batch_size": 32, + "epochs": 10, + "patience": 10, + "warmup_steps": 50, + "lr": 1e-05, + "max_grad_norm": 1, + "sampling_method": "random", + "sampling_ratio": 1, + "use_half_precision": 0 + }, + "evidence_classifier": { + "classes": [ + "NEG", + "POS" + ], + "batch_size": 32, + "warmup_steps": 50, + "epochs": 10, + "patience": 10, + "lr": 1e-05, + "max_grad_norm": 1, + "sampling_method": "everything", + "use_half_precision": 0 + } +} diff --git a/scripts/eraserbenchmark/params/movies_soft.json b/scripts/eraserbenchmark/params/movies_soft.json new file mode 100644 index 0000000..99d54da --- /dev/null +++ b/scripts/eraserbenchmark/params/movies_soft.json @@ -0,0 +1,21 @@ +{ + "embeddings": { + "embedding_file": "model_components/glove.6B.200d.txt", + "dropout": 0.2 + }, + "classifier": { + "classes": [ "NEG", "POS" ], + "has_query": 0, + "hidden_size": 32, + "mlp_size": 128, + "dropout": 0.2, + "batch_size": 16, + "epochs": 50, + "attention_epochs": 50, + "patience": 10, + "lr": 1e-3, + "dropout": 0.2, + "k_fraction": 0.07, + "threshold": 0.1 + } +} diff --git a/scripts/eraserbenchmark/params/multirc.json b/scripts/eraserbenchmark/params/multirc.json new file mode 100644 index 0000000..dc3cb2d --- /dev/null +++ b/scripts/eraserbenchmark/params/multirc.json @@ -0,0 +1,26 @@ +{ + "embeddings": { + "embedding_file": "model_components/glove.6B.200d.txt", + "dropout": 0.05 + }, + "evidence_identifier": { + "mlp_size": 128, + "dropout": 0.05, + "batch_size": 768, + "epochs": 50, + "patience": 10, + "lr": 1e-3, + "sampling_method": "random", + "sampling_ratio": 1.0 + }, + "evidence_classifier": { + "classes": [ "False", "True" ], + "mlp_size": 128, + "dropout": 0.05, + "batch_size": 768, + "epochs": 50, + "patience": 10, + "lr": 1e-3, + "sampling_method": "everything" + } +} diff --git a/scripts/eraserbenchmark/params/multirc_baas.json b/scripts/eraserbenchmark/params/multirc_baas.json new file mode 100644 index 0000000..9ea4928 --- /dev/null +++ b/scripts/eraserbenchmark/params/multirc_baas.json @@ -0,0 +1,26 @@ +{ + "start_server": 0, + "bert_dir": "model_components/uncased_L-12_H-768_A-12/", + "max_length": 512, + "pooling_strategy": "CLS_TOKEN", + "evidence_identifier": { + "batch_size": 64, + "epochs": 3, + "patience": 10, + "lr": 1e-3, + "max_grad_norm": 1.0, + "sampling_method": "random", + "sampling_ratio": 1.0 + }, + "evidence_classifier": { + "classes": [ "False", "True" ], + "batch_size": 64, + "epochs": 3, + "patience": 10, + "lr": 1e-3, + "max_grad_norm": 1.0, + "sampling_method": "everything" + } +} + + diff --git a/scripts/eraserbenchmark/params/multirc_bert.json b/scripts/eraserbenchmark/params/multirc_bert.json new file mode 100644 index 0000000..1ab31b5 --- /dev/null +++ b/scripts/eraserbenchmark/params/multirc_bert.json @@ -0,0 +1,32 @@ +{ + "max_length": 512, + "bert_vocab": "bert-base-uncased", + "bert_dir": "bert-base-uncased", + "use_evidence_sentence_identifier": 1, + "use_evidence_token_identifier": 0, + "evidence_identifier": { + "batch_size": 32, + "epochs": 10, + "patience": 10, + "warmup_steps": 50, + "lr": 1e-05, + "max_grad_norm": 1, + "sampling_method": "random", + "sampling_ratio": 1, + "use_half_precision": 0 + }, + "evidence_classifier": { + "classes": [ + "False", + "True" + ], + "batch_size": 32, + "warmup_steps": 50, + "epochs": 10, + "patience": 10, + "lr": 1e-05, + "max_grad_norm": 1, + "sampling_method": "everything", + "use_half_precision": 0 + } +} diff --git a/scripts/eraserbenchmark/params/multirc_soft.json b/scripts/eraserbenchmark/params/multirc_soft.json new file mode 100644 index 0000000..721697d --- /dev/null +++ b/scripts/eraserbenchmark/params/multirc_soft.json @@ -0,0 +1,21 @@ +{ + "embeddings": { + "embedding_file": "model_components/glove.6B.200d.txt", + "dropout": 0.2 + }, + "classifier": { + "classes": [ "False", "True" ], + "has_query": 1, + "hidden_size": 32, + "mlp_size": 128, + "dropout": 0.2, + "batch_size": 16, + "epochs": 50, + "attention_epochs": 50, + "patience": 10, + "lr": 1e-3, + "dropout": 0.2, + "k_fraction": 0.07, + "threshold": 0.1 + } +} diff --git a/scripts/eraserbenchmark/rationale_benchmark/__init__.py b/scripts/eraserbenchmark/rationale_benchmark/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/eraserbenchmark/rationale_benchmark/metrics.py b/scripts/eraserbenchmark/rationale_benchmark/metrics.py new file mode 100644 index 0000000..f41296d --- /dev/null +++ b/scripts/eraserbenchmark/rationale_benchmark/metrics.py @@ -0,0 +1,760 @@ +import argparse +import json +import logging +import os +import pprint + +from collections import Counter, defaultdict, namedtuple +from dataclasses import dataclass +from itertools import chain +from typing import Any, Callable, Dict, List, Set, Tuple + +import numpy as np +import torch + +from scipy.stats import entropy +from sklearn.metrics import accuracy_score, auc, average_precision_score, classification_report, precision_recall_curve, roc_auc_score + +from rationale_benchmark.utils import ( + Annotation, + Evidence, + annotations_from_jsonl, + load_jsonl, + load_documents, + load_flattened_documents + ) + +logging.basicConfig(level=logging.DEBUG, format='%(relativeCreated)6d %(threadName)s %(message)s') + +# start_token is inclusive, end_token is exclusive +@dataclass(eq=True, frozen=True) +class Rationale: + ann_id: str + docid: str + start_token: int + end_token: int + + def to_token_level(self) -> List['Rationale']: + ret = [] + for t in range(self.start_token, self.end_token): + ret.append(Rationale(self.ann_id, self.docid, t, t+1)) + return ret + + @classmethod + def from_annotation(cls, ann: Annotation) -> List['Rationale']: + ret = [] + for ev_group in ann.evidences: + for ev in ev_group: + ret.append(Rationale(ann.annotation_id, ev.docid, ev.start_token, ev.end_token)) + return ret + + @classmethod + def from_instance(cls, inst: dict) -> List['Rationale']: + ret = [] + for rat in inst['rationales']: + for pred in rat.get('hard_rationale_predictions', []): + ret.append(Rationale(inst['annotation_id'], rat['docid'], pred['start_token'], pred['end_token'])) + return ret + +@dataclass(eq=True, frozen=True) +class PositionScoredDocument: + ann_id: str + docid: str + scores: Tuple[float] + truths: Tuple[bool] + + @classmethod + def from_results(cls, instances: List[dict], annotations: List[Annotation], docs: Dict[str, List[Any]], use_tokens: bool=True) -> List['PositionScoredDocument']: + """Creates a paired list of annotation ids/docids/predictions/truth values""" + key_to_annotation = dict() + for ann in annotations: + for ev in chain.from_iterable(ann.evidences): + key = (ann.annotation_id, ev.docid) + if key not in key_to_annotation: + key_to_annotation[key] = [False for _ in docs[ev.docid]] + if use_tokens: + start, end = ev.start_token, ev.end_token + else: + start, end = ev.start_sentence, ev.end_sentence + for t in range(start, end): + key_to_annotation[key][t] = True + ret = [] + if use_tokens: + field = 'soft_rationale_predictions' + else: + field = 'soft_sentence_predictions' + for inst in instances: + for rat in inst['rationales']: + docid = rat['docid'] + scores = rat[field] + key = (inst['annotation_id'], docid) + assert len(scores) == len(docs[docid]) + if key in key_to_annotation : + assert len(scores) == len(key_to_annotation[key]) + else : + #In case model makes a prediction on docuemnt(s) for which ground truth evidence is not present + key_to_annotation[key] = [False for _ in docs[docid]] + ret.append(PositionScoredDocument(inst['annotation_id'], docid, tuple(scores), tuple(key_to_annotation[key]))) + return ret + +def _f1(_p, _r): + if _p == 0 or _r == 0: + return 0 + return 2 * _p * _r / (_p + _r) + +def _keyed_rationale_from_list(rats: List[Rationale]) -> Dict[Tuple[str, str], Rationale]: + ret = defaultdict(set) + for r in rats: + ret[(r.ann_id, r.docid)].add(r) + return ret + +def partial_match_score(truth: List[Rationale], pred: List[Rationale], thresholds: List[float]) -> List[Dict[str, Any]]: + """Computes a partial match F1 + + Computes an instance-level (annotation) micro- and macro-averaged F1 score. + True Positives are computed by using intersection-over-union and + thresholding the resulting intersection-over-union fraction. + + Micro-average results are computed by ignoring instance level distinctions + in the TP calculation (and recall, and precision, and finally the F1 of + those numbers). Macro-average results are computed first by measuring + instance (annotation + document) precisions and recalls, averaging those, + and finally computing an F1 of the resulting average. + """ + + ann_to_rat = _keyed_rationale_from_list(truth) + pred_to_rat = _keyed_rationale_from_list(pred) + print(ann_to_rat.items()) + print(pred_to_rat.items()) + num_classifications = {k:len(v) for k,v in pred_to_rat.items()} + num_truth = {k:len(v) for k,v in ann_to_rat.items()} + ious = defaultdict(dict) + for k in set(ann_to_rat.keys()) | set(pred_to_rat.keys()): + for p in pred_to_rat.get(k, []): + best_iou = 0.0 + for t in ann_to_rat.get(k, []): + num = len(set(range(p.start_token, p.end_token)) & set(range(t.start_token, t.end_token))) + denom = len(set(range(p.start_token, p.end_token)) | set(range(t.start_token, t.end_token))) + iou = 0 if denom == 0 else num / denom + if iou > best_iou: + best_iou = iou + ious[k][p] = best_iou + scores = [] + for threshold in thresholds: + threshold_tps = dict() + for k, vs in ious.items(): + threshold_tps[k] = sum(int(x >= threshold) for x in vs.values()) + micro_r = sum(threshold_tps.values()) / sum(num_truth.values()) if sum(num_truth.values()) > 0 else 0 + micro_p = sum(threshold_tps.values()) / sum(num_classifications.values()) if sum(num_classifications.values()) > 0 else 0 + micro_f1 = _f1(micro_r, micro_p) + macro_rs = list(threshold_tps.get(k, 0.0) / n if n > 0 else 0 for k, n in num_truth.items()) + macro_ps = list(threshold_tps.get(k, 0.0) / n if n > 0 else 0 for k, n in num_classifications.items()) + macro_r = sum(macro_rs) / len(macro_rs) if len(macro_rs) > 0 else 0 + macro_p = sum(macro_ps) / len(macro_ps) if len(macro_ps) > 0 else 0 + macro_f1 = _f1(macro_r, macro_p) + scores.append({'threshold': threshold, + 'micro': { + 'p': micro_p, + 'r': micro_r, + 'f1': micro_f1 + }, + 'macro': { + 'p': macro_p, + 'r': macro_r, + 'f1': macro_f1 + }, + }) + + return scores + +def score_hard_rationale_predictions(truth: List[Rationale], pred: List[Rationale]) -> Dict[str, Dict[str, float]]: + """Computes instance (annotation)-level micro/macro averaged F1s""" + scores = dict() + truth = set(truth) + pred = set(pred) + micro_prec = len(truth & pred) / (len(pred) if len(pred)>0 else 1) + micro_rec = len(truth & pred) / (len(truth) if len(truth)>0 else 1) + micro_f1 = _f1(micro_prec, micro_rec) + + scores['instance_micro'] = { + 'p': micro_prec, + 'r': micro_rec, + 'f1': micro_f1, + } + + ann_to_rat = _keyed_rationale_from_list(truth) + pred_to_rat = _keyed_rationale_from_list(pred) + instances_to_scores = dict() + for k in set(ann_to_rat.keys()) | (pred_to_rat.keys()): + if len(pred_to_rat.get(k, set())) > 0: + instance_prec = len(ann_to_rat.get(k, set()) & pred_to_rat.get(k, set())) / len(pred_to_rat[k]) + else: + instance_prec = 0 + if len(ann_to_rat.get(k, set())) > 0: + instance_rec = len(ann_to_rat.get(k, set()) & pred_to_rat.get(k, set())) / len(ann_to_rat[k]) + else: + instance_rec = 0 + instance_f1 = _f1(instance_prec, instance_rec) + instances_to_scores[k] = { + 'p': instance_prec, + 'r': instance_rec, + 'f1': instance_f1, + } + # these are calculated as sklearn would + macro_prec = sum(instance['p'] for instance in instances_to_scores.values()) / (len(instances_to_scores) if len(instances_to_scores)>0 else 1) + macro_rec = sum(instance['r'] for instance in instances_to_scores.values()) / (len(instances_to_scores) if len(instances_to_scores)>0 else 1) + macro_f1 = sum(instance['f1'] for instance in instances_to_scores.values()) / (len(instances_to_scores) if len(instances_to_scores)>0 else 1) + scores['instance_macro'] = { + 'p': macro_prec, + 'r': macro_rec, + 'f1': macro_f1, + } + return scores + +def _auprc(truth: Dict[Any, List[bool]], preds: Dict[Any, List[float]]) -> float: + if len(preds) == 0: + return 0.0 + assert len(truth.keys() and preds.keys()) == len(truth.keys()) + aucs = [] + for k, true in truth.items(): + pred = preds[k] + true = [int(t) for t in true] + precision, recall, _ = precision_recall_curve(true, pred) + aucs.append(auc(recall, precision)) + return np.average(aucs) + +def _score_aggregator(truth: Dict[Any, List[bool]], preds: Dict[Any, List[float]], score_function: Callable[[List[float], List[float]], float ], discard_single_class_answers: bool) -> float: + if len(preds) == 0: + return 0.0 + assert len(truth.keys() and preds.keys()) == len(truth.keys()) + scores = [] + for k, true in truth.items(): + pred = preds[k] + if (all(true) or all(not x for x in true)) and discard_single_class_answers: + continue + true = [int(t) for t in true] + scores.append(score_function(true, pred)) + return np.average(scores) + +def score_soft_tokens(paired_scores: List[PositionScoredDocument]) -> Dict[str, float]: + truth = {(ps.ann_id, ps.docid): ps.truths for ps in paired_scores} + pred = {(ps.ann_id, ps.docid): ps.scores for ps in paired_scores} + auprc_score = _auprc(truth, pred) + ap = _score_aggregator(truth, pred, average_precision_score, True) + roc_auc = _score_aggregator(truth, pred, roc_auc_score, True) + + return { + 'auprc': auprc_score, + 'average_precision': ap, + 'roc_auc_score': roc_auc, + } + +def _instances_aopc(instances: List[dict], thresholds: List[float], key: str) -> Tuple[float, List[float]]: + dataset_scores = [] + for inst in instances: + kls = inst['classification'] + beta_0 = inst['classification_scores'][kls] + instance_scores = [] + for score in filter(lambda x : x['threshold'] in thresholds, sorted(inst['thresholded_scores'], key=lambda x: x['threshold'])): + beta_k = score[key][kls] + delta = beta_0 - beta_k + instance_scores.append(delta) + assert len(instance_scores) == len(thresholds) + dataset_scores.append(instance_scores) + dataset_scores = np.array(dataset_scores) + # a careful reading of Samek, et al. "Evaluating the Visualization of What a Deep Neural Network Has Learned" + # and some algebra will show the reader that we can average in any of several ways and get the same result: + # over a flattened array, within an instance and then between instances, or over instances (by position) an + # then across them. + final_score = np.average(dataset_scores) + position_scores = np.average(dataset_scores, axis=0).tolist() + + return final_score, position_scores + +def compute_aopc_scores(instances: List[dict], aopc_thresholds: List[float]): + if aopc_thresholds is None : + aopc_thresholds = sorted(set(chain.from_iterable([x['threshold'] for x in y['thresholded_scores']] for y in instances))) + aopc_comprehensiveness_score, aopc_comprehensiveness_points = _instances_aopc(instances, aopc_thresholds, 'comprehensiveness_classification_scores') + aopc_sufficiency_score, aopc_sufficiency_points = _instances_aopc(instances, aopc_thresholds, 'sufficiency_classification_scores') + return aopc_thresholds, aopc_comprehensiveness_score, aopc_comprehensiveness_points, aopc_sufficiency_score, aopc_sufficiency_points + +def score_classifications(instances: List[dict], annotations: List[Annotation], docs: Dict[str, List[str]], aopc_thresholds: List[float]) -> Dict[str, float]: + def compute_kl(cls_scores_, faith_scores_): + keys = list(cls_scores_.keys()) + cls_scores_ = [cls_scores_[k] for k in keys] + faith_scores_ = [faith_scores_[k] for k in keys] + return entropy(faith_scores_, cls_scores_) + labels = list(set(x.classification for x in annotations)) + labels +=['None'] + #print("UIUIUIIU") + label_to_int = {l:i for i,l in enumerate(labels)} + key_to_instances = {inst['annotation_id']:inst for inst in instances} + truth = [] + predicted = [] + for ann in annotations: + truth.append(label_to_int[ann.classification]) + inst = key_to_instances[ann.annotation_id] + predicted.append(label_to_int[inst['classification']]) + classification_scores = classification_report(truth, predicted, output_dict=True, target_names=labels, digits=3) + accuracy = accuracy_score(truth, predicted) + if 'comprehensiveness_classification_scores' in instances[0]: + comprehensiveness_scores = [x['classification_scores'][x['classification']] - x['comprehensiveness_classification_scores'][x['classification']] for x in instances] + comprehensiveness_score = np.average(comprehensiveness_scores) + else : + comprehensiveness_score = None + comprehensiveness_scores = None + + if 'sufficiency_classification_scores' in instances[0]: + sufficiency_scores = [x['classification_scores'][x['classification']] - x['sufficiency_classification_scores'][x['classification']] for x in instances] + sufficiency_score = np.average(sufficiency_scores) + else : + sufficiency_score = None + sufficiency_scores = None + + if 'comprehensiveness_classification_scores' in instances[0]: + comprehensiveness_entropies = [entropy(list(x['classification_scores'].values())) - entropy(list(x['comprehensiveness_classification_scores'].values())) for x in instances] + comprehensiveness_entropy = np.average(comprehensiveness_entropies) + comprehensiveness_kl = np.average(list(compute_kl(x['classification_scores'], x['comprehensiveness_classification_scores']) for x in instances)) + else: + comprehensiveness_entropies = None + comprehensiveness_kl = None + comprehensiveness_entropy = None + + if 'sufficiency_classification_scores' in instances[0]: + sufficiency_entropies = [entropy(list(x['classification_scores'].values())) - entropy(list(x['sufficiency_classification_scores'].values())) for x in instances] + sufficiency_entropy = np.average(sufficiency_entropies) + sufficiency_kl = np.average(list(compute_kl(x['classification_scores'], x['sufficiency_classification_scores']) for x in instances)) + else: + sufficiency_entropies = None + sufficiency_kl = None + sufficiency_entropy = None + + if 'thresholded_scores' in instances[0]: + aopc_thresholds, aopc_comprehensiveness_score, aopc_comprehensiveness_points, aopc_sufficiency_score, aopc_sufficiency_points = compute_aopc_scores(instances, aopc_thresholds) + else: + aopc_thresholds, aopc_comprehensiveness_score, aopc_comprehensiveness_points, aopc_sufficiency_score, aopc_sufficiency_points = None, None, None, None, None + if 'tokens_to_flip' in instances[0]: + token_percentages = [] + for ann in annotations: + # in practice, this is of size 1 for everything except e-snli + docids = set(ev.docid for ev in chain.from_iterable(ann.evidences)) + inst = key_to_instances[ann.annotation_id] + tokens = inst['tokens_to_flip'] + doc_lengths = sum(len(docs[d]) for d in docids) + token_percentages.append(tokens / doc_lengths) + token_percentages = np.average(token_percentages) + else: + token_percentages = None + + return { + 'accuracy': accuracy, + 'prf': classification_scores, + 'comprehensiveness': comprehensiveness_score, + 'sufficiency': sufficiency_score, + 'comprehensiveness_entropy': comprehensiveness_entropy, + 'comprehensiveness_kl': comprehensiveness_kl, + 'sufficiency_entropy': sufficiency_entropy, + 'sufficiency_kl': sufficiency_kl, + 'aopc_thresholds': aopc_thresholds, + 'comprehensiveness_aopc': aopc_comprehensiveness_score, + 'comprehensiveness_aopc_points': aopc_comprehensiveness_points, + 'sufficiency_aopc': aopc_sufficiency_score, + 'sufficiency_aopc_points': aopc_sufficiency_points, + } + +def verify_instance(instance: dict, docs: Dict[str, list], thresholds: Set[float]): + error = False + docids = [] + # verify the internal structure of these instances is correct: + # * hard predictions are present + # * start and end tokens are valid + # * soft rationale predictions, if present, must have the same document length + + for rat in instance['rationales']: + docid = rat['docid'] + if docid not in docid: + error = True + logging.info(f'Error! For instance annotation={instance["annotation_id"]}, docid={docid} could not be found as a preprocessed document! Gave up on additional processing.') + continue + doc_length = len(docs[docid]) + for h1 in rat.get('hard_rationale_predictions', []): + # verify that each token is valid + # verify that no annotations overlap + for h2 in rat.get('hard_rationale_predictions', []): + if h1 == h2: + continue + if len(set(range(h1['start_token'], h1['end_token'])) & set(range(h2['start_token'], h2['end_token']))) > 0: + logging.info(f'Error! For instance annotation={instance["annotation_id"]}, docid={docid} {h1} and {h2} overlap!') + error = True + if h1['start_token'] > doc_length: + logging.info(f'Error! For instance annotation={instance["annotation_id"]}, docid={docid} received an impossible tokenspan: {h1} for a document of length {doc_length}') + error = True + if h1['end_token'] > doc_length: + logging.info(f'Error! For instance annotation={instance["annotation_id"]}, docid={docid} received an impossible tokenspan: {h1} for a document of length {doc_length}') + error = True + # length check for soft rationale + # note that either flattened_documents or sentence-broken documents must be passed in depending on result + soft_rationale_predictions = rat.get('soft_rationale_predictions', []) + if len(soft_rationale_predictions) > 0 and len(soft_rationale_predictions) != doc_length: + logging.info(f'Error! For instance annotation={instance["annotation_id"]}, docid={docid} expected classifications for {doc_length} tokens but have them for {len(soft_rationale_predictions)} tokens instead!') + error = True + + # count that one appears per-document + docids = Counter(docids) + for docid, count in docids.items(): + if count > 1: + error = True + logging.info('Error! For instance annotation={instance["annotation_id"]}, docid={docid} appear {count} times, may only appear once!') + + classification = instance.get('classification', '') + if not isinstance(classification, str): + logging.info(f'Error! For instance annotation={instance["annotation_id"]}, classification field {classification} is not a string!') + error = True + classification_scores = instance.get('classification_scores', dict()) + if not isinstance(classification_scores, dict): + logging.info(f'Error! For instance annotation={instance["annotation_id"]}, classification_scores field {classification_scores} is not a dict!') + error = True + comprehensiveness_classification_scores = instance.get('comprehensiveness_classification_scores', dict()) + if not isinstance(comprehensiveness_classification_scores, dict): + logging.info(f'Error! For instance annotation={instance["annotation_id"]}, comprehensiveness_classification_scores field {comprehensiveness_classification_scores} is not a dict!') + error = True + sufficiency_classification_scores = instance.get('sufficiency_classification_scores', dict()) + if not isinstance(sufficiency_classification_scores, dict): + logging.info(f'Error! For instance annotation={instance["annotation_id"]}, sufficiency_classification_scores field {sufficiency_classification_scores} is not a dict!') + error = True + if ('classification' in instance) != ('classification_scores' in instance): + logging.info(f'Error! For instance annotation={instance["annotation_id"]}, when providing a classification, you must also provide classification scores!') + error = True + if ('comprehensiveness_classification_scores' in instance) and not ('classification' in instance): + logging.info(f'Error! For instance annotation={instance["annotation_id"]}, when providing a classification, you must also provide a comprehensiveness_classification_score') + error = True + if ('sufficiency_classification_scores' in instance) and not ('classification_scores' in instance): + logging.info(f'Error! For instance annotation={instance["annotation_id"]}, when providing a sufficiency_classification_score, you must also provide a classification score!') + error = True + if 'thresholded_scores' in instance: + instance_thresholds = set(x['threshold'] for x in instance['thresholded_scores']) + if instance_thresholds != thresholds: + error = True + logging.info('Error: {instance["thresholded_scores"]} has thresholds that differ from previous thresholds: {thresholds}') + if 'comprehensiveness_classification_scores' not in instance\ + or 'sufficiency_classification_scores' not in instance\ + or 'classification' not in instance\ + or 'classification_scores' not in instance: + error = True + logging.info('Error: {instance} must have comprehensiveness_classification_scores, sufficiency_classification_scores, classification, and classification_scores defined when including thresholded scores') + if not all('sufficiency_classification_scores' in x for x in instance['thresholded_scores']): + error = True + logging.info('Error: {instance} must have sufficiency_classification_scores for every threshold') + if not all('comprehensiveness_classification_scores' in x for x in instance['thresholded_scores']): + error = True + logging.info('Error: {instance} must have comprehensiveness_classification_scores for every threshold') + return error + +def verify_instances(instances: List[dict], docs: Dict[str, list]): + annotation_ids = list(x['annotation_id'] for x in instances) + key_counter = Counter(annotation_ids) + multi_occurrence_annotation_ids = list(filter(lambda kv: kv[1] > 1, key_counter.items())) + error = False + if len(multi_occurrence_annotation_ids) > 0: + error = True + logging.info(f'Error in instances: {len(multi_occurrence_annotation_ids)} appear multiple times in the annotations file: {multi_occurrence_annotation_ids}') + failed_validation = set() + instances_with_classification = list() + instances_with_soft_rationale_predictions = list() + instances_with_soft_sentence_predictions = list() + instances_with_comprehensiveness_classifications = list() + instances_with_sufficiency_classifications = list() + instances_with_thresholded_scores = list() + if 'thresholded_scores' in instances[0]: + thresholds = set(x['threshold'] for x in instances[0]['thresholded_scores']) + else: + thresholds = None + for instance in instances: + instance_error = verify_instance(instance, docs, thresholds) + if instance_error: + error = True + failed_validation.add(instance['annotation_id']) + if instance.get('classification', None) != None: + instances_with_classification.append(instance) + if instance.get('comprehensiveness_classification_scores', None) != None: + instances_with_comprehensiveness_classifications.append(instance) + if instance.get('sufficiency_classification_scores', None) != None: + instances_with_sufficiency_classifications.append(instance) + has_soft_rationales = [] + has_soft_sentences = [] + for rat in instance['rationales']: + if rat.get('soft_rationale_predictions', None) != None: + has_soft_rationales.append(rat) + if rat.get('soft_sentence_predictions', None) != None: + has_soft_sentences.append(rat) + if len(has_soft_rationales) > 0: + instances_with_soft_rationale_predictions.append(instance) + if len(has_soft_rationales) != len(instance['rationales']): + error = True + logging.info(f'Error: instance {instance["annotation"]} has soft rationales for some but not all reported documents!') + if len(has_soft_sentences) > 0: + instances_with_soft_sentence_predictions.append(instance) + if len(has_soft_sentences) != len(instance['rationales']): + error = True + logging.info(f'Error: instance {instance["annotation"]} has soft sentences for some but not all reported documents!') + if 'thresholded_scores' in instance: + instances_with_thresholded_scores.append(instance) + logging.info(f'Error in instances: {len(failed_validation)} instances fail validation: {failed_validation}') + if len(instances_with_classification) != 0 and len(instances_with_classification) != len(instances): + logging.info(f'Either all {len(instances)} must have a classification or none may, instead {len(instances_with_classification)} do!') + error = True + if len(instances_with_soft_sentence_predictions) != 0 and len(instances_with_soft_sentence_predictions) != len(instances): + logging.info(f'Either all {len(instances)} must have a sentence prediction or none may, instead {len(instances_with_soft_sentence_predictions)} do!') + error = True + if len(instances_with_soft_rationale_predictions) != 0 and len(instances_with_soft_rationale_predictions) != len(instances): + logging.info(f'Either all {len(instances)} must have a soft rationale prediction or none may, instead {len(instances_with_soft_rationale_predictions)} do!') + error = True + if len(instances_with_comprehensiveness_classifications) != 0 and len(instances_with_comprehensiveness_classifications) != len(instances): + error = True + logging.info(f'Either all {len(instances)} must have a comprehensiveness classification or none may, instead {len(instances_with_comprehensiveness_classifications)} do!') + if len(instances_with_sufficiency_classifications) != 0 and len(instances_with_sufficiency_classifications) != len(instances): + error = True + logging.info(f'Either all {len(instances)} must have a sufficiency classification or none may, instead {len(instances_with_sufficiency_classifications)} do!') + if len(instances_with_thresholded_scores) != 0 and len(instances_with_thresholded_scores) != len(instances): + error = True + logging.info(f'Either all {len(instances)} must have thresholded scores or none may, instead {len(instances_with_thresholded_scores)} do!') + if error: + raise ValueError('Some instances are invalid, please fix your formatting and try again') + +def _has_hard_predictions(results: List[dict]) -> bool: + # assumes that we have run "verification" over the inputs + (results[0]) + #print('rationales' in results[0]) + #print(len(results[0]['rationales']) > 0) + #print('hard_rationale_predictions' in results[0]['rationales'][0]) + #print(results[0]['rationales'][0]['hard_rationale_predictions'] is not None) + #print(len(results[0]['rationales'][0]['hard_rationale_predictions']) > 0) + #print(results[0]['rationales'][0]['hard_rationale_predictions']) + return 'rationales' in results[0]\ + and len(results[0]['rationales']) > 0\ + and 'hard_rationale_predictions' in results[0]['rationales'][0]\ + and results[0]['rationales'][0]['hard_rationale_predictions'] is not None\ + #and len(results[0]['rationales'][0]['hard_rationale_predictions']) > 0 + +def _has_soft_predictions(results: List[dict]) -> bool: + # assumes that we have run "verification" over the inputs + return 'rationales' in results[0] and len(results[0]['rationales']) > 0 and 'soft_rationale_predictions' in results[0]['rationales'][0] and results[0]['rationales'][0]['soft_rationale_predictions'] is not None + +def _has_soft_sentence_predictions(results: List[dict]) -> bool: + # assumes that we have run "verification" over the inputs + return 'rationales' in results[0] and len(results[0]['rationales']) > 0 and 'soft_sentence_predictions' in results[0]['rationales'][0] and results[0]['rationales'][0]['soft_sentence_predictions'] is not None + +def _has_classifications(results: List[dict]) -> bool: + # assumes that we have run "verification" over the inputs + return 'classification' in results[0] and results[0]['classification'] is not None + +def main(): + parser = argparse.ArgumentParser(description="""Computes rationale and final class classification scores""", formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument('--data_dir', dest='data_dir', required=True, help='Which directory contains a {train,val,test}.jsonl file?') + parser.add_argument('--split', dest='split', required=True, help='Which of {train,val,test} are we scoring on?') + parser.add_argument('--strict', dest='strict', required=False, action='store_true', default=False, help='Do we perform strict scoring?') + parser.add_argument('--results', dest='results', required=True, help="""Results File + Contents are expected to be jsonl of: + { + "annotation_id": str, required + # these classifications *must not* overlap + "rationales": List[ + { + "docid": str, required + "hard_rationale_predictions": List[{ + "start_token": int, inclusive, required + "end_token": int, exclusive, required + }], optional, + # token level classifications, a value must be provided per-token + # in an ideal world, these correspond to the hard-decoding above. + "soft_rationale_predictions": List[float], optional. + # sentence level classifications, a value must be provided for every + # sentence in each document, or not at all + "soft_sentence_predictions": List[float], optional. + } + ], + # the classification the model made for the overall classification task + "classification": str, optional + # A probability distribution output by the model. We require this to be normalized. + "classification_scores": Dict[str, float], optional + # The next two fields are measures for how faithful your model is (the + # rationales it predicts are in some sense causal of the prediction), and + # how sufficient they are. We approximate a measure for comprehensiveness by + # asking that you remove the top k%% of tokens from your documents, + # running your models again, and reporting the score distribution in the + # "comprehensiveness_classification_scores" field. + # We approximate a measure of sufficiency by asking exactly the converse + # - that you provide model distributions on the removed k%% tokens. + # 'k' is determined by human rationales, and is documented in our paper. + # You should determine which of these tokens to remove based on some kind + # of information about your model: gradient based, attention based, other + # interpretability measures, etc. + # scores per class having removed k%% of the data, where k is determined by human comprehensive rationales + "comprehensiveness_classification_scores": Dict[str, float], optional + # scores per class having access to only k%% of the data, where k is determined by human comprehensive rationales + "sufficiency_classification_scores": Dict[str, float], optional + # the number of tokens required to flip the prediction - see "Is Attention Interpretable" by Serrano and Smith. + "tokens_to_flip": int, optional + "thresholded_scores": List[{ + "threshold": float, required, + "comprehensiveness_classification_scores": like "classification_scores" + "sufficiency_classification_scores": like "classification_scores" + }], optional. if present, then "classification" and "classification_scores" must be present + } + When providing one of the optional fields, it must be provided for *every* instance. + The classification, classification_score, and comprehensiveness_classification_scores + must together be present for every instance or absent for every instance. + """) + parser.add_argument('--iou_thresholds', dest='iou_thresholds', required=False, nargs='+', type=float, default=[0.5], help='''Thresholds for IOU scoring. + + These are used for "soft" or partial match scoring of rationale spans. + A span is considered a match if the size of the intersection of the prediction + and the annotation, divided by the union of the two spans, is larger than + the IOU threshold. This score can be computed for arbitrary thresholds. + ''') + parser.add_argument('--score_file', dest='score_file', required=False, default=None, help='Where to write results?') + parser.add_argument('--aopc_thresholds', nargs='+', required=False, type=float, default=[0.01, 0.05, 0.1, 0.2, 0.5], help='Thresholds for AOPC Thresholds') + args = parser.parse_args() + results = load_jsonl(args.results) + docids = set(chain.from_iterable([rat['docid'] for rat in res['rationales']] for res in results)) + docs = load_flattened_documents(args.data_dir, docids) + verify_instances(results, docs) + # load truth + annotations = annotations_from_jsonl(os.path.join(args.data_dir, args.split + '.jsonl')) + docids |= set(chain.from_iterable((ev.docid for ev in chain.from_iterable(ann.evidences)) for ann in annotations)) + + has_final_predictions = _has_classifications(results) + scores = dict() + if args.strict: + if not args.iou_thresholds: + raise ValueError("--iou_thresholds must be provided when running strict scoring") + if not has_final_predictions: + raise ValueError("We must have a 'classification', 'classification_score', and 'comprehensiveness_classification_score' field in order to perform scoring!") + # TODO think about offering a sentence level version of these scores. + if _has_hard_predictions(results): + truth = list(chain.from_iterable(Rationale.from_annotation(ann) for ann in annotations)) + pred = list(chain.from_iterable(Rationale.from_instance(inst) for inst in results)) + if args.iou_thresholds is not None: + iou_scores = partial_match_score(truth, pred, args.iou_thresholds) + scores['iou_scores'] = iou_scores + # NER style scoring + rationale_level_prf = score_hard_rationale_predictions(truth, pred) + scores['rationale_prf'] = rationale_level_prf + token_level_truth = list(chain.from_iterable(rat.to_token_level() for rat in truth)) + token_level_pred = list(chain.from_iterable(rat.to_token_level() for rat in pred)) + token_level_prf = score_hard_rationale_predictions(token_level_truth, token_level_pred) + scores['token_prf'] = token_level_prf + else: + logging.info("No hard predictions detected, skipping rationale scoring") + + if _has_soft_predictions(results): + flattened_documents = load_flattened_documents(args.data_dir, docids) + paired_scoring = PositionScoredDocument.from_results(results, annotations, flattened_documents, use_tokens=True) + token_scores = score_soft_tokens(paired_scoring) + scores['token_soft_metrics'] = token_scores + else: + logging.info("No soft predictions detected, skipping rationale scoring") + + if _has_soft_sentence_predictions(results): + documents = load_documents(args.data_dir, docids) + paired_scoring = PositionScoredDocument.from_results(results, annotations, documents, use_tokens=False) + sentence_scores = score_soft_tokens(paired_scoring) + scores['sentence_soft_metrics'] = sentence_scores + else: + logging.info("No sentence level predictions detected, skipping sentence-level diagnostic") + + if has_final_predictions: + flattened_documents = load_flattened_documents(args.data_dir, docids) + class_results = score_classifications(results, annotations, flattened_documents, args.aopc_thresholds) + scores['classification_scores'] = class_results + else: + logging.info("No classification scores detected, skipping classification") + + pprint.pprint(scores) + #logger= logging.getLogger( __name__ ) + #logging.info("UIUIUI"+__name__) + + if args.score_file: + with open(args.score_file, 'w') as of: + json.dump(scores, of, indent=4, sort_keys=True) + +### COPY TO RUN FROM PYTHON FILE +def runEvaluation(data_dir, split, results, score_file, strict=False, iou_thresholds=[0.5], aopc_thresholds=[0.01, 0.05, 0.1, 0.2, 0.5]): + #print(results) + args = type("args", (object, ), { + # data members + "data_dir": data_dir, + "split": split, + "strict": strict, + "results": results, + "iou_thresholds": iou_thresholds, + "score_file": score_file, + "aopc_thresholds": aopc_thresholds, + }) + + results = load_jsonl(args.results) + docids = set(chain.from_iterable([rat['docid'] for rat in res['rationales']] for res in results)) + docs = load_flattened_documents(args.data_dir, docids) + + verify_instances(results, docs) + # load truth + annotations = annotations_from_jsonl(os.path.join(args.data_dir, args.split + '.jsonl')) + docids |= set(chain.from_iterable((ev.docid for ev in chain.from_iterable(ann.evidences)) for ann in annotations)) + + has_final_predictions = _has_classifications(results) + scores = dict() + if args.strict: + if not args.iou_thresholds: + raise ValueError("--iou_thresholds must be provided when running strict scoring") + if not has_final_predictions: + raise ValueError("We must have a 'classification', 'classification_score', and 'comprehensiveness_classification_score' field in order to perform scoring!") + # TODO think about offering a sentence level version of these scores. + if _has_hard_predictions(results): + truth = list(chain.from_iterable(Rationale.from_annotation(ann) for ann in annotations)) + pred = list(chain.from_iterable(Rationale.from_instance(inst) for inst in results)) + if args.iou_thresholds is not None: + iou_scores = partial_match_score(truth, pred, args.iou_thresholds) + scores['iou_scores'] = iou_scores + # NER style scoring + rationale_level_prf = score_hard_rationale_predictions(truth, pred) + scores['rationale_prf'] = rationale_level_prf + token_level_truth = list(chain.from_iterable(rat.to_token_level() for rat in truth)) + token_level_pred = list(chain.from_iterable(rat.to_token_level() for rat in pred)) + token_level_prf = score_hard_rationale_predictions(token_level_truth, token_level_pred) + scores['token_prf'] = token_level_prf + else: + logging.info("No hard predictions detected, skipping rationale scoring") + + if _has_soft_predictions(results): + flattened_documents = load_flattened_documents(args.data_dir, docids) + paired_scoring = PositionScoredDocument.from_results(results, annotations, flattened_documents, use_tokens=True) + token_scores = score_soft_tokens(paired_scoring) + scores['token_soft_metrics'] = token_scores + else: + logging.info("No soft predictions detected, skipping rationale scoring") + + if _has_soft_sentence_predictions(results): + documents = load_documents(args.data_dir, docids) + paired_scoring = PositionScoredDocument.from_results(results, annotations, documents, use_tokens=False) + sentence_scores = score_soft_tokens(paired_scoring) + scores['sentence_soft_metrics'] = sentence_scores + else: + logging.info("No sentence level predictions detected, skipping sentence-level diagnostic") + + if has_final_predictions: + flattened_documents = load_flattened_documents(args.data_dir, docids) + class_results = score_classifications(results, annotations, flattened_documents, args.aopc_thresholds) + scores['classification_scores'] = class_results + else: + logging.info("No classification scores detected, skipping classification") + + pprint.pprint(scores) + #logger= logging.getLogger( __name__ ) + #logging.info("UIUIUI"+__name__) + + if args.score_file: + with open(args.score_file, 'w') as of: + json.dump(scores, of, indent=4, sort_keys=True) + +if __name__ == '__main__': + main() diff --git a/scripts/eraserbenchmark/rationale_benchmark/models/__init__.py b/scripts/eraserbenchmark/rationale_benchmark/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/eraserbenchmark/rationale_benchmark/models/encode_attend.py b/scripts/eraserbenchmark/rationale_benchmark/models/encode_attend.py new file mode 100644 index 0000000..3a0f1fa --- /dev/null +++ b/scripts/eraserbenchmark/rationale_benchmark/models/encode_attend.py @@ -0,0 +1,520 @@ +import argparse +import json +import logging +import random +import os + +from collections import defaultdict, OrderedDict +from heapq import nsmallest, nlargest +from itertools import chain +from typing import Any, Dict, List, Set, Tuple + +#import apex +import numpy as np +import torch +import torch.nn as nn + +from sklearn.metrics import accuracy_score, classification_report + +from rationale_benchmark.utils import ( + Annotation, + load_datasets, + load_documents, + intern_documents, + intern_annotations +) +from rationale_benchmark.models.model_utils import ( + PaddedSequence, + extract_embeddings, +) +from rationale_benchmark.models.mlp import ( + AttentiveClassifier, + LuongAttention, + RNNEncoder, + WordEmbedder +) + +logging.basicConfig(level=logging.DEBUG, format='%(relativeCreated)6d %(threadName)s %(message)s') +# let's make this more or less deterministic (not resistent to restarts) +random.seed(12345) +np.random.seed(67890) +torch.manual_seed(10111213) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + + +def initialize_model(params: dict, vocab: Set[str], batch_first: bool, unk_token='UNK'): + # TODO this is obviously asking for some sort of dependency injection. implement if it saves me time. + if 'embedding_file' in params['embeddings']: + embeddings, word_interner, de_interner = extract_embeddings(vocab, params['embeddings']['embedding_file'], + unk_token=unk_token) + if torch.cuda.is_available(): + embeddings = embeddings.cuda() + else: + raise ValueError("No 'embedding_file' found in params!") + word_embedder = WordEmbedder(embeddings, params['embeddings']['dropout']) + encoding_size = params['classifier'].get('hidden_size', word_embedder.output_dimension) + if bool(params['classifier']['has_query']): + attention_mechanism = LuongAttention(encoding_size) + query_encoder = RNNEncoder(word_embedder, + batch_first=batch_first, + condition=False, + output_dimension=encoding_size, + attention_mechanism=attention_mechanism) + condition = True + query_size = query_encoder.output_dimension + else: + query_encoder = None + condition = False + query_size = None + attention_mechanism = LuongAttention(encoding_size, encoding_size) + document_encoder = RNNEncoder(word_embedder, + batch_first=batch_first, + condition=condition, + output_dimension=encoding_size, + attention_mechanism=attention_mechanism) + evidence_classes = dict((y, x) for (x, y) in enumerate(params['classifier']['classes'])) + classifier = AttentiveClassifier(document_encoder, + query_encoder, + len(evidence_classes), + params['classifier']['mlp_size'], + params['classifier']['dropout']) + return classifier, word_interner, de_interner, evidence_classes + + +def annotation_to_instances(ann: Annotation, docs: Dict[str, List[List[int]]], class_interner: Dict[str, int]): + evidences = defaultdict(set) + for ev in ann.all_evidences(): + evidences[ev.docid].add(ev) + output_documents = dict() + evidence_spans = dict() + for d, evs in evidences.items(): + output_documents[d] = list(chain.from_iterable(docs[d])) + evidence_targets = [0 for _ in range(sum(len(s) for s in docs[d]))] + for ev in evs: + for t in range(ev.start_token, ev.end_token): + evidence_targets[t] = 1 + evidence_spans[d] = evidence_targets + return class_interner.get(ann.classification, -1), output_documents, evidence_spans + + +def convert_for_training(annotations: List[Annotation], docs: Dict[str, List[List[int]]], + class_interner: Dict[str, int]): + ids = [] + classes = [] + queries = [] + doc_vecs = [] + evidence_spans = [] + for ann in annotations: + kls, flattened_docs, ev_spans = annotation_to_instances(ann, docs, class_interner) + if len(flattened_docs) == 0: + continue + if ann.query and len(ann.query) > 0: + queries.append(torch.tensor(ann.query)) + classes.append(kls) + ids.append((ann.annotation_id, flattened_docs.keys())) + combined_doc_vecs = [] + combined_evidence_spans = [] + for d, doc_vec in flattened_docs.items(): + combined_doc_vecs.extend(doc_vec) + combined_evidence_spans.extend(ev_spans[d]) + doc_vecs.append(torch.tensor(combined_doc_vecs)) + evidence_spans.append(torch.tensor(combined_evidence_spans)) + if len(queries) == 0: + queries = None + return ids, classes, queries, doc_vecs, evidence_spans + + +def train_classifier(classifier: nn.Module, + save_dir: str, + train: List[Annotation], + val: List[Annotation], + documents: Dict[str, List[List[int]]], + model_pars: dict, + class_interner: Dict[str, int], + attention_optimizer=None, + classifier_optimizer=None) -> Tuple[nn.Module, dict]: + logging.info(f'Beginning training classifier with {len(train)} annotations, {len(val)} for validation') + # TODO this paramterization is a mess and all parameters should be easier to track + classifier_output_dir = os.path.join(save_dir, 'classifier') + os.makedirs(save_dir, exist_ok=True) + os.makedirs(classifier_output_dir, exist_ok=True) + model_save_file = os.path.join(classifier_output_dir, 'classifier.pt') + epoch_save_file = os.path.join(classifier_output_dir, 'classifier_epoch_data.pt') + train_ids, train_classes, train_queries, train_docs, train_evidence_spans = convert_for_training(train, + documents, + class_interner) + val_ids, val_classes, val_queries, val_docs, val_evidence_spans = convert_for_training(val, + documents, + class_interner) + if not bool(model_pars['classifier']['has_query']): + train_queries = None + val_queries = None + device = next(classifier.parameters()).device + + if attention_optimizer is None: + attention_optimizer = torch.optim.Adam(classifier.parameters(), lr=model_pars['classifier']['lr']) + if classifier_optimizer is None: + classifier_optimizer = torch.optim.Adam(classifier.parameters(), lr=model_pars['classifier']['lr']) + attention_criterion = nn.BCELoss(reduction='sum') + criterion = nn.CrossEntropyLoss(reduction='sum') + batch_size = model_pars['classifier']['batch_size'] + epochs = model_pars['classifier']['epochs'] + attention_epochs = model_pars['classifier']['attention_epochs'] + patience = model_pars['classifier']['patience'] + max_grad_norm = model_pars['classifier'].get('max_grad_norm', None) + class_labels = [k for k, v in sorted(class_interner.items())] + + results = { + 'attention_train_losses': [], + 'attention_val_losses': [], + 'train_loss': [], + 'train_f1': [], + 'train_acc': [], + 'val_loss': [], + 'val_f1': [], + 'val_acc': [], + } + best_attention_epoch = -1 + best_classifier_epoch = -1 + best_attention_loss = float('inf') + best_classifier_loss = float('inf') + best_model_state_dict = None + start_attention_epoch = 0 + start_classifier_epoch = 0 + epoch_data = {} + if os.path.exists(epoch_save_file): + logging.info(f'Restoring model from {model_save_file}') + classifier.load_state_dict(torch.load(model_save_file)) + epoch_data = torch.load(epoch_save_file) + start_attention_epoch = epoch_data.get('attention_epoch', -1) + 1 + start_classifier_epoch = epoch_data.get('classifier_epoch', -1) + 1 + best_attention_loss = epoch_data.get('best_attention_loss', float('inf')) + best_classifier_loss = epoch_data.get('best_classifier_loss', float('inf')) + # handle finishing because patience was exceeded or we didn't get the best final epoch + if bool(epoch_data.get('done_attention', 0)): + start_attention_epoch = epochs + if bool(epoch_data.get('done_classifier', 0)): + start_classifier_epoch = epochs + results = epoch_data['results'] + best_attention_epoch = start_attention_epoch + best_classifier_epoch = start_classifier_epoch + best_model_state_dict = OrderedDict({k: v.cpu() for k, v in classifier.state_dict().items()}) + logging.info(f'Restoring training from attention epoch {start_attention_epoch} / {start_classifier_epoch}') + logging.info(f'Training classifier attention from epoch {start_attention_epoch} until epoch {attention_epochs}') + for attention_epoch in range(start_attention_epoch, attention_epochs): + epoch_train_loss = 0 + epoch_train_tokens = 0 + epoch_val_loss = 0 + epoch_val_tokens = 0 + for batch_start in range(0, len(train_ids), batch_size): + #targets = train_classes[batch_start:batch_start + batch_size] + classifier.train() + attention_optimizer.zero_grad() + if train_queries is None: + queries = None + else: + queries = train_queries[batch_start:batch_start + batch_size] + docs = train_docs[batch_start:batch_start + batch_size] + train_spans = train_evidence_spans[batch_start:batch_start + batch_size] + _, _, _, unnormalized_document_attention, _ = classifier(queries, None, docs, return_attentions=True) + partially_normalized_document_attention = torch.sigmoid(unnormalized_document_attention.data.squeeze()) + train_spans = PaddedSequence.autopad(train_spans, batch_first=True, + device=unnormalized_document_attention.data.device) + batch_loss = attention_criterion(partially_normalized_document_attention, train_spans.data.float()) + epoch_train_loss += batch_loss.item() + train_size = torch.sum(train_spans.batch_sizes).item() + epoch_train_tokens += train_size + batch_loss = batch_loss / train_size + batch_loss.backward() + attention_optimizer.step() + results['attention_train_losses'].append(epoch_train_loss / epoch_train_tokens) + logging.info(f'Epoch {attention_epoch} attention train loss {epoch_train_loss / epoch_train_tokens}') + with torch.no_grad(): + classifier.eval() + for batch_start in range(0, len(val_ids), batch_size): + #targets = val_classes[batch_start:batch_start + batch_size] + if val_queries is None: + queries = None + else: + queries = val_queries[batch_start:batch_start + batch_size] + docs = val_docs[batch_start:batch_start + batch_size] + val_spans = val_evidence_spans[batch_start:batch_start + batch_size] + _, _, _, unnormalized_document_attention, _ = classifier(queries, None, docs, return_attentions=True) + unnormalized_document_attention = torch.sigmoid(unnormalized_document_attention.data) + val_spans = PaddedSequence.autopad(val_spans, batch_first=True, device=device) + batch_loss = attention_criterion(unnormalized_document_attention.squeeze(), val_spans.data.float()) + epoch_val_loss += batch_loss.item() + epoch_val_tokens += torch.sum(val_spans.batch_sizes).item() + epoch_val_loss = epoch_val_loss / epoch_val_tokens + results['attention_val_losses'].append(epoch_val_loss) + logging.info(f'Epoch {attention_epoch} attention val loss {epoch_val_loss}') + if epoch_val_loss < best_attention_loss: + best_model_state_dict = OrderedDict({k: v.cpu() for k, v in classifier.state_dict().items()}) + best_attention_epoch = attention_epoch + best_attention_loss = epoch_val_loss + epoch_data['attention_epoch'] = attention_epoch + epoch_data['results'] = results + epoch_data['best_attention_loss'] = best_attention_loss + epoch_data['best_classifier_loss'] = float('inf') + epoch_data['done_attention'] = 0 + epoch_data['done_classifier'] = 0 + torch.save(classifier.state_dict(), model_save_file) + torch.save(epoch_data, epoch_save_file) + logging.info(f'Epoch {attention_epoch} new best model with val loss {epoch_val_loss}') + if attention_epoch - best_attention_epoch > patience: + logging.info(f'Exiting after epoch {attention_epoch} due to no improvement') + epoch_data['done_attention'] = 1 + torch.save(epoch_data, epoch_save_file) + break + logging.info(f'Training classifier from epoch {start_classifier_epoch} until epoch {epochs}') + for classifier_epoch in range(start_classifier_epoch, epochs): + epoch_train_loss = 0 + epoch_val_loss = 0 + train_preds = [] + train_truth = [] + classifier.train() + for batch_start in range(0, len(train_ids), batch_size): + classifier.train() + classifier_optimizer.zero_grad() + targets = train_classes[batch_start:batch_start + batch_size] + train_truth.extend(targets) + targets = torch.tensor(targets, device=device) + if train_queries is not None: + queries = train_queries[batch_start:batch_start + batch_size] + else: + queries = None + docs = train_docs[batch_start:batch_start + batch_size] + classes = classifier(queries, None, docs, return_attentions=False) + train_preds.extend(x.item() for x in torch.argmax(classes, dim=1)) + batch_loss = criterion(classes.squeeze(), targets) + epoch_train_loss += batch_loss.item() + batch_loss /= len(docs) + batch_loss.backward() + classifier_optimizer.step() + train_accuracy = accuracy_score(train_truth, train_preds) + train_f1 = classification_report(train_truth, train_preds, output_dict=True) + results['train_loss'].append(epoch_train_loss / len(train_ids)) + results['train_acc'].append(train_accuracy) + results['train_f1'].append(train_f1) + logging.info( + f'Epoch {classifier_epoch} train loss {epoch_train_loss / len(train_ids)}, accuracy: {train_accuracy}, f1: {train_f1}') + with torch.no_grad(): + classifier.eval() + val_preds = [] + val_truth = [] + for batch_start in range(0, len(val_ids), batch_size): + targets = val_classes[batch_start:batch_start + batch_size] + val_truth.extend(targets) + if val_queries is not None: + queries = val_queries[batch_start:batch_start + batch_size] + else: + queries = None + docs = val_docs[batch_start:batch_start + batch_size] + classes = classifier(queries, None, docs, return_attentions=False) + targets = torch.tensor(targets, device=classes.device) + val_preds.extend(x.item() for x in torch.argmax(classes, dim=1)) + batch_loss = criterion(classes, targets) + if not torch.all(batch_loss == batch_loss): + import pdb; pdb.set_trace() + epoch_val_loss += batch_loss.item() + batch_loss /= len(docs) + epoch_val_loss /= len(val_ids) + val_accuracy = accuracy_score(val_truth, val_preds) + val_f1 = classification_report(val_truth, val_preds, output_dict=True) + results['val_loss'].append(epoch_val_loss) + results['val_acc'].append(val_accuracy) + results['val_f1'].append(val_f1) + logging.info(f'Epoch {classifier_epoch} val loss {epoch_val_loss}, accuracy: {val_accuracy}, f1: {val_f1}') + if epoch_val_loss < best_classifier_loss: + best_model_state_dict = OrderedDict({k: v.cpu() for k, v in classifier.state_dict().items()}) + best_classifier_epoch = classifier_epoch + best_val_loss = epoch_val_loss + epoch_data['classifier_epoch'] = classifier_epoch + epoch_data['attention_epoch'] = best_attention_epoch + epoch_data['best_attention_loss'] = best_attention_loss + epoch_data['results'] = results + epoch_data['best_classifier_loss'] = best_val_loss + epoch_data['done_classifier'] = 0 + epoch_data['done_attention'] = 1 + torch.save(classifier.state_dict(), model_save_file) + torch.save(epoch_data, epoch_save_file) + logging.info(f'Epoch {classifier_epoch} new best model with val loss {epoch_val_loss}') + if classifier_epoch - best_classifier_epoch > patience: + logging.info(f'Exiting after epoch {classifier_epoch} due to no improvement') + epoch_data['done_classifier'] = 1 + torch.save(epoch_data, epoch_save_file) + break + return classifier, results + + +def decode(classifier: nn.Module, + train: List[Annotation], + val: List[Annotation], + test: List[Annotation], + documents: Dict[str, List[List[int]]], + class_interner: Dict[str, int], + batch_size: int, + tensorize_model_inputs: bool, + threshold: float, + k_fraction: float, + has_query: bool) -> dict: + class_deinterner = {v: k for k, v in class_interner.items()} + train_ids, train_classes, train_queries, train_docs, train_evidence_spans = convert_for_training(train, documents, + class_interner) + val_ids, val_classes, val_queries, val_docs, val_evidence_spans = convert_for_training(val, documents, + class_interner) + test_ids, _, test_queries, test_docs, test_evidence_spans = convert_for_training(test, documents, class_interner) + if not bool(has_query): + train_queries = None + val_queries = None + test_queries = None + device = next(classifier.parameters()).device + + def decode_set(queries, docs): + classifier.eval() + with torch.no_grad(): + preds, unnormalized_attentions, attentions = [], [], [] + for batch_start in range(0, len(docs), batch_size): + if queries is not None: + q = queries[batch_start:batch_start + batch_size] + q = [torch.tensor(x) for x in q] + else: + q = None + d = docs[batch_start:batch_start + batch_size] + d = [torch.tensor(x) for x in d] + classes, _, _, unnormalized_document_attention, normalized_document_attention = classifier(q, None, d, + return_attentions=True) + preds.extend([[y.item() for y in x] for x in classes]) + attentions.extend([[x.item() for x in y] for y in + normalized_document_attention.unpad(normalized_document_attention.data.squeeze())]) + unnormalized_attentions.extend([[x.item() for x in y] for y in unnormalized_document_attention.unpad( + unnormalized_document_attention.data.squeeze())]) + #unnormalized_attentions.append([x.item() for x in unnormalized_document_attention]) + return preds, attentions + + def generate_rationales(ids, queries, docs): + rats = [] + evidence_only_docs = [] + no_evidence_docs = [] + for (ann_id, docids), doc, pred, attentions in zip(ids, docs, *decode_set(queries, docs)): + doc = np.array(doc) + classification_scores = {class_deinterner[cls]: p for cls, p in enumerate(pred)} + cls = np.argmax(pred) + if len(docids) == 1: + (docid,) = docids + soft_sentence_predictions = [] + start = 0 + for sent in documents[docid]: + end = start + len(sent) + soft_sentence_predictions.append(sum(attentions[start:end])) + start = end + rat = { + 'annotation_id': ann_id, + 'classification': class_deinterner[cls], + 'classification_scores': classification_scores, + 'rationales': [{ + 'docid': docid, + 'soft_rationale_predictions': attentions, + 'soft_sentence_predictions': soft_sentence_predictions, + }] + } + else: + raise ValueError() + # TODO make hard predictions + top_k = nlargest(int(k_fraction * len(doc)), zip(attentions, doc)) + top_k = [x[1] for x in top_k] + evidence_only_docs.append(top_k) + bottom_k = nsmallest(int((1 - k_fraction) * len(doc)), zip(attentions, doc)) + bottom_k = [x[1] for x in bottom_k] + no_evidence_docs.append(bottom_k) + rats.append(rat) + + for i, ((ann_id, _), pred, _) in enumerate(zip(ids, *decode_set(queries, evidence_only_docs))): + classification_scores = {class_deinterner[cls]: p for cls, p in enumerate(pred)} + rats[i]['sufficiency_classification_scores'] = classification_scores + + for i, ((ann_id, _), pred, _) in enumerate(zip(ids, *decode_set(queries, no_evidence_docs))): + classification_scores = {class_deinterner[cls]: p for cls, p in enumerate(pred)} + rats[i]['comprehensiveness_classification_scores'] = classification_scores + + return rats + + val_rats = generate_rationales(val_ids, val_queries, val_docs) + test_rats = generate_rationales(test_ids, test_queries, test_docs) + return val_rats, test_rats + + +def main(): + parser = argparse.ArgumentParser(description=""" """, formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument('--data_dir', dest='data_dir', required=True, + help='Which directory contains a {train,val,test}.jsonl file?') + parser.add_argument('--output_dir', dest='output_dir', required=True, + help='Where shall we write intermediate models + final data to?') + parser.add_argument('--model_params', dest='model_params', required=True, + help='JSoN file for loading arbitrary model parameters (e.g. optimizers, pre-saved files, etc.') + args = parser.parse_args() + BATCH_FIRST = True + + with open(args.model_params, 'r') as fp: + logging.debug(f'Loading model parameters from {args.model_params}') + model_params = json.load(fp) + train, val, test = load_datasets(args.data_dir) + documents = load_documents(args.data_dir) + document_vocab = set(chain.from_iterable(chain.from_iterable(documents.values()))) + annotation_vocab = set(chain.from_iterable(e.query.split() for e in chain(train, val, test))) + logging.debug(f'Loaded {len(documents)} documents with {len(document_vocab)} unique words') + # this ignores the case where annotations don't align perfectly with token boundaries, but this isn't that important + vocab = document_vocab | annotation_vocab + unk_token = 'UNK' + classifier, word_interner, de_interner, evidence_classes = initialize_model(model_params, vocab, + batch_first=BATCH_FIRST, + unk_token=unk_token) + classifier = classifier.cuda() + logging.debug( + f'Including annotations, we have {len(vocab)} total words in the data, with embeddings for {len(word_interner)}') + interned_documents = intern_documents(documents, word_interner, unk_token) + interned_train = intern_annotations(train, word_interner, unk_token) + interned_val = intern_annotations(val, word_interner, unk_token) + interned_test = intern_annotations(test, word_interner, unk_token) + + classifier, results = train_classifier(classifier, args.output_dir, interned_train, interned_val, + interned_documents, model_params, evidence_classes) + val_rats, test_rats = decode(classifier, + interned_train, + interned_val, + interned_test, + interned_documents, + evidence_classes, + batch_size=model_params['classifier']['batch_size'], + tensorize_model_inputs=True, + threshold=model_params['classifier']['threshold'], + k_fraction=model_params['classifier']['k_fraction'], + has_query=bool(model_params['classifier']['has_query'])) + with open(os.path.join(args.output_dir, 'val_decoded.jsonl'), 'w') as val_output: + for line in val_rats: + val_output.write(json.dumps(line)) + val_output.write('\n') + + with open(os.path.join(args.output_dir, 'test_decoded.jsonl'), 'w') as test_output: + for line in test_rats: + test_output.write(json.dumps(line)) + test_output.write('\n') + #training_results, train_decoded, val_decoded, test_decoded = decode(classifier, interned_train, interned_val, interned_test, interned_documents, evidence_classes, params['classifier']['batch_size'], tensorize_model_inputs=True) + #write_jsonl(train_decoded, os.path.join(args.output_dir, 'train_decoded.jsonl')) + #write_jsonl(val_decoded, os.path.join(args.output_dir, 'val_decoded.jsonl')) + #write_jsonl(test_decoded, os.path.join(args.output_dir, 'test_decoded.jsonl')) + #with open(os.path.join(args.output_dir, 'identifier_results.json'), 'w') as ident_output, \ + # open(os.path.join(args.output_dir, 'classifier_results.json'), 'w') as class_output: + # ident_output.write(json.dumps(evidence_ident_results)) + # class_output.write(json.dumps(evidence_class_results)) + #for k, v in pipeline_results.items(): + # if type(v) is dict: + # for k1, v1 in v.items(): + # logging.info(f'Pipeline results for {k}, {k1}={v1}') + # else: + # logging.info(f'Pipeline results {k}\t={v}') + + +if __name__ == '__main__': + main() diff --git a/scripts/eraserbenchmark/rationale_benchmark/models/mlp.py b/scripts/eraserbenchmark/rationale_benchmark/models/mlp.py new file mode 100644 index 0000000..fab798b --- /dev/null +++ b/scripts/eraserbenchmark/rationale_benchmark/models/mlp.py @@ -0,0 +1,282 @@ +from typing import Any, List + +import torch +import torch.nn as nn + +from transformers import BertForSequenceClassification + +from rationale_benchmark.models.model_utils import PaddedSequence + + +class WordEmbedder(nn.Module): + """ A thin wrapping for an nn.embedding """ + + def __init__(self, embeddings: nn.Embedding, dropout_rate: float): + super(WordEmbedder, self).__init__() + self.embeddings = embeddings + self.output_dimension = embeddings.embedding_dim + self.dropout = nn.Dropout(p=dropout_rate) + + def forward(self, docids: List[Any], ps: PaddedSequence) -> PaddedSequence: + # n.b. the docids are to allow for wrapping a pre-decoded BERT or ELMo or $FAVORITE_LANGUAGE_MODEL_OF_THE_DAY + if docids and len(docids) not in ps.data.size(): + raise ValueError(f"Document id dimension {len(docids)} does not match input data dimensions {ps.data.size()}") + embedded = self.embeddings(ps.data) + embedded = self.dropout(embedded) + return PaddedSequence(embedded, ps.batch_sizes, ps.batch_first) + + +class LuongAttention(nn.Module): + def __init__(self, + output_size: int, + query_size: int=None, + dropout_rate:float=0.0): + super(LuongAttention, self).__init__() + self.use_query = query_size is not None + self.query_size = query_size + self.hidden_size = output_size + input_size = query_size + output_size if self.use_query else output_size + self.w = nn.Linear(input_size, output_size) + self.v = nn.Parameter(torch.randn((output_size,)), requires_grad=True) + self.dropout = nn.Dropout(p=dropout_rate) + + def forward(self, + output: PaddedSequence, # batch x length x rep + query: torch.Tensor=None): + assert output.batch_first + if query is not None: + query = query.unsqueeze(1).repeat((1, output.data.size()[1], 1)) + attn_input = torch.cat([output.data, query], dim=2) + else: + attn_input = output.data + raw_score = self.w(attn_input) + score = torch.tanh(raw_score) @ self.v + score = score + output.mask(size=score.data.size(), + on=0, + off=float('-inf'), + dtype=torch.float, + device=self.v.device) + weights = torch.softmax(score, dim=1).unsqueeze(dim=-1) + expectation = weights * output.data + expectation = expectation.sum(dim=1) + expectation = self.dropout(expectation) + return score, weights, expectation + + +class BahadanauAttention(nn.Module): + + def __init__(self, + output_size: int, + query_size: int=None, + dropout_rate:float=0.0): + super(BahadanauAttention, self).__init__() + self.v = nn.Parameter(torch.randn((output_size, 1)), requires_grad=True) + self.w = nn.Linear(output_size, output_size, bias=False) + if query_size: + self.u = nn.Linear(query_size, output_size, bias=False) + else: + self.u = None + self.dropout = nn.Dropout(p=dropout_rate) + self.output_dimension = output_size + + def forward(self, + output: PaddedSequence, # batch x length x rep OR length x batch x rep + query: torch.Tensor=None): + raw_score = self.w(output.data) + if self.u: + raw_score += self.u(query.unsqueeze(1)) + score = torch.tanh(raw_score) @ self.v + score = score + output.mask( + size=output.data.size()[:2], + on=0, + off=float('-inf'), + dtype=torch.float, + device=self.v.device).unsqueeze(dim=-1) + dimension = 1 if output.batch_first else 0 + assert output.batch_first # for simplicity we don't bother + weights = torch.softmax(score, dim=dimension) + expectation = weights * output.data + expectation = expectation.sum(dim=dimension) + expectation = self.dropout(expectation) + return score, weights, expectation + + +class RNNEncoder(nn.Module): + """Recurrently encodes a sequence of words into a single vector. + + Collapsing the sequence of encoded values can be done either via an + attention mechanism, or if none provided, by just using the final hidden + states. + """ + + def __init__(self, + word_embedder: WordEmbedder, + output_dimension: int=None, + condition: bool=False, + batch_first: bool=False, + dropout_rate: float=0, + num_layers=1, + bidirectional=False, + attention_mechanism=None): + super(RNNEncoder, self).__init__() + if output_dimension is None: + output_dimension = word_embedder.output_dimension + self.word_embedder = word_embedder + self.output_dimension = output_dimension + self.condition = condition + self.batch_first = batch_first + input_size = word_embedder.output_dimension + self.rnn = nn.GRU(input_size=input_size, + hidden_size=self.output_dimension, + batch_first=batch_first, + dropout=dropout_rate, + num_layers=num_layers, + bidirectional=bidirectional) + self.attention_mechanism = attention_mechanism + + def forward(self, + docids: List[Any], + docs: PaddedSequence, + query: torch.Tensor): + embedded = self.word_embedder(docids, docs) + assert embedded.batch_first == self.rnn.batch_first + # concatenate the query to every input + if self.condition: + assert query is not None + #if self.batch_first: + # query = torch.cat(docs.data.size()[1]*[query.unsqueeze(dim=1)],dim=1) + #else: + # # TODO verify this works properly! + # import pdb; pdb.set_trace() + # query = torch.cat(docs.data.size()[0]*[query.unsqueeze(dim=0)],dim=0) + #embedded = torch.cat([query, embedded.data], dim=-1) + # this doesn't handle multilayer and multidirectional cases + output, hidden = self.rnn(docs.pack_other(embedded.data)) + output = PaddedSequence.from_packed_sequence(output, batch_first=docs.batch_first) + assert hidden.size()[-1] == self.rnn.hidden_size + if self.attention_mechanism is not None: + unnormalized_attention, attention, hidden = self.attention_mechanism(output, query) + assert hidden.size()[-1] == self.rnn.hidden_size + else: + unnormalized_attention, attention, hidden = None, None, hidden + return hidden, unnormalized_attention, attention, output + + +class AttentiveClassifier(nn.Module): + """Encodes a document + a query and makes a classification. Supports query-only and document-only modes. + + Args: + query_encoder: + - takes a list of query ids, query representation, and optional additional encoding element to a fixed size. + - same parameterization as the document_encoder (just think of the query as a document) + document_encoder: + - takes a list of docids (for convenience if working with pre-computed representations), document representations, and an encoded query to a fixed size + num_classes: + - how many things to make a prediction for + mlp_size: + - document + query -> linear (mlp_size) -> non-linear -> num_classes -> softmax + """ + + def __init__(self, + document_encoder: RNNEncoder, + query_encoder: RNNEncoder, + num_classes: int, + mlp_size: int, + dropout_rate: float): + super(AttentiveClassifier, self).__init__() + self.document_encoder = document_encoder + self.query_encoder = query_encoder + + document_output_dimension = self.document_encoder.output_dimension if document_encoder else 0 + query_output_dimension = self.query_encoder.output_dimension if query_encoder else 0 + + self.mlp = nn.Sequential(nn.Dropout(p=dropout_rate), + nn.Linear(document_output_dimension + query_output_dimension, mlp_size), + nn.ReLU(), + nn.Dropout(p=dropout_rate), + nn.Linear(mlp_size, num_classes), + nn.Softmax(dim=-1)) + + def forward(self, + query: List[torch.tensor], + docids: List[Any], + document_batch: List[torch.tensor], + return_attentions: bool=False): + # note about device management: + # since distributed training is enabled, the inputs to this module can be on *any* device (preferably cpu, since we wrap and unwrap the module) + # we want to keep these params on the input device (assuming CPU) for as long as possible for cheap memory access + device = next(self.parameters()).device + if query is not None: + assert self.query_encoder is not None + query = PaddedSequence.autopad(query, batch_first=self.query_encoder.batch_first, device=device) + query_vector, unnormalized_query_attention, normalized_query_attention, _ = self.query_encoder(None, query, None) + unnormalized_query_attention = PaddedSequence(unnormalized_query_attention, query.batch_sizes, query.batch_first) + normalized_query_attention = PaddedSequence(normalized_query_attention, query.batch_sizes, query.batch_first) + else: + query_vector = None + unnormalized_query_attention = None + normalized_query_attention = None + if document_batch is not None: + assert self.document_encoder is not None + document_batch = PaddedSequence.autopad(document_batch, batch_first=self.document_encoder.batch_first, device=device) + document_vector, unnormalized_document_attention, normalized_document_attention, _ = self.document_encoder(docids, document_batch, query_vector) + unnormalized_document_attention = PaddedSequence(unnormalized_document_attention, document_batch.batch_sizes, document_batch.batch_first) + normalized_document_attention = PaddedSequence(normalized_document_attention, document_batch.batch_sizes, document_batch.batch_first) + else: + document_vector = None + if query_vector is not None and document_vector is not None: + assert query_vector.size()[:2] == document_vector.size()[:2] + combined_vector = torch.cat([query_vector, document_vector], dim=-1) + else: + assert query_vector is not None or document_vector is not None + combined_vector = query_vector if query_vector is not None else document_vector + if return_attentions: + return self.mlp(combined_vector), unnormalized_query_attention, normalized_query_attention, unnormalized_document_attention, normalized_document_attention + else: + return self.mlp(combined_vector) + + +class BertClassifier(nn.Module): + """Thin wrapper around BertForSequenceClassification""" + def __init__(self, + bert_dir: str, + pad_token_id: int, + cls_token_id: int, + sep_token_id: int, + num_labels: int, + max_length: int=512, + use_half_precision=True): + super(BertClassifier, self).__init__() + bert = BertForSequenceClassification.from_pretrained(bert_dir, num_labels=num_labels) + if use_half_precision: + import apex + bert = bert.half() + self.bert = bert + self.pad_token_id = pad_token_id + self.cls_token_id = cls_token_id + self.sep_token_id = sep_token_id + self.max_length = max_length + + def forward(self, + query: List[torch.tensor], + docids: List[Any], + document_batch: List[torch.tensor]): + assert len(query) == len(document_batch) + # note about device management: + # since distributed training is enabled, the inputs to this module can be on *any* device (preferably cpu, since we wrap and unwrap the module) + # we want to keep these params on the input device (assuming CPU) for as long as possible for cheap memory access + target_device = next(self.parameters()).device + cls_token = torch.tensor([self.cls_token_id]).to(device=document_batch[0].device) + sep_token = torch.tensor([self.sep_token_id]).to(device=document_batch[0].device) + input_tensors = [] + position_ids = [] + for q, d in zip(query, document_batch): + if len(q) + len(d) + 2 > self.max_length: + d = d[:(self.max_length - len(q) - 2)] + input_tensors.append(torch.cat([cls_token, q, sep_token, d])) + position_ids.append(torch.tensor(list(range(0, len(q) + 1)) + list(range(0, len(d) + 1)))) + bert_input = PaddedSequence.autopad(input_tensors, batch_first=True, padding_value=self.pad_token_id, device=target_device) + positions = PaddedSequence.autopad(position_ids, batch_first=True, padding_value=0, device=target_device) + (classes,) = self.bert(bert_input.data, attention_mask=bert_input.mask(on=0.0, off=float('-inf'), device=target_device), position_ids=positions.data) + assert torch.all(classes == classes) # for nans + return classes diff --git a/scripts/eraserbenchmark/rationale_benchmark/models/model_utils.py b/scripts/eraserbenchmark/rationale_benchmark/models/model_utils.py new file mode 100644 index 0000000..ccefaa8 --- /dev/null +++ b/scripts/eraserbenchmark/rationale_benchmark/models/model_utils.py @@ -0,0 +1,155 @@ +from dataclasses import dataclass +from typing import Dict, List, Set + +import numpy as np +from gensim.models import KeyedVectors + +import torch +from torch import nn +from torch.nn.utils.rnn import pad_sequence, PackedSequence, pack_padded_sequence, pad_packed_sequence + + +@dataclass(eq=True, frozen=True) +class PaddedSequence: + """A utility class for padding variable length sequences mean for RNN input + This class is in the style of PackedSequence from the PyTorch RNN Utils, + but is somewhat more manual in approach. It provides the ability to generate masks + for outputs of the same input dimensions. + The constructor should never be called directly and should only be called via + the autopad classmethod. + + We'd love to delete this, but we pad_sequence, pack_padded_sequence, and + pad_packed_sequence all require shuffling around tuples of information, and some + convenience methods using these are nice to have. + """ + + data: torch.Tensor + batch_sizes: torch.Tensor + batch_first: bool = False + + @classmethod + def autopad(cls, data, batch_first: bool = False, padding_value=0, device=None) -> 'PaddedSequence': + # handle tensors of size 0 (single item) + data_ = [] + for d in data: + if len(d.size()) == 0: + d = d.unsqueeze(0) + data_.append(d) + padded = pad_sequence(data_, batch_first=batch_first, padding_value=padding_value) + if batch_first: + batch_lengths = torch.LongTensor([len(x) for x in data_]) + if any([x == 0 for x in batch_lengths]): + raise ValueError( + "Found a 0 length batch element, this can't possibly be right: {}".format(batch_lengths)) + else: + # TODO actually test this codepath + batch_lengths = torch.LongTensor([len(x) for x in data]) + return PaddedSequence(padded, batch_lengths, batch_first).to(device=device) + + def pack_other(self, data: torch.Tensor): + return pack_padded_sequence(data, self.batch_sizes, batch_first=self.batch_first, enforce_sorted=False) + + @classmethod + def from_packed_sequence(cls, ps: PackedSequence, batch_first: bool, padding_value=0) -> 'PaddedSequence': + padded, batch_sizes = pad_packed_sequence(ps, batch_first, padding_value) + return PaddedSequence(padded, batch_sizes, batch_first) + + def cuda(self) -> 'PaddedSequence': + return PaddedSequence(self.data.cuda(), self.batch_sizes.cuda(), batch_first=self.batch_first) + + def to(self, dtype=None, device=None, copy=False, non_blocking=False) -> 'PaddedSequence': + # TODO make to() support all of the torch.Tensor to() variants + return PaddedSequence( + self.data.to(dtype=dtype, device=device, copy=copy, non_blocking=non_blocking), + self.batch_sizes.to(device=device, copy=copy, non_blocking=non_blocking), + batch_first=self.batch_first) + + def mask(self, on=int(0), off=int(0), device='cpu', size=None, dtype=None) -> torch.Tensor: + if size is None: + size = self.data.size() + out_tensor = torch.zeros(*size, dtype=dtype) + # TODO this can be done more efficiently + out_tensor.fill_(off) + # note to self: these are probably less efficient than explicilty populating the off values instead of the on values. + if self.batch_first: + for i, bl in enumerate(self.batch_sizes): + out_tensor[i, :bl] = on + else: + for i, bl in enumerate(self.batch_sizes): + out_tensor[:bl, i] = on + return out_tensor.to(device) + + def unpad(self, other: torch.Tensor) -> List[torch.Tensor]: + out = [] + for o, bl in zip(other, self.batch_sizes): + out.append(o[:bl]) + return out + + def flip(self) -> 'PaddedSequence': + return PaddedSequence(self.data.transpose(0, 1), not self.batch_first, self.padding_value) + + +def extract_embeddings(vocab: Set[str], embedding_file: str, unk_token: str = 'UNK', pad_token: str = 'PAD') -> ( +nn.Embedding, Dict[str, int], List[str]): + vocab = vocab | set([unk_token, pad_token]) + if embedding_file.endswith('.bin'): + WVs = KeyedVectors.load_word2vec_format(embedding_file, binary=True) + + word_to_vector = dict() + WV_matrix = np.matrix([WVs[v] for v in WVs.vocab.keys()]) + + if unk_token not in WVs: + mean_vector = np.mean(WV_matrix, axis=0) + word_to_vector[unk_token] = mean_vector + if pad_token not in WVs: + word_to_vector[pad_token] = np.zeros(WVs.vector_size) + + for v in vocab: + if v in WVs: + word_to_vector[v] = WVs[v] + + interner = dict() + deinterner = list() + vectors = [] + count = 0 + for word in [pad_token, unk_token] + sorted(list(word_to_vector.keys() - {unk_token, pad_token})): + vector = word_to_vector[word] + vectors.append(np.array(vector)) + interner[word] = count + deinterner.append(word) + count += 1 + vectors = torch.FloatTensor(np.array(vectors)) + embedding = nn.Embedding.from_pretrained(vectors, padding_idx=interner[pad_token]) + embedding.weight.requires_grad = False + return embedding, interner, deinterner + elif embedding_file.endswith('.txt'): + word_to_vector = dict() + vector = [] + with open(embedding_file, 'r') as inf: + for line in inf: + contents = line.strip().split() + word = contents[0] + vector = torch.tensor([float(v) for v in contents[1:]]).unsqueeze(0) + word_to_vector[word] = vector + embed_size = vector.size() + if unk_token not in word_to_vector: + mean_vector = torch.cat(list(word_to_vector.values()), dim=0).mean(dim=0) + word_to_vector[unk_token] = mean_vector.unsqueeze(0) + if pad_token not in word_to_vector: + word_to_vector[pad_token] = torch.zeros(embed_size) + interner = dict() + deinterner = list() + vectors = [] + count = 0 + for word in [pad_token, unk_token] + sorted(list(word_to_vector.keys() - {unk_token, pad_token})): + vector = word_to_vector[word] + vectors.append(vector) + interner[word] = count + deinterner.append(word) + count += 1 + vectors = torch.cat(vectors, dim=0) + embedding = nn.Embedding.from_pretrained(vectors, padding_idx=interner[pad_token]) + embedding.weight.requires_grad = False + return embedding, interner, deinterner + else: + raise ValueError("Unable to open embeddings file {}".format(embedding_file)) diff --git a/scripts/eraserbenchmark/rationale_benchmark/models/pipeline/__init__.py b/scripts/eraserbenchmark/rationale_benchmark/models/pipeline/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/eraserbenchmark/rationale_benchmark/models/pipeline/bert_pipeline.py b/scripts/eraserbenchmark/rationale_benchmark/models/pipeline/bert_pipeline.py new file mode 100644 index 0000000..19f865b --- /dev/null +++ b/scripts/eraserbenchmark/rationale_benchmark/models/pipeline/bert_pipeline.py @@ -0,0 +1,330 @@ +# TODO consider if this can be collapsed back down into the pipeline_train.py +import argparse +import json +import logging +import random +import os + +from itertools import chain +from typing import List, Tuple + +import numpy as np +import torch + +from transformers import BertTokenizer + +from rationale_benchmark.models.sequence_taggers import BertTagger +from rationale_benchmark.utils import ( + Annotation, + Evidence, + write_jsonl, + load_datasets, + load_documents, +) +from rationale_benchmark.models.mlp import ( + BertClassifier, +) +from rationale_benchmark.models.pipeline.evidence_identifier import train_evidence_identifier +from rationale_benchmark.models.pipeline.evidence_token_identifier import train_evidence_token_identifier +from rationale_benchmark.models.pipeline.evidence_classifier import train_evidence_classifier +from rationale_benchmark.models.pipeline.pipeline_utils import decode, decode_evidence_tokens_and_classify + +logging.basicConfig(level=logging.DEBUG, format='%(relativeCreated)6d %(threadName)s %(message)s') +logger = logging.getLogger(__name__) +# let's make this more or less deterministic (not resistent to restarts) +random.seed(12345) +np.random.seed(67890) +torch.manual_seed(10111213) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + + +def initialize_models(params: dict, batch_first: bool, unk_token=''): + assert batch_first + max_length = params['max_length'] + tokenizer = BertTokenizer.from_pretrained(params['bert_vocab']) + pad_token_id = tokenizer.pad_token_id + cls_token_id = tokenizer.cls_token_id + sep_token_id = tokenizer.sep_token_id + bert_dir = params['bert_dir'] + if bool(params['use_evidence_sentence_identifier']): + use_half_precision = bool(params['evidence_identifier'].get('use_half_precision', 1)) + evidence_identifier = BertClassifier(bert_dir=bert_dir, + pad_token_id=pad_token_id, + cls_token_id=cls_token_id, + sep_token_id=sep_token_id, + num_labels=2, + max_length=max_length, + use_half_precision=use_half_precision) + else: + evidence_identifier = None + + if bool(params['use_evidence_token_identifier']): + use_half_precision = bool(params['evidence_token_identifier'].get('use_half_precision', 1)) + evidence_token_identifier = BertTagger(bert_dir=bert_dir, + pad_token_id=pad_token_id, + cls_token_id=cls_token_id, + sep_token_id=sep_token_id, + max_length=max_length, + use_half_precision=use_half_precision) + else: + evidence_token_identifier = None + use_half_precision = bool(params['evidence_classifier'].get('use_half_precision', 1)) + evidence_classes = dict((y, x) for (x, y) in enumerate(params['evidence_classifier']['classes'])) + evidence_classifier = BertClassifier(bert_dir=bert_dir, + pad_token_id=pad_token_id, + cls_token_id=cls_token_id, + sep_token_id=sep_token_id, + num_labels=len(evidence_classes), + max_length=max_length, + use_half_precision=use_half_precision) + word_interner = tokenizer.vocab + de_interner = tokenizer.ids_to_tokens + return evidence_identifier, evidence_token_identifier, evidence_classifier, word_interner, de_interner, evidence_classes, tokenizer + + +def bert_tokenize_doc(doc: List[List[str]], tokenizer, special_token_map) -> Tuple[List[List[str]], List[List[Tuple[int, int]]]]: + """ Tokenizes a document and returns [start, end) spans to map the wordpieces back to their source words""" + sents = [] + sent_token_spans = [] + for sent in doc: + tokens = [] + spans = [] + start = 0 + for w in sent: + if w in special_token_map: + tokens.append(w) + else: + tokens.extend(tokenizer.tokenize(w)) + end = len(tokens) + spans.append((start, end)) + start = end + sents.append(tokens) + sent_token_spans.append(spans) + return sents, sent_token_spans + + +def bert_intern_doc(doc: List[List[str]], tokenizer, special_token_map) -> List[List[int]]: + return [list(chain.from_iterable(special_token_map.get(w, tokenizer.encode(w)) for w in s)) for s in doc] + + +def bert_intern_annotation(annotations: List[Annotation], tokenizer): + ret = [] + for ann in annotations: + ev_groups = [] + for ev_group in ann.evidences: + evs = [] + for ev in ev_group: + text = list(chain.from_iterable(tokenizer.tokenize(w) for w in ev.text.split())) + if len(text) == 0: + continue + text = tokenizer.encode(text, add_special_tokens=False) + evs.append(Evidence(text=tuple(text), + docid=ev.docid, + start_token=ev.start_token, + end_token=ev.end_token, + start_sentence=ev.start_sentence, + end_sentence=ev.end_sentence)) + ev_groups.append(tuple(evs)) + query = list(chain.from_iterable(tokenizer.tokenize(w) for w in ann.query.split())) + if len(query) > 0: + query = tokenizer.encode(query, add_special_tokens=False) + else: + query = [] + ret.append(Annotation(annotation_id=ann.annotation_id, + query=tuple(query), + evidences=frozenset(ev_groups), + classification=ann.classification, + query_type=ann.query_type)) + return ret + + +BATCH_FIRST = True + + +def main(): + parser = argparse.ArgumentParser(description="""Trains a pipeline model. + + Step 1 is evidence identification, that is identify if a given sentence is evidence or not + Step 2 is evidence classification, that is given an evidence sentence, classify the final outcome for the final task + (e.g. sentiment or significance). + + These models should be separated into two separate steps, but at the moment: + * prep data (load, intern documents, load json) + * convert data for evidence identification - in the case of training data we take all the positives and sample some + negatives + * side note: this sampling is *somewhat* configurable and is done on a per-batch/epoch basis in order to gain a + broader sampling of negative values. + * train evidence identification + * convert data for evidence classification - take all rationales + decisions and use this as input + * train evidence classification + * decode first the evidence, then run classification for each split + + """, formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument('--data_dir', dest='data_dir', required=True, + help='Which directory contains a {train,val,test}.jsonl file?') + parser.add_argument('--output_dir', dest='output_dir', required=True, + help='Where shall we write intermediate models + final data to?') + parser.add_argument('--model_params', dest='model_params', required=True, + help='JSoN file for loading arbitrary model parameters (e.g. optimizers, pre-saved files, etc.') + args = parser.parse_args() + assert BATCH_FIRST + os.makedirs(args.output_dir, exist_ok=True) + + with open(args.model_params, 'r') as fp: + logger.info(f'Loading model parameters from {args.model_params}') + model_params = json.load(fp) + logger.info(f'Params: {json.dumps(model_params, indent=2, sort_keys=True)}') + train, val, test = load_datasets(args.data_dir) + docids = set(e.docid for e in + chain.from_iterable(chain.from_iterable(map(lambda ann: ann.evidences, chain(train, val, test))))) + documents = load_documents(args.data_dir, docids) + logger.info(f'Loaded {len(documents)} documents') + # this ignores the case where annotations don't align perfectly with token boundaries, but this isn't that important + unk_token = '' + evidence_identifier, evidence_token_identifier, evidence_classifier, word_interner, de_interner, evidence_classes, tokenizer = \ + initialize_models(model_params, + batch_first=BATCH_FIRST, + unk_token=unk_token) + if not ((evidence_identifier is None) ^ (evidence_token_identifier is None)): + raise ValueError('Exactly one of the evidence identifier and evidence token identifier must be defined, not both!') + logger.info(f'We have {len(word_interner)} wordpieces') + cache = os.path.join(args.output_dir, 'preprocessed.pkl') + if os.path.exists(cache): + logger.info(f'Loading interned documents from {cache}') + (interned_documents, interned_document_token_slices) = torch.load(cache) + else: + logger.info(f'Interning documents') + special_token_map = { + 'SEP': [evidence_classifier.sep_token_id], + '[SEP]': [evidence_classifier.sep_token_id], + '[sep]': [evidence_classifier.sep_token_id], + 'UNK': [tokenizer.unk_token_id], + '[UNK]': [tokenizer.unk_token_id], + '[unk]': [tokenizer.unk_token_id], + 'PAD': [tokenizer.unk_token_id], + '[PAD]': [tokenizer.unk_token_id], + '[pad]': [tokenizer.unk_token_id], + } + interned_documents = {} + interned_document_token_slices = {} + for d, doc in documents.items(): + tokenized, w_slices = bert_tokenize_doc(doc, tokenizer, special_token_map=special_token_map) + interned_documents[d] = bert_intern_doc(tokenized, tokenizer, special_token_map=special_token_map) + interned_document_token_slices[d] = w_slices + torch.save((interned_documents, interned_document_token_slices), cache) + interned_train = bert_intern_annotation(train, tokenizer) + interned_val = bert_intern_annotation(val, tokenizer) + interned_test = bert_intern_annotation(test, tokenizer) + + # train the evidence identifier + if evidence_identifier is not None: + logger.info('Beginning training of the evidence sentence identifier') + evidence_identifier = evidence_identifier.cuda() + optimizer = None + scheduler = None + evidence_identifier, evidence_ident_results = train_evidence_identifier(evidence_identifier, + args.output_dir, + interned_train, + interned_val, + interned_documents, + model_params, + optimizer=optimizer, + scheduler=scheduler, + tensorize_model_inputs=True) + evidence_identifier = evidence_identifier.cpu() # free GPU space for next model + else: + logger.info('No evidence sentence identifier provided; skipping') + evidence_identifier, evidence_ident_results = None, None + + if evidence_token_identifier is not None: + logger.info('Beginning training of the evidence token identifier') + evidence_token_identifier = evidence_token_identifier.cuda() + evidence_token_identifier, evidence_token_identifier_results = \ + train_evidence_token_identifier(evidence_token_identifier, + args.output_dir, + interned_train, + interned_val, + interned_documents=interned_documents, + source_documents=documents, + token_mapping=interned_document_token_slices, + model_pars=model_params, + tensorize_model_inputs=True + ) + evidence_token_identifier = evidence_token_identifier.cpu() + else: + logger.info('No evidence token identifier provided; skipping') + evidence_token_identifier_results = None + + # train the evidence classifier + logger.info('Beginning training of the evidence classifier') + evidence_classifier = evidence_classifier.cuda() + optimizer = None + scheduler = None + evidence_classifier, evidence_class_results = train_evidence_classifier(evidence_classifier, + args.output_dir, + interned_train, + interned_val, + interned_documents, + model_params, + optimizer=optimizer, + scheduler=scheduler, + class_interner=evidence_classes, + tensorize_model_inputs=True, + token_only_evidence=(evidence_token_identifier is not None)) + + # decode + logger.info('Beginning final decoding') + + if evidence_identifier is not None: + pipeline_batch_size = min([model_params['evidence_classifier']['batch_size'], + model_params['evidence_identifier']['batch_size']]) + pipeline_results, train_decoded, val_decoded, test_decoded = decode(evidence_identifier.cuda(), + evidence_classifier.cuda(), + interned_train, + interned_val, + interned_test, + interned_documents, + evidence_classes, + pipeline_batch_size, + BATCH_FIRST, + decoding_docs=documents) + elif evidence_token_identifier is not None: + pipeline_batch_size = min([model_params['evidence_classifier']['batch_size'], + model_params['evidence_token_identifier']['batch_size']]) + use_cose_hack = bool(model_params['evidence_token_identifier'].get('cose_data_hack', 0)) + if use_cose_hack: + logger.info("Using COS-E data prep and processing hacks!") + logger.info("These hacks impact evidence identification, classification, and decoding. They are not general") + pipeline_results, train_decoded, val_decoded, test_decoded = decode_evidence_tokens_and_classify(evidence_token_identifier.cuda(), + evidence_classifier.cuda(), + interned_train, + interned_val, + interned_test, + interned_documents, + documents, + interned_document_token_slices, # token_mapping + evidence_classes, + pipeline_batch_size, + decoding_docs=documents, + use_cose_hack=use_cose_hack) + else: + raise StateError('Impossible state!') + + write_jsonl(train_decoded, os.path.join(args.output_dir, 'train_decoded.jsonl')) + write_jsonl(val_decoded, os.path.join(args.output_dir, 'val_decoded.jsonl')) + write_jsonl(test_decoded, os.path.join(args.output_dir, 'test_decoded.jsonl')) + with open(os.path.join(args.output_dir, 'identifier_results.json'), 'w') as ident_output, \ + open(os.path.join(args.output_dir, 'classifier_results.json'), 'w') as class_output: + ident_output.write(json.dumps(evidence_ident_results)) + class_output.write(json.dumps(evidence_class_results)) + for k, v in pipeline_results.items(): + if type(v) is dict: + for k1, v1 in v.items(): + logging.info(f'Pipeline results for {k}, {k1}={v1}') + else: + logging.info(f'Pipeline results {k}\t={v}') + + +if __name__ == '__main__': + main() diff --git a/scripts/eraserbenchmark/rationale_benchmark/models/pipeline/evidence_classifier.py b/scripts/eraserbenchmark/rationale_benchmark/models/pipeline/evidence_classifier.py new file mode 100644 index 0000000..822b79d --- /dev/null +++ b/scripts/eraserbenchmark/rationale_benchmark/models/pipeline/evidence_classifier.py @@ -0,0 +1,193 @@ +import logging +import os +import random + +from collections import OrderedDict +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn + +from sklearn.metrics import accuracy_score, classification_report + +from rationale_benchmark.utils import Annotation + +from rationale_benchmark.models.pipeline.pipeline_utils import ( + annotations_to_evidence_classification, + token_annotations_to_evidence_classification, + make_preds_epoch, +) + + +def train_evidence_classifier(evidence_classifier: nn.Module, + save_dir: str, + train: List[Annotation], + val: List[Annotation], + documents: Dict[str, List[List[int]]], + model_pars: dict, + class_interner: Dict[str, int], + optimizer=None, + scheduler=None, + tensorize_model_inputs: bool = True, + token_only_evidence: bool=False) -> Tuple[nn.Module, dict]: + """Trains an end-task classifier based on the ground truth evidence + + This method tracks loss on the validation set, saves intermediate + models, and supports restoring from an unfinished state. The best model on + the validation set is maintained, and the model stops training if a patience + (see below) number of epochs with no improvement is exceeded. + + Args: + evidence_classifier: a module like the AttentiveClassifier + save_dir: a place to save intermediate and final results and models. + train: a List of interned Annotation objects. + val: a List of interned Annotation objects. + documents: a Dict of interned sentences + model_pars: Arbitrary parameters directory, assumed to contain an "evidence_classifier" sub-dict with: + lr: learning rate + batch_size: an int + sampling_method: a string, plus additional params in the dict to define creation of a sampler. + This should probably just be "everything" + epochs: the number of epochs to train for + patience: how long to wait for an improvement before giving up. + max_grad_norm: optional, clip gradients. + class_interner: an object for converting Annotation classes into ints. + optimizer: what pytorch optimizer to use, if none, initialize Adam + scheduler: optional, do we want a scheduler involved in learning? + tensorize_model_inputs: should we convert our data to tensors before passing it to the model? + Useful if we have a model that performs its own tokenization (e.g. BERT as a Service) + + Returns: + the trained evidence classifier and a dictionary of intermediate results. + """ + logging.info(f'Beginning training evidence classifier with {len(train)} annotations, {len(val)} for validation') + evidence_classifier_output_dir = os.path.join(save_dir, 'evidence_classifier') + os.makedirs(save_dir, exist_ok=True) + os.makedirs(evidence_classifier_output_dir, exist_ok=True) + model_save_file = os.path.join(evidence_classifier_output_dir, 'evidence_classifier.pt') + epoch_save_file = os.path.join(evidence_classifier_output_dir, 'evidence_classifier_epoch_data.pt') + + device = next(evidence_classifier.parameters()).device + if optimizer is None: + optimizer = torch.optim.Adam(evidence_classifier.parameters(), lr=model_pars['evidence_classifier']['lr']) + criterion = nn.CrossEntropyLoss(reduction='none') + batch_size = model_pars['evidence_classifier']['batch_size'] + epochs = model_pars['evidence_classifier']['epochs'] + patience = model_pars['evidence_classifier']['patience'] + max_grad_norm = model_pars['evidence_classifier'].get('max_grad_norm', None) + + if token_only_evidence: + evidence_train_data = token_annotations_to_evidence_classification(train, documents, class_interner) + evidence_val_data = token_annotations_to_evidence_classification(val, documents, class_interner) + else: + evidence_train_data = annotations_to_evidence_classification(train, documents, class_interner, include_all=False) + evidence_val_data = annotations_to_evidence_classification(val, documents, class_interner, include_all=False) + + class_labels = [k for k, v in sorted(class_interner.items())] + + results = { + 'train_loss': [], + 'train_f1': [], + 'train_acc': [], + 'val_loss': [], + 'val_f1': [], + 'val_acc': [], + } + best_epoch = -1 + best_val_loss = float('inf') + best_model_state_dict = None + start_epoch = 0 + epoch_data = {} + if os.path.exists(epoch_save_file): + logging.info(f'Restoring model from {model_save_file}') + evidence_classifier.load_state_dict(torch.load(model_save_file)) + epoch_data = torch.load(epoch_save_file) + start_epoch = epoch_data['epoch'] + 1 + # handle finishing because patience was exceeded or we didn't get the best final epoch + if bool(epoch_data.get('done', 0)): + start_epoch = epochs + results = epoch_data['results'] + best_epoch = start_epoch + best_model_state_dict = OrderedDict({k: v.cpu() for k, v in evidence_classifier.state_dict().items()}) + logging.info(f'Restoring training from epoch {start_epoch}') + logging.info(f'Training evidence classifier from epoch {start_epoch} until epoch {epochs}') + optimizer.zero_grad() + for epoch in range(start_epoch, epochs): + epoch_train_data = random.sample(evidence_train_data, k=len(evidence_train_data)) + epoch_val_data = random.sample(evidence_val_data, k=len(evidence_val_data)) + epoch_train_loss = 0 + evidence_classifier.train() + logging.info( + f'Training with {len(epoch_train_data) // batch_size} batches with {len(epoch_train_data)} examples') + for batch_start in range(0, len(epoch_train_data), batch_size): + batch_elements = epoch_train_data[batch_start:min(batch_start + batch_size, len(epoch_train_data))] + targets, queries, sentences = zip(*[(s.kls, s.query, s.sentence) for s in batch_elements]) + ids = [(s.ann_id, s.docid, s.index) for s in batch_elements] + targets = torch.tensor(targets, dtype=torch.long, device=device) + if tensorize_model_inputs: + queries = [torch.tensor(q, dtype=torch.long) for q in queries] + sentences = [torch.tensor(s, dtype=torch.long) for s in sentences] + preds = evidence_classifier(queries, ids, sentences) + loss = criterion(preds, targets.to(device=preds.device)).sum() + epoch_train_loss += loss.item() + loss = loss / len(preds) # accumulate entire loss above + loss.backward() + assert loss == loss # for nans + if max_grad_norm: + torch.nn.utils.clip_grad_norm_(evidence_classifier.parameters(), max_grad_norm) + optimizer.step() + if scheduler: + scheduler.step() + optimizer.zero_grad() + epoch_train_loss /= len(epoch_train_data) + assert epoch_train_loss == epoch_train_loss # for nans + results['train_loss'].append(epoch_train_loss) + logging.info(f'Epoch {epoch} training loss {epoch_train_loss}') + + with torch.no_grad(): + epoch_train_loss, epoch_train_soft_pred, epoch_train_hard_pred, epoch_train_truth = make_preds_epoch( + evidence_classifier, epoch_train_data, batch_size, device, criterion=criterion, + tensorize_model_inputs=tensorize_model_inputs) + results['train_f1'].append( + classification_report(epoch_train_truth, epoch_train_hard_pred, target_names=class_labels, + output_dict=True)) + results['train_acc'].append(accuracy_score(epoch_train_truth, epoch_train_hard_pred)) + epoch_val_loss, epoch_val_soft_pred, epoch_val_hard_pred, epoch_val_truth = make_preds_epoch( + evidence_classifier, epoch_val_data, batch_size, device, criterion=criterion, + tensorize_model_inputs=tensorize_model_inputs) + results['val_loss'].append(epoch_val_loss) + results['val_f1'].append( + classification_report(epoch_val_truth, epoch_val_hard_pred, target_names=class_labels, + output_dict=True)) + results['val_acc'].append(accuracy_score(epoch_val_truth, epoch_val_hard_pred)) + assert epoch_val_loss == epoch_val_loss # for nans + logging.info(f'Epoch {epoch} val loss {epoch_val_loss}') + logging.info(f'Epoch {epoch} val acc {results["val_acc"][-1]}') + logging.info(f'Epoch {epoch} val f1 {results["val_f1"][-1]}') + + if epoch_val_loss < best_val_loss: + best_model_state_dict = OrderedDict({k: v.cpu() for k, v in evidence_classifier.state_dict().items()}) + best_epoch = epoch + best_val_loss = epoch_val_loss + epoch_data = { + 'epoch': epoch, + 'results': results, + 'best_val_loss': best_val_loss, + 'done': 0, + } + torch.save(evidence_classifier.state_dict(), model_save_file) + torch.save(epoch_data, epoch_save_file) + logging.debug(f'Epoch {epoch} new best model with val loss {epoch_val_loss}') + if epoch - best_epoch > patience: + logging.info(f'Exiting after epoch {epoch} due to no improvement') + epoch_data['done'] = 1 + torch.save(epoch_data, epoch_save_file) + break + + epoch_data['done'] = 1 + epoch_data['results'] = results + torch.save(epoch_data, epoch_save_file) + evidence_classifier.load_state_dict(best_model_state_dict) + evidence_classifier = evidence_classifier.to(device=device) + evidence_classifier.eval() + return evidence_classifier, results diff --git a/scripts/eraserbenchmark/rationale_benchmark/models/pipeline/evidence_identifier.py b/scripts/eraserbenchmark/rationale_benchmark/models/pipeline/evidence_identifier.py new file mode 100644 index 0000000..7fd9def --- /dev/null +++ b/scripts/eraserbenchmark/rationale_benchmark/models/pipeline/evidence_identifier.py @@ -0,0 +1,261 @@ +import logging +import os +import random + +from collections import OrderedDict +from itertools import chain +from typing import Callable, Dict, List, Tuple + +import torch +import torch.nn as nn + +from sklearn.metrics import accuracy_score + +from rationale_benchmark.utils import Annotation + +from rationale_benchmark.models.pipeline.pipeline_utils import ( + SentenceEvidence, + annotations_to_evidence_identification, + make_preds_epoch, + score_rationales, +) + + +def _get_sampling_method(training_pars: dict) -> Callable[ + [List[SentenceEvidence], Dict[str, List[SentenceEvidence]]], List[SentenceEvidence]]: + """Generates a sampler that produces (positive, negative) sentence-level examples + + Returns a function that takes a document converted to sentence level + annotations and a dictionary of docid -> sentence level annotations, and + returns a set of sentence level annotations. + + This sampling method is necessary as we can have far too many negative + examples in our training data (almost nothing is actually evidence). + + n.b. this factory is clearly crying for modularization, again into + something that would call for dependency injection, but for the duration of + this project, this will be fine. + """ + + # TODO implement sampling for nearby sentences (within the document) + if training_pars['sampling_method'] == 'random': + sampling_ratio = training_pars['sampling_ratio'] + logging.info(f'Setting up random sampling with negative/positive ratio = {sampling_ratio}') + + def random_sampler(document: List[SentenceEvidence], _: Dict[str, List[SentenceEvidence]]) -> \ + List[SentenceEvidence]: + """Takes all the positives from a document, and a random choice over negatives""" + positives = list(filter(lambda s: s.kls == 1 and len(s.sentence) > 0, document)) + if any(map(lambda s: len(s.sentence) == 0, positives)): + raise ValueError("Some positive sentences are of zero length!") + all_negatives = list(filter(lambda s: s.kls == 0 and len(s.sentence) > 0, document)) + # handle an edge case where a document can be only or mostly evidence for a statement + num_negatives = min(len(all_negatives), round(len(positives) * sampling_ratio)) + random_negatives = random.choices(all_negatives, k=num_negatives) + # sort the results so the next step is deterministic, + results = sorted(positives + random_negatives) + # this is an inplace shuffle. + random.shuffle(results) + return results + + return random_sampler + elif training_pars['sampling_method'] == 'everything': + def everything_sampler(document: List[SentenceEvidence], + _: Dict[str, List[SentenceEvidence]]) -> List[SentenceEvidence]: + return document + return everything_sampler + else: + raise ValueError(f"Unknown sampling method for training: {training_pars['sampling_method']}") + + +def train_evidence_identifier(evidence_identifier: nn.Module, + save_dir: str, + train: List[Annotation], + val: List[Annotation], + documents: Dict[str, List[List[int]]], + model_pars: dict, + optimizer=None, + scheduler=None, + tensorize_model_inputs: bool = True) -> Tuple[nn.Module, dict]: + """Trains a module for rationale identification. + + This method tracks loss on the entire validation set, saves intermediate + models, and supports restoring from an unfinished state. The best model on + the validation set is maintained, and the model stops training if a patience + (see below) number of epochs with no improvement is exceeded. + + As there are likely too many negative examples to reasonably train a + classifier on everything, every epoch we subsample the negatives. + + Args: + evidence_identifier: a module like the AttentiveClassifier + save_dir: a place to save intermediate and final results and models. + train: a List of interned Annotation objects. + val: a List of interned Annotation objects. + documents: a Dict of interned sentences + model_pars: Arbitrary parameters directory, assumed to contain an "evidence_identifier" sub-dict with: + lr: learning rate + batch_size: an int + sampling_method: a string, plus additional params in the dict to define creation of a sampler + epochs: the number of epochs to train for + patience: how long to wait for an improvement before giving up. + max_grad_norm: optional, clip gradients. + optimizer: what pytorch optimizer to use, if none, initialize Adam + scheduler: optional, do we want a scheduler involved in learning? + tensorize_model_inputs: should we convert our data to tensors before passing it to the model? + Useful if we have a model that performs its own tokenization (e.g. BERT as a Service) + + Returns: + the trained evidence identifier and a dictionary of intermediate results. + """ + + def _prep_data_for_epoch(evidence_data: Dict[str, Dict[str, List[SentenceEvidence]]], + sampler: Callable[ + [List[SentenceEvidence], Dict[str, List[SentenceEvidence]]], List[SentenceEvidence]] + ) -> List[SentenceEvidence]: + output_sentences = [] + ann_ids = sorted(evidence_data.keys()) + # in place shuffle so we get a different per-epoch ordering + random.shuffle(ann_ids) + for ann_id in ann_ids: + for docid, sentences in evidence_data[ann_id].items(): + data = sampler(sentences, None) + output_sentences.extend(data) + return output_sentences + + logging.info(f'Beginning training with {len(train)} annotations, {len(val)} for validation') + evidence_identifier_output_dir = os.path.join(save_dir, 'evidence_identifier') + os.makedirs(save_dir, exist_ok=True) + os.makedirs(evidence_identifier_output_dir, exist_ok=True) + + model_save_file = os.path.join(evidence_identifier_output_dir, 'evidence_identifier.pt') + epoch_save_file = os.path.join(evidence_identifier_output_dir, 'evidence_identifier_epoch_data.pt') + + if optimizer is None: + optimizer = torch.optim.Adam(evidence_identifier.parameters(), lr=model_pars['evidence_identifier']['lr']) + criterion = nn.CrossEntropyLoss(reduction='none') + sampling_method = _get_sampling_method(model_pars['evidence_identifier']) + batch_size = model_pars['evidence_identifier']['batch_size'] + epochs = model_pars['evidence_identifier']['epochs'] + patience = model_pars['evidence_identifier']['patience'] + max_grad_norm = model_pars['evidence_classifier'].get('max_grad_norm', None) + + evidence_train_data = annotations_to_evidence_identification(train, documents) + evidence_val_data = annotations_to_evidence_identification(val, documents) + + device = next(evidence_identifier.parameters()).device + + results = { + # "sampled" losses do not represent the true data distribution, but do represent training data + 'sampled_epoch_train_losses': [], + 'sampled_epoch_val_losses': [], + # "full" losses do represent the true data distribution + 'full_epoch_val_losses': [], + 'full_epoch_val_acc': [], + 'full_epoch_val_rationale_scores': [], + } + # allow restoring an existing training run + start_epoch = 0 + best_epoch = -1 + best_val_loss = float('inf') + best_model_state_dict = None + epoch_data = {} + if os.path.exists(epoch_save_file): + evidence_identifier.load_state_dict(torch.load(model_save_file)) + epoch_data = torch.load(epoch_save_file) + start_epoch = epoch_data['epoch'] + 1 + # handle finishing because patience was exceeded or we didn't get the best final epoch + if bool(epoch_data.get('done', 0)): + start_epoch = epochs + results = epoch_data['results'] + best_epoch = start_epoch + best_model_state_dict = OrderedDict({k: v.cpu() for k, v in evidence_identifier.state_dict().items()}) + logging.info(f'Training evidence identifier from epoch {start_epoch} until epoch {epochs}') + optimizer.zero_grad() + for epoch in range(start_epoch, epochs): + epoch_train_data = _prep_data_for_epoch(evidence_train_data, sampling_method) + epoch_val_data = _prep_data_for_epoch(evidence_val_data, sampling_method) + sampled_epoch_train_loss = 0 + evidence_identifier.train() + logging.info( + f'Training with {len(epoch_train_data) // batch_size} batches with {len(epoch_train_data)} examples') + for batch_start in range(0, len(epoch_train_data), batch_size): + batch_elements = epoch_train_data[batch_start:min(batch_start + batch_size, len(epoch_train_data))] + # we sample every time to thereoretically get a better representation of instances over the corpus. + # this might just take more time than doing so in advance. + targets, queries, sentences = zip(*[(s.kls, s.query, s.sentence) for s in batch_elements]) + ids = [(s.ann_id, s.docid, s.index) for s in batch_elements] + targets = torch.tensor(targets, dtype=torch.long, device=device) + if tensorize_model_inputs: + queries = [torch.tensor(q, dtype=torch.long) for q in queries] + sentences = [torch.tensor(s, dtype=torch.long) for s in sentences] + preds = evidence_identifier(queries, ids, sentences) + loss = criterion(preds, targets.to(device=preds.device)).sum() + sampled_epoch_train_loss += loss.item() + loss = loss / len(preds) + loss.backward() + if max_grad_norm: + torch.nn.utils.clip_grad_norm_(evidence_identifier.parameters(), max_grad_norm) + optimizer.step() + if scheduler: + scheduler.step() + optimizer.zero_grad() + sampled_epoch_train_loss /= len(epoch_train_data) + results['sampled_epoch_train_losses'].append(sampled_epoch_train_loss) + logging.info(f'Epoch {epoch} sampled training loss {sampled_epoch_train_loss}') + + with torch.no_grad(): + evidence_identifier.eval() + sampled_epoch_val_loss, _, sampled_epoch_val_hard_pred, sampled_epoch_val_truth = \ + make_preds_epoch(evidence_identifier, + epoch_val_data, + batch_size, + device, + criterion, + tensorize_model_inputs) + results['sampled_epoch_val_losses'].append(sampled_epoch_val_loss) + sampled_epoch_val_acc = accuracy_score(sampled_epoch_val_truth, sampled_epoch_val_hard_pred) + logging.info(f'Epoch {epoch} sampled val loss {sampled_epoch_val_loss}, acc {sampled_epoch_val_acc}') + # evaluate over *all* of the validation data + all_val_data = list(filter(lambda se: len(se.sentence) > 0, chain.from_iterable( + chain.from_iterable(x.values() for x in evidence_val_data.values())))) + epoch_val_loss, epoch_val_soft_pred, epoch_val_hard_pred, epoch_val_truth = \ + make_preds_epoch(evidence_identifier, + all_val_data, + batch_size, + device, + criterion, + tensorize_model_inputs) + results['full_epoch_val_losses'].append(epoch_val_loss) + results['full_epoch_val_acc'].append(accuracy_score(epoch_val_truth, epoch_val_hard_pred)) + results['full_epoch_val_rationale_scores'].append( + score_rationales(val, documents, epoch_val_data, epoch_val_soft_pred)) + logging.info( + f'Epoch {epoch} full val loss {epoch_val_loss}, accuracy: {results["full_epoch_val_acc"][-1]}, rationale scores: {results["full_epoch_val_rationale_scores"][-1]}') + + # if epoch_val_loss < best_val_loss: + if sampled_epoch_val_loss < best_val_loss: + logging.debug(f'Epoch {epoch} new best model with sampled val loss {sampled_epoch_val_loss}') + best_model_state_dict = OrderedDict({k: v.cpu() for k, v in evidence_identifier.state_dict().items()}) + best_epoch = epoch + best_val_loss = sampled_epoch_val_loss + torch.save(evidence_identifier.state_dict(), model_save_file) + epoch_data = { + 'epoch': epoch, + 'results': results, + 'best_val_loss': best_val_loss, + 'done': 0 + } + torch.save(epoch_data, epoch_save_file) + if epoch - best_epoch > patience: + epoch_data['done'] = 1 + torch.save(epoch_data, epoch_save_file) + break + + epoch_data['done'] = 1 + epoch_data['results'] = results + torch.save(epoch_data, epoch_save_file) + evidence_identifier.load_state_dict(best_model_state_dict) + evidence_identifier = evidence_identifier.to(device=device) + evidence_identifier.eval() + return evidence_identifier, results diff --git a/scripts/eraserbenchmark/rationale_benchmark/models/pipeline/evidence_token_identifier.py b/scripts/eraserbenchmark/rationale_benchmark/models/pipeline/evidence_token_identifier.py new file mode 100644 index 0000000..4641353 --- /dev/null +++ b/scripts/eraserbenchmark/rationale_benchmark/models/pipeline/evidence_token_identifier.py @@ -0,0 +1,246 @@ +import itertools +import logging +import os +import random + +from collections import OrderedDict +from itertools import chain +from sklearn.metrics import accuracy_score, classification_report +from typing import List, Dict, Tuple, Callable + +import torch + +from torch import nn + +from rationale_benchmark.models.pipeline.pipeline_utils import SentenceEvidence, score_rationales, \ + annotations_to_evidence_token_identification, make_token_preds_epoch +from rationale_benchmark.utils import Annotation +from rationale_benchmark.models.model_utils import PaddedSequence + + +def _get_sampling_method(training_pars: dict) -> Callable[ + [List[SentenceEvidence], Dict[str, List[SentenceEvidence]]], List[SentenceEvidence]]: + """Generates a sampler that produces sentences with a mix of positive and negative evidence tokens + + Returns a function that takes a document converted to sentence level + annotations and a dictionary of docid -> sentence level annotations, and + returns a set of sentence level annotations. + + This is theoretically necessary to support a pipeline with three stages: evidence sentence identification, token + identification, followed by classification on the survivors. + + For e-SNLI and COS-E the correct assignment is to use "everything" or a variant like everything. + """ + + if training_pars['sampling_method'] == 'everything': + def everything_sampler(document: List[SentenceEvidence], + _: Dict[str, List[SentenceEvidence]]) -> List[SentenceEvidence]: + assert len(document) == 1 + return document[0] + return everything_sampler + else: + raise ValueError(f"Unknown sampling method for training: {training_pars['sampling_method']}") + + +def train_evidence_token_identifier(evidence_token_identifier: nn.Module, + save_dir: str, + train: List[Annotation], + val: List[Annotation], + interned_documents: Dict[str, List[List[int]]], + source_documents: Dict[str, List[List[str]]], + token_mapping: Dict[str, List[List[Tuple[int, int]]]], + model_pars: dict, + optimizer=None, + scheduler=None, + tensorize_model_inputs: bool = True) -> Tuple[nn.Module, dict]: + """Trains a module for token-level rationale identification. + + This method tracks loss on the entire validation set, saves intermediate + models, and supports restoring from an unfinished state. The best model on + the validation set is maintained, and the model stops training if a patience + (see below) number of epochs with no improvement is exceeded. + + As there are likely too many negative examples to reasonably train a + classifier on everything, every epoch we subsample the negatives. + + Args: + evidence_token_identifier: a module like the AttentiveClassifier + save_dir: a place to save intermediate and final results and models. + train: a List of interned Annotation objects. + val: a List of interned Annotation objects. + interned_documents: a Dict of interned sentences + source_documents: + token_mapping: + model_pars: Arbitrary parameters directory, assumed to contain an "evidence_identifier" sub-dict with: + lr: learning rate + batch_size: an int + sampling_method: a string, plus additional params in the dict to define creation of a sampler + epochs: the number of epochs to train for + patience: how long to wait for an improvement before giving up. + max_grad_norm: optional, clip gradients. + optimizer: what pytorch optimizer to use, if none, initialize Adam + scheduler: optional, do we want a scheduler involved in learning? + tensorize_model_inputs: should we convert our data to tensors before passing it to the model? + Useful if we have a model that performs its own tokenization (e.g. BERT as a Service) + + Returns: + the trained evidence token identifier and a dictionary of intermediate results. + """ + + def _prep_data_for_epoch(evidence_data: Dict[str, Dict[str, List[SentenceEvidence]]], + sampler: Callable[ + [List[SentenceEvidence], Dict[str, List[SentenceEvidence]]], List[SentenceEvidence]] + ) -> List[SentenceEvidence]: + output_sentences = [] + ann_ids = sorted(evidence_data.keys()) + # in place shuffle so we get a different per-epoch ordering + random.shuffle(ann_ids) + for ann_id in ann_ids: + for docid, sentences in evidence_data[ann_id].items(): + data = sampler(sentences, None) + output_sentences.append(data) + return output_sentences + + logging.info(f'Beginning training with {len(train)} annotations, {len(val)} for validation') + evidence_identifier_output_dir = os.path.join(save_dir, 'evidence_token_identifier') + os.makedirs(save_dir, exist_ok=True) + os.makedirs(evidence_identifier_output_dir, exist_ok=True) + + model_save_file = os.path.join(evidence_identifier_output_dir, 'evidence_token_identifier.pt') + epoch_save_file = os.path.join(evidence_identifier_output_dir, 'evidence_token_identifier_epoch_data.pt') + + if optimizer is None: + optimizer = torch.optim.Adam(evidence_token_identifier.parameters(), lr=model_pars['evidence_token_identifier']['lr']) + criterion = nn.BCELoss(reduction='none') + sampling_method = _get_sampling_method(model_pars['evidence_token_identifier']) + batch_size = model_pars['evidence_token_identifier']['batch_size'] + epochs = model_pars['evidence_token_identifier']['epochs'] + patience = model_pars['evidence_token_identifier']['patience'] + max_grad_norm = model_pars['evidence_token_identifier'].get('max_grad_norm', None) + use_cose_hack = bool(model_pars['evidence_token_identifier'].get('cose_data_hack', 0)) + + # annotation id -> docid -> [SentenceEvidence]) + evidence_train_data = annotations_to_evidence_token_identification(train, + source_documents=source_documents, + interned_documents=interned_documents, + token_mapping=token_mapping) + evidence_val_data = annotations_to_evidence_token_identification(val, + source_documents=source_documents, + interned_documents=interned_documents, + token_mapping=token_mapping) + + device = next(evidence_token_identifier.parameters()).device + + results = { + # "sampled" losses do not represent the true data distribution, but do represent training data + 'sampled_epoch_train_losses': [], + # "full" losses do represent the true data distribution + 'epoch_val_losses': [], + 'epoch_val_acc': [], + 'epoch_val_f': [], + 'epoch_val_rationale_scores': [], + } + # allow restoring an existing training run + start_epoch = 0 + best_epoch = -1 + best_val_loss = float('inf') + best_model_state_dict = None + epoch_data = {} + if os.path.exists(epoch_save_file): + evidence_token_identifier.load_state_dict(torch.load(model_save_file)) + epoch_data = torch.load(epoch_save_file) + start_epoch = epoch_data['epoch'] + 1 + # handle finishing because patience was exceeded or we didn't get the best final epoch + if bool(epoch_data.get('done', 0)): + start_epoch = epochs + results = epoch_data['results'] + best_epoch = start_epoch + best_model_state_dict = OrderedDict({k: v.cpu() for k, v in evidence_token_identifier.state_dict().items()}) + logging.info(f'Training evidence identifier from epoch {start_epoch} until epoch {epochs}') + optimizer.zero_grad() + for epoch in range(start_epoch, epochs): + epoch_train_data = _prep_data_for_epoch(evidence_train_data, sampling_method) + epoch_val_data = _prep_data_for_epoch(evidence_val_data, sampling_method) + sampled_epoch_train_loss = 0 + evidence_token_identifier.train() + logging.info( + f'Training with {len(epoch_train_data) // batch_size} batches with {len(epoch_train_data)} examples') + for batch_start in range(0, len(epoch_train_data), batch_size): + batch_elements = epoch_train_data[batch_start:min(batch_start + batch_size, len(epoch_train_data))] + # we sample every time to thereoretically get a better representation of instances over the corpus. + # this might just take more time than doing so in advance. + targets, queries, sentences = zip(*[(s.kls, s.query, s.sentence) for s in batch_elements]) + ids = [(s.ann_id, s.docid, s.index) for s in batch_elements] + #targets = torch.tensor(targets, dtype=torch.long, device=device) + targets = PaddedSequence.autopad([torch.tensor(t, dtype=torch.long, device=device) for t in targets], batch_first=True, device=device) + aggregate_spans = [token_mapping[s.docid][s.index] for s in batch_elements] + if tensorize_model_inputs: + if all(q is None for q in queries): + queries = [torch.tensor([], dtype=torch.long) for _ in queries] + else: + assert all(q is not None for q in queries) + queries = [torch.tensor(q, dtype=torch.long) for q in queries] + sentences = [torch.tensor(s, dtype=torch.long) for s in sentences] + preds = evidence_token_identifier(queries, ids, sentences, aggregate_spans) + mask = targets.mask(on=1, off=0, device=device, dtype=torch.float) + preds = preds * mask + loss = criterion(preds, (targets.data.to(device=preds.device) * mask).squeeze()).sum() + sampled_epoch_train_loss += loss.item() + loss = loss / torch.sum(mask) + loss.backward() + if max_grad_norm: + torch.nn.utils.clip_grad_norm_(evidence_token_identifier.parameters(), max_grad_norm) + optimizer.step() + if scheduler: + scheduler.step() + optimizer.zero_grad() + sampled_epoch_train_loss /= len(epoch_train_data) + results['sampled_epoch_train_losses'].append(sampled_epoch_train_loss) + logging.info(f'Epoch {epoch} training loss {sampled_epoch_train_loss}') + + with torch.no_grad(): + evidence_token_identifier.eval() + epoch_val_loss, epoch_val_soft_pred, epoch_val_hard_pred, epoch_val_truth = \ + make_token_preds_epoch(evidence_token_identifier, + epoch_val_data, + token_mapping, + batch_size, + device, + criterion, + tensorize_model_inputs) + #epoch_val_soft_pred = list(chain.from_iterable(epoch_val_soft_pred)) + epoch_val_hard_pred = list(chain.from_iterable(epoch_val_hard_pred)) + epoch_val_truth = list(chain.from_iterable(epoch_val_truth)) + results['epoch_val_losses'].append(epoch_val_loss) + results['epoch_val_acc'].append(accuracy_score(epoch_val_truth, epoch_val_hard_pred)) + results['epoch_val_f'].append(classification_report(epoch_val_truth, epoch_val_hard_pred, output_dict=True)) + epoch_val_soft_pred_for_scoring = [[[1 - z, z] for z in y] for y in epoch_val_soft_pred] + logging.info( + f'Epoch {epoch} full val loss {epoch_val_loss}, accuracy: {results["epoch_val_acc"][-1]}, f: {results["epoch_val_f"][-1]}, rationale scores: look, it\'s already a pain to duplicate this code. What do you want from me.') + + # if epoch_val_loss < best_val_loss: + if epoch_val_loss < best_val_loss: + logging.debug(f'Epoch {epoch} new best model with val loss {epoch_val_loss}') + best_model_state_dict = OrderedDict({k: v.cpu() for k, v in evidence_token_identifier.state_dict().items()}) + best_epoch = epoch + best_val_loss = epoch_val_loss + torch.save(evidence_token_identifier.state_dict(), model_save_file) + epoch_data = { + 'epoch': epoch, + 'results': results, + 'best_val_loss': best_val_loss, + 'done': 0 + } + torch.save(epoch_data, epoch_save_file) + if epoch - best_epoch > patience: + epoch_data['done'] = 1 + torch.save(epoch_data, epoch_save_file) + break + + epoch_data['done'] = 1 + epoch_data['results'] = results + torch.save(epoch_data, epoch_save_file) + evidence_token_identifier.load_state_dict(best_model_state_dict) + evidence_token_identifier = evidence_token_identifier.to(device=device) + evidence_token_identifier.eval() + return evidence_token_identifier, results diff --git a/scripts/eraserbenchmark/rationale_benchmark/models/pipeline/pipeline_train.py b/scripts/eraserbenchmark/rationale_benchmark/models/pipeline/pipeline_train.py new file mode 100644 index 0000000..1d82d4b --- /dev/null +++ b/scripts/eraserbenchmark/rationale_benchmark/models/pipeline/pipeline_train.py @@ -0,0 +1,167 @@ +import argparse +import json +import logging +import random +import os + +from itertools import chain +from typing import Set + +import numpy as np +import torch + +from rationale_benchmark.utils import ( + write_jsonl, + load_datasets, + load_documents, + intern_documents, + intern_annotations +) +from rationale_benchmark.models.mlp import ( + AttentiveClassifier, + BahadanauAttention, + RNNEncoder, + WordEmbedder +) +from rationale_benchmark.models.model_utils import extract_embeddings +from rationale_benchmark.models.pipeline.evidence_identifier import train_evidence_identifier +from rationale_benchmark.models.pipeline.evidence_classifier import train_evidence_classifier +from rationale_benchmark.models.pipeline.pipeline_utils import decode + +logging.basicConfig(level=logging.DEBUG, format='%(relativeCreated)6d %(threadName)s %(message)s') +# let's make this more or less deterministic (not resistant to restarts) +random.seed(12345) +np.random.seed(67890) +torch.manual_seed(10111213) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + + +def initialize_models(params: dict, vocab: Set[str], batch_first: bool, unk_token='UNK'): + # TODO this is obviously asking for some sort of dependency injection. implement if it saves me time. + if 'embedding_file' in params['embeddings']: + embeddings, word_interner, de_interner = extract_embeddings(vocab, params['embeddings']['embedding_file'], unk_token=unk_token) + if torch.cuda.is_available(): + embeddings = embeddings.cuda() + else: + raise ValueError("No 'embedding_file' found in params!") + word_embedder = WordEmbedder(embeddings, params['embeddings']['dropout']) + query_encoder = RNNEncoder(word_embedder, + batch_first=batch_first, + condition=False, + attention_mechanism=BahadanauAttention(word_embedder.output_dimension)) + document_encoder = RNNEncoder(word_embedder, + batch_first=batch_first, + condition=True, + attention_mechanism=BahadanauAttention(word_embedder.output_dimension, + query_size=query_encoder.output_dimension)) + evidence_identifier = AttentiveClassifier(document_encoder, + query_encoder, + 2, + params['evidence_identifier']['mlp_size'], + params['evidence_identifier']['dropout']) + query_encoder = RNNEncoder(word_embedder, + batch_first=batch_first, + condition=False, + attention_mechanism=BahadanauAttention(word_embedder.output_dimension)) + document_encoder = RNNEncoder(word_embedder, + batch_first=batch_first, + condition=True, + attention_mechanism=BahadanauAttention(word_embedder.output_dimension, + query_size=query_encoder.output_dimension)) + evidence_classes = dict((y,x) for (x,y) in enumerate(params['evidence_classifier']['classes'])) + evidence_classifier = AttentiveClassifier(document_encoder, + query_encoder, + len(evidence_classes), + params['evidence_classifier']['mlp_size'], + params['evidence_classifier']['dropout']) + return evidence_identifier, evidence_classifier, word_interner, de_interner, evidence_classes + + +def main(): + parser = argparse.ArgumentParser(description="""Trains a pipeline model. + + Step 1 is evidence identification, that is identify if a given sentence is evidence or not + Step 2 is evidence classification, that is given an evidence sentence, classify the final outcome for the final task (e.g. sentiment or significance). + + These models should be separated into two separate steps, but at the moment: + * prep data (load, intern documents, load json) + * convert data for evidence identification - in the case of training data we take all the positives and sample some negatives + * side note: this sampling is *somewhat* configurable and is done on a per-batch/epoch basis in order to gain a broader sampling of negative values. + * train evidence identification + * convert data for evidence classification - take all rationales + decisions and use this as input + * train evidence classification + * decode first the evidence, then run classification for each split + + """, formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument('--data_dir', dest='data_dir', required=True, + help='Which directory contains a {train,val,test}.jsonl file?') + parser.add_argument('--output_dir', dest='output_dir', required=True, + help='Where shall we write intermediate models + final data to?') + parser.add_argument('--model_params', dest='model_params', required=True, + help='JSoN file for loading arbitrary model parameters (e.g. optimizers, pre-saved files, etc.') + args = parser.parse_args() + BATCH_FIRST = True + + with open(args.model_params, 'r') as fp: + logging.debug(f'Loading model parameters from {args.model_params}') + model_params = json.load(fp) + train, val, test = load_datasets(args.data_dir) + docids = set(e.docid for e in chain.from_iterable(chain.from_iterable(map(lambda ann: ann.evidences, chain(train, val, test))))) + documents = load_documents(args.data_dir, docids) + document_vocab = set(chain.from_iterable(chain.from_iterable(documents.values()))) + annotation_vocab = set(chain.from_iterable(e.query.split() for e in chain(train, val, test))) + logging.debug(f'Loaded {len(documents)} documents with {len(document_vocab)} unique words') + # this ignores the case where annotations don't align perfectly with token boundaries, but this isn't that important + vocab = document_vocab | annotation_vocab + unk_token = 'UNK' + evidence_identifier, evidence_classifier, word_interner, de_interner, evidence_classes = \ + initialize_models(model_params, vocab, batch_first=BATCH_FIRST, unk_token=unk_token) + logging.debug(f'Including annotations, we have {len(vocab)} total words in the data, with embeddings for {len(word_interner)}') + interned_documents = intern_documents(documents, word_interner, unk_token) + interned_train = intern_annotations(train, word_interner, unk_token) + interned_val = intern_annotations(val, word_interner, unk_token) + interned_test = intern_annotations(test, word_interner, unk_token) + assert BATCH_FIRST # for correctness of the split dimension for DataParallel + evidence_identifier, evidence_ident_results = train_evidence_identifier(evidence_identifier.cuda(), + args.output_dir, interned_train, + interned_val, + interned_documents, + model_params, + tensorize_model_inputs=True) + evidence_classifier, evidence_class_results = train_evidence_classifier(evidence_classifier.cuda(), + args.output_dir, + interned_train, + interned_val, + interned_documents, + model_params, + class_interner=evidence_classes, + tensorize_model_inputs=True) + pipeline_batch_size = min([model_params['evidence_classifier']['batch_size'], + model_params['evidence_identifier']['batch_size']]) + pipeline_results, train_decoded, val_decoded, test_decoded = decode(evidence_identifier, + evidence_classifier, + interned_train, + interned_val, + interned_test, + interned_documents, + evidence_classes, + pipeline_batch_size, + tensorize_model_inputs=True) + write_jsonl(train_decoded, os.path.join(args.output_dir, 'train_decoded.jsonl')) + write_jsonl(val_decoded, os.path.join(args.output_dir, 'val_decoded.jsonl')) + write_jsonl(test_decoded, os.path.join(args.output_dir, 'test_decoded.jsonl')) + with open(os.path.join(args.output_dir, 'identifier_results.json'), 'w') as ident_output, \ + open(os.path.join(args.output_dir, 'classifier_results.json'), 'w') as class_output: + ident_output.write(json.dumps(evidence_ident_results)) + class_output.write(json.dumps(evidence_class_results)) + for k, v in pipeline_results.items(): + if type(v) is dict: + for k1, v1 in v.items(): + logging.info(f'Pipeline results for {k}, {k1}={v1}') + else: + logging.info(f'Pipeline results {k}\t={v}') + + +if __name__ == '__main__': + main() diff --git a/scripts/eraserbenchmark/rationale_benchmark/models/pipeline/pipeline_utils.py b/scripts/eraserbenchmark/rationale_benchmark/models/pipeline/pipeline_utils.py new file mode 100644 index 0000000..021976c --- /dev/null +++ b/scripts/eraserbenchmark/rationale_benchmark/models/pipeline/pipeline_utils.py @@ -0,0 +1,799 @@ +import itertools +import logging + +from collections import defaultdict, namedtuple +from itertools import chain +from typing import Any, Dict, List, Tuple + +import numpy as np +import torch +import torch.nn as nn + +from sklearn.metrics import classification_report, accuracy_score + +from rationale_benchmark.metrics import ( + PositionScoredDocument, + Rationale, + partial_match_score, + score_hard_rationale_predictions, + score_soft_tokens +) + +from rationale_benchmark.utils import Annotation +from rationale_benchmark.models.model_utils import PaddedSequence + +SentenceEvidence = namedtuple('SentenceEvidence', 'kls ann_id query docid index sentence') + +def token_annotations_to_evidence_classification(annotations: List[Annotation], + documents: Dict[str, List[List[Any]]], + class_interner: Dict[str, int], + ) -> List[SentenceEvidence]: + ret = [] + for ann in annotations: + docid_to_ev = defaultdict(list) + for evidence in ann.all_evidences(): + docid_to_ev[evidence.docid].append(evidence) + for docid, evidences in docid_to_ev.items(): + evidences = sorted(evidences, key=lambda ev: ev.start_token) + text = [] + covered_tokens = set() + doc = list(chain.from_iterable(documents[docid])) + for evidence in evidences: + assert evidence.start_token >= 0 and evidence.end_token > evidence.start_token + assert evidence.start_token < len(doc) and evidence.end_token <= len(doc) + text.extend(evidence.text) + new_tokens = set(range(evidence.start_token, evidence.end_token)) + if len(new_tokens & covered_tokens) > 0: + raise ValueError("Have overlapping token ranges covered in the evidence spans and the implementer was lazy; deal with it") + covered_tokens |= new_tokens + assert len(text) > 0 + ret.append(SentenceEvidence(kls=class_interner[ann.classification], + query=ann.query, + ann_id=ann.annotation_id, + docid=docid, + index=-1, + sentence=tuple(text))) + return ret + +def annotations_to_evidence_classification(annotations: List[Annotation], + documents: Dict[str, List[List[Any]]], + class_interner: Dict[str, int], + include_all: bool + ) -> List[SentenceEvidence]: + """Converts Corpus-Level annotations to Sentence Level relevance judgments. + + As this module is about a pipelined approach for evidence identification, + inputs to both an evidence identifier and evidence classifier need to be to + be on a sentence level, this module converts data to be that form. + + The return type is of the form + annotation id -> docid -> [sentence level annotations] + """ + ret = [] + for ann in annotations: + ann_id = ann.annotation_id + docids = set(ev.docid for ev in chain.from_iterable(ann.evidences)) + annotations_for_doc = defaultdict(list) + for d in docids: + for index, sent in enumerate(documents[d]): + annotations_for_doc[d].append( + SentenceEvidence( + kls=class_interner[ann.classification], + query=ann.query, + ann_id=ann.annotation_id, + docid=d, + index=index, + sentence=tuple(sent))) + if include_all: + ret.extend(chain.from_iterable(annotations_for_doc.values())) + else: + contributes = set() + for ev in chain.from_iterable(ann.evidences): + for index in range(ev.start_sentence, ev.end_sentence): + contributes.add(annotations_for_doc[ev.docid][index]) + ret.extend(contributes) + assert len(ret) > 0 + return ret + + +def annotations_to_evidence_identification(annotations: List[Annotation], + documents: Dict[str, List[List[Any]]] + ) -> Dict[str, Dict[str, List[SentenceEvidence]]]: + """Converts Corpus-Level annotations to Sentence Level relevance judgments. + + As this module is about a pipelined approach for evidence identification, + inputs to both an evidence identifier and evidence classifier need to be to + be on a sentence level, this module converts data to be that form. + + The return type is of the form + annotation id -> docid -> [sentence level annotations] + """ + ret = defaultdict(dict) # annotation id -> docid -> sentences + for ann in annotations: + ann_id = ann.annotation_id + for ev_group in ann.evidences: + for ev in ev_group: + if len(ev.text) == 0: + continue + if ev.docid not in ret[ann_id]: + ret[ann.annotation_id][ev.docid] = [] + # populate the document with "not evidence"; to be filled in later + for index, sent in enumerate(documents[ev.docid]): + ret[ann.annotation_id][ev.docid].append(SentenceEvidence( + kls=0, + query=ann.query, + ann_id=ann.annotation_id, + docid=ev.docid, + index=index, + sentence=sent)) + # define the evidence sections of the document + for s in range(ev.start_sentence, ev.end_sentence): + ret[ann.annotation_id][ev.docid][s] = SentenceEvidence( + kls=1, + ann_id=ann.annotation_id, + query=ann.query, + docid=ev.docid, + index=ret[ann.annotation_id][ev.docid][s].index, + sentence=ret[ann.annotation_id][ev.docid][s].sentence) + return ret + + +def annotations_to_evidence_token_identification(annotations: List[Annotation], + source_documents: Dict[str, List[List[str]]], + interned_documents: Dict[str, List[List[int]]], + token_mapping: Dict[str, List[List[Tuple[int, int]]]] + ) -> Dict[str, Dict[str, List[SentenceEvidence]]]: + # TODO document + # TODO should we simplify to use only source text? + ret = defaultdict(lambda: defaultdict(list)) # annotation id -> docid -> sentences + positive_tokens = 0 + negative_tokens = 0 + for ann in annotations: + annid = ann.annotation_id + docids = set(ev.docid for ev in chain.from_iterable(ann.evidences)) + sentence_offsets = defaultdict(list) # docid -> [(start, end)] + classes = defaultdict(list) # docid -> [token is yea or nay] + for docid in docids: + start = 0 + assert len(source_documents[docid]) == len(interned_documents[docid]) + for whole_token_sent, wordpiece_sent in zip(source_documents[docid], interned_documents[docid]): + classes[docid].extend([0 for _ in wordpiece_sent]) + end = start + len(wordpiece_sent) + sentence_offsets[docid].append((start, end)) + start = end + for ev in chain.from_iterable(ann.evidences): + if len(ev.text) == 0: + continue + flat_token_map = list(chain.from_iterable(token_mapping[ev.docid])) + if ev.start_token != -1: + #start, end = token_mapping[ev.docid][ev.start_token][0], token_mapping[ev.docid][ev.end_token][1] + start, end = flat_token_map[ev.start_token][0], flat_token_map[ev.end_token - 1][1] + else: + start = flat_token_map[sentence_offsets[ev.start_sentence][0]][0] + end = flat_token_map[sentence_offsets[ev.end_sentence - 1][1]][1] + for i in range(start, end): + classes[ev.docid][i] = 1 + for docid, offsets in sentence_offsets.items(): + token_assignments = classes[docid] + positive_tokens += sum(token_assignments) + negative_tokens += len(token_assignments) - sum(token_assignments) + for s, (start, end) in enumerate(offsets): + sent = interned_documents[docid][s] + ret[annid][docid].append(SentenceEvidence(kls=tuple(token_assignments[start:end]), + query=ann.query, + ann_id=ann.annotation_id, + docid=docid, + index=s, + sentence=sent)) + logging.info(f"Have {positive_tokens} positive wordpiece tokens, {negative_tokens} negative wordpiece tokens") + return ret + + +def make_preds_batch(classifier: nn.Module, + batch_elements: List[SentenceEvidence], + device=None, + criterion: nn.Module = None, + tensorize_model_inputs: bool = True) -> Tuple[float, List[float], List[int], List[int]]: + """Batch predictions + + Args: + classifier: a module that looks like an AttentiveClassifier + batch_elements: a list of elements to make predictions over. These must be SentenceEvidence objects. + device: Optional; what compute device this should run on + criterion: Optional; a loss function + tensorize_model_inputs: should we convert our data to tensors before passing it to the model? Useful if we have a model that performs its own tokenization + """ + # delete any "None" padding, if any (imposed by the use of the "grouper") + batch_elements = filter(lambda x: x is not None, batch_elements) + targets, queries, sentences = zip(*[(s.kls, s.query, s.sentence) for s in batch_elements]) + ids = [(s.ann_id, s.docid, s.index) for s in batch_elements] + targets = torch.tensor(targets, dtype=torch.long, device=device) + if tensorize_model_inputs: + queries = [torch.tensor(q, dtype=torch.long) for q in queries] + sentences = [torch.tensor(s, dtype=torch.long) for s in sentences] + preds = classifier(queries, ids, sentences) + targets = targets.to(device=preds.device) + if criterion: + loss = criterion(preds, targets) + else: + loss = None + # .float() because pytorch 1.3 introduces a bug where argmax is unsupported for float16 + hard_preds = torch.argmax(preds.float(), dim=-1) + return loss, preds, hard_preds, targets + + +def make_preds_epoch(classifier: nn.Module, + data: List[SentenceEvidence], + batch_size: int, + device=None, + criterion: nn.Module = None, + tensorize_model_inputs: bool = True): + """Predictions for more than one batch. + + Args: + classifier: a module that looks like an AttentiveClassifier + data: a list of elements to make predictions over. These must be SentenceEvidence objects. + batch_size: the biggest chunk we can fit in one batch. + device: Optional; what compute device this should run on + criterion: Optional; a loss function + tensorize_model_inputs: should we convert our data to tensors before passing it to the model? Useful if we have a model that performs its own tokenization + """ + epoch_loss = 0 + epoch_soft_pred = [] + epoch_hard_pred = [] + epoch_truth = [] + batches = _grouper(data, batch_size) + classifier.eval() + for batch in batches: + loss, soft_preds, hard_preds, targets = make_preds_batch(classifier, batch, device, criterion=criterion, + tensorize_model_inputs=tensorize_model_inputs) + if loss is not None: + epoch_loss += loss.sum().item() + epoch_hard_pred.extend(hard_preds) + epoch_soft_pred.extend(soft_preds.cpu()) + epoch_truth.extend(targets) + epoch_loss /= len(data) + epoch_hard_pred = [x.item() for x in epoch_hard_pred] + epoch_truth = [x.item() for x in epoch_truth] + return epoch_loss, epoch_soft_pred, epoch_hard_pred, epoch_truth + + +def make_token_preds_batch(classifier: nn.Module, + batch_elements: List[SentenceEvidence], + token_mapping: Dict[str, List[List[Tuple[int, int]]]], + device=None, + criterion: nn.Module = None, + tensorize_model_inputs: bool = True) -> Tuple[float, List[float], List[int], List[int]]: + """Batch predictions + + Args: + classifier: a module that looks like an AttentiveClassifier + batch_elements: a list of elements to make predictions over. These must be SentenceEvidence objects. + device: Optional; what compute device this should run on + criterion: Optional; a loss function + tensorize_model_inputs: should we convert our data to tensors before passing it to the model? Useful if we have a model that performs its own tokenization + """ + # delete any "None" padding, if any (imposed by the use of the "grouper") + batch_elements = filter(lambda x: x is not None, batch_elements) + targets, queries, sentences = zip(*[(s.kls, s.query, s.sentence) for s in batch_elements]) + ids = [(s.ann_id, s.docid, s.index) for s in batch_elements] + targets = PaddedSequence.autopad([torch.tensor(t, dtype=torch.long, device=device) for t in targets], batch_first=True, device=device) + aggregate_spans = [token_mapping[s.docid][s.index] for s in batch_elements] + if tensorize_model_inputs: + queries = [torch.tensor(q, dtype=torch.long) for q in queries] + sentences = [torch.tensor(s, dtype=torch.long) for s in sentences] + preds = classifier(queries, ids, sentences, aggregate_spans) + targets = targets.to(device=preds.device) + mask = targets.mask(on=1, off=0, device=preds.device, dtype=torch.float) + if criterion: + loss = criterion(preds, (targets.data.to(device=preds.device) * mask).squeeze()).sum() + else: + loss = None + hard_preds = [torch.round(x).to(dtype=torch.int).cpu() for x in targets.unpad(preds)] + targets = [[y.item() for y in x] for x in targets.unpad(targets.data.cpu())] + return loss, preds, hard_preds, targets #targets.unpad(targets.data.cpu()) + + +# TODO fix the arguments +def make_token_preds_epoch(classifier: nn.Module, + data: List[SentenceEvidence], + token_mapping: Dict[str, List[List[Tuple[int, int]]]], + batch_size: int, + device=None, + criterion: nn.Module = None, + tensorize_model_inputs: bool = True): + """Predictions for more than one batch. + + Args: + classifier: a module that looks like an AttentiveClassifier + data: a list of elements to make predictions over. These must be SentenceEvidence objects. + batch_size: the biggest chunk we can fit in one batch. + device: Optional; what compute device this should run on + criterion: Optional; a loss function + tensorize_model_inputs: should we convert our data to tensors before passing it to the model? Useful if we have a model that performs its own tokenization + """ + epoch_loss = 0 + epoch_soft_pred = [] + epoch_hard_pred = [] + epoch_truth = [] + batches = _grouper(data, batch_size) + classifier.eval() + for batch in batches: + loss, soft_preds, hard_preds, targets = make_token_preds_batch(classifier, + batch, + token_mapping, + device, + criterion=criterion, + tensorize_model_inputs=tensorize_model_inputs) + if loss is not None: + epoch_loss += loss.sum().item() + epoch_hard_pred.extend(hard_preds) + epoch_soft_pred.extend(soft_preds.cpu().tolist()) + epoch_truth.extend(targets) + epoch_loss /= len(data) + return epoch_loss, epoch_soft_pred, epoch_hard_pred, epoch_truth + + +# copied from https://docs.python.org/3/library/itertools.html#itertools-recipes +def _grouper(iterable, n, fillvalue=None): + "Collect data into fixed-length chunks or blocks" + # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx" + args = [iter(iterable)] * n + return itertools.zip_longest(*args, fillvalue=fillvalue) + + +def score_rationales(truth: List[Annotation], + documents: Dict[str, List[List[int]]], + input_data: List[SentenceEvidence], + scores: List[float] + ) -> dict: + results = {} + doc_to_sent_scores = dict() # (annid, docid) -> [sentence scores] + for sent, score in zip(input_data, scores): + k = (sent.ann_id, sent.docid) + if k not in doc_to_sent_scores: + doc_to_sent_scores[k] = [0.0 for _ in range(len(documents[sent.docid]))] + if not isinstance(score[1], float): + score[1] = score[1].item() + doc_to_sent_scores[(sent.ann_id, sent.docid)][sent.index] = score[1] + # hard rationale scoring + best_sentence = {k: np.argmax(np.array(v)) for k, v in doc_to_sent_scores.items()} + predicted_rationales = [] + for (ann_id, docid), sent_idx in best_sentence.items(): + start_token = sum(len(s) for s in documents[docid][:sent_idx]) + end_token = start_token + len(documents[docid][sent_idx]) + predicted_rationales.append(Rationale(ann_id, docid, start_token, end_token)) + true_rationales = list(chain.from_iterable(Rationale.from_annotation(rat) for rat in truth)) + + results['hard_rationale_scores'] = score_hard_rationale_predictions(true_rationales, predicted_rationales) + results['hard_rationale_partial_match_scores'] = partial_match_score(true_rationales, predicted_rationales, [0.5]) + + # soft rationale scoring + instance_format = [] + for (ann_id, docid), sentences in doc_to_sent_scores.items(): + soft_token_predictions = [] + for sent_score, sent_text in zip(sentences, documents[docid]): + soft_token_predictions.extend(sent_score for _ in range(len(sent_text))) + instance_format.append({ + 'annotation_id': ann_id, + 'rationales': [{ + 'docid': docid, + 'soft_rationale_predictions': soft_token_predictions, + 'soft_sentence_predictions': sentences, + }], + }) + flattened_documents = {k: list(chain.from_iterable(v)) for k, v in documents.items()} + token_scoring_format = PositionScoredDocument.from_results(instance_format, truth, flattened_documents, + use_tokens=True) + results['soft_token_scores'] = score_soft_tokens(token_scoring_format) + sentence_scoring_format = PositionScoredDocument.from_results(instance_format, truth, documents, use_tokens=False) + results['soft_sentence_scores'] = score_soft_tokens(sentence_scoring_format) + return results + + +def decode(evidence_identifier: nn.Module, + evidence_classifier: nn.Module, + train: List[Annotation], + val: List[Annotation], + test: List[Annotation], + docs: Dict[str, List[List[int]]], + class_interner: Dict[str, int], + batch_size: int, + tensorize_model_inputs: bool, + decoding_docs: Dict[str, List[Any]] = None) -> dict: + """Identifies and then classifies evidence + + Args: + evidence_identifier: a module for identifying evidence statements + evidence_classifier: a module for making a classification based on evidence statements + train: A List of interned Annotations + val: A List of interned Annotations + test: A List of interned Annotations + docs: A Dict of Documents, which are interned sentences. + class_interner: Converts an Annotation's final class into ints + batch_size: how big should our batches be? + tensorize_model_inputs: should we convert our data to tensors before passing it to the model? Useful if we have a model that performs its own tokenization + """ + device = None + class_labels = [k for k, v in sorted(class_interner.items(), key=lambda x: x[1])] + if decoding_docs is None: + decoding_docs = docs + + def prep(data: List[Annotation]) -> List[Tuple[SentenceEvidence, SentenceEvidence]]: + """Prepares data for evidence identification and classification. + + Creates paired evaluation data, wherein each (annotation, docid, sentence, kls) + tuplet appears first as the kls determining if the sentence is evidence, and + secondarily what the overall classification for the (annotation/docid) pair is. + This allows selection based on model scores of the evidence_identifier for + input to the evidence_classifier. + """ + identification_data = annotations_to_evidence_identification(data, docs) + classification_data = annotations_to_evidence_classification(data, docs, class_interner, include_all=True) + ann_doc_sents = defaultdict(lambda: defaultdict(dict)) # ann id -> docid -> sent idx -> sent data + ret = [] + for sent_ev in classification_data: + id_data = identification_data[sent_ev.ann_id][sent_ev.docid][sent_ev.index] + ret.append((id_data, sent_ev)) + assert id_data.ann_id == sent_ev.ann_id + assert id_data.docid == sent_ev.docid + assert id_data.index == sent_ev.index + assert len(ret) == len(classification_data) + return ret + + def decode_batch(data: List[Tuple[SentenceEvidence, SentenceEvidence]], name: str, score: bool = False, + annotations: List[Annotation] = None) -> dict: + """Identifies evidence statements and then makes classifications based on it. + + Args: + data: a paired list of SentenceEvidences, differing only in the kls field. + The first corresponds to whether or not something is evidence, and the second corresponds to an evidence class + name: a name for a results dict + """ + + num_uniques = len(set((x.ann_id, x.docid) for x, _ in data)) + logging.info(f'Decoding dataset {name} with {len(data)} sentences, {num_uniques} annotations') + identifier_data, classifier_data = zip(*data) + results = dict() + IdentificationClassificationResult = namedtuple('IdentificationClassificationResult', + 'identification_data classification_data soft_identification hard_identification soft_classification hard_classification') + with torch.no_grad(): + # make predictions for the evidence_identifier + evidence_identifier.eval() + evidence_classifier.eval() + + _, soft_identification_preds, hard_identification_preds, _ = make_preds_epoch(evidence_identifier, + identifier_data, batch_size, + device, + tensorize_model_inputs=tensorize_model_inputs) + assert len(soft_identification_preds) == len(data) + identification_results = defaultdict(list) + for id_data, cls_data, soft_id_pred, hard_id_pred in zip(identifier_data, classifier_data, + soft_identification_preds, + hard_identification_preds): + res = IdentificationClassificationResult(identification_data=id_data, + classification_data=cls_data, + # 1 is p(evidence|sent,query) + soft_identification=soft_id_pred[1].float().item(), + hard_identification=hard_id_pred, + soft_classification=None, + hard_classification=False) + identification_results[(id_data.ann_id, id_data.docid)].append(res) + + best_identification_results = {key: max(value, key=lambda x: x.soft_identification) for key, value in + identification_results.items()} + logging.info( + f'Selected the best sentence for {len(identification_results)} examples from a total of {len(soft_identification_preds)} sentences') + ids, classification_data = zip( + *[(k, v.classification_data) for k, v in best_identification_results.items()]) + _, soft_classification_preds, hard_classification_preds, classification_truth = make_preds_epoch( + evidence_classifier, classification_data, batch_size, device, + tensorize_model_inputs=tensorize_model_inputs) + classification_results = dict() + for eyeD, soft_class, hard_class in zip(ids, soft_classification_preds, hard_classification_preds): + input_id_result = best_identification_results[eyeD] + res = IdentificationClassificationResult(identification_data=input_id_result.identification_data, + classification_data=input_id_result.classification_data, + soft_identification=input_id_result.soft_identification, + hard_identification=input_id_result.hard_identification, + soft_classification=soft_class, + hard_classification=hard_class) + classification_results[eyeD] = res + + if score: + truth = [] + pred = [] + for res in classification_results.values(): + truth.append(res.classification_data.kls) + pred.append(res.hard_classification) + # results[f'{name}_f1'] = classification_report(classification_truth, pred, target_names=class_labels, output_dict=True) + results[f'{name}_f1'] = classification_report(classification_truth, hard_classification_preds, + target_names=class_labels, output_dict=True) + results[f'{name}_acc'] = accuracy_score(classification_truth, hard_classification_preds) + results[f'{name}_rationale'] = score_rationales(annotations, decoding_docs, identifier_data, + soft_identification_preds) + + # turn the above results into a format suitable for scoring via the rationale scorer + # n.b. the sentence-level evidence predictions (hard and soft) are + # broadcast to the token level for scoring. The comprehensiveness class + # score is also a lie since the pipeline model above is faithful by + # design. + decoded = dict() + decoded_scores = defaultdict(list) + for (ann_id, docid), pred in classification_results.items(): + sentence_prediction_scores = [x.soft_identification for x in identification_results[(ann_id, docid)]] + sentence_start_token = sum(len(s) for s in decoding_docs[docid][:pred.identification_data.index]) + sentence_end_token = sentence_start_token + len(decoding_docs[docid][pred.classification_data.index]) + hard_rationale_predictions = [{'start_token': sentence_start_token, 'end_token': sentence_end_token}] + soft_rationale_predictions = [] + for sent_result in sorted(identification_results[(ann_id, docid)], + key=lambda x: x.identification_data.index): + soft_rationale_predictions.extend(sent_result.soft_identification for _ in range(len( + decoding_docs[sent_result.identification_data.docid][sent_result.identification_data.index]))) + if ann_id not in decoded: + decoded[ann_id] = { + "annotation_id": ann_id, + "rationales": [], + "classification": class_labels[pred.hard_classification], + "classification_scores": {class_labels[i]: s.item() for i, s in + enumerate(pred.soft_classification)}, + # TODO this should turn into the data distribution for the predicted class + # "comprehensiveness_classification_scores": 0.0, + "truth": pred.classification_data.kls, + } + decoded[ann_id]['rationales'].append({ + "docid": docid, + "hard_rationale_predictions": hard_rationale_predictions, + "soft_rationale_predictions": soft_rationale_predictions, + "soft_sentence_predictions": sentence_prediction_scores, + }) + decoded_scores[ann_id].append(pred.soft_classification) + + # in practice, this is always a single element operation: + # in evidence inference (prompt is really a prompt + document), fever (we split documents into two classifications), movies (you only have one opinion about a movie), or boolQ (single document prompts) + # this exists to support weird models we *might* implement for cose/esnli + for ann_id, scores_list in decoded_scores.items(): + scores = torch.stack(scores_list) + score_avg = torch.mean(scores, dim=0) + # .float() because pytorch 1.3 introduces a bug where argmax is unsupported for float16 + hard_pred = torch.argmax(score_avg.float()).item() + decoded[ann_id]['classification'] = class_labels[hard_pred] + decoded[ann_id]['classification_scores'] = {class_labels[i]: s.item() for i, s in enumerate(score_avg)} + return results, list(decoded.values()) + + test_results, test_decoded = decode_batch(prep(test), 'test', score=False) + val_results, val_decoded = dict(), [] + train_results, train_decoded = dict(), [] + #val_results, val_decoded = decode_batch(prep(val), 'val', score=True, annotations=val) + #train_results, train_decoded = decode_batch(prep(train), 'train', score=True, annotations=train) + return dict(**train_results, **val_results, **test_results), train_decoded, val_decoded, test_decoded + +def decode_evidence_tokens_and_classify(evidence_token_identifier: nn.Module, + evidence_classifier: nn.Module, + train: List[Annotation], + val: List[Annotation], + test: List[Annotation], + docs: Dict[str, List[List[int]]], + source_documents: Dict[str, List[List[str]]], + token_mapping: Dict[str, List[List[Tuple[int, int]]]], + class_interner: Dict[str, int], + batch_size: int, + decoding_docs: Dict[str, List[Any]], + use_cose_hack: bool=False) -> dict: + """Identifies and then classifies evidence + + Args: + evidence_token_identifier: a module for identifying evidence statements + evidence_classifier: a module for making a classification based on evidence statements + train: A List of interned Annotations + val: A List of interned Annotations + test: A List of interned Annotations + docs: A Dict of Documents, which are interned sentences. + class_interner: Converts an Annotation's final class into ints + batch_size: how big should our batches be? + """ + device = None + class_labels = [k for k, v in sorted(class_interner.items(), key=lambda x: x[1])] + if decoding_docs is None: + decoding_docs = docs + + def prep(data: List[Annotation]) -> List[Tuple[SentenceEvidence, SentenceEvidence]]: + """Prepares data for evidence identification and classification. + + Creates paired evaluation data, wherein each (annotation, docid, sentence, kls) + tuplet appears first as the kls determining if the sentence is evidence, and + secondarily what the overall classification for the (annotation/docid) pair is. + This allows selection based on model scores of the evidence_token_identifier for + input to the evidence_classifier. + """ + #identification_data = annotations_to_evidence_identification(data, docs) + classification_data = token_annotations_to_evidence_classification(data, docs, class_interner) + # annotation id -> docid -> [SentenceEvidence]) + identification_data = annotations_to_evidence_token_identification(data, + source_documents=decoding_docs, + interned_documents=docs, + token_mapping=token_mapping) + ann_doc_sents = defaultdict(lambda: defaultdict(dict)) # ann id -> docid -> sent idx -> sent data + ret = [] + for sent_ev in classification_data: + id_data = identification_data[sent_ev.ann_id][sent_ev.docid][sent_ev.index] + ret.append((id_data, sent_ev)) + assert id_data.ann_id == sent_ev.ann_id + assert id_data.docid == sent_ev.docid + #assert id_data.index == sent_ev.index + assert len(ret) == len(classification_data) + return ret + + def decode_batch(data: List[Tuple[SentenceEvidence, SentenceEvidence]], name: str, score: bool = False, + annotations: List[Annotation] = None, class_labels: dict=class_labels) -> dict: + """Identifies evidence statements and then makes classifications based on it. + + Args: + data: a paired list of SentenceEvidences, differing only in the kls field. + The first corresponds to whether or not something is evidence, and the second corresponds to an evidence class + name: a name for a results dict + """ + + num_uniques = len(set((x.ann_id, x.docid) for x, _ in data)) + logging.info(f'Decoding dataset {name} with {len(data)} sentences, {num_uniques} annotations') + identifier_data, classifier_data = zip(*data) + results = dict() + with torch.no_grad(): + # make predictions for the evidence_token_identifier + evidence_token_identifier.eval() + evidence_classifier.eval() + + _, soft_identification_preds, hard_identification_preds, id_preds_truth = make_token_preds_epoch(evidence_token_identifier, + identifier_data, + token_mapping, + batch_size, + device, + tensorize_model_inputs=True) + assert len(soft_identification_preds) == len(data) + evidence_only_cls = [] + for id_data, cls_data, soft_id_pred, hard_id_pred in zip(identifier_data, + classifier_data, + soft_identification_preds, + hard_identification_preds): + assert cls_data.ann_id == id_data.ann_id + sent = [] + for (start, end) in token_mapping[cls_data.docid][0]: + if bool(hard_id_pred[start]): + sent.extend(id_data.sentence[start:end]) + #assert len(sent) > 0 + new_cls_data = SentenceEvidence(cls_data.kls, + cls_data.ann_id, + cls_data.query, + cls_data.docid, + cls_data.index, + tuple(sent)) + evidence_only_cls.append(new_cls_data) + _, soft_classification_preds, hard_classification_preds, classification_truth = make_preds_epoch( + evidence_classifier, evidence_only_cls, batch_size, device, + tensorize_model_inputs=True) + + if use_cose_hack: + logging.info('Reformatting identification and classification results to fit COS-E') + grouping = 5 + new_soft_identification_preds = [] + new_hard_identification_preds = [] + new_id_preds_truth = [] + new_soft_classification_preds = [] + new_hard_classification_preds = [] + new_classification_truth = [] + new_identifier_data = [] + class_labels = [] + + # TODO fix the labels for COS-E + for i in range(0, len(soft_identification_preds), grouping): + cls_scores = torch.stack(soft_classification_preds[i:i + grouping]) + cls_scores = nn.functional.softmax(cls_scores, dim=-1) + cls_scores = cls_scores[:,1] + choice = torch.argmax(cls_scores) + cls_labels = [x.ann_id.split('_')[-1] for x in evidence_only_cls[i:i + grouping]] + class_labels = cls_labels # we need to update the class labels because of the terrible hackery used to train this + cls_truths = [x.kls for x in evidence_only_cls[i:i + grouping]] + #cls_choice = evidence_only_cls[i + choice].ann_id.split('_')[-1] + cls_truth = np.argmax(cls_truths) + new_soft_identification_preds.append(soft_identification_preds[i + choice]) + new_hard_identification_preds.append(hard_identification_preds[i + choice]) + new_id_preds_truth.append(id_preds_truth[i + choice]) + new_soft_classification_preds.append(soft_classification_preds[i + choice]) + new_hard_classification_preds.append(choice) + new_identifier_data.append(identifier_data[i + choice]) + #new_hard_classification_preds.append(hard_classification_preds[i + choice]) + #new_classification_truth.append(classification_truth[i + choice]) + new_classification_truth.append(cls_truth) + + soft_identification_preds = new_soft_identification_preds + hard_identification_preds = new_hard_identification_preds + id_preds_truth = new_id_preds_truth + soft_classification_preds = new_soft_classification_preds + hard_classification_preds = new_hard_classification_preds + classification_truth = new_classification_truth + identifier_data = new_identifier_data + if score: + results[f'{name}_f1'] = classification_report(classification_truth, hard_classification_preds, + target_names=class_labels, output_dict=True) + results[f'{name}_acc'] = accuracy_score(classification_truth, hard_classification_preds) + results[f'{name}_token_pred_acc'] = accuracy_score(list(chain.from_iterable(id_preds_truth)), + list(chain.from_iterable(hard_identification_preds))) + results[f'{name}_token_pred_f1'] = classification_report(list(chain.from_iterable(id_preds_truth)), + list(chain.from_iterable(hard_identification_preds)), + output_dict=True) + # TODO for token level stuff! + soft_id_scores = [[1-x, x] for x in chain.from_iterable(soft_identification_preds)] + results[f'{name}_rationale'] = score_rationales(annotations, + decoding_docs, + identifier_data, + soft_id_scores) + logging.info(f'Results: {results}') + + # turn the above results into a format suitable for scoring via the rationale scorer + # n.b. the sentence-level evidence predictions (hard and soft) are + # broadcast to the token level for scoring. The comprehensiveness class + # score is also a lie since the pipeline model above is faithful by + # design. + decoded = dict() + scores = [] + assert len(identifier_data) == len(soft_identification_preds) + for id_data, soft_id_pred, hard_id_pred, soft_cls_preds, hard_cls_pred in zip(identifier_data, + soft_identification_preds, + hard_identification_preds, + soft_classification_preds, + hard_classification_preds): + docid = id_data.docid + if use_cose_hack: + docid = '_'.join(docid.split('_')[0:-1]) + assert len(docid) > 0 + rationales = { + "docid": docid, + "hard_rationale_predictions": [], + # token level classifications, a value must be provided per-token + # in an ideal world, these correspond to the hard-decoding above. + "soft_rationale_predictions": [], + # sentence level classifications, a value must be provided for every + # sentence in each document, or not at all + "soft_sentence_predictions": [1.0] + } + last = -1 + start_span = -1 + for pos, (start, _) in enumerate(token_mapping[id_data.docid][0]): + rationales['soft_rationale_predictions'].append(soft_id_pred[start]) + if bool(hard_id_pred[start]): + if start_span == -1: + start_span = pos + last = pos + else: + if start_span != -1: + rationales['hard_rationale_predictions'].append({ + "start_token": start_span, + "end_token": last + 1, + }) + last = -1 + start_span = -1 + if start_span != -1: + rationales['hard_rationale_predictions'].append({ + "start_token": start_span, + "end_token": last + 1, + }) + + ann_id = id_data.ann_id + if use_cose_hack: + ann_id = '_'.join(ann_id.split('_')[0:-1]) + soft_cls_preds = nn.functional.softmax(soft_cls_preds) + decoded[id_data.ann_id] = { + "annotation_id": ann_id, + "rationales": [rationales], + "classification": class_labels[hard_cls_pred], + "classification_scores": {class_labels[i]:score.item() for i,score in enumerate(soft_cls_preds)} + } + return results, list(decoded.values()) + + #test_results, test_decoded = dict(), [] + #val_results, val_decoded = dict(), [] + train_results, train_decoded = dict(), [] + val_results, val_decoded = decode_batch(prep(val), 'val', score=True, annotations=val, class_labels=class_labels) + test_results, test_decoded = decode_batch(prep(test), 'test', score=False, class_labels=class_labels) + #train_results, train_decoded = decode_batch(prep(train), 'train', score=True, annotations=train, class_labels=class_labels) + return dict(**train_results, **val_results, **test_results), train_decoded, val_decoded, test_decoded diff --git a/scripts/eraserbenchmark/rationale_benchmark/models/sequence_taggers.py b/scripts/eraserbenchmark/rationale_benchmark/models/sequence_taggers.py new file mode 100644 index 0000000..fbf14d1 --- /dev/null +++ b/scripts/eraserbenchmark/rationale_benchmark/models/sequence_taggers.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn + +from typing import List, Tuple, Any + +from transformers import BertModel + +from rationale_benchmark.models.model_utils import PaddedSequence + + +class BertTagger(nn.Module): + def __init__(self, + bert_dir: str, + pad_token_id: int, + cls_token_id: int, + sep_token_id: int, + max_length: int=512, + use_half_precision=True): + super(BertTagger, self).__init__() + self.sep_token_id = sep_token_id + self.cls_token_id = cls_token_id + self.pad_token_id = pad_token_id + self.max_length = max_length + bert = BertModel.from_pretrained(bert_dir) + if use_half_precision: + import apex + bert = bert.half() + self.bert = bert + self.relevance_tagger = nn.Sequential( + nn.Linear(self.bert.config.hidden_size, 1), + nn.Sigmoid() + ) + + def forward(self, + query: List[torch.tensor], + docids: List[Any], + document_batch: List[torch.tensor], + aggregate_spans: List[Tuple[int, int]]): + assert len(query) == len(document_batch) + # note about device management: since distributed training is enabled, the inputs to this module can be on + # *any* device (preferably cpu, since we wrap and unwrap the module) we want to keep these params on the + # input device (assuming CPU) for as long as possible for cheap memory access + target_device = next(self.parameters()).device + #cls_token = torch.tensor([self.cls_token_id]).to(device=document_batch[0].device) + sep_token = torch.tensor([self.sep_token_id]).to(device=document_batch[0].device) + input_tensors = [] + query_lengths = [] + for q, d in zip(query, document_batch): + if len(q) + len(d) + 1 > self.max_length: + d = d[:(self.max_length - len(q) - 1)] + input_tensors.append(torch.cat([q, sep_token, d])) + query_lengths.append(q.size()[0]) + bert_input = PaddedSequence.autopad(input_tensors, batch_first=True, padding_value=self.pad_token_id, device=target_device) + outputs = self.bert(bert_input.data, attention_mask=bert_input.mask(on=0.0, off=float('-inf'), dtype=torch.float, device=target_device)) + hidden = outputs[0] + classes = self.relevance_tagger(hidden) + ret = [] + for ql, cls, doc in zip(query_lengths, classes, document_batch): + start = ql + 1 + end = start + len(doc) + ret.append(cls[ql + 1:end]) + return PaddedSequence.autopad(ret, batch_first=True, padding_value=0, device=target_device).data.squeeze(dim=-1) diff --git a/scripts/eraserbenchmark/rationale_benchmark/utils.py b/scripts/eraserbenchmark/rationale_benchmark/utils.py new file mode 100644 index 0000000..bf48474 --- /dev/null +++ b/scripts/eraserbenchmark/rationale_benchmark/utils.py @@ -0,0 +1,226 @@ +import json +import os + +from dataclasses import dataclass, asdict, is_dataclass +from itertools import chain +from typing import Dict, List, Set, Tuple, Union, FrozenSet + + +@dataclass(eq=True, frozen=True) +class Evidence: + """ + (docid, start_token, end_token) form the only official Evidence; sentence level annotations are for convenience. + Args: + text: Some representation of the evidence text + docid: Some identifier for the document + start_token: The canonical start token, inclusive + end_token: The canonical end token, exclusive + start_sentence: Best guess start sentence, inclusive + end_sentence: Best guess end sentence, exclusive + """ + text: Union[str, Tuple[int], Tuple[str]] + docid: str + start_token: int = -1 + end_token: int = -1 + start_sentence: int = -1 + end_sentence: int = -1 + + +@dataclass(eq=True, frozen=True) +class Annotation: + """ + Args: + annotation_id: unique ID for this annotation element + query: some representation of a query string + evidences: a set of "evidence groups". + Each evidence group is: + * sufficient to respond to the query (or justify an answer) + * composed of one or more Evidences + * may have multiple documents in it (depending on the dataset) + - e-snli has multiple documents + - other datasets do not + classification: str + query_type: Optional str, additional information about the query + docids: a set of docids in which one may find evidence. + """ + annotation_id: str + query: Union[str, Tuple[int]] + evidences: Union[Set[Tuple[Evidence]], FrozenSet[Tuple[Evidence]]] + classification: str + query_type: str = None + docids: Set[str] = None + + def all_evidences(self) -> Tuple[Evidence]: + return tuple(list(chain.from_iterable(self.evidences))) + + +def annotations_to_jsonl(annotations, output_file): + with open(output_file, 'w') as of: + for ann in sorted(annotations, key=lambda x: x.annotation_id): + as_json = _annotation_to_dict(ann) + as_str = json.dumps(as_json, sort_keys=True) + of.write(as_str) + of.write('\n') + + +def _annotation_to_dict(dc): + # convenience method + if is_dataclass(dc): + d = asdict(dc) + ret = dict() + for k, v in d.items(): + ret[k] = _annotation_to_dict(v) + return ret + elif isinstance(dc, dict): + ret = dict() + for k, v in dc.items(): + k = _annotation_to_dict(k) + v = _annotation_to_dict(v) + ret[k] = v + return ret + elif isinstance(dc, str): + return dc + elif isinstance(dc, (set, frozenset, list, tuple)): + ret = [] + for x in dc: + ret.append(_annotation_to_dict(x)) + return tuple(ret) + else: + return dc + + +def load_jsonl(fp: str) -> List[dict]: + ret = [] + with open(fp, 'r') as inf: + for line in inf: + content = json.loads(line) + ret.append(content) + return ret + + +def write_jsonl(jsonl, output_file): + with open(output_file, 'w') as of: + for js in jsonl: + as_str = json.dumps(js, sort_keys=True) + of.write(as_str) + of.write('\n') + + +def annotations_from_jsonl(fp: str) -> List[Annotation]: + ret = [] + with open(fp, 'r') as inf: + for line in inf: + content = json.loads(line) + ev_groups = [] + for ev_group in content['evidences']: + ev_group = tuple([Evidence(**ev) for ev in ev_group]) + ev_groups.append(ev_group) + content['evidences'] = frozenset(ev_groups) + ret.append(Annotation(**content)) + return ret + + +def load_datasets(data_dir: str) -> Tuple[List[Annotation], List[Annotation], List[Annotation]]: + """Loads a training, validation, and test dataset + + Each dataset is assumed to have been serialized by annotations_to_jsonl, + that is it is a list of json-serialized Annotation instances. + """ + train_data = annotations_from_jsonl(os.path.join(data_dir, 'train.jsonl')) + val_data = annotations_from_jsonl(os.path.join(data_dir, 'val.jsonl')) + test_data = annotations_from_jsonl(os.path.join(data_dir, 'test.jsonl')) + return train_data, val_data, test_data + + +def load_documents(data_dir: str, docids: Set[str] = None) -> Dict[str, List[List[str]]]: + """Loads a subset of available documents from disk. + + Each document is assumed to be serialized as newline ('\n') separated sentences. + Each sentence is assumed to be space (' ') joined tokens. + """ + if os.path.exists(os.path.join(data_dir, 'docs.jsonl')): + assert not os.path.exists(os.path.join(data_dir, 'docs')) + return load_documents_from_file(data_dir, docids) + + docs_dir = os.path.join(data_dir, 'docs') + res = dict() + if docids is None: + docids = sorted(os.listdir(docs_dir)) + else: + docids = sorted(set(str(d) for d in docids)) + for d in docids: + with open(os.path.join(docs_dir, d), 'r', errors="ignore") as inf: + lines = [l.strip() for l in inf.readlines()] + lines = list(filter(lambda x: bool(len(x)), lines)) + tokenized = [list(filter(lambda x: bool(len(x)), line.strip().split(' '))) for line in lines] + res[d] = tokenized + return res + + +def load_flattened_documents(data_dir: str, docids: Set[str]) -> Dict[str, List[str]]: + """Loads a subset of available documents from disk. + + Returns a tokenized version of the document. + """ + unflattened_docs = load_documents(data_dir, docids) + flattened_docs = dict() + for doc, unflattened in unflattened_docs.items(): + flattened_docs[doc] = list(chain.from_iterable(unflattened)) + return flattened_docs + + +def intern_documents(documents: Dict[str, List[List[str]]], word_interner: Dict[str, int], unk_token: str): + """ + Replaces every word with its index in an embeddings file. + + If a word is not found, uses the unk_token instead + """ + ret = dict() + unk = word_interner[unk_token] + for docid, sentences in documents.items(): + ret[docid] = [[word_interner.get(w, unk) for w in s] for s in sentences] + return ret + + +def intern_annotations(annotations: List[Annotation], word_interner: Dict[str, int], unk_token: str): + ret = [] + for ann in annotations: + ev_groups = [] + for ev_group in ann.evidences: + evs = [] + for ev in ev_group: + evs.append(Evidence( + text=tuple([word_interner.get(t, word_interner[unk_token]) for t in ev.text.split()]), + docid=ev.docid, + start_token=ev.start_token, + end_token=ev.end_token, + start_sentence=ev.start_sentence, + end_sentence=ev.end_sentence)) + ev_groups.append(tuple(evs)) + ret.append(Annotation(annotation_id=ann.annotation_id, + query=tuple([word_interner.get(t, word_interner[unk_token]) for t in ann.query.split()]), + evidences=frozenset(ev_groups), + classification=ann.classification, + query_type=ann.query_type)) + return ret + + +def load_documents_from_file(data_dir: str, docids: Set[str] = None) -> Dict[str, List[List[str]]]: + """Loads a subset of available documents from 'docs.jsonl' file on disk. + + Each document is assumed to be serialized as newline ('\n') separated sentences. + Each sentence is assumed to be space (' ') joined tokens. + """ + docs_file = os.path.join(data_dir, 'docs.jsonl') + documents = load_jsonl(docs_file) + documents = {doc['docid']: doc['document'] for doc in documents} + res = dict() + if docids is None: + docids = sorted(list(documents.keys())) + else: + docids = sorted(set(str(d) for d in docids)) + for d in docids: + lines = documents[d].split('\n') + tokenized = [line.strip().split(' ') for line in lines] + res[d] = tokenized + return res diff --git a/scripts/eraserbenchmark/requirements.txt b/scripts/eraserbenchmark/requirements.txt new file mode 100644 index 0000000..db7eeb8 --- /dev/null +++ b/scripts/eraserbenchmark/requirements.txt @@ -0,0 +1,12 @@ +ftfy==5.5.1 +gensim==3.7.1 +numpy==1.16.2 +pandas==0.24.2 +pytorch-transformers==1.1.0 +scipy==1.2.1 +scispacy==0.2.3 +scikit-learn==0.20.3 +spacy==2.1.8 +tensorflow-gpu==1.14 +torch==1.3.0 +tqdm==4.31.1 diff --git a/scripts/evaluate_hatexplain.py b/scripts/evaluate_hatexplain.py index b8c05b3..d54b2d5 100644 --- a/scripts/evaluate_hatexplain.py +++ b/scripts/evaluate_hatexplain.py @@ -8,8 +8,11 @@ from xpotato.graph_extractor.extract import FeatureEvaluator from xpotato.dataset.explainable_dataset import ExplainableDataset +from hatexplain_to_eraser import data_tsv_to_eraser, prediction_to_eraser +from call_eraser import call_eraser def print_classification_report(df: DataFrame, stats: Dict[str, List]): + #print([(n > 0) * 1 for n in np.sum([p for p in stats["Predicted"]], axis=0)]) print( classification_report( df.label_id, @@ -66,22 +69,27 @@ def find_good_features( valid_files = [] evaluate(feature_file=save_features, files=train_files + valid_files, target=target) - def evaluate(feature_file: str, files: List[str], target: str): with open(feature_file) as feature_json: features = json.load(feature_json) evaluator = FeatureEvaluator() if target is None: target = list(features.keys())[0] - + for file in files: - print(f"File: {file}") + #print(f"File: {file}") potato = ExplainableDataset(path=file, label_vocab={"None": 0, target: 1}) df = potato.to_dataframe() stats = evaluator.evaluate_feature(target, features[target], df)[0] print_classification_report(df, stats) print("------------------------") - + matched_result = evaluator.match_features(df, features[target]) + subgraphs = matched_result["Matched rule"] + labels = matched_result["Predicted label"] + data_tsv_to_eraser(file) + prediction_to_eraser(file, subgraphs, labels, labels, labels, target) + call_eraser("./hatexplain", "train", "./hatexplain/train_prediction.jsonl") + print("------------------------") if __name__ == "__main__": argparser = ArgumentParser() diff --git a/scripts/hatexplain_to_eraser.py b/scripts/hatexplain_to_eraser.py new file mode 100644 index 0000000..e072570 --- /dev/null +++ b/scripts/hatexplain_to_eraser.py @@ -0,0 +1,247 @@ +import json +import os +import csv + +import more_itertools as mit + +save_path='hatexplain/' + +if not os.path.exists(save_path+'docs'): + os.makedirs(save_path+'docs') + +###https://github.com/hate-alert/HateXplain/blob/master/Explainability_Calculation_NB.ipynb +# https://stackoverflow.com/questions/2154249/identify-groups-of-continuous-numbers-in-a-list +def find_ranges(iterable): + """Yield range of consecutive numbers.""" + for group in mit.consecutive_groups(iterable): + group = list(group) + if len(group) == 1: + yield group[0] + else: + yield group[0], group[-1] + + +# used to convert our rationale list to the explanations mask +def rationale_to_explanations(text, rationale_list): + explanations = [] + #print(rationale_list) + # make sure it is a list + if(not isinstance(rationale_list, list)): + #rationale_list = [rationale_list] + print(rationale_list, " is not a list") + for i, word in enumerate(text.split()): + if word in rationale_list: + explanations.append(1) + else: + explanations.append(0) + #print(explanations) + return explanations + +def line_to_list(line): + str = line + str = str[1:-1] + str = str.replace('\'', '') + strlist = str.split(", ") + return strlist + +# Convert dataset into ERASER format: https://github.com/jayded/eraserbenchmark/blob/master/rationale_benchmark/utils.py +def get_evidence(post_id, text, rationale_list): + output = [] + #explanations = [1,1,0,0,1,1] + explanations = rationale_to_explanations(text, rationale_list) + anno_text = text.split() + indexes = sorted([i for i, each in enumerate(explanations) if each==1]) + span_list = list(find_ranges(indexes)) + #print(([i for i, each in enumerate(explanations) if each==1])) + for each in span_list: + if type(each)== int: + start = each + end = each+1 + elif len(each) == 2: + start = each[0] + end = each[1]+1 + else: + print('error') + + output.append({"docid":post_id, + "end_sentence": -1, + "end_token": end, + "start_sentence": -1, + "start_token": start, + "text": ' '.join([str(x) for x in anno_text[start:end]])}) + #if(len(output)>0): + return output + #else: + #return empty_evidence(post_id) +### + +def empty_evidence(post_id): + output = [] + output.append({"docid":post_id, + "end_sentence": -1, + "end_token": -1, + "start_sentence": -1, + "start_token": -1, + "text": ' '}) + return output + +def data_tsv_to_eraser(tsvfile): + train_tsv = open(tsvfile, encoding="utf8") + read_tsv = csv.reader(train_tsv, delimiter="\t") + newname = tsvfile.replace(".tsv", "") + write_eraser = open(save_path+newname+'.jsonl', 'w') + # 0 text + # 1 label + # 2 label_id + # 3 rationale + # 4 graph + data_set = [] + id = 2 + skip_first = False + for row in read_tsv: + if(skip_first == False): + skip_first = True + continue + + if(row[1] == "None"): # cut out GT None + id = id+1 # do not forget that + continue # rip None labels ;( + + entry = {} + #print(row[2]) + id_string = "tsv_line_"+str(id)+".txt" + write_doc = open(save_path+"docs/"+id_string, 'w', encoding="utf8") + write_doc.write(row[0]); + write_doc.close() + entry['annotation_id'] = id_string + entry['classification'] = row[1] + entry['docids'] = 'null' + entry['evidences'] = [get_evidence(id_string, row[0], line_to_list(row[3]))] + entry['query'] = "What is the class?" + entry['query_type'] = None + data_set.append(entry) + + #if(id == 350): + if(False): + print(row[0], row[3]) + print(rationale_to_explanations(row[0], line_to_list(row[3]))) + print(get_evidence(id_string, row[0], line_to_list(row[3]))) + + write_eraser.write(json.dumps(entry)+'\n') + id = id+1 + + train_tsv.close() + write_eraser.close() + +from xpotato.dataset.utils import default_pn_to_graph + +def penman_to_nodenames(penman): + g = default_pn_to_graph(penman)[0] + return list([(x[1]['name']) for x in g.nodes(data=True)]) + +# datatsvfile: the tsv file +# subgraphs: the graphs +# labels: m(xi)j +# labelswithoutr: m(xi\ri)j +# labelsonlyr: m(ri)j +# target: the target class +def prediction_to_eraser(datatsvfile, subgraphs, labels, labelswithoutr, labelsonlyr, target): + train_tsv = open(datatsvfile, encoding="utf8") + read_tsv = csv.reader(train_tsv, delimiter="\t") + newname = datatsvfile.replace(".tsv", "") + write_eraser = open(save_path+newname+'_prediction.jsonl', 'w') + # 0 text + # 1 label + # 2 label_id + # 3 rationale + # 4 graph + #print(subgraphs) + data_set = [] + # same id as data_tsv to get docs/ filenames right + id = 2 + skip_first = False + for row in read_tsv: + if(skip_first == False): + skip_first = True + continue + + subgraph = subgraphs[id-2] + label = labels[id-2] + labelwithoutr = labelswithoutr[id-2] + labelonlyr = labelsonlyr[id-2] + + if(row[1] == "None"): # cut out GT "None" + id = id+1 # do not forget that + continue # rip None labels ;( + + if(not label): + label = "None" # might be a problem with porting + predicted_rationales = [] + if(subgraph): #check if a rule matched (else predicted rationales are empty list) + predicted_rationales = penman_to_nodenames(subgraph[0][0]) # get a list + + entry = {} + #print(row[2]) + id_string = "tsv_line_"+str(id)+".txt" + write_doc = open(save_path+"docs/"+id_string, 'w', encoding="utf8") + write_doc.write(row[0]); + write_doc.close() + entry['annotation_id'] = id_string + entry['classification'] = label + + # m_xi = m(xi)j + m_xi = 0 + if(target == label): + m_xi = 1 + # m_xi_minus_ri = m(xi\ri)j + m_xi_minus_ri = 0 + if(target == labelwithoutr): + m_xi_minus_ri = 1 + # m_ri = m(ri)j + m_ri = 0 + if(target == labelonlyr): + m_ri = 1 + + # normal classification: P(target) = m(xi)j + ptarget = m_xi + classification_scores_dict = { + target: ptarget, + "None": 1-ptarget + } + entry['classification_scores'] = classification_scores_dict + # comprehensiveness = m(xi)j - m(xi\ri)j + pcomprehensiveness = m_xi - m_xi_minus_ri + comprehensiveness_classification_scores_dict = { + target: pcomprehensiveness, + "None": 1-pcomprehensiveness + } + entry['comprehensiveness_classification_scores'] = comprehensiveness_classification_scores_dict + # sufficiency = m(xi)j - m(ri)j + psufficiency = m_xi - m_ri + sufficiency_classification_scores_dict = { + target: psufficiency, + "None": 1-psufficiency + } + entry['sufficiency_classification_scores'] = sufficiency_classification_scores_dict + + entry['docids'] = 'null' + rationales_dict = { + "docid": id_string, + "hard_rationale_predictions": get_evidence(id_string, row[0], predicted_rationales) + } + entry['rationales'] = [rationales_dict] + entry['query'] = "What is the class?" + entry['query_type'] = None + data_set.append(entry) + + #if(id == 350): + if(False): + print(row[0], row[3]) + print(rationale_to_explanations(row[0], line_to_list(row[3]))) + print(get_evidence(id_string, row[0], line_to_list(row[3]))) + + write_eraser.write(json.dumps(entry)+'\n') + id = id+1 + + train_tsv.close() + write_eraser.close() \ No newline at end of file diff --git a/scripts/read_hatexplain.py b/scripts/read_hatexplain.py index e81d744..559fa7e 100644 --- a/scripts/read_hatexplain.py +++ b/scripts/read_hatexplain.py @@ -204,7 +204,7 @@ def process( args = argparser.parse_args() if args.mode != "distinct" and args.target is None: - raise ArgumentError( + raise ArgumentError(None, "Target is not given! If you want to produce a POTATO dataset " "(by running this code in process or both mode), you should specify the target." ) @@ -216,7 +216,7 @@ def process( else os.path.join(args.data_path, "dataset.json") ) if not os.path.isfile(dataset): - raise ArgumentError( + raise ArgumentError(None, "The specified data path is not a file and does not contain a dataset.json file. " "If your file has a different name, please specify." ) From 3ee4a3de84562deed218fc53278ee95f65ae2690 Mon Sep 17 00:00:00 2001 From: Markus Reichel Date: Mon, 13 Jun 2022 15:13:32 +0200 Subject: [PATCH 03/12] multi label matching 1 --- scripts/call_eraser.py | 23 ++++++++------- scripts/evaluate_hatexplain.py | 10 ++++--- scripts/hatexplain_to_eraser.py | 51 ++++++++++++++++++++++++++++----- 3 files changed, 62 insertions(+), 22 deletions(-) diff --git a/scripts/call_eraser.py b/scripts/call_eraser.py index 8ab40c0..42f4f99 100644 --- a/scripts/call_eraser.py +++ b/scripts/call_eraser.py @@ -30,16 +30,17 @@ def call_eraser(datadir, testtrainorval, pathtopredictions, silent=False): logger = logging.getLogger() logger.disabled = True #dir(rationale_benchmark.metrics) - eraser.runEvaluation(data_dir=datadir, #data dir - split=testtrainorval, # split - results=pathtopredictions, # results - score_file="eraser_output.json", # score - strict=False) # strict - #iou_thresholds=[0.5], # iou - #aopc_thresholds=[0.01, 0.05, 0.1, 0.2, 0.5]) # aopc + eraser.runEvaluation("None", # neutralclassname + data_dir=datadir, # data dir + split=testtrainorval, # split + results=pathtopredictions, # results + score_file=datadir+"/eraser_output.json", # score + strict=False) # strict + #iou_thresholds=[0.5], # iou + #aopc_thresholds=[0.01, 0.05, 0.1, 0.2, 0.5]) # aopc if silent: logger.disabled = False - print_eraser_results() + print_eraser_results(datadir) """ def call_eraser(datadir, testtrainorval, pathtopredictions): @@ -62,10 +63,10 @@ def call_eraser(datadir, testtrainorval, pathtopredictions): print_eraser_results() """ -def print_eraser_results(): +def print_eraser_results(datadir): # print the required results import json - with open('./eraser_output.json') as fp: + with open(datadir+'/eraser_output.json') as fp: output_data = json.load(fp) print('\nPlausibility') @@ -84,4 +85,4 @@ def print_eraser_results(): print('--') if __name__ == "__main__": - call_eraser("./hatexplain", "train", "./hatexplain/train_prediction.jsonl") \ No newline at end of file + call_eraser("./hatexplain", "val", "./hatexplain/val_prediction.jsonl") \ No newline at end of file diff --git a/scripts/evaluate_hatexplain.py b/scripts/evaluate_hatexplain.py index d54b2d5..072bc25 100644 --- a/scripts/evaluate_hatexplain.py +++ b/scripts/evaluate_hatexplain.py @@ -8,7 +8,7 @@ from xpotato.graph_extractor.extract import FeatureEvaluator from xpotato.dataset.explainable_dataset import ExplainableDataset -from hatexplain_to_eraser import data_tsv_to_eraser, prediction_to_eraser +from hatexplain_to_eraser import data_tsv_to_eraser, prediction_to_eraser, get_rationales from call_eraser import call_eraser def print_classification_report(df: DataFrame, stats: Dict[str, List]): @@ -83,12 +83,14 @@ def evaluate(feature_file: str, files: List[str], target: str): stats = evaluator.evaluate_feature(target, features[target], df)[0] print_classification_report(df, stats) print("------------------------") - matched_result = evaluator.match_features(df, features[target]) - subgraphs = matched_result["Matched rule"] + matched_result = evaluator.match_features(df, features[target], multi=True, return_subgraphs=True) + #subgraphs = matched_result["Matched rule"] + subgraphs = matched_result["Matched subgraph"] + #matched_result.to_csv("this_wow_x.csv") labels = matched_result["Predicted label"] data_tsv_to_eraser(file) prediction_to_eraser(file, subgraphs, labels, labels, labels, target) - call_eraser("./hatexplain", "train", "./hatexplain/train_prediction.jsonl") + call_eraser("./hatexplain", "val", "./hatexplain/val_prediction.jsonl") print("------------------------") if __name__ == "__main__": diff --git a/scripts/hatexplain_to_eraser.py b/scripts/hatexplain_to_eraser.py index e072570..b0021f2 100644 --- a/scripts/hatexplain_to_eraser.py +++ b/scripts/hatexplain_to_eraser.py @@ -4,6 +4,8 @@ import more_itertools as mit +from tuw_nlp.graph.utils import graph_to_pn + save_path='hatexplain/' if not os.path.exists(save_path+'docs'): @@ -139,6 +141,36 @@ def penman_to_nodenames(penman): g = default_pn_to_graph(penman)[0] return list([(x[1]['name']) for x in g.nodes(data=True)]) +#get rationales from subgraph prediction +def get_rationales(datatsvfile, subgraphs): + train_tsv = open(datatsvfile, encoding="utf8") + read_tsv = csv.reader(train_tsv, delimiter="\t") + # 0 text + # 1 label + # 2 label_id + # 3 rationale + # 4 graph + + # same id as data_tsv to get docs/ filenames right + id = 2 + skip_first = False + for row in read_tsv: + if(skip_first == False): + skip_first = True + continue + + subgraphlist = subgraphs[id-2] + + if(row[1] == "None"): # cut out GT "None" + id = id+1 # do not forget that + continue # rip None labels ;( + + predicted_rationales = [] + if(subgraphlist): #check if a rule matched (else predicted rationales are empty list) + for subgraph in subgraphlist[0]: + #predicted_rationales.append(penman_to_nodenames(subgraph[0][0])[0]) # get a list + predicted_rationales.append(penman_to_nodenames(graph_to_pn(subgraph))[0]) # get a list + # datatsvfile: the tsv file # subgraphs: the graphs # labels: m(xi)j @@ -164,9 +196,9 @@ def prediction_to_eraser(datatsvfile, subgraphs, labels, labelswithoutr, labelso if(skip_first == False): skip_first = True continue - - subgraph = subgraphs[id-2] - label = labels[id-2] + + subgraphlist = subgraphs[id-2] + labellist = labels[id-2] labelwithoutr = labelswithoutr[id-2] labelonlyr = labelsonlyr[id-2] @@ -174,12 +206,18 @@ def prediction_to_eraser(datatsvfile, subgraphs, labels, labelswithoutr, labelso id = id+1 # do not forget that continue # rip None labels ;( + label = "" + if(labellist): + label = labellist[0] # first label if(not label): label = "None" # might be a problem with porting + predicted_rationales = [] - if(subgraph): #check if a rule matched (else predicted rationales are empty list) - predicted_rationales = penman_to_nodenames(subgraph[0][0]) # get a list - + if(subgraphlist): #check if a rule matched (else predicted rationales are empty list) + for subgraph in subgraphlist[0]: + #predicted_rationales.append(penman_to_nodenames(subgraph[0][0])[0]) # get a list + predicted_rationales.append(penman_to_nodenames(graph_to_pn(subgraph))[0]) # get a list + print(predicted_rationales) entry = {} #print(row[2]) id_string = "tsv_line_"+str(id)+".txt" @@ -188,7 +226,6 @@ def prediction_to_eraser(datatsvfile, subgraphs, labels, labelswithoutr, labelso write_doc.close() entry['annotation_id'] = id_string entry['classification'] = label - # m_xi = m(xi)j m_xi = 0 if(target == label): From 5f772daf3a4e265ee1a37ed57885fc6454b46055 Mon Sep 17 00:00:00 2001 From: Markus Reichel Date: Wed, 15 Jun 2022 23:58:05 +0200 Subject: [PATCH 04/12] Add faithfulness, multi-rule rationale prediction --- scripts/call_eraser.py | 23 +++--- .../rationale_benchmark/metrics.py | 11 +-- scripts/evaluate_hatexplain.py | 58 +++++++++++-- scripts/hatexplain_to_eraser.py | 82 +++++++++++-------- 4 files changed, 119 insertions(+), 55 deletions(-) diff --git a/scripts/call_eraser.py b/scripts/call_eraser.py index 42f4f99..d9fa0e8 100644 --- a/scripts/call_eraser.py +++ b/scripts/call_eraser.py @@ -20,7 +20,7 @@ def nostdout(): #--results : The location of the model output file in eraser format #--score_file : The file name and location to write the output -def call_eraser(datadir, testtrainorval, pathtopredictions, silent=False): +def call_eraser(neutralclassname, datadir, testtrainorval, pathtopredictions, silent=False): import sys, os pkgpath = os.getcwd()+"\\eraserbenchmark" print(pkgpath) @@ -30,7 +30,7 @@ def call_eraser(datadir, testtrainorval, pathtopredictions, silent=False): logger = logging.getLogger() logger.disabled = True #dir(rationale_benchmark.metrics) - eraser.runEvaluation("None", # neutralclassname + eraser.runEvaluation(neutralclassname, # neutralclassname data_dir=datadir, # data dir split=testtrainorval, # split results=pathtopredictions, # results @@ -68,21 +68,24 @@ def print_eraser_results(datadir): import json with open(datadir+'/eraser_output.json') as fp: output_data = json.load(fp) - - print('\nPlausibility') + + print("\n------------------------") + + print('Plausibility') if 'iou_scores' in output_data: - print('IOU F1 :', output_data['iou_scores'][0]['macro']['f1']) - print('Token F1 :', output_data['token_prf']['instance_macro']['f1']) + print('IOU F1 :', round(output_data['iou_scores'][0]['macro']['f1'], 3)) + print('Token F1 :', round(output_data['token_prf']['instance_macro']['f1'], 3)) if 'token_soft_metrics' in output_data: - print('AUPRC :', output_data['token_soft_metrics']['auprc']) + print('AUPRC :', round(output_data['token_soft_metrics']['auprc'], 3)) print('\nFaithfulness') if 'classification_scores' in output_data: - print('Comprehensiveness :', output_data['classification_scores']['comprehensiveness']) - print('Sufficiency :', output_data['classification_scores']['sufficiency']) + print('Comprehensiveness :', round(output_data['classification_scores']['comprehensiveness'], 3)) + print('Sufficiency :', round(output_data['classification_scores']['sufficiency'], 3)) else: print('--') + print("") if __name__ == "__main__": - call_eraser("./hatexplain", "val", "./hatexplain/val_prediction.jsonl") \ No newline at end of file + call_eraser("None", "./hatexplain", "val", "./hatexplain/val_prediction.jsonl") \ No newline at end of file diff --git a/scripts/eraserbenchmark/rationale_benchmark/metrics.py b/scripts/eraserbenchmark/rationale_benchmark/metrics.py index f41296d..86ea03a 100644 --- a/scripts/eraserbenchmark/rationale_benchmark/metrics.py +++ b/scripts/eraserbenchmark/rationale_benchmark/metrics.py @@ -278,14 +278,15 @@ def compute_aopc_scores(instances: List[dict], aopc_thresholds: List[float]): aopc_sufficiency_score, aopc_sufficiency_points = _instances_aopc(instances, aopc_thresholds, 'sufficiency_classification_scores') return aopc_thresholds, aopc_comprehensiveness_score, aopc_comprehensiveness_points, aopc_sufficiency_score, aopc_sufficiency_points -def score_classifications(instances: List[dict], annotations: List[Annotation], docs: Dict[str, List[str]], aopc_thresholds: List[float]) -> Dict[str, float]: +def score_classifications(neutralclassname: str, instances: List[dict], annotations: List[Annotation], docs: Dict[str, List[str]], aopc_thresholds: List[float]) -> Dict[str, float]: def compute_kl(cls_scores_, faith_scores_): keys = list(cls_scores_.keys()) cls_scores_ = [cls_scores_[k] for k in keys] faith_scores_ = [faith_scores_[k] for k in keys] return entropy(faith_scores_, cls_scores_) labels = list(set(x.classification for x in annotations)) - labels +=['None'] + if(neutralclassname): + labels +=[neutralclassname] #print("UIUIUIIU") label_to_int = {l:i for i,l in enumerate(labels)} key_to_instances = {inst['annotation_id']:inst for inst in instances} @@ -665,7 +666,7 @@ def main(): if has_final_predictions: flattened_documents = load_flattened_documents(args.data_dir, docids) - class_results = score_classifications(results, annotations, flattened_documents, args.aopc_thresholds) + class_results = score_classifications(None, results, annotations, flattened_documents, args.aopc_thresholds) scores['classification_scores'] = class_results else: logging.info("No classification scores detected, skipping classification") @@ -679,7 +680,7 @@ def main(): json.dump(scores, of, indent=4, sort_keys=True) ### COPY TO RUN FROM PYTHON FILE -def runEvaluation(data_dir, split, results, score_file, strict=False, iou_thresholds=[0.5], aopc_thresholds=[0.01, 0.05, 0.1, 0.2, 0.5]): +def runEvaluation(neutralclassname, data_dir, split, results, score_file, strict=False, iou_thresholds=[0.5], aopc_thresholds=[0.01, 0.05, 0.1, 0.2, 0.5]): #print(results) args = type("args", (object, ), { # data members @@ -743,7 +744,7 @@ def runEvaluation(data_dir, split, results, score_file, strict=False, iou_thresh if has_final_predictions: flattened_documents = load_flattened_documents(args.data_dir, docids) - class_results = score_classifications(results, annotations, flattened_documents, args.aopc_thresholds) + class_results = score_classifications(neutralclassname, results, annotations, flattened_documents, args.aopc_thresholds) scores['classification_scores'] = class_results else: logging.info("No classification scores detected, skipping classification") diff --git a/scripts/evaluate_hatexplain.py b/scripts/evaluate_hatexplain.py index 072bc25..24f3687 100644 --- a/scripts/evaluate_hatexplain.py +++ b/scripts/evaluate_hatexplain.py @@ -2,11 +2,13 @@ import json import numpy as np from pandas import DataFrame +import pandas import logging from argparse import ArgumentParser, ArgumentError from sklearn.metrics import classification_report from xpotato.graph_extractor.extract import FeatureEvaluator from xpotato.dataset.explainable_dataset import ExplainableDataset +from xpotato.dataset.utils import save_dataframe from hatexplain_to_eraser import data_tsv_to_eraser, prediction_to_eraser, get_rationales from call_eraser import call_eraser @@ -69,6 +71,20 @@ def find_good_features( valid_files = [] evaluate(feature_file=save_features, files=train_files + valid_files, target=target) +def remove_rationals(x, rationals): + text = x + for rational in rationals: + text = text.replace(rational, "") + return text + +def concat_rationals(rationals): + if(not rationals): + return "UNK" + text = "" + for rational in rationals: + text = text + rational + " " + return text + def evaluate(feature_file: str, files: List[str], target: str): with open(feature_file) as feature_json: features = json.load(feature_json) @@ -81,16 +97,44 @@ def evaluate(feature_file: str, files: List[str], target: str): potato = ExplainableDataset(path=file, label_vocab={"None": 0, target: 1}) df = potato.to_dataframe() stats = evaluator.evaluate_feature(target, features[target], df)[0] - print_classification_report(df, stats) - print("------------------------") - matched_result = evaluator.match_features(df, features[target], multi=True, return_subgraphs=True) - #subgraphs = matched_result["Matched rule"] + + matched_result = evaluator.match_features(df, features[target], multi=True, return_subgraphs=True, allow_multi_graph=True) + matched_result.to_csv('temp_matched_result.tsv', sep="\t") subgraphs = matched_result["Matched subgraph"] - #matched_result.to_csv("this_wow_x.csv") labels = matched_result["Predicted label"] + rationale_as_text_list = get_rationales(file, subgraphs) + + df_without_rationales = df.copy() + for i in range(df_without_rationales['text'].size): + df_without_rationales['text'][i] = remove_rationals(df_without_rationales['text'][i], rationale_as_text_list[i]) + save_dataframe(df_without_rationales, 'temp_df_without_rationales.tsv') + potato_without_rationales = ExplainableDataset(path='temp_df_without_rationales.tsv', label_vocab={"None": 0, target: 1}) + potato_without_rationales.set_graphs(potato_without_rationales.parse_graphs(graph_format="ud")) + df_without_rationales = potato_without_rationales.to_dataframe() + save_dataframe(df_without_rationales, 'temp_df_without_rationales.tsv') # to view graphs + matched_result = evaluator.match_features(df_without_rationales, features[target], multi=True, return_subgraphs=True, allow_multi_graph=True) + labels_without_rationales = matched_result["Predicted label"] + + df_only_rationales = df.copy() + for i in range(df_only_rationales['text'].size): + df_only_rationales['text'][i] = concat_rationals(rationale_as_text_list[i]) + save_dataframe(df_only_rationales, 'temp_df_only_rationales.tsv') + potato_only_rationales = ExplainableDataset(path='temp_df_only_rationales.tsv', label_vocab={"None": 0, target: 1}) + potato_only_rationales.set_graphs(potato_only_rationales.parse_graphs(graph_format="ud")) + df_only_rationales = potato_only_rationales.to_dataframe() + save_dataframe(df_only_rationales, 'temp_df_only_rationales.tsv') # to view graphs + matched_result = evaluator.match_features(df_only_rationales, features[target], multi=True, return_subgraphs=True, allow_multi_graph=True) + labels_only_rationales = matched_result["Predicted label"] + + #print(labels) + #print(labels_without_rationales) + #print(labels_only_rationales) + data_tsv_to_eraser(file) - prediction_to_eraser(file, subgraphs, labels, labels, labels, target) - call_eraser("./hatexplain", "val", "./hatexplain/val_prediction.jsonl") + prediction_to_eraser(file, rationale_as_text_list, labels, labels_without_rationales, labels_only_rationales, target) + call_eraser("None", "./hatexplain", "val", "./hatexplain/val_prediction.jsonl") + print("------------------------") + print_classification_report(df, stats) print("------------------------") if __name__ == "__main__": diff --git a/scripts/hatexplain_to_eraser.py b/scripts/hatexplain_to_eraser.py index b0021f2..abf4531 100644 --- a/scripts/hatexplain_to_eraser.py +++ b/scripts/hatexplain_to_eraser.py @@ -31,7 +31,7 @@ def rationale_to_explanations(text, rationale_list): if(not isinstance(rationale_list, list)): #rationale_list = [rationale_list] print(rationale_list, " is not a list") - for i, word in enumerate(text.split()): + for i, word in enumerate(text.replace(",", "").split()): if word in rationale_list: explanations.append(1) else: @@ -51,6 +51,8 @@ def get_evidence(post_id, text, rationale_list): output = [] #explanations = [1,1,0,0,1,1] explanations = rationale_to_explanations(text, rationale_list) + #print(rationale_list) + #print("for text " + text + ":= " + str(explanations)) anno_text = text.split() indexes = sorted([i for i, each in enumerate(explanations) if each==1]) span_list = list(find_ranges(indexes)) @@ -108,7 +110,6 @@ def data_tsv_to_eraser(tsvfile): if(row[1] == "None"): # cut out GT None id = id+1 # do not forget that continue # rip None labels ;( - entry = {} #print(row[2]) id_string = "tsv_line_"+str(id)+".txt" @@ -122,7 +123,6 @@ def data_tsv_to_eraser(tsvfile): entry['query'] = "What is the class?" entry['query_type'] = None data_set.append(entry) - #if(id == 350): if(False): print(row[0], row[3]) @@ -150,34 +150,50 @@ def get_rationales(datatsvfile, subgraphs): # 2 label_id # 3 rationale # 4 graph - + rationals = [] # same id as data_tsv to get docs/ filenames right id = 2 skip_first = False for row in read_tsv: + #print(id) + if(skip_first == False): skip_first = True continue - - subgraphlist = subgraphs[id-2] - if(row[1] == "None"): # cut out GT "None" - id = id+1 # do not forget that - continue # rip None labels ;( - + subgraphlist = subgraphs[id-2] predicted_rationales = [] if(subgraphlist): #check if a rule matched (else predicted rationales are empty list) - for subgraph in subgraphlist[0]: - #predicted_rationales.append(penman_to_nodenames(subgraph[0][0])[0]) # get a list - predicted_rationales.append(penman_to_nodenames(graph_to_pn(subgraph))[0]) # get a list - + for subgraph in subgraphlist: + #print(subgraph[0]) + #print(graph_to_pn(subgraph[0])) + graph = subgraph[0] + words = [graph.nodes[node]["name"] for node in graph.nodes()] + #print(words) + predicted_rationales = predicted_rationales + words + #predicted_rationales.append(penman_to_nodenames(graph_to_pn(subgraph[0]))[0]) # get a list + + rationals.append(predicted_rationales) + + id = id+1 + return rationals + +# take first label +def conclude_label_from_labellist(neutralclassname, labellist): + label = "" + if(labellist): + label = labellist[0] # first label + if(not label): + label = "None" # might be a problem with porting + return label + # datatsvfile: the tsv file # subgraphs: the graphs # labels: m(xi)j # labelswithoutr: m(xi\ri)j # labelsonlyr: m(ri)j # target: the target class -def prediction_to_eraser(datatsvfile, subgraphs, labels, labelswithoutr, labelsonlyr, target): +def prediction_to_eraser(datatsvfile, rationals, labels, labelswithoutr, labelsonlyr, target): train_tsv = open(datatsvfile, encoding="utf8") read_tsv = csv.reader(train_tsv, delimiter="\t") newname = datatsvfile.replace(".tsv", "") @@ -197,27 +213,20 @@ def prediction_to_eraser(datatsvfile, subgraphs, labels, labelswithoutr, labelso skip_first = True continue - subgraphlist = subgraphs[id-2] - labellist = labels[id-2] - labelwithoutr = labelswithoutr[id-2] - labelonlyr = labelsonlyr[id-2] + rationallist = rationals[id-2] + #print(rationallist) + curlabel = labels[id-2] + curlabelwithoutr = labelswithoutr[id-2] + curlabelonlyr = labelsonlyr[id-2] if(row[1] == "None"): # cut out GT "None" id = id+1 # do not forget that continue # rip None labels ;( - label = "" - if(labellist): - label = labellist[0] # first label - if(not label): - label = "None" # might be a problem with porting + label = conclude_label_from_labellist("None", curlabel) + labelwithoutr = conclude_label_from_labellist("None", curlabelwithoutr) + labelonlyr = conclude_label_from_labellist("None", curlabelonlyr) - predicted_rationales = [] - if(subgraphlist): #check if a rule matched (else predicted rationales are empty list) - for subgraph in subgraphlist[0]: - #predicted_rationales.append(penman_to_nodenames(subgraph[0][0])[0]) # get a list - predicted_rationales.append(penman_to_nodenames(graph_to_pn(subgraph))[0]) # get a list - print(predicted_rationales) entry = {} #print(row[2]) id_string = "tsv_line_"+str(id)+".txt" @@ -238,6 +247,8 @@ def prediction_to_eraser(datatsvfile, subgraphs, labels, labelswithoutr, labelso m_ri = 0 if(target == labelonlyr): m_ri = 1 + + #print("m_xi = " + str(m_xi) + " m_xi_minus_ri = " + str(m_xi_minus_ri) + " m_ri = " + str(m_ri)) # normal classification: P(target) = m(xi)j ptarget = m_xi @@ -247,24 +258,29 @@ def prediction_to_eraser(datatsvfile, subgraphs, labels, labelswithoutr, labelso } entry['classification_scores'] = classification_scores_dict # comprehensiveness = m(xi)j - m(xi\ri)j - pcomprehensiveness = m_xi - m_xi_minus_ri + #pcomprehensiveness = m_xi - m_xi_minus_ri + pcomprehensiveness = m_xi_minus_ri comprehensiveness_classification_scores_dict = { target: pcomprehensiveness, "None": 1-pcomprehensiveness } entry['comprehensiveness_classification_scores'] = comprehensiveness_classification_scores_dict # sufficiency = m(xi)j - m(ri)j - psufficiency = m_xi - m_ri + #psufficiency = m_xi - m_ri + psufficiency = m_ri sufficiency_classification_scores_dict = { target: psufficiency, "None": 1-psufficiency } entry['sufficiency_classification_scores'] = sufficiency_classification_scores_dict + #print("pcomprehensiveness = " + str(pcomprehensiveness)) + #print("psufficiency = " + str(psufficiency)) + entry['docids'] = 'null' rationales_dict = { "docid": id_string, - "hard_rationale_predictions": get_evidence(id_string, row[0], predicted_rationales) + "hard_rationale_predictions": get_evidence(id_string, row[0], rationallist) } entry['rationales'] = [rationales_dict] entry['query'] = "What is the class?" From 2fd50fe806754853d4896e30ea6842ff76eee4ce Mon Sep 17 00:00:00 2001 From: Markus Reichel Date: Tue, 5 Jul 2022 13:30:02 +0200 Subject: [PATCH 05/12] Add random sample inspector --- scripts/evaluate_hatexplain.py | 10 +++--- scripts/inspect_hatexplain.py | 62 ++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 6 deletions(-) create mode 100644 scripts/inspect_hatexplain.py diff --git a/scripts/evaluate_hatexplain.py b/scripts/evaluate_hatexplain.py index 24f3687..5db0af1 100644 --- a/scripts/evaluate_hatexplain.py +++ b/scripts/evaluate_hatexplain.py @@ -99,11 +99,12 @@ def evaluate(feature_file: str, files: List[str], target: str): stats = evaluator.evaluate_feature(target, features[target], df)[0] matched_result = evaluator.match_features(df, features[target], multi=True, return_subgraphs=True, allow_multi_graph=True) - matched_result.to_csv('temp_matched_result.tsv', sep="\t") subgraphs = matched_result["Matched subgraph"] labels = matched_result["Predicted label"] rationale_as_text_list = get_rationales(file, subgraphs) - + matched_result["Predicted rational"] = rationale_as_text_list + matched_result.to_csv('temp_matched_result.tsv', sep="\t") + df_without_rationales = df.copy() for i in range(df_without_rationales['text'].size): df_without_rationales['text'][i] = remove_rationals(df_without_rationales['text'][i], rationale_as_text_list[i]) @@ -126,15 +127,12 @@ def evaluate(feature_file: str, files: List[str], target: str): matched_result = evaluator.match_features(df_only_rationales, features[target], multi=True, return_subgraphs=True, allow_multi_graph=True) labels_only_rationales = matched_result["Predicted label"] - #print(labels) - #print(labels_without_rationales) - #print(labels_only_rationales) - data_tsv_to_eraser(file) prediction_to_eraser(file, rationale_as_text_list, labels, labels_without_rationales, labels_only_rationales, target) call_eraser("None", "./hatexplain", "val", "./hatexplain/val_prediction.jsonl") print("------------------------") print_classification_report(df, stats) + print("------------------------") if __name__ == "__main__": diff --git a/scripts/inspect_hatexplain.py b/scripts/inspect_hatexplain.py new file mode 100644 index 0000000..a217217 --- /dev/null +++ b/scripts/inspect_hatexplain.py @@ -0,0 +1,62 @@ +from typing import List, Dict +import json +import random +import numpy as np +from pandas import DataFrame +import pandas +import logging +from argparse import ArgumentParser, ArgumentError +from sklearn.metrics import classification_report +from xpotato.graph_extractor.extract import FeatureEvaluator +from xpotato.dataset.explainable_dataset import ExplainableDataset +from xpotato.dataset.utils import save_dataframe + +from hatexplain_to_eraser import data_tsv_to_eraser, prediction_to_eraser, get_rationales +from call_eraser import call_eraser + + + +if __name__ == "__main__": + argparser = ArgumentParser() + + numarg = argparser.add_argument( + "--number", + "-n", + help="The number of random examples to inspect.", + default=10, + type=int, + ) + #argparser.add_argument( + # "--train", "-t", help="The train file in potato format", nargs="+" + #) + + args = argparser.parse_args() + print("Printing "+str(args.number)+" random examples...") + print("------------------------") + numbers = [] + frame = pandas.read_csv("temp_matched_result.tsv", sep="\t") + frame_without_rationales = pandas.read_csv("temp_df_without_rationales.tsv", sep="\t") + if(args.number > frame.shape[0]): + raise ArgumentError( + message="number is bigger than dataframe rows", + argument=numarg + ) + + + for i in range(args.number): + num = random.randint(0, frame.shape[0]) + while(num in numbers): + num = random.randint(0, frame.shape[0]) + numbers.append(num) + #print(frame[num]) + row = frame.iloc[num] + row2 = frame_without_rationales.iloc[num] + print("Sentence:") + print(row["Sentence"]) + print("Labels:") + print("GT: " + str(row2["label"]) + " vs Prediction: "+str(row["Predicted label"])) + print("Matched rule: "+ str(row["Matched rule"])) + print("Rationals:") + print("GT: " + row2["rationale"] + " vs Prediction: "+row["Predicted rational"]) + print("------------------------") + #print(num) \ No newline at end of file From 633a3fdfaf7637b5773ad2e57d1dd79988007ab2 Mon Sep 17 00:00:00 2001 From: Markus Reichel Date: Fri, 8 Jul 2022 13:18:47 +0200 Subject: [PATCH 06/12] Add random seed, consecutive, FP/TP/FN/TN, paging, color to inspection tool --- scripts/inspect_hatexplain.py | 155 +++++++++++++++++++++++++++++++--- 1 file changed, 142 insertions(+), 13 deletions(-) diff --git a/scripts/inspect_hatexplain.py b/scripts/inspect_hatexplain.py index a217217..d1d1265 100644 --- a/scripts/inspect_hatexplain.py +++ b/scripts/inspect_hatexplain.py @@ -14,49 +14,178 @@ from hatexplain_to_eraser import data_tsv_to_eraser, prediction_to_eraser, get_rationales from call_eraser import call_eraser +def check_if(limit, num, frame, frame_without_rationales): + row = frame.iloc[num] + row2 = frame_without_rationales.iloc[num] + + if(limit == "NO"): + return True + + # GT, predicted + if(limit == "TP" and str(row2["label"])!="None" and str(row["Predicted label"])!="nan"): + return True + + if(limit == "FP" and str(row2["label"])=="None" and str(row["Predicted label"])!="nan"): + return True + + if(limit == "TN" and str(row2["label"])=="None" and str(row["Predicted label"])=="nan"): + return True + if(limit == "FN" and str(row2["label"])!="None" and str(row["Predicted label"])=="nan"): + return True + + return False + +def get_confusion_class(num, frame, frame_without_rationales): + row = frame.iloc[num] + row2 = frame_without_rationales.iloc[num] + + # GT, predicted + if(str(row2["label"])!="None" and str(row["Predicted label"])!="nan"): + return "TP" + + if(str(row2["label"])=="None" and str(row["Predicted label"])!="nan"): + return "FP" + + if(str(row2["label"])=="None" and str(row["Predicted label"])=="nan"): + return "TN" + + if(str(row2["label"])!="None" and str(row["Predicted label"])=="nan"): + return "FN" + + +class bcolors: + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKCYAN = '\033[96m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' if __name__ == "__main__": argparser = ArgumentParser() + numarg = argparser.add_argument( + "--limit", + "-l", + help="Limit sample type to TP, FP, TN, FN.", + choices=["NO", "TP", "FP", "TN", "FN"], + default="NO", + type=str, + ) + numarg = argparser.add_argument( "--number", "-n", - help="The number of random examples to inspect.", + help="The number of examples to inspect.", default=10, type=int, ) + + numarg = argparser.add_argument( + "--page", + "-p", + help="Not the first but p-th n samples.", + default=1, + type=int, + ) + + numarg = argparser.add_argument( + "--random", + "-r", + help="Not first n samples but n random samples.", + action='store_true', + ) + + numarg = argparser.add_argument( + "--seed", + "-s", + help="Random seed for random mode.", + default=0, + type=int, + ) #argparser.add_argument( # "--train", "-t", help="The train file in potato format", nargs="+" #) args = argparser.parse_args() - print("Printing "+str(args.number)+" random examples...") + + if(args.random): + random.seed(args.seed) + + outtext = " " + if(args.limit != "NO"): + outtext = outtext + args.limit + " " + if(args.random): + outtext = outtext + "random " + print("Printing "+str(args.number)+ outtext+"examples...") print("------------------------") numbers = [] - frame = pandas.read_csv("temp_matched_result.tsv", sep="\t") + + frame = None + try: + frame = pandas.read_csv("temp_matched_result.tsv", sep="\t") + except IOError as e: + print("Evaluation dataframe files not found! Please run evaluate_hatexplain.py first.") frame_without_rationales = pandas.read_csv("temp_df_without_rationales.tsv", sep="\t") if(args.number > frame.shape[0]): raise ArgumentError( - message="number is bigger than dataframe rows", + message="Number is bigger than dataframe rows!", argument=numarg ) + num = -1 + random_attempts=0 + for i in range(args.number*args.page): + fitting_num = False + + # choose fitting num + while(fitting_num == False): + # choose random num + if(args.random): + num = random.randint(0, frame.shape[0]-1) + while(num in numbers): + num = random.randint(0, frame.shape[0]-1) + random_attempts=random_attempts+1 + if(random_attempts >= 10*frame.shape[0]): + print("No more examples found after "+str(random_attempts-1)+" random attempts.") + print("------------------------") + break + # choose next num + else: + num = num+1 + if(num >= frame.shape[0]): + print("No more examples found.") + print("------------------------") + break + + if(check_if(args.limit, num, frame, frame_without_rationales)): + fitting_num = True - for i in range(args.number): - num = random.randint(0, frame.shape[0]) - while(num in numbers): - num = random.randint(0, frame.shape[0]) + # no more fitting nums + if(fitting_num == False): + break + + # page + if(i < args.number*(args.page-1)): + continue + numbers.append(num) - #print(frame[num]) + + # select num row = frame.iloc[num] row2 = frame_without_rationales.iloc[num] - print("Sentence:") + print(bcolors.HEADER+"Sentence: ("+str(num)+")"+bcolors.ENDC) print(row["Sentence"]) print("Labels:") - print("GT: " + str(row2["label"]) + " vs Prediction: "+str(row["Predicted label"])) + print("GT: " + bcolors.OKBLUE+str(row2["label"])+bcolors.ENDC + " vs Prediction: "+bcolors.OKBLUE+str(row["Predicted label"])+bcolors.ENDC+" (it's "+get_confusion_class(num, frame, frame_without_rationales)+")") print("Matched rule: "+ str(row["Matched rule"])) print("Rationals:") - print("GT: " + row2["rationale"] + " vs Prediction: "+row["Predicted rational"]) + print("GT: " + bcolors.OKBLUE+row2["rationale"]+bcolors.ENDC + " vs Prediction: "+bcolors.OKBLUE+row["Predicted rational"]+bcolors.ENDC) print("------------------------") - #print(num) \ No newline at end of file + #print(num) + + From c0b016c0e4db7e21f577b1051f34dcf423547b77 Mon Sep 17 00:00:00 2001 From: Markus Reichel Date: Wed, 20 Jul 2022 12:21:17 +0200 Subject: [PATCH 07/12] Add simple latex table exporter --- scripts/export_eraser_latex.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 scripts/export_eraser_latex.py diff --git a/scripts/export_eraser_latex.py b/scripts/export_eraser_latex.py new file mode 100644 index 0000000..f9ba891 --- /dev/null +++ b/scripts/export_eraser_latex.py @@ -0,0 +1,34 @@ +from tabulate import tabulate +import json + +datadir="./hatexplain" + +with open(datadir+'/eraser_output.json') as fp: + output_data = json.load(fp) + +output = [] + +output.append(['Plausibility']) +if 'iou_scores' in output_data: + output.append(['IOU F1 :', round(output_data['iou_scores'][0]['macro']['f1'], 3)]) + output.append(['Token F1 :', round(output_data['token_prf']['instance_macro']['f1'], 3)]) + +if 'token_soft_metrics' in output_data: + output.append(['AUPRC :', round(output_data['token_soft_metrics']['auprc'], 3)]) + +output.append(['Faithfulness']) +if 'classification_scores' in output_data: + output.append(['Comprehensiveness :', round(output_data['classification_scores']['comprehensiveness'], 3)]) + output.append(['Sufficiency :', round(output_data['classification_scores']['sufficiency'], 3)]) +else: + output.append('--') +#output.append("") + +#print(tabulate(output)) +print("") +print("\\begin{table}") +print("\label{tab:example}") +print("\centering") +print(tabulate(output, headers=["Metric", "Value"], tablefmt="latex")) +print("\caption{Latex ERASER Table Example}") +print("\end{table}") \ No newline at end of file From af9b7bcdfaebaff9feacea7b81791a5344fcbf47 Mon Sep 17 00:00:00 2001 From: Markus Reichel Date: Wed, 20 Jul 2022 12:41:08 +0200 Subject: [PATCH 08/12] Add explaining comments for eraser evaluation script --- scripts/evaluate_hatexplain.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/scripts/evaluate_hatexplain.py b/scripts/evaluate_hatexplain.py index 5db0af1..d0740db 100644 --- a/scripts/evaluate_hatexplain.py +++ b/scripts/evaluate_hatexplain.py @@ -98,6 +98,7 @@ def evaluate(feature_file: str, files: List[str], target: str): df = potato.to_dataframe() stats = evaluator.evaluate_feature(target, features[target], df)[0] + # get labels and predicted rationals from matched_results matched_result = evaluator.match_features(df, features[target], multi=True, return_subgraphs=True, allow_multi_graph=True) subgraphs = matched_result["Matched subgraph"] labels = matched_result["Predicted label"] @@ -105,6 +106,7 @@ def evaluate(feature_file: str, files: List[str], target: str): matched_result["Predicted rational"] = rationale_as_text_list matched_result.to_csv('temp_matched_result.tsv', sep="\t") + # get labels by removing the predicted rationals df_without_rationales = df.copy() for i in range(df_without_rationales['text'].size): df_without_rationales['text'][i] = remove_rationals(df_without_rationales['text'][i], rationale_as_text_list[i]) @@ -116,6 +118,7 @@ def evaluate(feature_file: str, files: List[str], target: str): matched_result = evaluator.match_features(df_without_rationales, features[target], multi=True, return_subgraphs=True, allow_multi_graph=True) labels_without_rationales = matched_result["Predicted label"] + # get labels with only the rationals df_only_rationales = df.copy() for i in range(df_only_rationales['text'].size): df_only_rationales['text'][i] = concat_rationals(rationale_as_text_list[i]) @@ -127,6 +130,7 @@ def evaluate(feature_file: str, files: List[str], target: str): matched_result = evaluator.match_features(df_only_rationales, features[target], multi=True, return_subgraphs=True, allow_multi_graph=True) labels_only_rationales = matched_result["Predicted label"] + # convert the data to eraser format and run eraser data_tsv_to_eraser(file) prediction_to_eraser(file, rationale_as_text_list, labels, labels_without_rationales, labels_only_rationales, target) call_eraser("None", "./hatexplain", "val", "./hatexplain/val_prediction.jsonl") From 32398c12d1569dc652e439068bdbefc8be6964df Mon Sep 17 00:00:00 2001 From: Markus Reichel Date: Tue, 16 Aug 2022 15:31:22 +0200 Subject: [PATCH 09/12] Adapt evaluate and inspect to use rationale_lemma --- scripts/hatexplain_to_eraser.py | 8 ++++---- scripts/inspect_hatexplain.py | 2 +- xpotato/dataset/explainable_dataset.py | 3 ++- xpotato/dataset/explainable_sample.py | 18 ++++++++++-------- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/scripts/hatexplain_to_eraser.py b/scripts/hatexplain_to_eraser.py index abf4531..c980f64 100644 --- a/scripts/hatexplain_to_eraser.py +++ b/scripts/hatexplain_to_eraser.py @@ -119,15 +119,15 @@ def data_tsv_to_eraser(tsvfile): entry['annotation_id'] = id_string entry['classification'] = row[1] entry['docids'] = 'null' - entry['evidences'] = [get_evidence(id_string, row[0], line_to_list(row[3]))] + entry['evidences'] = [get_evidence(id_string, row[0], line_to_list(row[5]))] entry['query'] = "What is the class?" entry['query_type'] = None data_set.append(entry) #if(id == 350): if(False): - print(row[0], row[3]) - print(rationale_to_explanations(row[0], line_to_list(row[3]))) - print(get_evidence(id_string, row[0], line_to_list(row[3]))) + print(row[0], row[5]) + print(rationale_to_explanations(row[0], line_to_list(row[5]))) + print(get_evidence(id_string, row[0], line_to_list(row[5]))) write_eraser.write(json.dumps(entry)+'\n') id = id+1 diff --git a/scripts/inspect_hatexplain.py b/scripts/inspect_hatexplain.py index d1d1265..282a178 100644 --- a/scripts/inspect_hatexplain.py +++ b/scripts/inspect_hatexplain.py @@ -184,7 +184,7 @@ class bcolors: print("GT: " + bcolors.OKBLUE+str(row2["label"])+bcolors.ENDC + " vs Prediction: "+bcolors.OKBLUE+str(row["Predicted label"])+bcolors.ENDC+" (it's "+get_confusion_class(num, frame, frame_without_rationales)+")") print("Matched rule: "+ str(row["Matched rule"])) print("Rationals:") - print("GT: " + bcolors.OKBLUE+row2["rationale"]+bcolors.ENDC + " vs Prediction: "+bcolors.OKBLUE+row["Predicted rational"]+bcolors.ENDC) + print("GT: " + bcolors.OKBLUE+row2["rationale_lemma"]+bcolors.ENDC + " vs Prediction: "+bcolors.OKBLUE+row["Predicted rational"]+bcolors.ENDC) print("------------------------") #print(num) diff --git a/xpotato/dataset/explainable_dataset.py b/xpotato/dataset/explainable_dataset.py index 75f8f2d..896c5eb 100644 --- a/xpotato/dataset/explainable_dataset.py +++ b/xpotato/dataset/explainable_dataset.py @@ -47,7 +47,8 @@ def read_dataset( df = pd.read_csv(path, sep="\t") samples = [ ExplainableSample( - (example["text"], example["label"], example["rationale"]), + (example["text"], example["label"], example["rationale"], + example["rationale_id"], example["rationale_lemma"]), potato_graph=PotatoGraph(graph_str=example["graph"]), label_id=example["label_id"], ) diff --git a/xpotato/dataset/explainable_sample.py b/xpotato/dataset/explainable_sample.py index dc6e893..11899ba 100644 --- a/xpotato/dataset/explainable_sample.py +++ b/xpotato/dataset/explainable_sample.py @@ -7,15 +7,17 @@ class ExplainableSample(Sample): def __init__( - self, - example: Tuple[str, str], - potato_graph: PotatoGraph = None, - label_id: int = None, + self, + example: Tuple[str, str], + potato_graph: PotatoGraph = None, + label_id: int = None, ) -> None: - super().__init__(example=example, potato_graph=potato_graph, label_id=label_id) - self.rationale = example[2] - self.rationale_id = example[3] - self.potato_graph = example[4] + super().__init__(example=example, potato_graph=potato_graph, label_id=label_id) + self.rationale = example[2] + self.rationale_id = example[3] + self.rationale_lemma = example[4] + if len(example) >= 6: + self.potato_graph = example[5] def _postprocess(self, graph: PotatoGraph) -> PotatoGraph: rationale_bool = [] From 5999e4f1ba8be9c6eb5e5a13087eea81132875c2 Mon Sep 17 00:00:00 2001 From: Markus Reichel Date: Sat, 3 Sep 2022 12:25:46 +0200 Subject: [PATCH 10/12] Add auto script to evaluate multiple files at once --- scripts/auto_evaluate.sh | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100755 scripts/auto_evaluate.sh diff --git a/scripts/auto_evaluate.sh b/scripts/auto_evaluate.sh new file mode 100755 index 0000000..ed597f4 --- /dev/null +++ b/scripts/auto_evaluate.sh @@ -0,0 +1,38 @@ +# potato +function evaluate() { + path=$1 + rules=$2 + + # process path to folder + foldersign="/" + minus="-" + foldername=${path//$foldersign/$minus} + tsv=".tsv" + empty="" + foldername=${foldername//$tsv/$empty} + + mkdir $foldername + cp $path val.tsv + python evaluate_hatexplain.py -f $rules -t val.tsv | tee out.txt + cp temp_matched_result.tsv $foldername + cp temp_df_without_rationales.tsv $foldername + cp temp_df_only_rationales.tsv $foldername + mv out.txt $foldername + rm val.tsv +} + +#evaluate women/minority_val_all.tsv sexism_rules.json +#evaluate women/majority_val_all.tsv sexism_rules.json +#evaluate homosexual/minority_val_all.tsv homophobia_rules.json +#evaluate homosexual/majority_val_all.tsv homophobia_rules.json + +evaluate women/minority_val_one_majority.tsv sexism_rules.json +evaluate women/majority_val_one_majority.tsv sexism_rules.json +evaluate homosexual/minority_val_one_majority.tsv homophobia_rules.json +evaluate homosexual/majority_val_one_majority.tsv homophobia_rules.json + +evaluate women/minority_val_pure.tsv sexism_rules.json +evaluate women/majority_val_pure.tsv sexism_rules.json +evaluate homosexual/minority_val_pure.tsv homophobia_rules.json +evaluate homosexual/majority_val_pure.tsv homophobia_rules.json + From 68e63df010ceb9f8f1562d98ce3680bf71698363 Mon Sep 17 00:00:00 2001 From: Markus Reichel Date: Sat, 10 Sep 2022 00:06:12 +0200 Subject: [PATCH 11/12] Change table format and add p&r for plausibility for rules --- scripts/auto_evaluate.sh | 9 +++++---- scripts/call_eraser.py | 2 ++ scripts/evaluate_hatexplain.py | 19 ++++++++++++++----- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/scripts/auto_evaluate.sh b/scripts/auto_evaluate.sh index ed597f4..4739104 100755 --- a/scripts/auto_evaluate.sh +++ b/scripts/auto_evaluate.sh @@ -10,6 +10,7 @@ function evaluate() { tsv=".tsv" empty="" foldername=${foldername//$tsv/$empty} + foldername="$foldername-rules" mkdir $foldername cp $path val.tsv @@ -21,10 +22,10 @@ function evaluate() { rm val.tsv } -#evaluate women/minority_val_all.tsv sexism_rules.json -#evaluate women/majority_val_all.tsv sexism_rules.json -#evaluate homosexual/minority_val_all.tsv homophobia_rules.json -#evaluate homosexual/majority_val_all.tsv homophobia_rules.json +evaluate women/minority_val_all.tsv sexism_rules.json +evaluate women/majority_val_all.tsv sexism_rules.json +evaluate homosexual/minority_val_all.tsv homophobia_rules.json +evaluate homosexual/majority_val_all.tsv homophobia_rules.json evaluate women/minority_val_one_majority.tsv sexism_rules.json evaluate women/majority_val_one_majority.tsv sexism_rules.json diff --git a/scripts/call_eraser.py b/scripts/call_eraser.py index d9fa0e8..54a1892 100644 --- a/scripts/call_eraser.py +++ b/scripts/call_eraser.py @@ -74,7 +74,9 @@ def print_eraser_results(datadir): print('Plausibility') if 'iou_scores' in output_data: print('IOU F1 :', round(output_data['iou_scores'][0]['macro']['f1'], 3)) + print('( P :', round(output_data['iou_scores'][0]['macro']['p'], 3), ', R :', round(output_data['iou_scores'][0]['macro']['r'], 3), ')') print('Token F1 :', round(output_data['token_prf']['instance_macro']['f1'], 3)) + print('( P :', round(output_data['token_prf']['instance_macro']['p'], 3), ', R :', round(output_data['token_prf']['instance_macro']['r'], 3), ')') if 'token_soft_metrics' in output_data: print('AUPRC :', round(output_data['token_soft_metrics']['auprc'], 3)) diff --git a/scripts/evaluate_hatexplain.py b/scripts/evaluate_hatexplain.py index d0740db..74eece5 100644 --- a/scripts/evaluate_hatexplain.py +++ b/scripts/evaluate_hatexplain.py @@ -5,10 +5,11 @@ import pandas import logging from argparse import ArgumentParser, ArgumentError -from sklearn.metrics import classification_report +from sklearn.metrics import classification_report, confusion_matrix from xpotato.graph_extractor.extract import FeatureEvaluator from xpotato.dataset.explainable_dataset import ExplainableDataset from xpotato.dataset.utils import save_dataframe +from tuw_nlp.common.eval import * from hatexplain_to_eraser import data_tsv_to_eraser, prediction_to_eraser, get_rationales from call_eraser import call_eraser @@ -86,6 +87,7 @@ def concat_rationals(rationals): return text def evaluate(feature_file: str, files: List[str], target: str): + #feature_file="sexism_rules.json" with open(feature_file) as feature_json: features = json.load(feature_json) evaluator = FeatureEvaluator() @@ -93,7 +95,8 @@ def evaluate(feature_file: str, files: List[str], target: str): target = list(features.keys())[0] for file in files: - #print(f"File: {file}") + #file="women/minority_val_pure.tsv" + #target="Women" potato = ExplainableDataset(path=file, label_vocab={"None": 0, target: 1}) df = potato.to_dataframe() stats = evaluator.evaluate_feature(target, features[target], df)[0] @@ -135,9 +138,15 @@ def evaluate(feature_file: str, files: List[str], target: str): prediction_to_eraser(file, rationale_as_text_list, labels, labels_without_rationales, labels_only_rationales, target) call_eraser("None", "./hatexplain", "val", "./hatexplain/val_prediction.jsonl") print("------------------------") - print_classification_report(df, stats) - - print("------------------------") + #print_classification_report(df, stats) + #print("------------------------") + + tn0, fp0, fn0, tp0 = confusion_matrix(1-df.label_id, [1-(n > 0) * 1 for n in np.sum([p for p in stats["Predicted"]], axis=0)]).ravel() + confusion_dict_neutral={"TN":tn0,"FP":fp0,"FN":fn0,"TP":tp0} + tn1, fp1, fn1, tp1 = confusion_matrix(df.label_id, [(n > 0) * 1 for n in np.sum([p for p in stats["Predicted"]], axis=0)]).ravel() + confusion_dict_target={"TN":tn1,"FP":fp1,"FN":fn1,"TP":tp1} + cat_stats={target : confusion_dict_target, "None" : confusion_dict_neutral} + print_cat_stats(cat_stats) #s,tablefmt="latex_booktabs" if __name__ == "__main__": argparser = ArgumentParser() From 22cb4d9f6d661d03e094799deead984c6b1f3e83 Mon Sep 17 00:00:00 2001 From: Markus Reichel Date: Mon, 12 Sep 2022 19:32:06 +0200 Subject: [PATCH 12/12] Synchronize to out.txt --- scripts/auto_evaluate.sh | 12 +++++-- scripts/call_eraser.py | 4 +-- scripts/eraserbenchmark/print_eraser.py | 46 +++++++++++++++++++++++++ scripts/evaluate_hatexplain.py | 16 +++++++++ 4 files changed, 74 insertions(+), 4 deletions(-) create mode 100644 scripts/eraserbenchmark/print_eraser.py diff --git a/scripts/auto_evaluate.sh b/scripts/auto_evaluate.sh index 4739104..4cd7ea6 100755 --- a/scripts/auto_evaluate.sh +++ b/scripts/auto_evaluate.sh @@ -14,11 +14,19 @@ function evaluate() { mkdir $foldername cp $path val.tsv - python evaluate_hatexplain.py -f $rules -t val.tsv | tee out.txt + python evaluate_hatexplain.py -f $rules -t val.tsv | tee evaluate.txt + + cd eraserbenchmark + python print_eraser.py $foldername | tee out.txt + cd .. + cp temp_matched_result.tsv $foldername cp temp_df_without_rationales.tsv $foldername cp temp_df_only_rationales.tsv $foldername - mv out.txt $foldername + mv evaluate.txt $foldername + mv eraserbenchmark/out.txt $foldername + mv cat_stats.json $foldername + mv eraser_output.json $foldername rm val.tsv } diff --git a/scripts/call_eraser.py b/scripts/call_eraser.py index 54a1892..da2e01f 100644 --- a/scripts/call_eraser.py +++ b/scripts/call_eraser.py @@ -34,13 +34,13 @@ def call_eraser(neutralclassname, datadir, testtrainorval, pathtopredictions, si data_dir=datadir, # data dir split=testtrainorval, # split results=pathtopredictions, # results - score_file=datadir+"/eraser_output.json", # score + score_file=""+"./eraser_output.json", # score strict=False) # strict #iou_thresholds=[0.5], # iou #aopc_thresholds=[0.01, 0.05, 0.1, 0.2, 0.5]) # aopc if silent: logger.disabled = False - print_eraser_results(datadir) + print_eraser_results(".") """ def call_eraser(datadir, testtrainorval, pathtopredictions): diff --git a/scripts/eraserbenchmark/print_eraser.py b/scripts/eraserbenchmark/print_eraser.py new file mode 100644 index 0000000..58cabdb --- /dev/null +++ b/scripts/eraserbenchmark/print_eraser.py @@ -0,0 +1,46 @@ +import json +import argparse +from tuw_nlp.common.eval import * + +my_parser = argparse.ArgumentParser(description='Which model to use') + +# Add the arguments +my_parser.add_argument('foldername', + metavar='--foldername', + nargs='?', + type=str, + default="evaluation", + help='foldername of the data for description') + +args = my_parser.parse_args() + +# print the required results +with open('../eraser_output.json') as fp: + output_data = json.load(fp) + +print("\n"+args.foldername+":") +print("------------------------") + +print('Plausibility') +if 'iou_scores' in output_data: + print('IOU F1 :', round(output_data['iou_scores'][0]['macro']['f1'], 3)) + print('( P :', round(output_data['iou_scores'][0]['macro']['p'], 3), ', R :', round(output_data['iou_scores'][0]['macro']['r'], 3), ')') + print('Token F1 :', round(output_data['token_prf']['instance_macro']['f1'], 3)) + print('( P :', round(output_data['token_prf']['instance_macro']['p'], 3), ', R :', round(output_data['token_prf']['instance_macro']['r'], 3), ')') + +if 'token_soft_metrics' in output_data: + print('AUPRC :', round(output_data['token_soft_metrics']['auprc'], 3)) + +print('\nFaithfulness') +if 'classification_scores' in output_data: + print('Comprehensiveness :', round(output_data['classification_scores']['comprehensiveness'], 3)) + print('Sufficiency :', round(output_data['classification_scores']['sufficiency'], 3)) +else: + print('--') +print("") + +print("------------------------") + +with open('../cat_stats.json') as fp: + cat_stats = json.load(fp) + print_cat_stats(cat_stats) #s,tablefmt="latex_booktabs" \ No newline at end of file diff --git a/scripts/evaluate_hatexplain.py b/scripts/evaluate_hatexplain.py index 74eece5..21e31ab 100644 --- a/scripts/evaluate_hatexplain.py +++ b/scripts/evaluate_hatexplain.py @@ -86,6 +86,20 @@ def concat_rationals(rationals): text = text + rational + " " return text +class NumpyEncoder(json.JSONEncoder): + """ Special json encoder for numpy types """ + def default(self, obj): + if isinstance(obj, (np.int_, np.intc, np.intp, np.int8, + np.int16, np.int32, np.int64, np.uint8, + np.uint16, np.uint32, np.uint64)): + return int(obj) + elif isinstance(obj, (np.float_, np.float16, np.float32, + np.float64)): + return float(obj) + elif isinstance(obj,(np.ndarray,)): #### This is the fix + return obj.tolist() + return json.JSONEncoder.default(self, obj) + def evaluate(feature_file: str, files: List[str], target: str): #feature_file="sexism_rules.json" with open(feature_file) as feature_json: @@ -147,6 +161,8 @@ def evaluate(feature_file: str, files: List[str], target: str): confusion_dict_target={"TN":tn1,"FP":fp1,"FN":fn1,"TP":tp1} cat_stats={target : confusion_dict_target, "None" : confusion_dict_neutral} print_cat_stats(cat_stats) #s,tablefmt="latex_booktabs" + with open("cat_stats.json", 'w') as fp: + fp.write((json.dumps(cat_stats,cls=NumpyEncoder))) if __name__ == "__main__": argparser = ArgumentParser()