Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import logging
from typing import Dict, Generic, List, Optional, Set, TypeVar, Union

from pie_core.utils.dictionary import flatten_dict_s
from pytorch_ie import PyTorchIEModel
from torchmetrics import Metric, MetricCollection

from pie_modules.utils import flatten_dict

from .has_taskmodule import HasTaskmodule
from .stages import TESTING, TRAINING, VALIDATION

Expand Down Expand Up @@ -143,7 +142,7 @@ def log_metric(self, stage: str, reset: bool = True) -> None:
values = metric.compute()
log_kwargs = {"on_step": False, "on_epoch": True, "sync_dist": True}
if isinstance(values, dict):
values_flat = flatten_dict(values, sep="/")
values_flat = flatten_dict_s(values, sep="/")
for key, value in values_flat.items():
self.log(f"metric/{key}/{stage}", value, **log_kwargs)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/pie_modules/taskmodules/cross_text_binary_coref.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch
from pie_core import Annotation, TaskEncoding, TaskModule
from pie_core.utils.dictionary import list_of_dicts2dict_of_lists
from pytorch_ie.utils.window import get_window_around_slice
from torchmetrics import MetricCollection
from torchmetrics.classification import (
Expand All @@ -32,7 +33,6 @@
)
from pie_modules.taskmodules.common.mixins import RelationStatisticsMixin
from pie_modules.taskmodules.metrics import WrappedMetricWithPrepareFunction
from pie_modules.utils import list_of_dicts2dict_of_lists
from pie_modules.utils.tokenization import (
SpanNotAlignedWithTokenException,
get_aligned_token_span,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import torch
from pie_core import AnnotationLayer, TaskEncoding, TaskModule
from pie_core.utils.dictionary import list_of_dicts2dict_of_lists
from tokenizers import Encoding
from torchmetrics import F1Score, Metric, MetricCollection, Precision, Recall
from transformers import AutoTokenizer
Expand All @@ -48,7 +49,6 @@
PrecisionRecallAndF1ForLabeledAnnotations,
WrappedMetricWithPrepareFunction,
)
from pie_modules.utils import list_of_dicts2dict_of_lists
from pie_modules.utils.sequence_tagging import tag_sequence_to_token_spans

DocumentType: TypeAlias = TextBasedDocument
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@

import torch
from pie_core import Annotation
from pie_core.utils.dictionary import flatten_dict_s
from torch import LongTensor

from pie_modules.utils import flatten_dict

from .common import MetricWithArbitraryCounts

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -133,6 +132,6 @@ def compute(self) -> Union[Dict[str, Any], Dict[Optional[str], dict[str, float]]
result = {f"{self.prefix}{k}": v for k, v in result.items()}

if self.flatten_result_with_sep is not None:
return flatten_dict(result, sep=self.flatten_result_with_sep)
return flatten_dict_s(result, sep=self.flatten_result_with_sep)
else:
return result
3 changes: 1 addition & 2 deletions src/pie_modules/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .dictionary import flatten_dict, list_of_dicts2dict_of_lists

# backwards compatibility
from .dictionary import flatten_dict, list_of_dicts2dict_of_lists
from .hydra import resolve_type
27 changes: 3 additions & 24 deletions src/pie_modules/utils/dictionary.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,3 @@
from typing import Any, Iterable, MutableMapping, Tuple


def list_of_dicts2dict_of_lists(list_of_dicts: list[dict]) -> dict[str, list]:
return {k: [d[k] for d in list_of_dicts] for k in list_of_dicts[0].keys()}


def _flatten_dict_gen(d, parent_key, sep) -> Iterable[Tuple[str, Any]]:
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, MutableMapping):
yield from _flatten_dict_gen(v, new_key, sep=sep)
else:
yield new_key, v


def flatten_dict(d: MutableMapping, parent_key: str = "", sep: str = "/"):
"""Flatten a nested dictionary.

Example:
d = {"a": {"b": 1, "c": 2}, "d": 3}
flatten_nested_dict(d) == {"a/b": 1, "a/c": 2, "d": 3}
"""
return dict(_flatten_dict_gen(d, parent_key, sep))
# backwards compatibility
from pie_core.utils.dictionary import flatten_dict_s as flatten_dict
from pie_core.utils.dictionary import list_of_dicts2dict_of_lists
6 changes: 3 additions & 3 deletions tests/taskmodules/test_cross_text_binary_coref.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
import torch.testing
from pie_core.utils.dictionary import flatten_dict_s, list_of_dicts2dict_of_lists
from torch import tensor
from torchmetrics import Metric, MetricCollection

Expand All @@ -13,7 +14,6 @@
TextPairDocumentWithLabeledSpansAndBinaryCorefRelations,
)
from pie_modules.taskmodules import CrossTextBinaryCorefTaskModule
from pie_modules.utils import flatten_dict, list_of_dicts2dict_of_lists
from tests import FIXTURES_ROOT, _config_to_str

TOKENIZER_NAME_OR_PATH = "bert-base-cased"
Expand Down Expand Up @@ -305,9 +305,9 @@ def test_create_annotation_from_output(taskmodule, task_encodings, unbatched_out

def get_metric_state(metric_or_collection: Union[Metric, MetricCollection]) -> Dict[str, Any]:
if isinstance(metric_or_collection, Metric):
return flatten_dict(metric_or_collection.metric_state)
return flatten_dict_s(metric_or_collection.metric_state)
elif isinstance(metric_or_collection, MetricCollection):
return flatten_dict({k: get_metric_state(v) for k, v in metric_or_collection.items()})
return flatten_dict_s({k: get_metric_state(v) for k, v in metric_or_collection.items()})
else:
raise ValueError(f"unsupported type: {type(metric_or_collection)}")

Expand Down
8 changes: 5 additions & 3 deletions tests/taskmodules/test_re_span_pair_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import pytest
import torch
from pie_core import AnnotationLayer, annotation_field
from pie_core.utils.dictionary import flatten_dict_s
from torch import tensor
from torchmetrics import Metric, MetricCollection

from pie_modules.annotations import BinaryRelation, LabeledSpan
from pie_modules.documents import TextBasedDocument
from pie_modules.taskmodules import RESpanPairClassificationTaskModule
from pie_modules.utils import flatten_dict
from pie_modules.utils.span import distance
from tests import _config_to_str

Expand Down Expand Up @@ -557,9 +557,11 @@ def test_create_annotations_from_output(

def get_metric_state(metric_or_collection: Union[Metric, MetricCollection]) -> Dict[str, Any]:
if isinstance(metric_or_collection, Metric):
return {k: v.tolist() for k, v in flatten_dict(metric_or_collection.metric_state).items()}
return {
k: v.tolist() for k, v in flatten_dict_s(metric_or_collection.metric_state).items()
}
elif isinstance(metric_or_collection, MetricCollection):
return flatten_dict({k: get_metric_state(v) for k, v in metric_or_collection.items()})
return flatten_dict_s({k: get_metric_state(v) for k, v in metric_or_collection.items()})
else:
raise ValueError(f"unsupported type: {type(metric_or_collection)}")

Expand Down
8 changes: 5 additions & 3 deletions tests/taskmodules/test_re_text_classification_with_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
TaskEncoding,
annotation_field,
)
from pie_core.utils.dictionary import flatten_dict_s
from torch import tensor
from torchmetrics import Metric, MetricCollection

Expand All @@ -30,7 +31,6 @@
get_relation_argument_spans_and_roles,
span_distance,
)
from pie_modules.utils import flatten_dict
from pie_modules.utils.span import distance_inner
from tests import _config_to_str
from tests.conftest import _TABULATE_AVAILABLE, TestDocument
Expand Down Expand Up @@ -2292,9 +2292,11 @@ def test_get_global_attention(taskmodule, batch, cfg):

def get_metric_state(metric_or_collection: Union[Metric, MetricCollection]) -> Dict[str, Any]:
if isinstance(metric_or_collection, Metric):
return {k: v.tolist() for k, v in flatten_dict(metric_or_collection.metric_state).items()}
return {
k: v.tolist() for k, v in flatten_dict_s(metric_or_collection.metric_state).items()
}
elif isinstance(metric_or_collection, MetricCollection):
return flatten_dict({k: get_metric_state(v) for k, v in metric_or_collection.items()})
return flatten_dict_s({k: get_metric_state(v) for k, v in metric_or_collection.items()})
else:
raise ValueError(f"unsupported type: {type(metric_or_collection)}")

Expand Down
7 changes: 0 additions & 7 deletions tests/utils/test_dictionary.py

This file was deleted.

Loading