diff --git a/src/pie_modules/metrics/__init__.py b/src/pie_modules/metrics/__init__.py index 2b9f82750..7038ade07 100644 --- a/src/pie_modules/metrics/__init__.py +++ b/src/pie_modules/metrics/__init__.py @@ -1,3 +1,4 @@ +from .confusion_matrix import ConfusionMatrix from .f1 import F1Metric from .relation_argument_distance_collector import RelationArgumentDistanceCollector from .span_coverage_collector import SpanCoverageCollector diff --git a/src/pie_modules/metrics/confusion_matrix.py b/src/pie_modules/metrics/confusion_matrix.py new file mode 100644 index 000000000..3defba19e --- /dev/null +++ b/src/pie_modules/metrics/confusion_matrix.py @@ -0,0 +1,156 @@ +import logging +from collections import defaultdict +from typing import Callable, Dict, Optional, Tuple, Union + +import pandas as pd +from pytorch_ie.core import Annotation, Document, DocumentMetric +from pytorch_ie.utils.hydra import resolve_target + +logger = logging.getLogger(__name__) + + +class ConfusionMatrix(DocumentMetric): + """Computes the confusion matrix for a given annotation layer that contains labeled + annotations. + + Args: + layer: The layer to compute the confusion matrix for. + label_field: The field to use for the label. Defaults to "label". + unassignable_label: The label to use for false negative annotations. Defaults to "UNASSIGNABLE". + undetected_label: The label to use for false positive annotations. Defaults to "UNDETECTED". + strict: If True, raises an error if a base annotation has multiple gold labels. If False, logs a warning. + show_as_markdown: If True, logs the confusion matrix as markdown on the console when calling compute(). + annotation_processor: A callable that processes the annotations before calculating the confusion matrix. + """ + + def __init__( + self, + layer: str, + label_field: str = "label", + show_as_markdown: bool = False, + unassignable_label: str = "UNASSIGNABLE", + undetected_label: str = "UNDETECTED", + strict: bool = True, + annotation_processor: Optional[Union[Callable[[Annotation], Annotation], str]] = None, + ): + super().__init__() + self.layer = layer + self.label_field = label_field + self.unassignable_label = unassignable_label + self.undetected_label = undetected_label + self.strict = strict + self.show_as_markdown = show_as_markdown + if isinstance(annotation_processor, str): + self.annotation_processor = resolve_target(annotation_processor) + else: + self.annotation_processor = annotation_processor + + def reset(self) -> None: + self.counts: Dict[Tuple[str, str], int] = defaultdict(int) + + def calculate_counts( + self, + document: Document, + annotation_filter: Optional[Callable[[Annotation], bool]] = None, + annotation_processor: Optional[Callable[[Annotation], Annotation]] = None, + ) -> Dict[Tuple[str, str], int]: + annotation_processor = annotation_processor or (lambda ann: ann) + annotation_filter = annotation_filter or (lambda ann: True) + predicted_annotations = { + annotation_processor(ann) + for ann in document[self.layer].predictions + if annotation_filter(ann) + } + gold_annotations = { + annotation_processor(ann) for ann in document[self.layer] if annotation_filter(ann) + } + base2gold = defaultdict(list) + for ann in gold_annotations: + base_ann_kwargs = {self.label_field: "DUMMY_LABEL"} + base_ann = ann.copy(**base_ann_kwargs) + base2gold[base_ann].append(ann) + base2pred = defaultdict(list) + for ann in predicted_annotations: + base_ann_kwargs = {self.label_field: "DUMMY_LABEL"} + base_ann = ann.copy(**base_ann_kwargs) + base2pred[base_ann].append(ann) + + # (gold_label, pred_label) -> count + counts: Dict[Tuple[str, str], int] = defaultdict(int) + for base_ann in set(base2gold) | set(base2pred): + gold_labels = [getattr(ann, self.label_field) for ann in base2gold[base_ann]] + pred_labels = [getattr(ann, self.label_field) for ann in base2pred[base_ann]] + + if self.undetected_label in gold_labels: + raise ValueError( + f"The gold annotation has the label '{self.undetected_label}' for undetected instances. " + f"Set a different undetected_label." + ) + if self.unassignable_label in pred_labels: + raise ValueError( + f"The predicted annotation has the label '{self.unassignable_label}' for unassignable predictions. " + f"Set a different unassignable_label." + ) + + if len(gold_labels) > 1: + msg = f"The base annotation {base_ann} has multiple gold labels: {sorted(gold_labels)}." + if self.strict: + raise ValueError(msg) + else: + logger.warning(msg + " Skip this base annotation.") + continue + + # use placeholder labels for empty gold or prediction labels + if len(gold_labels) == 0: + gold_labels.append(self.undetected_label) + if len(pred_labels) == 0: + pred_labels.append(self.unassignable_label) + + # main logic + for gold_label in gold_labels: + for pred_label in pred_labels: + counts[(gold_label, pred_label)] += 1 + + return counts + + def add_counts(self, counts: Dict[Tuple[str, str], int]) -> None: + for key, value in counts.items(): + self.counts[key] += value + + def _update(self, document: Document) -> None: + new_counts = self.calculate_counts( + document=document, + annotation_processor=self.annotation_processor, + ) + self.add_counts(new_counts) + + def _compute(self) -> Dict[str, Dict[str, int]]: + + res: Dict[str, Dict[str, int]] = defaultdict(dict) + for gold_label, pred_label in sorted(self.counts): + res[gold_label][pred_label] = self.counts[(gold_label, pred_label)] + + if self.show_as_markdown: + res_df = pd.DataFrame(res).fillna(0) + # index is prediction, columns is gold + gold_labels = res_df.columns + pred_labels = res_df.index + + # re-arrange index and columns: sort and put undetected_label and unassignable_label at the end + gold_labels_sorted = sorted( + [gold_label for gold_label in gold_labels if gold_label != self.undetected_label] + ) + # re-add undetected_label at the end, if it was in the gold labels + if self.undetected_label in gold_labels: + gold_labels_sorted = gold_labels_sorted + [self.undetected_label] + pred_labels_sorted = sorted( + [pred_label for pred_label in pred_labels if pred_label != self.unassignable_label] + ) + # re-add unassignable_label at the end, if it was in the pred labels + if self.unassignable_label in pred_labels: + pred_labels_sorted = pred_labels_sorted + [self.unassignable_label] + res_df_sorted = res_df.loc[pred_labels_sorted, gold_labels_sorted] + + # transpose and show as markdown: index is now gold, columns is prediction + logger.info(f"\n{self.layer}:\n{res_df_sorted.T.to_markdown()}") + return res diff --git a/tests/metrics/test_confusion_matrix.py b/tests/metrics/test_confusion_matrix.py new file mode 100644 index 000000000..8a768e806 --- /dev/null +++ b/tests/metrics/test_confusion_matrix.py @@ -0,0 +1,177 @@ +import logging +from dataclasses import dataclass +from functools import partial +from typing import Dict + +import pytest +from pytorch_ie.annotations import LabeledSpan +from pytorch_ie.core import Annotation, AnnotationLayer, annotation_field +from pytorch_ie.documents import TextBasedDocument + +from pie_modules.metrics import ConfusionMatrix + + +@pytest.fixture +def documents(): + @dataclass + class TextDocumentWithEntities(TextBasedDocument): + entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text") + + # a test sentence with two entities + doc1 = TextDocumentWithEntities( + text="The quick brown fox jumps over the lazy dog.", + ) + doc1.entities.append(LabeledSpan(start=4, end=19, label="animal")) + doc1.entities.append(LabeledSpan(start=35, end=43, label="animal")) + assert str(doc1.entities[0]) == "quick brown fox" + assert str(doc1.entities[1]) == "lazy dog" + + # a second test sentence with a different text and a single entity (a company) + doc2 = TextDocumentWithEntities(text="Apple is a great company.") + doc2.entities.append(LabeledSpan(start=0, end=5, label="company")) + assert str(doc2.entities[0]) == "Apple" + + documents = [doc1, doc2] + + # add predictions + # correct + documents[0].entities.predictions.append(LabeledSpan(start=4, end=19, label="animal")) + # wrong label + documents[0].entities.predictions.append(LabeledSpan(start=35, end=43, label="cat")) + # correct + documents[1].entities.predictions.append(LabeledSpan(start=0, end=5, label="company")) + # wrong span + documents[1].entities.predictions.append(LabeledSpan(start=10, end=15, label="company")) + + return documents + + +def test_confusion_matrix(documents): + metric = ConfusionMatrix(layer="entities") + metric(documents) + # (gold_label, predicted_label): count + assert dict(metric.counts) == { + ("animal", "animal"): 1, + ("animal", "cat"): 1, + ("UNDETECTED", "company"): 1, + ("company", "company"): 1, + } + assert metric.compute() == { + "animal": {"animal": 1, "cat": 1}, + "UNDETECTED": {"company": 1}, + "company": {"company": 1}, + } + + +def test_undetected_is_gold_label(documents): + metric = ConfusionMatrix(layer="entities", undetected_label="animal") + with pytest.raises(ValueError) as exception: + metric(documents) + + assert str(exception.value).startswith("The gold annotation has the label") + + +def test_unassignable_is_pred_label(documents): + metric = ConfusionMatrix(layer="entities", unassignable_label="cat") + with pytest.raises(ValueError) as exception: + metric(documents) + + assert str(exception.value).startswith("The predicted annotation has the label") + + +@pytest.fixture +def documents_with_several_gold_labels(documents): + doc1 = documents[0].copy() + doc2 = documents[1].copy() + doc1.entities.append(LabeledSpan(start=4, end=19, label="cat")) + + return [doc1, doc2] + + +def test_documents_with_several_gold_labels(documents_with_several_gold_labels, caplog): + metric = ConfusionMatrix(layer="entities") + with pytest.raises(ValueError): + metric(documents_with_several_gold_labels) + + metric = ConfusionMatrix(layer="entities", strict=False) + metric(documents_with_several_gold_labels) + assert caplog.messages[0].startswith( + "The base annotation LabeledSpan(start=4, end=19, label='DUMMY_LABEL', score=1.0) has multiple gold labels: ['animal', 'cat']. Skip this base annotation." + ) + + +@pytest.fixture +def documents_without_predictions(documents): + doc1 = documents[0].copy() + doc2 = documents[1].copy() + + doc1.entities.predictions.clear() + doc2.entities.predictions.clear() + + return [doc1, doc2] + + +def test_documents_without_predictions(documents_without_predictions): + metric = ConfusionMatrix(layer="entities") + metric(documents_without_predictions) + assert dict(metric.counts) == {("animal", "UNASSIGNABLE"): 2, ("company", "UNASSIGNABLE"): 1} + + +def test_show_as_markdown(documents, caplog): + caplog.set_level(logging.INFO) + metric = ConfusionMatrix(layer="entities", show_as_markdown=True) + metric(documents) + + markdown = [ + "\nentities:\n| | animal | cat | company |\n|:-----------|---------:|------:|----------:|\n| animal | 1 | 1 | 0 |\n| company | 0 | 0 | 1 |\n| UNDETECTED | 0 | 0 | 1 |" + ] + + assert caplog.messages == markdown + + +def test_show_as_markdown_without_predictions(documents_without_predictions, caplog): + caplog.set_level(logging.INFO) + metric = ConfusionMatrix(layer="entities", show_as_markdown=True) + metric(documents_without_predictions) + + markdown = [ + "\nentities:\n| | UNASSIGNABLE |\n|:--------|---------------:|\n| animal | 2 |\n| company | 1 |" + ] + + assert caplog.messages == markdown + + +MAPPING = {"cat": "My beautified Cat", "company": "The Super-Company", "animal": "Hrrr"} + + +def relabel_annotation(ann: Annotation, mapping: Dict[str, str] = MAPPING) -> Annotation: + return ann.copy(label=mapping[ann.label]) + + +def test_annotation_processor(documents): + annotation_processor = partial( + relabel_annotation, + mapping=MAPPING, + ) + metric = ConfusionMatrix(layer="entities", annotation_processor=annotation_processor) + metric(documents) + + assert dict(metric.counts) == { + ("Hrrr", "My beautified Cat"): 1, + ("Hrrr", "Hrrr"): 1, + ("The Super-Company", "The Super-Company"): 1, + ("UNDETECTED", "The Super-Company"): 1, + } + + +def test_annotation_processor_str(documents): + annotation_processor = "tests.metrics.test_confusion_matrix.relabel_annotation" + metric = ConfusionMatrix(layer="entities", annotation_processor=annotation_processor) + metric(documents) + + assert dict(metric.counts) == { + ("Hrrr", "My beautified Cat"): 1, + ("Hrrr", "Hrrr"): 1, + ("The Super-Company", "The Super-Company"): 1, + ("UNDETECTED", "The Super-Company"): 1, + }