Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from typing import Optional, Sequence, Set, Tuple, TypeVar, Union

from pie_core import AnnotationLayer, Document
from pie_core.utils.hydra import resolve_type

from pie_modules.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
from pie_modules.utils import resolve_type

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion src/pie_modules/document/processing/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
)

from pie_core import Annotation
from pie_core.utils.hydra import resolve_type
from transformers import PreTrainedTokenizer

from pie_modules.annotations import MultiSpan, Span
from pie_modules.documents import TextBasedDocument, TokenBasedDocument
from pie_modules.utils import resolve_type

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion src/pie_modules/metrics/span_coverage_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from typing import Any, Dict, List, Optional, Set, Type, Union

from pie_core import Document, DocumentStatistic
from pie_core.utils.hydra import resolve_type
from transformers import AutoTokenizer, PreTrainedTokenizer

from pie_modules.annotations import LabeledMultiSpan, Span
from pie_modules.document.processing import tokenize_document
from pie_modules.documents import TextBasedDocument, TokenBasedDocument
from pie_modules.utils import resolve_type

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion src/pie_modules/metrics/span_length_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from typing import Any, Callable, Dict, List, Optional, Type, Union

from pie_core import Document, DocumentStatistic
from pie_core.utils.hydra import resolve_type
from transformers import AutoTokenizer, PreTrainedTokenizer

from pie_modules.annotations import Span
from pie_modules.document.processing import tokenize_document
from pie_modules.documents import TextBasedDocument, TokenBasedDocument
from pie_modules.utils import resolve_type

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion src/pie_modules/models/simple_generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, Optional, Tuple, Type, Union

import torch
from pie_core.utils.hydra import resolve_type
from pytorch_ie import PyTorchIEModel
from pytorch_lightning.utilities.types import OptimizerLRScheduler
from torch import FloatTensor, LongTensor
Expand All @@ -12,7 +13,6 @@
from typing_extensions import TypeAlias

from pie_modules.models.common import ModelWithBoilerplate
from pie_modules.utils import resolve_type

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
TargetEncoding,
TaskBatchEncoding,
)
from pie_core.utils.hydra import resolve_type
from torchmetrics import Metric
from transformers import AutoTokenizer, LogitsProcessorList, PreTrainedTokenizer
from typing_extensions import TypeAlias
Expand All @@ -45,7 +46,6 @@
)

from ..document.processing import token_based_document_to_text_based, tokenize_document
from ..utils import resolve_type
from .common import BatchableMixin, get_first_occurrence_index
from .metrics import (
PrecisionRecallAndF1ForLabeledAnnotations,
Expand Down
2 changes: 1 addition & 1 deletion src/pie_modules/taskmodules/text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
TargetEncoding,
TaskBatchEncoding,
)
from pie_core.utils.hydra import resolve_type
from torchmetrics import Metric
from transformers import AutoTokenizer, PreTrainedTokenizer
from typing_extensions import TypeAlias
Expand All @@ -38,7 +39,6 @@
tokenize_document,
)
from pie_modules.documents import TextBasedDocument, TokenBasedDocument
from pie_modules.utils import resolve_type

from .common import BatchableMixin, get_first_occurrence_index
from .metrics import WrappedMetricWithPrepareFunction
Expand Down
2 changes: 2 additions & 0 deletions src/pie_modules/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .dictionary import flatten_dict, list_of_dicts2dict_of_lists

# backwards compatibility
from .hydra import resolve_type
27 changes: 2 additions & 25 deletions src/pie_modules/utils/hydra.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,2 @@
from typing import Optional, Type, TypeVar, Union

from pie_core import Document
from pie_core.utils.hydra import resolve_target

T = TypeVar("T", bound=Document)
T_super = TypeVar("T_super", bound=Document)


def resolve_type(
type_or_str: Union[str, Type[T]], expected_super_type: Optional[Type[T_super]] = None
) -> Type[T]:
if isinstance(type_or_str, str):
dt = resolve_target(type_or_str) # type: ignore
else:
dt = type_or_str
if not (
isinstance(dt, type)
and (expected_super_type is None or issubclass(dt, expected_super_type))
):
raise TypeError(
f"type must be a subclass of {expected_super_type} or a string that resolves to that, "
f"but got {dt}"
)
return dt
# backwards compatibility
from pie_core.utils.hydra import resolve_target, resolve_type
2 changes: 1 addition & 1 deletion tests/utils/test_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import pytest
from pie_core import AnnotationLayer, annotation_field
from pie_core.utils.hydra import resolve_type

from pie_modules.annotations import LabeledSpan, Span
from pie_modules.documents import TextBasedDocument
from pie_modules.utils import resolve_type


@dataclasses.dataclass
Expand Down
Loading