1"""Representation of Python function headers and calls."""
3import collections
4import itertools
5import logging
7from pytype import abstract_utils
8from pytype import datatypes
9from pytype import utils
10from pytype.pytd import pytd
11from pytype.pytd import pytd_utils
13log = logging.getLogger(__name__)
16def argname(i):
17  """Get a name for an unnamed positional argument, given its position."""
18  return "_" + str(i)
21def _print(t):
22  return pytd_utils.Print(t.get_instance_type())
25class Signature:
26  """Representation of a Python function signature.
28  Attributes:
29    name: Name of the function.
30    param_names: A tuple of positional parameter names.
31    varargs_name: Name of the varargs parameter. (The "args" in *args)
32    kwonly_params: Tuple of keyword-only parameters. (Python 3)
33      E.g. ("x", "y") for "def f(a, *, x, y=2)". These do NOT appear in
34      param_names. Ordered like in the source file.
35    kwargs_name: Name of the kwargs parameter. (The "kwargs" in **kwargs)
36    defaults: Dictionary, name to value, for all parameters with default values.
37    annotations: A dictionary of type annotations. (string to type)
38    excluded_types: A set of type names that will be ignored when checking the
39      count of type parameters.
40    type_params: The set of type parameter names that appear in annotations.
41    has_return_annotation: Whether the function has a return annotation.
42    has_param_annotations: Whether the function has parameter annotations.
43  """
45  def __init__(self, name, param_names, varargs_name, kwonly_params,
46               kwargs_name, defaults, annotations,
47               postprocess_annotations=True):
48    self.name = name
49    self.param_names = param_names
50    self.varargs_name = varargs_name
51    self.kwonly_params = kwonly_params
52    self.kwargs_name = kwargs_name
53    self.defaults = defaults
54    self.annotations = annotations
55    self.excluded_types = set()
56    if postprocess_annotations:
57      for k, annot in self.annotations.items():
58        self.annotations[k] = self._postprocess_annotation(k, annot)
59    self.type_params = set()
60    for annot in self.annotations.values():
61      self.type_params.update(
62          p.name for p in annot.vm.annotations_util.get_type_parameters(annot))
64  @property
65  def has_return_annotation(self):
66    return "return" in self.annotations
68  @property
69  def has_param_annotations(self):
70    return bool(self.annotations.keys() - {"return"})
72  def add_scope(self, module):
73    """Add scope for type parameters in annotations."""
74    annotations = {}
75    for key, val in self.annotations.items():
76      annotations[key] = val.vm.annotations_util.add_scope(
77          val, self.excluded_types, module)
78    self.annotations = annotations
80  def _postprocess_annotation(self, name, annotation):
81    if name == self.varargs_name:
82      return annotation.vm.convert.create_new_varargs_value(annotation)
83    elif name == self.kwargs_name:
84      return annotation.vm.convert.create_new_kwargs_value(annotation)
85    else:
86      return annotation
88  def set_annotation(self, name, annotation):
89    self.annotations[name] = self._postprocess_annotation(name, annotation)
91  def del_annotation(self, name):
92    del self.annotations[name]  # Raises KeyError if annotation does not exist.
94  def check_type_parameter_count(self, stack):
95    """Check the count of type parameters in function."""
96    c = collections.Counter()
97    for annot in self.annotations.values():
98      c.update(annot.vm.annotations_util.get_type_parameters(annot))
99    for param, count in c.items():
100      if param.name in self.excluded_types:
101        # skip all the type parameters in `excluded_types`
102        continue
103      if count == 1 and not (param.constraints or param.bound or
104                             param.covariant or param.contravariant):
105        param.vm.errorlog.invalid_annotation(
106            stack, param, "Appears only once in the signature")
108  def drop_first_parameter(self):
109    return self._replace(param_names=self.param_names[1:])
111  def mandatory_param_count(self):
112    num = len([name
113               for name in self.param_names if name not in self.defaults])
114    num += len([name
115                for name in self.kwonly_params if name not in self.defaults])
116    return num
118  def maximum_param_count(self):
119    if self.varargs_name or self.kwargs_name:
120      return None
121    return len(self.param_names) + len(self.kwonly_params)
123  @classmethod
124  def from_pytd(cls, vm, name, sig):
125    """Construct an abstract signature from a pytd signature."""
126    pytd_annotations = [(p.name, p.type)
127                        for p in sig.params + (sig.starargs, sig.starstarargs)
128                        if p is not None]
129    pytd_annotations.append(("return", sig.return_type))
130    def param_to_var(p):
131      return vm.convert.constant_to_var(
132          p.type, subst=datatypes.AliasingDict(), node=vm.root_node)
134    return cls(
135        name=name,
136        param_names=tuple(p.name for p in sig.params if not p.kwonly),
137        varargs_name=None if sig.starargs is None else sig.starargs.name,
138        kwonly_params=tuple(p.name for p in sig.params if p.kwonly),
139        kwargs_name=None if sig.starstarargs is None else sig.starstarargs.name,
140        defaults={p.name: param_to_var(p) for p in sig.params if p.optional},
141        annotations={
142            name: vm.convert.constant_to_value(
143                typ, subst=datatypes.AliasingDict(), node=vm.root_node)
144            for name, typ in pytd_annotations
145        },
146        postprocess_annotations=False,
147    )
149  @classmethod
150  def from_callable(cls, val):
151    annotations = {argname(i): val.formal_type_parameters[i]
152                   for i in range(val.num_args)}
153    return cls(
154        name="<callable>",
155        param_names=tuple(sorted(annotations)),
156        varargs_name=None,
157        kwonly_params=(),
158        kwargs_name=None,
159        defaults={},
160        annotations=annotations,
161    )
163  @classmethod
164  def from_param_names(cls, name, param_names):
165    """Construct a minimal signature from a name and a list of param names."""
166    return cls(
167        name=name,
168        param_names=tuple(param_names),
169        varargs_name=None,
170        kwonly_params=(),
171        kwargs_name=None,
172        defaults={},
173        annotations={},
174    )
176  def has_param(self, name):
177    return name in self.param_names or name in self.kwonly_params or (
178        name == self.varargs_name or name == self.kwargs_name)
180  def insert_varargs_and_kwargs(self, arg_dict):
181    """Insert varargs and kwargs from arg_dict into the signature.
183    Args:
184      arg_dict: A name->binding dictionary of passed args.
186    Returns:
187      A copy of this signature with the passed varargs and kwargs inserted.
188    """
189    varargs_names = []
190    kwargs_names = []
191    for name in arg_dict:
192      if self.has_param(name):
193        continue
194      if pytd_utils.ANON_PARAM.match(name):
195        varargs_names.append(name)
196      else:
197        kwargs_names.append(name)
198    new_param_names = (self.param_names + tuple(sorted(varargs_names)) +
199                       tuple(sorted(kwargs_names)))
200    return self._replace(param_names=new_param_names)
202  _ATTRIBUTES = (
203      set(__init__.__code__.co_varnames[:__init__.__code__.co_argcount]) -
204      {"self", "postprocess_annotations"})
206  def _replace(self, **kwargs):
207    """Returns a copy of the signature with the specified values replaced."""
208    assert not set(kwargs) - self._ATTRIBUTES
209    for attr in self._ATTRIBUTES:
210      if attr not in kwargs:
211        kwargs[attr] = getattr(self, attr)
212    kwargs["postprocess_annotations"] = False
213    return type(self)(**kwargs)
215  def iter_args(self, args):
216    """Iterates through the given args, attaching names and expected types."""
217    for i, posarg in enumerate(args.posargs):
218      if i < len(self.param_names):
219        name = self.param_names[i]
220        yield (name, posarg, self.annotations.get(name))
221      elif self.varargs_name and self.varargs_name in self.annotations:
222        varargs_type = self.annotations[self.varargs_name]
223        formal = varargs_type.vm.convert.get_element_type(varargs_type)
224        yield (argname(i), posarg, formal)
225      else:
226        yield (argname(i), posarg, None)
227    for name, namedarg in sorted(args.namedargs.items()):
228      formal = self.annotations.get(name)
229      if formal is None and self.kwargs_name:
230        kwargs_type = self.annotations.get(self.kwargs_name)
231        if kwargs_type:
232          formal = kwargs_type.vm.convert.get_element_type(kwargs_type)
233      yield (name, namedarg, formal)
234    if self.varargs_name is not None and args.starargs is not None:
235      yield (self.varargs_name, args.starargs,
236             self.annotations.get(self.varargs_name))
237    if self.kwargs_name is not None and args.starstarargs is not None:
238      yield (self.kwargs_name, args.starstarargs,
239             self.annotations.get(self.kwargs_name))
241  def check_defaults(self):
242    """Returns the first non-default param following a default."""
243    # TODO(mdemello): We should raise an error here, analogous to
244    # the python-compiler-error we would get if analyzing the signature from a
245    # source file, but this class does not have access to the vm, and the
246    # exception hierarchy in this module derives from FailedFunctionCall.
247    has_default = False
248    for name in self.param_names:
249      if name in self.defaults:
250        has_default = True
251      elif has_default:
252        return name
253    return None
255  def _yield_arguments(self):
256    """Yield all the function arguments."""
257    names = list(self.param_names)
258    if self.varargs_name:
259      names.append("*" + self.varargs_name)
260    elif self.kwonly_params:
261      names.append("*")
262    names.extend(sorted(self.kwonly_params))
263    if self.kwargs_name:
264      names.append("**" + self.kwargs_name)
265    for name in names:
266      base_name = name.lstrip("*")
267      annot = self._print_annot(base_name)
268      default = self._print_default(base_name)
269      yield name + (": " + annot if annot else "") + (
270          " = " + default if default else "")
272  def _print_annot(self, name):
273    return _print(self.annotations[name]) if name in self.annotations else None
275  def _print_default(self, name):
276    if name in self.defaults:
277      values = self.defaults[name].data
278      if len(values) > 1:
279        return "Union[%s]" % ", ".join(_print(v) for v in values)
280      else:
281        return _print(values[0])
282    else:
283      return None
285  def __repr__(self):
286    args = ", ".join(self._yield_arguments())
287    ret = self._print_annot("return")
288    return "def {name}({args}) -> {ret}".format(
289        name=self.name, args=args, ret=ret if ret else "Any")
291  def get_first_arg(self, callargs):
292    return callargs.get(self.param_names[0]) if self.param_names else None
295class Args(collections.namedtuple(
296    "Args", ["posargs", "namedargs", "starargs", "starstarargs"])):
297  """Represents the parameters of a function call."""
299  def __new__(cls, posargs, namedargs=None, starargs=None, starstarargs=None):
300    """Create arguments for a function under analysis.
302    Args:
303      posargs: The positional arguments. A tuple of cfg.Variable.
304      namedargs: The keyword arguments. A dictionary, mapping strings to
305        cfg.Variable.
306      starargs: The *args parameter, or None.
307      starstarargs: The **kwargs parameter, or None.
308    Returns:
309      An Args instance.
310    """
311    assert isinstance(posargs, tuple), posargs
312    cls.replace = cls._replace
313    return super().__new__(
314        cls,
315        posargs=posargs,
316        namedargs=namedargs or {},
317        starargs=starargs,
318        starstarargs=starstarargs)
320  def is_empty(self):
321    if self.posargs or self.starargs or self.starstarargs:
322      return False
323    if isinstance(self.namedargs, dict):
324      return not self.namedargs
325    else:
326      return not self.namedargs.pyval
328  def starargs_as_tuple(self, node, vm):
329    try:
330      args = self.starargs and abstract_utils.get_atomic_python_constant(
331          self.starargs, tuple)
332    except abstract_utils.ConversionError:
333      args = None
334    if not args:
335      return args
336    return tuple(var if var.bindings else vm.convert.empty.to_variable(node)
337                 for var in args)
339  def starstarargs_as_dict(self):
340    try:
341      args = self.starstarargs and abstract_utils.get_atomic_python_constant(
342          self.starstarargs, dict)
343    except abstract_utils.ConversionError:
344      args = None
345    return args
347  def _expand_typed_star(self, vm, node, star, count):
348    """Convert *xs: Sequence[T] -> [T, T, ...]."""
349    if not count:
350      return []
351    p = abstract_utils.merged_type_parameter(node, star, abstract_utils.T)
352    if not p.bindings:
353      # TODO(b/159052609): This shouldn't happen. For some reason,
354      # namedtuple instances don't have any bindings in T; see
355      # tests/test_unpack:TestUnpack.test_unpack_namedtuple.
356      return [vm.new_unsolvable(node) for _ in range(count)]
357    return [p.AssignToNewVariable(node) for _ in range(count)]
359  def _unpack_and_match_args(self, node, vm, match_signature, starargs_tuple):
360    """Match args against a signature with unpacking."""
361    posargs = self.posargs
362    namedargs = self.namedargs
363    # As we have the function signature we will attempt to adjust the
364    # starargs into the missing posargs.
365    pre = []
366    post = []
367    stars = collections.deque(starargs_tuple)
368    while stars and not abstract_utils.is_var_splat(stars[0]):
369      pre.append(stars.popleft())
370    while stars and not abstract_utils.is_var_splat(stars[-1]):
371      post.append(stars.pop())
372    post.reverse()
373    n_matched = len(posargs) + len(pre) + len(post)
374    required_posargs = 0
375    for p in match_signature.param_names:
376      if p in namedargs or p in match_signature.defaults:
377        break
378      required_posargs += 1
379    posarg_delta = required_posargs - n_matched
381    if stars and not post:
382      star = stars[-1]
383      if match_signature.varargs_name:
384        # If the invocation ends with `*args`, return it to match against *args
385        # in the function signature. For f(<k args>, *xs, ..., *ys), transform
386        # to f(<k args>, *ys) since ys is an indefinite tuple anyway and will
387        # match against all remaining posargs.
388        return posargs + tuple(pre), abstract_utils.unwrap_splat(star)
389      else:
390        # If we do not have a `*args` in match_signature, just expand the
391        # terminal splat to as many args as needed and then drop it.
392        mid = self._expand_typed_star(vm, node, star, posarg_delta)
393        return posargs + tuple(pre + mid), None
394    elif posarg_delta <= len(stars):
395      # We have too many args; don't do *xs expansion. Go back to matching from
396      # the start and treat every entry in starargs_tuple as length 1.
397      n_params = len(match_signature.param_names)
398      all_args = posargs + starargs_tuple
399      if not match_signature.varargs_name:
400        # If the function sig has no *args, return everything in posargs
401        pos = _splats_to_any(all_args, vm)
402        return pos, None
403      # Don't unwrap splats here because f(*xs, y) is not the same as f(xs, y).
404      # TODO(mdemello): Ideally, since we are matching call f(*xs, y) against
405      # sig f(x, y) we should raise an error here.
406      pos = _splats_to_any(all_args[:n_params], vm)
407      star = []
408      for var in all_args[n_params:]:
409        if abstract_utils.is_var_splat(var):
410          star.append(
411              abstract_utils.merged_type_parameter(node, var, abstract_utils.T))
412        else:
413          star.append(var)
414      if star:
415        return pos, vm.convert.tuple_to_value(star).to_variable(node)
416      else:
417        return pos, None
418    elif stars:
419      if len(stars) == 1:
420        # Special case (<pre>, *xs) and (*xs, <post>) to fill in the type of xs
421        # in every remaining arg.
422        mid = self._expand_typed_star(vm, node, stars[0], posarg_delta)
423      else:
424        # If we have (*xs, <k args>, *ys) remaining, and more than k+2 params to
425        # match, don't try to match the intermediate params to any range, just
426        # match all k+2 to Any
427        mid = [vm.new_unsolvable(node) for _ in range(posarg_delta)]
428      return posargs + tuple(pre + mid + post), None
429    else:
430      # We have **kwargs but no *args in the invocation
431      return posargs + tuple(pre), None
433  def simplify(self, node, vm, match_signature=None):
434    """Try to insert part of *args, **kwargs into posargs / namedargs."""
435    # TODO(rechen): When we have type information about *args/**kwargs,
436    # we need to check it before doing this simplification.
437    posargs = self.posargs
438    namedargs = self.namedargs
439    starargs = self.starargs
440    starstarargs = self.starstarargs
441    # Unpack starstarargs into namedargs. We need to do this first so we can see
442    # what posargs are still required.
443    starstarargs_as_dict = self.starstarargs_as_dict()
444    if starstarargs_as_dict is not None:
445      # Unlike varargs below, we do not adjust starstarargs into namedargs when
446      # the function signature has matching param_names because we have not
447      # found a benefit in doing so.
448      if self.namedargs is None:
449        namedargs = starstarargs_as_dict
450      else:
451        namedargs.update(node, starstarargs_as_dict)
452      starstarargs = None
453    starargs_as_tuple = self.starargs_as_tuple(node, vm)
454    if starargs_as_tuple is not None:
455      if match_signature:
456        posargs, starargs = self._unpack_and_match_args(
457            node, vm, match_signature, starargs_as_tuple)
458      elif (starargs_as_tuple and
459            abstract_utils.is_var_splat(starargs_as_tuple[-1])):
460        # If the last arg is an indefinite iterable keep it in starargs. Convert
461        # any other splats to Any.
462        # TODO(mdemello): If there are multiple splats should we just fall
463        # through to the next case (setting them all to Any), and only hit this
464        # case for a *single* splat in terminal position?
465        posargs = self.posargs + _splats_to_any(starargs_as_tuple[:-1], vm)
466        starargs = abstract_utils.unwrap_splat(starargs_as_tuple[-1])
467      else:
468        # Don't try to unpack iterables in any other position since we don't
469        # have a signature to match. Just set all splats to Any.
470        posargs = self.posargs + _splats_to_any(starargs_as_tuple, vm)
471        starargs = None
472    return Args(posargs, namedargs, starargs, starstarargs)
474  def get_variables(self):
475    variables = list(self.posargs) + list(self.namedargs.values())
476    if self.starargs is not None:
477      variables.append(self.starargs)
478    if self.starstarargs is not None:
479      variables.append(self.starstarargs)
480    return variables
483class ReturnValueMixin:
484  """Mixin for exceptions that hold a return node and variable."""
486  def __init__(self):
487    super().__init__()
488    self.return_node = None
489    self.return_variable = None
491  def set_return(self, node, var):
492    self.return_node = node
493    self.return_variable = var
495  def get_return(self, state):
496    return state.change_cfg_node(self.return_node), self.return_variable
499# These names are chosen to match pytype error classes.
500# pylint: disable=g-bad-exception-name
501class FailedFunctionCall(Exception, ReturnValueMixin):
502  """Exception for failed function calls."""
504  def __gt__(self, other):
505    return other is None
508class NotCallable(FailedFunctionCall):
509  """For objects that don't have __call__."""
511  def __init__(self, obj):
512    super().__init__()
513    self.obj = obj
516class UndefinedParameterError(FailedFunctionCall):
517  """Function called with an undefined variable."""
519  def __init__(self, name):
520    super().__init__()
521    self.name = name
524class DictKeyMissing(Exception, ReturnValueMixin):
525  """When retrieving a key that does not exist in a dict."""
527  def __init__(self, name):
528    super().__init__()
529    self.name = name
531  def __gt__(self, other):
532    return other is None
535BadCall = collections.namedtuple("_", ["sig", "passed_args", "bad_param"])
538class BadParam(
539    collections.namedtuple("_", ["name", "expected", "protocol_error",
540                                 "noniterable_str_error"])):
542  def __new__(cls, name, expected, protocol_error=None,
543              noniterable_str_error=None):
544    return super().__new__(cls, name, expected, protocol_error,
545                           noniterable_str_error)
548class InvalidParameters(FailedFunctionCall):
549  """Exception for functions called with an incorrect parameter combination."""
551  def __init__(self, sig, passed_args, vm, bad_param=None):
552    super().__init__()
553    self.name = sig.name
554    passed_args = [(name, vm.merge_values(arg.data))
555                   for name, arg, _ in sig.iter_args(passed_args)]
556    self.bad_call = BadCall(sig=sig, passed_args=passed_args,
557                            bad_param=bad_param)
560class WrongArgTypes(InvalidParameters):
561  """For functions that were called with the wrong types."""
563  def __gt__(self, other):
564    return other is None or (isinstance(other, FailedFunctionCall) and
565                             not isinstance(other, WrongArgTypes))
568class WrongArgCount(InvalidParameters):
569  """E.g. if a function expecting 4 parameters is called with 3."""
572class WrongKeywordArgs(InvalidParameters):
573  """E.g. an arg "x" is passed to a function that doesn't have an "x" param."""
575  def __init__(self, sig, passed_args, vm, extra_keywords):
576    super().__init__(sig, passed_args, vm)
577    self.extra_keywords = tuple(extra_keywords)
580class DuplicateKeyword(InvalidParameters):
581  """E.g. an arg "x" is passed to a function as both a posarg and a kwarg."""
583  def __init__(self, sig, passed_args, vm, duplicate):
584    super().__init__(sig, passed_args, vm)
585    self.duplicate = duplicate
588class MissingParameter(InvalidParameters):
589  """E.g. a function requires parameter 'x' but 'x' isn't passed."""
591  def __init__(self, sig, passed_args, vm, missing_parameter):
592    super().__init__(sig, passed_args, vm)
593    self.missing_parameter = missing_parameter
594# pylint: enable=g-bad-exception-name
597class Mutation(collections.namedtuple("_", ["instance", "name", "value"])):
599  def __eq__(self, other):
600    return (self.instance == other.instance and
601            self.name == other.name and
602            frozenset(self.value.data) == frozenset(other.value.data))
604  def __hash__(self):
605    return hash((self.instance, self.name, frozenset(self.value.data)))
608class PyTDSignature(utils.VirtualMachineWeakrefMixin):
609  """A PyTD function type (signature).
611  This represents instances of functions with specific arguments and return
612  type.
613  """
615  def __init__(self, name, pytd_sig, vm):
616    super().__init__(vm)
617    self.name = name
618    self.pytd_sig = pytd_sig
619    self.param_types = [
620        self.vm.convert.constant_to_value(
621            p.type, subst=datatypes.AliasingDict(), node=self.vm.root_node)
622        for p in self.pytd_sig.params
623    ]
624    self.signature = Signature.from_pytd(vm, name, pytd_sig)
626  def _map_args(self, args, view):
627    """Map the passed arguments to a name->binding dictionary.
629    Args:
630      args: The passed arguments.
631      view: A variable->binding dictionary.
633    Returns:
634      A tuple of:
635        a list of formal arguments, each a (name, abstract value) pair;
636        a name->binding dictionary of the passed arguments.
638    Raises:
639      InvalidParameters: If the passed arguments don't match this signature.
640    """
641    formal_args = [(p.name, self.signature.annotations[p.name])
642                   for p in self.pytd_sig.params]
643    arg_dict = {}
645    # positional args
646    for name, arg in zip(self.signature.param_names, args.posargs):
647      arg_dict[name] = view[arg]
648    num_expected_posargs = len(self.signature.param_names)
649    if len(args.posargs) > num_expected_posargs and not self.pytd_sig.starargs:
650      raise WrongArgCount(self.signature, args, self.vm)
651    # Extra positional args are passed via the *args argument.
652    varargs_type = self.signature.annotations.get(self.signature.varargs_name)
653    if varargs_type and varargs_type.isinstance_ParameterizedClass():
654      for (i, vararg) in enumerate(args.posargs[num_expected_posargs:]):
655        name = argname(num_expected_posargs + i)
656        arg_dict[name] = view[vararg]
657        formal_args.append(
658            (name, varargs_type.get_formal_type_parameter(abstract_utils.T)))
660    # named args
661    for name, arg in args.namedargs.items():
662      if name in arg_dict:
663        raise DuplicateKeyword(self.signature, args, self.vm, name)
664      arg_dict[name] = view[arg]
665    extra_kwargs = set(args.namedargs) - {p.name for p in self.pytd_sig.params}
666    if extra_kwargs and not self.pytd_sig.starstarargs:
667      raise WrongKeywordArgs(self.signature, args, self.vm, extra_kwargs)
668    # Extra keyword args are passed via the **kwargs argument.
669    kwargs_type = self.signature.annotations.get(self.signature.kwargs_name)
670    if kwargs_type and kwargs_type.isinstance_ParameterizedClass():
671      # We sort the kwargs so that matching always happens in the same order.
672      for name in sorted(extra_kwargs):
673        formal_args.append(
674            (name, kwargs_type.get_formal_type_parameter(abstract_utils.V)))
676    # packed args
677    packed_args = [("starargs", self.signature.varargs_name),
678                   ("starstarargs", self.signature.kwargs_name)]
679    for arg_type, name in packed_args:
680      actual = getattr(args, arg_type)
681      pytd_val = getattr(self.pytd_sig, arg_type)
682      if actual and pytd_val:
683        arg_dict[name] = view[actual]
684        # The annotation is Tuple or Dict, but the passed arg only has to be
685        # Iterable or Mapping.
686        typ = self.vm.convert.widen_type(self.signature.annotations[name])
687        formal_args.append((name, typ))
689    return formal_args, arg_dict
691  def _fill_in_missing_parameters(self, node, args, arg_dict):
692    for p in self.pytd_sig.params:
693      if p.name not in arg_dict:
694        if (not p.optional and args.starargs is None and
695            args.starstarargs is None):
696          raise MissingParameter(self.signature, args, self.vm, p.name)
697        # Assume the missing parameter is filled in by *args or **kwargs.
698        # Unfortunately, we can't easily use *args or **kwargs to fill in
699        # something more precise, since we need a Value, not a Variable.
700        arg_dict[p.name] = self.vm.convert.unsolvable.to_binding(node)
702  def substitute_formal_args(self, node, args, view, alias_map):
703    """Substitute matching args into this signature. Used by PyTDFunction."""
704    formal_args, arg_dict = self._map_args(args, view)
705    self._fill_in_missing_parameters(node, args, arg_dict)
706    subst, bad_arg = self.vm.matcher(node).compute_subst(
707        formal_args, arg_dict, view, alias_map)
708    if subst is None:
709      if self.signature.has_param(bad_arg.name):
710        signature = self.signature
711      else:
712        signature = self.signature.insert_varargs_and_kwargs(arg_dict)
713      raise WrongArgTypes(signature, args, self.vm, bad_param=bad_arg)
714    if log.isEnabledFor(logging.DEBUG):
715      log.debug("Matched arguments against sig%s",
716                pytd_utils.Print(self.pytd_sig))
717    for nr, p in enumerate(self.pytd_sig.params):
718      log.info("param %d) %s: %s <=> %s", nr, p.name, p.type, arg_dict[p.name])
719    for name, var in sorted(subst.items()):
720      log.debug("Using %s=%r %r", name, var, var.data)
722    return arg_dict, subst
724  def instantiate_return(self, node, subst, sources):
725    return_type = self.pytd_sig.return_type
726    # Type parameter values, which are instantiated by the matcher, will end up
727    # in the return value. Since the matcher does not call __init__, we need to
728    # do that now. The one exception is that Type[X] does not instantiate X, so
729    # we do not call X.__init__.
730    if return_type.name != "builtins.type":
731      for param in pytd_utils.GetTypeParameters(return_type):
732        if param.full_name in subst:
733          node = self.vm.call_init(node, subst[param.full_name])
734    try:
735      ret = self.vm.convert.constant_to_var(
736          abstract_utils.AsReturnValue(return_type), subst, node,
737          source_sets=[sources])
738    except self.vm.convert.TypeParameterError:
739      # The return type contains a type parameter without a substitution.
740      subst = subst.copy()
741      for t in pytd_utils.GetTypeParameters(return_type):
742        if t.full_name not in subst:
743          subst[t.full_name] = self.vm.convert.empty.to_variable(node)
744      return node, self.vm.convert.constant_to_var(
745          abstract_utils.AsReturnValue(return_type), subst, node,
746          source_sets=[sources])
747    if not ret.bindings and isinstance(return_type, pytd.TypeParameter):
748      ret.AddBinding(self.vm.convert.empty, [], node)
749    return node, ret
751  def call_with_args(self, node, func, arg_dict,
752                     subst, ret_map, alias_map=None):
753    """Call this signature. Used by PyTDFunction."""
754    t = (self.pytd_sig.return_type, subst)
755    sources = [func] + list(arg_dict.values())
756    if t not in ret_map:
757      node, ret_map[t] = self.instantiate_return(node, subst, sources)
758    else:
759      # add the new sources
760      for data in ret_map[t].data:
761        ret_map[t].AddBinding(data, sources, node)
762    mutations = self._get_mutation(node, arg_dict, subst, ret_map[t])
763    self.vm.trace_call(node, func, (self,),
764                       tuple(arg_dict[p.name] for p in self.pytd_sig.params),
765                       {}, ret_map[t])
766    return node, ret_map[t], mutations
768  @classmethod
769  def _collect_mutated_parameters(cls, typ, mutated_type):
770    if (isinstance(typ, pytd.UnionType) and
771        isinstance(mutated_type, pytd.UnionType)):
772      if len(typ.type_list) != len(mutated_type.type_list):
773        raise ValueError(
774            "Type list lengths do not match:\nOld: %s\nNew: %s" %
775            (typ.type_list, mutated_type.type_list))
776      return itertools.chain.from_iterable(
777          cls._collect_mutated_parameters(t1, t2)
778          for t1, t2 in zip(typ.type_list, mutated_type.type_list))
779    if typ == mutated_type and isinstance(typ, pytd.ClassType):
780      return []  # no mutation needed
781    if (not isinstance(typ, pytd.GenericType) or
782        not isinstance(mutated_type, pytd.GenericType) or
783        typ.base_type != mutated_type.base_type or
784        not isinstance(typ.base_type, pytd.ClassType) or
785        not typ.base_type.cls):
786      raise ValueError("Unsupported mutation:\n%r ->\n%r" %
787                       (typ, mutated_type))
788    return [zip(mutated_type.base_type.cls.template, mutated_type.parameters)]
790  def _get_mutation(self, node, arg_dict, subst, retvar):
791    """Mutation for changing the type parameters of mutable arguments.
793    This will adjust the type parameters as needed for pytd functions like:
794      def append_float(x: list[int]):
795        x = list[int or float]
796    This is called after all the signature matching has succeeded, and we
797    know we're actually calling this function.
799    Args:
800      node: The current CFG node.
801      arg_dict: A map of strings to pytd.Bindings instances.
802      subst: Current type parameters.
803      retvar: A variable of the return value.
804    Returns:
805      A list of Mutation instances.
806    Raises:
807      ValueError: If the pytd contains invalid information for mutated params.
808    """
809    # Handle mutable parameters using the information type parameters
810    mutations = []
811    # It's possible that the signature contains type parameters that are used
812    # in mutations but are not filled in by the arguments, e.g. when starargs
813    # and starstarargs have type parameters but are not in the args. Check that
814    # subst has an entry for every type parameter, adding any that are missing.
815    if any(f.mutated_type for f in self.pytd_sig.params):
816      subst = subst.copy()
817      for t in pytd_utils.GetTypeParameters(self.pytd_sig):
818        if t.full_name not in subst:
819          subst[t.full_name] = self.vm.convert.empty.to_variable(node)
820    for formal in self.pytd_sig.params:
821      actual = arg_dict[formal.name]
822      arg = actual.data
823      if (formal.mutated_type is not None and arg.isinstance_SimpleValue()):
824        try:
825          all_names_actuals = self._collect_mutated_parameters(
826              formal.type, formal.mutated_type)
827        except ValueError as e:
828          log.error("Old: %s", pytd_utils.Print(formal.type))
829          log.error("New: %s", pytd_utils.Print(formal.mutated_type))
830          log.error("Actual: %r", actual)
831          raise ValueError("Mutable parameters setting a type to a "
832                           "different base type is not allowed.") from e
833        for names_actuals in all_names_actuals:
834          for tparam, type_actual in names_actuals:
835            log.info("Mutating %s to %s",
836                     tparam.name,
837                     pytd_utils.Print(type_actual))
838            type_actual_val = self.vm.convert.constant_to_var(
839                abstract_utils.AsInstance(type_actual), subst, node,
840                discard_concrete_values=True)
841            mutations.append(Mutation(arg, tparam.full_name, type_actual_val))
842    if self.name == "__new__":
843      # This is a constructor, so check whether the constructed instance needs
844      # to be mutated.
845      for ret in retvar.data:
846        if ret.cls:
847          for t in ret.cls.template:
848            if t.full_name in subst:
849              mutations.append(Mutation(ret, t.full_name, subst[t.full_name]))
850    return mutations
852  def get_positional_names(self):
853    return [p.name for p in self.pytd_sig.params
854            if not p.kwonly]
856  def set_defaults(self, defaults):
857    """Set signature's default arguments. Requires rebuilding PyTD signature.
859    Args:
860      defaults: An iterable of function argument defaults.
862    Returns:
863      Self with an updated signature.
864    """
865    defaults = list(defaults)
866    params = []
867    for param in reversed(self.pytd_sig.params):
868      if defaults:
869        defaults.pop()  # Discard the default. Unless we want to update type?
870        params.append(pytd.Parameter(
871            name=param.name,
872            type=param.type,
873            kwonly=param.kwonly,
874            optional=True,
875            mutated_type=param.mutated_type
876        ))
877      else:
878        params.append(pytd.Parameter(
879            name=param.name,
880            type=param.type,
881            kwonly=param.kwonly,
882            optional=False,  # Reset any previously-set defaults
883            mutated_type=param.mutated_type
884        ))
885    new_sig = pytd.Signature(
886        params=tuple(reversed(params)),
887        starargs=self.pytd_sig.starargs,
888        starstarargs=self.pytd_sig.starstarargs,
889        return_type=self.pytd_sig.return_type,
890        exceptions=self.pytd_sig.exceptions,
891        template=self.pytd_sig.template
892    )
893    # Now update self
894    self.pytd_sig = new_sig
895    self.param_types = [
896        self.vm.convert.constant_to_value(
897            p.type, subst=datatypes.AliasingDict(), node=self.vm.root_node)
898        for p in self.pytd_sig.params
899    ]
900    self.signature = Signature.from_pytd(self.vm, self.name, self.pytd_sig)
901    return self
903  def __repr__(self):
904    return pytd_utils.Print(self.pytd_sig)
907def _splats_to_any(seq, vm):
908  return tuple(
909      vm.new_unsolvable(vm.root_node) if abstract_utils.is_var_splat(v) else v
910      for v in seq)