diff --git a/README.md b/README.md index f3de009..a3d4a01 100644 --- a/README.md +++ b/README.md @@ -7,50 +7,40 @@ [![Downloads](https://img.shields.io/pypi/dm/astlab.svg)](https://pypistats.org/packages/astlab) [![GitHub stars](https://img.shields.io/github/stars/zerlok/astlab)](https://github.com/zerlok/astlab/stargazers) -**astlab** is a Python library that provides an intuitive API for building and manipulating Abstract Syntax Trees (ASTs) to generate Python code. With **astlab**, you can easily create Python modules, classes, fields, and more using a simple and readable syntax, and convert the AST back into executable Python code. +**astlab** is a Python library that provides an intuitive API for building and manipulating Abstract Syntax Trees (ASTs) to generate Python code. With **astlab**, you can easily construct Python modules, classes, functions, type aliases, and generics using a fluent API — then render them into valid, executable Python code. ## Features -- **Easy AST construction**: Build Python code using a fluent and intuitive API. -- **Code generation**: Convert your AST into valid Python code, forget about jinja templates. -- **Supports nested scopes & auto imports**: Create nested classes, methods, and fields effortlessly. Reference types from other modules easily. -- **Highly customizable**: Extend and modify the API to suit your needs. +* **Easy AST construction**: Build Python code using a fluent, structured API. +* **Code generation**: Generate fully valid, formatted Python source without templates. +* **Supports nested scopes & auto imports**: Create classes, methods, and nested modules with automatic import resolution. +* **Type system support**: Define and use **type variables**, **generic classes**, and **type aliases** compatible with Python 3.9–3.14 syntax. +* **Highly customizable**: Extend the builder model for any Python AST use case. ## Installation -You can install **astlab** from PyPI using `pip`: - ```bash pip install astlab ``` ## Usage -### Simple example - -Here's a basic example of how to use **astlab** to create a Python module with a dataclass. +### Simple Example ```python import ast import astlab -# Create a new Python module with astlab.module("foo") as foo: - # Build a "Bar" dataclass with foo.class_def("Bar").dataclass() as bar: - # Define a field "spam" of type int bar.field_def("spam", int) -# Generate and print the Python code from the AST print(foo.render()) -# Or you can just get the AST print(ast.dump(foo.build(), indent=4)) ``` #### Output -Render: - ```python import builtins import dataclasses @@ -60,39 +50,9 @@ class Bar: spam: builtins.int ``` -Dump built AST: +--- -```python -Module( - body=[ - Import( - names=[ - alias(name='builtins')]), - Import( - names=[ - alias(name='dataclasses')]), - ClassDef( - name='Bar', - bases=[], - keywords=[], - body=[ - AnnAssign( - target=Name(id='spam'), - annotation=Attribute( - value=Name(id='builtins'), - attr='int'), - simple=1)], - decorator_list=[ - Call( - func=Attribute( - value=Name(id='dataclasses'), - attr='dataclass'), - args=[], - keywords=[])])], - type_ignores=[]) -``` - -### Func def & call example +### Function Definition & Call Example ```python import astlab @@ -118,7 +78,9 @@ class Bar: return result ``` -### Type reference example +--- + +### Type Reference Example ```python import astlab @@ -146,4 +108,158 @@ class Eggs(main.foo.Bar): def do_stuff(self) -> typing.Optional[main.foo.Bar]: pass +``` + +--- + +### Generics and Type Variables + +**astlab** supports defining type variables and generic classes. +Both the legacy (`typing.TypeVar`) and modern (`class Node[T: int]`) syntaxes are supported depending on Python version. + +#### Example + +```python +import astlab + +with astlab.module("generic") as mod: + with mod.class_def("Node") as node, node.type_var("T").lower(int) as T: + node.field_def("value", T) + node.field_def("parent", node.ref().type_params(type_var).optional(), mod.none()) + +print(mod.render()) +``` + +#### Output (python < 3.12) + +```python +import builtins +import typing + +T = typing.TypeVar('T', bound=builtins.int) + +class Node(typing.Generic[T]): + value: T + parent: typing.Optional['Node[T]'] = None +``` + +#### Output (python 3.12, 3.13) + +```python +import builtins +import typing + +class Node[T: builtins.int]: + value: T + parent: typing.Optional['Node[T]'] = None +``` + +#### Output (python ≥ 3.14) + +```python +import builtins +import typing + +class Node[T: builtins.int]: + value: T + parent: typing.Optional[Node[T]] = None +``` + +--- + +### Type Aliases + +**astlab** allows declarative creation of type aliases, including recursive and generic aliases. +It automatically emits valid syntax for both `typing.TypeAlias` (pre-3.12) and `type X = Y` (3.12+). + +#### Example + +```python +import astlab +from astlab.types import predef + +with astlab.module("alias") as mod: + mod.type_alias("MyInt").assign(int) + + with mod.type_alias("Json") as json_alias: + json_alias.assign( + json_alias.union_type( + None, + bool, + int, + float, + str, + mod.list_type(json_alias), + mod.dict_type(str, json_alias), + ) + ) + + with ( + mod.type_alias("Nested") as nested_alias, + nested_alias.type_var("T") as T, + ): + nested_alias.assign( + nested_alias.union_type( + T + nested_alias.sequence_type(nested_alias.type_params(T)), + ) + ) +``` + +#### Output (python < 3.12) + +```python +import builtins +import typing + +MyInt: typing.TypeAlias = builtins.int +Json: typing.TypeAlias = typing.Union[ + None, + builtins.bool, + builtins.int, + builtins.float, + builtins.str, + builtins.list['Json'], + builtins.dict[builtins.str, 'Json'], +] +T = typing.TypeVar("T") +Nested: typing.TypeAlias = typing.Union[T, typing.Sequence['Nested[T]']] +``` + +#### Output (python 3.12, 3.13) + +```python +import builtins +import typing + +type MyInt = builtins.int +type Json = typing.Union[ + None, + builtins.bool, + builtins.int, + builtins.float, + builtins.str, + builtins.list['Json'], + builtins.dict[builtins.str, 'Json'], +] +type Nested[T] = typing.Union[T, typing.Sequence['Nested[T]']] +``` + +#### Output (python ≥ 3.14) + +```python +import builtins +import typing + +type MyInt = builtins.int +type Json = typing.Union[ + None, + builtins.bool, + builtins.int, + builtins.float, + builtins.str, + builtins.list[Json], + builtins.dict[builtins.str, Json], +] +type Nested[T] = typing.Union[T, typing.Sequence[Nested[T]]] ``` \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 0f38bbe..c1f1e13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "astlab" -version = "0.4.5" +version = "0.5.0" description = "provides an intuitive API for building and manipulating Abstract Syntax Trees (ASTs) to generate Python code." authors = ["zerlok "] readme = "README.md" diff --git a/src/astlab/abc.py b/src/astlab/abc.py index a168cbd..ee5149e 100644 --- a/src/astlab/abc.py +++ b/src/astlab/abc.py @@ -8,6 +8,7 @@ "Expr", "Stmt", "TypeDefinitionBuilder", + "TypeExpr", "TypeRef", ] @@ -31,10 +32,6 @@ class ASTExpressionBuilder(metaclass=abc.ABCMeta): def build_expr(self) -> ast.expr: raise NotImplementedError - @abc.abstractmethod - def build_annotation(self) -> ast.expr: - raise NotImplementedError - class ASTStatementBuilder(metaclass=abc.ABCMeta): @abc.abstractmethod @@ -55,21 +52,13 @@ def ref(self) -> ASTExpressionBuilder: Expr: TypeAlias = t.Union[ast.expr, ASTExpressionBuilder] Stmt: TypeAlias = t.Union[ast.stmt, ASTStatementBuilder, Expr] -TypeRef: TypeAlias = t.Union[ - Expr, - RuntimeType, - TypeInfo, - TypeDefinitionBuilder, -] +TypeRef: TypeAlias = t.Union[RuntimeType, TypeInfo, TypeDefinitionBuilder] +TypeExpr: TypeAlias = t.Union[Expr, TypeRef] class ASTResolver(metaclass=abc.ABCMeta): @abc.abstractmethod - def resolve_expr(self, ref: TypeRef, *tail: str) -> ast.expr: - raise NotImplementedError - - @abc.abstractmethod - def resolve_annotation(self, ref: TypeRef) -> ast.expr: + def resolve_expr(self, expr: TypeExpr, *tail: str) -> ast.expr: raise NotImplementedError @abc.abstractmethod diff --git a/src/astlab/builder.py b/src/astlab/builder.py index aa1de98..780e953 100644 --- a/src/astlab/builder.py +++ b/src/astlab/builder.py @@ -5,7 +5,6 @@ "CallASTBuilder", "ClassScopeASTBuilder", "ClassStatementASTBuilder", - "ClassTypeRefBuilder", "Comprehension", "ForStatementASTBuilder", "FuncArgInfo", @@ -17,6 +16,7 @@ "PackageASTBuilder", "ScopeASTBuilder", "TryStatementASTBuilder", + "TypeRefBuilder", "WhileStatementASTBuilder", "WithStatementASTBuilder", "build_module", @@ -45,18 +45,18 @@ Expr, Stmt, TypeDefinitionBuilder, + TypeExpr, TypeRef, ) from astlab.context import BuildContext from astlab.resolver import DefaultASTResolver from astlab.types import ( - LiteralTypeInfo, ModuleInfo, NamedTypeInfo, PackageInfo, - RuntimeType, TypeInfo, TypeInspector, + TypeVarVariance, predef, ) from astlab.writer import render_module, write_module @@ -124,12 +124,9 @@ def __init__(self, context: BuildContext) -> None: def _scope(self) -> ScopeASTBuilder: return ScopeASTBuilder(self._context) - def _normalize_expr(self, expr: TypeRef, *tail: str) -> ast.expr: + def _normalize_expr(self, expr: TypeExpr, *tail: str) -> ast.expr: return self._context.resolver.resolve_expr(expr, *tail) - def _normalize_annotation(self, expr: TypeRef) -> ast.expr: - return self._context.resolver.resolve_annotation(expr) - def _normalize_body(self, body: t.Sequence[Stmt], docs: t.Optional[t.Sequence[str]] = None) -> list[ast.stmt]: return self._context.resolver.resolve_stmts(*body, docs=docs, pass_if_empty=True) @@ -171,10 +168,6 @@ def __truediv__(self, other: Expr) -> Self: def build_expr(self) -> ast.expr: return self._normalize_expr(self.__factory()) - @override - def build_annotation(self) -> ast.expr: - return self._normalize_annotation(self.__factory()) - def stmt(self, *, append: bool = True) -> ast.stmt: node = ast.Expr(value=self.build_expr()) if append: @@ -232,7 +225,7 @@ class AttrASTBuilder(BaseASTExpressionBuilder): def __init__( self, context: BuildContext, - head: t.Union[str, TypeRef], + head: t.Union[str, TypeExpr], *tail: str, is_awaited: bool = False, ) -> None: @@ -309,7 +302,7 @@ class CallASTBuilder(BaseASTExpressionBuilder): def __init__( self, context: BuildContext, - func: TypeRef, + func: TypeExpr, args: t.Optional[t.Sequence[Expr]] = None, kwargs: t.Optional[t.Mapping[str, Expr]] = None, ) -> None: @@ -384,7 +377,7 @@ class SliceASTBuilder(BaseASTExpressionBuilder): def __init__( self, context: BuildContext, - value: TypeRef, + value: TypeExpr, index: t.Optional[t.Union[Expr, Slice]] = None, *, is_awaited: bool = False, @@ -440,199 +433,188 @@ def __create_expr(self) -> ast.expr: return node -class ClassTypeRefBuilder(_BaseBuilder, ASTExpressionBuilder): +class TypeRefBuilder(ASTExpressionBuilder): def __init__( self, context: BuildContext, info: TypeInfo, - transform: t.Optional[t.Callable[[BuildContext, TypeInfo], Expr]] = None, + transform: t.Optional[t.Callable[[TypeInfo], TypeInfo]] = None, ) -> None: - super().__init__(context) + self.__context = context self.__info = info - self.__transform: t.Callable[[BuildContext, TypeInfo], Expr] = ( - transform if transform is not None else self.__ident - ) + self.__transform: t.Callable[[TypeInfo], TypeInfo] = transform if transform is not None else self.__ident @property def info(self) -> TypeInfo: return self.__info - def optional(self) -> ClassTypeRefBuilder: + def optional(self) -> TypeRefBuilder: def transform( - inner: t.Callable[[BuildContext, TypeInfo], Expr], - context: BuildContext, + inner: t.Callable[[TypeInfo], TypeInfo], info: TypeInfo, - ) -> Expr: - return self._scope.optional_type(inner(context, info)) + ) -> TypeInfo: + return predef().optional.with_type_params(inner(info)) return self.__wrap(transform) - def collection(self) -> ClassTypeRefBuilder: + def collection(self) -> TypeRefBuilder: def transform( - inner: t.Callable[[BuildContext, TypeInfo], Expr], - context: BuildContext, + inner: t.Callable[[TypeInfo], TypeInfo], info: TypeInfo, - ) -> Expr: - return self._scope.collection_type(inner(context, info)) + ) -> TypeInfo: + return predef().collection.with_type_params(inner(info)) return self.__wrap(transform) - def set(self) -> ClassTypeRefBuilder: + def set(self) -> TypeRefBuilder: def transform( - inner: t.Callable[[BuildContext, TypeInfo], Expr], - context: BuildContext, + inner: t.Callable[[TypeInfo], TypeInfo], info: TypeInfo, - ) -> Expr: - return self._scope.generic_type(set, inner(context, info)) + ) -> TypeInfo: + return predef().set.with_type_params(inner(info)) return self.__wrap(transform) - def sequence(self, *, mutable: bool = False) -> ClassTypeRefBuilder: + def sequence(self, *, mutable: bool = False) -> TypeRefBuilder: def transform( - inner: t.Callable[[BuildContext, TypeInfo], Expr], - context: BuildContext, + inner: t.Callable[[TypeInfo], TypeInfo], info: TypeInfo, - ) -> Expr: - return self._scope.sequence_type(inner(context, info), mutable=mutable) + ) -> TypeInfo: + return (predef().mutable_sequence if mutable else predef().sequence).with_type_params(inner(info)) return self.__wrap(transform) - def list(self) -> ClassTypeRefBuilder: + def list(self) -> TypeRefBuilder: def transform( - inner: t.Callable[[BuildContext, TypeInfo], Expr], - context: BuildContext, + inner: t.Callable[[TypeInfo], TypeInfo], info: TypeInfo, - ) -> Expr: - return self._scope.generic_type(list, inner(context, info)) + ) -> TypeInfo: + return predef().list.with_type_params(inner(info)) return self.__wrap(transform) - def mapping_key(self, value: TypeRef, *, mutable: bool = False) -> ClassTypeRefBuilder: + def mapping_key(self, value: TypeRef, *, mutable: bool = False) -> TypeRefBuilder: def transform( - inner: t.Callable[[BuildContext, TypeInfo], Expr], - context: BuildContext, + inner: t.Callable[[TypeInfo], TypeInfo], info: TypeInfo, - ) -> Expr: - return self._scope.mapping_type(inner(context, info), value, mutable=mutable) + ) -> TypeInfo: + return (predef().mutable_mapping if mutable else predef().mapping).with_type_params( + inner(info), self.__context.inspector.inspect(value) + ) return self.__wrap(transform) - def dict_key(self, value: TypeRef) -> ClassTypeRefBuilder: + def dict_key(self, value: TypeRef) -> TypeRefBuilder: def transform( - inner: t.Callable[[BuildContext, TypeInfo], Expr], - context: BuildContext, + inner: t.Callable[[TypeInfo], TypeInfo], info: TypeInfo, - ) -> Expr: - return self._scope.generic_type(dict, inner(context, info), value) + ) -> TypeInfo: + return predef().dict.with_type_params(inner(info), self.__context.inspector.inspect(value)) return self.__wrap(transform) - def mapping_value(self, key: TypeRef, *, mutable: bool = False) -> ClassTypeRefBuilder: + def mapping_value(self, key: TypeRef, *, mutable: bool = False) -> TypeRefBuilder: def transform( - inner: t.Callable[[BuildContext, TypeInfo], Expr], - context: BuildContext, + inner: t.Callable[[TypeInfo], TypeInfo], info: TypeInfo, - ) -> Expr: - return self._scope.mapping_type(key, inner(context, info), mutable=mutable) + ) -> TypeInfo: + return (predef().mutable_mapping if mutable else predef().mapping).with_type_params( + self.__context.inspector.inspect(key), inner(info) + ) return self.__wrap(transform) - def dict_value(self, key: TypeRef) -> ClassTypeRefBuilder: + def dict_value(self, key: TypeRef) -> TypeRefBuilder: def transform( - inner: t.Callable[[BuildContext, TypeInfo], Expr], - context: BuildContext, + inner: t.Callable[[TypeInfo], TypeInfo], info: TypeInfo, - ) -> Expr: - return self._scope.generic_type(dict, key, inner(context, info)) + ) -> TypeInfo: + return predef().dict.with_type_params(self.__context.inspector.inspect(key), inner(info)) return self.__wrap(transform) - def context_manager(self, *, is_async: bool = False) -> ClassTypeRefBuilder: + def context_manager(self, *, is_async: bool = False) -> TypeRefBuilder: def transform( - inner: t.Callable[[BuildContext, TypeInfo], Expr], - context: BuildContext, + inner: t.Callable[[TypeInfo], TypeInfo], info: TypeInfo, - ) -> Expr: - return self._scope.context_manager_type(inner(context, info), is_async=is_async) + ) -> TypeInfo: + return (predef().async_context_manager if is_async else predef().context_manager).with_type_params( + inner(info) + ) return self.__wrap(transform) - def iterator(self, *, is_async: bool = False) -> ClassTypeRefBuilder: + def iterator(self, *, is_async: bool = False) -> TypeRefBuilder: def transform( - inner: t.Callable[[BuildContext, TypeInfo], Expr], - context: BuildContext, + inner: t.Callable[[TypeInfo], TypeInfo], info: TypeInfo, - ) -> Expr: - return self._scope.iterable_type(inner(context, info), is_async=is_async) + ) -> TypeInfo: + return (predef().async_iterator if is_async else predef().iterator).with_type_params(inner(info)) return self.__wrap(transform) - def iterable(self, *, is_async: bool = False) -> ClassTypeRefBuilder: + def iterable(self, *, is_async: bool = False) -> TypeRefBuilder: def transform( - inner: t.Callable[[BuildContext, TypeInfo], Expr], - context: BuildContext, + inner: t.Callable[[TypeInfo], TypeInfo], info: TypeInfo, - ) -> Expr: - return self._scope.iterable_type(inner(context, info), is_async=is_async) + ) -> TypeInfo: + return (predef().async_iterable if is_async else predef().iterable).with_type_params(inner(info)) + + return self.__wrap(transform) + + def type_params(self, *params: TypeRef) -> TypeRefBuilder: + def transform( + inner: t.Callable[[TypeInfo], TypeInfo], + info: TypeInfo, + ) -> TypeInfo: + origin = inner(info) + + if not isinstance(origin, NamedTypeInfo): + msg = "named type info was expected to apply type params" + raise TypeError(msg, origin, params) + + return origin.with_type_params(*(self.__context.inspector.inspect(param) for param in params)) return self.__wrap(transform) def attr(self, *tail: str) -> AttrASTBuilder: - return AttrASTBuilder(self._context, self, *tail) + return AttrASTBuilder(self.__context, self, *tail) def init( self, args: t.Optional[t.Sequence[Expr]] = None, kwargs: t.Optional[t.Mapping[str, Expr]] = None, ) -> CallASTBuilder: - return CallASTBuilder(self._context, self, args, kwargs) + return CallASTBuilder(self.__context, self, args, kwargs) @override def build_expr(self) -> ast.expr: - return self._normalize_expr(self.__transform(self._context, self.__info)) - - @override - def build_annotation(self) -> ast.expr: - return self._normalize_annotation(self.__transform(self._context, self.__info)) + return self.__context.resolver.resolve_expr(self.__transform(self.__info)) - def __ident(self, context: BuildContext, info: TypeInfo) -> Expr: - return context.resolver.resolve_expr(info) + def __ident(self, info: TypeInfo) -> TypeInfo: + return info def __wrap( self, - transform: t.Callable[[t.Callable[[BuildContext, TypeInfo], Expr], BuildContext, TypeInfo], Expr], - ) -> ClassTypeRefBuilder: - return self.__class__(self._context, self.__info, partial(transform, self.__transform)) - - -@dataclass(frozen=True) -class Comprehension: - target: Expr - items: Expr - predicates: t.Sequence[Expr] = field(default_factory=list) - is_async: bool = False + transform: t.Callable[[t.Callable[[TypeInfo], TypeInfo], TypeInfo], TypeInfo], + ) -> TypeRefBuilder: + return self.__class__(self.__context, self.__info, partial(transform, self.__transform)) -# noinspection PyTypeChecker -class ScopeASTBuilder(_BaseBuilder): - def type_ref(self, origin: t.Union[TypeInfo, RuntimeType]) -> ClassTypeRefBuilder: - return ClassTypeRefBuilder( - context=self._context, - info=origin - if isinstance(origin, (NamedTypeInfo, LiteralTypeInfo)) - else self._context.inspector.inspect(origin), - ) +class AnnotationASTBuilder(_BaseBuilder): + def type_ref(self, origin: TypeRef) -> TypeRefBuilder: + return TypeRefBuilder(self._context, self._context.inspector.inspect(origin)) - if sys.version_info < (3, 10): + if sys.version_info >= (3, 10): @_ast_expr_builder - def const(self, value: t.Union[str, bytes, bool, complex, None]) -> Expr: # noqa: FBT001 + def const(self, value: t.Union[str, bytes, bool, complex, EllipsisType, None]) -> Expr: # noqa: FBT001 return ast.Constant(value=value) else: @_ast_expr_builder - def const(self, value: t.Union[str, bytes, bool, complex, EllipsisType, None]) -> Expr: # noqa: FBT001 + def const(self, value: t.Union[str, bytes, bool, complex, None]) -> Expr: # noqa: FBT001 return ast.Constant(value=value) def none(self) -> Expr: @@ -641,6 +623,80 @@ def none(self) -> Expr: def ellipsis(self) -> Expr: return ast.Constant(value=...) + def generic_type(self, generic: TypeExpr, *params: TypeExpr) -> Expr: + if len(params) == 0: + return self._normalize_expr(generic) + + return ast.Subscript( + value=self._normalize_expr(generic), + slice=self._normalize_expr(self.tuple_type(*params, normalize=True)), + ) + + def literal_type(self, *values: t.Union[str, Expr]) -> Expr: + if not values: + return self._normalize_expr(predef().no_return) + + return self.generic_type( + predef().literal, + *(self.const(val) if isinstance(val, str) else val for val in values), + ) + + def optional_type(self, of_type: TypeExpr) -> Expr: + return self.generic_type(predef().optional, of_type) + + def union_type(self, *params: TypeExpr, normalize: bool = False) -> Expr: + if not params: + return self._normalize_expr(predef().no_return) + + if normalize and len(params) == 1: + return self._normalize_expr(params[0]) + + return self.generic_type(predef().union, *params) + + def tuple_type(self, *params: TypeExpr, normalize: bool = False) -> Expr: + if normalize and len(params) == 1: + return self._normalize_expr(params[0]) + + return ast.Tuple(elts=[self._normalize_expr(item) for item in params]) + + def collection_type(self, of_type: TypeExpr) -> Expr: + return self.generic_type(predef().collection, of_type) + + def sequence_type(self, of_type: TypeExpr, *, mutable: bool = False) -> Expr: + return self.generic_type(predef().mutable_sequence if mutable else predef().sequence, of_type) + + def list_type(self, of_type: TypeExpr) -> Expr: + return self.generic_type(predef().list, of_type) + + def mapping_type(self, key_type: TypeExpr, value_type: TypeExpr, *, mutable: bool = False) -> Expr: + return self.generic_type(predef().mutable_mapping if mutable else predef().mapping, key_type, value_type) + + def dict_type(self, key_type: TypeExpr, value_type: TypeExpr) -> Expr: + return self.generic_type(predef().dict, key_type, value_type) + + def iterator_type(self, of_type: TypeExpr, *, is_async: bool = False) -> Expr: + return self.generic_type(predef().async_iterator if is_async else predef().iterator, of_type) + + def iterable_type(self, of_type: TypeExpr, *, is_async: bool = False) -> Expr: + return self.generic_type(predef().async_iterable if is_async else predef().iterable, of_type) + + def context_manager_type(self, of_type: TypeExpr, *, is_async: bool = False) -> Expr: + return self.generic_type( + predef().async_context_manager if is_async else predef().context_manager, + of_type, + ) + + +@dataclass(frozen=True) +class Comprehension: + target: Expr + items: Expr + predicates: t.Sequence[Expr] = field(default_factory=list) + is_async: bool = False + + +# noinspection PyTypeChecker +class ScopeASTBuilder(AnnotationASTBuilder): @_ast_expr_builder def await_(self, expr: Expr) -> Expr: return ast.Await(self._normalize_expr(expr)) @@ -680,7 +736,7 @@ def ternary_not_none_expr( ) @_ast_expr_builder - def tuple_expr(self, *items: TypeRef, normalize: bool = False) -> Expr: + def tuple_expr(self, *items: TypeExpr, normalize: bool = False) -> Expr: if normalize and len(items) == 1: return self._normalize_expr(items[0]) @@ -866,20 +922,23 @@ def compare_not_in_expr(self, left: Expr, right: Expr) -> Expr: comparators=[self._normalize_expr(right)], ) - def attr(self, head: t.Union[str, TypeRef], *tail: str) -> AttrASTBuilder: + def attr(self, head: t.Union[str, TypeExpr], *tail: str) -> AttrASTBuilder: return AttrASTBuilder(self._context, head, *tail) def call( self, - func: TypeRef, + func: TypeExpr, args: t.Optional[t.Sequence[Expr]] = None, kwargs: t.Optional[t.Mapping[str, Expr]] = None, ) -> CallASTBuilder: return CallASTBuilder(self._context, func, args, kwargs) @_ast_expr_builder - def subscript(self, value: TypeRef, slice_: Expr) -> Expr: - return ast.Subscript(value=self._normalize_expr(value), slice=self._normalize_expr(slice_)) + def subscript(self, value: TypeExpr, *slice_: TypeExpr) -> Expr: + return ast.Subscript( + value=self._normalize_expr(value), + slice=self._normalize_expr(self.tuple_expr(*slice_, normalize=True)), + ) @_ast_expr_builder def slice( @@ -894,60 +953,6 @@ def slice( step=self._normalize_expr(step) if step is not None else None, ) - def generic_type(self, generic: TypeRef, *args: TypeRef) -> Expr: - if len(args) == 0: - return self._normalize_annotation(generic) - - return ast.Subscript( - value=self._normalize_annotation(generic), - slice=self._normalize_annotation(self.tuple_type(*args, normalize=True)), - ) - - def literal_type(self, *args: t.Union[str, Expr]) -> Expr: - if not args: - return self._normalize_annotation(predef().no_return) - - return self.generic_type( - predef().literal, - *(self.const(arg) if isinstance(arg, str) else arg for arg in args), - ) - - def optional_type(self, of_type: TypeRef) -> Expr: - return self.generic_type(predef().optional, of_type) - - def union_type(self, *args: TypeRef) -> Expr: - if not args: - return self._normalize_annotation(predef().no_return) - - return self.generic_type(predef().union, *args) - - def tuple_type(self, *items: TypeRef, normalize: bool = False) -> Expr: - if normalize and len(items) == 1: - return self._normalize_annotation(items[0]) - - return ast.Tuple(elts=[self._normalize_annotation(item) for item in items]) - - def collection_type(self, of_type: TypeRef) -> Expr: - return self.generic_type(predef().collection, of_type) - - def sequence_type(self, of_type: TypeRef, *, mutable: bool = False) -> Expr: - return self.generic_type(predef().mutable_sequence if mutable else predef().sequence, of_type) - - def mapping_type(self, key_type: TypeRef, value_type: TypeRef, *, mutable: bool = False) -> Expr: - return self.generic_type(predef().mutable_mapping if mutable else predef().mapping, key_type, value_type) - - def iterator_type(self, of_type: TypeRef, *, is_async: bool = False) -> Expr: - return self.generic_type(predef().async_iterator if is_async else predef().iterator, of_type) - - def iterable_type(self, of_type: TypeRef, *, is_async: bool = False) -> Expr: - return self.generic_type(predef().async_iterable if is_async else predef().iterable, of_type) - - def context_manager_type(self, of_type: TypeRef, *, is_async: bool = False) -> Expr: - return self.generic_type( - predef().async_context_manager if is_async else predef().context_manager, - of_type, - ) - def stmt(self, *stmts: t.Optional[Stmt]) -> None: self._context.extend_body(self._context.resolver.resolve_stmts(*stmts)) @@ -958,14 +963,17 @@ def func_def(self, name: str) -> FuncStatementASTBuilder: return FuncStatementASTBuilder(self._context, name) @_ast_stmt_builder - def field_def(self, name: str, annotation: TypeRef, default: t.Optional[Expr] = None) -> ast.stmt: + def field_def(self, name: str, annotation: TypeExpr, default: t.Optional[Expr] = None) -> ast.stmt: return ast.AnnAssign( target=ast.Name(id=name), - annotation=self._normalize_annotation(annotation), + annotation=self._normalize_expr(annotation), value=self._normalize_expr(default) if default is not None else None, simple=1, ) + def type_alias(self, name: str) -> TypeAliasStatementASTBuilder: + return TypeAliasStatementASTBuilder(self._context, name) + @_ast_stmt_builder def assign_stmt(self, target: t.Union[str, Expr], value: Expr) -> ast.stmt: return ast.Assign( @@ -1088,7 +1096,7 @@ def __enter_block(self, body: list[ast.stmt]) -> t.Iterator[ScopeASTBuilder]: def __enter_implicitly(self, body: list[ast.stmt]) -> t.Iterator[ScopeASTBuilder]: if not self.__allow_implicit_enter: msg = "can't enter into the nested block implicitly" - raise RuntimeError(msg, self, body) + raise ASTBuildError(msg, self, body) with self, self.__enter_block(body) as scope: yield scope @@ -1239,7 +1247,7 @@ def build_stmt(self) -> t.Sequence[ast.stmt]: class TryStatementASTBuilder(_NestedBlockASTBuilder): @dataclass() class ExceptHandler: - types: t.Sequence[TypeRef] + types: t.Sequence[TypeExpr] name: t.Optional[str] body: list[ast.stmt] @@ -1253,7 +1261,7 @@ def __init__(self, context: BuildContext) -> None: def body(self) -> t.ContextManager[ScopeASTBuilder]: return self._block(self.__body) - def except_(self, *types: TypeRef, name: t.Optional[str] = None) -> t.ContextManager[ScopeASTBuilder]: + def except_(self, *types: TypeExpr, name: t.Optional[str] = None) -> t.ContextManager[ScopeASTBuilder]: body = list[ast.stmt]() self.__handlers.append( @@ -1297,10 +1305,199 @@ def build_stmt(self) -> t.Sequence[ast.stmt]: ] -@dataclass(frozen=True) -class TypeVar: - name: str - bound: t.Optional[TypeRef] = None +class TypeVarBuilder(_BaseBuilder, TypeDefinitionBuilder, ASTStatementBuilder): + def __init__( + self, + context: BuildContext, + name: str, + variance: t.Optional[TypeVarVariance] = None, + constraints: t.Optional[t.Sequence[TypeExpr]] = None, + lower: t.Optional[TypeExpr] = None, + ) -> None: + super().__init__(context) + self.__name = name + self.__module = self._context.module + self.__namespace = self._context.namespace if sys.version_info >= (3, 12) else self._context.namespace[:-1] + self.__variance = variance + self.__constraints = list[TypeExpr](constraints or ()) + self.__lower = lower + + self.__info = NamedTypeInfo( + name=self.__name, + module=self.__module, + namespace=self.__namespace, + ) + + def __enter__(self) -> TypeInfo: + return self.__info + + def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None: + pass + + @override + @property + def info(self) -> TypeInfo: + return self.__info + + @override + def ref(self) -> ASTExpressionBuilder: + return TypeRefBuilder(self._context, self.__info) + + def invariant(self) -> Self: + self.__variance = "invariant" + return self + + def covariant(self) -> Self: + self.__variance = "covariant" + return self + + def contravariant(self) -> Self: + self.__variance = "contravariant" + return self + + def constraints(self, *types: TypeExpr) -> Self: + self.__constraints.extend(types) + return self + + def lower(self, type_: TypeExpr) -> Self: + self.__lower = type_ + return self + + @override + def build_stmt(self) -> t.Sequence[ast.stmt]: + args: list[ast.expr] = [ast.Constant(value=self.__name)] + keywords = list[ast.keyword]() + + if self.__variance is None or self.__variance == "invariant": + pass + elif self.__variance == "covariant": + keywords.append(ast.keyword(arg="covariant", value=ast.Constant(value=True))) + elif self.__variance == "contravariant": + keywords.append(ast.keyword(arg="contravariant", value=ast.Constant(value=True))) + else: + assert_never(self.__variance) + + if self.__lower is not None: + keywords.append(ast.keyword(arg="bound", value=self._normalize_expr(self.__lower))) + + return [ + ast.Assign( + targets=[ast.Name(id=self.__name)], + value=ast.Call( + func=self._normalize_expr(predef().type_var), + args=args, + keywords=keywords, + ), + lineno=0, + ), + ] + + # NOTE: workaround for passing mypy typings in CI for python 3.12 + if sys.version_info >= (3, 12): + + def build_type_param(self) -> ast.type_param: + return ast.TypeVar( + name=self.__name, + bound=self._normalize_expr(self.__lower) if self.__lower is not None else None, + ) + + +class TypeAliasExpressionBuilder(AnnotationASTBuilder, TypeDefinitionBuilder, ASTStatementBuilder): + def __init__(self, context: BuildContext, info: NamedTypeInfo) -> None: + super().__init__(context) + self.__info = info + self.__expr: t.Optional[TypeExpr] = None + self.__type_vars = list[TypeVarBuilder]() + + @override + @property + def info(self) -> TypeInfo: + return self.__info + + @override + def ref(self) -> TypeRefBuilder: + return TypeRefBuilder(self._context, self.__info) + + def assign(self, expr: TypeExpr) -> None: + self.__expr = expr + + def type_var(self, name: str) -> TypeVarBuilder: + type_var = TypeVarBuilder(self._context, name) + self.__type_vars.append(type_var) + return type_var + + def type_params(self, *params: TypeExpr) -> Expr: + return self.ref().type_params(*params) + + # NOTE: workaround for passing mypy typings in CI for python 3.12 + if sys.version_info >= (3, 12): + + @override + def build_stmt(self) -> t.Sequence[ast.stmt]: + if self.__expr is None: + msg = "type alias expression is not set" + raise IncompleteStatementError(msg, self) + + return [ + ast.TypeAlias( + name=ast.Name(id=self.__info.name), + value=self._normalize_expr(self.__expr), + type_params=[tv.build_type_param() for tv in self.__type_vars], + ) + ] + + else: + + @override + def build_stmt(self) -> t.Sequence[ast.stmt]: + if self.__expr is None: + msg = "type alias expression is not set" + raise IncompleteStatementError(msg, self) + + stmts = [stmt for tv in self.__type_vars for stmt in tv.build_stmt()] + stmts.append( + ast.AnnAssign( + target=ast.Name(id=self.__info.name), + annotation=self._normalize_expr(predef().type_alias), + value=self._normalize_expr(self.__expr), + simple=1, + ) + ) + + return stmts + + +class TypeAliasStatementASTBuilder(_BaseBuilder, ASTStatementBuilder, TypeDefinitionBuilder): + def __init__(self, context: BuildContext, name: str) -> None: + super().__init__(context) + self.__info = NamedTypeInfo(name=name, module=self._context.module, namespace=self._context.namespace) + self.__annotation = TypeAliasExpressionBuilder(context=self._context, info=self.__info) + + def __enter__(self) -> TypeAliasExpressionBuilder: + self._context.enter_scope(self.__annotation.info.name, []) + return self.__annotation + + def __exit__(self, exc_type: object, exc_value: object, exc_traceback: object) -> None: + if exc_type is None: + self._context.leave_scope() + self._context.extend_body(self.build_stmt()) + + @override + @property + def info(self) -> TypeInfo: + return self.__info + + @override + def ref(self) -> ASTExpressionBuilder: + return TypeRefBuilder(self._context, self.__info) + + def assign(self, expr: TypeExpr) -> None: + self.__annotation.assign(expr) + self._context.extend_body(self.build_stmt()) + + @override + def build_stmt(self) -> t.Sequence[ast.stmt]: + return self.__annotation.build_stmt() class ClassScopeASTBuilder(ScopeASTBuilder, TypeDefinitionBuilder): @@ -1314,9 +1511,12 @@ def info(self) -> TypeInfo: return self.__header.info @override - def ref(self) -> ClassTypeRefBuilder: + def ref(self) -> TypeRefBuilder: return self.__header.ref() + def type_var(self, name: str) -> TypeVarBuilder: + return self.__header.type_var(name) + def method_def(self, name: str) -> MethodStatementASTBuilder: return MethodStatementASTBuilder(self._context, name) @@ -1327,7 +1527,7 @@ def init_def(self) -> MethodStatementASTBuilder: return self.method_def("__init__").returns(self.const(None)) @contextmanager - def init_self_attrs_def(self, attrs: t.Mapping[str, TypeRef]) -> t.Iterator[MethodScopeASTBuilder]: + def init_self_attrs_def(self, attrs: t.Mapping[str, TypeExpr]) -> t.Iterator[MethodScopeASTBuilder]: init_def = self.init_def() for name, annotation in attrs.items(): @@ -1389,10 +1589,10 @@ class ClassStatementASTBuilder( def __init__(self, context: BuildContext, name: str) -> None: super().__init__(context) self.__info = NamedTypeInfo(name=name, module=self._context.module, namespace=self._context.namespace) - self.__bases = list[TypeRef]() - self.__decorators = list[TypeRef]() - self.__keywords = dict[str, TypeRef]() - self.__type_vars = list[TypeVar]() + self.__bases = list[TypeExpr]() + self.__decorators = list[TypeExpr]() + self.__keywords = dict[str, TypeExpr]() + self.__type_vars = list[TypeVarBuilder]() self.__docs = list[str]() self.__body = list[ast.stmt]() @@ -1413,14 +1613,13 @@ def info(self) -> TypeInfo: return self.__info @override - def ref(self) -> ClassTypeRefBuilder: - return ClassTypeRefBuilder(self._context, self.__info) - - if sys.version_info >= (3, 12): + def ref(self) -> TypeRefBuilder: + return TypeRefBuilder(self._context, self.__info) - def type_param(self, name: str, bound: t.Optional[TypeRef] = None) -> Self: - self.__type_vars.append(TypeVar(name=name, bound=bound)) - return self + def type_var(self, name: str) -> TypeVarBuilder: + type_var = TypeVarBuilder(self._context, name) + self.__type_vars.append(type_var) + return type_var def docstring(self, value: t.Optional[str]) -> Self: if value: @@ -1445,15 +1644,15 @@ def dataclass(self, *, frozen: bool = False, kw_only: bool = False) -> Self: return self.decorators(dc) - def inherits(self, *bases: t.Optional[TypeRef]) -> Self: + def inherits(self, *bases: t.Optional[TypeExpr]) -> Self: self.__bases.extend(base for base in bases if base is not None) return self - def decorators(self, *items: t.Optional[TypeRef]) -> Self: + def decorators(self, *items: t.Optional[TypeExpr]) -> Self: self.__decorators.extend(item for item in items if item is not None) return self - def keywords(self, **keywords: t.Optional[TypeRef]) -> Self: + def keywords(self, **keywords: t.Optional[TypeExpr]) -> Self: self.__keywords.update({key: value for key, value in keywords.items() if value is not None}) return self @@ -1469,13 +1668,7 @@ def build_stmt(self) -> t.Sequence[ast.stmt]: keywords=self.__build_keywords(), body=self._normalize_body(self.__body, self.__docs), decorator_list=self.__build_decorators(), - type_params=[ - ast.TypeVar( - name=type_var.name, - bound=self._normalize_expr(type_var.bound) if type_var.bound is not None else None, - ) - for type_var in self.__type_vars - ], + type_params=[tv.build_type_param() for tv in self.__type_vars], ), ] @@ -1483,7 +1676,12 @@ def build_stmt(self) -> t.Sequence[ast.stmt]: # noinspection PyArgumentList @override def build_stmt(self) -> t.Sequence[ast.stmt]: - return [ + stmts = [stmt for tv in self.__type_vars for stmt in tv.build_stmt()] + + if self.__type_vars: + self.__bases.insert(0, predef().generic.with_type_params(*(tv.info for tv in self.__type_vars))) + + stmts.append( ast.ClassDef( name=self.__info.name, bases=self.__build_bases(), @@ -1491,7 +1689,9 @@ def build_stmt(self) -> t.Sequence[ast.stmt]: body=self._normalize_body(self.__body, self.__docs), decorator_list=self.__build_decorators(), ), - ] + ) + + return stmts def __build_bases(self) -> list[ast.expr]: return [self._normalize_expr(base) for base in self.__bases] @@ -1512,16 +1712,12 @@ def __init__(self, context: BuildContext, info: NamedTypeInfo) -> None: def build_expr(self) -> ast.expr: return self.__context.resolver.resolve_expr(self.__info) - @override - def build_annotation(self) -> ast.expr: - return self.__context.resolver.resolve_annotation(self.__info) - @dataclass(frozen=True) class FuncArgInfo: name: str kind: t.Literal["positional-only", "positional-or-keyword", "var-positional", "keyword-only", "var-keyword"] - annotation: t.Optional[TypeRef] = None + annotation: t.Optional[TypeExpr] = None default: t.Optional[Expr] = None @@ -1535,9 +1731,9 @@ class FuncStatementASTBuilder( def __init__(self, context: BuildContext, name: str) -> None: super().__init__(context) self.__info = NamedTypeInfo(name=name, module=self._context.module, namespace=self._context.namespace) - self.__decorators = list[TypeRef]() + self.__decorators = list[TypeExpr]() self.__args = list[FuncArgInfo]() - self.__returns: t.Optional[TypeRef] = None + self.__returns: t.Optional[TypeExpr] = None self.__is_async = False self.__is_abstract = False self.__is_override = False @@ -1584,14 +1780,14 @@ def docstring(self, value: t.Optional[str]) -> Self: self.__docs.append(value) return self - def decorators(self, *items: t.Optional[TypeRef]) -> Self: + def decorators(self, *items: t.Optional[TypeExpr]) -> Self: self.__decorators.extend(item for item in items if item is not None) return self def arg( self, name: str, - annotation: t.Optional[TypeRef] = None, + annotation: t.Optional[TypeExpr] = None, default: t.Optional[Expr] = None, ) -> Self: return self.args(FuncArgInfo(name=name, kind="positional-or-keyword", annotation=annotation, default=default)) @@ -1599,7 +1795,7 @@ def arg( def kwarg( self, name: str, - annotation: t.Optional[TypeRef] = None, + annotation: t.Optional[TypeExpr] = None, default: t.Optional[Expr] = None, ) -> Self: return self.args(FuncArgInfo(name=name, kind="keyword-only", annotation=annotation, default=default)) @@ -1614,7 +1810,7 @@ def args(self, *args: t.Union[FuncArgInfo, t.Sequence[FuncArgInfo]]) -> Self: self.__args.extend(chain.from_iterable((part,) if isinstance(part, FuncArgInfo) else part for part in args)) return self - def returns(self, ret: t.Optional[TypeRef]) -> Self: + def returns(self, ret: t.Optional[TypeExpr]) -> Self: if ret is not None: self.__returns = ret return self @@ -1635,8 +1831,6 @@ def not_implemented(self) -> Self: def build_stmt(self) -> t.Sequence[ast.stmt]: node: ast.stmt - scope = ScopeASTBuilder(self._context) - if self.__is_async: # noinspection PyArgumentList node = ast.AsyncFunctionDef( # type: ignore[call-overload,no-any-return,unused-ignore] @@ -1644,7 +1838,7 @@ def build_stmt(self) -> t.Sequence[ast.stmt]: name=self.__info.name, decorator_list=self.__build_decorators(), args=self.__build_args(), - returns=self.__build_returns(scope), + returns=self.__build_returns(), body=self.__build_body(), lineno=0, ) @@ -1656,7 +1850,7 @@ def build_stmt(self) -> t.Sequence[ast.stmt]: name=self.__info.name, decorator_list=self.__build_decorators(), args=self.__build_args(), - returns=self.__build_returns(scope), + returns=self.__build_returns(), body=self.__build_body(), lineno=0, ) @@ -1664,8 +1858,8 @@ def build_stmt(self) -> t.Sequence[ast.stmt]: return [node] def __build_decorators(self) -> list[ast.expr]: - head_decorators: list[TypeRef] = [] - last_decorators: list[TypeRef] = [] + head_decorators: list[TypeExpr] = [] + last_decorators: list[TypeExpr] = [] if self.__is_override: head_decorators.append(predef().override_decorator) @@ -1694,7 +1888,7 @@ def __build_args(self) -> ast.arguments: for info in self.__args: arg = ast.arg( arg=info.name, - annotation=self._normalize_annotation(info.annotation) if info.annotation is not None else None, + annotation=self._normalize_expr(info.annotation) if info.annotation is not None else None, ) if info.kind == "positional-only": @@ -1720,13 +1914,16 @@ def __build_args(self) -> ast.arguments: return node - def __build_returns(self, scope: ScopeASTBuilder) -> t.Optional[ast.expr]: + def __build_returns(self) -> t.Optional[ast.expr]: if self.__returns is None: return None ret = self.__returns if self.__iterator_cm: - ret = scope.iterator_type(ret, is_async=self.__is_async) + ret = ast.Subscript( + value=self._normalize_expr(predef().async_iterator if self.__is_async else predef().iterator), + slice=self._normalize_expr(ret), + ) return self._normalize_expr(ret) diff --git a/src/astlab/reader.py b/src/astlab/reader.py index 408fa40..d904224 100644 --- a/src/astlab/reader.py +++ b/src/astlab/reader.py @@ -38,7 +38,12 @@ def import_module_path(path: Path) -> ModuleType: # Avoid `.py` in last part. qualname = ".".join((*relpath.parts[:-1], relpath.stem)) - return importlib.import_module(qualname) + try: + return importlib.import_module(qualname) + + except ImportError as err: + msg = "can't import module from path" + raise ImportError(msg, path) from err def _get_rel_path_parts_count(args: tuple[Path, Path]) -> int: diff --git a/src/astlab/resolver.py b/src/astlab/resolver.py index 95e873e..817d71a 100644 --- a/src/astlab/resolver.py +++ b/src/astlab/resolver.py @@ -5,65 +5,53 @@ ] import ast +import sys import typing as t from dataclasses import replace from itertools import chain from astlab._typing import assert_never, override -from astlab.abc import ASTExpressionBuilder, ASTResolver, ASTStatementBuilder, Stmt, TypeDefinitionBuilder, TypeRef +from astlab.abc import ASTExpressionBuilder, ASTResolver, ASTStatementBuilder, Stmt, TypeDefinitionBuilder, TypeExpr +from astlab.traverse import traverse_dfs_post_order from astlab.types import ( LiteralTypeInfo, ModuleInfo, NamedTypeInfo, + TypeAnnotator, TypeInfo, TypeInspector, ellipsis_type_info, none_type_info, ) +from astlab.types.model import EnumTypeInfo, TypeVarInfo class DefaultASTResolver(ASTResolver): - def __init__(self, inspector: t.Optional[TypeInspector] = None) -> None: + def __init__( + self, + inspector: t.Optional[TypeInspector] = None, + annotator: t.Optional[TypeAnnotator] = None, + ) -> None: self.__module: t.Optional[ModuleInfo] = None self.__namespace: t.Sequence[str] = () self.__dependencies: t.MutableSet[ModuleInfo] = set[ModuleInfo]() self.__inspector = inspector if inspector is not None else TypeInspector() + self.__annotator = annotator if annotator is not None else TypeAnnotator() @override - def resolve_expr(self, ref: TypeRef, *tail: str) -> ast.expr: - if isinstance(ref, ast.expr): - return self.__chain_attr(ref, *tail) + def resolve_expr(self, expr: TypeExpr, *tail: str) -> ast.expr: + if isinstance(expr, ast.expr): + return self.__chain_attr(expr, *tail) - elif isinstance(ref, ASTExpressionBuilder): - return self.__chain_attr(ref.build_expr(), *tail) + elif isinstance(expr, ASTExpressionBuilder): + return self.__chain_attr(expr.build_expr(), *tail) - elif isinstance(ref, (NamedTypeInfo, LiteralTypeInfo)): - return self.__type_info_expr(ref, tail) - - elif isinstance(ref, TypeDefinitionBuilder): - return self.__type_info_expr(ref.info, tail) + elif isinstance(expr, TypeDefinitionBuilder): + return self.__resolve_info(expr.info, tail) else: - info = self.__inspector.inspect(ref) - return self.__type_info_expr(info, tail) - - @override - def resolve_annotation(self, ref: TypeRef) -> ast.expr: - if isinstance(ref, ast.expr): - return ref - - elif isinstance(ref, ASTExpressionBuilder): - return ref.build_annotation() - - elif isinstance(ref, (NamedTypeInfo, LiteralTypeInfo)): - return self.__type_info_expr(ref, is_annotation=True) - - elif isinstance(ref, TypeDefinitionBuilder): - return self.__type_info_expr(ref.info, is_annotation=True) - - else: - info = self.__inspector.inspect(ref) - return self.__type_info_expr(info, is_annotation=True) + info = self.__inspector.inspect(expr) + return self.__resolve_info(info, tail) @override def resolve_stmts( @@ -103,47 +91,26 @@ def set_current_scope( self.__namespace = namespace self.__dependencies = dependencies - def __type_info_expr(self, info: TypeInfo, tail: t.Sequence[str] = (), *, is_annotation: bool = False) -> ast.expr: - if isinstance(info, NamedTypeInfo) and info.type_vars: - msg = "can't build expr for type with type vars" - raise ValueError(msg, info) + def __resolve_info(self, root: TypeInfo, tail: t.Sequence[str] = ()) -> ast.expr: + nodes = dict[TypeInfo, ast.expr]() - resolved_info = self.__resolve_dependency(info) - return self.__type_info_attr(resolved_info, tail, is_annotation=is_annotation) + for info in traverse_dfs_post_order(root, self.__get_children): + resolved_info = self.__resolve_dependency(info) - def __type_info_attr(self, info: TypeInfo, tail: t.Sequence[str] = (), *, is_annotation: bool = False) -> ast.expr: - if is_annotation and info == none_type_info(): - return ast.Constant(value=None) - - if is_annotation and info == ellipsis_type_info(): - return ast.Constant(value=...) + node = self.__build_expr(resolved_info) + if isinstance(info, NamedTypeInfo) and info.type_params: + params = [nodes[tp] for tp in info.type_params] + node = ast.Subscript( + value=node, + slice=ast.Tuple(elts=params) if len(params) > 1 else params[0], + ) - parts = self.__module.parts if self.__module is not None else () - head, *middle = ( - (info.parts[len(parts) :] if info.module == self.__module else info.parts) - if not isinstance(info, ModuleInfo) - else info.parts - ) + if self.__is_forward_ref(resolved_info): + node = ast.Constant(value=ast.unparse(node)) - origin = self.__chain_attr(ast.Name(id=head), *middle, *tail) - args = ( - ( - [self.__type_info_attr(param, is_annotation=is_annotation) for param in info.type_params] - if isinstance(info, NamedTypeInfo) - else [ast.Constant(value=value) for value in info.values] - ) - if not isinstance(info, ModuleInfo) - else [] - ) + nodes[info] = node - return ( - ast.Subscript( - value=origin, - slice=ast.Tuple(elts=args) if len(args) > 1 else args[0], - ) - if args - else origin - ) + return self.__chain_attr(nodes[root], *tail) def __resolve_dependency(self, info: TypeInfo) -> TypeInfo: if isinstance(info, ModuleInfo): @@ -153,33 +120,62 @@ def __resolve_dependency(self, info: TypeInfo) -> TypeInfo: return info - elif isinstance(info, NamedTypeInfo): - if info.module == self.__module: - ns = ( - info.namespace[len(self.__namespace) :] - if info.namespace[: len(self.__namespace)] == self.__namespace - else info.namespace - ) - - else: + elif isinstance(info, (TypeVarInfo, NamedTypeInfo, EnumTypeInfo)): + if info.module != self.__module: self.__dependencies.add(info.module) - ns = info.namespace + return info - return replace( - info, - namespace=ns, - type_params=tuple(self.__resolve_dependency(param) for param in info.type_params), - ) + elif info.namespace[: len(self.__namespace)] == self.__namespace: + # use shorten namespace for a type in nested namespace of the current scope + return replace(info, namespace=info.namespace[len(self.__namespace) :]) + + else: + return info elif isinstance(info, LiteralTypeInfo): - self.__dependencies.add(info.module) + if info.module != self.__module: + self.__dependencies.add(info.module) + return info else: assert_never(info) + def __build_expr(self, info: TypeInfo) -> ast.expr: + if info == none_type_info(): + return ast.Constant(value=None) + + if info == ellipsis_type_info(): + return ast.Constant(value=...) + + parts = self.__module.parts if self.__module is not None else () + head, *tail = ( + (info.parts[len(parts) :] if info.module == self.__module else info.parts) + if not isinstance(info, ModuleInfo) + else info.parts + ) + + return self.__chain_attr(ast.Name(id=head), *tail) + + if sys.version_info >= (3, 14): + + def __is_forward_ref(self, _: TypeInfo) -> bool: + return False + + else: + + def __is_forward_ref(self, info: TypeInfo) -> bool: + return ( + not isinstance(info, ModuleInfo) + and info.module == self.__module + and (*info.namespace, info.name) == self.__namespace + ) + def __chain_attr(self, expr: ast.expr, *tail: str) -> ast.expr: for attr in tail: expr = ast.Attribute(attr=attr, value=expr) return expr + + def __get_children(self, info: TypeInfo) -> t.Iterable[TypeInfo]: + return info.type_params if isinstance(info, NamedTypeInfo) else () diff --git a/src/astlab/traverse.py b/src/astlab/traverse.py new file mode 100644 index 0000000..0d4423b --- /dev/null +++ b/src/astlab/traverse.py @@ -0,0 +1,19 @@ +import typing as t +from collections import deque + +T = t.TypeVar("T", bound=t.Hashable) + + +def traverse_dfs_post_order(root: T, children: t.Callable[[T], t.Iterable[T]]) -> t.Iterable[T]: + stack = deque[tuple[T, bool]]([(root, False)]) + + while stack: + item, processed = stack.pop() + if processed: + yield item + + else: + stack.append((item, True)) + + for child in children(item): + stack.append((child, False)) diff --git a/src/astlab/types/__init__.py b/src/astlab/types/__init__.py index 2a20b44..4618371 100644 --- a/src/astlab/types/__init__.py +++ b/src/astlab/types/__init__.py @@ -1,4 +1,6 @@ __all__ = [ + "EnumTypeInfo", + "EnumTypeValue", "LiteralTypeInfo", "LiteralTypeValue", "ModuleInfo", @@ -10,6 +12,9 @@ "TypeInfo", "TypeInspector", "TypeLoader", + "TypeLoaderError", + "TypeVarInfo", + "TypeVarVariance", "builtins_module_info", "ellipsis_type_info", "none_type_info", @@ -19,8 +24,10 @@ from astlab.types.annotator import TypeAnnotator from astlab.types.inspector import TypeInspector -from astlab.types.loader import ModuleLoader, TypeLoader +from astlab.types.loader import ModuleLoader, TypeLoader, TypeLoaderError from astlab.types.model import ( + EnumTypeInfo, + EnumTypeValue, LiteralTypeInfo, LiteralTypeValue, ModuleInfo, @@ -28,6 +35,8 @@ PackageInfo, RuntimeType, TypeInfo, + TypeVarInfo, + TypeVarVariance, builtins_module_info, ellipsis_type_info, none_type_info, diff --git a/src/astlab/types/annotator.py b/src/astlab/types/annotator.py index 130abb2..a7e289b 100644 --- a/src/astlab/types/annotator.py +++ b/src/astlab/types/annotator.py @@ -5,19 +5,24 @@ ] import ast +import enum import typing as t from collections import deque from dataclasses import replace from astlab._typing import assert_never, override from astlab.cache import lru_cache_method -from astlab.types.loader import ModuleLoader +from astlab.types.loader import TypeLoader, TypeLoaderError from astlab.types.model import ( + EnumTypeInfo, + EnumTypeValue, LiteralTypeInfo, LiteralTypeValue, ModuleInfo, NamedTypeInfo, + RuntimeType, TypeInfo, + TypeVarInfo, builtins_module_info, ellipsis_type_info, none_type_info, @@ -27,35 +32,44 @@ class TypeAnnotator: """Provides annotation string form type info and vice versa (parses annotation to type info).""" - def __init__(self, loader: t.Optional[ModuleLoader] = None) -> None: - self.__loader = loader or ModuleLoader() + def __init__(self, loader: t.Optional[TypeLoader] = None) -> None: + self.__loader = loader or TypeLoader() @lru_cache_method() def annotate(self, info: TypeInfo) -> str: - if isinstance(info, ModuleInfo): - return "builtins.module" + if info == none_type_info(): + return "None" - elif isinstance(info, NamedTypeInfo): - if info == none_type_info(): - return "None" + if info == ellipsis_type_info(): + return "..." + + if isinstance(info, ModuleInfo): + annotation = "builtins.module" - if info == ellipsis_type_info(): - return "..." + elif isinstance(info, TypeVarInfo): + annotation = info.qualname - if not info.type_params: - return info.qualname + elif isinstance(info, NamedTypeInfo): + annotation = info.qualname - # TODO: fix recursive type - params = ", ".join(self.annotate(tp) for tp in info.type_params) - return f"{info.qualname}[{params}]" + if info.type_params: + # TODO: fix recursive type + params = ", ".join(self.annotate(tp) for tp in info.type_params) + annotation = f"{annotation}[{params}]" elif isinstance(info, LiteralTypeInfo): vals = ", ".join(repr(v) for v in info.values) - return f"typing.Literal[{vals}]" + annotation = f"{info.qualname}[{vals}]" + + elif isinstance(info, EnumTypeInfo): + annotation = info.qualname else: assert_never(info) + return annotation + + @lru_cache_method() def parse(self, qualname: str) -> TypeInfo: node = ast.parse(qualname) @@ -67,7 +81,7 @@ def parse(self, qualname: str) -> TypeInfo: class _ExprParser(ast.NodeVisitor): - def __init__(self, loader: ModuleLoader) -> None: + def __init__(self, loader: TypeLoader) -> None: self.__loader = loader self.__parts = deque[str]() self.__info: t.Optional[TypeInfo] = None @@ -89,16 +103,36 @@ def visit_Constant(self, node: ast.Constant) -> None: def visit_Name(self, node: ast.Name) -> None: self.__parts.appendleft(node.id) *parts, name = self.__parts + named_type_info = self.__extract_named_type_info(node, parts, name) + rtt: RuntimeType = self.__loader.load(named_type_info) + + if isinstance(rtt, t.TypeVar): # type: ignore[misc] + self.__set_result( + TypeVarInfo( + name=named_type_info.name, + module=named_type_info.module, + namespace=named_type_info.namespace, + ) + ) + + elif isinstance(rtt, type) and issubclass(rtt, enum.Enum): # type: ignore[misc] + self.__set_result( + EnumTypeInfo( + name=named_type_info.name, + module=named_type_info.module, + namespace=named_type_info.namespace, + values=tuple( + EnumTypeValue( + name=enum_value.name, + value=enum_value.value, # type: ignore[misc] + ) + for enum_value in rtt + ), + ) + ) - module = self.__extract_module_info(node, parts) - - info = NamedTypeInfo( - name=name, - module=module, - namespace=tuple(parts[len(module.parts) :]), - ) - - self.__set_result(info) + else: + self.__set_result(named_type_info) @override def visit_Attribute(self, node: ast.Attribute) -> None: @@ -121,7 +155,6 @@ def visit_Subscript(self, node: ast.Subscript) -> None: if isinstance(node.slice, ast.Tuple) else self.__parse_type_params(node.slice) ), - type_vars=(), ) def parse(self, node: ast.AST) -> TypeInfo: @@ -140,18 +173,30 @@ def __set_result(self, info: TypeInfo) -> None: self.__info = info - def __extract_module_info(self, node: ast.AST, parts: t.Sequence[str]) -> ModuleInfo: + def __extract_named_type_info( + self, + node: ast.AST, + parts: t.Sequence[str], + name: str, + ) -> NamedTypeInfo: if not parts: - return builtins_module_info() + return NamedTypeInfo( + name=name, + module=builtins_module_info(), + ) for i in range(len(parts), 0, -1): module = ModuleInfo.build(*parts[:i]) try: self.__loader.load(module) - except ImportError: + except TypeLoaderError: continue else: - return module + return NamedTypeInfo( + name=name, + module=module, + namespace=tuple(parts[len(module.parts) :]), + ) msg = "invalid module parts" raise ValueError(msg, parts, ast.dump(node)) diff --git a/src/astlab/types/inspector.py b/src/astlab/types/inspector.py index 00efb0e..1868f9c 100644 --- a/src/astlab/types/inspector.py +++ b/src/astlab/types/inspector.py @@ -7,18 +7,34 @@ "TypeInspector", ] +import enum import sys import typing as t +from dataclasses import replace from astlab.cache import lru_cache_method -from astlab.types.model import LiteralTypeInfo, ModuleInfo, NamedTypeInfo, RuntimeType, TypeInfo, typing_module_info +from astlab.types.model import ( + EnumTypeInfo, + EnumTypeValue, + LiteralTypeInfo, + ModuleInfo, + NamedTypeInfo, + RuntimeType, + TypeInfo, + TypeVarInfo, + typing_module_info, +) +from astlab.types.predef import get_predef class TypeInspector: """Provides type info from runtime type.""" @lru_cache_method() - def inspect(self, type_: RuntimeType) -> TypeInfo: + def inspect(self, type_: t.Union[TypeInfo, RuntimeType]) -> TypeInfo: + if isinstance(type_, (ModuleInfo, TypeVarInfo, NamedTypeInfo, LiteralTypeInfo, EnumTypeInfo)): + return type_ + if isinstance( type_, t._LiteralGenericAlias, # type: ignore[attr-defined] # noqa: SLF001 @@ -33,6 +49,33 @@ def inspect(self, type_: RuntimeType) -> TypeInfo: origin, type_params = self.__unpack_generic(type_) module, namespace, name = self.__get_module_naming(origin) + if isinstance(origin, t.TypeVar): + return TypeVarInfo( + name=origin.__name__, + module=module, + namespace=namespace, + variance=( + "covariant" + if origin.__covariant__ + else "contravariant" + if origin.__contravariant__ + else "invariant" + ), + lower=self.inspect(origin.__bound__) + if origin.__bound__ is not None + else replace(get_predef().union, type_params=tuple(self.inspect(co) for co in origin.__constraints__)) + if origin.__constraints__ + else None, + ) + + if isinstance(origin, type) and issubclass(origin, enum.Enum): + return EnumTypeInfo( + name=name, + module=module, + namespace=tuple(namespace), + values=tuple(EnumTypeValue(name=enum_value.name, value=enum_value.value) for enum_value in origin), + ) + return NamedTypeInfo( name=name, module=module, diff --git a/src/astlab/types/loader.py b/src/astlab/types/loader.py index 816e67f..60fc806 100644 --- a/src/astlab/types/loader.py +++ b/src/astlab/types/loader.py @@ -3,6 +3,7 @@ __all__ = [ "ModuleLoader", "TypeLoader", + "TypeLoaderError", ] import importlib @@ -16,12 +17,14 @@ from astlab.cache import lru_cache_method from astlab.reader import import_module_path from astlab.types.model import ( + EnumTypeInfo, LiteralTypeInfo, ModuleInfo, NamedTypeInfo, PackageInfo, RuntimeType, TypeInfo, + TypeVarInfo, ellipsis_type_info, none_type_info, ) @@ -54,6 +57,10 @@ def clear_cache(self) -> None: importlib.invalidate_caches() +class TypeLoaderError(Exception): + pass + + class TypeLoader: """Loads runtime type from provided info.""" @@ -63,32 +70,79 @@ def __init__(self, module: t.Optional[ModuleLoader] = None) -> None: @lru_cache_method() def load(self, info: TypeInfo) -> RuntimeType: if isinstance(info, ModuleInfo): - return self.__module.load(info) + rtt = self.__load_module(info) + + elif isinstance(info, TypeVarInfo): + rtt = self.__load_type_var(info) elif isinstance(info, NamedTypeInfo): if info == none_type_info(): - return None + rtt = None elif info == ellipsis_type_info(): - return Ellipsis + rtt = Ellipsis - value: object = self.__module.load(info.module) + else: + rtt = self.__load_named_type(info) - for name in info.namespace: - value = getattr(value, name) + elif isinstance(info, LiteralTypeInfo): + rtt = getitem(t.Literal, info.values) - # NOTE: need to check that we loaded a type. - type_: object = getattr(value, info.name) + elif isinstance(info, EnumTypeInfo): + rtt = self.__load_type_by_name(info) - if not info.type_params: - return type_ + else: + assert_never(info) - # TODO: fix recursive type - type_params = tuple(self.load(tp) for tp in info.type_params) - return getitem(type_, type_params) if len(type_params) > 1 else getitem(type_, type_params[0]) # type: ignore[call-overload, misc] + return rtt - elif isinstance(info, LiteralTypeInfo): - return getitem(t.Literal, info.values) + def clear_cache(self) -> None: + self.load.cache_clear() # type: ignore[attr-defined] + self.__module.clear_cache() - else: - assert_never(info) + def __load_module(self, info: ModuleInfo) -> RuntimeType: + try: + return self.__module.load(info) + + except ImportError as err: + msg = "module can't be loaded" + raise TypeLoaderError(msg, info) from err + + def __load_type_var(self, info: TypeVarInfo) -> RuntimeType: + # noinspection PyTypeHints + return t.TypeVar( + name=info.name, + bound=self.load(info.lower) if info.lower else None, + covariant=info.variance == "covariant", + contravariant=info.variance == "contravariant", + ) + + def __load_named_type(self, info: NamedTypeInfo) -> RuntimeType: + rtt = self.__load_type_by_name(info) + if not info.type_params: + return rtt + + # TODO: fix recursive type load + loaded_type_params = tuple(self.load(tp) for tp in info.type_params) + + try: + return ( + getitem(rtt, loaded_type_params) if len(loaded_type_params) > 1 else getitem(rtt, loaded_type_params[0]) # type: ignore[arg-type,misc] + ) + + except TypeError as err: + msg = "type params can't be applied to type" + raise TypeLoaderError(msg, info) from err + + def __load_type_by_name(self, info: t.Union[NamedTypeInfo, EnumTypeInfo]) -> RuntimeType: + container: object = self.load(info.module) + try: + for name in info.namespace: + container = getattr(container, name) + + # NOTE: need to check that we loaded a type. + return getattr(container, info.name) # type: ignore[misc] + + except AttributeError as err: + msg = "module has not attribute" + raise TypeLoaderError(msg, info) from err diff --git a/src/astlab/types/model.py b/src/astlab/types/model.py index 1ad90b3..957d709 100644 --- a/src/astlab/types/model.py +++ b/src/astlab/types/model.py @@ -1,6 +1,8 @@ from __future__ import annotations __all__ = [ + "EnumTypeInfo", + "EnumTypeValue", "LiteralTypeInfo", "LiteralTypeValue", "ModuleInfo", @@ -8,6 +10,8 @@ "PackageInfo", "RuntimeType", "TypeInfo", + "TypeVarInfo", + "TypeVarVariance", "builtins_module_info", "ellipsis_type_info", "none_type_info", @@ -144,13 +148,32 @@ def typing_module_info() -> ModuleInfo: return ModuleInfo("typing") +TypeVarVariance: TypeAlias = t.Literal["invariant", "covariant", "contravariant"] + + +@dataclass(frozen=True) +class TypeVarInfo: + name: str + module: ModuleInfo + namespace: t.Sequence[str] = field(default_factory=tuple) + variance: TypeVarVariance = "invariant" + lower: t.Optional[TypeInfo] = None + + @cached_property + def parts(self) -> t.Sequence[str]: + return *self.module.parts, *self.namespace, self.name + + @cached_property + def qualname(self) -> str: + return ".".join(self.parts) + + @dataclass(frozen=True) class NamedTypeInfo: name: str module: ModuleInfo namespace: t.Sequence[str] = field(default_factory=tuple) type_params: t.Sequence[TypeInfo] = field(default_factory=tuple) - type_vars: t.Sequence[str] = field(default_factory=tuple) @classmethod def build( @@ -158,7 +181,6 @@ def build( module: t.Union[str, t.Sequence[str], ModuleInfo], name: str, type_params: t.Optional[t.Sequence[TypeInfo]] = None, - type_vars: t.Optional[t.Sequence[str]] = None, ) -> NamedTypeInfo: *namespace, type_name = name.split(".") if not type_name: @@ -174,7 +196,6 @@ def build( else ModuleInfo.build(*module), namespace=tuple(namespace), type_params=tuple(type_params or ()), - type_vars=tuple(type_vars or ()), ) @cached_property @@ -185,19 +206,8 @@ def parts(self) -> t.Sequence[str]: def qualname(self) -> str: return ".".join(self.parts) - def with_type_params(self, *infos: TypeInfo) -> NamedTypeInfo: - if not infos: - return self - - if len(infos) > len(self.type_vars): - msg = "too many type parameters" - raise ValueError(msg, infos, self) - - return replace( - self, - type_params=(*self.type_params, *infos), - type_vars=self.type_vars[len(infos) :], - ) + def with_type_params(self, *type_params: TypeInfo) -> NamedTypeInfo: + return replace(self, type_params=type_params) @cache # type: ignore[misc] @@ -215,7 +225,6 @@ def ellipsis_type_info() -> NamedTypeInfo: @dataclass(frozen=True) class LiteralTypeInfo: - # TODO: enum values values: t.Sequence[LiteralTypeValue] @cached_property @@ -239,4 +248,49 @@ def qualname(self) -> str: return ".".join(self.parts) -TypeInfo: TypeAlias = t.Union[ModuleInfo, NamedTypeInfo, LiteralTypeInfo] +@dataclass(frozen=True) +class EnumTypeValue: + name: str + value: LiteralTypeValue + + +@dataclass(frozen=True) +class EnumTypeInfo: + name: str + module: ModuleInfo + namespace: t.Sequence[str] = field(default_factory=tuple) + values: t.Sequence[EnumTypeValue] = field(default_factory=tuple) + + @classmethod + def build( + cls, + module: t.Union[str, t.Sequence[str], ModuleInfo], + name: str, + values: t.Sequence[EnumTypeValue], + ) -> EnumTypeInfo: + *namespace, type_name = name.split(".") + if not type_name: + msg = "type name can't be empty" + raise ValueError(msg, name) + + return EnumTypeInfo( + name=type_name, + module=module + if isinstance(module, ModuleInfo) + else ModuleInfo.from_str(module) + if isinstance(module, str) + else ModuleInfo.build(*module), + namespace=tuple(namespace), + values=values, + ) + + @cached_property + def parts(self) -> t.Sequence[str]: + return *self.module.parts, *self.namespace, self.name + + @cached_property + def qualname(self) -> str: + return ".".join(self.parts) + + +TypeInfo: TypeAlias = t.Union[ModuleInfo, TypeVarInfo, NamedTypeInfo, LiteralTypeInfo, EnumTypeInfo] diff --git a/src/astlab/types/predef.py b/src/astlab/types/predef.py index 4ee2c95..99acfb6 100644 --- a/src/astlab/types/predef.py +++ b/src/astlab/types/predef.py @@ -30,6 +30,10 @@ class Predef: def typing_module(self) -> ModuleInfo: return typing_module_info() + @cached_property + def enum_module(self) -> ModuleInfo: + return ModuleInfo("enum") + @cached_property def dataclasses_module(self) -> ModuleInfo: return ModuleInfo("dataclasses") @@ -59,7 +63,7 @@ def object(self) -> NamedTypeInfo: return NamedTypeInfo("object", self.builtins_module) @cached_property - def none_type(self) -> NamedTypeInfo: + def none(self) -> NamedTypeInfo: return none_type_info() @cached_property @@ -142,6 +146,14 @@ def final_decorator(self) -> NamedTypeInfo: def final(self) -> NamedTypeInfo: return NamedTypeInfo("Final", self.typing_module) + @cached_property + def type_alias(self) -> NamedTypeInfo: + return NamedTypeInfo("TypeAlias", self.typing_module) + + @cached_property + def type_var(self) -> NamedTypeInfo: + return NamedTypeInfo("TypeVar", self.typing_module) + @cached_property def class_var(self) -> NamedTypeInfo: return NamedTypeInfo("ClassVar", self.typing_module) @@ -222,6 +234,14 @@ def async_iterable(self) -> NamedTypeInfo: def literal(self) -> NamedTypeInfo: return NamedTypeInfo("Literal", self.typing_module) + @cached_property + def enum(self) -> NamedTypeInfo: + return NamedTypeInfo("Enum", self.enum_module) + + @cached_property + def enum_auto(self) -> NamedTypeInfo: + return NamedTypeInfo("auto", self.enum_module) + @cached_property def no_return(self) -> NamedTypeInfo: return NamedTypeInfo("NoReturn", self.typing_module) diff --git a/tests/conftest.py b/tests/conftest.py index 93f19e6..4c7a266 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,8 +30,8 @@ def type_loader() -> TypeLoader: @pytest.fixture -def type_annotator(module_loader: ModuleLoader) -> TypeAnnotator: - return TypeAnnotator(module_loader) +def type_annotator(type_loader: TypeLoader) -> TypeAnnotator: + return TypeAnnotator(type_loader) @pytest.fixture diff --git a/tests/stub/types.py b/tests/stub/types.py index 4f410ed..61237c9 100644 --- a/tests/stub/types.py +++ b/tests/stub/types.py @@ -1,3 +1,4 @@ +import enum import typing as t from dataclasses import dataclass @@ -22,6 +23,11 @@ class Z: pass +class StubEnum(enum.Enum): + FOO = enum.auto() + BAR = enum.auto() + + class StubCM(t.ContextManager["StubCM"]): @override def __exit__(self, exc_type: object, exc_value: object, traceback: object, /) -> None: @@ -38,5 +44,7 @@ class StubNode(t.Generic[T]): StubUnionAlias: TypeAlias = t.Union[StubFoo, StubBar[StubInt], StubX] +StubRecursive: TypeAlias = t.Union[T, t.Sequence["StubRecursive[T]"]] + # TODO: enable after python 3.9, 3.10, 3.11 version support stop drop. -# type StubNumber = int | float # noqa: ERA001 +# type StubRecursive[T] = T | t.Sequence[StubRecursive[T]] # noqa: ERA001 diff --git a/tests/unit/test_builder.py b/tests/unit/test_builder.py index b60715b..a89880c 100644 --- a/tests/unit/test_builder.py +++ b/tests/unit/test_builder.py @@ -3,38 +3,48 @@ import typing as t import pytest -from _pytest.mark import ParameterSet +from _pytest.mark import MarkDecorator, ParameterSet from astlab.builder import Comprehension, ModuleASTBuilder, build_module, build_package from astlab.reader import parse_module -from astlab.types import none_type_info +from astlab.types import none_type_info, predef _PARAMS: t.Final[list[ParameterSet]] = [] -def _to_module_param(func: t.Callable[[], ModuleASTBuilder]) -> ParameterSet: - expected_code = inspect.getdoc(func) - assert expected_code is not None +@pytest.mark.parametrize(("builder", "expected_code"), _PARAMS) +def test_module_build(builder: t.Callable[[], ModuleASTBuilder], normalized_expected_code: str) -> None: + assert builder().render() == normalized_expected_code - param = pytest.param(func(), parse_module(expected_code), id=func.__name__) - _PARAMS.append(param) - return param +@pytest.fixture +def normalized_expected_code(expected_code: str) -> str: + return ast.unparse(parse_module(expected_code)) -@pytest.mark.parametrize(("builder", "expected"), _PARAMS) -def test_module_build(builder: ModuleASTBuilder, expected: ast.Module) -> None: - assert builder.render() == ast.unparse(expected) +def _to_param( + marks: t.Optional[t.Sequence[MarkDecorator]] = None, +) -> t.Callable[[t.Callable[[], ModuleASTBuilder]], t.Callable[[], ModuleASTBuilder]]: + def inner(func: t.Callable[[], ModuleASTBuilder]) -> t.Callable[[], ModuleASTBuilder]: + expected_code = inspect.getdoc(func) + assert expected_code is not None + param = pytest.param(func, expected_code, id=func.__name__, marks=marks or []) + _PARAMS.append(param) -@_to_module_param + return func + + return inner + + +@_to_param() def build_empty_module() -> ModuleASTBuilder: """""" with build_module("simple") as mod: return mod -@_to_module_param +@_to_param() def build_simple_module() -> ModuleASTBuilder: # noinspection PySingleQuotedDocstring ''' @@ -89,7 +99,7 @@ def do_buzz(self) -> builtins.object: return mod -@_to_module_param +@_to_param() def build_bar_impl_module() -> ModuleASTBuilder: """ import builtins @@ -124,7 +134,7 @@ def do_stuff(self, spam: builtins.str) -> typing.Iterator[builtins.str]: return bar -@_to_module_param +@_to_param() def build_optionals() -> ModuleASTBuilder: """ import builtins @@ -135,6 +145,7 @@ class MyOptions: my_optional_str: typing.Optional[builtins.str] my_optional_list_of_int: typing.Optional[builtins.list[builtins.int]] my_list_of_optional_int: builtins.list[typing.Optional[builtins.int]] + my_optional_str_default: typing.Optional[builtins.str] = None """ with build_module("opts") as mod, mod.class_def("MyOptions") as opt: @@ -142,11 +153,12 @@ class MyOptions: opt.field_def("my_optional_str", opt.type_ref(str).optional()) opt.field_def("my_optional_list_of_int", opt.type_ref(int).list().optional()) opt.field_def("my_list_of_optional_int", opt.type_ref(int).optional().list()) + opt.field_def("my_optional_str_default", opt.type_ref(str).optional(), opt.none()) return mod -@_to_module_param +@_to_param() def build_unions() -> ModuleASTBuilder: """ import builtins @@ -168,7 +180,7 @@ class MyOptions: return mod -@_to_module_param +@_to_param() def build_runtime_types() -> ModuleASTBuilder: """ import builtins @@ -185,7 +197,7 @@ def build_runtime_types() -> ModuleASTBuilder: return mod -@_to_module_param +@_to_param() def build_is_not_none_expr() -> ModuleASTBuilder: """ maybe = body if test is not None else None @@ -197,7 +209,7 @@ def build_is_not_none_expr() -> ModuleASTBuilder: return mod -@_to_module_param +@_to_param() def build_list_const() -> ModuleASTBuilder: """ result = [1, 2, foo, bar] @@ -209,7 +221,7 @@ def build_list_const() -> ModuleASTBuilder: return mod -@_to_module_param +@_to_param() def build_list_compr_expr() -> ModuleASTBuilder: """ result = [item for target in iterable] @@ -224,7 +236,7 @@ def build_list_compr_expr() -> ModuleASTBuilder: return mod -@_to_module_param +@_to_param() def build_try_except_else() -> ModuleASTBuilder: """ import builtins @@ -256,7 +268,7 @@ def build_try_except_else() -> ModuleASTBuilder: return mod -@_to_module_param +@_to_param() def build_index_slice() -> ModuleASTBuilder: """ list[str] @@ -279,3 +291,259 @@ def build_index_slice() -> ModuleASTBuilder: mod.stmt(mod.attr("arr").index(mod.tuple_expr(mod.slice(), mod.slice(), mod.const(0)))) return mod + + +@_to_param( + marks=( + pytest.mark.skipif( + condition="sys.version_info >= (3, 12)", + reason="new type var syntax can be used", + ), + ), +) +def build_generic_class_before_312() -> ModuleASTBuilder: + """ + import builtins + import typing + + T = typing.TypeVar('T', bound=builtins.int) + + class Node(typing.Generic[T]): + value: T + parent: typing.Optional['Node[T]'] = None + """ + + with build_module("generic") as mod: + with mod.class_def("Node") as node, node.type_var("T").lower(int) as type_var: + node.field_def("value", type_var) + node.field_def("parent", node.ref().type_params(type_var).optional(), mod.none()) + + return mod + + +@_to_param( + marks=( + pytest.mark.skipif( + condition="sys.version_info < (3, 12) or sys.version_info >= (3, 14)", + reason="type var syntax is available since python version 3.12", + ), + ), +) +def build_generic_class_312_313() -> ModuleASTBuilder: + """ + import builtins + import typing + + class Node[T : builtins.int]: + value: T + parent: typing.Optional['Node[T]'] = None + """ + + with build_module("generic") as mod: + with mod.class_def("Node") as node, node.type_var("T").lower(int) as type_var: + node.field_def("value", type_var) + node.field_def("parent", node.ref().type_params(type_var).optional(), mod.none()) + + return mod + + +@_to_param( + marks=( + pytest.mark.skipif( + condition="sys.version_info < (3, 14)", + reason="type var syntax is available since python version 3.12 and forward ref is supported since 3.14", + ), + ), +) +def build_generic_class_after_314() -> ModuleASTBuilder: + """ + import builtins + import typing + + class Node[T : builtins.int]: + value: T + parent: typing.Optional[Node[T]] = None + """ + + with build_module("generic") as mod: + with mod.class_def("Node") as node, node.type_var("T").lower(int) as type_var: + node.field_def("value", type_var) + node.field_def("parent", node.ref().type_params(type_var).optional(), mod.none()) + + return mod + + +@_to_param( + marks=( + pytest.mark.skipif( + condition="sys.version_info >= (3, 12)", + reason="new type alias syntax can be used", + ), + ), +) +def build_type_alias_before_syntax_312() -> ModuleASTBuilder: + """ + import builtins + import typing + + MyInt: typing.TypeAlias = builtins.int + Json: typing.TypeAlias = typing.Union[ + None, + builtins.bool, + builtins.int, + builtins.float, + builtins.str, + builtins.list['Json'], + builtins.dict[builtins.str, 'Json'], + ] + T1 = typing.TypeVar('T1') + T2 = typing.TypeVar('T2') + Nested: typing.TypeAlias = typing.Union[T1, T2, typing.Sequence['Nested[T1, T2]']] + """ + + with build_module("alias") as mod: + mod.type_alias("MyInt").assign(predef().int) + + with mod.type_alias("Json") as json_alias: + json_alias.assign( + json_alias.union_type( + predef().none, + predef().bool, + predef().int, + predef().float, + predef().str, + mod.list_type(json_alias), + mod.dict_type(predef().str, json_alias), + ) + ) + + with ( + mod.type_alias("Nested") as nested_alias, + nested_alias.type_var("T1") as type_var_1, + nested_alias.type_var("T2") as type_var_2, + ): + nested_alias.assign( + nested_alias.union_type( + type_var_1, + type_var_2, + nested_alias.sequence_type(nested_alias.type_params(type_var_1, type_var_2)), + ) + ) + + return mod + + +@_to_param( + marks=( + pytest.mark.skipif( + condition="sys.version_info < (3, 12) or sys.version_info >= (3, 14)", + reason="type alias syntax is available since python version 3.12", + ), + ), +) +def build_type_alias_syntax_312() -> ModuleASTBuilder: + """ + import builtins + import typing + + type MyInt = builtins.int + type Json = typing.Union[ + None, + builtins.bool, + builtins.int, + builtins.float, + builtins.str, + builtins.list['Json'], + builtins.dict[builtins.str, 'Json'] + ] + type Nested[T1, T2] = typing.Union[T1, T2, typing.Sequence['Nested[T1, T2]']] + """ + + with build_module("alias") as mod: + mod.type_alias("MyInt").assign(predef().int) + + with mod.type_alias("Json") as json_alias: + json_alias.assign( + json_alias.union_type( + predef().none, + predef().bool, + predef().int, + predef().float, + predef().str, + mod.list_type(json_alias), + mod.dict_type(predef().str, json_alias), + ) + ) + + with ( + mod.type_alias("Nested") as nested_alias, + nested_alias.type_var("T1") as type_var_1, + nested_alias.type_var("T2") as type_var_2, + ): + nested_alias.assign( + nested_alias.union_type( + type_var_1, + type_var_2, + nested_alias.sequence_type(nested_alias.type_params(type_var_1, type_var_2)), + ) + ) + + return mod + + +@_to_param( + marks=( + pytest.mark.skipif( + condition="sys.version_info < (3, 14)", + reason="type alias syntax is available since python version 3.12 and forward ref is supported since 3.14", + ), + ), +) +def build_type_alias_syntax_314() -> ModuleASTBuilder: + """ + import builtins + import typing + + type MyInt = builtins.int + type Json = typing.Union[ + None, + builtins.bool, + builtins.int, + builtins.float, + builtins.str, + builtins.list[Json], + builtins.dict[builtins.str, Json] + ] + type Nested[T1, T2] = typing.Union[T1, T2, typing.Sequence[Nested[T1, T2]]] + """ + + with build_module("alias") as mod: + mod.type_alias("MyInt").assign(predef().int) + + with mod.type_alias("Json") as json_alias: + json_alias.assign( + json_alias.union_type( + predef().none, + predef().bool, + predef().int, + predef().float, + predef().str, + mod.list_type(json_alias), + mod.dict_type(predef().str, json_alias), + ) + ) + + with ( + mod.type_alias("Nested") as nested_alias, + nested_alias.type_var("T1") as type_var_1, + nested_alias.type_var("T2") as type_var_2, + ): + nested_alias.assign( + nested_alias.union_type( + type_var_1, + type_var_2, + nested_alias.sequence_type(nested_alias.type_params(type_var_1, type_var_2)), + ) + ) + + return mod diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index 860c4f9..c1f66b8 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -6,9 +6,10 @@ import pytest from astlab import abc as astlab_abc +from astlab.types import EnumTypeInfo, EnumTypeValue from astlab.types.annotator import TypeAnnotator from astlab.types.inspector import TypeInspector -from astlab.types.loader import TypeLoader +from astlab.types.loader import TypeLoader, TypeLoaderError from astlab.types.model import ( LiteralTypeInfo, ModuleInfo, @@ -19,7 +20,7 @@ builtins_module_info, none_type_info, ) -from tests.stub.types import StubBar, StubCM, StubFoo, StubInt, StubNode, StubUnionAlias, StubX +from tests.stub.types import StubBar, StubCM, StubEnum, StubFoo, StubInt, StubNode, StubUnionAlias, StubX class TestPackageInfo: @@ -333,6 +334,24 @@ def test_qualname_ok(self, value: ModuleInfo, expected: t.Sequence[str]) -> None namespace=("StubX", "Y"), ), ), + pytest.param( + StubEnum, + "tests.stub.types.StubEnum", + EnumTypeInfo( + name="StubEnum", + module=ModuleInfo("types", PackageInfo("stub", PackageInfo("tests"))), + values=( + EnumTypeValue( + name="FOO", + value=1, + ), + EnumTypeValue( + name="BAR", + value=2, + ), + ), + ), + ), pytest.param( StubCM, "tests.stub.types.StubCM", @@ -473,11 +492,11 @@ def test_load_ok( [ pytest.param( NamedTypeInfo("NonExistingType", builtins_module_info()), - AttributeError, + TypeLoaderError, ), pytest.param( NamedTypeInfo("SomeType", ModuleInfo("non_existing_module")), - ModuleNotFoundError, + TypeLoaderError, ), ], )