From 8917711e50f5aeea1b62a514dd1608236095c6c5 Mon Sep 17 00:00:00 2001 From: zerlok Date: Wed, 15 Oct 2025 23:32:58 +0200 Subject: [PATCH 1/8] improve TypeLoader errors --- src/astlab/reader.py | 7 ++++++- src/astlab/types/loader.py | 27 +++++++++++++++++++++------ tests/unit/test_types.py | 6 +++--- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/src/astlab/reader.py b/src/astlab/reader.py index 408fa40..d904224 100644 --- a/src/astlab/reader.py +++ b/src/astlab/reader.py @@ -38,7 +38,12 @@ def import_module_path(path: Path) -> ModuleType: # Avoid `.py` in last part. qualname = ".".join((*relpath.parts[:-1], relpath.stem)) - return importlib.import_module(qualname) + try: + return importlib.import_module(qualname) + + except ImportError as err: + msg = "can't import module from path" + raise ImportError(msg, path) from err def _get_rel_path_parts_count(args: tuple[Path, Path]) -> int: diff --git a/src/astlab/types/loader.py b/src/astlab/types/loader.py index 816e67f..4d90933 100644 --- a/src/astlab/types/loader.py +++ b/src/astlab/types/loader.py @@ -3,6 +3,7 @@ __all__ = [ "ModuleLoader", "TypeLoader", + "TypeLoaderError", ] import importlib @@ -54,6 +55,10 @@ def clear_cache(self) -> None: importlib.invalidate_caches() +class TypeLoaderError(Exception): + pass + + class TypeLoader: """Loads runtime type from provided info.""" @@ -72,18 +77,28 @@ def load(self, info: TypeInfo) -> RuntimeType: elif info == ellipsis_type_info(): return Ellipsis - value: object = self.__module.load(info.module) + try: + value: object = self.__module.load(info.module) + + except ImportError as err: + msg = "can't load module" + raise TypeLoaderError(msg, info) from err + + try: + for name in info.namespace: + value = getattr(value, name) - for name in info.namespace: - value = getattr(value, name) + # NOTE: need to check that we loaded a type. + type_: object = getattr(value, info.name) - # NOTE: need to check that we loaded a type. - type_: object = getattr(value, info.name) + except AttributeError as err: + msg = "can't module has no attribute" + raise TypeLoaderError(msg, info) from err if not info.type_params: return type_ - # TODO: fix recursive type + # TODO: fix recursive type load type_params = tuple(self.load(tp) for tp in info.type_params) return getitem(type_, type_params) if len(type_params) > 1 else getitem(type_, type_params[0]) # type: ignore[call-overload, misc] diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index 860c4f9..954caa5 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -8,7 +8,7 @@ from astlab import abc as astlab_abc from astlab.types.annotator import TypeAnnotator from astlab.types.inspector import TypeInspector -from astlab.types.loader import TypeLoader +from astlab.types.loader import TypeLoader, TypeLoaderError from astlab.types.model import ( LiteralTypeInfo, ModuleInfo, @@ -473,11 +473,11 @@ def test_load_ok( [ pytest.param( NamedTypeInfo("NonExistingType", builtins_module_info()), - AttributeError, + TypeLoaderError, ), pytest.param( NamedTypeInfo("SomeType", ModuleInfo("non_existing_module")), - ModuleNotFoundError, + TypeLoaderError, ), ], ) From 66efcfaca33dc0036bb9cf9a1b7bee9857d59b08 Mon Sep 17 00:00:00 2001 From: zerlok Date: Wed, 15 Oct 2025 23:33:51 +0200 Subject: [PATCH 2/8] add type alias AST builder and support type vars --- src/astlab/builder.py | 289 +++++++++++++++++++++++++++---------- src/astlab/types/predef.py | 10 +- tests/unit/test_builder.py | 88 ++++++++--- 3 files changed, 288 insertions(+), 99 deletions(-) diff --git a/src/astlab/builder.py b/src/astlab/builder.py index aa1de98..9a5db4c 100644 --- a/src/astlab/builder.py +++ b/src/astlab/builder.py @@ -523,7 +523,7 @@ def transform( context: BuildContext, info: TypeInfo, ) -> Expr: - return self._scope.generic_type(dict, inner(context, info), value) + return self._scope.dict_type(inner(context, info), value) return self.__wrap(transform) @@ -543,7 +543,7 @@ def transform( context: BuildContext, info: TypeInfo, ) -> Expr: - return self._scope.generic_type(dict, key, inner(context, info)) + return self._scope.dict_type(key, inner(context, info)) return self.__wrap(transform) @@ -577,6 +577,16 @@ def transform( return self.__wrap(transform) + def type_params(self, *params: TypeRef) -> ClassTypeRefBuilder: + def transform( + inner: t.Callable[[BuildContext, TypeInfo], Expr], + context: BuildContext, + info: TypeInfo, + ) -> Expr: + return self._scope.subscript(inner(context, info), self._scope.tuple_type(*params, normalize=True)) + + return self.__wrap(transform) + def attr(self, *tail: str) -> AttrASTBuilder: return AttrASTBuilder(self._context, self, *tail) @@ -605,16 +615,7 @@ def __wrap( return self.__class__(self._context, self.__info, partial(transform, self.__transform)) -@dataclass(frozen=True) -class Comprehension: - target: Expr - items: Expr - predicates: t.Sequence[Expr] = field(default_factory=list) - is_async: bool = False - - -# noinspection PyTypeChecker -class ScopeASTBuilder(_BaseBuilder): +class AnnotationASTBuilder(_BaseBuilder): def type_ref(self, origin: t.Union[TypeInfo, RuntimeType]) -> ClassTypeRefBuilder: return ClassTypeRefBuilder( context=self._context, @@ -623,16 +624,16 @@ def type_ref(self, origin: t.Union[TypeInfo, RuntimeType]) -> ClassTypeRefBuilde else self._context.inspector.inspect(origin), ) - if sys.version_info < (3, 10): + if sys.version_info >= (3, 10): @_ast_expr_builder - def const(self, value: t.Union[str, bytes, bool, complex, None]) -> Expr: # noqa: FBT001 + def const(self, value: t.Union[str, bytes, bool, complex, EllipsisType, None]) -> Expr: # noqa: FBT001 return ast.Constant(value=value) else: @_ast_expr_builder - def const(self, value: t.Union[str, bytes, bool, complex, EllipsisType, None]) -> Expr: # noqa: FBT001 + def const(self, value: t.Union[str, bytes, bool, complex, None]) -> Expr: # noqa: FBT001 return ast.Constant(value=value) def none(self) -> Expr: @@ -641,6 +642,108 @@ def none(self) -> Expr: def ellipsis(self) -> Expr: return ast.Constant(value=...) + def generic_type(self, generic: TypeRef, *params: TypeRef) -> Expr: + if len(params) == 0: + return self._normalize_annotation(generic) + + return ast.Subscript( + value=self._normalize_annotation(generic), + slice=self._normalize_annotation(self.tuple_type(*params, normalize=True)), + ) + + def literal_type(self, *values: t.Union[str, Expr]) -> Expr: + if not values: + return self._normalize_annotation(predef().no_return) + + return self.generic_type( + predef().literal, + *(self.const(val) if isinstance(val, str) else val for val in values), + ) + + def optional_type(self, of_type: TypeRef) -> Expr: + return self.generic_type(predef().optional, of_type) + + def union_type(self, *params: TypeRef) -> Expr: + if not params: + return self._normalize_annotation(predef().no_return) + + return self.generic_type(predef().union, *params) + + def tuple_type(self, *params: TypeRef, normalize: bool = False) -> Expr: + if normalize and len(params) == 1: + return self._normalize_annotation(params[0]) + + return ast.Tuple(elts=[self._normalize_annotation(item) for item in params]) + + def collection_type(self, of_type: TypeRef) -> Expr: + return self.generic_type(predef().collection, of_type) + + def sequence_type(self, of_type: TypeRef, *, mutable: bool = False) -> Expr: + return self.generic_type(predef().mutable_sequence if mutable else predef().sequence, of_type) + + def list_type(self, of_type: TypeRef) -> Expr: + return self.generic_type(predef().list, of_type) + + def mapping_type(self, key_type: TypeRef, value_type: TypeRef, *, mutable: bool = False) -> Expr: + return self.generic_type(predef().mutable_mapping if mutable else predef().mapping, key_type, value_type) + + def dict_type(self, key_type: TypeRef, value_type: TypeRef) -> Expr: + return self.generic_type(predef().dict, key_type, value_type) + + def iterator_type(self, of_type: TypeRef, *, is_async: bool = False) -> Expr: + return self.generic_type(predef().async_iterator if is_async else predef().iterator, of_type) + + def iterable_type(self, of_type: TypeRef, *, is_async: bool = False) -> Expr: + return self.generic_type(predef().async_iterable if is_async else predef().iterable, of_type) + + def context_manager_type(self, of_type: TypeRef, *, is_async: bool = False) -> Expr: + return self.generic_type( + predef().async_context_manager if is_async else predef().context_manager, + of_type, + ) + + def _on_expression(self, expr: Expr) -> None: + pass + + +class TypeVarRefBuilder(_BaseBuilder, ASTExpressionBuilder): + def __init__( + self, + context: BuildContext, + name: str, + mode: t.Literal["invariant", "covariant", "contravariant"] = "invariant", + lower: t.Optional[TypeRef] = None, + ) -> None: + super().__init__(context) + self.__name = name + self.__mode = mode + self.__lower = lower + + def __enter__(self) -> Self: + return self + + def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None: + pass + + @override + def build_expr(self) -> ast.expr: + return ast.Name(id=self.__name) + + @override + def build_annotation(self) -> ast.expr: + return ast.Name(id=self.__name) + + +@dataclass(frozen=True) +class Comprehension: + target: Expr + items: Expr + predicates: t.Sequence[Expr] = field(default_factory=list) + is_async: bool = False + + +# noinspection PyTypeChecker +class ScopeASTBuilder(AnnotationASTBuilder): @_ast_expr_builder def await_(self, expr: Expr) -> Expr: return ast.Await(self._normalize_expr(expr)) @@ -894,60 +997,6 @@ def slice( step=self._normalize_expr(step) if step is not None else None, ) - def generic_type(self, generic: TypeRef, *args: TypeRef) -> Expr: - if len(args) == 0: - return self._normalize_annotation(generic) - - return ast.Subscript( - value=self._normalize_annotation(generic), - slice=self._normalize_annotation(self.tuple_type(*args, normalize=True)), - ) - - def literal_type(self, *args: t.Union[str, Expr]) -> Expr: - if not args: - return self._normalize_annotation(predef().no_return) - - return self.generic_type( - predef().literal, - *(self.const(arg) if isinstance(arg, str) else arg for arg in args), - ) - - def optional_type(self, of_type: TypeRef) -> Expr: - return self.generic_type(predef().optional, of_type) - - def union_type(self, *args: TypeRef) -> Expr: - if not args: - return self._normalize_annotation(predef().no_return) - - return self.generic_type(predef().union, *args) - - def tuple_type(self, *items: TypeRef, normalize: bool = False) -> Expr: - if normalize and len(items) == 1: - return self._normalize_annotation(items[0]) - - return ast.Tuple(elts=[self._normalize_annotation(item) for item in items]) - - def collection_type(self, of_type: TypeRef) -> Expr: - return self.generic_type(predef().collection, of_type) - - def sequence_type(self, of_type: TypeRef, *, mutable: bool = False) -> Expr: - return self.generic_type(predef().mutable_sequence if mutable else predef().sequence, of_type) - - def mapping_type(self, key_type: TypeRef, value_type: TypeRef, *, mutable: bool = False) -> Expr: - return self.generic_type(predef().mutable_mapping if mutable else predef().mapping, key_type, value_type) - - def iterator_type(self, of_type: TypeRef, *, is_async: bool = False) -> Expr: - return self.generic_type(predef().async_iterator if is_async else predef().iterator, of_type) - - def iterable_type(self, of_type: TypeRef, *, is_async: bool = False) -> Expr: - return self.generic_type(predef().async_iterable if is_async else predef().iterable, of_type) - - def context_manager_type(self, of_type: TypeRef, *, is_async: bool = False) -> Expr: - return self.generic_type( - predef().async_context_manager if is_async else predef().context_manager, - of_type, - ) - def stmt(self, *stmts: t.Optional[Stmt]) -> None: self._context.extend_body(self._context.resolver.resolve_stmts(*stmts)) @@ -966,6 +1015,9 @@ def field_def(self, name: str, annotation: TypeRef, default: t.Optional[Expr] = simple=1, ) + def type_alias(self, name: str) -> TypeAliasStatementASTBuilder: + return TypeAliasStatementASTBuilder(self._context, name) + @_ast_stmt_builder def assign_stmt(self, target: t.Union[str, Expr], value: Expr) -> ast.stmt: return ast.Assign( @@ -1059,6 +1111,77 @@ def __build_comprehensions( ] +class TypeAliasStatementASTBuilder(AnnotationASTBuilder, ASTStatementBuilder, TypeDefinitionBuilder): + def __init__(self, context: BuildContext, name: str, annotation: t.Optional[TypeRef] = None) -> None: + super().__init__(context) + self.__info = NamedTypeInfo(name=name, module=self._context.module, namespace=self._context.namespace) + self.__annotation: t.Optional[TypeRef] = None + self.__type_vars = list[TypeVar]() + + def __enter__(self) -> Self: + return self + + def __exit__(self, exc_type: object, exc_value: object, exc_traceback: object) -> None: + pass + + @override + @property + def info(self) -> TypeInfo: + return self.__info + + @override + def ref(self) -> ClassTypeRefBuilder: + return ClassTypeRefBuilder(self._context, self.__info) + + def assign(self, annotation: TypeRef) -> None: + self.__annotation = annotation + self._context.extend_body(self.build_stmt()) + + def type_var(self, name: str) -> TypeVarRefBuilder: + type_var = TypeVarRefBuilder(self._context, name) + self.__type_vars.append(TypeVar(name=name)) + return type_var + + def type_params(self, *params: TypeRef) -> Expr: + return self.ref().type_params(*params) + + # NOTE: workaround for passing mypy typings in CI for python 3.12 + if sys.version_info >= (3, 12): + + @override + def build_stmt(self) -> t.Sequence[ast.stmt]: + if self.__annotation is None: + msg = "type alias is not set" + raise ValueError(msg, self) + + return [ + ast.TypeAlias( + name=ast.Name(id=self.__info.name), + value=self._normalize_annotation(self.__annotation), + type_params=[ + ast.TypeVar( + name=tv.name, + bound=self._normalize_annotation(tv.bound) if tv.bound is not None else None, + ) + for tv in self.__type_vars + ], + ) + ] + + else: + + def build_stmt(self, name: str, annotation: TypeRef) -> t.Sequence[ast.stmt]: + return [ + ast.AnnAssign( + target=ast.Name(id=name), + # TODO: add typing.TypeAlias + annotation=self._normalize_annotation(...), + value=self._normalize_annotation(annotation), + simple=1, + ) + ] + + class _NestedBlockASTBuilder(_BaseBuilder, ASTStatementBuilder, metaclass=abc.ABCMeta): def __init__(self, context: BuildContext, *, allow_implicit_enter: bool = True) -> None: super().__init__(context) @@ -1304,9 +1427,15 @@ class TypeVar: class ClassScopeASTBuilder(ScopeASTBuilder, TypeDefinitionBuilder): - def __init__(self, context: BuildContext, header: ClassStatementASTBuilder) -> None: + def __init__( + self, + context: BuildContext, + header: ClassStatementASTBuilder, + type_vars: t.MutableSequence[TypeVar], + ) -> None: super().__init__(context) self.__header = header + self.__type_vars = type_vars @override @property @@ -1317,6 +1446,11 @@ def info(self) -> TypeInfo: def ref(self) -> ClassTypeRefBuilder: return self.__header.ref() + def type_var(self, name: str) -> TypeVarRefBuilder: + type_var = TypeVarRefBuilder(self._context, name) + self.__type_vars.append(TypeVar(name=name)) + return type_var + def method_def(self, name: str) -> MethodStatementASTBuilder: return MethodStatementASTBuilder(self._context, name) @@ -1399,7 +1533,7 @@ def __init__(self, context: BuildContext, name: str) -> None: @override def __enter__(self) -> ClassScopeASTBuilder: self._context.enter_scope(self.__info.name, self.__body) - return ClassScopeASTBuilder(self._context, self) + return ClassScopeASTBuilder(self._context, self, self.__type_vars) @override def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None: @@ -1635,8 +1769,6 @@ def not_implemented(self) -> Self: def build_stmt(self) -> t.Sequence[ast.stmt]: node: ast.stmt - scope = ScopeASTBuilder(self._context) - if self.__is_async: # noinspection PyArgumentList node = ast.AsyncFunctionDef( # type: ignore[call-overload,no-any-return,unused-ignore] @@ -1644,7 +1776,7 @@ def build_stmt(self) -> t.Sequence[ast.stmt]: name=self.__info.name, decorator_list=self.__build_decorators(), args=self.__build_args(), - returns=self.__build_returns(scope), + returns=self.__build_returns(), body=self.__build_body(), lineno=0, ) @@ -1656,7 +1788,7 @@ def build_stmt(self) -> t.Sequence[ast.stmt]: name=self.__info.name, decorator_list=self.__build_decorators(), args=self.__build_args(), - returns=self.__build_returns(scope), + returns=self.__build_returns(), body=self.__build_body(), lineno=0, ) @@ -1720,13 +1852,16 @@ def __build_args(self) -> ast.arguments: return node - def __build_returns(self, scope: ScopeASTBuilder) -> t.Optional[ast.expr]: + def __build_returns(self) -> t.Optional[ast.expr]: if self.__returns is None: return None ret = self.__returns if self.__iterator_cm: - ret = scope.iterator_type(ret, is_async=self.__is_async) + ret = ast.Subscript( + value=self._normalize_annotation(predef().async_iterator if self.__is_async else predef().iterator), + slice=self._normalize_annotation(ret), + ) return self._normalize_expr(ret) diff --git a/src/astlab/types/predef.py b/src/astlab/types/predef.py index 4ee2c95..0955ae7 100644 --- a/src/astlab/types/predef.py +++ b/src/astlab/types/predef.py @@ -59,7 +59,7 @@ def object(self) -> NamedTypeInfo: return NamedTypeInfo("object", self.builtins_module) @cached_property - def none_type(self) -> NamedTypeInfo: + def none(self) -> NamedTypeInfo: return none_type_info() @cached_property @@ -142,6 +142,14 @@ def final_decorator(self) -> NamedTypeInfo: def final(self) -> NamedTypeInfo: return NamedTypeInfo("Final", self.typing_module) + @cached_property + def type_alias(self) -> NamedTypeInfo: + return NamedTypeInfo("TypeAlias", self.typing_module) + + @cached_property + def type_var(self) -> NamedTypeInfo: + return NamedTypeInfo("TypeVar", self.typing_module) + @cached_property def class_var(self) -> NamedTypeInfo: return NamedTypeInfo("ClassVar", self.typing_module) diff --git a/tests/unit/test_builder.py b/tests/unit/test_builder.py index b60715b..b0b8e4f 100644 --- a/tests/unit/test_builder.py +++ b/tests/unit/test_builder.py @@ -3,38 +3,43 @@ import typing as t import pytest -from _pytest.mark import ParameterSet +from _pytest.mark import MarkDecorator, ParameterSet from astlab.builder import Comprehension, ModuleASTBuilder, build_module, build_package from astlab.reader import parse_module -from astlab.types import none_type_info +from astlab.types import none_type_info, predef _PARAMS: t.Final[list[ParameterSet]] = [] -def _to_module_param(func: t.Callable[[], ModuleASTBuilder]) -> ParameterSet: - expected_code = inspect.getdoc(func) - assert expected_code is not None +def _to_param( + marks: t.Optional[t.Sequence[MarkDecorator]] = None, +) -> t.Callable[[t.Callable[[], ModuleASTBuilder]], t.Callable[[], ModuleASTBuilder]]: + def inner(func: t.Callable[[], ModuleASTBuilder]) -> t.Callable[[], ModuleASTBuilder]: + expected_code = inspect.getdoc(func) + assert expected_code is not None - param = pytest.param(func(), parse_module(expected_code), id=func.__name__) - _PARAMS.append(param) + param = pytest.param(func, expected_code, id=func.__name__, marks=marks or []) + _PARAMS.append(param) - return param + return func + + return inner @pytest.mark.parametrize(("builder", "expected"), _PARAMS) -def test_module_build(builder: ModuleASTBuilder, expected: ast.Module) -> None: - assert builder.render() == ast.unparse(expected) +def test_module_build(builder: t.Callable[[], ModuleASTBuilder], expected: str) -> None: + assert builder().render() == ast.unparse(parse_module(expected)) -@_to_module_param +@_to_param() def build_empty_module() -> ModuleASTBuilder: """""" with build_module("simple") as mod: return mod -@_to_module_param +@_to_param() def build_simple_module() -> ModuleASTBuilder: # noinspection PySingleQuotedDocstring ''' @@ -89,7 +94,7 @@ def do_buzz(self) -> builtins.object: return mod -@_to_module_param +@_to_param() def build_bar_impl_module() -> ModuleASTBuilder: """ import builtins @@ -124,7 +129,7 @@ def do_stuff(self, spam: builtins.str) -> typing.Iterator[builtins.str]: return bar -@_to_module_param +@_to_param() def build_optionals() -> ModuleASTBuilder: """ import builtins @@ -146,7 +151,7 @@ class MyOptions: return mod -@_to_module_param +@_to_param() def build_unions() -> ModuleASTBuilder: """ import builtins @@ -168,7 +173,7 @@ class MyOptions: return mod -@_to_module_param +@_to_param() def build_runtime_types() -> ModuleASTBuilder: """ import builtins @@ -185,7 +190,7 @@ def build_runtime_types() -> ModuleASTBuilder: return mod -@_to_module_param +@_to_param() def build_is_not_none_expr() -> ModuleASTBuilder: """ maybe = body if test is not None else None @@ -197,7 +202,7 @@ def build_is_not_none_expr() -> ModuleASTBuilder: return mod -@_to_module_param +@_to_param() def build_list_const() -> ModuleASTBuilder: """ result = [1, 2, foo, bar] @@ -209,7 +214,7 @@ def build_list_const() -> ModuleASTBuilder: return mod -@_to_module_param +@_to_param() def build_list_compr_expr() -> ModuleASTBuilder: """ result = [item for target in iterable] @@ -224,7 +229,7 @@ def build_list_compr_expr() -> ModuleASTBuilder: return mod -@_to_module_param +@_to_param() def build_try_except_else() -> ModuleASTBuilder: """ import builtins @@ -256,7 +261,7 @@ def build_try_except_else() -> ModuleASTBuilder: return mod -@_to_module_param +@_to_param() def build_index_slice() -> ModuleASTBuilder: """ list[str] @@ -279,3 +284,44 @@ def build_index_slice() -> ModuleASTBuilder: mod.stmt(mod.attr("arr").index(mod.tuple_expr(mod.slice(), mod.slice(), mod.const(0)))) return mod + + +@_to_param() +def build_type_alias() -> ModuleASTBuilder: + """ + import builtins + import typing + + type MyInt = builtins.int + type Json = typing.Union[None, builtins.bool, builtins.int, builtins.float, builtins.str, builtins.list[Json], builtins.dict[builtins.str, Json]] + type Nested[T1, T2] = typing.Union[T1, T2, typing.Sequence[Nested[T1, T2]]] + """ + + with build_module("alias") as mod: + mod.type_alias("MyInt").assign(predef().int) + + with mod.type_alias("Json") as json_alias: + json_alias.assign( + json_alias.union_type( + predef().none, + predef().bool, + predef().int, + predef().float, + predef().str, + mod.list_type(json_alias), + mod.dict_type(predef().str, json_alias), + ) + ) + + with ( + mod.type_alias("Nested") as nested_alias, + nested_alias.type_var("T1") as type_var_1, + nested_alias.type_var("T2") as type_var_2, + ): + nested_alias.assign( + nested_alias.union_type( + type_var_1, type_var_2, nested_alias.sequence_type(nested_alias.type_params(type_var_1, type_var_2)) + ) + ) + + return mod From 3062178b1d87da0ff6a4de60f727566a51fb5a0c Mon Sep 17 00:00:00 2001 From: zerlok Date: Fri, 17 Oct 2025 01:20:01 +0200 Subject: [PATCH 3/8] add enum & type var types, support type var and type alias AST --- src/astlab/builder.py | 372 ++++++++++++++++++++-------------- src/astlab/resolver.py | 29 +-- src/astlab/types/__init__.py | 11 +- src/astlab/types/annotator.py | 101 ++++++--- src/astlab/types/inspector.py | 42 +++- src/astlab/types/loader.py | 89 +++++--- src/astlab/types/model.py | 91 +++++++-- src/astlab/types/predef.py | 12 ++ tests/conftest.py | 4 +- tests/stub/types.py | 10 +- tests/unit/test_builder.py | 151 +++++++++++++- tests/unit/test_types.py | 21 +- 12 files changed, 680 insertions(+), 253 deletions(-) diff --git a/src/astlab/builder.py b/src/astlab/builder.py index 9a5db4c..45a7568 100644 --- a/src/astlab/builder.py +++ b/src/astlab/builder.py @@ -5,7 +5,6 @@ "CallASTBuilder", "ClassScopeASTBuilder", "ClassStatementASTBuilder", - "ClassTypeRefBuilder", "Comprehension", "ForStatementASTBuilder", "FuncArgInfo", @@ -17,6 +16,7 @@ "PackageASTBuilder", "ScopeASTBuilder", "TryStatementASTBuilder", + "TypeRefBuilder", "WhileStatementASTBuilder", "WithStatementASTBuilder", "build_module", @@ -57,6 +57,7 @@ RuntimeType, TypeInfo, TypeInspector, + TypeVarVariance, predef, ) from astlab.writer import render_module, write_module @@ -440,7 +441,7 @@ def __create_expr(self) -> ast.expr: return node -class ClassTypeRefBuilder(_BaseBuilder, ASTExpressionBuilder): +class TypeRefBuilder(_BaseBuilder, ASTExpressionBuilder): def __init__( self, context: BuildContext, @@ -457,7 +458,7 @@ def __init__( def info(self) -> TypeInfo: return self.__info - def optional(self) -> ClassTypeRefBuilder: + def optional(self) -> TypeRefBuilder: def transform( inner: t.Callable[[BuildContext, TypeInfo], Expr], context: BuildContext, @@ -467,7 +468,7 @@ def transform( return self.__wrap(transform) - def collection(self) -> ClassTypeRefBuilder: + def collection(self) -> TypeRefBuilder: def transform( inner: t.Callable[[BuildContext, TypeInfo], Expr], context: BuildContext, @@ -477,7 +478,7 @@ def transform( return self.__wrap(transform) - def set(self) -> ClassTypeRefBuilder: + def set(self) -> TypeRefBuilder: def transform( inner: t.Callable[[BuildContext, TypeInfo], Expr], context: BuildContext, @@ -487,7 +488,7 @@ def transform( return self.__wrap(transform) - def sequence(self, *, mutable: bool = False) -> ClassTypeRefBuilder: + def sequence(self, *, mutable: bool = False) -> TypeRefBuilder: def transform( inner: t.Callable[[BuildContext, TypeInfo], Expr], context: BuildContext, @@ -497,7 +498,7 @@ def transform( return self.__wrap(transform) - def list(self) -> ClassTypeRefBuilder: + def list(self) -> TypeRefBuilder: def transform( inner: t.Callable[[BuildContext, TypeInfo], Expr], context: BuildContext, @@ -507,7 +508,7 @@ def transform( return self.__wrap(transform) - def mapping_key(self, value: TypeRef, *, mutable: bool = False) -> ClassTypeRefBuilder: + def mapping_key(self, value: TypeRef, *, mutable: bool = False) -> TypeRefBuilder: def transform( inner: t.Callable[[BuildContext, TypeInfo], Expr], context: BuildContext, @@ -517,7 +518,7 @@ def transform( return self.__wrap(transform) - def dict_key(self, value: TypeRef) -> ClassTypeRefBuilder: + def dict_key(self, value: TypeRef) -> TypeRefBuilder: def transform( inner: t.Callable[[BuildContext, TypeInfo], Expr], context: BuildContext, @@ -527,7 +528,7 @@ def transform( return self.__wrap(transform) - def mapping_value(self, key: TypeRef, *, mutable: bool = False) -> ClassTypeRefBuilder: + def mapping_value(self, key: TypeRef, *, mutable: bool = False) -> TypeRefBuilder: def transform( inner: t.Callable[[BuildContext, TypeInfo], Expr], context: BuildContext, @@ -537,7 +538,7 @@ def transform( return self.__wrap(transform) - def dict_value(self, key: TypeRef) -> ClassTypeRefBuilder: + def dict_value(self, key: TypeRef) -> TypeRefBuilder: def transform( inner: t.Callable[[BuildContext, TypeInfo], Expr], context: BuildContext, @@ -547,7 +548,7 @@ def transform( return self.__wrap(transform) - def context_manager(self, *, is_async: bool = False) -> ClassTypeRefBuilder: + def context_manager(self, *, is_async: bool = False) -> TypeRefBuilder: def transform( inner: t.Callable[[BuildContext, TypeInfo], Expr], context: BuildContext, @@ -557,7 +558,7 @@ def transform( return self.__wrap(transform) - def iterator(self, *, is_async: bool = False) -> ClassTypeRefBuilder: + def iterator(self, *, is_async: bool = False) -> TypeRefBuilder: def transform( inner: t.Callable[[BuildContext, TypeInfo], Expr], context: BuildContext, @@ -567,7 +568,7 @@ def transform( return self.__wrap(transform) - def iterable(self, *, is_async: bool = False) -> ClassTypeRefBuilder: + def iterable(self, *, is_async: bool = False) -> TypeRefBuilder: def transform( inner: t.Callable[[BuildContext, TypeInfo], Expr], context: BuildContext, @@ -577,7 +578,7 @@ def transform( return self.__wrap(transform) - def type_params(self, *params: TypeRef) -> ClassTypeRefBuilder: + def type_params(self, *params: TypeRef) -> TypeRefBuilder: def transform( inner: t.Callable[[BuildContext, TypeInfo], Expr], context: BuildContext, @@ -611,13 +612,13 @@ def __ident(self, context: BuildContext, info: TypeInfo) -> Expr: def __wrap( self, transform: t.Callable[[t.Callable[[BuildContext, TypeInfo], Expr], BuildContext, TypeInfo], Expr], - ) -> ClassTypeRefBuilder: + ) -> TypeRefBuilder: return self.__class__(self._context, self.__info, partial(transform, self.__transform)) class AnnotationASTBuilder(_BaseBuilder): - def type_ref(self, origin: t.Union[TypeInfo, RuntimeType]) -> ClassTypeRefBuilder: - return ClassTypeRefBuilder( + def type_ref(self, origin: t.Union[TypeInfo, RuntimeType]) -> TypeRefBuilder: + return TypeRefBuilder( context=self._context, info=origin if isinstance(origin, (NamedTypeInfo, LiteralTypeInfo)) @@ -663,10 +664,13 @@ def literal_type(self, *values: t.Union[str, Expr]) -> Expr: def optional_type(self, of_type: TypeRef) -> Expr: return self.generic_type(predef().optional, of_type) - def union_type(self, *params: TypeRef) -> Expr: + def union_type(self, *params: TypeRef, normalize: bool = False) -> Expr: if not params: return self._normalize_annotation(predef().no_return) + if normalize and len(params) == 1: + return self._normalize_annotation(params[0]) + return self.generic_type(predef().union, *params) def tuple_type(self, *params: TypeRef, normalize: bool = False) -> Expr: @@ -706,34 +710,6 @@ def _on_expression(self, expr: Expr) -> None: pass -class TypeVarRefBuilder(_BaseBuilder, ASTExpressionBuilder): - def __init__( - self, - context: BuildContext, - name: str, - mode: t.Literal["invariant", "covariant", "contravariant"] = "invariant", - lower: t.Optional[TypeRef] = None, - ) -> None: - super().__init__(context) - self.__name = name - self.__mode = mode - self.__lower = lower - - def __enter__(self) -> Self: - return self - - def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None: - pass - - @override - def build_expr(self) -> ast.expr: - return ast.Name(id=self.__name) - - @override - def build_annotation(self) -> ast.expr: - return ast.Name(id=self.__name) - - @dataclass(frozen=True) class Comprehension: target: Expr @@ -1111,77 +1087,6 @@ def __build_comprehensions( ] -class TypeAliasStatementASTBuilder(AnnotationASTBuilder, ASTStatementBuilder, TypeDefinitionBuilder): - def __init__(self, context: BuildContext, name: str, annotation: t.Optional[TypeRef] = None) -> None: - super().__init__(context) - self.__info = NamedTypeInfo(name=name, module=self._context.module, namespace=self._context.namespace) - self.__annotation: t.Optional[TypeRef] = None - self.__type_vars = list[TypeVar]() - - def __enter__(self) -> Self: - return self - - def __exit__(self, exc_type: object, exc_value: object, exc_traceback: object) -> None: - pass - - @override - @property - def info(self) -> TypeInfo: - return self.__info - - @override - def ref(self) -> ClassTypeRefBuilder: - return ClassTypeRefBuilder(self._context, self.__info) - - def assign(self, annotation: TypeRef) -> None: - self.__annotation = annotation - self._context.extend_body(self.build_stmt()) - - def type_var(self, name: str) -> TypeVarRefBuilder: - type_var = TypeVarRefBuilder(self._context, name) - self.__type_vars.append(TypeVar(name=name)) - return type_var - - def type_params(self, *params: TypeRef) -> Expr: - return self.ref().type_params(*params) - - # NOTE: workaround for passing mypy typings in CI for python 3.12 - if sys.version_info >= (3, 12): - - @override - def build_stmt(self) -> t.Sequence[ast.stmt]: - if self.__annotation is None: - msg = "type alias is not set" - raise ValueError(msg, self) - - return [ - ast.TypeAlias( - name=ast.Name(id=self.__info.name), - value=self._normalize_annotation(self.__annotation), - type_params=[ - ast.TypeVar( - name=tv.name, - bound=self._normalize_annotation(tv.bound) if tv.bound is not None else None, - ) - for tv in self.__type_vars - ], - ) - ] - - else: - - def build_stmt(self, name: str, annotation: TypeRef) -> t.Sequence[ast.stmt]: - return [ - ast.AnnAssign( - target=ast.Name(id=name), - # TODO: add typing.TypeAlias - annotation=self._normalize_annotation(...), - value=self._normalize_annotation(annotation), - simple=1, - ) - ] - - class _NestedBlockASTBuilder(_BaseBuilder, ASTStatementBuilder, metaclass=abc.ABCMeta): def __init__(self, context: BuildContext, *, allow_implicit_enter: bool = True) -> None: super().__init__(context) @@ -1211,7 +1116,7 @@ def __enter_block(self, body: list[ast.stmt]) -> t.Iterator[ScopeASTBuilder]: def __enter_implicitly(self, body: list[ast.stmt]) -> t.Iterator[ScopeASTBuilder]: if not self.__allow_implicit_enter: msg = "can't enter into the nested block implicitly" - raise RuntimeError(msg, self, body) + raise ASTBuildError(msg, self, body) with self, self.__enter_block(body) as scope: yield scope @@ -1420,22 +1325,186 @@ def build_stmt(self) -> t.Sequence[ast.stmt]: ] -@dataclass(frozen=True) -class TypeVar: - name: str - bound: t.Optional[TypeRef] = None - - -class ClassScopeASTBuilder(ScopeASTBuilder, TypeDefinitionBuilder): +class TypeVarBuilder(_BaseBuilder, ASTStatementBuilder): def __init__( self, context: BuildContext, - header: ClassStatementASTBuilder, - type_vars: t.MutableSequence[TypeVar], + name: str, + variance: t.Optional[TypeVarVariance] = None, + lower: t.Optional[t.Sequence[TypeRef]] = None, ) -> None: + super().__init__(context) + self.__name = name + self.__module = self._context.module + self.__namespace = self._context.namespace + self.__variance = variance + self.__lower = lower or [] + + def __enter__(self) -> TypeRef: + return NamedTypeInfo( + name=self.__name, + module=self.__module, + namespace=self.__namespace, + ) + + def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None: + pass + + def invariant(self) -> Self: + self.__variance = "invariant" + return self + + def covariant(self) -> Self: + self.__variance = "covariant" + return self + + def contravariant(self) -> Self: + self.__variance = "contravariant" + return self + + def lower(self, *types: TypeRef) -> Self: + self.__lower = types + return self + + @override + def build_stmt(self) -> t.Sequence[ast.stmt]: + return [ + ast.Assign( + targets=[ast.Name(id=self.__name)], + value=ast.Call( + func=self._normalize_expr(predef().type_var), + args=[ast.Constant(self.__name)], + keywords=[ + ast.keyword( + arg="covariant", + value=ast.Constant(value=self.__variance == "covariant"), + ), + ast.keyword( + arg="contravariant", + value=ast.Constant(value=self.__variance == "contravariant"), + ), + ast.keyword( + arg="bound", + value=self._normalize_expr( + self._scope.union_type(*self.__lower, normalize=True) + if self.__lower + else self._scope.none() + ), + ), + ], + ), + lineno=0, + ), + ] + + # NOTE: workaround for passing mypy typings in CI for python 3.12 + if sys.version_info >= (3, 12): + + def build_type_param(self) -> ast.type_param: + return ast.TypeVar( + name=self.__name, + bound=self._normalize_expr(self._scope.union_type(*self.__lower, normalize=True)) + if self.__lower + else None, + ) + + +class TypeAliasDefinitionBuilder(AnnotationASTBuilder, TypeDefinitionBuilder, ASTStatementBuilder): + def __init__(self, context: BuildContext, info: NamedTypeInfo) -> None: + super().__init__(context) + self.__info = info + self.__expr: t.Optional[TypeRef] = None + self.__type_vars = list[TypeVarBuilder]() + + @override + @property + def info(self) -> TypeInfo: + return self.__info + + @override + def ref(self) -> TypeRefBuilder: + return TypeRefBuilder(self._context, self.__info) + + def assign(self, expr: TypeRef) -> None: + self.__expr = expr + + def type_var(self, name: str) -> TypeVarBuilder: + type_var = TypeVarBuilder(self._context, name) + self.__type_vars.append(type_var) + return type_var + + def type_params(self, *params: TypeRef) -> Expr: + return self.ref().type_params(*params) + + # NOTE: workaround for passing mypy typings in CI for python 3.12 + if sys.version_info >= (3, 12): + + @override + def build_stmt(self) -> t.Sequence[ast.stmt]: + if self.__expr is None: + msg = "type alias expression is not set" + raise IncompleteStatementError(msg, self) + + return [ + ast.TypeAlias( + name=ast.Name(id=self.__info.name), + value=self._normalize_annotation(self.__expr), + type_params=[tv.build_type_param() for tv in self.__type_vars], + ) + ] + + else: + + @override + def build_stmt(self) -> t.Sequence[ast.stmt]: + if self.__expr is None: + msg = "type alias expression is not set" + raise IncompleteStatementError(msg, self) + + stmts = [stmt for tv in self.__type_vars for stmt in tv.build_stmt()] + stmts.append( + ast.AnnAssign( + target=ast.Name(id=self.__info.name), + # TODO: add typing.TypeAlias + annotation=self._normalize_annotation(predef().type_alias), + value=self._normalize_expr(self.__expr), + simple=1, + ) + ) + + return stmts + + +class TypeAliasStatementASTBuilder(_BaseBuilder, ASTStatementBuilder): + def __init__(self, context: BuildContext, name: str) -> None: + super().__init__(context) + self.__annotation = TypeAliasDefinitionBuilder( + context=self._context, + info=NamedTypeInfo(name=name, module=self._context.module, namespace=self._context.namespace), + ) + + def __enter__(self) -> TypeAliasDefinitionBuilder: + self._context.enter_scope(self.__annotation.info.name, []) + return self.__annotation + + def __exit__(self, exc_type: object, exc_value: object, exc_traceback: object) -> None: + if exc_type is None: + self._context.leave_scope() + self._context.extend_body(self.build_stmt()) + + def assign(self, expr: TypeRef) -> None: + self.__annotation.assign(expr) + self._context.extend_body(self.build_stmt()) + + @override + def build_stmt(self) -> t.Sequence[ast.stmt]: + return self.__annotation.build_stmt() + + +class ClassScopeASTBuilder(ScopeASTBuilder, TypeDefinitionBuilder): + def __init__(self, context: BuildContext, header: ClassStatementASTBuilder) -> None: super().__init__(context) self.__header = header - self.__type_vars = type_vars @override @property @@ -1443,13 +1512,11 @@ def info(self) -> TypeInfo: return self.__header.info @override - def ref(self) -> ClassTypeRefBuilder: + def ref(self) -> TypeRefBuilder: return self.__header.ref() - def type_var(self, name: str) -> TypeVarRefBuilder: - type_var = TypeVarRefBuilder(self._context, name) - self.__type_vars.append(TypeVar(name=name)) - return type_var + def type_var(self, name: str) -> TypeVarBuilder: + return self.__header.type_var(name) def method_def(self, name: str) -> MethodStatementASTBuilder: return MethodStatementASTBuilder(self._context, name) @@ -1526,14 +1593,14 @@ def __init__(self, context: BuildContext, name: str) -> None: self.__bases = list[TypeRef]() self.__decorators = list[TypeRef]() self.__keywords = dict[str, TypeRef]() - self.__type_vars = list[TypeVar]() + self.__type_vars = list[TypeVarBuilder]() self.__docs = list[str]() self.__body = list[ast.stmt]() @override def __enter__(self) -> ClassScopeASTBuilder: self._context.enter_scope(self.__info.name, self.__body) - return ClassScopeASTBuilder(self._context, self, self.__type_vars) + return ClassScopeASTBuilder(self._context, self) @override def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None: @@ -1547,14 +1614,13 @@ def info(self) -> TypeInfo: return self.__info @override - def ref(self) -> ClassTypeRefBuilder: - return ClassTypeRefBuilder(self._context, self.__info) - - if sys.version_info >= (3, 12): + def ref(self) -> TypeRefBuilder: + return TypeRefBuilder(self._context, self.__info) - def type_param(self, name: str, bound: t.Optional[TypeRef] = None) -> Self: - self.__type_vars.append(TypeVar(name=name, bound=bound)) - return self + def type_var(self, name: str) -> TypeVarBuilder: + type_var = TypeVarBuilder(self._context, name) + self.__type_vars.append(type_var) + return type_var def docstring(self, value: t.Optional[str]) -> Self: if value: @@ -1603,13 +1669,7 @@ def build_stmt(self) -> t.Sequence[ast.stmt]: keywords=self.__build_keywords(), body=self._normalize_body(self.__body, self.__docs), decorator_list=self.__build_decorators(), - type_params=[ - ast.TypeVar( - name=type_var.name, - bound=self._normalize_expr(type_var.bound) if type_var.bound is not None else None, - ) - for type_var in self.__type_vars - ], + type_params=[tv.build_type_param() for tv in self.__type_vars], ), ] @@ -1617,7 +1677,13 @@ def build_stmt(self) -> t.Sequence[ast.stmt]: # noinspection PyArgumentList @override def build_stmt(self) -> t.Sequence[ast.stmt]: - return [ + stmts = [stmt for tv in self.__type_vars for stmt in tv.build_stmt()] + + if self.__type_vars: + # TODO: add type params to generic + self.__bases.insert(0, predef().generic) + + stmts.append( ast.ClassDef( name=self.__info.name, bases=self.__build_bases(), @@ -1625,7 +1691,9 @@ def build_stmt(self) -> t.Sequence[ast.stmt]: body=self._normalize_body(self.__body, self.__docs), decorator_list=self.__build_decorators(), ), - ] + ) + + return stmts def __build_bases(self) -> list[ast.expr]: return [self._normalize_expr(base) for base in self.__bases] diff --git a/src/astlab/resolver.py b/src/astlab/resolver.py index 95e873e..94fdf46 100644 --- a/src/astlab/resolver.py +++ b/src/astlab/resolver.py @@ -20,6 +20,7 @@ ellipsis_type_info, none_type_info, ) +from astlab.types.model import EnumTypeInfo, TypeVarInfo class DefaultASTResolver(ASTResolver): @@ -104,10 +105,6 @@ def set_current_scope( self.__dependencies = dependencies def __type_info_expr(self, info: TypeInfo, tail: t.Sequence[str] = (), *, is_annotation: bool = False) -> ast.expr: - if isinstance(info, NamedTypeInfo) and info.type_vars: - msg = "can't build expr for type with type vars" - raise ValueError(msg, info) - resolved_info = self.__resolve_dependency(info) return self.__type_info_attr(resolved_info, tail, is_annotation=is_annotation) @@ -126,22 +123,24 @@ def __type_info_attr(self, info: TypeInfo, tail: t.Sequence[str] = (), *, is_ann ) origin = self.__chain_attr(ast.Name(id=head), *middle, *tail) - args = ( + params = ( ( [self.__type_info_attr(param, is_annotation=is_annotation) for param in info.type_params] if isinstance(info, NamedTypeInfo) + else [] + if isinstance(info, EnumTypeInfo) else [ast.Constant(value=value) for value in info.values] ) - if not isinstance(info, ModuleInfo) + if not isinstance(info, (ModuleInfo, TypeVarInfo)) else [] ) return ( ast.Subscript( value=origin, - slice=ast.Tuple(elts=args) if len(args) > 1 else args[0], + slice=ast.Tuple(elts=params) if len(params) > 1 else params[0], ) - if args + if params else origin ) @@ -153,7 +152,7 @@ def __resolve_dependency(self, info: TypeInfo) -> TypeInfo: return info - elif isinstance(info, NamedTypeInfo): + elif isinstance(info, (TypeVarInfo, NamedTypeInfo, EnumTypeInfo)): if info.module == self.__module: ns = ( info.namespace[len(self.__namespace) :] @@ -165,10 +164,14 @@ def __resolve_dependency(self, info: TypeInfo) -> TypeInfo: self.__dependencies.add(info.module) ns = info.namespace - return replace( - info, - namespace=ns, - type_params=tuple(self.__resolve_dependency(param) for param in info.type_params), + return ( + replace( + info, + namespace=ns, + type_params=tuple(self.__resolve_dependency(param) for param in info.type_params), + ) + if isinstance(info, NamedTypeInfo) + else replace(info, namespace=ns) ) elif isinstance(info, LiteralTypeInfo): diff --git a/src/astlab/types/__init__.py b/src/astlab/types/__init__.py index 2a20b44..4618371 100644 --- a/src/astlab/types/__init__.py +++ b/src/astlab/types/__init__.py @@ -1,4 +1,6 @@ __all__ = [ + "EnumTypeInfo", + "EnumTypeValue", "LiteralTypeInfo", "LiteralTypeValue", "ModuleInfo", @@ -10,6 +12,9 @@ "TypeInfo", "TypeInspector", "TypeLoader", + "TypeLoaderError", + "TypeVarInfo", + "TypeVarVariance", "builtins_module_info", "ellipsis_type_info", "none_type_info", @@ -19,8 +24,10 @@ from astlab.types.annotator import TypeAnnotator from astlab.types.inspector import TypeInspector -from astlab.types.loader import ModuleLoader, TypeLoader +from astlab.types.loader import ModuleLoader, TypeLoader, TypeLoaderError from astlab.types.model import ( + EnumTypeInfo, + EnumTypeValue, LiteralTypeInfo, LiteralTypeValue, ModuleInfo, @@ -28,6 +35,8 @@ PackageInfo, RuntimeType, TypeInfo, + TypeVarInfo, + TypeVarVariance, builtins_module_info, ellipsis_type_info, none_type_info, diff --git a/src/astlab/types/annotator.py b/src/astlab/types/annotator.py index 130abb2..0833759 100644 --- a/src/astlab/types/annotator.py +++ b/src/astlab/types/annotator.py @@ -5,19 +5,24 @@ ] import ast +import enum import typing as t from collections import deque from dataclasses import replace from astlab._typing import assert_never, override from astlab.cache import lru_cache_method -from astlab.types.loader import ModuleLoader +from astlab.types.loader import TypeLoader, TypeLoaderError from astlab.types.model import ( + EnumTypeInfo, + EnumTypeValue, LiteralTypeInfo, LiteralTypeValue, ModuleInfo, NamedTypeInfo, + RuntimeType, TypeInfo, + TypeVarInfo, builtins_module_info, ellipsis_type_info, none_type_info, @@ -27,35 +32,44 @@ class TypeAnnotator: """Provides annotation string form type info and vice versa (parses annotation to type info).""" - def __init__(self, loader: t.Optional[ModuleLoader] = None) -> None: - self.__loader = loader or ModuleLoader() + def __init__(self, loader: t.Optional[TypeLoader] = None) -> None: + self.__loader = loader or TypeLoader() @lru_cache_method() def annotate(self, info: TypeInfo) -> str: if isinstance(info, ModuleInfo): - return "builtins.module" + annotation = "builtins.module" + + elif isinstance(info, TypeVarInfo): + annotation = info.name elif isinstance(info, NamedTypeInfo): if info == none_type_info(): - return "None" + annotation = "None" - if info == ellipsis_type_info(): - return "..." + elif info == ellipsis_type_info(): + annotation = "..." - if not info.type_params: - return info.qualname + elif info.type_params: + # TODO: fix recursive type + params = ", ".join(self.annotate(tp) for tp in info.type_params) + annotation = f"{info.qualname}[{params}]" - # TODO: fix recursive type - params = ", ".join(self.annotate(tp) for tp in info.type_params) - return f"{info.qualname}[{params}]" + else: + annotation = info.qualname elif isinstance(info, LiteralTypeInfo): vals = ", ".join(repr(v) for v in info.values) - return f"typing.Literal[{vals}]" + annotation = f"typing.Literal[{vals}]" + + elif isinstance(info, EnumTypeInfo): + annotation = info.qualname else: assert_never(info) + return annotation + def parse(self, qualname: str) -> TypeInfo: node = ast.parse(qualname) @@ -67,7 +81,7 @@ def parse(self, qualname: str) -> TypeInfo: class _ExprParser(ast.NodeVisitor): - def __init__(self, loader: ModuleLoader) -> None: + def __init__(self, loader: TypeLoader) -> None: self.__loader = loader self.__parts = deque[str]() self.__info: t.Optional[TypeInfo] = None @@ -89,16 +103,36 @@ def visit_Constant(self, node: ast.Constant) -> None: def visit_Name(self, node: ast.Name) -> None: self.__parts.appendleft(node.id) *parts, name = self.__parts + named_type_info = self.__extract_named_type_info(node, parts, name) + rtt: RuntimeType = self.__loader.load(named_type_info) + + if isinstance(rtt, t.TypeVar): # type: ignore[misc] + self.__set_result( + TypeVarInfo( + name=named_type_info.name, + module=named_type_info.module, + namespace=named_type_info.namespace, + ) + ) + + elif isinstance(rtt, type) and issubclass(rtt, enum.Enum): # type: ignore[misc] + self.__set_result( + EnumTypeInfo( + name=named_type_info.name, + module=named_type_info.module, + namespace=named_type_info.namespace, + values=tuple( + EnumTypeValue( + name=enum_value.name, + value=enum_value.value, # type: ignore[misc] + ) + for enum_value in rtt + ), + ) + ) - module = self.__extract_module_info(node, parts) - - info = NamedTypeInfo( - name=name, - module=module, - namespace=tuple(parts[len(module.parts) :]), - ) - - self.__set_result(info) + else: + self.__set_result(named_type_info) @override def visit_Attribute(self, node: ast.Attribute) -> None: @@ -121,7 +155,6 @@ def visit_Subscript(self, node: ast.Subscript) -> None: if isinstance(node.slice, ast.Tuple) else self.__parse_type_params(node.slice) ), - type_vars=(), ) def parse(self, node: ast.AST) -> TypeInfo: @@ -140,18 +173,30 @@ def __set_result(self, info: TypeInfo) -> None: self.__info = info - def __extract_module_info(self, node: ast.AST, parts: t.Sequence[str]) -> ModuleInfo: + def __extract_named_type_info( + self, + node: ast.AST, + parts: t.Sequence[str], + name: str, + ) -> NamedTypeInfo: if not parts: - return builtins_module_info() + return NamedTypeInfo( + name=name, + module=builtins_module_info(), + ) for i in range(len(parts), 0, -1): module = ModuleInfo.build(*parts[:i]) try: self.__loader.load(module) - except ImportError: + except TypeLoaderError: continue else: - return module + return NamedTypeInfo( + name=name, + module=module, + namespace=tuple(parts[len(module.parts) :]), + ) msg = "invalid module parts" raise ValueError(msg, parts, ast.dump(node)) diff --git a/src/astlab/types/inspector.py b/src/astlab/types/inspector.py index 00efb0e..3b69889 100644 --- a/src/astlab/types/inspector.py +++ b/src/astlab/types/inspector.py @@ -7,11 +7,24 @@ "TypeInspector", ] +import enum import sys import typing as t +from dataclasses import replace from astlab.cache import lru_cache_method -from astlab.types.model import LiteralTypeInfo, ModuleInfo, NamedTypeInfo, RuntimeType, TypeInfo, typing_module_info +from astlab.types.model import ( + EnumTypeInfo, + EnumTypeValue, + LiteralTypeInfo, + ModuleInfo, + NamedTypeInfo, + RuntimeType, + TypeInfo, + TypeVarInfo, + typing_module_info, +) +from astlab.types.predef import get_predef class TypeInspector: @@ -33,6 +46,33 @@ def inspect(self, type_: RuntimeType) -> TypeInfo: origin, type_params = self.__unpack_generic(type_) module, namespace, name = self.__get_module_naming(origin) + if isinstance(origin, t.TypeVar): + return TypeVarInfo( + name=origin.__name__, + module=module, + namespace=namespace, + variance=( + "covariant" + if origin.__covariant__ + else "contravariant" + if origin.__contravariant__ + else "invariant" + ), + lower=self.inspect(origin.__bound__) + if origin.__bound__ is not None + else replace(get_predef().union, type_params=tuple(self.inspect(co) for co in origin.__constraints__)) + if origin.__constraints__ + else None, + ) + + if isinstance(origin, type) and issubclass(origin, enum.Enum): + return EnumTypeInfo( + name=name, + module=module, + namespace=tuple(namespace), + values=tuple(EnumTypeValue(name=enum_value.name, value=enum_value.value) for enum_value in origin), + ) + return NamedTypeInfo( name=name, module=module, diff --git a/src/astlab/types/loader.py b/src/astlab/types/loader.py index 4d90933..60fc806 100644 --- a/src/astlab/types/loader.py +++ b/src/astlab/types/loader.py @@ -17,12 +17,14 @@ from astlab.cache import lru_cache_method from astlab.reader import import_module_path from astlab.types.model import ( + EnumTypeInfo, LiteralTypeInfo, ModuleInfo, NamedTypeInfo, PackageInfo, RuntimeType, TypeInfo, + TypeVarInfo, ellipsis_type_info, none_type_info, ) @@ -68,42 +70,79 @@ def __init__(self, module: t.Optional[ModuleLoader] = None) -> None: @lru_cache_method() def load(self, info: TypeInfo) -> RuntimeType: if isinstance(info, ModuleInfo): - return self.__module.load(info) + rtt = self.__load_module(info) + + elif isinstance(info, TypeVarInfo): + rtt = self.__load_type_var(info) elif isinstance(info, NamedTypeInfo): if info == none_type_info(): - return None + rtt = None elif info == ellipsis_type_info(): - return Ellipsis + rtt = Ellipsis - try: - value: object = self.__module.load(info.module) + else: + rtt = self.__load_named_type(info) - except ImportError as err: - msg = "can't load module" - raise TypeLoaderError(msg, info) from err + elif isinstance(info, LiteralTypeInfo): + rtt = getitem(t.Literal, info.values) - try: - for name in info.namespace: - value = getattr(value, name) + elif isinstance(info, EnumTypeInfo): + rtt = self.__load_type_by_name(info) - # NOTE: need to check that we loaded a type. - type_: object = getattr(value, info.name) + else: + assert_never(info) - except AttributeError as err: - msg = "can't module has no attribute" - raise TypeLoaderError(msg, info) from err + return rtt - if not info.type_params: - return type_ + def clear_cache(self) -> None: + self.load.cache_clear() # type: ignore[attr-defined] + self.__module.clear_cache() - # TODO: fix recursive type load - type_params = tuple(self.load(tp) for tp in info.type_params) - return getitem(type_, type_params) if len(type_params) > 1 else getitem(type_, type_params[0]) # type: ignore[call-overload, misc] + def __load_module(self, info: ModuleInfo) -> RuntimeType: + try: + return self.__module.load(info) - elif isinstance(info, LiteralTypeInfo): - return getitem(t.Literal, info.values) + except ImportError as err: + msg = "module can't be loaded" + raise TypeLoaderError(msg, info) from err - else: - assert_never(info) + def __load_type_var(self, info: TypeVarInfo) -> RuntimeType: + # noinspection PyTypeHints + return t.TypeVar( + name=info.name, + bound=self.load(info.lower) if info.lower else None, + covariant=info.variance == "covariant", + contravariant=info.variance == "contravariant", + ) + + def __load_named_type(self, info: NamedTypeInfo) -> RuntimeType: + rtt = self.__load_type_by_name(info) + if not info.type_params: + return rtt + + # TODO: fix recursive type load + loaded_type_params = tuple(self.load(tp) for tp in info.type_params) + + try: + return ( + getitem(rtt, loaded_type_params) if len(loaded_type_params) > 1 else getitem(rtt, loaded_type_params[0]) # type: ignore[arg-type,misc] + ) + + except TypeError as err: + msg = "type params can't be applied to type" + raise TypeLoaderError(msg, info) from err + + def __load_type_by_name(self, info: t.Union[NamedTypeInfo, EnumTypeInfo]) -> RuntimeType: + container: object = self.load(info.module) + try: + for name in info.namespace: + container = getattr(container, name) + + # NOTE: need to check that we loaded a type. + return getattr(container, info.name) # type: ignore[misc] + + except AttributeError as err: + msg = "module has not attribute" + raise TypeLoaderError(msg, info) from err diff --git a/src/astlab/types/model.py b/src/astlab/types/model.py index 1ad90b3..1e2cfa4 100644 --- a/src/astlab/types/model.py +++ b/src/astlab/types/model.py @@ -1,6 +1,8 @@ from __future__ import annotations __all__ = [ + "EnumTypeInfo", + "EnumTypeValue", "LiteralTypeInfo", "LiteralTypeValue", "ModuleInfo", @@ -8,6 +10,8 @@ "PackageInfo", "RuntimeType", "TypeInfo", + "TypeVarInfo", + "TypeVarVariance", "builtins_module_info", "ellipsis_type_info", "none_type_info", @@ -15,7 +19,7 @@ ] import typing as t -from dataclasses import dataclass, field, replace +from dataclasses import dataclass, field from functools import cache, cached_property from pathlib import Path from types import GenericAlias, ModuleType @@ -144,13 +148,32 @@ def typing_module_info() -> ModuleInfo: return ModuleInfo("typing") +TypeVarVariance: TypeAlias = t.Literal["invariant", "covariant", "contravariant"] + + +@dataclass(frozen=True) +class TypeVarInfo: + name: str + module: ModuleInfo + namespace: t.Sequence[str] = field(default_factory=tuple) + variance: TypeVarVariance = "invariant" + lower: t.Optional[TypeInfo] = None + + @cached_property + def parts(self) -> t.Sequence[str]: + return *self.module.parts, *self.namespace, self.name + + @cached_property + def qualname(self) -> str: + return ".".join(self.parts) + + @dataclass(frozen=True) class NamedTypeInfo: name: str module: ModuleInfo namespace: t.Sequence[str] = field(default_factory=tuple) type_params: t.Sequence[TypeInfo] = field(default_factory=tuple) - type_vars: t.Sequence[str] = field(default_factory=tuple) @classmethod def build( @@ -158,7 +181,6 @@ def build( module: t.Union[str, t.Sequence[str], ModuleInfo], name: str, type_params: t.Optional[t.Sequence[TypeInfo]] = None, - type_vars: t.Optional[t.Sequence[str]] = None, ) -> NamedTypeInfo: *namespace, type_name = name.split(".") if not type_name: @@ -174,7 +196,6 @@ def build( else ModuleInfo.build(*module), namespace=tuple(namespace), type_params=tuple(type_params or ()), - type_vars=tuple(type_vars or ()), ) @cached_property @@ -185,20 +206,6 @@ def parts(self) -> t.Sequence[str]: def qualname(self) -> str: return ".".join(self.parts) - def with_type_params(self, *infos: TypeInfo) -> NamedTypeInfo: - if not infos: - return self - - if len(infos) > len(self.type_vars): - msg = "too many type parameters" - raise ValueError(msg, infos, self) - - return replace( - self, - type_params=(*self.type_params, *infos), - type_vars=self.type_vars[len(infos) :], - ) - @cache # type: ignore[misc] def none_type_info() -> NamedTypeInfo: @@ -215,7 +222,6 @@ def ellipsis_type_info() -> NamedTypeInfo: @dataclass(frozen=True) class LiteralTypeInfo: - # TODO: enum values values: t.Sequence[LiteralTypeValue] @cached_property @@ -239,4 +245,49 @@ def qualname(self) -> str: return ".".join(self.parts) -TypeInfo: TypeAlias = t.Union[ModuleInfo, NamedTypeInfo, LiteralTypeInfo] +@dataclass(frozen=True, kw_only=True) +class EnumTypeValue: + name: str + value: LiteralTypeValue + + +@dataclass(frozen=True) +class EnumTypeInfo: + name: str + module: ModuleInfo + namespace: t.Sequence[str] = field(default_factory=tuple) + values: t.Sequence[EnumTypeValue] = field(default_factory=tuple) + + @classmethod + def build( + cls, + module: t.Union[str, t.Sequence[str], ModuleInfo], + name: str, + values: t.Sequence[EnumTypeValue], + ) -> EnumTypeInfo: + *namespace, type_name = name.split(".") + if not type_name: + msg = "type name can't be empty" + raise ValueError(msg, name) + + return EnumTypeInfo( + name=type_name, + module=module + if isinstance(module, ModuleInfo) + else ModuleInfo.from_str(module) + if isinstance(module, str) + else ModuleInfo.build(*module), + namespace=tuple(namespace), + values=values, + ) + + @cached_property + def parts(self) -> t.Sequence[str]: + return *self.module.parts, *self.namespace, self.name + + @cached_property + def qualname(self) -> str: + return ".".join(self.parts) + + +TypeInfo: TypeAlias = t.Union[ModuleInfo, TypeVarInfo, NamedTypeInfo, LiteralTypeInfo, EnumTypeInfo] diff --git a/src/astlab/types/predef.py b/src/astlab/types/predef.py index 0955ae7..99acfb6 100644 --- a/src/astlab/types/predef.py +++ b/src/astlab/types/predef.py @@ -30,6 +30,10 @@ class Predef: def typing_module(self) -> ModuleInfo: return typing_module_info() + @cached_property + def enum_module(self) -> ModuleInfo: + return ModuleInfo("enum") + @cached_property def dataclasses_module(self) -> ModuleInfo: return ModuleInfo("dataclasses") @@ -230,6 +234,14 @@ def async_iterable(self) -> NamedTypeInfo: def literal(self) -> NamedTypeInfo: return NamedTypeInfo("Literal", self.typing_module) + @cached_property + def enum(self) -> NamedTypeInfo: + return NamedTypeInfo("Enum", self.enum_module) + + @cached_property + def enum_auto(self) -> NamedTypeInfo: + return NamedTypeInfo("auto", self.enum_module) + @cached_property def no_return(self) -> NamedTypeInfo: return NamedTypeInfo("NoReturn", self.typing_module) diff --git a/tests/conftest.py b/tests/conftest.py index 93f19e6..4c7a266 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,8 +30,8 @@ def type_loader() -> TypeLoader: @pytest.fixture -def type_annotator(module_loader: ModuleLoader) -> TypeAnnotator: - return TypeAnnotator(module_loader) +def type_annotator(type_loader: TypeLoader) -> TypeAnnotator: + return TypeAnnotator(type_loader) @pytest.fixture diff --git a/tests/stub/types.py b/tests/stub/types.py index 4f410ed..61237c9 100644 --- a/tests/stub/types.py +++ b/tests/stub/types.py @@ -1,3 +1,4 @@ +import enum import typing as t from dataclasses import dataclass @@ -22,6 +23,11 @@ class Z: pass +class StubEnum(enum.Enum): + FOO = enum.auto() + BAR = enum.auto() + + class StubCM(t.ContextManager["StubCM"]): @override def __exit__(self, exc_type: object, exc_value: object, traceback: object, /) -> None: @@ -38,5 +44,7 @@ class StubNode(t.Generic[T]): StubUnionAlias: TypeAlias = t.Union[StubFoo, StubBar[StubInt], StubX] +StubRecursive: TypeAlias = t.Union[T, t.Sequence["StubRecursive[T]"]] + # TODO: enable after python 3.9, 3.10, 3.11 version support stop drop. -# type StubNumber = int | float # noqa: ERA001 +# type StubRecursive[T] = T | t.Sequence[StubRecursive[T]] # noqa: ERA001 diff --git a/tests/unit/test_builder.py b/tests/unit/test_builder.py index b0b8e4f..62c55f7 100644 --- a/tests/unit/test_builder.py +++ b/tests/unit/test_builder.py @@ -12,6 +12,16 @@ _PARAMS: t.Final[list[ParameterSet]] = [] +@pytest.mark.parametrize(("builder", "expected_code"), _PARAMS) +def test_module_build(builder: t.Callable[[], ModuleASTBuilder], normalized_expected_code: str) -> None: + assert builder().render() == normalized_expected_code + + +@pytest.fixture +def normalized_expected_code(expected_code: str) -> str: + return ast.unparse(parse_module(expected_code)) + + def _to_param( marks: t.Optional[t.Sequence[MarkDecorator]] = None, ) -> t.Callable[[t.Callable[[], ModuleASTBuilder]], t.Callable[[], ModuleASTBuilder]]: @@ -27,11 +37,6 @@ def inner(func: t.Callable[[], ModuleASTBuilder]) -> t.Callable[[], ModuleASTBui return inner -@pytest.mark.parametrize(("builder", "expected"), _PARAMS) -def test_module_build(builder: t.Callable[[], ModuleASTBuilder], expected: str) -> None: - assert builder().render() == ast.unparse(parse_module(expected)) - - @_to_param() def build_empty_module() -> ModuleASTBuilder: """""" @@ -140,6 +145,7 @@ class MyOptions: my_optional_str: typing.Optional[builtins.str] my_optional_list_of_int: typing.Optional[builtins.list[builtins.int]] my_list_of_optional_int: builtins.list[typing.Optional[builtins.int]] + my_optional_str_default: typing.Optional[builtins.str] = None """ with build_module("opts") as mod, mod.class_def("MyOptions") as opt: @@ -147,6 +153,7 @@ class MyOptions: opt.field_def("my_optional_str", opt.type_ref(str).optional()) opt.field_def("my_optional_list_of_int", opt.type_ref(int).list().optional()) opt.field_def("my_list_of_optional_int", opt.type_ref(int).optional().list()) + opt.field_def("my_optional_str_default", opt.type_ref(str).optional(), opt.none()) return mod @@ -286,14 +293,138 @@ def build_index_slice() -> ModuleASTBuilder: return mod -@_to_param() -def build_type_alias() -> ModuleASTBuilder: +@_to_param( + marks=( + pytest.mark.skipif( + condition="sys.version_info >= (3, 12)", + reason="syntax `type XXX = YYY` was introduced since python version 3.12", + ), + ), +) +def build_generic_class_before_312() -> ModuleASTBuilder: + """ + import typing + + T = typing.TypeVar('T') + + class Node(typing.Generic[T]): + value: T + parent: 'Node[T]' + """ + + with build_module("generic") as mod: + with mod.class_def("Node") as node, node.type_var("T") as type_var: + node.field_def("value", type_var) + node.field_def("parent", node.ref().type_params(type_var)) + + return mod + + +@_to_param( + marks=( + pytest.mark.skipif( + condition="sys.version_info < (3, 12)", + reason="syntax `type XXX = YYY` was introduced since python version 3.12", + ), + ), +) +def build_generic_class_after_312() -> ModuleASTBuilder: + """ + import builtins + + class Node[T : builtins.int]: + value: T + parent: Node[T] + """ + + with build_module("generic") as mod: + with mod.class_def("Node") as node, node.type_var("T").lower(predef().int) as type_var: + node.field_def("value", type_var) + node.field_def("parent", node.ref().type_params(type_var)) + + return mod + + +@_to_param( + marks=( + pytest.mark.skipif( + condition="sys.version_info >= (3, 12)", + reason="syntax `type XXX = YYY` was introduced since python version 3.12", + ), + ), +) +def build_type_alias_before_syntax_312() -> ModuleASTBuilder: + """ + import builtins + import typing + + MyInt: typing.TypeAlias = builtins.int + Json: typing.TypeAlias = typing.Union[ + None, + builtins.bool, + builtins.int, + builtins.float, + builtins.str, + builtins.list['Json'], + builtins.dict[builtins.str, 'Json'], + ] + """ + + with build_module("alias") as mod: + mod.type_alias("MyInt").assign(predef().int) + + with mod.type_alias("Json") as json_alias: + json_alias.assign( + json_alias.union_type( + predef().none, + predef().bool, + predef().int, + predef().float, + predef().str, + mod.list_type(json_alias), + mod.dict_type(predef().str, json_alias), + ) + ) + + with ( + mod.type_alias("Nested") as nested_alias, + nested_alias.type_var("T1") as type_var_1, + nested_alias.type_var("T2") as type_var_2, + ): + nested_alias.assign( + nested_alias.union_type( + type_var_1, + type_var_2, + nested_alias.sequence_type(nested_alias.type_params(type_var_1, type_var_2)), + ) + ) + + return mod + + +@_to_param( + marks=( + pytest.mark.skipif( + condition="sys.version_info < (3, 12)", + reason="syntax `type XXX = YYY` was introduced since python version 3.12", + ), + ), +) +def build_type_alias_syntax_312() -> ModuleASTBuilder: """ import builtins import typing type MyInt = builtins.int - type Json = typing.Union[None, builtins.bool, builtins.int, builtins.float, builtins.str, builtins.list[Json], builtins.dict[builtins.str, Json]] + type Json = typing.Union[ + None, + builtins.bool, + builtins.int, + builtins.float, + builtins.str, + builtins.list[Json], + builtins.dict[builtins.str, Json] + ] type Nested[T1, T2] = typing.Union[T1, T2, typing.Sequence[Nested[T1, T2]]] """ @@ -320,7 +451,9 @@ def build_type_alias() -> ModuleASTBuilder: ): nested_alias.assign( nested_alias.union_type( - type_var_1, type_var_2, nested_alias.sequence_type(nested_alias.type_params(type_var_1, type_var_2)) + type_var_1, + type_var_2, + nested_alias.sequence_type(nested_alias.type_params(type_var_1, type_var_2)), ) ) diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index 954caa5..c1f66b8 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -6,6 +6,7 @@ import pytest from astlab import abc as astlab_abc +from astlab.types import EnumTypeInfo, EnumTypeValue from astlab.types.annotator import TypeAnnotator from astlab.types.inspector import TypeInspector from astlab.types.loader import TypeLoader, TypeLoaderError @@ -19,7 +20,7 @@ builtins_module_info, none_type_info, ) -from tests.stub.types import StubBar, StubCM, StubFoo, StubInt, StubNode, StubUnionAlias, StubX +from tests.stub.types import StubBar, StubCM, StubEnum, StubFoo, StubInt, StubNode, StubUnionAlias, StubX class TestPackageInfo: @@ -333,6 +334,24 @@ def test_qualname_ok(self, value: ModuleInfo, expected: t.Sequence[str]) -> None namespace=("StubX", "Y"), ), ), + pytest.param( + StubEnum, + "tests.stub.types.StubEnum", + EnumTypeInfo( + name="StubEnum", + module=ModuleInfo("types", PackageInfo("stub", PackageInfo("tests"))), + values=( + EnumTypeValue( + name="FOO", + value=1, + ), + EnumTypeValue( + name="BAR", + value=2, + ), + ), + ), + ), pytest.param( StubCM, "tests.stub.types.StubCM", From cf578d0c3db68665207e808fec4d29bc7cd60465 Mon Sep 17 00:00:00 2001 From: zerlok Date: Fri, 17 Oct 2025 23:29:10 +0200 Subject: [PATCH 4/8] refactor DefaultASTResolver * use dfs post order traverse * partially handle forward ref types --- src/astlab/builder.py | 89 ++++++++++++---------- src/astlab/resolver.py | 135 ++++++++++++++++++---------------- src/astlab/traverse.py | 19 +++++ src/astlab/types/annotator.py | 50 +++++++------ src/astlab/types/model.py | 2 +- tests/unit/test_builder.py | 7 +- 6 files changed, 173 insertions(+), 129 deletions(-) create mode 100644 src/astlab/traverse.py diff --git a/src/astlab/builder.py b/src/astlab/builder.py index 45a7568..de55d1f 100644 --- a/src/astlab/builder.py +++ b/src/astlab/builder.py @@ -584,7 +584,7 @@ def transform( context: BuildContext, info: TypeInfo, ) -> Expr: - return self._scope.subscript(inner(context, info), self._scope.tuple_type(*params, normalize=True)) + return self._scope.subscript(inner(context, info), *params) return self.__wrap(transform) @@ -706,9 +706,6 @@ def context_manager_type(self, of_type: TypeRef, *, is_async: bool = False) -> E of_type, ) - def _on_expression(self, expr: Expr) -> None: - pass - @dataclass(frozen=True) class Comprehension: @@ -957,8 +954,11 @@ def call( return CallASTBuilder(self._context, func, args, kwargs) @_ast_expr_builder - def subscript(self, value: TypeRef, slice_: Expr) -> Expr: - return ast.Subscript(value=self._normalize_expr(value), slice=self._normalize_expr(slice_)) + def subscript(self, value: TypeRef, *slice_: TypeRef) -> Expr: + return ast.Subscript( + value=self._normalize_expr(value), + slice=self._normalize_expr(self.tuple_expr(*slice_, normalize=True)), + ) @_ast_expr_builder def slice( @@ -1331,14 +1331,16 @@ def __init__( context: BuildContext, name: str, variance: t.Optional[TypeVarVariance] = None, - lower: t.Optional[t.Sequence[TypeRef]] = None, + constraints: t.Optional[t.Sequence[TypeRef]] = None, + lower: t.Optional[TypeRef] = None, ) -> None: super().__init__(context) self.__name = name self.__module = self._context.module self.__namespace = self._context.namespace self.__variance = variance - self.__lower = lower or [] + self.__constraints = list[TypeRef](constraints or ()) + self.__lower = lower def __enter__(self) -> TypeRef: return NamedTypeInfo( @@ -1362,36 +1364,38 @@ def contravariant(self) -> Self: self.__variance = "contravariant" return self - def lower(self, *types: TypeRef) -> Self: - self.__lower = types + def constraints(self, *types: TypeRef) -> Self: + self.__constraints.extend(types) + return self + + def lower(self, type_: TypeRef) -> Self: + self.__lower = type_ return self @override def build_stmt(self) -> t.Sequence[ast.stmt]: + args: list[ast.expr] = [ast.Constant(value=self.__name)] + keywords = list[ast.keyword]() + + if self.__variance is None or self.__variance == "invariant": + pass + elif self.__variance == "covariant": + keywords.append(ast.keyword(arg="covariant", value=ast.Constant(value=True))) + elif self.__variance == "contravariant": + keywords.append(ast.keyword(arg="contravariant", value=ast.Constant(value=True))) + else: + assert_never(self.__variance) + + if self.__lower is not None: + keywords.append(ast.keyword(arg="bound", value=self._normalize_expr(self.__lower))) + return [ ast.Assign( targets=[ast.Name(id=self.__name)], value=ast.Call( func=self._normalize_expr(predef().type_var), - args=[ast.Constant(self.__name)], - keywords=[ - ast.keyword( - arg="covariant", - value=ast.Constant(value=self.__variance == "covariant"), - ), - ast.keyword( - arg="contravariant", - value=ast.Constant(value=self.__variance == "contravariant"), - ), - ast.keyword( - arg="bound", - value=self._normalize_expr( - self._scope.union_type(*self.__lower, normalize=True) - if self.__lower - else self._scope.none() - ), - ), - ], + args=args, + keywords=keywords, ), lineno=0, ), @@ -1403,13 +1407,11 @@ def build_stmt(self) -> t.Sequence[ast.stmt]: def build_type_param(self) -> ast.type_param: return ast.TypeVar( name=self.__name, - bound=self._normalize_expr(self._scope.union_type(*self.__lower, normalize=True)) - if self.__lower - else None, + bound=self._normalize_expr(self.__lower) if self.__lower is not None else None, ) -class TypeAliasDefinitionBuilder(AnnotationASTBuilder, TypeDefinitionBuilder, ASTStatementBuilder): +class TypeAliasExpressionBuilder(AnnotationASTBuilder, TypeDefinitionBuilder, ASTStatementBuilder): def __init__(self, context: BuildContext, info: NamedTypeInfo) -> None: super().__init__(context) self.__info = info @@ -1438,6 +1440,7 @@ def type_params(self, *params: TypeRef) -> Expr: # NOTE: workaround for passing mypy typings in CI for python 3.12 if sys.version_info >= (3, 12): + # if False: @override def build_stmt(self) -> t.Sequence[ast.stmt]: @@ -1465,7 +1468,6 @@ def build_stmt(self) -> t.Sequence[ast.stmt]: stmts.append( ast.AnnAssign( target=ast.Name(id=self.__info.name), - # TODO: add typing.TypeAlias annotation=self._normalize_annotation(predef().type_alias), value=self._normalize_expr(self.__expr), simple=1, @@ -1475,15 +1477,13 @@ def build_stmt(self) -> t.Sequence[ast.stmt]: return stmts -class TypeAliasStatementASTBuilder(_BaseBuilder, ASTStatementBuilder): +class TypeAliasStatementASTBuilder(_BaseBuilder, ASTStatementBuilder, TypeDefinitionBuilder): def __init__(self, context: BuildContext, name: str) -> None: super().__init__(context) - self.__annotation = TypeAliasDefinitionBuilder( - context=self._context, - info=NamedTypeInfo(name=name, module=self._context.module, namespace=self._context.namespace), - ) + self.__info = NamedTypeInfo(name=name, module=self._context.module, namespace=self._context.namespace) + self.__annotation = TypeAliasExpressionBuilder(context=self._context, info=self.__info) - def __enter__(self) -> TypeAliasDefinitionBuilder: + def __enter__(self) -> TypeAliasExpressionBuilder: self._context.enter_scope(self.__annotation.info.name, []) return self.__annotation @@ -1492,6 +1492,15 @@ def __exit__(self, exc_type: object, exc_value: object, exc_traceback: object) - self._context.leave_scope() self._context.extend_body(self.build_stmt()) + @override + @property + def info(self) -> TypeInfo: + return self.__info + + @override + def ref(self) -> ASTExpressionBuilder: + return TypeRefBuilder(self._context, self.__info) + def assign(self, expr: TypeRef) -> None: self.__annotation.assign(expr) self._context.extend_body(self.build_stmt()) diff --git a/src/astlab/resolver.py b/src/astlab/resolver.py index 94fdf46..678d7b2 100644 --- a/src/astlab/resolver.py +++ b/src/astlab/resolver.py @@ -5,16 +5,19 @@ ] import ast +import sys import typing as t from dataclasses import replace from itertools import chain from astlab._typing import assert_never, override from astlab.abc import ASTExpressionBuilder, ASTResolver, ASTStatementBuilder, Stmt, TypeDefinitionBuilder, TypeRef +from astlab.traverse import traverse_dfs_post_order from astlab.types import ( LiteralTypeInfo, ModuleInfo, NamedTypeInfo, + TypeAnnotator, TypeInfo, TypeInspector, ellipsis_type_info, @@ -24,11 +27,16 @@ class DefaultASTResolver(ASTResolver): - def __init__(self, inspector: t.Optional[TypeInspector] = None) -> None: + def __init__( + self, + inspector: t.Optional[TypeInspector] = None, + annotator: t.Optional[TypeAnnotator] = None, + ) -> None: self.__module: t.Optional[ModuleInfo] = None self.__namespace: t.Sequence[str] = () self.__dependencies: t.MutableSet[ModuleInfo] = set[ModuleInfo]() self.__inspector = inspector if inspector is not None else TypeInspector() + self.__annotator = annotator if annotator is not None else TypeAnnotator() @override def resolve_expr(self, ref: TypeRef, *tail: str) -> ast.expr: @@ -38,15 +46,15 @@ def resolve_expr(self, ref: TypeRef, *tail: str) -> ast.expr: elif isinstance(ref, ASTExpressionBuilder): return self.__chain_attr(ref.build_expr(), *tail) - elif isinstance(ref, (NamedTypeInfo, LiteralTypeInfo)): - return self.__type_info_expr(ref, tail) + elif isinstance(ref, (TypeVarInfo, NamedTypeInfo, LiteralTypeInfo, EnumTypeInfo)): + return self.__resolve_info(ref, tail) elif isinstance(ref, TypeDefinitionBuilder): - return self.__type_info_expr(ref.info, tail) + return self.__resolve_info(ref.info, tail) else: info = self.__inspector.inspect(ref) - return self.__type_info_expr(info, tail) + return self.__resolve_info(info, tail) @override def resolve_annotation(self, ref: TypeRef) -> ast.expr: @@ -56,15 +64,15 @@ def resolve_annotation(self, ref: TypeRef) -> ast.expr: elif isinstance(ref, ASTExpressionBuilder): return ref.build_annotation() - elif isinstance(ref, (NamedTypeInfo, LiteralTypeInfo)): - return self.__type_info_expr(ref, is_annotation=True) + elif isinstance(ref, (TypeVarInfo, NamedTypeInfo, LiteralTypeInfo, EnumTypeInfo)): + return self.__resolve_info(ref) elif isinstance(ref, TypeDefinitionBuilder): - return self.__type_info_expr(ref.info, is_annotation=True) + return self.__resolve_info(ref.info) else: info = self.__inspector.inspect(ref) - return self.__type_info_expr(info, is_annotation=True) + return self.__resolve_info(info) @override def resolve_stmts( @@ -104,45 +112,30 @@ def set_current_scope( self.__namespace = namespace self.__dependencies = dependencies - def __type_info_expr(self, info: TypeInfo, tail: t.Sequence[str] = (), *, is_annotation: bool = False) -> ast.expr: - resolved_info = self.__resolve_dependency(info) - return self.__type_info_attr(resolved_info, tail, is_annotation=is_annotation) + def __resolve_info(self, root: TypeInfo, tail: t.Sequence[str] = ()) -> ast.expr: + nodes = dict[TypeInfo, ast.expr]() + # forward_refs = dict[TypeInfo, bool]() - def __type_info_attr(self, info: TypeInfo, tail: t.Sequence[str] = (), *, is_annotation: bool = False) -> ast.expr: - if is_annotation and info == none_type_info(): - return ast.Constant(value=None) + node: ast.expr - if is_annotation and info == ellipsis_type_info(): - return ast.Constant(value=...) + for info in traverse_dfs_post_order(root, self.__get_children): + resolved_info = self.__resolve_dependency(info) - parts = self.__module.parts if self.__module is not None else () - head, *middle = ( - (info.parts[len(parts) :] if info.module == self.__module else info.parts) - if not isinstance(info, ModuleInfo) - else info.parts - ) + if self.__is_forward_ref(resolved_info): + node = ast.Constant(value=self.__annotator.annotate(resolved_info, qualified=False)) - origin = self.__chain_attr(ast.Name(id=head), *middle, *tail) - params = ( - ( - [self.__type_info_attr(param, is_annotation=is_annotation) for param in info.type_params] - if isinstance(info, NamedTypeInfo) - else [] - if isinstance(info, EnumTypeInfo) - else [ast.Constant(value=value) for value in info.values] - ) - if not isinstance(info, (ModuleInfo, TypeVarInfo)) - else [] - ) + else: + node = self.__build_expr(resolved_info) + if isinstance(info, NamedTypeInfo) and info.type_params: + params = [nodes[tp] for tp in info.type_params] + node = ast.Subscript( + value=node, + slice=ast.Tuple(elts=params) if len(params) > 1 else params[0], + ) - return ( - ast.Subscript( - value=origin, - slice=ast.Tuple(elts=params) if len(params) > 1 else params[0], - ) - if params - else origin - ) + nodes[info] = node + + return self.__chain_attr(nodes[root], *tail) def __resolve_dependency(self, info: TypeInfo) -> TypeInfo: if isinstance(info, ModuleInfo): @@ -153,36 +146,54 @@ def __resolve_dependency(self, info: TypeInfo) -> TypeInfo: return info elif isinstance(info, (TypeVarInfo, NamedTypeInfo, EnumTypeInfo)): - if info.module == self.__module: - ns = ( - info.namespace[len(self.__namespace) :] - if info.namespace[: len(self.__namespace)] == self.__namespace - else info.namespace - ) + if info.module != self.__module: + self.__dependencies.add(info.module) + return info + + elif info.namespace[: len(self.__namespace)] == self.__namespace: + # use shorten namespace for a type in nested namespace of the current scope + return replace(info, namespace=info.namespace[len(self.__namespace) :]) else: - self.__dependencies.add(info.module) - ns = info.namespace - - return ( - replace( - info, - namespace=ns, - type_params=tuple(self.__resolve_dependency(param) for param in info.type_params), - ) - if isinstance(info, NamedTypeInfo) - else replace(info, namespace=ns) - ) + return info elif isinstance(info, LiteralTypeInfo): - self.__dependencies.add(info.module) + if info.module != self.__module: + self.__dependencies.add(info.module) + return info else: assert_never(info) + def __build_expr(self, info: TypeInfo) -> ast.expr: + if info == none_type_info(): + return ast.Constant(value=None) + + if info == ellipsis_type_info(): + return ast.Constant(value=...) + + parts = self.__module.parts if self.__module is not None else () + head, *tail = ( + (info.parts[len(parts) :] if info.module == self.__module else info.parts) + if not isinstance(info, ModuleInfo) + else info.parts + ) + + return self.__chain_attr(ast.Name(id=head), *tail) + + def __is_forward_ref(self, info: TypeInfo) -> bool: + return ( + sys.version_info < (3, 12) + and info.module == self.__module + and (*info.namespace, info.name) == self.__namespace + ) + def __chain_attr(self, expr: ast.expr, *tail: str) -> ast.expr: for attr in tail: expr = ast.Attribute(attr=attr, value=expr) return expr + + def __get_children(self, info: TypeInfo) -> t.Iterable[TypeInfo]: + return info.type_params if isinstance(info, NamedTypeInfo) else () diff --git a/src/astlab/traverse.py b/src/astlab/traverse.py new file mode 100644 index 0000000..0d4423b --- /dev/null +++ b/src/astlab/traverse.py @@ -0,0 +1,19 @@ +import typing as t +from collections import deque + +T = t.TypeVar("T", bound=t.Hashable) + + +def traverse_dfs_post_order(root: T, children: t.Callable[[T], t.Iterable[T]]) -> t.Iterable[T]: + stack = deque[tuple[T, bool]]([(root, False)]) + + while stack: + item, processed = stack.pop() + if processed: + yield item + + else: + stack.append((item, True)) + + for child in children(item): + stack.append((child, False)) diff --git a/src/astlab/types/annotator.py b/src/astlab/types/annotator.py index 0833759..88f5a54 100644 --- a/src/astlab/types/annotator.py +++ b/src/astlab/types/annotator.py @@ -36,49 +36,53 @@ def __init__(self, loader: t.Optional[TypeLoader] = None) -> None: self.__loader = loader or TypeLoader() @lru_cache_method() - def annotate(self, info: TypeInfo) -> str: + def annotate(self, info: TypeInfo, *, qualified: bool = True) -> str: + if info == none_type_info(): + return "None" + + if info == ellipsis_type_info(): + return "..." + + annotation = self.__annotate(info) + return ".".join((*info.parts[:-1], annotation)) if qualified else annotation + + @lru_cache_method() + def parse(self, qualname: str) -> TypeInfo: + node = ast.parse(qualname) + + if len(node.body) != 1: + msg = "invalid qualified name" + raise ValueError(msg, qualname) + + return _ExprParser(self.__loader).parse(node) + + def __annotate(self, info: TypeInfo) -> str: if isinstance(info, ModuleInfo): - annotation = "builtins.module" + annotation = "module" elif isinstance(info, TypeVarInfo): annotation = info.name elif isinstance(info, NamedTypeInfo): - if info == none_type_info(): - annotation = "None" - - elif info == ellipsis_type_info(): - annotation = "..." + annotation = info.name - elif info.type_params: + if info.type_params: # TODO: fix recursive type params = ", ".join(self.annotate(tp) for tp in info.type_params) - annotation = f"{info.qualname}[{params}]" - - else: - annotation = info.qualname + annotation = f"{annotation}[{params}]" elif isinstance(info, LiteralTypeInfo): vals = ", ".join(repr(v) for v in info.values) - annotation = f"typing.Literal[{vals}]" + annotation = f"Literal[{vals}]" elif isinstance(info, EnumTypeInfo): - annotation = info.qualname + annotation = info.name else: assert_never(info) return annotation - def parse(self, qualname: str) -> TypeInfo: - node = ast.parse(qualname) - - if len(node.body) != 1: - msg = "invalid qualified name" - raise ValueError(msg, qualname) - - return _ExprParser(self.__loader).parse(node) - class _ExprParser(ast.NodeVisitor): def __init__(self, loader: TypeLoader) -> None: diff --git a/src/astlab/types/model.py b/src/astlab/types/model.py index 1e2cfa4..d61f097 100644 --- a/src/astlab/types/model.py +++ b/src/astlab/types/model.py @@ -245,7 +245,7 @@ def qualname(self) -> str: return ".".join(self.parts) -@dataclass(frozen=True, kw_only=True) +@dataclass(frozen=True) class EnumTypeValue: name: str value: LiteralTypeValue diff --git a/tests/unit/test_builder.py b/tests/unit/test_builder.py index 62c55f7..6962843 100644 --- a/tests/unit/test_builder.py +++ b/tests/unit/test_builder.py @@ -303,9 +303,10 @@ def build_index_slice() -> ModuleASTBuilder: ) def build_generic_class_before_312() -> ModuleASTBuilder: """ + import builtins import typing - T = typing.TypeVar('T') + T = typing.TypeVar('T', bound=builtins.int) class Node(typing.Generic[T]): value: T @@ -313,7 +314,7 @@ class Node(typing.Generic[T]): """ with build_module("generic") as mod: - with mod.class_def("Node") as node, node.type_var("T") as type_var: + with mod.class_def("Node") as node, node.type_var("T").lower(int) as type_var: node.field_def("value", type_var) node.field_def("parent", node.ref().type_params(type_var)) @@ -338,7 +339,7 @@ class Node[T : builtins.int]: """ with build_module("generic") as mod: - with mod.class_def("Node") as node, node.type_var("T").lower(predef().int) as type_var: + with mod.class_def("Node") as node, node.type_var("T").lower(int) as type_var: node.field_def("value", type_var) node.field_def("parent", node.ref().type_params(type_var)) From f4f31bbca27fae8f6d19503ae5c9fa67880d4266 Mon Sep 17 00:00:00 2001 From: zerlok Date: Sat, 18 Oct 2025 00:00:15 +0200 Subject: [PATCH 5/8] refactor ASTResolver interface * exclude `Expr` from `TypeRef` alias * introduce new `TypeExpr` alias * remove redundant `resolve_annotation` --- src/astlab/abc.py | 19 +-- src/astlab/builder.py | 280 ++++++++++++++++------------------ src/astlab/resolver.py | 40 ++--- src/astlab/types/inspector.py | 5 +- src/astlab/types/model.py | 5 +- 5 files changed, 149 insertions(+), 200 deletions(-) diff --git a/src/astlab/abc.py b/src/astlab/abc.py index a168cbd..ee5149e 100644 --- a/src/astlab/abc.py +++ b/src/astlab/abc.py @@ -8,6 +8,7 @@ "Expr", "Stmt", "TypeDefinitionBuilder", + "TypeExpr", "TypeRef", ] @@ -31,10 +32,6 @@ class ASTExpressionBuilder(metaclass=abc.ABCMeta): def build_expr(self) -> ast.expr: raise NotImplementedError - @abc.abstractmethod - def build_annotation(self) -> ast.expr: - raise NotImplementedError - class ASTStatementBuilder(metaclass=abc.ABCMeta): @abc.abstractmethod @@ -55,21 +52,13 @@ def ref(self) -> ASTExpressionBuilder: Expr: TypeAlias = t.Union[ast.expr, ASTExpressionBuilder] Stmt: TypeAlias = t.Union[ast.stmt, ASTStatementBuilder, Expr] -TypeRef: TypeAlias = t.Union[ - Expr, - RuntimeType, - TypeInfo, - TypeDefinitionBuilder, -] +TypeRef: TypeAlias = t.Union[RuntimeType, TypeInfo, TypeDefinitionBuilder] +TypeExpr: TypeAlias = t.Union[Expr, TypeRef] class ASTResolver(metaclass=abc.ABCMeta): @abc.abstractmethod - def resolve_expr(self, ref: TypeRef, *tail: str) -> ast.expr: - raise NotImplementedError - - @abc.abstractmethod - def resolve_annotation(self, ref: TypeRef) -> ast.expr: + def resolve_expr(self, expr: TypeExpr, *tail: str) -> ast.expr: raise NotImplementedError @abc.abstractmethod diff --git a/src/astlab/builder.py b/src/astlab/builder.py index de55d1f..f083d8c 100644 --- a/src/astlab/builder.py +++ b/src/astlab/builder.py @@ -45,16 +45,15 @@ Expr, Stmt, TypeDefinitionBuilder, + TypeExpr, TypeRef, ) from astlab.context import BuildContext from astlab.resolver import DefaultASTResolver from astlab.types import ( - LiteralTypeInfo, ModuleInfo, NamedTypeInfo, PackageInfo, - RuntimeType, TypeInfo, TypeInspector, TypeVarVariance, @@ -125,12 +124,9 @@ def __init__(self, context: BuildContext) -> None: def _scope(self) -> ScopeASTBuilder: return ScopeASTBuilder(self._context) - def _normalize_expr(self, expr: TypeRef, *tail: str) -> ast.expr: + def _normalize_expr(self, expr: TypeExpr, *tail: str) -> ast.expr: return self._context.resolver.resolve_expr(expr, *tail) - def _normalize_annotation(self, expr: TypeRef) -> ast.expr: - return self._context.resolver.resolve_annotation(expr) - def _normalize_body(self, body: t.Sequence[Stmt], docs: t.Optional[t.Sequence[str]] = None) -> list[ast.stmt]: return self._context.resolver.resolve_stmts(*body, docs=docs, pass_if_empty=True) @@ -172,10 +168,6 @@ def __truediv__(self, other: Expr) -> Self: def build_expr(self) -> ast.expr: return self._normalize_expr(self.__factory()) - @override - def build_annotation(self) -> ast.expr: - return self._normalize_annotation(self.__factory()) - def stmt(self, *, append: bool = True) -> ast.stmt: node = ast.Expr(value=self.build_expr()) if append: @@ -233,7 +225,7 @@ class AttrASTBuilder(BaseASTExpressionBuilder): def __init__( self, context: BuildContext, - head: t.Union[str, TypeRef], + head: t.Union[str, TypeExpr], *tail: str, is_awaited: bool = False, ) -> None: @@ -310,7 +302,7 @@ class CallASTBuilder(BaseASTExpressionBuilder): def __init__( self, context: BuildContext, - func: TypeRef, + func: TypeExpr, args: t.Optional[t.Sequence[Expr]] = None, kwargs: t.Optional[t.Mapping[str, Expr]] = None, ) -> None: @@ -385,7 +377,7 @@ class SliceASTBuilder(BaseASTExpressionBuilder): def __init__( self, context: BuildContext, - value: TypeRef, + value: TypeExpr, index: t.Optional[t.Union[Expr, Slice]] = None, *, is_awaited: bool = False, @@ -441,18 +433,16 @@ def __create_expr(self) -> ast.expr: return node -class TypeRefBuilder(_BaseBuilder, ASTExpressionBuilder): +class TypeRefBuilder(ASTExpressionBuilder): def __init__( self, context: BuildContext, info: TypeInfo, - transform: t.Optional[t.Callable[[BuildContext, TypeInfo], Expr]] = None, + transform: t.Optional[t.Callable[[TypeInfo], TypeInfo]] = None, ) -> None: - super().__init__(context) + self.__context = context self.__info = info - self.__transform: t.Callable[[BuildContext, TypeInfo], Expr] = ( - transform if transform is not None else self.__ident - ) + self.__transform: t.Callable[[TypeInfo], TypeInfo] = transform if transform is not None else self.__ident @property def info(self) -> TypeInfo: @@ -460,170 +450,160 @@ def info(self) -> TypeInfo: def optional(self) -> TypeRefBuilder: def transform( - inner: t.Callable[[BuildContext, TypeInfo], Expr], - context: BuildContext, + inner: t.Callable[[TypeInfo], TypeInfo], info: TypeInfo, - ) -> Expr: - return self._scope.optional_type(inner(context, info)) + ) -> TypeInfo: + return predef().optional.with_type_params(inner(info)) return self.__wrap(transform) def collection(self) -> TypeRefBuilder: def transform( - inner: t.Callable[[BuildContext, TypeInfo], Expr], - context: BuildContext, + inner: t.Callable[[TypeInfo], TypeInfo], info: TypeInfo, - ) -> Expr: - return self._scope.collection_type(inner(context, info)) + ) -> TypeInfo: + return predef().collection.with_type_params(inner(info)) return self.__wrap(transform) def set(self) -> TypeRefBuilder: def transform( - inner: t.Callable[[BuildContext, TypeInfo], Expr], - context: BuildContext, + inner: t.Callable[[TypeInfo], TypeInfo], info: TypeInfo, - ) -> Expr: - return self._scope.generic_type(set, inner(context, info)) + ) -> TypeInfo: + return predef().set.with_type_params(inner(info)) return self.__wrap(transform) def sequence(self, *, mutable: bool = False) -> TypeRefBuilder: def transform( - inner: t.Callable[[BuildContext, TypeInfo], Expr], - context: BuildContext, + inner: t.Callable[[TypeInfo], TypeInfo], info: TypeInfo, - ) -> Expr: - return self._scope.sequence_type(inner(context, info), mutable=mutable) + ) -> TypeInfo: + return (predef().mutable_sequence if mutable else predef().sequence).with_type_params(inner(info)) return self.__wrap(transform) def list(self) -> TypeRefBuilder: def transform( - inner: t.Callable[[BuildContext, TypeInfo], Expr], - context: BuildContext, + inner: t.Callable[[TypeInfo], TypeInfo], info: TypeInfo, - ) -> Expr: - return self._scope.generic_type(list, inner(context, info)) + ) -> TypeInfo: + return predef().list.with_type_params(inner(info)) return self.__wrap(transform) def mapping_key(self, value: TypeRef, *, mutable: bool = False) -> TypeRefBuilder: def transform( - inner: t.Callable[[BuildContext, TypeInfo], Expr], - context: BuildContext, + inner: t.Callable[[TypeInfo], TypeInfo], info: TypeInfo, - ) -> Expr: - return self._scope.mapping_type(inner(context, info), value, mutable=mutable) + ) -> TypeInfo: + return (predef().mutable_mapping if mutable else predef().mapping).with_type_params( + inner(info), self.__context.inspector.inspect(value) + ) return self.__wrap(transform) def dict_key(self, value: TypeRef) -> TypeRefBuilder: def transform( - inner: t.Callable[[BuildContext, TypeInfo], Expr], - context: BuildContext, + inner: t.Callable[[TypeInfo], TypeInfo], info: TypeInfo, - ) -> Expr: - return self._scope.dict_type(inner(context, info), value) + ) -> TypeInfo: + return predef().dict.with_type_params(inner(info), self.__context.inspector.inspect(value)) return self.__wrap(transform) def mapping_value(self, key: TypeRef, *, mutable: bool = False) -> TypeRefBuilder: def transform( - inner: t.Callable[[BuildContext, TypeInfo], Expr], - context: BuildContext, + inner: t.Callable[[TypeInfo], TypeInfo], info: TypeInfo, - ) -> Expr: - return self._scope.mapping_type(key, inner(context, info), mutable=mutable) + ) -> TypeInfo: + return (predef().mutable_mapping if mutable else predef().mapping).with_type_params( + self.__context.inspector.inspect(key), inner(info) + ) return self.__wrap(transform) def dict_value(self, key: TypeRef) -> TypeRefBuilder: def transform( - inner: t.Callable[[BuildContext, TypeInfo], Expr], - context: BuildContext, + inner: t.Callable[[TypeInfo], TypeInfo], info: TypeInfo, - ) -> Expr: - return self._scope.dict_type(key, inner(context, info)) + ) -> TypeInfo: + return predef().dict.with_type_params(self.__context.inspector.inspect(key), inner(info)) return self.__wrap(transform) def context_manager(self, *, is_async: bool = False) -> TypeRefBuilder: def transform( - inner: t.Callable[[BuildContext, TypeInfo], Expr], - context: BuildContext, + inner: t.Callable[[TypeInfo], TypeInfo], info: TypeInfo, - ) -> Expr: - return self._scope.context_manager_type(inner(context, info), is_async=is_async) + ) -> TypeInfo: + return (predef().async_context_manager if is_async else predef().context_manager).with_type_params( + inner(info) + ) return self.__wrap(transform) def iterator(self, *, is_async: bool = False) -> TypeRefBuilder: def transform( - inner: t.Callable[[BuildContext, TypeInfo], Expr], - context: BuildContext, + inner: t.Callable[[TypeInfo], TypeInfo], info: TypeInfo, - ) -> Expr: - return self._scope.iterable_type(inner(context, info), is_async=is_async) + ) -> TypeInfo: + return (predef().async_iterator if is_async else predef().iterator).with_type_params(inner(info)) return self.__wrap(transform) def iterable(self, *, is_async: bool = False) -> TypeRefBuilder: def transform( - inner: t.Callable[[BuildContext, TypeInfo], Expr], - context: BuildContext, + inner: t.Callable[[TypeInfo], TypeInfo], info: TypeInfo, - ) -> Expr: - return self._scope.iterable_type(inner(context, info), is_async=is_async) + ) -> TypeInfo: + return (predef().async_iterable if is_async else predef().iterable).with_type_params(inner(info)) return self.__wrap(transform) def type_params(self, *params: TypeRef) -> TypeRefBuilder: def transform( - inner: t.Callable[[BuildContext, TypeInfo], Expr], - context: BuildContext, + inner: t.Callable[[TypeInfo], TypeInfo], info: TypeInfo, - ) -> Expr: - return self._scope.subscript(inner(context, info), *params) + ) -> TypeInfo: + origin = inner(info) + + if not isinstance(origin, NamedTypeInfo): + msg = "named type info was expected to apply type params" + raise TypeError(msg, origin, params) + + return origin.with_type_params(*(self.__context.inspector.inspect(param) for param in params)) return self.__wrap(transform) def attr(self, *tail: str) -> AttrASTBuilder: - return AttrASTBuilder(self._context, self, *tail) + return AttrASTBuilder(self.__context, self, *tail) def init( self, args: t.Optional[t.Sequence[Expr]] = None, kwargs: t.Optional[t.Mapping[str, Expr]] = None, ) -> CallASTBuilder: - return CallASTBuilder(self._context, self, args, kwargs) + return CallASTBuilder(self.__context, self, args, kwargs) @override def build_expr(self) -> ast.expr: - return self._normalize_expr(self.__transform(self._context, self.__info)) - - @override - def build_annotation(self) -> ast.expr: - return self._normalize_annotation(self.__transform(self._context, self.__info)) + return self.__context.resolver.resolve_expr(self.__transform(self.__info)) - def __ident(self, context: BuildContext, info: TypeInfo) -> Expr: - return context.resolver.resolve_expr(info) + def __ident(self, info: TypeInfo) -> TypeInfo: + return info def __wrap( self, - transform: t.Callable[[t.Callable[[BuildContext, TypeInfo], Expr], BuildContext, TypeInfo], Expr], + transform: t.Callable[[t.Callable[[TypeInfo], TypeInfo], TypeInfo], TypeInfo], ) -> TypeRefBuilder: - return self.__class__(self._context, self.__info, partial(transform, self.__transform)) + return self.__class__(self.__context, self.__info, partial(transform, self.__transform)) class AnnotationASTBuilder(_BaseBuilder): - def type_ref(self, origin: t.Union[TypeInfo, RuntimeType]) -> TypeRefBuilder: - return TypeRefBuilder( - context=self._context, - info=origin - if isinstance(origin, (NamedTypeInfo, LiteralTypeInfo)) - else self._context.inspector.inspect(origin), - ) + def type_ref(self, origin: TypeRef) -> TypeRefBuilder: + return TypeRefBuilder(self._context, self._context.inspector.inspect(origin)) if sys.version_info >= (3, 10): @@ -643,64 +623,64 @@ def none(self) -> Expr: def ellipsis(self) -> Expr: return ast.Constant(value=...) - def generic_type(self, generic: TypeRef, *params: TypeRef) -> Expr: + def generic_type(self, generic: TypeExpr, *params: TypeExpr) -> Expr: if len(params) == 0: - return self._normalize_annotation(generic) + return self._normalize_expr(generic) return ast.Subscript( - value=self._normalize_annotation(generic), - slice=self._normalize_annotation(self.tuple_type(*params, normalize=True)), + value=self._normalize_expr(generic), + slice=self._normalize_expr(self.tuple_type(*params, normalize=True)), ) def literal_type(self, *values: t.Union[str, Expr]) -> Expr: if not values: - return self._normalize_annotation(predef().no_return) + return self._normalize_expr(predef().no_return) return self.generic_type( predef().literal, *(self.const(val) if isinstance(val, str) else val for val in values), ) - def optional_type(self, of_type: TypeRef) -> Expr: + def optional_type(self, of_type: TypeExpr) -> Expr: return self.generic_type(predef().optional, of_type) - def union_type(self, *params: TypeRef, normalize: bool = False) -> Expr: + def union_type(self, *params: TypeExpr, normalize: bool = False) -> Expr: if not params: - return self._normalize_annotation(predef().no_return) + return self._normalize_expr(predef().no_return) if normalize and len(params) == 1: - return self._normalize_annotation(params[0]) + return self._normalize_expr(params[0]) return self.generic_type(predef().union, *params) - def tuple_type(self, *params: TypeRef, normalize: bool = False) -> Expr: + def tuple_type(self, *params: TypeExpr, normalize: bool = False) -> Expr: if normalize and len(params) == 1: - return self._normalize_annotation(params[0]) + return self._normalize_expr(params[0]) - return ast.Tuple(elts=[self._normalize_annotation(item) for item in params]) + return ast.Tuple(elts=[self._normalize_expr(item) for item in params]) - def collection_type(self, of_type: TypeRef) -> Expr: + def collection_type(self, of_type: TypeExpr) -> Expr: return self.generic_type(predef().collection, of_type) - def sequence_type(self, of_type: TypeRef, *, mutable: bool = False) -> Expr: + def sequence_type(self, of_type: TypeExpr, *, mutable: bool = False) -> Expr: return self.generic_type(predef().mutable_sequence if mutable else predef().sequence, of_type) - def list_type(self, of_type: TypeRef) -> Expr: + def list_type(self, of_type: TypeExpr) -> Expr: return self.generic_type(predef().list, of_type) - def mapping_type(self, key_type: TypeRef, value_type: TypeRef, *, mutable: bool = False) -> Expr: + def mapping_type(self, key_type: TypeExpr, value_type: TypeExpr, *, mutable: bool = False) -> Expr: return self.generic_type(predef().mutable_mapping if mutable else predef().mapping, key_type, value_type) - def dict_type(self, key_type: TypeRef, value_type: TypeRef) -> Expr: + def dict_type(self, key_type: TypeExpr, value_type: TypeExpr) -> Expr: return self.generic_type(predef().dict, key_type, value_type) - def iterator_type(self, of_type: TypeRef, *, is_async: bool = False) -> Expr: + def iterator_type(self, of_type: TypeExpr, *, is_async: bool = False) -> Expr: return self.generic_type(predef().async_iterator if is_async else predef().iterator, of_type) - def iterable_type(self, of_type: TypeRef, *, is_async: bool = False) -> Expr: + def iterable_type(self, of_type: TypeExpr, *, is_async: bool = False) -> Expr: return self.generic_type(predef().async_iterable if is_async else predef().iterable, of_type) - def context_manager_type(self, of_type: TypeRef, *, is_async: bool = False) -> Expr: + def context_manager_type(self, of_type: TypeExpr, *, is_async: bool = False) -> Expr: return self.generic_type( predef().async_context_manager if is_async else predef().context_manager, of_type, @@ -756,7 +736,7 @@ def ternary_not_none_expr( ) @_ast_expr_builder - def tuple_expr(self, *items: TypeRef, normalize: bool = False) -> Expr: + def tuple_expr(self, *items: TypeExpr, normalize: bool = False) -> Expr: if normalize and len(items) == 1: return self._normalize_expr(items[0]) @@ -942,19 +922,19 @@ def compare_not_in_expr(self, left: Expr, right: Expr) -> Expr: comparators=[self._normalize_expr(right)], ) - def attr(self, head: t.Union[str, TypeRef], *tail: str) -> AttrASTBuilder: + def attr(self, head: t.Union[str, TypeExpr], *tail: str) -> AttrASTBuilder: return AttrASTBuilder(self._context, head, *tail) def call( self, - func: TypeRef, + func: TypeExpr, args: t.Optional[t.Sequence[Expr]] = None, kwargs: t.Optional[t.Mapping[str, Expr]] = None, ) -> CallASTBuilder: return CallASTBuilder(self._context, func, args, kwargs) @_ast_expr_builder - def subscript(self, value: TypeRef, *slice_: TypeRef) -> Expr: + def subscript(self, value: TypeExpr, *slice_: TypeExpr) -> Expr: return ast.Subscript( value=self._normalize_expr(value), slice=self._normalize_expr(self.tuple_expr(*slice_, normalize=True)), @@ -983,10 +963,10 @@ def func_def(self, name: str) -> FuncStatementASTBuilder: return FuncStatementASTBuilder(self._context, name) @_ast_stmt_builder - def field_def(self, name: str, annotation: TypeRef, default: t.Optional[Expr] = None) -> ast.stmt: + def field_def(self, name: str, annotation: TypeExpr, default: t.Optional[Expr] = None) -> ast.stmt: return ast.AnnAssign( target=ast.Name(id=name), - annotation=self._normalize_annotation(annotation), + annotation=self._normalize_expr(annotation), value=self._normalize_expr(default) if default is not None else None, simple=1, ) @@ -1267,7 +1247,7 @@ def build_stmt(self) -> t.Sequence[ast.stmt]: class TryStatementASTBuilder(_NestedBlockASTBuilder): @dataclass() class ExceptHandler: - types: t.Sequence[TypeRef] + types: t.Sequence[TypeExpr] name: t.Optional[str] body: list[ast.stmt] @@ -1281,7 +1261,7 @@ def __init__(self, context: BuildContext) -> None: def body(self) -> t.ContextManager[ScopeASTBuilder]: return self._block(self.__body) - def except_(self, *types: TypeRef, name: t.Optional[str] = None) -> t.ContextManager[ScopeASTBuilder]: + def except_(self, *types: TypeExpr, name: t.Optional[str] = None) -> t.ContextManager[ScopeASTBuilder]: body = list[ast.stmt]() self.__handlers.append( @@ -1331,18 +1311,18 @@ def __init__( context: BuildContext, name: str, variance: t.Optional[TypeVarVariance] = None, - constraints: t.Optional[t.Sequence[TypeRef]] = None, - lower: t.Optional[TypeRef] = None, + constraints: t.Optional[t.Sequence[TypeExpr]] = None, + lower: t.Optional[TypeExpr] = None, ) -> None: super().__init__(context) self.__name = name self.__module = self._context.module self.__namespace = self._context.namespace self.__variance = variance - self.__constraints = list[TypeRef](constraints or ()) + self.__constraints = list[TypeExpr](constraints or ()) self.__lower = lower - def __enter__(self) -> TypeRef: + def __enter__(self) -> TypeExpr: return NamedTypeInfo( name=self.__name, module=self.__module, @@ -1364,11 +1344,11 @@ def contravariant(self) -> Self: self.__variance = "contravariant" return self - def constraints(self, *types: TypeRef) -> Self: + def constraints(self, *types: TypeExpr) -> Self: self.__constraints.extend(types) return self - def lower(self, type_: TypeRef) -> Self: + def lower(self, type_: TypeExpr) -> Self: self.__lower = type_ return self @@ -1415,7 +1395,7 @@ class TypeAliasExpressionBuilder(AnnotationASTBuilder, TypeDefinitionBuilder, AS def __init__(self, context: BuildContext, info: NamedTypeInfo) -> None: super().__init__(context) self.__info = info - self.__expr: t.Optional[TypeRef] = None + self.__expr: t.Optional[TypeExpr] = None self.__type_vars = list[TypeVarBuilder]() @override @@ -1427,7 +1407,7 @@ def info(self) -> TypeInfo: def ref(self) -> TypeRefBuilder: return TypeRefBuilder(self._context, self.__info) - def assign(self, expr: TypeRef) -> None: + def assign(self, expr: TypeExpr) -> None: self.__expr = expr def type_var(self, name: str) -> TypeVarBuilder: @@ -1435,7 +1415,7 @@ def type_var(self, name: str) -> TypeVarBuilder: self.__type_vars.append(type_var) return type_var - def type_params(self, *params: TypeRef) -> Expr: + def type_params(self, *params: TypeExpr) -> Expr: return self.ref().type_params(*params) # NOTE: workaround for passing mypy typings in CI for python 3.12 @@ -1451,7 +1431,7 @@ def build_stmt(self) -> t.Sequence[ast.stmt]: return [ ast.TypeAlias( name=ast.Name(id=self.__info.name), - value=self._normalize_annotation(self.__expr), + value=self._normalize_expr(self.__expr), type_params=[tv.build_type_param() for tv in self.__type_vars], ) ] @@ -1468,7 +1448,7 @@ def build_stmt(self) -> t.Sequence[ast.stmt]: stmts.append( ast.AnnAssign( target=ast.Name(id=self.__info.name), - annotation=self._normalize_annotation(predef().type_alias), + annotation=self._normalize_expr(predef().type_alias), value=self._normalize_expr(self.__expr), simple=1, ) @@ -1501,7 +1481,7 @@ def info(self) -> TypeInfo: def ref(self) -> ASTExpressionBuilder: return TypeRefBuilder(self._context, self.__info) - def assign(self, expr: TypeRef) -> None: + def assign(self, expr: TypeExpr) -> None: self.__annotation.assign(expr) self._context.extend_body(self.build_stmt()) @@ -1537,7 +1517,7 @@ def init_def(self) -> MethodStatementASTBuilder: return self.method_def("__init__").returns(self.const(None)) @contextmanager - def init_self_attrs_def(self, attrs: t.Mapping[str, TypeRef]) -> t.Iterator[MethodScopeASTBuilder]: + def init_self_attrs_def(self, attrs: t.Mapping[str, TypeExpr]) -> t.Iterator[MethodScopeASTBuilder]: init_def = self.init_def() for name, annotation in attrs.items(): @@ -1599,9 +1579,9 @@ class ClassStatementASTBuilder( def __init__(self, context: BuildContext, name: str) -> None: super().__init__(context) self.__info = NamedTypeInfo(name=name, module=self._context.module, namespace=self._context.namespace) - self.__bases = list[TypeRef]() - self.__decorators = list[TypeRef]() - self.__keywords = dict[str, TypeRef]() + self.__bases = list[TypeExpr]() + self.__decorators = list[TypeExpr]() + self.__keywords = dict[str, TypeExpr]() self.__type_vars = list[TypeVarBuilder]() self.__docs = list[str]() self.__body = list[ast.stmt]() @@ -1654,15 +1634,15 @@ def dataclass(self, *, frozen: bool = False, kw_only: bool = False) -> Self: return self.decorators(dc) - def inherits(self, *bases: t.Optional[TypeRef]) -> Self: + def inherits(self, *bases: t.Optional[TypeExpr]) -> Self: self.__bases.extend(base for base in bases if base is not None) return self - def decorators(self, *items: t.Optional[TypeRef]) -> Self: + def decorators(self, *items: t.Optional[TypeExpr]) -> Self: self.__decorators.extend(item for item in items if item is not None) return self - def keywords(self, **keywords: t.Optional[TypeRef]) -> Self: + def keywords(self, **keywords: t.Optional[TypeExpr]) -> Self: self.__keywords.update({key: value for key, value in keywords.items() if value is not None}) return self @@ -1723,16 +1703,12 @@ def __init__(self, context: BuildContext, info: NamedTypeInfo) -> None: def build_expr(self) -> ast.expr: return self.__context.resolver.resolve_expr(self.__info) - @override - def build_annotation(self) -> ast.expr: - return self.__context.resolver.resolve_annotation(self.__info) - @dataclass(frozen=True) class FuncArgInfo: name: str kind: t.Literal["positional-only", "positional-or-keyword", "var-positional", "keyword-only", "var-keyword"] - annotation: t.Optional[TypeRef] = None + annotation: t.Optional[TypeExpr] = None default: t.Optional[Expr] = None @@ -1746,9 +1722,9 @@ class FuncStatementASTBuilder( def __init__(self, context: BuildContext, name: str) -> None: super().__init__(context) self.__info = NamedTypeInfo(name=name, module=self._context.module, namespace=self._context.namespace) - self.__decorators = list[TypeRef]() + self.__decorators = list[TypeExpr]() self.__args = list[FuncArgInfo]() - self.__returns: t.Optional[TypeRef] = None + self.__returns: t.Optional[TypeExpr] = None self.__is_async = False self.__is_abstract = False self.__is_override = False @@ -1795,14 +1771,14 @@ def docstring(self, value: t.Optional[str]) -> Self: self.__docs.append(value) return self - def decorators(self, *items: t.Optional[TypeRef]) -> Self: + def decorators(self, *items: t.Optional[TypeExpr]) -> Self: self.__decorators.extend(item for item in items if item is not None) return self def arg( self, name: str, - annotation: t.Optional[TypeRef] = None, + annotation: t.Optional[TypeExpr] = None, default: t.Optional[Expr] = None, ) -> Self: return self.args(FuncArgInfo(name=name, kind="positional-or-keyword", annotation=annotation, default=default)) @@ -1810,7 +1786,7 @@ def arg( def kwarg( self, name: str, - annotation: t.Optional[TypeRef] = None, + annotation: t.Optional[TypeExpr] = None, default: t.Optional[Expr] = None, ) -> Self: return self.args(FuncArgInfo(name=name, kind="keyword-only", annotation=annotation, default=default)) @@ -1825,7 +1801,7 @@ def args(self, *args: t.Union[FuncArgInfo, t.Sequence[FuncArgInfo]]) -> Self: self.__args.extend(chain.from_iterable((part,) if isinstance(part, FuncArgInfo) else part for part in args)) return self - def returns(self, ret: t.Optional[TypeRef]) -> Self: + def returns(self, ret: t.Optional[TypeExpr]) -> Self: if ret is not None: self.__returns = ret return self @@ -1873,8 +1849,8 @@ def build_stmt(self) -> t.Sequence[ast.stmt]: return [node] def __build_decorators(self) -> list[ast.expr]: - head_decorators: list[TypeRef] = [] - last_decorators: list[TypeRef] = [] + head_decorators: list[TypeExpr] = [] + last_decorators: list[TypeExpr] = [] if self.__is_override: head_decorators.append(predef().override_decorator) @@ -1903,7 +1879,7 @@ def __build_args(self) -> ast.arguments: for info in self.__args: arg = ast.arg( arg=info.name, - annotation=self._normalize_annotation(info.annotation) if info.annotation is not None else None, + annotation=self._normalize_expr(info.annotation) if info.annotation is not None else None, ) if info.kind == "positional-only": @@ -1936,8 +1912,8 @@ def __build_returns(self) -> t.Optional[ast.expr]: ret = self.__returns if self.__iterator_cm: ret = ast.Subscript( - value=self._normalize_annotation(predef().async_iterator if self.__is_async else predef().iterator), - slice=self._normalize_annotation(ret), + value=self._normalize_expr(predef().async_iterator if self.__is_async else predef().iterator), + slice=self._normalize_expr(ret), ) return self._normalize_expr(ret) diff --git a/src/astlab/resolver.py b/src/astlab/resolver.py index 678d7b2..94bc07a 100644 --- a/src/astlab/resolver.py +++ b/src/astlab/resolver.py @@ -11,7 +11,7 @@ from itertools import chain from astlab._typing import assert_never, override -from astlab.abc import ASTExpressionBuilder, ASTResolver, ASTStatementBuilder, Stmt, TypeDefinitionBuilder, TypeRef +from astlab.abc import ASTExpressionBuilder, ASTResolver, ASTStatementBuilder, Stmt, TypeDefinitionBuilder, TypeExpr from astlab.traverse import traverse_dfs_post_order from astlab.types import ( LiteralTypeInfo, @@ -39,41 +39,20 @@ def __init__( self.__annotator = annotator if annotator is not None else TypeAnnotator() @override - def resolve_expr(self, ref: TypeRef, *tail: str) -> ast.expr: - if isinstance(ref, ast.expr): - return self.__chain_attr(ref, *tail) + def resolve_expr(self, expr: TypeExpr, *tail: str) -> ast.expr: + if isinstance(expr, ast.expr): + return self.__chain_attr(expr, *tail) - elif isinstance(ref, ASTExpressionBuilder): - return self.__chain_attr(ref.build_expr(), *tail) + elif isinstance(expr, ASTExpressionBuilder): + return self.__chain_attr(expr.build_expr(), *tail) - elif isinstance(ref, (TypeVarInfo, NamedTypeInfo, LiteralTypeInfo, EnumTypeInfo)): - return self.__resolve_info(ref, tail) - - elif isinstance(ref, TypeDefinitionBuilder): - return self.__resolve_info(ref.info, tail) + elif isinstance(expr, TypeDefinitionBuilder): + return self.__resolve_info(expr.info, tail) else: - info = self.__inspector.inspect(ref) + info = self.__inspector.inspect(expr) return self.__resolve_info(info, tail) - @override - def resolve_annotation(self, ref: TypeRef) -> ast.expr: - if isinstance(ref, ast.expr): - return ref - - elif isinstance(ref, ASTExpressionBuilder): - return ref.build_annotation() - - elif isinstance(ref, (TypeVarInfo, NamedTypeInfo, LiteralTypeInfo, EnumTypeInfo)): - return self.__resolve_info(ref) - - elif isinstance(ref, TypeDefinitionBuilder): - return self.__resolve_info(ref.info) - - else: - info = self.__inspector.inspect(ref) - return self.__resolve_info(info) - @override def resolve_stmts( self, @@ -114,7 +93,6 @@ def set_current_scope( def __resolve_info(self, root: TypeInfo, tail: t.Sequence[str] = ()) -> ast.expr: nodes = dict[TypeInfo, ast.expr]() - # forward_refs = dict[TypeInfo, bool]() node: ast.expr diff --git a/src/astlab/types/inspector.py b/src/astlab/types/inspector.py index 3b69889..1868f9c 100644 --- a/src/astlab/types/inspector.py +++ b/src/astlab/types/inspector.py @@ -31,7 +31,10 @@ class TypeInspector: """Provides type info from runtime type.""" @lru_cache_method() - def inspect(self, type_: RuntimeType) -> TypeInfo: + def inspect(self, type_: t.Union[TypeInfo, RuntimeType]) -> TypeInfo: + if isinstance(type_, (ModuleInfo, TypeVarInfo, NamedTypeInfo, LiteralTypeInfo, EnumTypeInfo)): + return type_ + if isinstance( type_, t._LiteralGenericAlias, # type: ignore[attr-defined] # noqa: SLF001 diff --git a/src/astlab/types/model.py b/src/astlab/types/model.py index d61f097..957d709 100644 --- a/src/astlab/types/model.py +++ b/src/astlab/types/model.py @@ -19,7 +19,7 @@ ] import typing as t -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from functools import cache, cached_property from pathlib import Path from types import GenericAlias, ModuleType @@ -206,6 +206,9 @@ def parts(self) -> t.Sequence[str]: def qualname(self) -> str: return ".".join(self.parts) + def with_type_params(self, *type_params: TypeInfo) -> NamedTypeInfo: + return replace(self, type_params=type_params) + @cache # type: ignore[misc] def none_type_info() -> NamedTypeInfo: From 95523acea36f14be82c418770e9dd238420fc561 Mon Sep 17 00:00:00 2001 From: zerlok Date: Sat, 18 Oct 2025 00:22:32 +0200 Subject: [PATCH 6/8] fully support forward ref case and type vars in python < 3.12 --- src/astlab/builder.py | 23 ++++++++++++++------ src/astlab/resolver.py | 40 +++++++++++++++++++---------------- src/astlab/types/annotator.py | 36 ++++++++++++++----------------- tests/unit/test_builder.py | 3 +++ 4 files changed, 57 insertions(+), 45 deletions(-) diff --git a/src/astlab/builder.py b/src/astlab/builder.py index f083d8c..780e953 100644 --- a/src/astlab/builder.py +++ b/src/astlab/builder.py @@ -1305,7 +1305,7 @@ def build_stmt(self) -> t.Sequence[ast.stmt]: ] -class TypeVarBuilder(_BaseBuilder, ASTStatementBuilder): +class TypeVarBuilder(_BaseBuilder, TypeDefinitionBuilder, ASTStatementBuilder): def __init__( self, context: BuildContext, @@ -1317,21 +1317,32 @@ def __init__( super().__init__(context) self.__name = name self.__module = self._context.module - self.__namespace = self._context.namespace + self.__namespace = self._context.namespace if sys.version_info >= (3, 12) else self._context.namespace[:-1] self.__variance = variance self.__constraints = list[TypeExpr](constraints or ()) self.__lower = lower - def __enter__(self) -> TypeExpr: - return NamedTypeInfo( + self.__info = NamedTypeInfo( name=self.__name, module=self.__module, namespace=self.__namespace, ) + def __enter__(self) -> TypeInfo: + return self.__info + def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None: pass + @override + @property + def info(self) -> TypeInfo: + return self.__info + + @override + def ref(self) -> ASTExpressionBuilder: + return TypeRefBuilder(self._context, self.__info) + def invariant(self) -> Self: self.__variance = "invariant" return self @@ -1420,7 +1431,6 @@ def type_params(self, *params: TypeExpr) -> Expr: # NOTE: workaround for passing mypy typings in CI for python 3.12 if sys.version_info >= (3, 12): - # if False: @override def build_stmt(self) -> t.Sequence[ast.stmt]: @@ -1669,8 +1679,7 @@ def build_stmt(self) -> t.Sequence[ast.stmt]: stmts = [stmt for tv in self.__type_vars for stmt in tv.build_stmt()] if self.__type_vars: - # TODO: add type params to generic - self.__bases.insert(0, predef().generic) + self.__bases.insert(0, predef().generic.with_type_params(*(tv.info for tv in self.__type_vars))) stmts.append( ast.ClassDef( diff --git a/src/astlab/resolver.py b/src/astlab/resolver.py index 94bc07a..310c7e8 100644 --- a/src/astlab/resolver.py +++ b/src/astlab/resolver.py @@ -94,22 +94,19 @@ def set_current_scope( def __resolve_info(self, root: TypeInfo, tail: t.Sequence[str] = ()) -> ast.expr: nodes = dict[TypeInfo, ast.expr]() - node: ast.expr - for info in traverse_dfs_post_order(root, self.__get_children): resolved_info = self.__resolve_dependency(info) - if self.__is_forward_ref(resolved_info): - node = ast.Constant(value=self.__annotator.annotate(resolved_info, qualified=False)) + node = self.__build_expr(resolved_info) + if isinstance(info, NamedTypeInfo) and info.type_params: + params = [nodes[tp] for tp in info.type_params] + node = ast.Subscript( + value=node, + slice=ast.Tuple(elts=params) if len(params) > 1 else params[0], + ) - else: - node = self.__build_expr(resolved_info) - if isinstance(info, NamedTypeInfo) and info.type_params: - params = [nodes[tp] for tp in info.type_params] - node = ast.Subscript( - value=node, - slice=ast.Tuple(elts=params) if len(params) > 1 else params[0], - ) + if self.__is_forward_ref(resolved_info): + node = ast.Constant(value=ast.unparse(node)) nodes[info] = node @@ -160,12 +157,19 @@ def __build_expr(self, info: TypeInfo) -> ast.expr: return self.__chain_attr(ast.Name(id=head), *tail) - def __is_forward_ref(self, info: TypeInfo) -> bool: - return ( - sys.version_info < (3, 12) - and info.module == self.__module - and (*info.namespace, info.name) == self.__namespace - ) + if sys.version_info >= (3, 12): + + def __is_forward_ref(self, _: TypeInfo) -> bool: + return False + + else: + + def __is_forward_ref(self, info: TypeInfo) -> bool: + return ( + not isinstance(info, ModuleInfo) + and info.module == self.__module + and (*info.namespace, info.name) == self.__namespace + ) def __chain_attr(self, expr: ast.expr, *tail: str) -> ast.expr: for attr in tail: diff --git a/src/astlab/types/annotator.py b/src/astlab/types/annotator.py index 88f5a54..a7e289b 100644 --- a/src/astlab/types/annotator.py +++ b/src/astlab/types/annotator.py @@ -36,35 +36,21 @@ def __init__(self, loader: t.Optional[TypeLoader] = None) -> None: self.__loader = loader or TypeLoader() @lru_cache_method() - def annotate(self, info: TypeInfo, *, qualified: bool = True) -> str: + def annotate(self, info: TypeInfo) -> str: if info == none_type_info(): return "None" if info == ellipsis_type_info(): return "..." - annotation = self.__annotate(info) - return ".".join((*info.parts[:-1], annotation)) if qualified else annotation - - @lru_cache_method() - def parse(self, qualname: str) -> TypeInfo: - node = ast.parse(qualname) - - if len(node.body) != 1: - msg = "invalid qualified name" - raise ValueError(msg, qualname) - - return _ExprParser(self.__loader).parse(node) - - def __annotate(self, info: TypeInfo) -> str: if isinstance(info, ModuleInfo): - annotation = "module" + annotation = "builtins.module" elif isinstance(info, TypeVarInfo): - annotation = info.name + annotation = info.qualname elif isinstance(info, NamedTypeInfo): - annotation = info.name + annotation = info.qualname if info.type_params: # TODO: fix recursive type @@ -73,16 +59,26 @@ def __annotate(self, info: TypeInfo) -> str: elif isinstance(info, LiteralTypeInfo): vals = ", ".join(repr(v) for v in info.values) - annotation = f"Literal[{vals}]" + annotation = f"{info.qualname}[{vals}]" elif isinstance(info, EnumTypeInfo): - annotation = info.name + annotation = info.qualname else: assert_never(info) return annotation + @lru_cache_method() + def parse(self, qualname: str) -> TypeInfo: + node = ast.parse(qualname) + + if len(node.body) != 1: + msg = "invalid qualified name" + raise ValueError(msg, qualname) + + return _ExprParser(self.__loader).parse(node) + class _ExprParser(ast.NodeVisitor): def __init__(self, loader: TypeLoader) -> None: diff --git a/tests/unit/test_builder.py b/tests/unit/test_builder.py index 6962843..7971f9b 100644 --- a/tests/unit/test_builder.py +++ b/tests/unit/test_builder.py @@ -369,6 +369,9 @@ def build_type_alias_before_syntax_312() -> ModuleASTBuilder: builtins.list['Json'], builtins.dict[builtins.str, 'Json'], ] + T1 = typing.TypeVar('T1') + T2 = typing.TypeVar('T2') + Nested: typing.TypeAlias = typing.Union[T1, T2, typing.Sequence['Nested[T1, T2]']] """ with build_module("alias") as mod: From 4fd217e80cb78f37d9ab7df6a71821967d84eac6 Mon Sep 17 00:00:00 2001 From: zerlok Date: Sat, 18 Oct 2025 00:36:20 +0200 Subject: [PATCH 7/8] update README.md and minor version --- README.md | 187 +++++++++++++++++++++++++++++++++++-------------- pyproject.toml | 2 +- 2 files changed, 137 insertions(+), 52 deletions(-) diff --git a/README.md b/README.md index f3de009..f525756 100644 --- a/README.md +++ b/README.md @@ -7,50 +7,40 @@ [![Downloads](https://img.shields.io/pypi/dm/astlab.svg)](https://pypistats.org/packages/astlab) [![GitHub stars](https://img.shields.io/github/stars/zerlok/astlab)](https://github.com/zerlok/astlab/stargazers) -**astlab** is a Python library that provides an intuitive API for building and manipulating Abstract Syntax Trees (ASTs) to generate Python code. With **astlab**, you can easily create Python modules, classes, fields, and more using a simple and readable syntax, and convert the AST back into executable Python code. +**astlab** is a Python library that provides an intuitive API for building and manipulating Abstract Syntax Trees (ASTs) to generate Python code. With **astlab**, you can easily construct Python modules, classes, functions, type aliases, and generics using a fluent API — then render them into valid, executable Python code. ## Features -- **Easy AST construction**: Build Python code using a fluent and intuitive API. -- **Code generation**: Convert your AST into valid Python code, forget about jinja templates. -- **Supports nested scopes & auto imports**: Create nested classes, methods, and fields effortlessly. Reference types from other modules easily. -- **Highly customizable**: Extend and modify the API to suit your needs. +* **Easy AST construction**: Build Python code using a fluent, structured API. +* **Code generation**: Generate fully valid, formatted Python source without templates. +* **Supports nested scopes & auto imports**: Create classes, methods, and nested modules with automatic import resolution. +* **Type system support**: Define and use **type variables**, **generic classes**, and **type aliases** compatible with Python 3.9–3.14 syntax. +* **Highly customizable**: Extend the builder model for any Python AST use case. ## Installation -You can install **astlab** from PyPI using `pip`: - ```bash pip install astlab ``` ## Usage -### Simple example - -Here's a basic example of how to use **astlab** to create a Python module with a dataclass. +### Simple Example ```python import ast import astlab -# Create a new Python module with astlab.module("foo") as foo: - # Build a "Bar" dataclass with foo.class_def("Bar").dataclass() as bar: - # Define a field "spam" of type int bar.field_def("spam", int) -# Generate and print the Python code from the AST print(foo.render()) -# Or you can just get the AST print(ast.dump(foo.build(), indent=4)) ``` #### Output -Render: - ```python import builtins import dataclasses @@ -60,39 +50,9 @@ class Bar: spam: builtins.int ``` -Dump built AST: +--- -```python -Module( - body=[ - Import( - names=[ - alias(name='builtins')]), - Import( - names=[ - alias(name='dataclasses')]), - ClassDef( - name='Bar', - bases=[], - keywords=[], - body=[ - AnnAssign( - target=Name(id='spam'), - annotation=Attribute( - value=Name(id='builtins'), - attr='int'), - simple=1)], - decorator_list=[ - Call( - func=Attribute( - value=Name(id='dataclasses'), - attr='dataclass'), - args=[], - keywords=[])])], - type_ignores=[]) -``` - -### Func def & call example +### Function Definition & Call Example ```python import astlab @@ -118,7 +78,9 @@ class Bar: return result ``` -### Type reference example +--- + +### Type Reference Example ```python import astlab @@ -146,4 +108,127 @@ class Eggs(main.foo.Bar): def do_stuff(self) -> typing.Optional[main.foo.Bar]: pass -``` \ No newline at end of file +``` + +--- + +### Generics and Type Variables + +**astlab** supports defining type variables and generic classes. +Both the legacy (`typing.TypeVar`) and modern (`class Node[T: int]`) syntaxes are supported depending on Python version. + +#### Example + +```python +import astlab + +with astlab.module("generic") as mod: + with mod.class_def("Node") as node, node.type_var("T").lower(int) as T: + node.field_def("value", T) + node.field_def("parent", node.ref().type_params(T)) + +print(mod.render()) +``` + +#### Output (python < 3.12) + +```python +import builtins +import typing + +T = typing.TypeVar('T', bound=builtins.int) + +class Node(typing.Generic[T]): + value: T + parent: 'Node[T]' +``` + +#### Output (python ≥ 3.12) + +```python +import builtins + +class Node[T: builtins.int]: + value: T + parent: Node[T] +``` + +--- + +### Type Aliases + +**astlab** allows declarative creation of type aliases, including recursive and generic aliases. +It automatically emits valid syntax for both `typing.TypeAlias` (pre-3.12) and `type X = Y` (3.12+). + +#### Example + +```python +import astlab +from astlab.types import predef + +with astlab.module("alias") as mod: + mod.type_alias("MyInt").assign(int) + + with mod.type_alias("Json") as json_alias: + json_alias.assign( + json_alias.union_type( + None, + bool, + int, + float, + str, + mod.list_type(json_alias), + mod.dict_type(str, json_alias), + ) + ) + + with ( + mod.type_alias("Nested") as nested_alias, + nested_alias.type_var("T") as T, + ): + nested_alias.assign( + nested_alias.union_type( + T + nested_alias.sequence_type(nested_alias.type_params(T)), + ) + ) +``` + +#### Output (python < 3.12) + +```python +import builtins +import typing + +MyInt: typing.TypeAlias = builtins.int +Json: typing.TypeAlias = typing.Union[ + None, + builtins.bool, + builtins.int, + builtins.float, + builtins.str, + builtins.list['Json'], + builtins.dict[builtins.str, 'Json'], +] +T = typing.TypeVar("T") +Nested: typing.TypeAlias = typing.Union[T, typing.Sequence['Nested[T]']] +``` + +#### Output (python ≥ 3.12) + +```python +import builtins +import typing + +type MyInt = builtins.int +type Json = typing.Union[ + None, + builtins.bool, + builtins.int, + builtins.float, + builtins.str, + builtins.list[Json], + builtins.dict[builtins.str, Json], +] +type Nested[T] = typing.Union[T, typing.Sequence[Nested[T]]] +``` diff --git a/pyproject.toml b/pyproject.toml index 0f38bbe..c1f1e13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "astlab" -version = "0.4.5" +version = "0.5.0" description = "provides an intuitive API for building and manipulating Abstract Syntax Trees (ASTs) to generate Python code." authors = ["zerlok "] readme = "README.md" From c5c5cb0cb181fd728fd5b09d4b5cd9fe8d0c4d35 Mon Sep 17 00:00:00 2001 From: zerlok Date: Sat, 18 Oct 2025 00:54:06 +0200 Subject: [PATCH 8/8] allow forward refs only from python 3.14 --- README.md | 43 ++++++++++++--- src/astlab/resolver.py | 2 +- tests/unit/test_builder.py | 107 +++++++++++++++++++++++++++++++++---- 3 files changed, 134 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index f525756..a3d4a01 100644 --- a/README.md +++ b/README.md @@ -125,7 +125,7 @@ import astlab with astlab.module("generic") as mod: with mod.class_def("Node") as node, node.type_var("T").lower(int) as T: node.field_def("value", T) - node.field_def("parent", node.ref().type_params(T)) + node.field_def("parent", node.ref().type_params(type_var).optional(), mod.none()) print(mod.render()) ``` @@ -140,17 +140,29 @@ T = typing.TypeVar('T', bound=builtins.int) class Node(typing.Generic[T]): value: T - parent: 'Node[T]' + parent: typing.Optional['Node[T]'] = None ``` -#### Output (python ≥ 3.12) +#### Output (python 3.12, 3.13) ```python import builtins +import typing + +class Node[T: builtins.int]: + value: T + parent: typing.Optional['Node[T]'] = None +``` + +#### Output (python ≥ 3.14) + +```python +import builtins +import typing class Node[T: builtins.int]: value: T - parent: Node[T] + parent: typing.Optional[Node[T]] = None ``` --- @@ -214,7 +226,26 @@ T = typing.TypeVar("T") Nested: typing.TypeAlias = typing.Union[T, typing.Sequence['Nested[T]']] ``` -#### Output (python ≥ 3.12) +#### Output (python 3.12, 3.13) + +```python +import builtins +import typing + +type MyInt = builtins.int +type Json = typing.Union[ + None, + builtins.bool, + builtins.int, + builtins.float, + builtins.str, + builtins.list['Json'], + builtins.dict[builtins.str, 'Json'], +] +type Nested[T] = typing.Union[T, typing.Sequence['Nested[T]']] +``` + +#### Output (python ≥ 3.14) ```python import builtins @@ -231,4 +262,4 @@ type Json = typing.Union[ builtins.dict[builtins.str, Json], ] type Nested[T] = typing.Union[T, typing.Sequence[Nested[T]]] -``` +``` \ No newline at end of file diff --git a/src/astlab/resolver.py b/src/astlab/resolver.py index 310c7e8..817d71a 100644 --- a/src/astlab/resolver.py +++ b/src/astlab/resolver.py @@ -157,7 +157,7 @@ def __build_expr(self, info: TypeInfo) -> ast.expr: return self.__chain_attr(ast.Name(id=head), *tail) - if sys.version_info >= (3, 12): + if sys.version_info >= (3, 14): def __is_forward_ref(self, _: TypeInfo) -> bool: return False diff --git a/tests/unit/test_builder.py b/tests/unit/test_builder.py index 7971f9b..a89880c 100644 --- a/tests/unit/test_builder.py +++ b/tests/unit/test_builder.py @@ -297,7 +297,7 @@ def build_index_slice() -> ModuleASTBuilder: marks=( pytest.mark.skipif( condition="sys.version_info >= (3, 12)", - reason="syntax `type XXX = YYY` was introduced since python version 3.12", + reason="new type var syntax can be used", ), ), ) @@ -310,13 +310,13 @@ def build_generic_class_before_312() -> ModuleASTBuilder: class Node(typing.Generic[T]): value: T - parent: 'Node[T]' + parent: typing.Optional['Node[T]'] = None """ with build_module("generic") as mod: with mod.class_def("Node") as node, node.type_var("T").lower(int) as type_var: node.field_def("value", type_var) - node.field_def("parent", node.ref().type_params(type_var)) + node.field_def("parent", node.ref().type_params(type_var).optional(), mod.none()) return mod @@ -324,24 +324,51 @@ class Node(typing.Generic[T]): @_to_param( marks=( pytest.mark.skipif( - condition="sys.version_info < (3, 12)", - reason="syntax `type XXX = YYY` was introduced since python version 3.12", + condition="sys.version_info < (3, 12) or sys.version_info >= (3, 14)", + reason="type var syntax is available since python version 3.12", ), ), ) -def build_generic_class_after_312() -> ModuleASTBuilder: +def build_generic_class_312_313() -> ModuleASTBuilder: """ import builtins + import typing + + class Node[T : builtins.int]: + value: T + parent: typing.Optional['Node[T]'] = None + """ + + with build_module("generic") as mod: + with mod.class_def("Node") as node, node.type_var("T").lower(int) as type_var: + node.field_def("value", type_var) + node.field_def("parent", node.ref().type_params(type_var).optional(), mod.none()) + + return mod + + +@_to_param( + marks=( + pytest.mark.skipif( + condition="sys.version_info < (3, 14)", + reason="type var syntax is available since python version 3.12 and forward ref is supported since 3.14", + ), + ), +) +def build_generic_class_after_314() -> ModuleASTBuilder: + """ + import builtins + import typing class Node[T : builtins.int]: value: T - parent: Node[T] + parent: typing.Optional[Node[T]] = None """ with build_module("generic") as mod: with mod.class_def("Node") as node, node.type_var("T").lower(int) as type_var: node.field_def("value", type_var) - node.field_def("parent", node.ref().type_params(type_var)) + node.field_def("parent", node.ref().type_params(type_var).optional(), mod.none()) return mod @@ -350,7 +377,7 @@ class Node[T : builtins.int]: marks=( pytest.mark.skipif( condition="sys.version_info >= (3, 12)", - reason="syntax `type XXX = YYY` was introduced since python version 3.12", + reason="new type alias syntax can be used", ), ), ) @@ -409,8 +436,8 @@ def build_type_alias_before_syntax_312() -> ModuleASTBuilder: @_to_param( marks=( pytest.mark.skipif( - condition="sys.version_info < (3, 12)", - reason="syntax `type XXX = YYY` was introduced since python version 3.12", + condition="sys.version_info < (3, 12) or sys.version_info >= (3, 14)", + reason="type alias syntax is available since python version 3.12", ), ), ) @@ -419,6 +446,64 @@ def build_type_alias_syntax_312() -> ModuleASTBuilder: import builtins import typing + type MyInt = builtins.int + type Json = typing.Union[ + None, + builtins.bool, + builtins.int, + builtins.float, + builtins.str, + builtins.list['Json'], + builtins.dict[builtins.str, 'Json'] + ] + type Nested[T1, T2] = typing.Union[T1, T2, typing.Sequence['Nested[T1, T2]']] + """ + + with build_module("alias") as mod: + mod.type_alias("MyInt").assign(predef().int) + + with mod.type_alias("Json") as json_alias: + json_alias.assign( + json_alias.union_type( + predef().none, + predef().bool, + predef().int, + predef().float, + predef().str, + mod.list_type(json_alias), + mod.dict_type(predef().str, json_alias), + ) + ) + + with ( + mod.type_alias("Nested") as nested_alias, + nested_alias.type_var("T1") as type_var_1, + nested_alias.type_var("T2") as type_var_2, + ): + nested_alias.assign( + nested_alias.union_type( + type_var_1, + type_var_2, + nested_alias.sequence_type(nested_alias.type_params(type_var_1, type_var_2)), + ) + ) + + return mod + + +@_to_param( + marks=( + pytest.mark.skipif( + condition="sys.version_info < (3, 14)", + reason="type alias syntax is available since python version 3.12 and forward ref is supported since 3.14", + ), + ), +) +def build_type_alias_syntax_314() -> ModuleASTBuilder: + """ + import builtins + import typing + type MyInt = builtins.int type Json = typing.Union[ None,