1from functools import partial 2from typing import Callable, Optional, List 3 4from mypy import message_registry 5from mypy.nodes import Expression, StrExpr, IntExpr, DictExpr, UnaryExpr 6from mypy.plugin import ( 7 Plugin, FunctionContext, MethodContext, MethodSigContext, AttributeContext, ClassDefContext, 8 CheckerPluginInterface, 9) 10from mypy.plugins.common import try_getting_str_literals 11from mypy.types import ( 12 Type, Instance, AnyType, TypeOfAny, CallableType, NoneType, TypedDictType, 13 TypeVarDef, TypeVarType, TPDICT_FB_NAMES, get_proper_type, LiteralType 14) 15from mypy.subtypes import is_subtype 16from mypy.typeops import make_simplified_union 17from mypy.checkexpr import is_literal_type_like 18 19 20class DefaultPlugin(Plugin): 21 """Type checker plugin that is enabled by default.""" 22 23 def get_function_hook(self, fullname: str 24 ) -> Optional[Callable[[FunctionContext], Type]]: 25 from mypy.plugins import ctypes 26 27 if fullname == 'contextlib.contextmanager': 28 return contextmanager_callback 29 elif fullname == 'builtins.open' and self.python_version[0] == 3: 30 return open_callback 31 elif fullname == 'ctypes.Array': 32 return ctypes.array_constructor_callback 33 return None 34 35 def get_method_signature_hook(self, fullname: str 36 ) -> Optional[Callable[[MethodSigContext], CallableType]]: 37 from mypy.plugins import ctypes 38 39 if fullname == 'typing.Mapping.get': 40 return typed_dict_get_signature_callback 41 elif fullname in set(n + '.setdefault' for n in TPDICT_FB_NAMES): 42 return typed_dict_setdefault_signature_callback 43 elif fullname in set(n + '.pop' for n in TPDICT_FB_NAMES): 44 return typed_dict_pop_signature_callback 45 elif fullname in set(n + '.update' for n in TPDICT_FB_NAMES): 46 return typed_dict_update_signature_callback 47 elif fullname in set(n + '.__delitem__' for n in TPDICT_FB_NAMES): 48 return typed_dict_delitem_signature_callback 49 elif fullname == 'ctypes.Array.__setitem__': 50 return ctypes.array_setitem_callback 51 return None 52 53 def get_method_hook(self, fullname: str 54 ) -> Optional[Callable[[MethodContext], Type]]: 55 from mypy.plugins import ctypes 56 57 if fullname == 'typing.Mapping.get': 58 return typed_dict_get_callback 59 elif fullname == 'builtins.int.__pow__': 60 return int_pow_callback 61 elif fullname == 'builtins.int.__neg__': 62 return int_neg_callback 63 elif fullname in set(n + '.setdefault' for n in TPDICT_FB_NAMES): 64 return typed_dict_setdefault_callback 65 elif fullname in set(n + '.pop' for n in TPDICT_FB_NAMES): 66 return typed_dict_pop_callback 67 elif fullname in set(n + '.__delitem__' for n in TPDICT_FB_NAMES): 68 return typed_dict_delitem_callback 69 elif fullname == 'ctypes.Array.__getitem__': 70 return ctypes.array_getitem_callback 71 elif fullname == 'ctypes.Array.__iter__': 72 return ctypes.array_iter_callback 73 elif fullname == 'pathlib.Path.open': 74 return path_open_callback 75 return None 76 77 def get_attribute_hook(self, fullname: str 78 ) -> Optional[Callable[[AttributeContext], Type]]: 79 from mypy.plugins import ctypes 80 from mypy.plugins import enums 81 82 if fullname == 'ctypes.Array.value': 83 return ctypes.array_value_callback 84 elif fullname == 'ctypes.Array.raw': 85 return ctypes.array_raw_callback 86 elif fullname in enums.ENUM_NAME_ACCESS: 87 return enums.enum_name_callback 88 elif fullname in enums.ENUM_VALUE_ACCESS: 89 return enums.enum_value_callback 90 return None 91 92 def get_class_decorator_hook(self, fullname: str 93 ) -> Optional[Callable[[ClassDefContext], None]]: 94 from mypy.plugins import attrs 95 from mypy.plugins import dataclasses 96 from mypy.plugins import functools 97 98 if fullname in attrs.attr_class_makers: 99 return attrs.attr_class_maker_callback 100 elif fullname in attrs.attr_dataclass_makers: 101 return partial( 102 attrs.attr_class_maker_callback, 103 auto_attribs_default=True, 104 ) 105 elif fullname in attrs.attr_frozen_makers: 106 return partial( 107 attrs.attr_class_maker_callback, 108 auto_attribs_default=None, 109 frozen_default=True, 110 ) 111 elif fullname in attrs.attr_define_makers: 112 return partial( 113 attrs.attr_class_maker_callback, 114 auto_attribs_default=None, 115 ) 116 elif fullname in dataclasses.dataclass_makers: 117 return dataclasses.dataclass_class_maker_callback 118 elif fullname in functools.functools_total_ordering_makers: 119 return functools.functools_total_ordering_maker_callback 120 121 return None 122 123 124def open_callback(ctx: FunctionContext) -> Type: 125 """Infer a better return type for 'open'.""" 126 return _analyze_open_signature( 127 arg_types=ctx.arg_types, 128 args=ctx.args, 129 mode_arg_index=1, 130 default_return_type=ctx.default_return_type, 131 api=ctx.api, 132 ) 133 134 135def path_open_callback(ctx: MethodContext) -> Type: 136 """Infer a better return type for 'pathlib.Path.open'.""" 137 return _analyze_open_signature( 138 arg_types=ctx.arg_types, 139 args=ctx.args, 140 mode_arg_index=0, 141 default_return_type=ctx.default_return_type, 142 api=ctx.api, 143 ) 144 145 146def _analyze_open_signature(arg_types: List[List[Type]], 147 args: List[List[Expression]], 148 mode_arg_index: int, 149 default_return_type: Type, 150 api: CheckerPluginInterface, 151 ) -> Type: 152 """A helper for analyzing any function that has approximately 153 the same signature as the builtin 'open(...)' function. 154 155 Currently, the only thing the caller can customize is the index 156 of the 'mode' argument. If the mode argument is omitted or is a 157 string literal, we refine the return type to either 'TextIO' or 158 'BinaryIO' as appropriate. 159 """ 160 mode = None 161 if not arg_types or len(arg_types[mode_arg_index]) != 1: 162 mode = 'r' 163 else: 164 mode_expr = args[mode_arg_index][0] 165 if isinstance(mode_expr, StrExpr): 166 mode = mode_expr.value 167 if mode is not None: 168 assert isinstance(default_return_type, Instance) # type: ignore 169 if 'b' in mode: 170 return api.named_generic_type('typing.BinaryIO', []) 171 else: 172 return api.named_generic_type('typing.TextIO', []) 173 return default_return_type 174 175 176def contextmanager_callback(ctx: FunctionContext) -> Type: 177 """Infer a better return type for 'contextlib.contextmanager'.""" 178 # Be defensive, just in case. 179 if ctx.arg_types and len(ctx.arg_types[0]) == 1: 180 arg_type = get_proper_type(ctx.arg_types[0][0]) 181 default_return = get_proper_type(ctx.default_return_type) 182 if (isinstance(arg_type, CallableType) 183 and isinstance(default_return, CallableType)): 184 # The stub signature doesn't preserve information about arguments so 185 # add them back here. 186 return default_return.copy_modified( 187 arg_types=arg_type.arg_types, 188 arg_kinds=arg_type.arg_kinds, 189 arg_names=arg_type.arg_names, 190 variables=arg_type.variables, 191 is_ellipsis_args=arg_type.is_ellipsis_args) 192 return ctx.default_return_type 193 194 195def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType: 196 """Try to infer a better signature type for TypedDict.get. 197 198 This is used to get better type context for the second argument that 199 depends on a TypedDict value type. 200 """ 201 signature = ctx.default_signature 202 if (isinstance(ctx.type, TypedDictType) 203 and len(ctx.args) == 2 204 and len(ctx.args[0]) == 1 205 and isinstance(ctx.args[0][0], StrExpr) 206 and len(signature.arg_types) == 2 207 and len(signature.variables) == 1 208 and len(ctx.args[1]) == 1): 209 key = ctx.args[0][0].value 210 value_type = get_proper_type(ctx.type.items.get(key)) 211 ret_type = signature.ret_type 212 if value_type: 213 default_arg = ctx.args[1][0] 214 if (isinstance(value_type, TypedDictType) 215 and isinstance(default_arg, DictExpr) 216 and len(default_arg.items) == 0): 217 # Caller has empty dict {} as default for typed dict. 218 value_type = value_type.copy_modified(required_keys=set()) 219 # Tweak the signature to include the value type as context. It's 220 # only needed for type inference since there's a union with a type 221 # variable that accepts everything. 222 assert isinstance(signature.variables[0], TypeVarDef) 223 tv = TypeVarType(signature.variables[0]) 224 return signature.copy_modified( 225 arg_types=[signature.arg_types[0], 226 make_simplified_union([value_type, tv])], 227 ret_type=ret_type) 228 return signature 229 230 231def typed_dict_get_callback(ctx: MethodContext) -> Type: 232 """Infer a precise return type for TypedDict.get with literal first argument.""" 233 if (isinstance(ctx.type, TypedDictType) 234 and len(ctx.arg_types) >= 1 235 and len(ctx.arg_types[0]) == 1): 236 keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0]) 237 if keys is None: 238 return ctx.default_return_type 239 240 output_types = [] # type: List[Type] 241 for key in keys: 242 value_type = get_proper_type(ctx.type.items.get(key)) 243 if value_type is None: 244 return ctx.default_return_type 245 246 if len(ctx.arg_types) == 1: 247 output_types.append(value_type) 248 elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 249 and len(ctx.args[1]) == 1): 250 default_arg = ctx.args[1][0] 251 if (isinstance(default_arg, DictExpr) and len(default_arg.items) == 0 252 and isinstance(value_type, TypedDictType)): 253 # Special case '{}' as the default for a typed dict type. 254 output_types.append(value_type.copy_modified(required_keys=set())) 255 else: 256 output_types.append(value_type) 257 output_types.append(ctx.arg_types[1][0]) 258 259 if len(ctx.arg_types) == 1: 260 output_types.append(NoneType()) 261 262 return make_simplified_union(output_types) 263 return ctx.default_return_type 264 265 266def typed_dict_pop_signature_callback(ctx: MethodSigContext) -> CallableType: 267 """Try to infer a better signature type for TypedDict.pop. 268 269 This is used to get better type context for the second argument that 270 depends on a TypedDict value type. 271 """ 272 signature = ctx.default_signature 273 str_type = ctx.api.named_generic_type('builtins.str', []) 274 if (isinstance(ctx.type, TypedDictType) 275 and len(ctx.args) == 2 276 and len(ctx.args[0]) == 1 277 and isinstance(ctx.args[0][0], StrExpr) 278 and len(signature.arg_types) == 2 279 and len(signature.variables) == 1 280 and len(ctx.args[1]) == 1): 281 key = ctx.args[0][0].value 282 value_type = ctx.type.items.get(key) 283 if value_type: 284 # Tweak the signature to include the value type as context. It's 285 # only needed for type inference since there's a union with a type 286 # variable that accepts everything. 287 assert isinstance(signature.variables[0], TypeVarDef) 288 tv = TypeVarType(signature.variables[0]) 289 typ = make_simplified_union([value_type, tv]) 290 return signature.copy_modified( 291 arg_types=[str_type, typ], 292 ret_type=typ) 293 return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]]) 294 295 296def typed_dict_pop_callback(ctx: MethodContext) -> Type: 297 """Type check and infer a precise return type for TypedDict.pop.""" 298 if (isinstance(ctx.type, TypedDictType) 299 and len(ctx.arg_types) >= 1 300 and len(ctx.arg_types[0]) == 1): 301 keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0]) 302 if keys is None: 303 ctx.api.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context) 304 return AnyType(TypeOfAny.from_error) 305 306 value_types = [] 307 for key in keys: 308 if key in ctx.type.required_keys: 309 ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context) 310 311 value_type = ctx.type.items.get(key) 312 if value_type: 313 value_types.append(value_type) 314 else: 315 ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) 316 return AnyType(TypeOfAny.from_error) 317 318 if len(ctx.args[1]) == 0: 319 return make_simplified_union(value_types) 320 elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 321 and len(ctx.args[1]) == 1): 322 return make_simplified_union([*value_types, ctx.arg_types[1][0]]) 323 return ctx.default_return_type 324 325 326def typed_dict_setdefault_signature_callback(ctx: MethodSigContext) -> CallableType: 327 """Try to infer a better signature type for TypedDict.setdefault. 328 329 This is used to get better type context for the second argument that 330 depends on a TypedDict value type. 331 """ 332 signature = ctx.default_signature 333 str_type = ctx.api.named_generic_type('builtins.str', []) 334 if (isinstance(ctx.type, TypedDictType) 335 and len(ctx.args) == 2 336 and len(ctx.args[0]) == 1 337 and isinstance(ctx.args[0][0], StrExpr) 338 and len(signature.arg_types) == 2 339 and len(ctx.args[1]) == 1): 340 key = ctx.args[0][0].value 341 value_type = ctx.type.items.get(key) 342 if value_type: 343 return signature.copy_modified(arg_types=[str_type, value_type]) 344 return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]]) 345 346 347def typed_dict_setdefault_callback(ctx: MethodContext) -> Type: 348 """Type check TypedDict.setdefault and infer a precise return type.""" 349 if (isinstance(ctx.type, TypedDictType) 350 and len(ctx.arg_types) == 2 351 and len(ctx.arg_types[0]) == 1 352 and len(ctx.arg_types[1]) == 1): 353 keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0]) 354 if keys is None: 355 ctx.api.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context) 356 return AnyType(TypeOfAny.from_error) 357 358 default_type = ctx.arg_types[1][0] 359 360 value_types = [] 361 for key in keys: 362 value_type = ctx.type.items.get(key) 363 364 if value_type is None: 365 ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) 366 return AnyType(TypeOfAny.from_error) 367 368 # The signature_callback above can't always infer the right signature 369 # (e.g. when the expression is a variable that happens to be a Literal str) 370 # so we need to handle the check ourselves here and make sure the provided 371 # default can be assigned to all key-value pairs we're updating. 372 if not is_subtype(default_type, value_type): 373 ctx.api.msg.typeddict_setdefault_arguments_inconsistent( 374 default_type, value_type, ctx.context) 375 return AnyType(TypeOfAny.from_error) 376 377 value_types.append(value_type) 378 379 return make_simplified_union(value_types) 380 return ctx.default_return_type 381 382 383def typed_dict_delitem_signature_callback(ctx: MethodSigContext) -> CallableType: 384 # Replace NoReturn as the argument type. 385 str_type = ctx.api.named_generic_type('builtins.str', []) 386 return ctx.default_signature.copy_modified(arg_types=[str_type]) 387 388 389def typed_dict_delitem_callback(ctx: MethodContext) -> Type: 390 """Type check TypedDict.__delitem__.""" 391 if (isinstance(ctx.type, TypedDictType) 392 and len(ctx.arg_types) == 1 393 and len(ctx.arg_types[0]) == 1): 394 keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0]) 395 if keys is None: 396 ctx.api.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context) 397 return AnyType(TypeOfAny.from_error) 398 399 for key in keys: 400 if key in ctx.type.required_keys: 401 ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, ctx.context) 402 elif key not in ctx.type.items: 403 ctx.api.msg.typeddict_key_not_found(ctx.type, key, ctx.context) 404 return ctx.default_return_type 405 406 407def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType: 408 """Try to infer a better signature type for TypedDict.update.""" 409 signature = ctx.default_signature 410 if (isinstance(ctx.type, TypedDictType) 411 and len(signature.arg_types) == 1): 412 arg_type = get_proper_type(signature.arg_types[0]) 413 assert isinstance(arg_type, TypedDictType) 414 arg_type = arg_type.as_anonymous() 415 arg_type = arg_type.copy_modified(required_keys=set()) 416 return signature.copy_modified(arg_types=[arg_type]) 417 return signature 418 419 420def int_pow_callback(ctx: MethodContext) -> Type: 421 """Infer a more precise return type for int.__pow__.""" 422 # int.__pow__ has an optional modulo argument, 423 # so we expect 2 argument positions 424 if (len(ctx.arg_types) == 2 425 and len(ctx.arg_types[0]) == 1 and len(ctx.arg_types[1]) == 0): 426 arg = ctx.args[0][0] 427 if isinstance(arg, IntExpr): 428 exponent = arg.value 429 elif isinstance(arg, UnaryExpr) and arg.op == '-' and isinstance(arg.expr, IntExpr): 430 exponent = -arg.expr.value 431 else: 432 # Right operand not an int literal or a negated literal -- give up. 433 return ctx.default_return_type 434 if exponent >= 0: 435 return ctx.api.named_generic_type('builtins.int', []) 436 else: 437 return ctx.api.named_generic_type('builtins.float', []) 438 return ctx.default_return_type 439 440 441def int_neg_callback(ctx: MethodContext) -> Type: 442 """Infer a more precise return type for int.__neg__. 443 444 This is mainly used to infer the return type as LiteralType 445 if the original underlying object is a LiteralType object 446 """ 447 if isinstance(ctx.type, Instance) and ctx.type.last_known_value is not None: 448 value = ctx.type.last_known_value.value 449 fallback = ctx.type.last_known_value.fallback 450 if isinstance(value, int): 451 if is_literal_type_like(ctx.api.type_context[-1]): 452 return LiteralType(value=-value, fallback=fallback) 453 else: 454 return ctx.type.copy_modified(last_known_value=LiteralType( 455 value=-value, 456 fallback=ctx.type, 457 line=ctx.type.line, 458 column=ctx.type.column, 459 )) 460 elif isinstance(ctx.type, LiteralType): 461 value = ctx.type.value 462 fallback = ctx.type.fallback 463 if isinstance(value, int): 464 return LiteralType(value=-value, fallback=fallback) 465 return ctx.default_return_type 466