Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/pie_modules/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .confusion_matrix import ConfusionMatrix
from .f1 import F1Metric
from .relation_argument_distance_collector import RelationArgumentDistanceCollector
from .span_coverage_collector import SpanCoverageCollector
Expand Down
156 changes: 156 additions & 0 deletions src/pie_modules/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import logging
from collections import defaultdict
from typing import Callable, Dict, Optional, Tuple, Union

import pandas as pd
from pytorch_ie.core import Annotation, Document, DocumentMetric
from pytorch_ie.utils.hydra import resolve_target

logger = logging.getLogger(__name__)


class ConfusionMatrix(DocumentMetric):
"""Computes the confusion matrix for a given annotation layer that contains labeled
annotations.

Args:
layer: The layer to compute the confusion matrix for.
label_field: The field to use for the label. Defaults to "label".
unassignable_label: The label to use for false negative annotations. Defaults to "UNASSIGNABLE".
undetected_label: The label to use for false positive annotations. Defaults to "UNDETECTED".
strict: If True, raises an error if a base annotation has multiple gold labels. If False, logs a warning.
show_as_markdown: If True, logs the confusion matrix as markdown on the console when calling compute().
annotation_processor: A callable that processes the annotations before calculating the confusion matrix.
"""

def __init__(
self,
layer: str,
label_field: str = "label",
show_as_markdown: bool = False,
unassignable_label: str = "UNASSIGNABLE",
undetected_label: str = "UNDETECTED",
strict: bool = True,
annotation_processor: Optional[Union[Callable[[Annotation], Annotation], str]] = None,
):
super().__init__()
self.layer = layer
self.label_field = label_field
self.unassignable_label = unassignable_label
self.undetected_label = undetected_label
self.strict = strict
self.show_as_markdown = show_as_markdown
if isinstance(annotation_processor, str):
self.annotation_processor = resolve_target(annotation_processor)
else:
self.annotation_processor = annotation_processor

def reset(self) -> None:
self.counts: Dict[Tuple[str, str], int] = defaultdict(int)

def calculate_counts(
self,
document: Document,
annotation_filter: Optional[Callable[[Annotation], bool]] = None,
annotation_processor: Optional[Callable[[Annotation], Annotation]] = None,
) -> Dict[Tuple[str, str], int]:
annotation_processor = annotation_processor or (lambda ann: ann)
annotation_filter = annotation_filter or (lambda ann: True)
predicted_annotations = {
annotation_processor(ann)
for ann in document[self.layer].predictions
if annotation_filter(ann)
}
gold_annotations = {
annotation_processor(ann) for ann in document[self.layer] if annotation_filter(ann)
}
base2gold = defaultdict(list)
for ann in gold_annotations:
base_ann_kwargs = {self.label_field: "DUMMY_LABEL"}
base_ann = ann.copy(**base_ann_kwargs)
base2gold[base_ann].append(ann)
base2pred = defaultdict(list)
for ann in predicted_annotations:
base_ann_kwargs = {self.label_field: "DUMMY_LABEL"}
base_ann = ann.copy(**base_ann_kwargs)
base2pred[base_ann].append(ann)

# (gold_label, pred_label) -> count
counts: Dict[Tuple[str, str], int] = defaultdict(int)
for base_ann in set(base2gold) | set(base2pred):
gold_labels = [getattr(ann, self.label_field) for ann in base2gold[base_ann]]
pred_labels = [getattr(ann, self.label_field) for ann in base2pred[base_ann]]

if self.undetected_label in gold_labels:
raise ValueError(
f"The gold annotation has the label '{self.undetected_label}' for undetected instances. "
f"Set a different undetected_label."
)
if self.unassignable_label in pred_labels:
raise ValueError(
f"The predicted annotation has the label '{self.unassignable_label}' for unassignable predictions. "
f"Set a different unassignable_label."
)

if len(gold_labels) > 1:
msg = f"The base annotation {base_ann} has multiple gold labels: {sorted(gold_labels)}."
if self.strict:
raise ValueError(msg)
else:
logger.warning(msg + " Skip this base annotation.")
continue

# use placeholder labels for empty gold or prediction labels
if len(gold_labels) == 0:
gold_labels.append(self.undetected_label)
if len(pred_labels) == 0:
pred_labels.append(self.unassignable_label)

# main logic
for gold_label in gold_labels:
for pred_label in pred_labels:
counts[(gold_label, pred_label)] += 1

return counts

def add_counts(self, counts: Dict[Tuple[str, str], int]) -> None:
for key, value in counts.items():
self.counts[key] += value

def _update(self, document: Document) -> None:
new_counts = self.calculate_counts(
document=document,
annotation_processor=self.annotation_processor,
)
self.add_counts(new_counts)

def _compute(self) -> Dict[str, Dict[str, int]]:

res: Dict[str, Dict[str, int]] = defaultdict(dict)
for gold_label, pred_label in sorted(self.counts):
res[gold_label][pred_label] = self.counts[(gold_label, pred_label)]

if self.show_as_markdown:
res_df = pd.DataFrame(res).fillna(0)
# index is prediction, columns is gold
gold_labels = res_df.columns
pred_labels = res_df.index

# re-arrange index and columns: sort and put undetected_label and unassignable_label at the end
gold_labels_sorted = sorted(
[gold_label for gold_label in gold_labels if gold_label != self.undetected_label]
)
# re-add undetected_label at the end, if it was in the gold labels
if self.undetected_label in gold_labels:
gold_labels_sorted = gold_labels_sorted + [self.undetected_label]
pred_labels_sorted = sorted(
[pred_label for pred_label in pred_labels if pred_label != self.unassignable_label]
)
# re-add unassignable_label at the end, if it was in the pred labels
if self.unassignable_label in pred_labels:
pred_labels_sorted = pred_labels_sorted + [self.unassignable_label]
res_df_sorted = res_df.loc[pred_labels_sorted, gold_labels_sorted]

# transpose and show as markdown: index is now gold, columns is prediction
logger.info(f"\n{self.layer}:\n{res_df_sorted.T.to_markdown()}")
return res
177 changes: 177 additions & 0 deletions tests/metrics/test_confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import logging
from dataclasses import dataclass
from functools import partial
from typing import Dict

import pytest
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.core import Annotation, AnnotationLayer, annotation_field
from pytorch_ie.documents import TextBasedDocument

from pie_modules.metrics import ConfusionMatrix


@pytest.fixture
def documents():
@dataclass
class TextDocumentWithEntities(TextBasedDocument):
entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text")

# a test sentence with two entities
doc1 = TextDocumentWithEntities(
text="The quick brown fox jumps over the lazy dog.",
)
doc1.entities.append(LabeledSpan(start=4, end=19, label="animal"))
doc1.entities.append(LabeledSpan(start=35, end=43, label="animal"))
assert str(doc1.entities[0]) == "quick brown fox"
assert str(doc1.entities[1]) == "lazy dog"

# a second test sentence with a different text and a single entity (a company)
doc2 = TextDocumentWithEntities(text="Apple is a great company.")
doc2.entities.append(LabeledSpan(start=0, end=5, label="company"))
assert str(doc2.entities[0]) == "Apple"

documents = [doc1, doc2]

# add predictions
# correct
documents[0].entities.predictions.append(LabeledSpan(start=4, end=19, label="animal"))
# wrong label
documents[0].entities.predictions.append(LabeledSpan(start=35, end=43, label="cat"))
# correct
documents[1].entities.predictions.append(LabeledSpan(start=0, end=5, label="company"))
# wrong span
documents[1].entities.predictions.append(LabeledSpan(start=10, end=15, label="company"))

return documents


def test_confusion_matrix(documents):
metric = ConfusionMatrix(layer="entities")
metric(documents)
# (gold_label, predicted_label): count
assert dict(metric.counts) == {
("animal", "animal"): 1,
("animal", "cat"): 1,
("UNDETECTED", "company"): 1,
("company", "company"): 1,
}
assert metric.compute() == {
"animal": {"animal": 1, "cat": 1},
"UNDETECTED": {"company": 1},
"company": {"company": 1},
}


def test_undetected_is_gold_label(documents):
metric = ConfusionMatrix(layer="entities", undetected_label="animal")
with pytest.raises(ValueError) as exception:
metric(documents)

assert str(exception.value).startswith("The gold annotation has the label")


def test_unassignable_is_pred_label(documents):
metric = ConfusionMatrix(layer="entities", unassignable_label="cat")
with pytest.raises(ValueError) as exception:
metric(documents)

assert str(exception.value).startswith("The predicted annotation has the label")


@pytest.fixture
def documents_with_several_gold_labels(documents):
doc1 = documents[0].copy()
doc2 = documents[1].copy()
doc1.entities.append(LabeledSpan(start=4, end=19, label="cat"))

return [doc1, doc2]


def test_documents_with_several_gold_labels(documents_with_several_gold_labels, caplog):
metric = ConfusionMatrix(layer="entities")
with pytest.raises(ValueError):
metric(documents_with_several_gold_labels)

metric = ConfusionMatrix(layer="entities", strict=False)
metric(documents_with_several_gold_labels)
assert caplog.messages[0].startswith(
"The base annotation LabeledSpan(start=4, end=19, label='DUMMY_LABEL', score=1.0) has multiple gold labels: ['animal', 'cat']. Skip this base annotation."
)


@pytest.fixture
def documents_without_predictions(documents):
doc1 = documents[0].copy()
doc2 = documents[1].copy()

doc1.entities.predictions.clear()
doc2.entities.predictions.clear()

return [doc1, doc2]


def test_documents_without_predictions(documents_without_predictions):
metric = ConfusionMatrix(layer="entities")
metric(documents_without_predictions)
assert dict(metric.counts) == {("animal", "UNASSIGNABLE"): 2, ("company", "UNASSIGNABLE"): 1}


def test_show_as_markdown(documents, caplog):
caplog.set_level(logging.INFO)
metric = ConfusionMatrix(layer="entities", show_as_markdown=True)
metric(documents)

markdown = [
"\nentities:\n| | animal | cat | company |\n|:-----------|---------:|------:|----------:|\n| animal | 1 | 1 | 0 |\n| company | 0 | 0 | 1 |\n| UNDETECTED | 0 | 0 | 1 |"
]

assert caplog.messages == markdown


def test_show_as_markdown_without_predictions(documents_without_predictions, caplog):
caplog.set_level(logging.INFO)
metric = ConfusionMatrix(layer="entities", show_as_markdown=True)
metric(documents_without_predictions)

markdown = [
"\nentities:\n| | UNASSIGNABLE |\n|:--------|---------------:|\n| animal | 2 |\n| company | 1 |"
]

assert caplog.messages == markdown


MAPPING = {"cat": "My beautified Cat", "company": "The Super-Company", "animal": "Hrrr"}


def relabel_annotation(ann: Annotation, mapping: Dict[str, str] = MAPPING) -> Annotation:
return ann.copy(label=mapping[ann.label])


def test_annotation_processor(documents):
annotation_processor = partial(
relabel_annotation,
mapping=MAPPING,
)
metric = ConfusionMatrix(layer="entities", annotation_processor=annotation_processor)
metric(documents)

assert dict(metric.counts) == {
("Hrrr", "My beautified Cat"): 1,
("Hrrr", "Hrrr"): 1,
("The Super-Company", "The Super-Company"): 1,
("UNDETECTED", "The Super-Company"): 1,
}


def test_annotation_processor_str(documents):
annotation_processor = "tests.metrics.test_confusion_matrix.relabel_annotation"
metric = ConfusionMatrix(layer="entities", annotation_processor=annotation_processor)
metric(documents)

assert dict(metric.counts) == {
("Hrrr", "My beautified Cat"): 1,
("Hrrr", "Hrrr"): 1,
("The Super-Company", "The Super-Company"): 1,
("UNDETECTED", "The Super-Company"): 1,
}
Loading