1"""Utilities for inline type annotations.""" 2 3import collections 4import itertools 5import sys 6 7from pytype import abstract 8from pytype import abstract_utils 9from pytype import class_mixin 10from pytype import mixin 11from pytype import utils 12from pytype.overlays import typing_overlay 13 14 15class AnnotationsUtil(utils.VirtualMachineWeakrefMixin): 16 """Utility class for inline type annotations.""" 17 18 def sub_annotations(self, node, annotations, substs, instantiate_unbound): 19 """Apply type parameter substitutions to a dictionary of annotations.""" 20 if substs and all(substs): 21 return {name: self.sub_one_annotation(node, annot, substs, 22 instantiate_unbound) 23 for name, annot in annotations.items()} 24 return annotations 25 26 def sub_one_annotation(self, node, annot, substs, instantiate_unbound=True): 27 """Apply type parameter substitutions to an annotation.""" 28 if isinstance(annot, abstract.TypeParameter): 29 # We use the given substitutions to bind the annotation if 30 # (1) every subst provides at least one binding, and 31 # (2) none of the bindings are ambiguous, and 32 # (3) at least one binding is non-empty. 33 if all(annot.full_name in subst and subst[annot.full_name].bindings 34 for subst in substs): 35 vals = sum((subst[annot.full_name].data for subst in substs), []) 36 else: 37 vals = None 38 if (vals is None or 39 any(isinstance(v, abstract.AMBIGUOUS) for v in vals) or 40 all(isinstance(v, abstract.Empty) for v in vals)): 41 if instantiate_unbound: 42 vals = annot.instantiate(node).data 43 else: 44 vals = [annot] 45 return self.vm.convert.merge_classes(vals) 46 elif isinstance(annot, mixin.NestedAnnotation): 47 inner_types = [(key, self.sub_one_annotation(node, val, substs, 48 instantiate_unbound)) 49 for key, val in annot.get_inner_types()] 50 return annot.replace(inner_types) 51 return annot 52 53 def get_late_annotations(self, annot): 54 if annot.is_late_annotation() and not annot.resolved: 55 yield annot 56 elif isinstance(annot, mixin.NestedAnnotation): 57 for _, typ in annot.get_inner_types(): 58 yield from self.get_late_annotations(typ) 59 60 def remove_late_annotations(self, annot): 61 """Replace unresolved late annotations with unsolvables.""" 62 if annot.is_late_annotation() and not annot.resolved: 63 return self.vm.convert.unsolvable 64 elif isinstance(annot, mixin.NestedAnnotation): 65 inner_types = [(key, self.remove_late_annotations(val)) 66 for key, val in annot.get_inner_types()] 67 return annot.replace(inner_types) 68 return annot 69 70 def add_scope(self, annot, types, module): 71 """Add scope for type parameters. 72 73 In original type class, all type parameters that should be added a scope 74 will be replaced with a new copy. 75 76 Args: 77 annot: The type class. 78 types: A type name list that should be added a scope. 79 module: Module name. 80 81 Returns: 82 The type with fresh type parameters that have been added the scope. 83 """ 84 if isinstance(annot, abstract.TypeParameter): 85 if annot.name in types: 86 new_annot = annot.copy() 87 new_annot.module = module 88 return new_annot 89 return annot 90 elif isinstance(annot, abstract.TupleClass): 91 params = {} 92 for name, param in annot.formal_type_parameters.items(): 93 params[name] = self.add_scope(param, types, module) 94 return abstract.TupleClass( 95 annot.base_cls, params, self.vm, annot.template) 96 elif isinstance(annot, mixin.NestedAnnotation): 97 inner_types = [(key, self.add_scope(typ, types, module)) 98 for key, typ in annot.get_inner_types()] 99 return annot.replace(inner_types) 100 return annot 101 102 def get_type_parameters(self, annot, seen=None): 103 """Returns all the TypeParameter instances that appear in the annotation. 104 105 Note that if you just need to know whether or not the annotation contains 106 type parameters, you can check its `.formal` attribute. 107 108 Args: 109 annot: An annotation. 110 seen: A seen set. 111 """ 112 seen = seen or set() 113 if annot in seen: 114 return [] 115 if isinstance(annot, abstract.ParameterizedClass): 116 # We track parameterized classes to avoid recursion errors when a class 117 # contains itself. 118 seen = seen | {annot} 119 if isinstance(annot, abstract.TypeParameter): 120 return [annot] 121 elif isinstance(annot, abstract.TupleClass): 122 annots = [] 123 for idx in range(annot.tuple_length): 124 annots.extend(self.get_type_parameters( 125 annot.formal_type_parameters[idx], seen)) 126 return annots 127 elif isinstance(annot, mixin.NestedAnnotation): 128 return sum((self.get_type_parameters(t, seen) 129 for _, t in annot.get_inner_types()), []) 130 return [] 131 132 def get_callable_type_parameter_names(self, var): 133 """Gets all TypeParameter names that appear in a Callable in 'var'.""" 134 type_params = set() 135 seen = set() 136 stack = list(var.data) 137 while stack: 138 annot = stack.pop() 139 if annot in seen: 140 continue 141 seen.add(annot) 142 if annot.full_name == "typing.Callable": 143 params = collections.Counter(self.get_type_parameters(annot)) 144 if isinstance(annot, abstract.CallableClass): 145 # pytype represents Callable[[T1, T2], None] as 146 # CallableClass({0: T1, 1: T2, ARGS: Union[T1, T2], RET: None}), 147 # so we have to fix double-counting of argument type parameters. 148 params -= collections.Counter(self.get_type_parameters( 149 annot.formal_type_parameters[abstract_utils.ARGS])) 150 # Type parameters that appear only once in a function signature are 151 # invalid, so ignore them. 152 type_params.update(p.name for p, n in params.items() if n > 1) 153 elif isinstance(annot, mixin.NestedAnnotation): 154 stack.extend(v for _, v in annot.get_inner_types()) 155 return type_params 156 157 def convert_function_type_annotation(self, name, typ): 158 visible = typ.data 159 if len(visible) > 1: 160 self.vm.errorlog.ambiguous_annotation(self.vm.frames, visible, name) 161 return None 162 else: 163 return visible[0] 164 165 def convert_function_annotations(self, node, raw_annotations): 166 """Convert raw annotations to a {name: annotation} dict.""" 167 if raw_annotations: 168 # {"i": int, "return": str} is stored as (int, str, ("i", "return")) 169 names = abstract_utils.get_atomic_python_constant(raw_annotations[-1]) 170 type_list = raw_annotations[:-1] 171 annotations_list = [] 172 for name, t in zip(names, type_list): 173 name = abstract_utils.get_atomic_python_constant(name) 174 t = self.convert_function_type_annotation(name, t) 175 annotations_list.append((name, t)) 176 return self.convert_annotations_list(node, annotations_list) 177 else: 178 return {} 179 180 def convert_annotations_list(self, node, annotations_list): 181 """Convert a (name, raw_annot) list to a {name: annotation} dict.""" 182 annotations = {} 183 for name, t in annotations_list: 184 if t is None: 185 continue 186 annot = self._process_one_annotation( 187 node, t, name, self.vm.simple_stack()) 188 if annot is not None: 189 annotations[name] = annot 190 return annotations 191 192 def convert_class_annotations(self, node, raw_annotations): 193 """Convert a name -> raw_annot dict to annotations.""" 194 annotations = {} 195 raw_items = raw_annotations.items() 196 if sys.version_info[:2] < (3, 6): 197 # Make sure annotation errors are reported in a deterministic order. 198 raw_items = sorted(raw_items, key=str) 199 for name, t in raw_items: 200 # Don't use the parameter name, since it's often something unhelpful 201 # like `0`. 202 annot = self._process_one_annotation( 203 node, t, None, self.vm.simple_stack()) 204 annotations[name] = annot or self.vm.convert.unsolvable 205 return annotations 206 207 def init_annotation(self, node, name, annot, container=None, extra_key=None): 208 node, value = self.vm.init_class( 209 node, annot, container=container, extra_key=extra_key) 210 for d in value.data: 211 d.from_annotation = name 212 return node, value 213 214 def extract_and_init_annotation(self, node, name, var): 215 """Extracts an annotation from var and instantiates it.""" 216 frame = self.vm.frame 217 substs = frame.substs 218 if frame.func and isinstance(frame.func.data, abstract.BoundFunction): 219 self_var = frame.f_locals.pyval.get("self") 220 if self_var: 221 type_params = [] 222 for v in self_var.data: 223 if v.cls: 224 # Normalize type parameter names by dropping the scope. 225 type_params.extend(p.with_module(None) for p in v.cls.template) 226 self_substs = tuple( 227 abstract_utils.get_type_parameter_substitutions(v, type_params) 228 for v in self_var.data) 229 substs = abstract_utils.combine_substs(substs, self_substs) 230 allowed_type_params = set( 231 itertools.chain(*substs, self.get_callable_type_parameter_names(var))) 232 typ = self.extract_annotation( 233 node, var, name, self.vm.simple_stack(), 234 allowed_type_params=allowed_type_params) 235 if typ.formal: 236 resolved_type = self.sub_one_annotation(node, typ, substs, 237 instantiate_unbound=False) 238 _, value = self.init_annotation(node, name, resolved_type) 239 else: 240 _, value = self.init_annotation(node, name, typ) 241 return typ, value 242 243 def apply_annotation(self, node, op, name, value): 244 """If there is an annotation for the op, return its value.""" 245 assert op is self.vm.frame.current_opcode 246 if op.code.co_filename != self.vm.filename: 247 return None, value 248 if not op.annotation: 249 return None, value 250 annot = op.annotation 251 frame = self.vm.frame 252 with self.vm.generate_late_annotations(self.vm.simple_stack()): 253 var, errorlog = abstract_utils.eval_expr( 254 self.vm, node, frame.f_globals, frame.f_locals, annot) 255 if errorlog: 256 self.vm.errorlog.invalid_annotation( 257 self.vm.frames, annot, details=errorlog.details) 258 return self.extract_and_init_annotation(node, name, var) 259 260 def extract_annotation( 261 self, node, var, name, stack, allowed_type_params=None): 262 """Returns an annotation extracted from 'var'. 263 264 Args: 265 node: The current node. 266 var: The variable to extract from. 267 name: The annotated name. 268 stack: The frame stack. 269 allowed_type_params: Type parameters that are allowed to appear in the 270 annotation. 'None' means all are allowed. 271 """ 272 try: 273 typ = abstract_utils.get_atomic_value(var) 274 except abstract_utils.ConversionError: 275 self.vm.errorlog.ambiguous_annotation(self.vm.frames, None, name) 276 return self.vm.convert.unsolvable 277 typ = self._process_one_annotation(node, typ, name, stack) 278 if not typ: 279 return self.vm.convert.unsolvable 280 if typ.formal and allowed_type_params is not None: 281 illegal_params = [x.name for x in self.get_type_parameters(typ) 282 if x.name not in allowed_type_params] 283 if illegal_params: 284 details = "TypeVar(s) %s not in scope" % ", ".join( 285 repr(p) for p in utils.unique_list(illegal_params)) 286 if self.vm.frame.func: 287 method = self.vm.frame.func.data 288 if isinstance(method, abstract.BoundFunction): 289 desc = "class" 290 frame_name = method.name.rsplit(".", 1)[0] 291 else: 292 desc = "class" if method.is_class_builder else "method" 293 frame_name = method.name 294 details += f" for {desc} {frame_name!r}" 295 if "AnyStr" in illegal_params: 296 str_type = "Union[str, bytes]" 297 details += (f"\nNote: For all string types, use {str_type}.") 298 self.vm.errorlog.invalid_annotation(stack, typ, details, name) 299 return self.vm.convert.unsolvable 300 return typ 301 302 def eval_multi_arg_annotation(self, node, func, annot, stack): 303 """Evaluate annotation for multiple arguments (from a type comment).""" 304 args, errorlog = self._eval_expr_as_tuple(node, annot, stack) 305 if errorlog: 306 self.vm.errorlog.invalid_function_type_comment( 307 stack, annot, details=errorlog.details) 308 code = func.code 309 expected = code.get_arg_count() 310 names = code.co_varnames 311 312 # This is a hack. Specifying the type of the first arg is optional in 313 # class and instance methods. There is no way to tell at this time 314 # how the function will be used, so if the first arg is self or cls we 315 # make it optional. The logic is somewhat convoluted because we don't 316 # want to count the skipped argument in an error message. 317 if len(args) != expected: 318 if expected and names[0] in ["self", "cls"]: 319 expected -= 1 320 names = names[1:] 321 322 if len(args) != expected: 323 self.vm.errorlog.invalid_function_type_comment( 324 stack, annot, 325 details="Expected %d args, %d given" % (expected, len(args))) 326 return 327 for name, arg in zip(names, args): 328 resolved = self._process_one_annotation(node, arg, name, stack) 329 if resolved is not None: 330 func.signature.set_annotation(name, resolved) 331 332 def _process_one_annotation(self, node, annotation, name, stack): 333 """Change annotation / record errors where required.""" 334 # Make sure we pass in a frozen snapshot of the frame stack, rather than the 335 # actual stack, since late annotations need to snapshot the stack at time of 336 # creation in order to get the right line information for error messages. 337 assert isinstance(stack, tuple), "stack must be an immutable sequence" 338 339 if isinstance(annotation, abstract.AnnotationContainer): 340 annotation = annotation.base_cls 341 342 if isinstance(annotation, typing_overlay.Union): 343 self.vm.errorlog.invalid_annotation( 344 stack, annotation, "Needs options", name) 345 return None 346 elif (name is not None and name != "return" 347 and isinstance(annotation, typing_overlay.NoReturn)): 348 self.vm.errorlog.invalid_annotation( 349 stack, annotation, "NoReturn is not allowed", name) 350 return None 351 elif isinstance(annotation, abstract.Instance) and ( 352 annotation.cls == self.vm.convert.str_type or 353 annotation.cls == self.vm.convert.unicode_type 354 ): 355 # String annotations : Late evaluation 356 if isinstance(annotation, mixin.PythonConstant): 357 expr = annotation.pyval 358 if not expr: 359 self.vm.errorlog.invalid_annotation( 360 stack, annotation, "Cannot be an empty string", name) 361 return None 362 frame = self.vm.frame 363 # Immediately try to evaluate the reference, generating LateAnnotation 364 # objects as needed. We don't store the entire string as a 365 # LateAnnotation because: 366 # - With __future__.annotations, all annotations look like forward 367 # references - most of them don't need to be late evaluated. 368 # - Given an expression like "Union[str, NotYetDefined]", we want to 369 # evaluate the union immediately so we don't end up with a complex 370 # LateAnnotation, which can lead to bugs when instantiated. 371 with self.vm.generate_late_annotations(stack): 372 v, errorlog = abstract_utils.eval_expr( 373 self.vm, node, frame.f_globals, frame.f_locals, expr) 374 if errorlog: 375 self.vm.errorlog.copy_from(errorlog.errors, stack) 376 if len(v.data) == 1: 377 return self._process_one_annotation(node, v.data[0], name, stack) 378 self.vm.errorlog.ambiguous_annotation(stack, [annotation], name) 379 return None 380 elif annotation.cls == self.vm.convert.none_type: 381 # PEP 484 allows to write "NoneType" as "None" 382 return self.vm.convert.none_type 383 elif isinstance(annotation, mixin.NestedAnnotation): 384 if annotation.processed: 385 return annotation 386 annotation.processed = True 387 for key, typ in annotation.get_inner_types(): 388 processed = self._process_one_annotation(node, typ, name, stack) 389 if processed is None: 390 return None 391 elif isinstance(processed, typing_overlay.NoReturn): 392 self.vm.errorlog.invalid_annotation( 393 stack, typ, "NoReturn is not allowed as inner type", name) 394 return None 395 annotation.update_inner_type(key, processed) 396 return annotation 397 elif isinstance(annotation, (class_mixin.Class, 398 abstract.AMBIGUOUS_OR_EMPTY, 399 abstract.TypeParameter, 400 typing_overlay.NoReturn)): 401 return annotation 402 else: 403 self.vm.errorlog.invalid_annotation(stack, annotation, "Not a type", name) 404 return None 405 406 def _eval_expr_as_tuple(self, node, expr, stack): 407 """Evaluate an expression as a tuple.""" 408 if not expr: 409 return (), None 410 411 f_globals, f_locals = self.vm.frame.f_globals, self.vm.frame.f_locals 412 with self.vm.generate_late_annotations(stack): 413 result_var, errorlog = abstract_utils.eval_expr( 414 self.vm, node, f_globals, f_locals, expr) 415 result = abstract_utils.get_atomic_value(result_var) 416 # If the result is a tuple, expand it. 417 if (isinstance(result, mixin.PythonConstant) and 418 isinstance(result.pyval, tuple)): 419 return (tuple(abstract_utils.get_atomic_value(x) for x in result.pyval), 420 errorlog) 421 else: 422 return (result,), errorlog 423 424 def deformalize(self, value): 425 # TODO(rechen): Instead of doing this, call sub_one_annotation() to replace 426 # type parameters with their bound/constraints. 427 while value.formal: 428 if isinstance(value, abstract.ParameterizedClass): 429 value = value.base_cls 430 else: 431 value = self.vm.convert.unsolvable 432 return value 433