1"""Type inference constraints."""
2
3from typing import Iterable, List, Optional, Sequence
4from typing_extensions import Final
5
6from mypy.types import (
7    CallableType, Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarType, Instance,
8    TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType,
9    UninhabitedType, TypeType, TypeVarId, TypeQuery, is_named_instance, TypeOfAny, LiteralType,
10    ProperType, get_proper_type, TypeAliasType, TypeGuardType
11)
12from mypy.maptype import map_instance_to_supertype
13import mypy.subtypes
14import mypy.sametypes
15import mypy.typeops
16from mypy.erasetype import erase_typevars
17from mypy.nodes import COVARIANT, CONTRAVARIANT
18from mypy.argmap import ArgTypeExpander
19from mypy.typestate import TypeState
20
21SUBTYPE_OF = 0  # type: Final
22SUPERTYPE_OF = 1  # type: Final
23
24
25class Constraint:
26    """A representation of a type constraint.
27
28    It can be either T <: type or T :> type (T is a type variable).
29    """
30
31    type_var = None  # type: TypeVarId
32    op = 0           # SUBTYPE_OF or SUPERTYPE_OF
33    target = None    # type: Type
34
35    def __init__(self, type_var: TypeVarId, op: int, target: Type) -> None:
36        self.type_var = type_var
37        self.op = op
38        self.target = target
39
40    def __repr__(self) -> str:
41        op_str = '<:'
42        if self.op == SUPERTYPE_OF:
43            op_str = ':>'
44        return '{} {} {}'.format(self.type_var, op_str, self.target)
45
46
47def infer_constraints_for_callable(
48        callee: CallableType, arg_types: Sequence[Optional[Type]], arg_kinds: List[int],
49        formal_to_actual: List[List[int]]) -> List[Constraint]:
50    """Infer type variable constraints for a callable and actual arguments.
51
52    Return a list of constraints.
53    """
54    constraints = []  # type: List[Constraint]
55    mapper = ArgTypeExpander()
56
57    for i, actuals in enumerate(formal_to_actual):
58        for actual in actuals:
59            actual_arg_type = arg_types[actual]
60            if actual_arg_type is None:
61                continue
62
63            actual_type = mapper.expand_actual_type(actual_arg_type, arg_kinds[actual],
64                                                    callee.arg_names[i], callee.arg_kinds[i])
65            c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF)
66            constraints.extend(c)
67
68    return constraints
69
70
71def infer_constraints(template: Type, actual: Type,
72                      direction: int) -> List[Constraint]:
73    """Infer type constraints.
74
75    Match a template type, which may contain type variable references,
76    recursively against a type which does not contain (the same) type
77    variable references. The result is a list of type constrains of
78    form 'T is a supertype/subtype of x', where T is a type variable
79    present in the template and x is a type without reference to type
80    variables present in the template.
81
82    Assume T and S are type variables. Now the following results can be
83    calculated (read as '(template, actual) --> result'):
84
85      (T, X)            -->  T :> X
86      (X[T], X[Y])      -->  T <: Y and T :> Y
87      ((T, T), (X, Y))  -->  T :> X and T :> Y
88      ((T, S), (X, Y))  -->  T :> X and S :> Y
89      (X[T], Any)       -->  T <: Any and T :> Any
90
91    The constraints are represented as Constraint objects.
92    """
93    if any(get_proper_type(template) == get_proper_type(t) for t in TypeState._inferring):
94        return []
95    if isinstance(template, TypeAliasType) and template.is_recursive:
96        # This case requires special care because it may cause infinite recursion.
97        TypeState._inferring.append(template)
98        res = _infer_constraints(template, actual, direction)
99        TypeState._inferring.pop()
100        return res
101    return _infer_constraints(template, actual, direction)
102
103
104def _infer_constraints(template: Type, actual: Type,
105                       direction: int) -> List[Constraint]:
106
107    orig_template = template
108    template = get_proper_type(template)
109    actual = get_proper_type(actual)
110
111    # Type inference shouldn't be affected by whether union types have been simplified.
112    # We however keep any ErasedType items, so that the caller will see it when using
113    # checkexpr.has_erased_component().
114    if isinstance(template, UnionType):
115        template = mypy.typeops.make_simplified_union(template.items, keep_erased=True)
116    if isinstance(actual, UnionType):
117        actual = mypy.typeops.make_simplified_union(actual.items, keep_erased=True)
118
119    # Ignore Any types from the type suggestion engine to avoid them
120    # causing us to infer Any in situations where a better job could
121    # be done otherwise. (This can produce false positives but that
122    # doesn't really matter because it is all heuristic anyway.)
123    if isinstance(actual, AnyType) and actual.type_of_any == TypeOfAny.suggestion_engine:
124        return []
125
126    # If the template is simply a type variable, emit a Constraint directly.
127    # We need to handle this case before handling Unions for two reasons:
128    #  1. "T <: Union[U1, U2]" is not equivalent to "T <: U1 or T <: U2",
129    #     because T can itself be a union (notably, Union[U1, U2] itself).
130    #  2. "T :> Union[U1, U2]" is logically equivalent to "T :> U1 and
131    #     T :> U2", but they are not equivalent to the constraint solver,
132    #     which never introduces new Union types (it uses join() instead).
133    if isinstance(template, TypeVarType):
134        return [Constraint(template.id, direction, actual)]
135
136    # Now handle the case of either template or actual being a Union.
137    # For a Union to be a subtype of another type, every item of the Union
138    # must be a subtype of that type, so concatenate the constraints.
139    if direction == SUBTYPE_OF and isinstance(template, UnionType):
140        res = []
141        for t_item in template.items:
142            res.extend(infer_constraints(t_item, actual, direction))
143        return res
144    if direction == SUPERTYPE_OF and isinstance(actual, UnionType):
145        res = []
146        for a_item in actual.items:
147            res.extend(infer_constraints(orig_template, a_item, direction))
148        return res
149
150    # Now the potential subtype is known not to be a Union or a type
151    # variable that we are solving for. In that case, for a Union to
152    # be a supertype of the potential subtype, some item of the Union
153    # must be a supertype of it.
154    if direction == SUBTYPE_OF and isinstance(actual, UnionType):
155        # If some of items is not a complete type, disregard that.
156        items = simplify_away_incomplete_types(actual.items)
157        # We infer constraints eagerly -- try to find constraints for a type
158        # variable if possible. This seems to help with some real-world
159        # use cases.
160        return any_constraints(
161            [infer_constraints_if_possible(template, a_item, direction)
162             for a_item in items],
163            eager=True)
164    if direction == SUPERTYPE_OF and isinstance(template, UnionType):
165        # When the template is a union, we are okay with leaving some
166        # type variables indeterminate. This helps with some special
167        # cases, though this isn't very principled.
168        return any_constraints(
169            [infer_constraints_if_possible(t_item, actual, direction)
170             for t_item in template.items],
171            eager=False)
172
173    # Remaining cases are handled by ConstraintBuilderVisitor.
174    return template.accept(ConstraintBuilderVisitor(actual, direction))
175
176
177def infer_constraints_if_possible(template: Type, actual: Type,
178                                  direction: int) -> Optional[List[Constraint]]:
179    """Like infer_constraints, but return None if the input relation is
180    known to be unsatisfiable, for example if template=List[T] and actual=int.
181    (In this case infer_constraints would return [], just like it would for
182    an automatically satisfied relation like template=List[T] and actual=object.)
183    """
184    if (direction == SUBTYPE_OF and
185            not mypy.subtypes.is_subtype(erase_typevars(template), actual)):
186        return None
187    if (direction == SUPERTYPE_OF and
188            not mypy.subtypes.is_subtype(actual, erase_typevars(template))):
189        return None
190    if (direction == SUPERTYPE_OF and isinstance(template, TypeVarType) and
191            not mypy.subtypes.is_subtype(actual, erase_typevars(template.upper_bound))):
192        # This is not caught by the above branch because of the erase_typevars() call,
193        # that would return 'Any' for a type variable.
194        return None
195    return infer_constraints(template, actual, direction)
196
197
198def any_constraints(options: List[Optional[List[Constraint]]], eager: bool) -> List[Constraint]:
199    """Deduce what we can from a collection of constraint lists.
200
201    It's a given that at least one of the lists must be satisfied. A
202    None element in the list of options represents an unsatisfiable
203    constraint and is ignored.  Ignore empty constraint lists if eager
204    is true -- they are always trivially satisfiable.
205    """
206    if eager:
207        valid_options = [option for option in options if option]
208    else:
209        valid_options = [option for option in options if option is not None]
210    if len(valid_options) == 1:
211        return valid_options[0]
212    elif (len(valid_options) > 1 and
213          all(is_same_constraints(valid_options[0], c)
214              for c in valid_options[1:])):
215        # Multiple sets of constraints that are all the same. Just pick any one of them.
216        # TODO: More generally, if a given (variable, direction) pair appears in
217        #       every option, combine the bounds with meet/join.
218        return valid_options[0]
219
220    # Otherwise, there are either no valid options or multiple, inconsistent valid
221    # options. Give up and deduce nothing.
222    return []
223
224
225def is_same_constraints(x: List[Constraint], y: List[Constraint]) -> bool:
226    for c1 in x:
227        if not any(is_same_constraint(c1, c2) for c2 in y):
228            return False
229    for c1 in y:
230        if not any(is_same_constraint(c1, c2) for c2 in x):
231            return False
232    return True
233
234
235def is_same_constraint(c1: Constraint, c2: Constraint) -> bool:
236    return (c1.type_var == c2.type_var
237            and c1.op == c2.op
238            and mypy.sametypes.is_same_type(c1.target, c2.target))
239
240
241def simplify_away_incomplete_types(types: Iterable[Type]) -> List[Type]:
242    complete = [typ for typ in types if is_complete_type(typ)]
243    if complete:
244        return complete
245    else:
246        return list(types)
247
248
249def is_complete_type(typ: Type) -> bool:
250    """Is a type complete?
251
252    A complete doesn't have uninhabited type components or (when not in strict
253    optional mode) None components.
254    """
255    return typ.accept(CompleteTypeVisitor())
256
257
258class CompleteTypeVisitor(TypeQuery[bool]):
259    def __init__(self) -> None:
260        super().__init__(all)
261
262    def visit_uninhabited_type(self, t: UninhabitedType) -> bool:
263        return False
264
265
266class ConstraintBuilderVisitor(TypeVisitor[List[Constraint]]):
267    """Visitor class for inferring type constraints."""
268
269    # The type that is compared against a template
270    # TODO: The value may be None. Is that actually correct?
271    actual = None  # type: ProperType
272
273    def __init__(self, actual: ProperType, direction: int) -> None:
274        # Direction must be SUBTYPE_OF or SUPERTYPE_OF.
275        self.actual = actual
276        self.direction = direction
277
278    # Trivial leaf types
279
280    def visit_unbound_type(self, template: UnboundType) -> List[Constraint]:
281        return []
282
283    def visit_any(self, template: AnyType) -> List[Constraint]:
284        return []
285
286    def visit_none_type(self, template: NoneType) -> List[Constraint]:
287        return []
288
289    def visit_uninhabited_type(self, template: UninhabitedType) -> List[Constraint]:
290        return []
291
292    def visit_erased_type(self, template: ErasedType) -> List[Constraint]:
293        return []
294
295    def visit_deleted_type(self, template: DeletedType) -> List[Constraint]:
296        return []
297
298    def visit_literal_type(self, template: LiteralType) -> List[Constraint]:
299        return []
300
301    # Errors
302
303    def visit_partial_type(self, template: PartialType) -> List[Constraint]:
304        # We can't do anything useful with a partial type here.
305        assert False, "Internal error"
306
307    # Non-trivial leaf type
308
309    def visit_type_var(self, template: TypeVarType) -> List[Constraint]:
310        assert False, ("Unexpected TypeVarType in ConstraintBuilderVisitor"
311                       " (should have been handled in infer_constraints)")
312
313    # Non-leaf types
314
315    def visit_instance(self, template: Instance) -> List[Constraint]:
316        original_actual = actual = self.actual
317        res = []  # type: List[Constraint]
318        if isinstance(actual, (CallableType, Overloaded)) and template.type.is_protocol:
319            if template.type.protocol_members == ['__call__']:
320                # Special case: a generic callback protocol
321                if not any(mypy.sametypes.is_same_type(template, t)
322                           for t in template.type.inferring):
323                    template.type.inferring.append(template)
324                    call = mypy.subtypes.find_member('__call__', template, actual,
325                                                     is_operator=True)
326                    assert call is not None
327                    if mypy.subtypes.is_subtype(actual, erase_typevars(call)):
328                        subres = infer_constraints(call, actual, self.direction)
329                        res.extend(subres)
330                    template.type.inferring.pop()
331                    return res
332        if isinstance(actual, CallableType) and actual.fallback is not None:
333            actual = actual.fallback
334        if isinstance(actual, Overloaded) and actual.fallback is not None:
335            actual = actual.fallback
336        if isinstance(actual, TypedDictType):
337            actual = actual.as_anonymous().fallback
338        if isinstance(actual, LiteralType):
339            actual = actual.fallback
340        if isinstance(actual, Instance):
341            instance = actual
342            erased = erase_typevars(template)
343            assert isinstance(erased, Instance)  # type: ignore
344            # We always try nominal inference if possible,
345            # it is much faster than the structural one.
346            if (self.direction == SUBTYPE_OF and
347                    template.type.has_base(instance.type.fullname)):
348                mapped = map_instance_to_supertype(template, instance.type)
349                tvars = mapped.type.defn.type_vars
350                # N.B: We use zip instead of indexing because the lengths might have
351                # mismatches during daemon reprocessing.
352                for tvar, mapped_arg, instance_arg in zip(tvars, mapped.args, instance.args):
353                    # The constraints for generic type parameters depend on variance.
354                    # Include constraints from both directions if invariant.
355                    if tvar.variance != CONTRAVARIANT:
356                        res.extend(infer_constraints(
357                            mapped_arg, instance_arg, self.direction))
358                    if tvar.variance != COVARIANT:
359                        res.extend(infer_constraints(
360                            mapped_arg, instance_arg, neg_op(self.direction)))
361                return res
362            elif (self.direction == SUPERTYPE_OF and
363                    instance.type.has_base(template.type.fullname)):
364                mapped = map_instance_to_supertype(instance, template.type)
365                tvars = template.type.defn.type_vars
366                # N.B: We use zip instead of indexing because the lengths might have
367                # mismatches during daemon reprocessing.
368                for tvar, mapped_arg, template_arg in zip(tvars, mapped.args, template.args):
369                    # The constraints for generic type parameters depend on variance.
370                    # Include constraints from both directions if invariant.
371                    if tvar.variance != CONTRAVARIANT:
372                        res.extend(infer_constraints(
373                            template_arg, mapped_arg, self.direction))
374                    if tvar.variance != COVARIANT:
375                        res.extend(infer_constraints(
376                            template_arg, mapped_arg, neg_op(self.direction)))
377                return res
378            if (template.type.is_protocol and self.direction == SUPERTYPE_OF and
379                    # We avoid infinite recursion for structural subtypes by checking
380                    # whether this type already appeared in the inference chain.
381                    # This is a conservative way break the inference cycles.
382                    # It never produces any "false" constraints but gives up soon
383                    # on purely structural inference cycles, see #3829.
384                    # Note that we use is_protocol_implementation instead of is_subtype
385                    # because some type may be considered a subtype of a protocol
386                    # due to _promote, but still not implement the protocol.
387                    not any(mypy.sametypes.is_same_type(template, t)
388                            for t in template.type.inferring) and
389                    mypy.subtypes.is_protocol_implementation(instance, erased)):
390                template.type.inferring.append(template)
391                self.infer_constraints_from_protocol_members(res, instance, template,
392                                                             original_actual, template)
393                template.type.inferring.pop()
394                return res
395            elif (instance.type.is_protocol and self.direction == SUBTYPE_OF and
396                  # We avoid infinite recursion for structural subtypes also here.
397                  not any(mypy.sametypes.is_same_type(instance, i)
398                          for i in instance.type.inferring) and
399                  mypy.subtypes.is_protocol_implementation(erased, instance)):
400                instance.type.inferring.append(instance)
401                self.infer_constraints_from_protocol_members(res, instance, template,
402                                                             template, instance)
403                instance.type.inferring.pop()
404                return res
405        if isinstance(actual, AnyType):
406            # IDEA: Include both ways, i.e. add negation as well?
407            return self.infer_against_any(template.args, actual)
408        if (isinstance(actual, TupleType) and
409            (is_named_instance(template, 'typing.Iterable') or
410             is_named_instance(template, 'typing.Container') or
411             is_named_instance(template, 'typing.Sequence') or
412             is_named_instance(template, 'typing.Reversible'))
413                and self.direction == SUPERTYPE_OF):
414            for item in actual.items:
415                cb = infer_constraints(template.args[0], item, SUPERTYPE_OF)
416                res.extend(cb)
417            return res
418        elif isinstance(actual, TupleType) and self.direction == SUPERTYPE_OF:
419            return infer_constraints(template,
420                                     mypy.typeops.tuple_fallback(actual),
421                                     self.direction)
422        else:
423            return []
424
425    def infer_constraints_from_protocol_members(self, res: List[Constraint],
426                                                instance: Instance, template: Instance,
427                                                subtype: Type, protocol: Instance) -> None:
428        """Infer constraints for situations where either 'template' or 'instance' is a protocol.
429
430        The 'protocol' is the one of two that is an instance of protocol type, 'subtype'
431        is the type used to bind self during inference. Currently, we just infer constrains for
432        every protocol member type (both ways for settable members).
433        """
434        for member in protocol.type.protocol_members:
435            inst = mypy.subtypes.find_member(member, instance, subtype)
436            temp = mypy.subtypes.find_member(member, template, subtype)
437            assert inst is not None and temp is not None
438            # The above is safe since at this point we know that 'instance' is a subtype
439            # of (erased) 'template', therefore it defines all protocol members
440            res.extend(infer_constraints(temp, inst, self.direction))
441            if (mypy.subtypes.IS_SETTABLE in
442                    mypy.subtypes.get_member_flags(member, protocol.type)):
443                # Settable members are invariant, add opposite constraints
444                res.extend(infer_constraints(temp, inst, neg_op(self.direction)))
445
446    def visit_callable_type(self, template: CallableType) -> List[Constraint]:
447        if isinstance(self.actual, CallableType):
448            cactual = self.actual
449            # FIX verify argument counts
450            # FIX what if one of the functions is generic
451            res = []  # type: List[Constraint]
452
453            # We can't infer constraints from arguments if the template is Callable[..., T] (with
454            # literal '...').
455            if not template.is_ellipsis_args:
456                # The lengths should match, but don't crash (it will error elsewhere).
457                for t, a in zip(template.arg_types, cactual.arg_types):
458                    # Negate direction due to function argument type contravariance.
459                    res.extend(infer_constraints(t, a, neg_op(self.direction)))
460            template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type
461            if template.type_guard is not None:
462                template_ret_type = template.type_guard
463            if cactual.type_guard is not None:
464                cactual_ret_type = cactual.type_guard
465            res.extend(infer_constraints(template_ret_type, cactual_ret_type,
466                                         self.direction))
467            return res
468        elif isinstance(self.actual, AnyType):
469            # FIX what if generic
470            res = self.infer_against_any(template.arg_types, self.actual)
471            any_type = AnyType(TypeOfAny.from_another_any, source_any=self.actual)
472            res.extend(infer_constraints(template.ret_type, any_type, self.direction))
473            return res
474        elif isinstance(self.actual, Overloaded):
475            return self.infer_against_overloaded(self.actual, template)
476        elif isinstance(self.actual, TypeType):
477            return infer_constraints(template.ret_type, self.actual.item, self.direction)
478        elif isinstance(self.actual, Instance):
479            # Instances with __call__ method defined are considered structural
480            # subtypes of Callable with a compatible signature.
481            call = mypy.subtypes.find_member('__call__', self.actual, self.actual,
482                                             is_operator=True)
483            if call:
484                return infer_constraints(template, call, self.direction)
485            else:
486                return []
487        else:
488            return []
489
490    def infer_against_overloaded(self, overloaded: Overloaded,
491                                 template: CallableType) -> List[Constraint]:
492        # Create constraints by matching an overloaded type against a template.
493        # This is tricky to do in general. We cheat by only matching against
494        # the first overload item that is callable compatible. This
495        # seems to work somewhat well, but we should really use a more
496        # reliable technique.
497        item = find_matching_overload_item(overloaded, template)
498        return infer_constraints(template, item, self.direction)
499
500    def visit_tuple_type(self, template: TupleType) -> List[Constraint]:
501        actual = self.actual
502        if isinstance(actual, TupleType) and len(actual.items) == len(template.items):
503            res = []  # type: List[Constraint]
504            for i in range(len(template.items)):
505                res.extend(infer_constraints(template.items[i],
506                                             actual.items[i],
507                                             self.direction))
508            return res
509        elif isinstance(actual, AnyType):
510            return self.infer_against_any(template.items, actual)
511        else:
512            return []
513
514    def visit_typeddict_type(self, template: TypedDictType) -> List[Constraint]:
515        actual = self.actual
516        if isinstance(actual, TypedDictType):
517            res = []  # type: List[Constraint]
518            # NOTE: Non-matching keys are ignored. Compatibility is checked
519            #       elsewhere so this shouldn't be unsafe.
520            for (item_name, template_item_type, actual_item_type) in template.zip(actual):
521                res.extend(infer_constraints(template_item_type,
522                                             actual_item_type,
523                                             self.direction))
524            return res
525        elif isinstance(actual, AnyType):
526            return self.infer_against_any(template.items.values(), actual)
527        else:
528            return []
529
530    def visit_union_type(self, template: UnionType) -> List[Constraint]:
531        assert False, ("Unexpected UnionType in ConstraintBuilderVisitor"
532                       " (should have been handled in infer_constraints)")
533
534    def visit_type_alias_type(self, template: TypeAliasType) -> List[Constraint]:
535        assert False, "This should be never called, got {}".format(template)
536
537    def visit_type_guard_type(self, template: TypeGuardType) -> List[Constraint]:
538        assert False, "This should be never called, got {}".format(template)
539
540    def infer_against_any(self, types: Iterable[Type], any_type: AnyType) -> List[Constraint]:
541        res = []  # type: List[Constraint]
542        for t in types:
543            res.extend(infer_constraints(t, any_type, self.direction))
544        return res
545
546    def visit_overloaded(self, template: Overloaded) -> List[Constraint]:
547        res = []  # type: List[Constraint]
548        for t in template.items():
549            res.extend(infer_constraints(t, self.actual, self.direction))
550        return res
551
552    def visit_type_type(self, template: TypeType) -> List[Constraint]:
553        if isinstance(self.actual, CallableType):
554            return infer_constraints(template.item, self.actual.ret_type, self.direction)
555        elif isinstance(self.actual, Overloaded):
556            return infer_constraints(template.item, self.actual.items()[0].ret_type,
557                                     self.direction)
558        elif isinstance(self.actual, TypeType):
559            return infer_constraints(template.item, self.actual.item, self.direction)
560        elif isinstance(self.actual, AnyType):
561            return infer_constraints(template.item, self.actual, self.direction)
562        else:
563            return []
564
565
566def neg_op(op: int) -> int:
567    """Map SubtypeOf to SupertypeOf and vice versa."""
568
569    if op == SUBTYPE_OF:
570        return SUPERTYPE_OF
571    elif op == SUPERTYPE_OF:
572        return SUBTYPE_OF
573    else:
574        raise ValueError('Invalid operator {}'.format(op))
575
576
577def find_matching_overload_item(overloaded: Overloaded, template: CallableType) -> CallableType:
578    """Disambiguate overload item against a template."""
579    items = overloaded.items()
580    for item in items:
581        # Return type may be indeterminate in the template, so ignore it when performing a
582        # subtype check.
583        if mypy.subtypes.is_callable_compatible(item, template,
584                                                is_compat=mypy.subtypes.is_subtype,
585                                                ignore_return=True):
586            return item
587    # Fall back to the first item if we can't find a match. This is totally arbitrary --
588    # maybe we should just bail out at this point.
589    return items[0]
590