1"""Utilities for inline type annotations."""
2
3import collections
4import itertools
5import sys
6
7from pytype import abstract
8from pytype import abstract_utils
9from pytype import class_mixin
10from pytype import mixin
11from pytype import utils
12from pytype.overlays import typing_overlay
13
14
15class AnnotationsUtil(utils.VirtualMachineWeakrefMixin):
16  """Utility class for inline type annotations."""
17
18  def sub_annotations(self, node, annotations, substs, instantiate_unbound):
19    """Apply type parameter substitutions to a dictionary of annotations."""
20    if substs and all(substs):
21      return {name: self.sub_one_annotation(node, annot, substs,
22                                            instantiate_unbound)
23              for name, annot in annotations.items()}
24    return annotations
25
26  def sub_one_annotation(self, node, annot, substs, instantiate_unbound=True):
27    """Apply type parameter substitutions to an annotation."""
28    if isinstance(annot, abstract.TypeParameter):
29      # We use the given substitutions to bind the annotation if
30      # (1) every subst provides at least one binding, and
31      # (2) none of the bindings are ambiguous, and
32      # (3) at least one binding is non-empty.
33      if all(annot.full_name in subst and subst[annot.full_name].bindings
34             for subst in substs):
35        vals = sum((subst[annot.full_name].data for subst in substs), [])
36      else:
37        vals = None
38      if (vals is None or
39          any(isinstance(v, abstract.AMBIGUOUS) for v in vals) or
40          all(isinstance(v, abstract.Empty) for v in vals)):
41        if instantiate_unbound:
42          vals = annot.instantiate(node).data
43        else:
44          vals = [annot]
45      return self.vm.convert.merge_classes(vals)
46    elif isinstance(annot, mixin.NestedAnnotation):
47      inner_types = [(key, self.sub_one_annotation(node, val, substs,
48                                                   instantiate_unbound))
49                     for key, val in annot.get_inner_types()]
50      return annot.replace(inner_types)
51    return annot
52
53  def get_late_annotations(self, annot):
54    if annot.is_late_annotation() and not annot.resolved:
55      yield annot
56    elif isinstance(annot, mixin.NestedAnnotation):
57      for _, typ in annot.get_inner_types():
58        yield from self.get_late_annotations(typ)
59
60  def remove_late_annotations(self, annot):
61    """Replace unresolved late annotations with unsolvables."""
62    if annot.is_late_annotation() and not annot.resolved:
63      return self.vm.convert.unsolvable
64    elif isinstance(annot, mixin.NestedAnnotation):
65      inner_types = [(key, self.remove_late_annotations(val))
66                     for key, val in annot.get_inner_types()]
67      return annot.replace(inner_types)
68    return annot
69
70  def add_scope(self, annot, types, module):
71    """Add scope for type parameters.
72
73    In original type class, all type parameters that should be added a scope
74    will be replaced with a new copy.
75
76    Args:
77      annot: The type class.
78      types: A type name list that should be added a scope.
79      module: Module name.
80
81    Returns:
82      The type with fresh type parameters that have been added the scope.
83    """
84    if isinstance(annot, abstract.TypeParameter):
85      if annot.name in types:
86        new_annot = annot.copy()
87        new_annot.module = module
88        return new_annot
89      return annot
90    elif isinstance(annot, abstract.TupleClass):
91      params = {}
92      for name, param in annot.formal_type_parameters.items():
93        params[name] = self.add_scope(param, types, module)
94      return abstract.TupleClass(
95          annot.base_cls, params, self.vm, annot.template)
96    elif isinstance(annot, mixin.NestedAnnotation):
97      inner_types = [(key, self.add_scope(typ, types, module))
98                     for key, typ in annot.get_inner_types()]
99      return annot.replace(inner_types)
100    return annot
101
102  def get_type_parameters(self, annot, seen=None):
103    """Returns all the TypeParameter instances that appear in the annotation.
104
105    Note that if you just need to know whether or not the annotation contains
106    type parameters, you can check its `.formal` attribute.
107
108    Args:
109      annot: An annotation.
110      seen: A seen set.
111    """
112    seen = seen or set()
113    if annot in seen:
114      return []
115    if isinstance(annot, abstract.ParameterizedClass):
116      # We track parameterized classes to avoid recursion errors when a class
117      # contains itself.
118      seen = seen | {annot}
119    if isinstance(annot, abstract.TypeParameter):
120      return [annot]
121    elif isinstance(annot, abstract.TupleClass):
122      annots = []
123      for idx in range(annot.tuple_length):
124        annots.extend(self.get_type_parameters(
125            annot.formal_type_parameters[idx], seen))
126      return annots
127    elif isinstance(annot, mixin.NestedAnnotation):
128      return sum((self.get_type_parameters(t, seen)
129                  for _, t in annot.get_inner_types()), [])
130    return []
131
132  def get_callable_type_parameter_names(self, var):
133    """Gets all TypeParameter names that appear in a Callable in 'var'."""
134    type_params = set()
135    seen = set()
136    stack = list(var.data)
137    while stack:
138      annot = stack.pop()
139      if annot in seen:
140        continue
141      seen.add(annot)
142      if annot.full_name == "typing.Callable":
143        params = collections.Counter(self.get_type_parameters(annot))
144        if isinstance(annot, abstract.CallableClass):
145          # pytype represents Callable[[T1, T2], None] as
146          # CallableClass({0: T1, 1: T2, ARGS: Union[T1, T2], RET: None}),
147          # so we have to fix double-counting of argument type parameters.
148          params -= collections.Counter(self.get_type_parameters(
149              annot.formal_type_parameters[abstract_utils.ARGS]))
150        # Type parameters that appear only once in a function signature are
151        # invalid, so ignore them.
152        type_params.update(p.name for p, n in params.items() if n > 1)
153      elif isinstance(annot, mixin.NestedAnnotation):
154        stack.extend(v for _, v in annot.get_inner_types())
155    return type_params
156
157  def convert_function_type_annotation(self, name, typ):
158    visible = typ.data
159    if len(visible) > 1:
160      self.vm.errorlog.ambiguous_annotation(self.vm.frames, visible, name)
161      return None
162    else:
163      return visible[0]
164
165  def convert_function_annotations(self, node, raw_annotations):
166    """Convert raw annotations to a {name: annotation} dict."""
167    if raw_annotations:
168      # {"i": int, "return": str} is stored as (int, str, ("i", "return"))
169      names = abstract_utils.get_atomic_python_constant(raw_annotations[-1])
170      type_list = raw_annotations[:-1]
171      annotations_list = []
172      for name, t in zip(names, type_list):
173        name = abstract_utils.get_atomic_python_constant(name)
174        t = self.convert_function_type_annotation(name, t)
175        annotations_list.append((name, t))
176      return self.convert_annotations_list(node, annotations_list)
177    else:
178      return {}
179
180  def convert_annotations_list(self, node, annotations_list):
181    """Convert a (name, raw_annot) list to a {name: annotation} dict."""
182    annotations = {}
183    for name, t in annotations_list:
184      if t is None:
185        continue
186      annot = self._process_one_annotation(
187          node, t, name, self.vm.simple_stack())
188      if annot is not None:
189        annotations[name] = annot
190    return annotations
191
192  def convert_class_annotations(self, node, raw_annotations):
193    """Convert a name -> raw_annot dict to annotations."""
194    annotations = {}
195    raw_items = raw_annotations.items()
196    if sys.version_info[:2] < (3, 6):
197      # Make sure annotation errors are reported in a deterministic order.
198      raw_items = sorted(raw_items, key=str)
199    for name, t in raw_items:
200      # Don't use the parameter name, since it's often something unhelpful
201      # like `0`.
202      annot = self._process_one_annotation(
203          node, t, None, self.vm.simple_stack())
204      annotations[name] = annot or self.vm.convert.unsolvable
205    return annotations
206
207  def init_annotation(self, node, name, annot, container=None, extra_key=None):
208    node, value = self.vm.init_class(
209        node, annot, container=container, extra_key=extra_key)
210    for d in value.data:
211      d.from_annotation = name
212    return node, value
213
214  def extract_and_init_annotation(self, node, name, var):
215    """Extracts an annotation from var and instantiates it."""
216    frame = self.vm.frame
217    substs = frame.substs
218    if frame.func and isinstance(frame.func.data, abstract.BoundFunction):
219      self_var = frame.f_locals.pyval.get("self")
220      if self_var:
221        type_params = []
222        for v in self_var.data:
223          if v.cls:
224            # Normalize type parameter names by dropping the scope.
225            type_params.extend(p.with_module(None) for p in v.cls.template)
226        self_substs = tuple(
227            abstract_utils.get_type_parameter_substitutions(v, type_params)
228            for v in self_var.data)
229        substs = abstract_utils.combine_substs(substs, self_substs)
230    allowed_type_params = set(
231        itertools.chain(*substs, self.get_callable_type_parameter_names(var)))
232    typ = self.extract_annotation(
233        node, var, name, self.vm.simple_stack(),
234        allowed_type_params=allowed_type_params)
235    if typ.formal:
236      resolved_type = self.sub_one_annotation(node, typ, substs,
237                                              instantiate_unbound=False)
238      _, value = self.init_annotation(node, name, resolved_type)
239    else:
240      _, value = self.init_annotation(node, name, typ)
241    return typ, value
242
243  def apply_annotation(self, node, op, name, value):
244    """If there is an annotation for the op, return its value."""
245    assert op is self.vm.frame.current_opcode
246    if op.code.co_filename != self.vm.filename:
247      return None, value
248    if not op.annotation:
249      return None, value
250    annot = op.annotation
251    frame = self.vm.frame
252    with self.vm.generate_late_annotations(self.vm.simple_stack()):
253      var, errorlog = abstract_utils.eval_expr(
254          self.vm, node, frame.f_globals, frame.f_locals, annot)
255    if errorlog:
256      self.vm.errorlog.invalid_annotation(
257          self.vm.frames, annot, details=errorlog.details)
258    return self.extract_and_init_annotation(node, name, var)
259
260  def extract_annotation(
261      self, node, var, name, stack, allowed_type_params=None):
262    """Returns an annotation extracted from 'var'.
263
264    Args:
265      node: The current node.
266      var: The variable to extract from.
267      name: The annotated name.
268      stack: The frame stack.
269      allowed_type_params: Type parameters that are allowed to appear in the
270        annotation. 'None' means all are allowed.
271    """
272    try:
273      typ = abstract_utils.get_atomic_value(var)
274    except abstract_utils.ConversionError:
275      self.vm.errorlog.ambiguous_annotation(self.vm.frames, None, name)
276      return self.vm.convert.unsolvable
277    typ = self._process_one_annotation(node, typ, name, stack)
278    if not typ:
279      return self.vm.convert.unsolvable
280    if typ.formal and allowed_type_params is not None:
281      illegal_params = [x.name for x in self.get_type_parameters(typ)
282                        if x.name not in allowed_type_params]
283      if illegal_params:
284        details = "TypeVar(s) %s not in scope" % ", ".join(
285            repr(p) for p in utils.unique_list(illegal_params))
286        if self.vm.frame.func:
287          method = self.vm.frame.func.data
288          if isinstance(method, abstract.BoundFunction):
289            desc = "class"
290            frame_name = method.name.rsplit(".", 1)[0]
291          else:
292            desc = "class" if method.is_class_builder else "method"
293            frame_name = method.name
294          details += f" for {desc} {frame_name!r}"
295        if "AnyStr" in illegal_params:
296          str_type = "Union[str, bytes]"
297          details += (f"\nNote: For all string types, use {str_type}.")
298        self.vm.errorlog.invalid_annotation(stack, typ, details, name)
299        return self.vm.convert.unsolvable
300    return typ
301
302  def eval_multi_arg_annotation(self, node, func, annot, stack):
303    """Evaluate annotation for multiple arguments (from a type comment)."""
304    args, errorlog = self._eval_expr_as_tuple(node, annot, stack)
305    if errorlog:
306      self.vm.errorlog.invalid_function_type_comment(
307          stack, annot, details=errorlog.details)
308    code = func.code
309    expected = code.get_arg_count()
310    names = code.co_varnames
311
312    # This is a hack.  Specifying the type of the first arg is optional in
313    # class and instance methods.  There is no way to tell at this time
314    # how the function will be used, so if the first arg is self or cls we
315    # make it optional.  The logic is somewhat convoluted because we don't
316    # want to count the skipped argument in an error message.
317    if len(args) != expected:
318      if expected and names[0] in ["self", "cls"]:
319        expected -= 1
320        names = names[1:]
321
322    if len(args) != expected:
323      self.vm.errorlog.invalid_function_type_comment(
324          stack, annot,
325          details="Expected %d args, %d given" % (expected, len(args)))
326      return
327    for name, arg in zip(names, args):
328      resolved = self._process_one_annotation(node, arg, name, stack)
329      if resolved is not None:
330        func.signature.set_annotation(name, resolved)
331
332  def _process_one_annotation(self, node, annotation, name, stack):
333    """Change annotation / record errors where required."""
334    # Make sure we pass in a frozen snapshot of the frame stack, rather than the
335    # actual stack, since late annotations need to snapshot the stack at time of
336    # creation in order to get the right line information for error messages.
337    assert isinstance(stack, tuple), "stack must be an immutable sequence"
338
339    if isinstance(annotation, abstract.AnnotationContainer):
340      annotation = annotation.base_cls
341
342    if isinstance(annotation, typing_overlay.Union):
343      self.vm.errorlog.invalid_annotation(
344          stack, annotation, "Needs options", name)
345      return None
346    elif (name is not None and name != "return"
347          and isinstance(annotation, typing_overlay.NoReturn)):
348      self.vm.errorlog.invalid_annotation(
349          stack, annotation, "NoReturn is not allowed", name)
350      return None
351    elif isinstance(annotation, abstract.Instance) and (
352        annotation.cls == self.vm.convert.str_type or
353        annotation.cls == self.vm.convert.unicode_type
354    ):
355      # String annotations : Late evaluation
356      if isinstance(annotation, mixin.PythonConstant):
357        expr = annotation.pyval
358        if not expr:
359          self.vm.errorlog.invalid_annotation(
360              stack, annotation, "Cannot be an empty string", name)
361          return None
362        frame = self.vm.frame
363        # Immediately try to evaluate the reference, generating LateAnnotation
364        # objects as needed. We don't store the entire string as a
365        # LateAnnotation because:
366        # - With __future__.annotations, all annotations look like forward
367        #   references - most of them don't need to be late evaluated.
368        # - Given an expression like "Union[str, NotYetDefined]", we want to
369        #   evaluate the union immediately so we don't end up with a complex
370        #   LateAnnotation, which can lead to bugs when instantiated.
371        with self.vm.generate_late_annotations(stack):
372          v, errorlog = abstract_utils.eval_expr(
373              self.vm, node, frame.f_globals, frame.f_locals, expr)
374        if errorlog:
375          self.vm.errorlog.copy_from(errorlog.errors, stack)
376        if len(v.data) == 1:
377          return self._process_one_annotation(node, v.data[0], name, stack)
378      self.vm.errorlog.ambiguous_annotation(stack, [annotation], name)
379      return None
380    elif annotation.cls == self.vm.convert.none_type:
381      # PEP 484 allows to write "NoneType" as "None"
382      return self.vm.convert.none_type
383    elif isinstance(annotation, mixin.NestedAnnotation):
384      if annotation.processed:
385        return annotation
386      annotation.processed = True
387      for key, typ in annotation.get_inner_types():
388        processed = self._process_one_annotation(node, typ, name, stack)
389        if processed is None:
390          return None
391        elif isinstance(processed, typing_overlay.NoReturn):
392          self.vm.errorlog.invalid_annotation(
393              stack, typ, "NoReturn is not allowed as inner type", name)
394          return None
395        annotation.update_inner_type(key, processed)
396      return annotation
397    elif isinstance(annotation, (class_mixin.Class,
398                                 abstract.AMBIGUOUS_OR_EMPTY,
399                                 abstract.TypeParameter,
400                                 typing_overlay.NoReturn)):
401      return annotation
402    else:
403      self.vm.errorlog.invalid_annotation(stack, annotation, "Not a type", name)
404      return None
405
406  def _eval_expr_as_tuple(self, node, expr, stack):
407    """Evaluate an expression as a tuple."""
408    if not expr:
409      return (), None
410
411    f_globals, f_locals = self.vm.frame.f_globals, self.vm.frame.f_locals
412    with self.vm.generate_late_annotations(stack):
413      result_var, errorlog = abstract_utils.eval_expr(
414          self.vm, node, f_globals, f_locals, expr)
415    result = abstract_utils.get_atomic_value(result_var)
416    # If the result is a tuple, expand it.
417    if (isinstance(result, mixin.PythonConstant) and
418        isinstance(result.pyval, tuple)):
419      return (tuple(abstract_utils.get_atomic_value(x) for x in result.pyval),
420              errorlog)
421    else:
422      return (result,), errorlog
423
424  def deformalize(self, value):
425    # TODO(rechen): Instead of doing this, call sub_one_annotation() to replace
426    # type parameters with their bound/constraints.
427    while value.formal:
428      if isinstance(value, abstract.ParameterizedClass):
429        value = value.base_cls
430      else:
431        value = self.vm.convert.unsolvable
432    return value
433