1"""Type inference constraints.""" 2 3from typing import Iterable, List, Optional, Sequence 4from typing_extensions import Final 5 6from mypy.types import ( 7 CallableType, Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarType, Instance, 8 TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType, 9 UninhabitedType, TypeType, TypeVarId, TypeQuery, is_named_instance, TypeOfAny, LiteralType, 10 ProperType, get_proper_type, TypeAliasType, TypeGuardType 11) 12from mypy.maptype import map_instance_to_supertype 13import mypy.subtypes 14import mypy.sametypes 15import mypy.typeops 16from mypy.erasetype import erase_typevars 17from mypy.nodes import COVARIANT, CONTRAVARIANT 18from mypy.argmap import ArgTypeExpander 19from mypy.typestate import TypeState 20 21SUBTYPE_OF = 0 # type: Final 22SUPERTYPE_OF = 1 # type: Final 23 24 25class Constraint: 26 """A representation of a type constraint. 27 28 It can be either T <: type or T :> type (T is a type variable). 29 """ 30 31 type_var = None # type: TypeVarId 32 op = 0 # SUBTYPE_OF or SUPERTYPE_OF 33 target = None # type: Type 34 35 def __init__(self, type_var: TypeVarId, op: int, target: Type) -> None: 36 self.type_var = type_var 37 self.op = op 38 self.target = target 39 40 def __repr__(self) -> str: 41 op_str = '<:' 42 if self.op == SUPERTYPE_OF: 43 op_str = ':>' 44 return '{} {} {}'.format(self.type_var, op_str, self.target) 45 46 47def infer_constraints_for_callable( 48 callee: CallableType, arg_types: Sequence[Optional[Type]], arg_kinds: List[int], 49 formal_to_actual: List[List[int]]) -> List[Constraint]: 50 """Infer type variable constraints for a callable and actual arguments. 51 52 Return a list of constraints. 53 """ 54 constraints = [] # type: List[Constraint] 55 mapper = ArgTypeExpander() 56 57 for i, actuals in enumerate(formal_to_actual): 58 for actual in actuals: 59 actual_arg_type = arg_types[actual] 60 if actual_arg_type is None: 61 continue 62 63 actual_type = mapper.expand_actual_type(actual_arg_type, arg_kinds[actual], 64 callee.arg_names[i], callee.arg_kinds[i]) 65 c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF) 66 constraints.extend(c) 67 68 return constraints 69 70 71def infer_constraints(template: Type, actual: Type, 72 direction: int) -> List[Constraint]: 73 """Infer type constraints. 74 75 Match a template type, which may contain type variable references, 76 recursively against a type which does not contain (the same) type 77 variable references. The result is a list of type constrains of 78 form 'T is a supertype/subtype of x', where T is a type variable 79 present in the template and x is a type without reference to type 80 variables present in the template. 81 82 Assume T and S are type variables. Now the following results can be 83 calculated (read as '(template, actual) --> result'): 84 85 (T, X) --> T :> X 86 (X[T], X[Y]) --> T <: Y and T :> Y 87 ((T, T), (X, Y)) --> T :> X and T :> Y 88 ((T, S), (X, Y)) --> T :> X and S :> Y 89 (X[T], Any) --> T <: Any and T :> Any 90 91 The constraints are represented as Constraint objects. 92 """ 93 if any(get_proper_type(template) == get_proper_type(t) for t in TypeState._inferring): 94 return [] 95 if isinstance(template, TypeAliasType) and template.is_recursive: 96 # This case requires special care because it may cause infinite recursion. 97 TypeState._inferring.append(template) 98 res = _infer_constraints(template, actual, direction) 99 TypeState._inferring.pop() 100 return res 101 return _infer_constraints(template, actual, direction) 102 103 104def _infer_constraints(template: Type, actual: Type, 105 direction: int) -> List[Constraint]: 106 107 orig_template = template 108 template = get_proper_type(template) 109 actual = get_proper_type(actual) 110 111 # Type inference shouldn't be affected by whether union types have been simplified. 112 # We however keep any ErasedType items, so that the caller will see it when using 113 # checkexpr.has_erased_component(). 114 if isinstance(template, UnionType): 115 template = mypy.typeops.make_simplified_union(template.items, keep_erased=True) 116 if isinstance(actual, UnionType): 117 actual = mypy.typeops.make_simplified_union(actual.items, keep_erased=True) 118 119 # Ignore Any types from the type suggestion engine to avoid them 120 # causing us to infer Any in situations where a better job could 121 # be done otherwise. (This can produce false positives but that 122 # doesn't really matter because it is all heuristic anyway.) 123 if isinstance(actual, AnyType) and actual.type_of_any == TypeOfAny.suggestion_engine: 124 return [] 125 126 # If the template is simply a type variable, emit a Constraint directly. 127 # We need to handle this case before handling Unions for two reasons: 128 # 1. "T <: Union[U1, U2]" is not equivalent to "T <: U1 or T <: U2", 129 # because T can itself be a union (notably, Union[U1, U2] itself). 130 # 2. "T :> Union[U1, U2]" is logically equivalent to "T :> U1 and 131 # T :> U2", but they are not equivalent to the constraint solver, 132 # which never introduces new Union types (it uses join() instead). 133 if isinstance(template, TypeVarType): 134 return [Constraint(template.id, direction, actual)] 135 136 # Now handle the case of either template or actual being a Union. 137 # For a Union to be a subtype of another type, every item of the Union 138 # must be a subtype of that type, so concatenate the constraints. 139 if direction == SUBTYPE_OF and isinstance(template, UnionType): 140 res = [] 141 for t_item in template.items: 142 res.extend(infer_constraints(t_item, actual, direction)) 143 return res 144 if direction == SUPERTYPE_OF and isinstance(actual, UnionType): 145 res = [] 146 for a_item in actual.items: 147 res.extend(infer_constraints(orig_template, a_item, direction)) 148 return res 149 150 # Now the potential subtype is known not to be a Union or a type 151 # variable that we are solving for. In that case, for a Union to 152 # be a supertype of the potential subtype, some item of the Union 153 # must be a supertype of it. 154 if direction == SUBTYPE_OF and isinstance(actual, UnionType): 155 # If some of items is not a complete type, disregard that. 156 items = simplify_away_incomplete_types(actual.items) 157 # We infer constraints eagerly -- try to find constraints for a type 158 # variable if possible. This seems to help with some real-world 159 # use cases. 160 return any_constraints( 161 [infer_constraints_if_possible(template, a_item, direction) 162 for a_item in items], 163 eager=True) 164 if direction == SUPERTYPE_OF and isinstance(template, UnionType): 165 # When the template is a union, we are okay with leaving some 166 # type variables indeterminate. This helps with some special 167 # cases, though this isn't very principled. 168 return any_constraints( 169 [infer_constraints_if_possible(t_item, actual, direction) 170 for t_item in template.items], 171 eager=False) 172 173 # Remaining cases are handled by ConstraintBuilderVisitor. 174 return template.accept(ConstraintBuilderVisitor(actual, direction)) 175 176 177def infer_constraints_if_possible(template: Type, actual: Type, 178 direction: int) -> Optional[List[Constraint]]: 179 """Like infer_constraints, but return None if the input relation is 180 known to be unsatisfiable, for example if template=List[T] and actual=int. 181 (In this case infer_constraints would return [], just like it would for 182 an automatically satisfied relation like template=List[T] and actual=object.) 183 """ 184 if (direction == SUBTYPE_OF and 185 not mypy.subtypes.is_subtype(erase_typevars(template), actual)): 186 return None 187 if (direction == SUPERTYPE_OF and 188 not mypy.subtypes.is_subtype(actual, erase_typevars(template))): 189 return None 190 if (direction == SUPERTYPE_OF and isinstance(template, TypeVarType) and 191 not mypy.subtypes.is_subtype(actual, erase_typevars(template.upper_bound))): 192 # This is not caught by the above branch because of the erase_typevars() call, 193 # that would return 'Any' for a type variable. 194 return None 195 return infer_constraints(template, actual, direction) 196 197 198def any_constraints(options: List[Optional[List[Constraint]]], eager: bool) -> List[Constraint]: 199 """Deduce what we can from a collection of constraint lists. 200 201 It's a given that at least one of the lists must be satisfied. A 202 None element in the list of options represents an unsatisfiable 203 constraint and is ignored. Ignore empty constraint lists if eager 204 is true -- they are always trivially satisfiable. 205 """ 206 if eager: 207 valid_options = [option for option in options if option] 208 else: 209 valid_options = [option for option in options if option is not None] 210 if len(valid_options) == 1: 211 return valid_options[0] 212 elif (len(valid_options) > 1 and 213 all(is_same_constraints(valid_options[0], c) 214 for c in valid_options[1:])): 215 # Multiple sets of constraints that are all the same. Just pick any one of them. 216 # TODO: More generally, if a given (variable, direction) pair appears in 217 # every option, combine the bounds with meet/join. 218 return valid_options[0] 219 220 # Otherwise, there are either no valid options or multiple, inconsistent valid 221 # options. Give up and deduce nothing. 222 return [] 223 224 225def is_same_constraints(x: List[Constraint], y: List[Constraint]) -> bool: 226 for c1 in x: 227 if not any(is_same_constraint(c1, c2) for c2 in y): 228 return False 229 for c1 in y: 230 if not any(is_same_constraint(c1, c2) for c2 in x): 231 return False 232 return True 233 234 235def is_same_constraint(c1: Constraint, c2: Constraint) -> bool: 236 return (c1.type_var == c2.type_var 237 and c1.op == c2.op 238 and mypy.sametypes.is_same_type(c1.target, c2.target)) 239 240 241def simplify_away_incomplete_types(types: Iterable[Type]) -> List[Type]: 242 complete = [typ for typ in types if is_complete_type(typ)] 243 if complete: 244 return complete 245 else: 246 return list(types) 247 248 249def is_complete_type(typ: Type) -> bool: 250 """Is a type complete? 251 252 A complete doesn't have uninhabited type components or (when not in strict 253 optional mode) None components. 254 """ 255 return typ.accept(CompleteTypeVisitor()) 256 257 258class CompleteTypeVisitor(TypeQuery[bool]): 259 def __init__(self) -> None: 260 super().__init__(all) 261 262 def visit_uninhabited_type(self, t: UninhabitedType) -> bool: 263 return False 264 265 266class ConstraintBuilderVisitor(TypeVisitor[List[Constraint]]): 267 """Visitor class for inferring type constraints.""" 268 269 # The type that is compared against a template 270 # TODO: The value may be None. Is that actually correct? 271 actual = None # type: ProperType 272 273 def __init__(self, actual: ProperType, direction: int) -> None: 274 # Direction must be SUBTYPE_OF or SUPERTYPE_OF. 275 self.actual = actual 276 self.direction = direction 277 278 # Trivial leaf types 279 280 def visit_unbound_type(self, template: UnboundType) -> List[Constraint]: 281 return [] 282 283 def visit_any(self, template: AnyType) -> List[Constraint]: 284 return [] 285 286 def visit_none_type(self, template: NoneType) -> List[Constraint]: 287 return [] 288 289 def visit_uninhabited_type(self, template: UninhabitedType) -> List[Constraint]: 290 return [] 291 292 def visit_erased_type(self, template: ErasedType) -> List[Constraint]: 293 return [] 294 295 def visit_deleted_type(self, template: DeletedType) -> List[Constraint]: 296 return [] 297 298 def visit_literal_type(self, template: LiteralType) -> List[Constraint]: 299 return [] 300 301 # Errors 302 303 def visit_partial_type(self, template: PartialType) -> List[Constraint]: 304 # We can't do anything useful with a partial type here. 305 assert False, "Internal error" 306 307 # Non-trivial leaf type 308 309 def visit_type_var(self, template: TypeVarType) -> List[Constraint]: 310 assert False, ("Unexpected TypeVarType in ConstraintBuilderVisitor" 311 " (should have been handled in infer_constraints)") 312 313 # Non-leaf types 314 315 def visit_instance(self, template: Instance) -> List[Constraint]: 316 original_actual = actual = self.actual 317 res = [] # type: List[Constraint] 318 if isinstance(actual, (CallableType, Overloaded)) and template.type.is_protocol: 319 if template.type.protocol_members == ['__call__']: 320 # Special case: a generic callback protocol 321 if not any(mypy.sametypes.is_same_type(template, t) 322 for t in template.type.inferring): 323 template.type.inferring.append(template) 324 call = mypy.subtypes.find_member('__call__', template, actual, 325 is_operator=True) 326 assert call is not None 327 if mypy.subtypes.is_subtype(actual, erase_typevars(call)): 328 subres = infer_constraints(call, actual, self.direction) 329 res.extend(subres) 330 template.type.inferring.pop() 331 return res 332 if isinstance(actual, CallableType) and actual.fallback is not None: 333 actual = actual.fallback 334 if isinstance(actual, Overloaded) and actual.fallback is not None: 335 actual = actual.fallback 336 if isinstance(actual, TypedDictType): 337 actual = actual.as_anonymous().fallback 338 if isinstance(actual, LiteralType): 339 actual = actual.fallback 340 if isinstance(actual, Instance): 341 instance = actual 342 erased = erase_typevars(template) 343 assert isinstance(erased, Instance) # type: ignore 344 # We always try nominal inference if possible, 345 # it is much faster than the structural one. 346 if (self.direction == SUBTYPE_OF and 347 template.type.has_base(instance.type.fullname)): 348 mapped = map_instance_to_supertype(template, instance.type) 349 tvars = mapped.type.defn.type_vars 350 # N.B: We use zip instead of indexing because the lengths might have 351 # mismatches during daemon reprocessing. 352 for tvar, mapped_arg, instance_arg in zip(tvars, mapped.args, instance.args): 353 # The constraints for generic type parameters depend on variance. 354 # Include constraints from both directions if invariant. 355 if tvar.variance != CONTRAVARIANT: 356 res.extend(infer_constraints( 357 mapped_arg, instance_arg, self.direction)) 358 if tvar.variance != COVARIANT: 359 res.extend(infer_constraints( 360 mapped_arg, instance_arg, neg_op(self.direction))) 361 return res 362 elif (self.direction == SUPERTYPE_OF and 363 instance.type.has_base(template.type.fullname)): 364 mapped = map_instance_to_supertype(instance, template.type) 365 tvars = template.type.defn.type_vars 366 # N.B: We use zip instead of indexing because the lengths might have 367 # mismatches during daemon reprocessing. 368 for tvar, mapped_arg, template_arg in zip(tvars, mapped.args, template.args): 369 # The constraints for generic type parameters depend on variance. 370 # Include constraints from both directions if invariant. 371 if tvar.variance != CONTRAVARIANT: 372 res.extend(infer_constraints( 373 template_arg, mapped_arg, self.direction)) 374 if tvar.variance != COVARIANT: 375 res.extend(infer_constraints( 376 template_arg, mapped_arg, neg_op(self.direction))) 377 return res 378 if (template.type.is_protocol and self.direction == SUPERTYPE_OF and 379 # We avoid infinite recursion for structural subtypes by checking 380 # whether this type already appeared in the inference chain. 381 # This is a conservative way break the inference cycles. 382 # It never produces any "false" constraints but gives up soon 383 # on purely structural inference cycles, see #3829. 384 # Note that we use is_protocol_implementation instead of is_subtype 385 # because some type may be considered a subtype of a protocol 386 # due to _promote, but still not implement the protocol. 387 not any(mypy.sametypes.is_same_type(template, t) 388 for t in template.type.inferring) and 389 mypy.subtypes.is_protocol_implementation(instance, erased)): 390 template.type.inferring.append(template) 391 self.infer_constraints_from_protocol_members(res, instance, template, 392 original_actual, template) 393 template.type.inferring.pop() 394 return res 395 elif (instance.type.is_protocol and self.direction == SUBTYPE_OF and 396 # We avoid infinite recursion for structural subtypes also here. 397 not any(mypy.sametypes.is_same_type(instance, i) 398 for i in instance.type.inferring) and 399 mypy.subtypes.is_protocol_implementation(erased, instance)): 400 instance.type.inferring.append(instance) 401 self.infer_constraints_from_protocol_members(res, instance, template, 402 template, instance) 403 instance.type.inferring.pop() 404 return res 405 if isinstance(actual, AnyType): 406 # IDEA: Include both ways, i.e. add negation as well? 407 return self.infer_against_any(template.args, actual) 408 if (isinstance(actual, TupleType) and 409 (is_named_instance(template, 'typing.Iterable') or 410 is_named_instance(template, 'typing.Container') or 411 is_named_instance(template, 'typing.Sequence') or 412 is_named_instance(template, 'typing.Reversible')) 413 and self.direction == SUPERTYPE_OF): 414 for item in actual.items: 415 cb = infer_constraints(template.args[0], item, SUPERTYPE_OF) 416 res.extend(cb) 417 return res 418 elif isinstance(actual, TupleType) and self.direction == SUPERTYPE_OF: 419 return infer_constraints(template, 420 mypy.typeops.tuple_fallback(actual), 421 self.direction) 422 else: 423 return [] 424 425 def infer_constraints_from_protocol_members(self, res: List[Constraint], 426 instance: Instance, template: Instance, 427 subtype: Type, protocol: Instance) -> None: 428 """Infer constraints for situations where either 'template' or 'instance' is a protocol. 429 430 The 'protocol' is the one of two that is an instance of protocol type, 'subtype' 431 is the type used to bind self during inference. Currently, we just infer constrains for 432 every protocol member type (both ways for settable members). 433 """ 434 for member in protocol.type.protocol_members: 435 inst = mypy.subtypes.find_member(member, instance, subtype) 436 temp = mypy.subtypes.find_member(member, template, subtype) 437 assert inst is not None and temp is not None 438 # The above is safe since at this point we know that 'instance' is a subtype 439 # of (erased) 'template', therefore it defines all protocol members 440 res.extend(infer_constraints(temp, inst, self.direction)) 441 if (mypy.subtypes.IS_SETTABLE in 442 mypy.subtypes.get_member_flags(member, protocol.type)): 443 # Settable members are invariant, add opposite constraints 444 res.extend(infer_constraints(temp, inst, neg_op(self.direction))) 445 446 def visit_callable_type(self, template: CallableType) -> List[Constraint]: 447 if isinstance(self.actual, CallableType): 448 cactual = self.actual 449 # FIX verify argument counts 450 # FIX what if one of the functions is generic 451 res = [] # type: List[Constraint] 452 453 # We can't infer constraints from arguments if the template is Callable[..., T] (with 454 # literal '...'). 455 if not template.is_ellipsis_args: 456 # The lengths should match, but don't crash (it will error elsewhere). 457 for t, a in zip(template.arg_types, cactual.arg_types): 458 # Negate direction due to function argument type contravariance. 459 res.extend(infer_constraints(t, a, neg_op(self.direction))) 460 template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type 461 if template.type_guard is not None: 462 template_ret_type = template.type_guard 463 if cactual.type_guard is not None: 464 cactual_ret_type = cactual.type_guard 465 res.extend(infer_constraints(template_ret_type, cactual_ret_type, 466 self.direction)) 467 return res 468 elif isinstance(self.actual, AnyType): 469 # FIX what if generic 470 res = self.infer_against_any(template.arg_types, self.actual) 471 any_type = AnyType(TypeOfAny.from_another_any, source_any=self.actual) 472 res.extend(infer_constraints(template.ret_type, any_type, self.direction)) 473 return res 474 elif isinstance(self.actual, Overloaded): 475 return self.infer_against_overloaded(self.actual, template) 476 elif isinstance(self.actual, TypeType): 477 return infer_constraints(template.ret_type, self.actual.item, self.direction) 478 elif isinstance(self.actual, Instance): 479 # Instances with __call__ method defined are considered structural 480 # subtypes of Callable with a compatible signature. 481 call = mypy.subtypes.find_member('__call__', self.actual, self.actual, 482 is_operator=True) 483 if call: 484 return infer_constraints(template, call, self.direction) 485 else: 486 return [] 487 else: 488 return [] 489 490 def infer_against_overloaded(self, overloaded: Overloaded, 491 template: CallableType) -> List[Constraint]: 492 # Create constraints by matching an overloaded type against a template. 493 # This is tricky to do in general. We cheat by only matching against 494 # the first overload item that is callable compatible. This 495 # seems to work somewhat well, but we should really use a more 496 # reliable technique. 497 item = find_matching_overload_item(overloaded, template) 498 return infer_constraints(template, item, self.direction) 499 500 def visit_tuple_type(self, template: TupleType) -> List[Constraint]: 501 actual = self.actual 502 if isinstance(actual, TupleType) and len(actual.items) == len(template.items): 503 res = [] # type: List[Constraint] 504 for i in range(len(template.items)): 505 res.extend(infer_constraints(template.items[i], 506 actual.items[i], 507 self.direction)) 508 return res 509 elif isinstance(actual, AnyType): 510 return self.infer_against_any(template.items, actual) 511 else: 512 return [] 513 514 def visit_typeddict_type(self, template: TypedDictType) -> List[Constraint]: 515 actual = self.actual 516 if isinstance(actual, TypedDictType): 517 res = [] # type: List[Constraint] 518 # NOTE: Non-matching keys are ignored. Compatibility is checked 519 # elsewhere so this shouldn't be unsafe. 520 for (item_name, template_item_type, actual_item_type) in template.zip(actual): 521 res.extend(infer_constraints(template_item_type, 522 actual_item_type, 523 self.direction)) 524 return res 525 elif isinstance(actual, AnyType): 526 return self.infer_against_any(template.items.values(), actual) 527 else: 528 return [] 529 530 def visit_union_type(self, template: UnionType) -> List[Constraint]: 531 assert False, ("Unexpected UnionType in ConstraintBuilderVisitor" 532 " (should have been handled in infer_constraints)") 533 534 def visit_type_alias_type(self, template: TypeAliasType) -> List[Constraint]: 535 assert False, "This should be never called, got {}".format(template) 536 537 def visit_type_guard_type(self, template: TypeGuardType) -> List[Constraint]: 538 assert False, "This should be never called, got {}".format(template) 539 540 def infer_against_any(self, types: Iterable[Type], any_type: AnyType) -> List[Constraint]: 541 res = [] # type: List[Constraint] 542 for t in types: 543 res.extend(infer_constraints(t, any_type, self.direction)) 544 return res 545 546 def visit_overloaded(self, template: Overloaded) -> List[Constraint]: 547 res = [] # type: List[Constraint] 548 for t in template.items(): 549 res.extend(infer_constraints(t, self.actual, self.direction)) 550 return res 551 552 def visit_type_type(self, template: TypeType) -> List[Constraint]: 553 if isinstance(self.actual, CallableType): 554 return infer_constraints(template.item, self.actual.ret_type, self.direction) 555 elif isinstance(self.actual, Overloaded): 556 return infer_constraints(template.item, self.actual.items()[0].ret_type, 557 self.direction) 558 elif isinstance(self.actual, TypeType): 559 return infer_constraints(template.item, self.actual.item, self.direction) 560 elif isinstance(self.actual, AnyType): 561 return infer_constraints(template.item, self.actual, self.direction) 562 else: 563 return [] 564 565 566def neg_op(op: int) -> int: 567 """Map SubtypeOf to SupertypeOf and vice versa.""" 568 569 if op == SUBTYPE_OF: 570 return SUPERTYPE_OF 571 elif op == SUPERTYPE_OF: 572 return SUBTYPE_OF 573 else: 574 raise ValueError('Invalid operator {}'.format(op)) 575 576 577def find_matching_overload_item(overloaded: Overloaded, template: CallableType) -> CallableType: 578 """Disambiguate overload item against a template.""" 579 items = overloaded.items() 580 for item in items: 581 # Return type may be indeterminate in the template, so ignore it when performing a 582 # subtype check. 583 if mypy.subtypes.is_callable_compatible(item, template, 584 is_compat=mypy.subtypes.is_subtype, 585 ignore_return=True): 586 return item 587 # Fall back to the first item if we can't find a match. This is totally arbitrary -- 588 # maybe we should just bail out at this point. 589 return items[0] 590