1from functools import partial
2from typing import Callable, Optional, List
3
4from mypy import message_registry
5from mypy.nodes import Expression, StrExpr, IntExpr, DictExpr, UnaryExpr
6from mypy.plugin import (
7    Plugin, FunctionContext, MethodContext, MethodSigContext, AttributeContext, ClassDefContext,
8    CheckerPluginInterface,
9)
10from mypy.plugins.common import try_getting_str_literals
11from mypy.types import (
12    Type, Instance, AnyType, TypeOfAny, CallableType, NoneType, TypedDictType,
13    TypeVarDef, TypeVarType, TPDICT_FB_NAMES, get_proper_type, LiteralType
14)
15from mypy.subtypes import is_subtype
16from mypy.typeops import make_simplified_union
17from mypy.checkexpr import is_literal_type_like
18
19
20class DefaultPlugin(Plugin):
21    """Type checker plugin that is enabled by default."""
22
23    def get_function_hook(self, fullname: str
24                          ) -> Optional[Callable[[FunctionContext], Type]]:
25        from mypy.plugins import ctypes
26
27        if fullname == 'contextlib.contextmanager':
28            return contextmanager_callback
29        elif fullname == 'builtins.open' and self.python_version[0] == 3:
30            return open_callback
31        elif fullname == 'ctypes.Array':
32            return ctypes.array_constructor_callback
33        return None
34
35    def get_method_signature_hook(self, fullname: str
36                                  ) -> Optional[Callable[[MethodSigContext], CallableType]]:
37        from mypy.plugins import ctypes
38
39        if fullname == 'typing.Mapping.get':
40            return typed_dict_get_signature_callback
41        elif fullname in set(n + '.setdefault' for n in TPDICT_FB_NAMES):
42            return typed_dict_setdefault_signature_callback
43        elif fullname in set(n + '.pop' for n in TPDICT_FB_NAMES):
44            return typed_dict_pop_signature_callback
45        elif fullname in set(n + '.update' for n in TPDICT_FB_NAMES):
46            return typed_dict_update_signature_callback
47        elif fullname in set(n + '.__delitem__' for n in TPDICT_FB_NAMES):
48            return typed_dict_delitem_signature_callback
49        elif fullname == 'ctypes.Array.__setitem__':
50            return ctypes.array_setitem_callback
51        return None
52
53    def get_method_hook(self, fullname: str
54                        ) -> Optional[Callable[[MethodContext], Type]]:
55        from mypy.plugins import ctypes
56
57        if fullname == 'typing.Mapping.get':
58            return typed_dict_get_callback
59        elif fullname == 'builtins.int.__pow__':
60            return int_pow_callback
61        elif fullname == 'builtins.int.__neg__':
62            return int_neg_callback
63        elif fullname in set(n + '.setdefault' for n in TPDICT_FB_NAMES):
64            return typed_dict_setdefault_callback
65        elif fullname in set(n + '.pop' for n in TPDICT_FB_NAMES):
66            return typed_dict_pop_callback
67        elif fullname in set(n + '.__delitem__' for n in TPDICT_FB_NAMES):
68            return typed_dict_delitem_callback
69        elif fullname == 'ctypes.Array.__getitem__':
70            return ctypes.array_getitem_callback
71        elif fullname == 'ctypes.Array.__iter__':
72            return ctypes.array_iter_callback
73        elif fullname == 'pathlib.Path.open':
74            return path_open_callback
75        return None
76
77    def get_attribute_hook(self, fullname: str
78                           ) -> Optional[Callable[[AttributeContext], Type]]:
79        from mypy.plugins import ctypes
80        from mypy.plugins import enums
81
82        if fullname == 'ctypes.Array.value':
83            return ctypes.array_value_callback
84        elif fullname == 'ctypes.Array.raw':
85            return ctypes.array_raw_callback
86        elif fullname in enums.ENUM_NAME_ACCESS:
87            return enums.enum_name_callback
88        elif fullname in enums.ENUM_VALUE_ACCESS:
89            return enums.enum_value_callback
90        return None
91
92    def get_class_decorator_hook(self, fullname: str
93                                 ) -> Optional[Callable[[ClassDefContext], None]]:
94        from mypy.plugins import attrs
95        from mypy.plugins import dataclasses
96        from mypy.plugins import functools
97
98        if fullname in attrs.attr_class_makers:
99            return attrs.attr_class_maker_callback
100        elif fullname in attrs.attr_dataclass_makers:
101            return partial(
102                attrs.attr_class_maker_callback,
103                auto_attribs_default=True,
104            )
105        elif fullname in attrs.attr_frozen_makers:
106            return partial(
107                attrs.attr_class_maker_callback,
108                auto_attribs_default=None,
109                frozen_default=True,
110            )
111        elif fullname in attrs.attr_define_makers:
112            return partial(
113                attrs.attr_class_maker_callback,
114                auto_attribs_default=None,
115            )
116        elif fullname in dataclasses.dataclass_makers:
117            return dataclasses.dataclass_class_maker_callback
118        elif fullname in functools.functools_total_ordering_makers:
119            return functools.functools_total_ordering_maker_callback
120
121        return None
122
123
124def open_callback(ctx: FunctionContext) -> Type:
125    """Infer a better return type for 'open'."""
126    return _analyze_open_signature(
127        arg_types=ctx.arg_types,
128        args=ctx.args,
129        mode_arg_index=1,
130        default_return_type=ctx.default_return_type,
131        api=ctx.api,
132    )
133
134
135def path_open_callback(ctx: MethodContext) -> Type:
136    """Infer a better return type for 'pathlib.Path.open'."""
137    return _analyze_open_signature(
138        arg_types=ctx.arg_types,
139        args=ctx.args,
140        mode_arg_index=0,
141        default_return_type=ctx.default_return_type,
142        api=ctx.api,
143    )
144
145
146def _analyze_open_signature(arg_types: List[List[Type]],
147                            args: List[List[Expression]],
148                            mode_arg_index: int,
149                            default_return_type: Type,
150                            api: CheckerPluginInterface,
151                            ) -> Type:
152    """A helper for analyzing any function that has approximately
153    the same signature as the builtin 'open(...)' function.
154
155    Currently, the only thing the caller can customize is the index
156    of the 'mode' argument. If the mode argument is omitted or is a
157    string literal, we refine the return type to either 'TextIO' or
158    'BinaryIO' as appropriate.
159    """
160    mode = None
161    if not arg_types or len(arg_types[mode_arg_index]) != 1:
162        mode = 'r'
163    else:
164        mode_expr = args[mode_arg_index][0]
165        if isinstance(mode_expr, StrExpr):
166            mode = mode_expr.value
167    if mode is not None:
168        assert isinstance(default_return_type, Instance)  # type: ignore
169        if 'b' in mode:
170            return api.named_generic_type('typing.BinaryIO', [])
171        else:
172            return api.named_generic_type('typing.TextIO', [])
173    return default_return_type
174
175
176def contextmanager_callback(ctx: FunctionContext) -> Type:
177    """Infer a better return type for 'contextlib.contextmanager'."""
178    # Be defensive, just in case.
179    if ctx.arg_types and len(ctx.arg_types[0]) == 1:
180        arg_type = get_proper_type(ctx.arg_types[0][0])
181        default_return = get_proper_type(ctx.default_return_type)
182        if (isinstance(arg_type, CallableType)
183                and isinstance(default_return, CallableType)):
184            # The stub signature doesn't preserve information about arguments so
185            # add them back here.
186            return default_return.copy_modified(
187                arg_types=arg_type.arg_types,
188                arg_kinds=arg_type.arg_kinds,
189                arg_names=arg_type.arg_names,
190                variables=arg_type.variables,
191                is_ellipsis_args=arg_type.is_ellipsis_args)
192    return ctx.default_return_type
193
194
195def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType:
196    """Try to infer a better signature type for TypedDict.get.
197
198    This is used to get better type context for the second argument that
199    depends on a TypedDict value type.
200    """
201    signature = ctx.default_signature
202    if (isinstance(ctx.type, TypedDictType)
203            and len(ctx.args) == 2
204            and len(ctx.args[0]) == 1
205            and isinstance(ctx.args[0][0], StrExpr)
206            and len(signature.arg_types) == 2
207            and len(signature.variables) == 1
208            and len(ctx.args[1]) == 1):
209        key = ctx.args[0][0].value
210        value_type = get_proper_type(ctx.type.items.get(key))
211        ret_type = signature.ret_type
212        if value_type:
213            default_arg = ctx.args[1][0]
214            if (isinstance(value_type, TypedDictType)
215                    and isinstance(default_arg, DictExpr)
216                    and len(default_arg.items) == 0):
217                # Caller has empty dict {} as default for typed dict.
218                value_type = value_type.copy_modified(required_keys=set())
219            # Tweak the signature to include the value type as context. It's
220            # only needed for type inference since there's a union with a type
221            # variable that accepts everything.
222            assert isinstance(signature.variables[0], TypeVarDef)
223            tv = TypeVarType(signature.variables[0])
224            return signature.copy_modified(
225                arg_types=[signature.arg_types[0],
226                           make_simplified_union([value_type, tv])],
227                ret_type=ret_type)
228    return signature
229
230
231def typed_dict_get_callback(ctx: MethodContext) -> Type:
232    """Infer a precise return type for TypedDict.get with literal first argument."""
233    if (isinstance(ctx.type, TypedDictType)
234            and len(ctx.arg_types) >= 1
235            and len(ctx.arg_types[0]) == 1):
236        keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
237        if keys is None:
238            return ctx.default_return_type
239
240        output_types = []  # type: List[Type]
241        for key in keys:
242            value_type = get_proper_type(ctx.type.items.get(key))
243            if value_type is None:
244                return ctx.default_return_type
245
246            if len(ctx.arg_types) == 1:
247                output_types.append(value_type)
248            elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1
249                  and len(ctx.args[1]) == 1):
250                default_arg = ctx.args[1][0]
251                if (isinstance(default_arg, DictExpr) and len(default_arg.items) == 0
252                        and isinstance(value_type, TypedDictType)):
253                    # Special case '{}' as the default for a typed dict type.
254                    output_types.append(value_type.copy_modified(required_keys=set()))
255                else:
256                    output_types.append(value_type)
257                    output_types.append(ctx.arg_types[1][0])
258
259        if len(ctx.arg_types) == 1:
260            output_types.append(NoneType())
261
262        return make_simplified_union(output_types)
263    return ctx.default_return_type
264
265
266def typed_dict_pop_signature_callback(ctx: MethodSigContext) -> CallableType:
267    """Try to infer a better signature type for TypedDict.pop.
268
269    This is used to get better type context for the second argument that
270    depends on a TypedDict value type.
271    """
272    signature = ctx.default_signature
273    str_type = ctx.api.named_generic_type('builtins.str', [])
274    if (isinstance(ctx.type, TypedDictType)
275            and len(ctx.args) == 2
276            and len(ctx.args[0]) == 1
277            and isinstance(ctx.args[0][0], StrExpr)
278            and len(signature.arg_types) == 2
279            and len(signature.variables) == 1
280            and len(ctx.args[1]) == 1):
281        key = ctx.args[0][0].value
282        value_type = ctx.type.items.get(key)
283        if value_type:
284            # Tweak the signature to include the value type as context. It's
285            # only needed for type inference since there's a union with a type
286            # variable that accepts everything.
287            assert isinstance(signature.variables[0], TypeVarDef)
288            tv = TypeVarType(signature.variables[0])
289            typ = make_simplified_union([value_type, tv])
290            return signature.copy_modified(
291                arg_types=[str_type, typ],
292                ret_type=typ)
293    return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]])
294
295
296def typed_dict_pop_callback(ctx: MethodContext) -> Type:
297    """Type check and infer a precise return type for TypedDict.pop."""
298    if (isinstance(ctx.type, TypedDictType)
299            and len(ctx.arg_types) >= 1
300            and len(ctx.arg_types[0]) == 1):
301        keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
302        if keys is None:
303            ctx.api.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context)
304            return AnyType(TypeOfAny.from_error)
305
306        value_types = []
307        for key in keys:
308            if key in ctx.type.required_keys:
309                ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
310
311            value_type = ctx.type.items.get(key)
312            if value_type:
313                value_types.append(value_type)
314            else:
315                ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
316                return AnyType(TypeOfAny.from_error)
317
318        if len(ctx.args[1]) == 0:
319            return make_simplified_union(value_types)
320        elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1
321              and len(ctx.args[1]) == 1):
322            return make_simplified_union([*value_types, ctx.arg_types[1][0]])
323    return ctx.default_return_type
324
325
326def typed_dict_setdefault_signature_callback(ctx: MethodSigContext) -> CallableType:
327    """Try to infer a better signature type for TypedDict.setdefault.
328
329    This is used to get better type context for the second argument that
330    depends on a TypedDict value type.
331    """
332    signature = ctx.default_signature
333    str_type = ctx.api.named_generic_type('builtins.str', [])
334    if (isinstance(ctx.type, TypedDictType)
335            and len(ctx.args) == 2
336            and len(ctx.args[0]) == 1
337            and isinstance(ctx.args[0][0], StrExpr)
338            and len(signature.arg_types) == 2
339            and len(ctx.args[1]) == 1):
340        key = ctx.args[0][0].value
341        value_type = ctx.type.items.get(key)
342        if value_type:
343            return signature.copy_modified(arg_types=[str_type, value_type])
344    return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]])
345
346
347def typed_dict_setdefault_callback(ctx: MethodContext) -> Type:
348    """Type check TypedDict.setdefault and infer a precise return type."""
349    if (isinstance(ctx.type, TypedDictType)
350            and len(ctx.arg_types) == 2
351            and len(ctx.arg_types[0]) == 1
352            and len(ctx.arg_types[1]) == 1):
353        keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
354        if keys is None:
355            ctx.api.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context)
356            return AnyType(TypeOfAny.from_error)
357
358        default_type = ctx.arg_types[1][0]
359
360        value_types = []
361        for key in keys:
362            value_type = ctx.type.items.get(key)
363
364            if value_type is None:
365                ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
366                return AnyType(TypeOfAny.from_error)
367
368            # The signature_callback above can't always infer the right signature
369            # (e.g. when the expression is a variable that happens to be a Literal str)
370            # so we need to handle the check ourselves here and make sure the provided
371            # default can be assigned to all key-value pairs we're updating.
372            if not is_subtype(default_type, value_type):
373                ctx.api.msg.typeddict_setdefault_arguments_inconsistent(
374                    default_type, value_type, ctx.context)
375                return AnyType(TypeOfAny.from_error)
376
377            value_types.append(value_type)
378
379        return make_simplified_union(value_types)
380    return ctx.default_return_type
381
382
383def typed_dict_delitem_signature_callback(ctx: MethodSigContext) -> CallableType:
384    # Replace NoReturn as the argument type.
385    str_type = ctx.api.named_generic_type('builtins.str', [])
386    return ctx.default_signature.copy_modified(arg_types=[str_type])
387
388
389def typed_dict_delitem_callback(ctx: MethodContext) -> Type:
390    """Type check TypedDict.__delitem__."""
391    if (isinstance(ctx.type, TypedDictType)
392            and len(ctx.arg_types) == 1
393            and len(ctx.arg_types[0]) == 1):
394        keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
395        if keys is None:
396            ctx.api.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context)
397            return AnyType(TypeOfAny.from_error)
398
399        for key in keys:
400            if key in ctx.type.required_keys:
401                ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context)
402            elif key not in ctx.type.items:
403                ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context)
404    return ctx.default_return_type
405
406
407def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType:
408    """Try to infer a better signature type for TypedDict.update."""
409    signature = ctx.default_signature
410    if (isinstance(ctx.type, TypedDictType)
411            and len(signature.arg_types) == 1):
412        arg_type = get_proper_type(signature.arg_types[0])
413        assert isinstance(arg_type, TypedDictType)
414        arg_type = arg_type.as_anonymous()
415        arg_type = arg_type.copy_modified(required_keys=set())
416        return signature.copy_modified(arg_types=[arg_type])
417    return signature
418
419
420def int_pow_callback(ctx: MethodContext) -> Type:
421    """Infer a more precise return type for int.__pow__."""
422    # int.__pow__ has an optional modulo argument,
423    # so we expect 2 argument positions
424    if (len(ctx.arg_types) == 2
425            and len(ctx.arg_types[0]) == 1 and len(ctx.arg_types[1]) == 0):
426        arg = ctx.args[0][0]
427        if isinstance(arg, IntExpr):
428            exponent = arg.value
429        elif isinstance(arg, UnaryExpr) and arg.op == '-' and isinstance(arg.expr, IntExpr):
430            exponent = -arg.expr.value
431        else:
432            # Right operand not an int literal or a negated literal -- give up.
433            return ctx.default_return_type
434        if exponent >= 0:
435            return ctx.api.named_generic_type('builtins.int', [])
436        else:
437            return ctx.api.named_generic_type('builtins.float', [])
438    return ctx.default_return_type
439
440
441def int_neg_callback(ctx: MethodContext) -> Type:
442    """Infer a more precise return type for int.__neg__.
443
444    This is mainly used to infer the return type as LiteralType
445    if the original underlying object is a LiteralType object
446    """
447    if isinstance(ctx.type, Instance) and ctx.type.last_known_value is not None:
448        value = ctx.type.last_known_value.value
449        fallback = ctx.type.last_known_value.fallback
450        if isinstance(value, int):
451            if is_literal_type_like(ctx.api.type_context[-1]):
452                return LiteralType(value=-value, fallback=fallback)
453            else:
454                return ctx.type.copy_modified(last_known_value=LiteralType(
455                    value=-value,
456                    fallback=ctx.type,
457                    line=ctx.type.line,
458                    column=ctx.type.column,
459                ))
460    elif isinstance(ctx.type, LiteralType):
461        value = ctx.type.value
462        fallback = ctx.type.fallback
463        if isinstance(value, int):
464            return LiteralType(value=-value, fallback=fallback)
465    return ctx.default_return_type
466