Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions mypy/plugins/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,11 @@ def _parse_converter(
converter_type = converter_expr.node.type
elif isinstance(converter_expr.node, TypeInfo):
converter_type = type_object_type(converter_expr.node)
elif isinstance(converter_expr.node, Var) and converter_expr.node.type:
# The converter is a variable annotated with a callable type.
var_type = get_proper_type(converter_expr.node.type)
if isinstance(var_type, FunctionLike):
converter_type = var_type
elif (
isinstance(converter_expr, IndexExpr)
and isinstance(converter_expr.analyzed, TypeApplication)
Expand All @@ -751,6 +756,10 @@ def _parse_converter(
)
else:
converter_type = None
elif isinstance(converter_expr, CallExpr):
# The converter is the result of a call, e.g. `converter=make_converter(arg)`.
# Use the return type of the callee as the converter type.
converter_type = _callable_return_type(converter_expr)

if isinstance(converter_expr, LambdaExpr):
# TODO: should we send a fail if converter_expr.min_args > 1?
Expand Down Expand Up @@ -794,6 +803,44 @@ def _parse_converter(
return converter_info


def _callable_return_type(call: CallExpr) -> Type | None:
"""Return the return type of `call` if it is statically known to be callable.

This is used to support converters created by higher-order functions, e.g.
`converter=make_converter(arg)`. We don't perform full type inference at the
call site; we just look at the statically declared return type of the callee.
Generic returns are returned as-is and may contain unresolved type variables.
"""
callee = call.callee
callee_type: Type | None = None
if isinstance(callee, RefExpr) and callee.node:
if isinstance(callee.node, (FuncDef, OverloadedFuncDef)):
callee_type = callee.node.type
elif isinstance(callee.node, Var):
callee_type = callee.node.type
elif isinstance(callee, CallExpr):
# Chained calls like `factory()(arg)`.
callee_type = _callable_return_type(callee)
if callee_type is None:
return None
callee_type = get_proper_type(callee_type)
if isinstance(callee_type, CallableType):
ret = get_proper_type(callee_type.ret_type)
if isinstance(ret, FunctionLike):
return ret
elif isinstance(callee_type, Overloaded):
# Without type inference at the call site we can't pick the correct
# overload. As a heuristic, take the first overload whose return type is
# itself a callable. This matches helpers like `attrs.converters.pipe`
# and `attrs.converters.default_if_none`, whose first overload is the
# most specific callable form.
for item in callee_type.items:
ret = get_proper_type(item.ret_type)
if isinstance(ret, FunctionLike):
return ret
return None


def is_valid_overloaded_converter(defn: OverloadedFuncDef) -> bool:
return all(
(not isinstance(item, Decorator) or isinstance(item.func.type, FunctionLike))
Expand Down
70 changes: 70 additions & 0 deletions test-data/unit/check-plugin-attrs.test
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,76 @@ class C:
reveal_type(C) # N: Revealed type is "def (x: Any, y: Any, z: Any) -> __main__.C"
[builtins fixtures/list.pyi]

[case testAttrsUsingHigherOrderConverter]
# Regression test for https://github.com/python/mypy/issues/15736
from typing import Any, Callable
from attrs import define, field

def make_converter(_length: int) -> Callable[[str], str]:
def converter(val: str) -> str:
return val
return converter

def make_untyped_converter(_length: int) -> Callable[[Any], Any]:
def f(val: Any) -> Any:
return val
return f

@define
class C:
a: str = field(converter=make_converter(40))
b: str = field(converter=make_untyped_converter(40))

reveal_type(C) # N: Revealed type is "def (a: builtins.str, b: Any) -> __main__.C"
reveal_type(C("hi", 5).a) # N: Revealed type is "builtins.str"
[builtins fixtures/list.pyi]

[case testAttrsUsingCallableVariableConverter]
from typing import Callable
from attrs import define, field

def to_str(x: int) -> str:
return ""
my_converter: Callable[[int], str] = to_str

@define
class C:
x: str = field(converter=my_converter)

reveal_type(C) # N: Revealed type is "def (x: builtins.int) -> __main__.C"
reveal_type(C(15).x) # N: Revealed type is "builtins.str"
[builtins fixtures/list.pyi]

[case testAttrsUsingHigherOrderConverterChainedCall]
from typing import Callable
from attrs import define, field

def outer() -> Callable[[int], Callable[[str], str]]:
def middle(_n: int) -> Callable[[str], str]:
def inner(v: str) -> str:
return v
return inner
return middle

@define
class C:
x: str = field(converter=outer()(40))

reveal_type(C) # N: Revealed type is "def (x: builtins.str) -> __main__.C"
[builtins fixtures/list.pyi]

[case testAttrsUsingDefaultIfNoneConverter]
from typing import Optional
from attrs import define, field
from attrs.converters import default_if_none

@define
class C:
x: int = field(default=None, converter=default_if_none(0))

reveal_type(C) # N: Revealed type is "def (x: Any =) -> __main__.C"
[builtins fixtures/plugin_attrs.pyi]

[case testAttrsUsingConverterAndSubclass]
import attr

Expand Down
Loading