1"""Simple type inference for decorated functions during semantic analysis."""
2
3from typing import Optional
4
5from mypy.nodes import Expression, Decorator, CallExpr, FuncDef, RefExpr, Var, ARG_POS
6from mypy.types import (
7    Type, CallableType, AnyType, TypeOfAny, TypeVarType, ProperType, get_proper_type
8)
9from mypy.typeops import function_type
10from mypy.typevars import has_no_typevars
11from mypy.semanal_shared import SemanticAnalyzerInterface
12
13
14def infer_decorator_signature_if_simple(dec: Decorator,
15                                        analyzer: SemanticAnalyzerInterface) -> None:
16    """Try to infer the type of the decorated function.
17
18    This lets us resolve additional references to decorated functions
19    during type checking. Otherwise the type might not be available
20    when we need it, since module top levels can't be deferred.
21
22    This basically uses a simple special-purpose type inference
23    engine just for decorators.
24    """
25    if dec.var.is_property:
26        # Decorators are expected to have a callable type (it's a little odd).
27        if dec.func.type is None:
28            dec.var.type = CallableType(
29                [AnyType(TypeOfAny.special_form)],
30                [ARG_POS],
31                [None],
32                AnyType(TypeOfAny.special_form),
33                analyzer.named_type('__builtins__.function'),
34                name=dec.var.name)
35        elif isinstance(dec.func.type, CallableType):
36            dec.var.type = dec.func.type
37        return
38    decorator_preserves_type = True
39    for expr in dec.decorators:
40        preserve_type = False
41        if isinstance(expr, RefExpr) and isinstance(expr.node, FuncDef):
42            if expr.node.type and is_identity_signature(expr.node.type):
43                preserve_type = True
44        if not preserve_type:
45            decorator_preserves_type = False
46            break
47    if decorator_preserves_type:
48        # No non-identity decorators left. We can trivially infer the type
49        # of the function here.
50        dec.var.type = function_type(dec.func, analyzer.named_type('__builtins__.function'))
51    if dec.decorators:
52        return_type = calculate_return_type(dec.decorators[0])
53        if return_type and isinstance(return_type, AnyType):
54            # The outermost decorator will return Any so we know the type of the
55            # decorated function.
56            dec.var.type = AnyType(TypeOfAny.from_another_any, source_any=return_type)
57        sig = find_fixed_callable_return(dec.decorators[0])
58        if sig:
59            # The outermost decorator always returns the same kind of function,
60            # so we know that this is the type of the decorated function.
61            orig_sig = function_type(dec.func, analyzer.named_type('__builtins__.function'))
62            sig.name = orig_sig.items()[0].name
63            dec.var.type = sig
64
65
66def is_identity_signature(sig: Type) -> bool:
67    """Is type a callable of form T -> T (where T is a type variable)?"""
68    sig = get_proper_type(sig)
69    if isinstance(sig, CallableType) and sig.arg_kinds == [ARG_POS]:
70        if isinstance(sig.arg_types[0], TypeVarType) and isinstance(sig.ret_type, TypeVarType):
71            return sig.arg_types[0].id == sig.ret_type.id
72    return False
73
74
75def calculate_return_type(expr: Expression) -> Optional[ProperType]:
76    """Return the return type if we can calculate it.
77
78    This only uses information available during semantic analysis so this
79    will sometimes return None because of insufficient information (as
80    type inference hasn't run yet).
81    """
82    if isinstance(expr, RefExpr):
83        if isinstance(expr.node, FuncDef):
84            typ = expr.node.type
85            if typ is None:
86                # No signature -> default to Any.
87                return AnyType(TypeOfAny.unannotated)
88            # Explicit Any return?
89            if isinstance(typ, CallableType):
90                return get_proper_type(typ.ret_type)
91            return None
92        elif isinstance(expr.node, Var):
93            return get_proper_type(expr.node.type)
94    elif isinstance(expr, CallExpr):
95        return calculate_return_type(expr.callee)
96    return None
97
98
99def find_fixed_callable_return(expr: Expression) -> Optional[CallableType]:
100    """Return the return type, if expression refers to a callable that returns a callable.
101
102    But only do this if the return type has no type variables. Return None otherwise.
103    This approximates things a lot as this is supposed to be called before type checking
104    when full type information is not available yet.
105    """
106    if isinstance(expr, RefExpr):
107        if isinstance(expr.node, FuncDef):
108            typ = expr.node.type
109            if typ:
110                if isinstance(typ, CallableType) and has_no_typevars(typ.ret_type):
111                    ret_type = get_proper_type(typ.ret_type)
112                    if isinstance(ret_type, CallableType):
113                        return ret_type
114    elif isinstance(expr, CallExpr):
115        t = find_fixed_callable_return(expr.callee)
116        if t:
117            ret_type = get_proper_type(t.ret_type)
118            if isinstance(ret_type, CallableType):
119                return ret_type
120    return None
121