diff --git a/src/pie_modules/models/common/model_with_metrics_from_taskmodule.py b/src/pie_modules/models/common/model_with_metrics_from_taskmodule.py index 06ffe2ce9..ea65ffd35 100644 --- a/src/pie_modules/models/common/model_with_metrics_from_taskmodule.py +++ b/src/pie_modules/models/common/model_with_metrics_from_taskmodule.py @@ -1,11 +1,10 @@ import logging from typing import Dict, Generic, List, Optional, Set, TypeVar, Union +from pie_core.utils.dictionary import flatten_dict_s from pytorch_ie import PyTorchIEModel from torchmetrics import Metric, MetricCollection -from pie_modules.utils import flatten_dict - from .has_taskmodule import HasTaskmodule from .stages import TESTING, TRAINING, VALIDATION @@ -143,7 +142,7 @@ def log_metric(self, stage: str, reset: bool = True) -> None: values = metric.compute() log_kwargs = {"on_step": False, "on_epoch": True, "sync_dist": True} if isinstance(values, dict): - values_flat = flatten_dict(values, sep="/") + values_flat = flatten_dict_s(values, sep="/") for key, value in values_flat.items(): self.log(f"metric/{key}/{stage}", value, **log_kwargs) else: diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py index 6b42c8c5b..cff2fa139 100644 --- a/src/pie_modules/taskmodules/cross_text_binary_coref.py +++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py @@ -16,6 +16,7 @@ import torch from pie_core import Annotation, TaskEncoding, TaskModule +from pie_core.utils.dictionary import list_of_dicts2dict_of_lists from pytorch_ie.utils.window import get_window_around_slice from torchmetrics import MetricCollection from torchmetrics.classification import ( @@ -32,7 +33,6 @@ ) from pie_modules.taskmodules.common.mixins import RelationStatisticsMixin from pie_modules.taskmodules.metrics import WrappedMetricWithPrepareFunction -from pie_modules.utils import list_of_dicts2dict_of_lists from pie_modules.utils.tokenization import ( SpanNotAlignedWithTokenException, get_aligned_token_span, diff --git a/src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py b/src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py index e3dfacb1c..ca566a024 100644 --- a/src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py +++ b/src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py @@ -25,6 +25,7 @@ import torch from pie_core import AnnotationLayer, TaskEncoding, TaskModule +from pie_core.utils.dictionary import list_of_dicts2dict_of_lists from tokenizers import Encoding from torchmetrics import F1Score, Metric, MetricCollection, Precision, Recall from transformers import AutoTokenizer @@ -48,7 +49,6 @@ PrecisionRecallAndF1ForLabeledAnnotations, WrappedMetricWithPrepareFunction, ) -from pie_modules.utils import list_of_dicts2dict_of_lists from pie_modules.utils.sequence_tagging import tag_sequence_to_token_spans DocumentType: TypeAlias = TextBasedDocument diff --git a/src/pie_modules/taskmodules/metrics/precision_recall_and_f1_for_labeled_annotations.py b/src/pie_modules/taskmodules/metrics/precision_recall_and_f1_for_labeled_annotations.py index ca11d7f41..2bb83eb4c 100644 --- a/src/pie_modules/taskmodules/metrics/precision_recall_and_f1_for_labeled_annotations.py +++ b/src/pie_modules/taskmodules/metrics/precision_recall_and_f1_for_labeled_annotations.py @@ -4,10 +4,9 @@ import torch from pie_core import Annotation +from pie_core.utils.dictionary import flatten_dict_s from torch import LongTensor -from pie_modules.utils import flatten_dict - from .common import MetricWithArbitraryCounts logger = logging.getLogger(__name__) @@ -133,6 +132,6 @@ def compute(self) -> Union[Dict[str, Any], Dict[Optional[str], dict[str, float]] result = {f"{self.prefix}{k}": v for k, v in result.items()} if self.flatten_result_with_sep is not None: - return flatten_dict(result, sep=self.flatten_result_with_sep) + return flatten_dict_s(result, sep=self.flatten_result_with_sep) else: return result diff --git a/src/pie_modules/utils/__init__.py b/src/pie_modules/utils/__init__.py index 18ce75525..c8017711c 100644 --- a/src/pie_modules/utils/__init__.py +++ b/src/pie_modules/utils/__init__.py @@ -1,4 +1,3 @@ -from .dictionary import flatten_dict, list_of_dicts2dict_of_lists - # backwards compatibility +from .dictionary import flatten_dict, list_of_dicts2dict_of_lists from .hydra import resolve_type diff --git a/src/pie_modules/utils/dictionary.py b/src/pie_modules/utils/dictionary.py index 7559e1783..34cb20be0 100644 --- a/src/pie_modules/utils/dictionary.py +++ b/src/pie_modules/utils/dictionary.py @@ -1,24 +1,3 @@ -from typing import Any, Iterable, MutableMapping, Tuple - - -def list_of_dicts2dict_of_lists(list_of_dicts: list[dict]) -> dict[str, list]: - return {k: [d[k] for d in list_of_dicts] for k in list_of_dicts[0].keys()} - - -def _flatten_dict_gen(d, parent_key, sep) -> Iterable[Tuple[str, Any]]: - for k, v in d.items(): - new_key = parent_key + sep + k if parent_key else k - if isinstance(v, MutableMapping): - yield from _flatten_dict_gen(v, new_key, sep=sep) - else: - yield new_key, v - - -def flatten_dict(d: MutableMapping, parent_key: str = "", sep: str = "/"): - """Flatten a nested dictionary. - - Example: - d = {"a": {"b": 1, "c": 2}, "d": 3} - flatten_nested_dict(d) == {"a/b": 1, "a/c": 2, "d": 3} - """ - return dict(_flatten_dict_gen(d, parent_key, sep)) +# backwards compatibility +from pie_core.utils.dictionary import flatten_dict_s as flatten_dict +from pie_core.utils.dictionary import list_of_dicts2dict_of_lists diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py index f3677159f..065f9b519 100644 --- a/tests/taskmodules/test_cross_text_binary_coref.py +++ b/tests/taskmodules/test_cross_text_binary_coref.py @@ -3,6 +3,7 @@ import pytest import torch.testing +from pie_core.utils.dictionary import flatten_dict_s, list_of_dicts2dict_of_lists from torch import tensor from torchmetrics import Metric, MetricCollection @@ -13,7 +14,6 @@ TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, ) from pie_modules.taskmodules import CrossTextBinaryCorefTaskModule -from pie_modules.utils import flatten_dict, list_of_dicts2dict_of_lists from tests import FIXTURES_ROOT, _config_to_str TOKENIZER_NAME_OR_PATH = "bert-base-cased" @@ -305,9 +305,9 @@ def test_create_annotation_from_output(taskmodule, task_encodings, unbatched_out def get_metric_state(metric_or_collection: Union[Metric, MetricCollection]) -> Dict[str, Any]: if isinstance(metric_or_collection, Metric): - return flatten_dict(metric_or_collection.metric_state) + return flatten_dict_s(metric_or_collection.metric_state) elif isinstance(metric_or_collection, MetricCollection): - return flatten_dict({k: get_metric_state(v) for k, v in metric_or_collection.items()}) + return flatten_dict_s({k: get_metric_state(v) for k, v in metric_or_collection.items()}) else: raise ValueError(f"unsupported type: {type(metric_or_collection)}") diff --git a/tests/taskmodules/test_re_span_pair_classification.py b/tests/taskmodules/test_re_span_pair_classification.py index de8c5e478..f82554c1e 100644 --- a/tests/taskmodules/test_re_span_pair_classification.py +++ b/tests/taskmodules/test_re_span_pair_classification.py @@ -5,13 +5,13 @@ import pytest import torch from pie_core import AnnotationLayer, annotation_field +from pie_core.utils.dictionary import flatten_dict_s from torch import tensor from torchmetrics import Metric, MetricCollection from pie_modules.annotations import BinaryRelation, LabeledSpan from pie_modules.documents import TextBasedDocument from pie_modules.taskmodules import RESpanPairClassificationTaskModule -from pie_modules.utils import flatten_dict from pie_modules.utils.span import distance from tests import _config_to_str @@ -557,9 +557,11 @@ def test_create_annotations_from_output( def get_metric_state(metric_or_collection: Union[Metric, MetricCollection]) -> Dict[str, Any]: if isinstance(metric_or_collection, Metric): - return {k: v.tolist() for k, v in flatten_dict(metric_or_collection.metric_state).items()} + return { + k: v.tolist() for k, v in flatten_dict_s(metric_or_collection.metric_state).items() + } elif isinstance(metric_or_collection, MetricCollection): - return flatten_dict({k: get_metric_state(v) for k, v in metric_or_collection.items()}) + return flatten_dict_s({k: get_metric_state(v) for k, v in metric_or_collection.items()}) else: raise ValueError(f"unsupported type: {type(metric_or_collection)}") diff --git a/tests/taskmodules/test_re_text_classification_with_indices.py b/tests/taskmodules/test_re_text_classification_with_indices.py index 16bcfe83d..0574e10e5 100644 --- a/tests/taskmodules/test_re_text_classification_with_indices.py +++ b/tests/taskmodules/test_re_text_classification_with_indices.py @@ -14,6 +14,7 @@ TaskEncoding, annotation_field, ) +from pie_core.utils.dictionary import flatten_dict_s from torch import tensor from torchmetrics import Metric, MetricCollection @@ -30,7 +31,6 @@ get_relation_argument_spans_and_roles, span_distance, ) -from pie_modules.utils import flatten_dict from pie_modules.utils.span import distance_inner from tests import _config_to_str from tests.conftest import _TABULATE_AVAILABLE, TestDocument @@ -2292,9 +2292,11 @@ def test_get_global_attention(taskmodule, batch, cfg): def get_metric_state(metric_or_collection: Union[Metric, MetricCollection]) -> Dict[str, Any]: if isinstance(metric_or_collection, Metric): - return {k: v.tolist() for k, v in flatten_dict(metric_or_collection.metric_state).items()} + return { + k: v.tolist() for k, v in flatten_dict_s(metric_or_collection.metric_state).items() + } elif isinstance(metric_or_collection, MetricCollection): - return flatten_dict({k: get_metric_state(v) for k, v in metric_or_collection.items()}) + return flatten_dict_s({k: get_metric_state(v) for k, v in metric_or_collection.items()}) else: raise ValueError(f"unsupported type: {type(metric_or_collection)}") diff --git a/tests/utils/test_dictionary.py b/tests/utils/test_dictionary.py deleted file mode 100644 index a03331f32..000000000 --- a/tests/utils/test_dictionary.py +++ /dev/null @@ -1,7 +0,0 @@ -from pie_modules.utils import flatten_dict - - -def test_flatten_nested_dict(): - d = {"a": {"b": 1, "c": 2}, "d": 3} - assert flatten_dict(d) == {"a/b": 1, "a/c": 2, "d": 3} - assert flatten_dict(d, sep=".") == {"a.b": 1, "a.c": 2, "d": 3}