1"""Representation of Python function headers and calls.""" 2 3import collections 4import itertools 5import logging 6 7from pytype import abstract_utils 8from pytype import datatypes 9from pytype import utils 10from pytype.pytd import pytd 11from pytype.pytd import pytd_utils 12 13log = logging.getLogger(__name__) 14 15 16def argname(i): 17 """Get a name for an unnamed positional argument, given its position.""" 18 return "_" + str(i) 19 20 21def _print(t): 22 return pytd_utils.Print(t.get_instance_type()) 23 24 25class Signature: 26 """Representation of a Python function signature. 27 28 Attributes: 29 name: Name of the function. 30 param_names: A tuple of positional parameter names. 31 varargs_name: Name of the varargs parameter. (The "args" in *args) 32 kwonly_params: Tuple of keyword-only parameters. (Python 3) 33 E.g. ("x", "y") for "def f(a, *, x, y=2)". These do NOT appear in 34 param_names. Ordered like in the source file. 35 kwargs_name: Name of the kwargs parameter. (The "kwargs" in **kwargs) 36 defaults: Dictionary, name to value, for all parameters with default values. 37 annotations: A dictionary of type annotations. (string to type) 38 excluded_types: A set of type names that will be ignored when checking the 39 count of type parameters. 40 type_params: The set of type parameter names that appear in annotations. 41 has_return_annotation: Whether the function has a return annotation. 42 has_param_annotations: Whether the function has parameter annotations. 43 """ 44 45 def __init__(self, name, param_names, varargs_name, kwonly_params, 46 kwargs_name, defaults, annotations, 47 postprocess_annotations=True): 48 self.name = name 49 self.param_names = param_names 50 self.varargs_name = varargs_name 51 self.kwonly_params = kwonly_params 52 self.kwargs_name = kwargs_name 53 self.defaults = defaults 54 self.annotations = annotations 55 self.excluded_types = set() 56 if postprocess_annotations: 57 for k, annot in self.annotations.items(): 58 self.annotations[k] = self._postprocess_annotation(k, annot) 59 self.type_params = set() 60 for annot in self.annotations.values(): 61 self.type_params.update( 62 p.name for p in annot.vm.annotations_util.get_type_parameters(annot)) 63 64 @property 65 def has_return_annotation(self): 66 return "return" in self.annotations 67 68 @property 69 def has_param_annotations(self): 70 return bool(self.annotations.keys() - {"return"}) 71 72 def add_scope(self, module): 73 """Add scope for type parameters in annotations.""" 74 annotations = {} 75 for key, val in self.annotations.items(): 76 annotations[key] = val.vm.annotations_util.add_scope( 77 val, self.excluded_types, module) 78 self.annotations = annotations 79 80 def _postprocess_annotation(self, name, annotation): 81 if name == self.varargs_name: 82 return annotation.vm.convert.create_new_varargs_value(annotation) 83 elif name == self.kwargs_name: 84 return annotation.vm.convert.create_new_kwargs_value(annotation) 85 else: 86 return annotation 87 88 def set_annotation(self, name, annotation): 89 self.annotations[name] = self._postprocess_annotation(name, annotation) 90 91 def del_annotation(self, name): 92 del self.annotations[name] # Raises KeyError if annotation does not exist. 93 94 def check_type_parameter_count(self, stack): 95 """Check the count of type parameters in function.""" 96 c = collections.Counter() 97 for annot in self.annotations.values(): 98 c.update(annot.vm.annotations_util.get_type_parameters(annot)) 99 for param, count in c.items(): 100 if param.name in self.excluded_types: 101 # skip all the type parameters in `excluded_types` 102 continue 103 if count == 1 and not (param.constraints or param.bound or 104 param.covariant or param.contravariant): 105 param.vm.errorlog.invalid_annotation( 106 stack, param, "Appears only once in the signature") 107 108 def drop_first_parameter(self): 109 return self._replace(param_names=self.param_names[1:]) 110 111 def mandatory_param_count(self): 112 num = len([name 113 for name in self.param_names if name not in self.defaults]) 114 num += len([name 115 for name in self.kwonly_params if name not in self.defaults]) 116 return num 117 118 def maximum_param_count(self): 119 if self.varargs_name or self.kwargs_name: 120 return None 121 return len(self.param_names) + len(self.kwonly_params) 122 123 @classmethod 124 def from_pytd(cls, vm, name, sig): 125 """Construct an abstract signature from a pytd signature.""" 126 pytd_annotations = [(p.name, p.type) 127 for p in sig.params + (sig.starargs, sig.starstarargs) 128 if p is not None] 129 pytd_annotations.append(("return", sig.return_type)) 130 def param_to_var(p): 131 return vm.convert.constant_to_var( 132 p.type, subst=datatypes.AliasingDict(), node=vm.root_node) 133 134 return cls( 135 name=name, 136 param_names=tuple(p.name for p in sig.params if not p.kwonly), 137 varargs_name=None if sig.starargs is None else sig.starargs.name, 138 kwonly_params=tuple(p.name for p in sig.params if p.kwonly), 139 kwargs_name=None if sig.starstarargs is None else sig.starstarargs.name, 140 defaults={p.name: param_to_var(p) for p in sig.params if p.optional}, 141 annotations={ 142 name: vm.convert.constant_to_value( 143 typ, subst=datatypes.AliasingDict(), node=vm.root_node) 144 for name, typ in pytd_annotations 145 }, 146 postprocess_annotations=False, 147 ) 148 149 @classmethod 150 def from_callable(cls, val): 151 annotations = {argname(i): val.formal_type_parameters[i] 152 for i in range(val.num_args)} 153 return cls( 154 name="<callable>", 155 param_names=tuple(sorted(annotations)), 156 varargs_name=None, 157 kwonly_params=(), 158 kwargs_name=None, 159 defaults={}, 160 annotations=annotations, 161 ) 162 163 @classmethod 164 def from_param_names(cls, name, param_names): 165 """Construct a minimal signature from a name and a list of param names.""" 166 return cls( 167 name=name, 168 param_names=tuple(param_names), 169 varargs_name=None, 170 kwonly_params=(), 171 kwargs_name=None, 172 defaults={}, 173 annotations={}, 174 ) 175 176 def has_param(self, name): 177 return name in self.param_names or name in self.kwonly_params or ( 178 name == self.varargs_name or name == self.kwargs_name) 179 180 def insert_varargs_and_kwargs(self, arg_dict): 181 """Insert varargs and kwargs from arg_dict into the signature. 182 183 Args: 184 arg_dict: A name->binding dictionary of passed args. 185 186 Returns: 187 A copy of this signature with the passed varargs and kwargs inserted. 188 """ 189 varargs_names = [] 190 kwargs_names = [] 191 for name in arg_dict: 192 if self.has_param(name): 193 continue 194 if pytd_utils.ANON_PARAM.match(name): 195 varargs_names.append(name) 196 else: 197 kwargs_names.append(name) 198 new_param_names = (self.param_names + tuple(sorted(varargs_names)) + 199 tuple(sorted(kwargs_names))) 200 return self._replace(param_names=new_param_names) 201 202 _ATTRIBUTES = ( 203 set(__init__.__code__.co_varnames[:__init__.__code__.co_argcount]) - 204 {"self", "postprocess_annotations"}) 205 206 def _replace(self, **kwargs): 207 """Returns a copy of the signature with the specified values replaced.""" 208 assert not set(kwargs) - self._ATTRIBUTES 209 for attr in self._ATTRIBUTES: 210 if attr not in kwargs: 211 kwargs[attr] = getattr(self, attr) 212 kwargs["postprocess_annotations"] = False 213 return type(self)(**kwargs) 214 215 def iter_args(self, args): 216 """Iterates through the given args, attaching names and expected types.""" 217 for i, posarg in enumerate(args.posargs): 218 if i < len(self.param_names): 219 name = self.param_names[i] 220 yield (name, posarg, self.annotations.get(name)) 221 elif self.varargs_name and self.varargs_name in self.annotations: 222 varargs_type = self.annotations[self.varargs_name] 223 formal = varargs_type.vm.convert.get_element_type(varargs_type) 224 yield (argname(i), posarg, formal) 225 else: 226 yield (argname(i), posarg, None) 227 for name, namedarg in sorted(args.namedargs.items()): 228 formal = self.annotations.get(name) 229 if formal is None and self.kwargs_name: 230 kwargs_type = self.annotations.get(self.kwargs_name) 231 if kwargs_type: 232 formal = kwargs_type.vm.convert.get_element_type(kwargs_type) 233 yield (name, namedarg, formal) 234 if self.varargs_name is not None and args.starargs is not None: 235 yield (self.varargs_name, args.starargs, 236 self.annotations.get(self.varargs_name)) 237 if self.kwargs_name is not None and args.starstarargs is not None: 238 yield (self.kwargs_name, args.starstarargs, 239 self.annotations.get(self.kwargs_name)) 240 241 def check_defaults(self): 242 """Returns the first non-default param following a default.""" 243 # TODO(mdemello): We should raise an error here, analogous to 244 # the python-compiler-error we would get if analyzing the signature from a 245 # source file, but this class does not have access to the vm, and the 246 # exception hierarchy in this module derives from FailedFunctionCall. 247 has_default = False 248 for name in self.param_names: 249 if name in self.defaults: 250 has_default = True 251 elif has_default: 252 return name 253 return None 254 255 def _yield_arguments(self): 256 """Yield all the function arguments.""" 257 names = list(self.param_names) 258 if self.varargs_name: 259 names.append("*" + self.varargs_name) 260 elif self.kwonly_params: 261 names.append("*") 262 names.extend(sorted(self.kwonly_params)) 263 if self.kwargs_name: 264 names.append("**" + self.kwargs_name) 265 for name in names: 266 base_name = name.lstrip("*") 267 annot = self._print_annot(base_name) 268 default = self._print_default(base_name) 269 yield name + (": " + annot if annot else "") + ( 270 " = " + default if default else "") 271 272 def _print_annot(self, name): 273 return _print(self.annotations[name]) if name in self.annotations else None 274 275 def _print_default(self, name): 276 if name in self.defaults: 277 values = self.defaults[name].data 278 if len(values) > 1: 279 return "Union[%s]" % ", ".join(_print(v) for v in values) 280 else: 281 return _print(values[0]) 282 else: 283 return None 284 285 def __repr__(self): 286 args = ", ".join(self._yield_arguments()) 287 ret = self._print_annot("return") 288 return "def {name}({args}) -> {ret}".format( 289 name=self.name, args=args, ret=ret if ret else "Any") 290 291 def get_first_arg(self, callargs): 292 return callargs.get(self.param_names[0]) if self.param_names else None 293 294 295class Args(collections.namedtuple( 296 "Args", ["posargs", "namedargs", "starargs", "starstarargs"])): 297 """Represents the parameters of a function call.""" 298 299 def __new__(cls, posargs, namedargs=None, starargs=None, starstarargs=None): 300 """Create arguments for a function under analysis. 301 302 Args: 303 posargs: The positional arguments. A tuple of cfg.Variable. 304 namedargs: The keyword arguments. A dictionary, mapping strings to 305 cfg.Variable. 306 starargs: The *args parameter, or None. 307 starstarargs: The **kwargs parameter, or None. 308 Returns: 309 An Args instance. 310 """ 311 assert isinstance(posargs, tuple), posargs 312 cls.replace = cls._replace 313 return super().__new__( 314 cls, 315 posargs=posargs, 316 namedargs=namedargs or {}, 317 starargs=starargs, 318 starstarargs=starstarargs) 319 320 def is_empty(self): 321 if self.posargs or self.starargs or self.starstarargs: 322 return False 323 if isinstance(self.namedargs, dict): 324 return not self.namedargs 325 else: 326 return not self.namedargs.pyval 327 328 def starargs_as_tuple(self, node, vm): 329 try: 330 args = self.starargs and abstract_utils.get_atomic_python_constant( 331 self.starargs, tuple) 332 except abstract_utils.ConversionError: 333 args = None 334 if not args: 335 return args 336 return tuple(var if var.bindings else vm.convert.empty.to_variable(node) 337 for var in args) 338 339 def starstarargs_as_dict(self): 340 try: 341 args = self.starstarargs and abstract_utils.get_atomic_python_constant( 342 self.starstarargs, dict) 343 except abstract_utils.ConversionError: 344 args = None 345 return args 346 347 def _expand_typed_star(self, vm, node, star, count): 348 """Convert *xs: Sequence[T] -> [T, T, ...].""" 349 if not count: 350 return [] 351 p = abstract_utils.merged_type_parameter(node, star, abstract_utils.T) 352 if not p.bindings: 353 # TODO(b/159052609): This shouldn't happen. For some reason, 354 # namedtuple instances don't have any bindings in T; see 355 # tests/test_unpack:TestUnpack.test_unpack_namedtuple. 356 return [vm.new_unsolvable(node) for _ in range(count)] 357 return [p.AssignToNewVariable(node) for _ in range(count)] 358 359 def _unpack_and_match_args(self, node, vm, match_signature, starargs_tuple): 360 """Match args against a signature with unpacking.""" 361 posargs = self.posargs 362 namedargs = self.namedargs 363 # As we have the function signature we will attempt to adjust the 364 # starargs into the missing posargs. 365 pre = [] 366 post = [] 367 stars = collections.deque(starargs_tuple) 368 while stars and not abstract_utils.is_var_splat(stars[0]): 369 pre.append(stars.popleft()) 370 while stars and not abstract_utils.is_var_splat(stars[-1]): 371 post.append(stars.pop()) 372 post.reverse() 373 n_matched = len(posargs) + len(pre) + len(post) 374 required_posargs = 0 375 for p in match_signature.param_names: 376 if p in namedargs or p in match_signature.defaults: 377 break 378 required_posargs += 1 379 posarg_delta = required_posargs - n_matched 380 381 if stars and not post: 382 star = stars[-1] 383 if match_signature.varargs_name: 384 # If the invocation ends with `*args`, return it to match against *args 385 # in the function signature. For f(<k args>, *xs, ..., *ys), transform 386 # to f(<k args>, *ys) since ys is an indefinite tuple anyway and will 387 # match against all remaining posargs. 388 return posargs + tuple(pre), abstract_utils.unwrap_splat(star) 389 else: 390 # If we do not have a `*args` in match_signature, just expand the 391 # terminal splat to as many args as needed and then drop it. 392 mid = self._expand_typed_star(vm, node, star, posarg_delta) 393 return posargs + tuple(pre + mid), None 394 elif posarg_delta <= len(stars): 395 # We have too many args; don't do *xs expansion. Go back to matching from 396 # the start and treat every entry in starargs_tuple as length 1. 397 n_params = len(match_signature.param_names) 398 all_args = posargs + starargs_tuple 399 if not match_signature.varargs_name: 400 # If the function sig has no *args, return everything in posargs 401 pos = _splats_to_any(all_args, vm) 402 return pos, None 403 # Don't unwrap splats here because f(*xs, y) is not the same as f(xs, y). 404 # TODO(mdemello): Ideally, since we are matching call f(*xs, y) against 405 # sig f(x, y) we should raise an error here. 406 pos = _splats_to_any(all_args[:n_params], vm) 407 star = [] 408 for var in all_args[n_params:]: 409 if abstract_utils.is_var_splat(var): 410 star.append( 411 abstract_utils.merged_type_parameter(node, var, abstract_utils.T)) 412 else: 413 star.append(var) 414 if star: 415 return pos, vm.convert.tuple_to_value(star).to_variable(node) 416 else: 417 return pos, None 418 elif stars: 419 if len(stars) == 1: 420 # Special case (<pre>, *xs) and (*xs, <post>) to fill in the type of xs 421 # in every remaining arg. 422 mid = self._expand_typed_star(vm, node, stars[0], posarg_delta) 423 else: 424 # If we have (*xs, <k args>, *ys) remaining, and more than k+2 params to 425 # match, don't try to match the intermediate params to any range, just 426 # match all k+2 to Any 427 mid = [vm.new_unsolvable(node) for _ in range(posarg_delta)] 428 return posargs + tuple(pre + mid + post), None 429 else: 430 # We have **kwargs but no *args in the invocation 431 return posargs + tuple(pre), None 432 433 def simplify(self, node, vm, match_signature=None): 434 """Try to insert part of *args, **kwargs into posargs / namedargs.""" 435 # TODO(rechen): When we have type information about *args/**kwargs, 436 # we need to check it before doing this simplification. 437 posargs = self.posargs 438 namedargs = self.namedargs 439 starargs = self.starargs 440 starstarargs = self.starstarargs 441 # Unpack starstarargs into namedargs. We need to do this first so we can see 442 # what posargs are still required. 443 starstarargs_as_dict = self.starstarargs_as_dict() 444 if starstarargs_as_dict is not None: 445 # Unlike varargs below, we do not adjust starstarargs into namedargs when 446 # the function signature has matching param_names because we have not 447 # found a benefit in doing so. 448 if self.namedargs is None: 449 namedargs = starstarargs_as_dict 450 else: 451 namedargs.update(node, starstarargs_as_dict) 452 starstarargs = None 453 starargs_as_tuple = self.starargs_as_tuple(node, vm) 454 if starargs_as_tuple is not None: 455 if match_signature: 456 posargs, starargs = self._unpack_and_match_args( 457 node, vm, match_signature, starargs_as_tuple) 458 elif (starargs_as_tuple and 459 abstract_utils.is_var_splat(starargs_as_tuple[-1])): 460 # If the last arg is an indefinite iterable keep it in starargs. Convert 461 # any other splats to Any. 462 # TODO(mdemello): If there are multiple splats should we just fall 463 # through to the next case (setting them all to Any), and only hit this 464 # case for a *single* splat in terminal position? 465 posargs = self.posargs + _splats_to_any(starargs_as_tuple[:-1], vm) 466 starargs = abstract_utils.unwrap_splat(starargs_as_tuple[-1]) 467 else: 468 # Don't try to unpack iterables in any other position since we don't 469 # have a signature to match. Just set all splats to Any. 470 posargs = self.posargs + _splats_to_any(starargs_as_tuple, vm) 471 starargs = None 472 return Args(posargs, namedargs, starargs, starstarargs) 473 474 def get_variables(self): 475 variables = list(self.posargs) + list(self.namedargs.values()) 476 if self.starargs is not None: 477 variables.append(self.starargs) 478 if self.starstarargs is not None: 479 variables.append(self.starstarargs) 480 return variables 481 482 483class ReturnValueMixin: 484 """Mixin for exceptions that hold a return node and variable.""" 485 486 def __init__(self): 487 super().__init__() 488 self.return_node = None 489 self.return_variable = None 490 491 def set_return(self, node, var): 492 self.return_node = node 493 self.return_variable = var 494 495 def get_return(self, state): 496 return state.change_cfg_node(self.return_node), self.return_variable 497 498 499# These names are chosen to match pytype error classes. 500# pylint: disable=g-bad-exception-name 501class FailedFunctionCall(Exception, ReturnValueMixin): 502 """Exception for failed function calls.""" 503 504 def __gt__(self, other): 505 return other is None 506 507 508class NotCallable(FailedFunctionCall): 509 """For objects that don't have __call__.""" 510 511 def __init__(self, obj): 512 super().__init__() 513 self.obj = obj 514 515 516class UndefinedParameterError(FailedFunctionCall): 517 """Function called with an undefined variable.""" 518 519 def __init__(self, name): 520 super().__init__() 521 self.name = name 522 523 524class DictKeyMissing(Exception, ReturnValueMixin): 525 """When retrieving a key that does not exist in a dict.""" 526 527 def __init__(self, name): 528 super().__init__() 529 self.name = name 530 531 def __gt__(self, other): 532 return other is None 533 534 535BadCall = collections.namedtuple("_", ["sig", "passed_args", "bad_param"]) 536 537 538class BadParam( 539 collections.namedtuple("_", ["name", "expected", "protocol_error", 540 "noniterable_str_error"])): 541 542 def __new__(cls, name, expected, protocol_error=None, 543 noniterable_str_error=None): 544 return super().__new__(cls, name, expected, protocol_error, 545 noniterable_str_error) 546 547 548class InvalidParameters(FailedFunctionCall): 549 """Exception for functions called with an incorrect parameter combination.""" 550 551 def __init__(self, sig, passed_args, vm, bad_param=None): 552 super().__init__() 553 self.name = sig.name 554 passed_args = [(name, vm.merge_values(arg.data)) 555 for name, arg, _ in sig.iter_args(passed_args)] 556 self.bad_call = BadCall(sig=sig, passed_args=passed_args, 557 bad_param=bad_param) 558 559 560class WrongArgTypes(InvalidParameters): 561 """For functions that were called with the wrong types.""" 562 563 def __gt__(self, other): 564 return other is None or (isinstance(other, FailedFunctionCall) and 565 not isinstance(other, WrongArgTypes)) 566 567 568class WrongArgCount(InvalidParameters): 569 """E.g. if a function expecting 4 parameters is called with 3.""" 570 571 572class WrongKeywordArgs(InvalidParameters): 573 """E.g. an arg "x" is passed to a function that doesn't have an "x" param.""" 574 575 def __init__(self, sig, passed_args, vm, extra_keywords): 576 super().__init__(sig, passed_args, vm) 577 self.extra_keywords = tuple(extra_keywords) 578 579 580class DuplicateKeyword(InvalidParameters): 581 """E.g. an arg "x" is passed to a function as both a posarg and a kwarg.""" 582 583 def __init__(self, sig, passed_args, vm, duplicate): 584 super().__init__(sig, passed_args, vm) 585 self.duplicate = duplicate 586 587 588class MissingParameter(InvalidParameters): 589 """E.g. a function requires parameter 'x' but 'x' isn't passed.""" 590 591 def __init__(self, sig, passed_args, vm, missing_parameter): 592 super().__init__(sig, passed_args, vm) 593 self.missing_parameter = missing_parameter 594# pylint: enable=g-bad-exception-name 595 596 597class Mutation(collections.namedtuple("_", ["instance", "name", "value"])): 598 599 def __eq__(self, other): 600 return (self.instance == other.instance and 601 self.name == other.name and 602 frozenset(self.value.data) == frozenset(other.value.data)) 603 604 def __hash__(self): 605 return hash((self.instance, self.name, frozenset(self.value.data))) 606 607 608class PyTDSignature(utils.VirtualMachineWeakrefMixin): 609 """A PyTD function type (signature). 610 611 This represents instances of functions with specific arguments and return 612 type. 613 """ 614 615 def __init__(self, name, pytd_sig, vm): 616 super().__init__(vm) 617 self.name = name 618 self.pytd_sig = pytd_sig 619 self.param_types = [ 620 self.vm.convert.constant_to_value( 621 p.type, subst=datatypes.AliasingDict(), node=self.vm.root_node) 622 for p in self.pytd_sig.params 623 ] 624 self.signature = Signature.from_pytd(vm, name, pytd_sig) 625 626 def _map_args(self, args, view): 627 """Map the passed arguments to a name->binding dictionary. 628 629 Args: 630 args: The passed arguments. 631 view: A variable->binding dictionary. 632 633 Returns: 634 A tuple of: 635 a list of formal arguments, each a (name, abstract value) pair; 636 a name->binding dictionary of the passed arguments. 637 638 Raises: 639 InvalidParameters: If the passed arguments don't match this signature. 640 """ 641 formal_args = [(p.name, self.signature.annotations[p.name]) 642 for p in self.pytd_sig.params] 643 arg_dict = {} 644 645 # positional args 646 for name, arg in zip(self.signature.param_names, args.posargs): 647 arg_dict[name] = view[arg] 648 num_expected_posargs = len(self.signature.param_names) 649 if len(args.posargs) > num_expected_posargs and not self.pytd_sig.starargs: 650 raise WrongArgCount(self.signature, args, self.vm) 651 # Extra positional args are passed via the *args argument. 652 varargs_type = self.signature.annotations.get(self.signature.varargs_name) 653 if varargs_type and varargs_type.isinstance_ParameterizedClass(): 654 for (i, vararg) in enumerate(args.posargs[num_expected_posargs:]): 655 name = argname(num_expected_posargs + i) 656 arg_dict[name] = view[vararg] 657 formal_args.append( 658 (name, varargs_type.get_formal_type_parameter(abstract_utils.T))) 659 660 # named args 661 for name, arg in args.namedargs.items(): 662 if name in arg_dict: 663 raise DuplicateKeyword(self.signature, args, self.vm, name) 664 arg_dict[name] = view[arg] 665 extra_kwargs = set(args.namedargs) - {p.name for p in self.pytd_sig.params} 666 if extra_kwargs and not self.pytd_sig.starstarargs: 667 raise WrongKeywordArgs(self.signature, args, self.vm, extra_kwargs) 668 # Extra keyword args are passed via the **kwargs argument. 669 kwargs_type = self.signature.annotations.get(self.signature.kwargs_name) 670 if kwargs_type and kwargs_type.isinstance_ParameterizedClass(): 671 # We sort the kwargs so that matching always happens in the same order. 672 for name in sorted(extra_kwargs): 673 formal_args.append( 674 (name, kwargs_type.get_formal_type_parameter(abstract_utils.V))) 675 676 # packed args 677 packed_args = [("starargs", self.signature.varargs_name), 678 ("starstarargs", self.signature.kwargs_name)] 679 for arg_type, name in packed_args: 680 actual = getattr(args, arg_type) 681 pytd_val = getattr(self.pytd_sig, arg_type) 682 if actual and pytd_val: 683 arg_dict[name] = view[actual] 684 # The annotation is Tuple or Dict, but the passed arg only has to be 685 # Iterable or Mapping. 686 typ = self.vm.convert.widen_type(self.signature.annotations[name]) 687 formal_args.append((name, typ)) 688 689 return formal_args, arg_dict 690 691 def _fill_in_missing_parameters(self, node, args, arg_dict): 692 for p in self.pytd_sig.params: 693 if p.name not in arg_dict: 694 if (not p.optional and args.starargs is None and 695 args.starstarargs is None): 696 raise MissingParameter(self.signature, args, self.vm, p.name) 697 # Assume the missing parameter is filled in by *args or **kwargs. 698 # Unfortunately, we can't easily use *args or **kwargs to fill in 699 # something more precise, since we need a Value, not a Variable. 700 arg_dict[p.name] = self.vm.convert.unsolvable.to_binding(node) 701 702 def substitute_formal_args(self, node, args, view, alias_map): 703 """Substitute matching args into this signature. Used by PyTDFunction.""" 704 formal_args, arg_dict = self._map_args(args, view) 705 self._fill_in_missing_parameters(node, args, arg_dict) 706 subst, bad_arg = self.vm.matcher(node).compute_subst( 707 formal_args, arg_dict, view, alias_map) 708 if subst is None: 709 if self.signature.has_param(bad_arg.name): 710 signature = self.signature 711 else: 712 signature = self.signature.insert_varargs_and_kwargs(arg_dict) 713 raise WrongArgTypes(signature, args, self.vm, bad_param=bad_arg) 714 if log.isEnabledFor(logging.DEBUG): 715 log.debug("Matched arguments against sig%s", 716 pytd_utils.Print(self.pytd_sig)) 717 for nr, p in enumerate(self.pytd_sig.params): 718 log.info("param %d) %s: %s <=> %s", nr, p.name, p.type, arg_dict[p.name]) 719 for name, var in sorted(subst.items()): 720 log.debug("Using %s=%r %r", name, var, var.data) 721 722 return arg_dict, subst 723 724 def instantiate_return(self, node, subst, sources): 725 return_type = self.pytd_sig.return_type 726 # Type parameter values, which are instantiated by the matcher, will end up 727 # in the return value. Since the matcher does not call __init__, we need to 728 # do that now. The one exception is that Type[X] does not instantiate X, so 729 # we do not call X.__init__. 730 if return_type.name != "builtins.type": 731 for param in pytd_utils.GetTypeParameters(return_type): 732 if param.full_name in subst: 733 node = self.vm.call_init(node, subst[param.full_name]) 734 try: 735 ret = self.vm.convert.constant_to_var( 736 abstract_utils.AsReturnValue(return_type), subst, node, 737 source_sets=[sources]) 738 except self.vm.convert.TypeParameterError: 739 # The return type contains a type parameter without a substitution. 740 subst = subst.copy() 741 for t in pytd_utils.GetTypeParameters(return_type): 742 if t.full_name not in subst: 743 subst[t.full_name] = self.vm.convert.empty.to_variable(node) 744 return node, self.vm.convert.constant_to_var( 745 abstract_utils.AsReturnValue(return_type), subst, node, 746 source_sets=[sources]) 747 if not ret.bindings and isinstance(return_type, pytd.TypeParameter): 748 ret.AddBinding(self.vm.convert.empty, [], node) 749 return node, ret 750 751 def call_with_args(self, node, func, arg_dict, 752 subst, ret_map, alias_map=None): 753 """Call this signature. Used by PyTDFunction.""" 754 t = (self.pytd_sig.return_type, subst) 755 sources = [func] + list(arg_dict.values()) 756 if t not in ret_map: 757 node, ret_map[t] = self.instantiate_return(node, subst, sources) 758 else: 759 # add the new sources 760 for data in ret_map[t].data: 761 ret_map[t].AddBinding(data, sources, node) 762 mutations = self._get_mutation(node, arg_dict, subst, ret_map[t]) 763 self.vm.trace_call(node, func, (self,), 764 tuple(arg_dict[p.name] for p in self.pytd_sig.params), 765 {}, ret_map[t]) 766 return node, ret_map[t], mutations 767 768 @classmethod 769 def _collect_mutated_parameters(cls, typ, mutated_type): 770 if (isinstance(typ, pytd.UnionType) and 771 isinstance(mutated_type, pytd.UnionType)): 772 if len(typ.type_list) != len(mutated_type.type_list): 773 raise ValueError( 774 "Type list lengths do not match:\nOld: %s\nNew: %s" % 775 (typ.type_list, mutated_type.type_list)) 776 return itertools.chain.from_iterable( 777 cls._collect_mutated_parameters(t1, t2) 778 for t1, t2 in zip(typ.type_list, mutated_type.type_list)) 779 if typ == mutated_type and isinstance(typ, pytd.ClassType): 780 return [] # no mutation needed 781 if (not isinstance(typ, pytd.GenericType) or 782 not isinstance(mutated_type, pytd.GenericType) or 783 typ.base_type != mutated_type.base_type or 784 not isinstance(typ.base_type, pytd.ClassType) or 785 not typ.base_type.cls): 786 raise ValueError("Unsupported mutation:\n%r ->\n%r" % 787 (typ, mutated_type)) 788 return [zip(mutated_type.base_type.cls.template, mutated_type.parameters)] 789 790 def _get_mutation(self, node, arg_dict, subst, retvar): 791 """Mutation for changing the type parameters of mutable arguments. 792 793 This will adjust the type parameters as needed for pytd functions like: 794 def append_float(x: list[int]): 795 x = list[int or float] 796 This is called after all the signature matching has succeeded, and we 797 know we're actually calling this function. 798 799 Args: 800 node: The current CFG node. 801 arg_dict: A map of strings to pytd.Bindings instances. 802 subst: Current type parameters. 803 retvar: A variable of the return value. 804 Returns: 805 A list of Mutation instances. 806 Raises: 807 ValueError: If the pytd contains invalid information for mutated params. 808 """ 809 # Handle mutable parameters using the information type parameters 810 mutations = [] 811 # It's possible that the signature contains type parameters that are used 812 # in mutations but are not filled in by the arguments, e.g. when starargs 813 # and starstarargs have type parameters but are not in the args. Check that 814 # subst has an entry for every type parameter, adding any that are missing. 815 if any(f.mutated_type for f in self.pytd_sig.params): 816 subst = subst.copy() 817 for t in pytd_utils.GetTypeParameters(self.pytd_sig): 818 if t.full_name not in subst: 819 subst[t.full_name] = self.vm.convert.empty.to_variable(node) 820 for formal in self.pytd_sig.params: 821 actual = arg_dict[formal.name] 822 arg = actual.data 823 if (formal.mutated_type is not None and arg.isinstance_SimpleValue()): 824 try: 825 all_names_actuals = self._collect_mutated_parameters( 826 formal.type, formal.mutated_type) 827 except ValueError as e: 828 log.error("Old: %s", pytd_utils.Print(formal.type)) 829 log.error("New: %s", pytd_utils.Print(formal.mutated_type)) 830 log.error("Actual: %r", actual) 831 raise ValueError("Mutable parameters setting a type to a " 832 "different base type is not allowed.") from e 833 for names_actuals in all_names_actuals: 834 for tparam, type_actual in names_actuals: 835 log.info("Mutating %s to %s", 836 tparam.name, 837 pytd_utils.Print(type_actual)) 838 type_actual_val = self.vm.convert.constant_to_var( 839 abstract_utils.AsInstance(type_actual), subst, node, 840 discard_concrete_values=True) 841 mutations.append(Mutation(arg, tparam.full_name, type_actual_val)) 842 if self.name == "__new__": 843 # This is a constructor, so check whether the constructed instance needs 844 # to be mutated. 845 for ret in retvar.data: 846 if ret.cls: 847 for t in ret.cls.template: 848 if t.full_name in subst: 849 mutations.append(Mutation(ret, t.full_name, subst[t.full_name])) 850 return mutations 851 852 def get_positional_names(self): 853 return [p.name for p in self.pytd_sig.params 854 if not p.kwonly] 855 856 def set_defaults(self, defaults): 857 """Set signature's default arguments. Requires rebuilding PyTD signature. 858 859 Args: 860 defaults: An iterable of function argument defaults. 861 862 Returns: 863 Self with an updated signature. 864 """ 865 defaults = list(defaults) 866 params = [] 867 for param in reversed(self.pytd_sig.params): 868 if defaults: 869 defaults.pop() # Discard the default. Unless we want to update type? 870 params.append(pytd.Parameter( 871 name=param.name, 872 type=param.type, 873 kwonly=param.kwonly, 874 optional=True, 875 mutated_type=param.mutated_type 876 )) 877 else: 878 params.append(pytd.Parameter( 879 name=param.name, 880 type=param.type, 881 kwonly=param.kwonly, 882 optional=False, # Reset any previously-set defaults 883 mutated_type=param.mutated_type 884 )) 885 new_sig = pytd.Signature( 886 params=tuple(reversed(params)), 887 starargs=self.pytd_sig.starargs, 888 starstarargs=self.pytd_sig.starstarargs, 889 return_type=self.pytd_sig.return_type, 890 exceptions=self.pytd_sig.exceptions, 891 template=self.pytd_sig.template 892 ) 893 # Now update self 894 self.pytd_sig = new_sig 895 self.param_types = [ 896 self.vm.convert.constant_to_value( 897 p.type, subst=datatypes.AliasingDict(), node=self.vm.root_node) 898 for p in self.pytd_sig.params 899 ] 900 self.signature = Signature.from_pytd(self.vm, self.name, self.pytd_sig) 901 return self 902 903 def __repr__(self): 904 return pytd_utils.Print(self.pytd_sig) 905 906 907def _splats_to_any(seq, vm): 908 return tuple( 909 vm.new_unsolvable(vm.root_node) if abstract_utils.is_var_splat(v) else v 910 for v in seq) 911