1#!/usr/bin/env python3 2"""Generator of dynamically typed draft stubs for arbitrary modules. 3 4The logic of this script can be split in three steps: 5* parsing options and finding sources: 6 - use runtime imports be default (to find also C modules) 7 - or use mypy's mechanisms, if importing is prohibited 8* (optionally) semantically analysing the sources using mypy (as a single set) 9* emitting the stubs text: 10 - for Python modules: from ASTs using StubGenerator 11 - for C modules using runtime introspection and (optionally) Sphinx docs 12 13During first and third steps some problematic files can be skipped, but any 14blocking error during second step will cause the whole program to stop. 15 16Basic usage: 17 18 $ stubgen foo.py bar.py some_directory 19 => Generate out/foo.pyi, out/bar.pyi, and stubs for some_directory (recursively). 20 21 $ stubgen -m urllib.parse 22 => Generate out/urllib/parse.pyi. 23 24 $ stubgen -p urllib 25 => Generate stubs for whole urlib package (recursively). 26 27For Python 2 mode, use --py2: 28 29 $ stubgen --py2 -m textwrap 30 31For C modules, you can get more precise function signatures by parsing .rst (Sphinx) 32documentation for extra information. For this, use the --doc-dir option: 33 34 $ stubgen --doc-dir <DIR>/Python-3.4.2/Doc/library -m curses 35 36Note: The generated stubs should be verified manually. 37 38TODO: 39 - support stubs for C modules in Python 2 mode 40 - detect 'if PY2 / is_py2' etc. and either preserve those or only include Python 2 or 3 case 41 - maybe use .rst docs also for Python modules 42 - maybe export more imported names if there is no __all__ (this affects ssl.SSLError, for example) 43 - a quick and dirty heuristic would be to turn this on if a module has something like 44 'from x import y as _y' 45 - we don't seem to always detect properties ('closed' in 'io', for example) 46""" 47 48import glob 49import os 50import os.path 51import sys 52import traceback 53import argparse 54from collections import defaultdict 55 56from typing import ( 57 List, Dict, Tuple, Iterable, Mapping, Optional, Set, cast, 58) 59from typing_extensions import Final 60 61import mypy.build 62import mypy.parse 63import mypy.errors 64import mypy.traverser 65import mypy.mixedtraverser 66import mypy.util 67from mypy import defaults 68from mypy.modulefinder import ( 69 ModuleNotFoundReason, FindModuleCache, SearchPaths, BuildSource, default_lib_path 70) 71from mypy.nodes import ( 72 Expression, IntExpr, UnaryExpr, StrExpr, BytesExpr, NameExpr, FloatExpr, MemberExpr, 73 TupleExpr, ListExpr, ComparisonExpr, CallExpr, IndexExpr, EllipsisExpr, 74 ClassDef, MypyFile, Decorator, AssignmentStmt, TypeInfo, 75 IfStmt, ImportAll, ImportFrom, Import, FuncDef, FuncBase, Block, 76 Statement, OverloadedFuncDef, ARG_POS, ARG_STAR, ARG_STAR2, ARG_NAMED, ARG_NAMED_OPT 77) 78from mypy.stubgenc import generate_stub_for_c_module 79from mypy.stubutil import ( 80 default_py2_interpreter, CantImport, generate_guarded, 81 walk_packages, find_module_path_and_all_py2, find_module_path_and_all_py3, 82 report_missing, fail_missing, remove_misplaced_type_comments, common_dir_prefix 83) 84from mypy.stubdoc import parse_all_signatures, find_unique_signatures, Sig 85from mypy.options import Options as MypyOptions 86from mypy.types import ( 87 Type, TypeStrVisitor, CallableType, UnboundType, NoneType, TupleType, TypeList, Instance, 88 AnyType, get_proper_type 89) 90from mypy.visitor import NodeVisitor 91from mypy.find_sources import create_source_list, InvalidSourceList 92from mypy.build import build 93from mypy.errors import CompileError, Errors 94from mypy.traverser import has_return_statement 95from mypy.moduleinspect import ModuleInspect 96 97 98# Common ways of naming package containing vendored modules. 99VENDOR_PACKAGES = [ 100 'packages', 101 'vendor', 102 'vendored', 103 '_vendor', 104 '_vendored_packages', 105] # type: Final 106 107# Avoid some file names that are unnecessary or likely to cause trouble (\n for end of path). 108BLACKLIST = [ 109 '/six.py\n', # Likely vendored six; too dynamic for us to handle 110 '/vendored/', # Vendored packages 111 '/vendor/', # Vendored packages 112 '/_vendor/', 113 '/_vendored_packages/', 114] # type: Final 115 116# Special-cased names that are implicitly exported from the stub (from m import y as y). 117EXTRA_EXPORTED = { 118 'pyasn1_modules.rfc2437.univ', 119 'pyasn1_modules.rfc2459.char', 120 'pyasn1_modules.rfc2459.univ', 121} # type: Final 122 123# These names should be omitted from generated stubs. 124IGNORED_DUNDERS = { 125 '__all__', 126 '__author__', 127 '__version__', 128 '__about__', 129 '__copyright__', 130 '__email__', 131 '__license__', 132 '__summary__', 133 '__title__', 134 '__uri__', 135 '__str__', 136 '__repr__', 137 '__getstate__', 138 '__setstate__', 139 '__slots__', 140} # type: Final 141 142# These methods are expected to always return a non-trivial value. 143METHODS_WITH_RETURN_VALUE = { 144 '__ne__', 145 '__eq__', 146 '__lt__', 147 '__le__', 148 '__gt__', 149 '__ge__', 150 '__hash__', 151 '__iter__', 152} # type: Final 153 154 155class Options: 156 """Represents stubgen options. 157 158 This class is mutable to simplify testing. 159 """ 160 def __init__(self, 161 pyversion: Tuple[int, int], 162 no_import: bool, 163 doc_dir: str, 164 search_path: List[str], 165 interpreter: str, 166 parse_only: bool, 167 ignore_errors: bool, 168 include_private: bool, 169 output_dir: str, 170 modules: List[str], 171 packages: List[str], 172 files: List[str], 173 verbose: bool, 174 quiet: bool, 175 export_less: bool) -> None: 176 # See parse_options for descriptions of the flags. 177 self.pyversion = pyversion 178 self.no_import = no_import 179 self.doc_dir = doc_dir 180 self.search_path = search_path 181 self.interpreter = interpreter 182 self.decointerpreter = interpreter 183 self.parse_only = parse_only 184 self.ignore_errors = ignore_errors 185 self.include_private = include_private 186 self.output_dir = output_dir 187 self.modules = modules 188 self.packages = packages 189 self.files = files 190 self.verbose = verbose 191 self.quiet = quiet 192 self.export_less = export_less 193 194 195class StubSource: 196 """A single source for stub: can be a Python or C module. 197 198 A simple extension of BuildSource that also carries the AST and 199 the value of __all__ detected at runtime. 200 """ 201 def __init__(self, module: str, path: Optional[str] = None, 202 runtime_all: Optional[List[str]] = None) -> None: 203 self.source = BuildSource(path, module, None) 204 self.runtime_all = runtime_all 205 self.ast = None # type: Optional[MypyFile] 206 207 @property 208 def module(self) -> str: 209 return self.source.module 210 211 @property 212 def path(self) -> Optional[str]: 213 return self.source.path 214 215 216# What was generated previously in the stub file. We keep track of these to generate 217# nicely formatted output (add empty line between non-empty classes, for example). 218EMPTY = 'EMPTY' # type: Final 219FUNC = 'FUNC' # type: Final 220CLASS = 'CLASS' # type: Final 221EMPTY_CLASS = 'EMPTY_CLASS' # type: Final 222VAR = 'VAR' # type: Final 223NOT_IN_ALL = 'NOT_IN_ALL' # type: Final 224 225# Indicates that we failed to generate a reasonable output 226# for a given node. These should be manually replaced by a user. 227 228ERROR_MARKER = '<ERROR>' # type: Final 229 230 231class AnnotationPrinter(TypeStrVisitor): 232 """Visitor used to print existing annotations in a file. 233 234 The main difference from TypeStrVisitor is a better treatment of 235 unbound types. 236 237 Notes: 238 * This visitor doesn't add imports necessary for annotations, this is done separately 239 by ImportTracker. 240 * It can print all kinds of types, but the generated strings may not be valid (notably 241 callable types) since it prints the same string that reveal_type() does. 242 * For Instance types it prints the fully qualified names. 243 """ 244 # TODO: Generate valid string representation for callable types. 245 # TODO: Use short names for Instances. 246 def __init__(self, stubgen: 'StubGenerator') -> None: 247 super().__init__() 248 self.stubgen = stubgen 249 250 def visit_any(self, t: AnyType) -> str: 251 s = super().visit_any(t) 252 self.stubgen.import_tracker.require_name(s) 253 return s 254 255 def visit_unbound_type(self, t: UnboundType) -> str: 256 s = t.name 257 self.stubgen.import_tracker.require_name(s) 258 if t.args: 259 s += '[{}]'.format(self.list_str(t.args)) 260 return s 261 262 def visit_none_type(self, t: NoneType) -> str: 263 return "None" 264 265 def visit_type_list(self, t: TypeList) -> str: 266 return '[{}]'.format(self.list_str(t.items)) 267 268 269class AliasPrinter(NodeVisitor[str]): 270 """Visitor used to collect type aliases _and_ type variable definitions. 271 272 Visit r.h.s of the definition to get the string representation of type alias. 273 """ 274 def __init__(self, stubgen: 'StubGenerator') -> None: 275 self.stubgen = stubgen 276 super().__init__() 277 278 def visit_call_expr(self, node: CallExpr) -> str: 279 # Call expressions are not usually types, but we also treat `X = TypeVar(...)` as a 280 # type alias that has to be preserved (even if TypeVar is not the same as an alias) 281 callee = node.callee.accept(self) 282 args = [] 283 for name, arg, kind in zip(node.arg_names, node.args, node.arg_kinds): 284 if kind == ARG_POS: 285 args.append(arg.accept(self)) 286 elif kind == ARG_STAR: 287 args.append('*' + arg.accept(self)) 288 elif kind == ARG_STAR2: 289 args.append('**' + arg.accept(self)) 290 elif kind == ARG_NAMED: 291 args.append('{}={}'.format(name, arg.accept(self))) 292 else: 293 raise ValueError("Unknown argument kind %d in call" % kind) 294 return "{}({})".format(callee, ", ".join(args)) 295 296 def visit_name_expr(self, node: NameExpr) -> str: 297 self.stubgen.import_tracker.require_name(node.name) 298 return node.name 299 300 def visit_member_expr(self, o: MemberExpr) -> str: 301 node = o # type: Expression 302 trailer = '' 303 while isinstance(node, MemberExpr): 304 trailer = '.' + node.name + trailer 305 node = node.expr 306 if not isinstance(node, NameExpr): 307 return ERROR_MARKER 308 self.stubgen.import_tracker.require_name(node.name) 309 return node.name + trailer 310 311 def visit_str_expr(self, node: StrExpr) -> str: 312 return repr(node.value) 313 314 def visit_index_expr(self, node: IndexExpr) -> str: 315 base = node.base.accept(self) 316 index = node.index.accept(self) 317 return "{}[{}]".format(base, index) 318 319 def visit_tuple_expr(self, node: TupleExpr) -> str: 320 return ", ".join(n.accept(self) for n in node.items) 321 322 def visit_list_expr(self, node: ListExpr) -> str: 323 return "[{}]".format(", ".join(n.accept(self) for n in node.items)) 324 325 def visit_ellipsis(self, node: EllipsisExpr) -> str: 326 return "..." 327 328 329class ImportTracker: 330 """Record necessary imports during stub generation.""" 331 332 def __init__(self) -> None: 333 # module_for['foo'] has the module name where 'foo' was imported from, or None if 334 # 'foo' is a module imported directly; examples 335 # 'from pkg.m import f as foo' ==> module_for['foo'] == 'pkg.m' 336 # 'from m import f' ==> module_for['f'] == 'm' 337 # 'import m' ==> module_for['m'] == None 338 # 'import pkg.m' ==> module_for['pkg.m'] == None 339 # ==> module_for['pkg'] == None 340 self.module_for = {} # type: Dict[str, Optional[str]] 341 342 # direct_imports['foo'] is the module path used when the name 'foo' was added to the 343 # namespace. 344 # import foo.bar.baz ==> direct_imports['foo'] == 'foo.bar.baz' 345 # ==> direct_imports['foo.bar'] == 'foo.bar.baz' 346 # ==> direct_imports['foo.bar.baz'] == 'foo.bar.baz' 347 self.direct_imports = {} # type: Dict[str, str] 348 349 # reverse_alias['foo'] is the name that 'foo' had originally when imported with an 350 # alias; examples 351 # 'import numpy as np' ==> reverse_alias['np'] == 'numpy' 352 # 'import foo.bar as bar' ==> reverse_alias['bar'] == 'foo.bar' 353 # 'from decimal import Decimal as D' ==> reverse_alias['D'] == 'Decimal' 354 self.reverse_alias = {} # type: Dict[str, str] 355 356 # required_names is the set of names that are actually used in a type annotation 357 self.required_names = set() # type: Set[str] 358 359 # Names that should be reexported if they come from another module 360 self.reexports = set() # type: Set[str] 361 362 def add_import_from(self, module: str, names: List[Tuple[str, Optional[str]]]) -> None: 363 for name, alias in names: 364 if alias: 365 # 'from {module} import {name} as {alias}' 366 self.module_for[alias] = module 367 self.reverse_alias[alias] = name 368 else: 369 # 'from {module} import {name}' 370 self.module_for[name] = module 371 self.reverse_alias.pop(name, None) 372 self.direct_imports.pop(alias or name, None) 373 374 def add_import(self, module: str, alias: Optional[str] = None) -> None: 375 if alias: 376 # 'import {module} as {alias}' 377 self.module_for[alias] = None 378 self.reverse_alias[alias] = module 379 else: 380 # 'import {module}' 381 name = module 382 # add module and its parent packages 383 while name: 384 self.module_for[name] = None 385 self.direct_imports[name] = module 386 self.reverse_alias.pop(name, None) 387 name = name.rpartition('.')[0] 388 389 def require_name(self, name: str) -> None: 390 self.required_names.add(name.split('.')[0]) 391 392 def reexport(self, name: str) -> None: 393 """Mark a given non qualified name as needed in __all__. 394 395 This means that in case it comes from a module, it should be 396 imported with an alias even is the alias is the same as the name. 397 """ 398 self.require_name(name) 399 self.reexports.add(name) 400 401 def import_lines(self) -> List[str]: 402 """The list of required import lines (as strings with python code).""" 403 result = [] 404 405 # To summarize multiple names imported from a same module, we collect those 406 # in the `module_map` dictionary, mapping a module path to the list of names that should 407 # be imported from it. the names can also be alias in the form 'original as alias' 408 module_map = defaultdict(list) # type: Mapping[str, List[str]] 409 410 for name in sorted(self.required_names): 411 # If we haven't seen this name in an import statement, ignore it 412 if name not in self.module_for: 413 continue 414 415 m = self.module_for[name] 416 if m is not None: 417 # This name was found in a from ... import ... 418 # Collect the name in the module_map 419 if name in self.reverse_alias: 420 name = '{} as {}'.format(self.reverse_alias[name], name) 421 elif name in self.reexports: 422 name = '{} as {}'.format(name, name) 423 module_map[m].append(name) 424 else: 425 # This name was found in an import ... 426 # We can already generate the import line 427 if name in self.reverse_alias: 428 source = self.reverse_alias[name] 429 result.append("import {} as {}\n".format(source, name)) 430 elif name in self.reexports: 431 assert '.' not in name # Because reexports only has nonqualified names 432 result.append("import {} as {}\n".format(name, name)) 433 else: 434 result.append("import {}\n".format(self.direct_imports[name])) 435 436 # Now generate all the from ... import ... lines collected in module_map 437 for module, names in sorted(module_map.items()): 438 result.append("from {} import {}\n".format(module, ', '.join(sorted(names)))) 439 return result 440 441 442def find_defined_names(file: MypyFile) -> Set[str]: 443 finder = DefinitionFinder() 444 file.accept(finder) 445 return finder.names 446 447 448class DefinitionFinder(mypy.traverser.TraverserVisitor): 449 """Find names of things defined at the top level of a module.""" 450 451 # TODO: Assignment statements etc. 452 453 def __init__(self) -> None: 454 # Short names of things defined at the top level. 455 self.names = set() # type: Set[str] 456 457 def visit_class_def(self, o: ClassDef) -> None: 458 # Don't recurse into classes, as we only keep track of top-level definitions. 459 self.names.add(o.name) 460 461 def visit_func_def(self, o: FuncDef) -> None: 462 # Don't recurse, as we only keep track of top-level definitions. 463 self.names.add(o.name) 464 465 466def find_referenced_names(file: MypyFile) -> Set[str]: 467 finder = ReferenceFinder() 468 file.accept(finder) 469 return finder.refs 470 471 472class ReferenceFinder(mypy.mixedtraverser.MixedTraverserVisitor): 473 """Find all name references (both local and global).""" 474 475 # TODO: Filter out local variable and class attribute references 476 477 def __init__(self) -> None: 478 # Short names of things defined at the top level. 479 self.refs = set() # type: Set[str] 480 481 def visit_block(self, block: Block) -> None: 482 if not block.is_unreachable: 483 super().visit_block(block) 484 485 def visit_name_expr(self, e: NameExpr) -> None: 486 self.refs.add(e.name) 487 488 def visit_instance(self, t: Instance) -> None: 489 self.add_ref(t.type.fullname) 490 super().visit_instance(t) 491 492 def visit_unbound_type(self, t: UnboundType) -> None: 493 if t.name: 494 self.add_ref(t.name) 495 496 def visit_tuple_type(self, t: TupleType) -> None: 497 # Ignore fallback 498 for item in t.items: 499 item.accept(self) 500 501 def visit_callable_type(self, t: CallableType) -> None: 502 # Ignore fallback 503 for arg in t.arg_types: 504 arg.accept(self) 505 t.ret_type.accept(self) 506 507 def add_ref(self, fullname: str) -> None: 508 self.refs.add(fullname.split('.')[-1]) 509 510 511class StubGenerator(mypy.traverser.TraverserVisitor): 512 """Generate stub text from a mypy AST.""" 513 514 def __init__(self, 515 _all_: Optional[List[str]], pyversion: Tuple[int, int], 516 include_private: bool = False, 517 analyzed: bool = False, 518 export_less: bool = False) -> None: 519 # Best known value of __all__. 520 self._all_ = _all_ 521 self._output = [] # type: List[str] 522 self._decorators = [] # type: List[str] 523 self._import_lines = [] # type: List[str] 524 # Current indent level (indent is hardcoded to 4 spaces). 525 self._indent = '' 526 # Stack of defined variables (per scope). 527 self._vars = [[]] # type: List[List[str]] 528 # What was generated previously in the stub file. 529 self._state = EMPTY 530 self._toplevel_names = [] # type: List[str] 531 self._pyversion = pyversion 532 self._include_private = include_private 533 self.import_tracker = ImportTracker() 534 # Was the tree semantically analysed before? 535 self.analyzed = analyzed 536 # Disable implicit exports of package-internal imports? 537 self.export_less = export_less 538 # Add imports that could be implicitly generated 539 self.import_tracker.add_import_from("typing", [("NamedTuple", None)]) 540 # Names in __all__ are required 541 for name in _all_ or (): 542 if name not in IGNORED_DUNDERS: 543 self.import_tracker.reexport(name) 544 self.defined_names = set() # type: Set[str] 545 # Short names of methods defined in the body of the current class 546 self.method_names = set() # type: Set[str] 547 548 def visit_mypy_file(self, o: MypyFile) -> None: 549 self.module = o.fullname # Current module being processed 550 self.path = o.path 551 self.defined_names = find_defined_names(o) 552 self.referenced_names = find_referenced_names(o) 553 typing_imports = ["Any", "Optional", "TypeVar"] 554 for t in typing_imports: 555 if t not in self.defined_names: 556 alias = None 557 else: 558 alias = '_' + t 559 self.import_tracker.add_import_from("typing", [(t, alias)]) 560 super().visit_mypy_file(o) 561 undefined_names = [name for name in self._all_ or [] 562 if name not in self._toplevel_names] 563 if undefined_names: 564 if self._state != EMPTY: 565 self.add('\n') 566 self.add('# Names in __all__ with no definition:\n') 567 for name in sorted(undefined_names): 568 self.add('# %s\n' % name) 569 570 def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None: 571 """@property with setters and getters, or @overload chain""" 572 overload_chain = False 573 for item in o.items: 574 if not isinstance(item, Decorator): 575 continue 576 577 if self.is_private_name(item.func.name, item.func.fullname): 578 continue 579 580 is_abstract, is_overload = self.process_decorator(item) 581 582 if not overload_chain: 583 self.visit_func_def(item.func, is_abstract=is_abstract, is_overload=is_overload) 584 if is_overload: 585 overload_chain = True 586 elif overload_chain and is_overload: 587 self.visit_func_def(item.func, is_abstract=is_abstract, is_overload=is_overload) 588 else: 589 # skip the overload implementation and clear the decorator we just processed 590 self.clear_decorators() 591 592 def visit_func_def(self, o: FuncDef, is_abstract: bool = False, 593 is_overload: bool = False) -> None: 594 if (self.is_private_name(o.name, o.fullname) 595 or self.is_not_in_all(o.name) 596 or (self.is_recorded_name(o.name) and not is_overload)): 597 self.clear_decorators() 598 return 599 if not self._indent and self._state not in (EMPTY, FUNC) and not o.is_awaitable_coroutine: 600 self.add('\n') 601 if not self.is_top_level(): 602 self_inits = find_self_initializers(o) 603 for init, value in self_inits: 604 if init in self.method_names: 605 # Can't have both an attribute and a method/property with the same name. 606 continue 607 init_code = self.get_init(init, value) 608 if init_code: 609 self.add(init_code) 610 # dump decorators, just before "def ..." 611 for s in self._decorators: 612 self.add(s) 613 self.clear_decorators() 614 self.add("%s%sdef %s(" % (self._indent, 'async ' if o.is_coroutine else '', o.name)) 615 self.record_name(o.name) 616 args = [] # type: List[str] 617 for i, arg_ in enumerate(o.arguments): 618 var = arg_.variable 619 kind = arg_.kind 620 name = var.name 621 annotated_type = (o.unanalyzed_type.arg_types[i] 622 if isinstance(o.unanalyzed_type, CallableType) else None) 623 # I think the name check is incorrect: there are libraries which 624 # name their 0th argument other than self/cls 625 is_self_arg = i == 0 and name == 'self' 626 is_cls_arg = i == 0 and name == 'cls' 627 annotation = "" 628 if annotated_type and not is_self_arg and not is_cls_arg: 629 # Luckily, an argument explicitly annotated with "Any" has 630 # type "UnboundType" and will not match. 631 if not isinstance(get_proper_type(annotated_type), AnyType): 632 annotation = ": {}".format(self.print_annotation(annotated_type)) 633 if arg_.initializer: 634 if kind in (ARG_NAMED, ARG_NAMED_OPT) and not any(arg.startswith('*') 635 for arg in args): 636 args.append('*') 637 if not annotation: 638 typename = self.get_str_type_of_node(arg_.initializer, True, False) 639 if typename == '': 640 annotation = '=...' 641 else: 642 annotation = ': {} = ...'.format(typename) 643 else: 644 annotation += ' = ...' 645 arg = name + annotation 646 elif kind == ARG_STAR: 647 arg = '*%s%s' % (name, annotation) 648 elif kind == ARG_STAR2: 649 arg = '**%s%s' % (name, annotation) 650 else: 651 arg = name + annotation 652 args.append(arg) 653 retname = None 654 if o.name != '__init__' and isinstance(o.unanalyzed_type, CallableType): 655 if isinstance(get_proper_type(o.unanalyzed_type.ret_type), AnyType): 656 # Luckily, a return type explicitly annotated with "Any" has 657 # type "UnboundType" and will enter the else branch. 658 retname = None # implicit Any 659 else: 660 retname = self.print_annotation(o.unanalyzed_type.ret_type) 661 elif isinstance(o, FuncDef) and (o.is_abstract or o.name in METHODS_WITH_RETURN_VALUE): 662 # Always assume abstract methods return Any unless explicitly annotated. Also 663 # some dunder methods should not have a None return type. 664 retname = None # implicit Any 665 elif not has_return_statement(o) and not is_abstract: 666 retname = 'None' 667 retfield = '' 668 if retname is not None: 669 retfield = ' -> ' + retname 670 671 self.add(', '.join(args)) 672 self.add("){}: ...\n".format(retfield)) 673 self._state = FUNC 674 675 def visit_decorator(self, o: Decorator) -> None: 676 if self.is_private_name(o.func.name, o.func.fullname): 677 return 678 679 is_abstract, _ = self.process_decorator(o) 680 self.visit_func_def(o.func, is_abstract=is_abstract) 681 682 def process_decorator(self, o: Decorator) -> Tuple[bool, bool]: 683 """Process a series of decorators. 684 685 Only preserve certain special decorators such as @abstractmethod. 686 687 Return a pair of booleans: 688 - True if any of the decorators makes a method abstract. 689 - True if any of the decorators is typing.overload. 690 """ 691 is_abstract = False 692 is_overload = False 693 for decorator in o.original_decorators: 694 if isinstance(decorator, NameExpr): 695 i_is_abstract, i_is_overload = self.process_name_expr_decorator(decorator, o) 696 is_abstract = is_abstract or i_is_abstract 697 is_overload = is_overload or i_is_overload 698 elif isinstance(decorator, MemberExpr): 699 i_is_abstract, i_is_overload = self.process_member_expr_decorator(decorator, o) 700 is_abstract = is_abstract or i_is_abstract 701 is_overload = is_overload or i_is_overload 702 return is_abstract, is_overload 703 704 def process_name_expr_decorator(self, expr: NameExpr, context: Decorator) -> Tuple[bool, bool]: 705 """Process a function decorator of form @foo. 706 707 Only preserve certain special decorators such as @abstractmethod. 708 709 Return a pair of booleans: 710 - True if the decorator makes a method abstract. 711 - True if the decorator is typing.overload. 712 """ 713 is_abstract = False 714 is_overload = False 715 name = expr.name 716 if name in ('property', 'staticmethod', 'classmethod'): 717 self.add_decorator(name) 718 elif self.import_tracker.module_for.get(name) in ('asyncio', 719 'asyncio.coroutines', 720 'types'): 721 self.add_coroutine_decorator(context.func, name, name) 722 elif self.refers_to_fullname(name, 'abc.abstractmethod'): 723 self.add_decorator(name) 724 self.import_tracker.require_name(name) 725 is_abstract = True 726 elif self.refers_to_fullname(name, 'abc.abstractproperty'): 727 self.add_decorator('property') 728 self.add_decorator('abc.abstractmethod') 729 is_abstract = True 730 elif self.refers_to_fullname(name, 'typing.overload'): 731 self.add_decorator(name) 732 self.add_typing_import('overload') 733 is_overload = True 734 return is_abstract, is_overload 735 736 def refers_to_fullname(self, name: str, fullname: str) -> bool: 737 module, short = fullname.rsplit('.', 1) 738 return (self.import_tracker.module_for.get(name) == module and 739 (name == short or 740 self.import_tracker.reverse_alias.get(name) == short)) 741 742 def process_member_expr_decorator(self, expr: MemberExpr, context: Decorator) -> Tuple[bool, 743 bool]: 744 """Process a function decorator of form @foo.bar. 745 746 Only preserve certain special decorators such as @abstractmethod. 747 748 Return a pair of booleans: 749 - True if the decorator makes a method abstract. 750 - True if the decorator is typing.overload. 751 """ 752 is_abstract = False 753 is_overload = False 754 if expr.name == 'setter' and isinstance(expr.expr, NameExpr): 755 self.add_decorator('%s.setter' % expr.expr.name) 756 elif (isinstance(expr.expr, NameExpr) and 757 (expr.expr.name == 'abc' or 758 self.import_tracker.reverse_alias.get(expr.expr.name) == 'abc') and 759 expr.name in ('abstractmethod', 'abstractproperty')): 760 if expr.name == 'abstractproperty': 761 self.import_tracker.require_name(expr.expr.name) 762 self.add_decorator('%s' % ('property')) 763 self.add_decorator('%s.%s' % (expr.expr.name, 'abstractmethod')) 764 else: 765 self.import_tracker.require_name(expr.expr.name) 766 self.add_decorator('%s.%s' % (expr.expr.name, expr.name)) 767 is_abstract = True 768 elif expr.name == 'coroutine': 769 if (isinstance(expr.expr, MemberExpr) and 770 expr.expr.name == 'coroutines' and 771 isinstance(expr.expr.expr, NameExpr) and 772 (expr.expr.expr.name == 'asyncio' or 773 self.import_tracker.reverse_alias.get(expr.expr.expr.name) == 774 'asyncio')): 775 self.add_coroutine_decorator(context.func, 776 '%s.coroutines.coroutine' % 777 (expr.expr.expr.name,), 778 expr.expr.expr.name) 779 elif (isinstance(expr.expr, NameExpr) and 780 (expr.expr.name in ('asyncio', 'types') or 781 self.import_tracker.reverse_alias.get(expr.expr.name) in 782 ('asyncio', 'asyncio.coroutines', 'types'))): 783 self.add_coroutine_decorator(context.func, 784 expr.expr.name + '.coroutine', 785 expr.expr.name) 786 elif (isinstance(expr.expr, NameExpr) and 787 (expr.expr.name == 'typing' or 788 self.import_tracker.reverse_alias.get(expr.expr.name) == 'typing') and 789 expr.name == 'overload'): 790 self.import_tracker.require_name(expr.expr.name) 791 self.add_decorator('%s.%s' % (expr.expr.name, 'overload')) 792 is_overload = True 793 return is_abstract, is_overload 794 795 def visit_class_def(self, o: ClassDef) -> None: 796 self.method_names = find_method_names(o.defs.body) 797 sep = None # type: Optional[int] 798 if not self._indent and self._state != EMPTY: 799 sep = len(self._output) 800 self.add('\n') 801 self.add('%sclass %s' % (self._indent, o.name)) 802 self.record_name(o.name) 803 base_types = self.get_base_types(o) 804 if base_types: 805 for base in base_types: 806 self.import_tracker.require_name(base) 807 if isinstance(o.metaclass, (NameExpr, MemberExpr)): 808 meta = o.metaclass.accept(AliasPrinter(self)) 809 base_types.append('metaclass=' + meta) 810 elif self.analyzed and o.info.is_abstract: 811 base_types.append('metaclass=abc.ABCMeta') 812 self.import_tracker.add_import('abc') 813 self.import_tracker.require_name('abc') 814 if base_types: 815 self.add('(%s)' % ', '.join(base_types)) 816 self.add(':\n') 817 n = len(self._output) 818 self._indent += ' ' 819 self._vars.append([]) 820 super().visit_class_def(o) 821 self._indent = self._indent[:-4] 822 self._vars.pop() 823 self._vars[-1].append(o.name) 824 if len(self._output) == n: 825 if self._state == EMPTY_CLASS and sep is not None: 826 self._output[sep] = '' 827 self._output[-1] = self._output[-1][:-1] + ' ...\n' 828 self._state = EMPTY_CLASS 829 else: 830 self._state = CLASS 831 self.method_names = set() 832 833 def get_base_types(self, cdef: ClassDef) -> List[str]: 834 """Get list of base classes for a class.""" 835 base_types = [] # type: List[str] 836 for base in cdef.base_type_exprs: 837 if isinstance(base, NameExpr): 838 if base.name != 'object': 839 base_types.append(base.name) 840 elif isinstance(base, MemberExpr): 841 modname = get_qualified_name(base.expr) 842 base_types.append('%s.%s' % (modname, base.name)) 843 elif isinstance(base, IndexExpr): 844 p = AliasPrinter(self) 845 base_types.append(base.accept(p)) 846 return base_types 847 848 def visit_block(self, o: Block) -> None: 849 # Unreachable statements may be partially uninitialized and that may 850 # cause trouble. 851 if not o.is_unreachable: 852 super().visit_block(o) 853 854 def visit_assignment_stmt(self, o: AssignmentStmt) -> None: 855 foundl = [] 856 857 for lvalue in o.lvalues: 858 if isinstance(lvalue, NameExpr) and self.is_namedtuple(o.rvalue): 859 assert isinstance(o.rvalue, CallExpr) 860 self.process_namedtuple(lvalue, o.rvalue) 861 continue 862 if (self.is_top_level() and 863 isinstance(lvalue, NameExpr) and not self.is_private_name(lvalue.name) and 864 # it is never an alias with explicit annotation 865 not o.unanalyzed_type and self.is_alias_expression(o.rvalue)): 866 self.process_typealias(lvalue, o.rvalue) 867 continue 868 if isinstance(lvalue, TupleExpr) or isinstance(lvalue, ListExpr): 869 items = lvalue.items 870 if isinstance(o.unanalyzed_type, TupleType): # type: ignore 871 annotations = o.unanalyzed_type.items # type: Iterable[Optional[Type]] 872 else: 873 annotations = [None] * len(items) 874 else: 875 items = [lvalue] 876 annotations = [o.unanalyzed_type] 877 sep = False 878 found = False 879 for item, annotation in zip(items, annotations): 880 if isinstance(item, NameExpr): 881 init = self.get_init(item.name, o.rvalue, annotation) 882 if init: 883 found = True 884 if not sep and not self._indent and \ 885 self._state not in (EMPTY, VAR): 886 init = '\n' + init 887 sep = True 888 self.add(init) 889 self.record_name(item.name) 890 foundl.append(found) 891 892 if all(foundl): 893 self._state = VAR 894 895 def is_namedtuple(self, expr: Expression) -> bool: 896 if not isinstance(expr, CallExpr): 897 return False 898 callee = expr.callee 899 return ((isinstance(callee, NameExpr) and callee.name.endswith('namedtuple')) or 900 (isinstance(callee, MemberExpr) and callee.name == 'namedtuple')) 901 902 def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None: 903 if self._state != EMPTY: 904 self.add('\n') 905 if isinstance(rvalue.args[1], StrExpr): 906 items = rvalue.args[1].value.split(" ") 907 elif isinstance(rvalue.args[1], (ListExpr, TupleExpr)): 908 list_items = cast(List[StrExpr], rvalue.args[1].items) 909 items = [item.value for item in list_items] 910 else: 911 self.add('%s%s: Any' % (self._indent, lvalue.name)) 912 self.import_tracker.require_name('Any') 913 return 914 self.import_tracker.require_name('NamedTuple') 915 self.add('{}class {}(NamedTuple):'.format(self._indent, lvalue.name)) 916 if len(items) == 0: 917 self.add(' ...\n') 918 else: 919 self.import_tracker.require_name('Any') 920 self.add('\n') 921 for item in items: 922 self.add('{} {}: Any\n'.format(self._indent, item)) 923 self._state = CLASS 924 925 def is_alias_expression(self, expr: Expression, top_level: bool = True) -> bool: 926 """Return True for things that look like target for an alias. 927 928 Used to know if assignments look like type aliases, function alias, 929 or module alias. 930 """ 931 # Assignment of TypeVar(...) are passed through 932 if (isinstance(expr, CallExpr) and 933 isinstance(expr.callee, NameExpr) and 934 expr.callee.name == 'TypeVar'): 935 return True 936 elif isinstance(expr, EllipsisExpr): 937 return not top_level 938 elif isinstance(expr, NameExpr): 939 if expr.name in ('True', 'False'): 940 return False 941 elif expr.name == 'None': 942 return not top_level 943 else: 944 return not self.is_private_name(expr.name) 945 elif isinstance(expr, MemberExpr) and self.analyzed: 946 # Also add function and module aliases. 947 return ((top_level and isinstance(expr.node, (FuncDef, Decorator, MypyFile)) 948 or isinstance(expr.node, TypeInfo)) and 949 not self.is_private_member(expr.node.fullname)) 950 elif (isinstance(expr, IndexExpr) and isinstance(expr.base, NameExpr) and 951 not self.is_private_name(expr.base.name)): 952 if isinstance(expr.index, TupleExpr): 953 indices = expr.index.items 954 else: 955 indices = [expr.index] 956 if expr.base.name == 'Callable' and len(indices) == 2: 957 args, ret = indices 958 if isinstance(args, EllipsisExpr): 959 indices = [ret] 960 elif isinstance(args, ListExpr): 961 indices = args.items + [ret] 962 else: 963 return False 964 return all(self.is_alias_expression(i, top_level=False) for i in indices) 965 else: 966 return False 967 968 def process_typealias(self, lvalue: NameExpr, rvalue: Expression) -> None: 969 p = AliasPrinter(self) 970 self.add("{} = {}\n".format(lvalue.name, rvalue.accept(p))) 971 self.record_name(lvalue.name) 972 self._vars[-1].append(lvalue.name) 973 974 def visit_if_stmt(self, o: IfStmt) -> None: 975 # Ignore if __name__ == '__main__'. 976 expr = o.expr[0] 977 if (isinstance(expr, ComparisonExpr) and 978 isinstance(expr.operands[0], NameExpr) and 979 isinstance(expr.operands[1], StrExpr) and 980 expr.operands[0].name == '__name__' and 981 '__main__' in expr.operands[1].value): 982 return 983 super().visit_if_stmt(o) 984 985 def visit_import_all(self, o: ImportAll) -> None: 986 self.add_import_line('from %s%s import *\n' % ('.' * o.relative, o.id)) 987 988 def visit_import_from(self, o: ImportFrom) -> None: 989 exported_names = set() # type: Set[str] 990 import_names = [] 991 module, relative = translate_module_name(o.id, o.relative) 992 if self.module: 993 full_module, ok = mypy.util.correct_relative_import( 994 self.module, relative, module, self.path.endswith('.__init__.py') 995 ) 996 if not ok: 997 full_module = module 998 else: 999 full_module = module 1000 if module == '__future__': 1001 return # Not preserved 1002 for name, as_name in o.names: 1003 if name == 'six': 1004 # Vendored six -- translate into plain 'import six'. 1005 self.visit_import(Import([('six', None)])) 1006 continue 1007 exported = False 1008 if as_name is None and self.module and (self.module + '.' + name) in EXTRA_EXPORTED: 1009 # Special case certain names that should be exported, against our general rules. 1010 exported = True 1011 is_private = self.is_private_name(name, full_module + '.' + name) 1012 if (as_name is None 1013 and name not in self.referenced_names 1014 and (not self._all_ or name in IGNORED_DUNDERS) 1015 and not is_private 1016 and module not in ('abc', 'typing', 'asyncio')): 1017 # An imported name that is never referenced in the module is assumed to be 1018 # exported, unless there is an explicit __all__. Note that we need to special 1019 # case 'abc' since some references are deleted during semantic analysis. 1020 exported = True 1021 top_level = full_module.split('.')[0] 1022 if (as_name is None 1023 and not self.export_less 1024 and (not self._all_ or name in IGNORED_DUNDERS) 1025 and self.module 1026 and not is_private 1027 and top_level in (self.module.split('.')[0], 1028 '_' + self.module.split('.')[0])): 1029 # Export imports from the same package, since we can't reliably tell whether they 1030 # are part of the public API. 1031 exported = True 1032 if exported: 1033 self.import_tracker.reexport(name) 1034 as_name = name 1035 import_names.append((name, as_name)) 1036 self.import_tracker.add_import_from('.' * relative + module, import_names) 1037 self._vars[-1].extend(alias or name for name, alias in import_names) 1038 for name, alias in import_names: 1039 self.record_name(alias or name) 1040 1041 if self._all_: 1042 # Include import froms that import names defined in __all__. 1043 names = [name for name, alias in o.names 1044 if name in self._all_ and alias is None and name not in IGNORED_DUNDERS] 1045 exported_names.update(names) 1046 1047 def visit_import(self, o: Import) -> None: 1048 for id, as_id in o.ids: 1049 self.import_tracker.add_import(id, as_id) 1050 if as_id is None: 1051 target_name = id.split('.')[0] 1052 else: 1053 target_name = as_id 1054 self._vars[-1].append(target_name) 1055 self.record_name(target_name) 1056 1057 def get_init(self, lvalue: str, rvalue: Expression, 1058 annotation: Optional[Type] = None) -> Optional[str]: 1059 """Return initializer for a variable. 1060 1061 Return None if we've generated one already or if the variable is internal. 1062 """ 1063 if lvalue in self._vars[-1]: 1064 # We've generated an initializer already for this variable. 1065 return None 1066 # TODO: Only do this at module top level. 1067 if self.is_private_name(lvalue) or self.is_not_in_all(lvalue): 1068 return None 1069 self._vars[-1].append(lvalue) 1070 if annotation is not None: 1071 typename = self.print_annotation(annotation) 1072 if (isinstance(annotation, UnboundType) and not annotation.args and 1073 annotation.name == 'Final' and 1074 self.import_tracker.module_for.get('Final') in ('typing', 1075 'typing_extensions')): 1076 # Final without type argument is invalid in stubs. 1077 final_arg = self.get_str_type_of_node(rvalue) 1078 typename += '[{}]'.format(final_arg) 1079 else: 1080 typename = self.get_str_type_of_node(rvalue) 1081 return '%s%s: %s\n' % (self._indent, lvalue, typename) 1082 1083 def add(self, string: str) -> None: 1084 """Add text to generated stub.""" 1085 self._output.append(string) 1086 1087 def add_decorator(self, name: str) -> None: 1088 if not self._indent and self._state not in (EMPTY, FUNC): 1089 self._decorators.append('\n') 1090 self._decorators.append('%s@%s\n' % (self._indent, name)) 1091 1092 def clear_decorators(self) -> None: 1093 self._decorators.clear() 1094 1095 def typing_name(self, name: str) -> str: 1096 if name in self.defined_names: 1097 # Avoid name clash between name from typing and a name defined in stub. 1098 return '_' + name 1099 else: 1100 return name 1101 1102 def add_typing_import(self, name: str) -> None: 1103 """Add a name to be imported from typing, unless it's imported already. 1104 1105 The import will be internal to the stub. 1106 """ 1107 name = self.typing_name(name) 1108 self.import_tracker.require_name(name) 1109 1110 def add_import_line(self, line: str) -> None: 1111 """Add a line of text to the import section, unless it's already there.""" 1112 if line not in self._import_lines: 1113 self._import_lines.append(line) 1114 1115 def add_coroutine_decorator(self, func: FuncDef, name: str, require_name: str) -> None: 1116 func.is_awaitable_coroutine = True 1117 self.add_decorator(name) 1118 self.import_tracker.require_name(require_name) 1119 1120 def output(self) -> str: 1121 """Return the text for the stub.""" 1122 imports = '' 1123 if self._import_lines: 1124 imports += ''.join(self._import_lines) 1125 imports += ''.join(self.import_tracker.import_lines()) 1126 if imports and self._output: 1127 imports += '\n' 1128 return imports + ''.join(self._output) 1129 1130 def is_not_in_all(self, name: str) -> bool: 1131 if self.is_private_name(name): 1132 return False 1133 if self._all_: 1134 return self.is_top_level() and name not in self._all_ 1135 return False 1136 1137 def is_private_name(self, name: str, fullname: Optional[str] = None) -> bool: 1138 if self._include_private: 1139 return False 1140 if fullname in EXTRA_EXPORTED: 1141 return False 1142 return name.startswith('_') and (not name.endswith('__') 1143 or name in IGNORED_DUNDERS) 1144 1145 def is_private_member(self, fullname: str) -> bool: 1146 parts = fullname.split('.') 1147 for part in parts: 1148 if self.is_private_name(part): 1149 return True 1150 return False 1151 1152 def get_str_type_of_node(self, rvalue: Expression, 1153 can_infer_optional: bool = False, 1154 can_be_any: bool = True) -> str: 1155 if isinstance(rvalue, IntExpr): 1156 return 'int' 1157 if isinstance(rvalue, StrExpr): 1158 return 'str' 1159 if isinstance(rvalue, BytesExpr): 1160 return 'bytes' 1161 if isinstance(rvalue, FloatExpr): 1162 return 'float' 1163 if isinstance(rvalue, UnaryExpr) and isinstance(rvalue.expr, IntExpr): 1164 return 'int' 1165 if isinstance(rvalue, NameExpr) and rvalue.name in ('True', 'False'): 1166 return 'bool' 1167 if can_infer_optional and \ 1168 isinstance(rvalue, NameExpr) and rvalue.name == 'None': 1169 self.add_typing_import('Any') 1170 return '{} | None'.format(self.typing_name('Any')) 1171 if can_be_any: 1172 self.add_typing_import('Any') 1173 return self.typing_name('Any') 1174 else: 1175 return '' 1176 1177 def print_annotation(self, t: Type) -> str: 1178 printer = AnnotationPrinter(self) 1179 return t.accept(printer) 1180 1181 def is_top_level(self) -> bool: 1182 """Are we processing the top level of a file?""" 1183 return self._indent == '' 1184 1185 def record_name(self, name: str) -> None: 1186 """Mark a name as defined. 1187 1188 This only does anything if at the top level of a module. 1189 """ 1190 if self.is_top_level(): 1191 self._toplevel_names.append(name) 1192 1193 def is_recorded_name(self, name: str) -> bool: 1194 """Has this name been recorded previously?""" 1195 return self.is_top_level() and name in self._toplevel_names 1196 1197 1198def find_method_names(defs: List[Statement]) -> Set[str]: 1199 # TODO: Traverse into nested definitions 1200 result = set() 1201 for defn in defs: 1202 if isinstance(defn, FuncDef): 1203 result.add(defn.name) 1204 elif isinstance(defn, Decorator): 1205 result.add(defn.func.name) 1206 elif isinstance(defn, OverloadedFuncDef): 1207 for item in defn.items: 1208 result.update(find_method_names([item])) 1209 return result 1210 1211 1212class SelfTraverser(mypy.traverser.TraverserVisitor): 1213 def __init__(self) -> None: 1214 self.results = [] # type: List[Tuple[str, Expression]] 1215 1216 def visit_assignment_stmt(self, o: AssignmentStmt) -> None: 1217 lvalue = o.lvalues[0] 1218 if (isinstance(lvalue, MemberExpr) and 1219 isinstance(lvalue.expr, NameExpr) and 1220 lvalue.expr.name == 'self'): 1221 self.results.append((lvalue.name, o.rvalue)) 1222 1223 1224def find_self_initializers(fdef: FuncBase) -> List[Tuple[str, Expression]]: 1225 """Find attribute initializers in a method. 1226 1227 Return a list of pairs (attribute name, r.h.s. expression). 1228 """ 1229 traverser = SelfTraverser() 1230 fdef.accept(traverser) 1231 return traverser.results 1232 1233 1234def get_qualified_name(o: Expression) -> str: 1235 if isinstance(o, NameExpr): 1236 return o.name 1237 elif isinstance(o, MemberExpr): 1238 return '%s.%s' % (get_qualified_name(o.expr), o.name) 1239 else: 1240 return ERROR_MARKER 1241 1242 1243def remove_blacklisted_modules(modules: List[StubSource]) -> List[StubSource]: 1244 return [module for module in modules 1245 if module.path is None or not is_blacklisted_path(module.path)] 1246 1247 1248def is_blacklisted_path(path: str) -> bool: 1249 return any(substr in (normalize_path_separators(path) + '\n') 1250 for substr in BLACKLIST) 1251 1252 1253def normalize_path_separators(path: str) -> str: 1254 if sys.platform == 'win32': 1255 return path.replace('\\', '/') 1256 return path 1257 1258 1259def collect_build_targets(options: Options, mypy_opts: MypyOptions) -> Tuple[List[StubSource], 1260 List[StubSource]]: 1261 """Collect files for which we need to generate stubs. 1262 1263 Return list of Python modules and C modules. 1264 """ 1265 if options.packages or options.modules: 1266 if options.no_import: 1267 py_modules = find_module_paths_using_search(options.modules, 1268 options.packages, 1269 options.search_path, 1270 options.pyversion) 1271 c_modules = [] # type: List[StubSource] 1272 else: 1273 # Using imports is the default, since we can also find C modules. 1274 py_modules, c_modules = find_module_paths_using_imports(options.modules, 1275 options.packages, 1276 options.interpreter, 1277 options.pyversion, 1278 options.verbose, 1279 options.quiet) 1280 else: 1281 # Use mypy native source collection for files and directories. 1282 try: 1283 source_list = create_source_list(options.files, mypy_opts) 1284 except InvalidSourceList as e: 1285 raise SystemExit(str(e)) from e 1286 py_modules = [StubSource(m.module, m.path) for m in source_list] 1287 c_modules = [] 1288 1289 py_modules = remove_blacklisted_modules(py_modules) 1290 1291 return py_modules, c_modules 1292 1293 1294def find_module_paths_using_imports(modules: List[str], 1295 packages: List[str], 1296 interpreter: str, 1297 pyversion: Tuple[int, int], 1298 verbose: bool, 1299 quiet: bool) -> Tuple[List[StubSource], 1300 List[StubSource]]: 1301 """Find path and runtime value of __all__ (if possible) for modules and packages. 1302 1303 This function uses runtime Python imports to get the information. 1304 """ 1305 with ModuleInspect() as inspect: 1306 py_modules = [] # type: List[StubSource] 1307 c_modules = [] # type: List[StubSource] 1308 found = list(walk_packages(inspect, packages, verbose)) 1309 modules = modules + found 1310 modules = [mod 1311 for mod in modules 1312 if not is_non_library_module(mod)] # We don't want to run any tests or scripts 1313 for mod in modules: 1314 try: 1315 if pyversion[0] == 2: 1316 result = find_module_path_and_all_py2(mod, interpreter) 1317 else: 1318 result = find_module_path_and_all_py3(inspect, mod, verbose) 1319 except CantImport as e: 1320 tb = traceback.format_exc() 1321 if verbose: 1322 sys.stdout.write(tb) 1323 if not quiet: 1324 report_missing(mod, e.message, tb) 1325 continue 1326 if not result: 1327 c_modules.append(StubSource(mod)) 1328 else: 1329 path, runtime_all = result 1330 py_modules.append(StubSource(mod, path, runtime_all)) 1331 return py_modules, c_modules 1332 1333 1334def is_non_library_module(module: str) -> bool: 1335 """Does module look like a test module or a script?""" 1336 if module.endswith(( 1337 '.tests', 1338 '.test', 1339 '.testing', 1340 '_tests', 1341 '_test_suite', 1342 'test_util', 1343 'test_utils', 1344 'test_base', 1345 '.__main__', 1346 '.conftest', # Used by pytest 1347 '.setup', # Typically an install script 1348 )): 1349 return True 1350 if module.split('.')[-1].startswith('test_'): 1351 return True 1352 if ('.tests.' in module 1353 or '.test.' in module 1354 or '.testing.' in module 1355 or '.SelfTest.' in module): 1356 return True 1357 return False 1358 1359 1360def translate_module_name(module: str, relative: int) -> Tuple[str, int]: 1361 for pkg in VENDOR_PACKAGES: 1362 for alt in 'six.moves', 'six': 1363 substr = '{}.{}'.format(pkg, alt) 1364 if (module.endswith('.' + substr) 1365 or (module == substr and relative)): 1366 return alt, 0 1367 if '.' + substr + '.' in module: 1368 return alt + '.' + module.partition('.' + substr + '.')[2], 0 1369 return module, relative 1370 1371 1372def find_module_paths_using_search(modules: List[str], packages: List[str], 1373 search_path: List[str], 1374 pyversion: Tuple[int, int]) -> List[StubSource]: 1375 """Find sources for modules and packages requested. 1376 1377 This function just looks for source files at the file system level. 1378 This is used if user passes --no-import, and will not find C modules. 1379 Exit if some of the modules or packages can't be found. 1380 """ 1381 result = [] # type: List[StubSource] 1382 typeshed_path = default_lib_path(mypy.build.default_data_dir(), pyversion, None) 1383 search_paths = SearchPaths(('.',) + tuple(search_path), (), (), tuple(typeshed_path)) 1384 cache = FindModuleCache(search_paths, fscache=None, options=None) 1385 for module in modules: 1386 m_result = cache.find_module(module) 1387 if isinstance(m_result, ModuleNotFoundReason): 1388 fail_missing(module, m_result) 1389 module_path = None 1390 else: 1391 module_path = m_result 1392 result.append(StubSource(module, module_path)) 1393 for package in packages: 1394 p_result = cache.find_modules_recursive(package) 1395 if p_result: 1396 fail_missing(package, ModuleNotFoundReason.NOT_FOUND) 1397 sources = [StubSource(m.module, m.path) for m in p_result] 1398 result.extend(sources) 1399 1400 result = [m for m in result if not is_non_library_module(m.module)] 1401 1402 return result 1403 1404 1405def mypy_options(stubgen_options: Options) -> MypyOptions: 1406 """Generate mypy options using the flag passed by user.""" 1407 options = MypyOptions() 1408 options.follow_imports = 'skip' 1409 options.incremental = False 1410 options.ignore_errors = True 1411 options.semantic_analysis_only = True 1412 options.python_version = stubgen_options.pyversion 1413 options.show_traceback = True 1414 options.transform_source = remove_misplaced_type_comments 1415 return options 1416 1417 1418def parse_source_file(mod: StubSource, mypy_options: MypyOptions) -> None: 1419 """Parse a source file. 1420 1421 On success, store AST in the corresponding attribute of the stub source. 1422 If there are syntax errors, print them and exit. 1423 """ 1424 assert mod.path is not None, "Not found module was not skipped" 1425 with open(mod.path, 'rb') as f: 1426 data = f.read() 1427 source = mypy.util.decode_python_encoding(data, mypy_options.python_version) 1428 errors = Errors() 1429 mod.ast = mypy.parse.parse(source, fnam=mod.path, module=mod.module, 1430 errors=errors, options=mypy_options) 1431 mod.ast._fullname = mod.module 1432 if errors.is_blockers(): 1433 # Syntax error! 1434 for m in errors.new_messages(): 1435 sys.stderr.write('%s\n' % m) 1436 sys.exit(1) 1437 1438 1439def generate_asts_for_modules(py_modules: List[StubSource], 1440 parse_only: bool, 1441 mypy_options: MypyOptions, 1442 verbose: bool) -> None: 1443 """Use mypy to parse (and optionally analyze) source files.""" 1444 if not py_modules: 1445 return # Nothing to do here, but there may be C modules 1446 if verbose: 1447 print('Processing %d files...' % len(py_modules)) 1448 if parse_only: 1449 for mod in py_modules: 1450 parse_source_file(mod, mypy_options) 1451 return 1452 # Perform full semantic analysis of the source set. 1453 try: 1454 res = build([module.source for module in py_modules], mypy_options) 1455 except CompileError as e: 1456 raise SystemExit("Critical error during semantic analysis: {}".format(e)) from e 1457 1458 for mod in py_modules: 1459 mod.ast = res.graph[mod.module].tree 1460 # Use statically inferred __all__ if there is no runtime one. 1461 if mod.runtime_all is None: 1462 mod.runtime_all = res.manager.semantic_analyzer.export_map[mod.module] 1463 1464 1465def generate_stub_from_ast(mod: StubSource, 1466 target: str, 1467 parse_only: bool = False, 1468 pyversion: Tuple[int, int] = defaults.PYTHON3_VERSION, 1469 include_private: bool = False, 1470 export_less: bool = False) -> None: 1471 """Use analysed (or just parsed) AST to generate type stub for single file. 1472 1473 If directory for target doesn't exist it will created. Existing stub 1474 will be overwritten. 1475 """ 1476 gen = StubGenerator(mod.runtime_all, 1477 pyversion=pyversion, 1478 include_private=include_private, 1479 analyzed=not parse_only, 1480 export_less=export_less) 1481 assert mod.ast is not None, "This function must be used only with analyzed modules" 1482 mod.ast.accept(gen) 1483 1484 # Write output to file. 1485 subdir = os.path.dirname(target) 1486 if subdir and not os.path.isdir(subdir): 1487 os.makedirs(subdir) 1488 with open(target, 'w') as file: 1489 file.write(''.join(gen.output())) 1490 1491 1492def collect_docs_signatures(doc_dir: str) -> Tuple[Dict[str, str], Dict[str, str]]: 1493 """Gather all function and class signatures in the docs. 1494 1495 Return a tuple (function signatures, class signatures). 1496 Currently only used for C modules. 1497 """ 1498 all_sigs = [] # type: List[Sig] 1499 all_class_sigs = [] # type: List[Sig] 1500 for path in glob.glob('%s/*.rst' % doc_dir): 1501 with open(path) as f: 1502 loc_sigs, loc_class_sigs = parse_all_signatures(f.readlines()) 1503 all_sigs += loc_sigs 1504 all_class_sigs += loc_class_sigs 1505 sigs = dict(find_unique_signatures(all_sigs)) 1506 class_sigs = dict(find_unique_signatures(all_class_sigs)) 1507 return sigs, class_sigs 1508 1509 1510def generate_stubs(options: Options) -> None: 1511 """Main entry point for the program.""" 1512 mypy_opts = mypy_options(options) 1513 py_modules, c_modules = collect_build_targets(options, mypy_opts) 1514 1515 # Collect info from docs (if given): 1516 sigs = class_sigs = None # type: Optional[Dict[str, str]] 1517 if options.doc_dir: 1518 sigs, class_sigs = collect_docs_signatures(options.doc_dir) 1519 1520 # Use parsed sources to generate stubs for Python modules. 1521 generate_asts_for_modules(py_modules, options.parse_only, mypy_opts, options.verbose) 1522 files = [] 1523 for mod in py_modules: 1524 assert mod.path is not None, "Not found module was not skipped" 1525 target = mod.module.replace('.', '/') 1526 if os.path.basename(mod.path) == '__init__.py': 1527 target += '/__init__.pyi' 1528 else: 1529 target += '.pyi' 1530 target = os.path.join(options.output_dir, target) 1531 files.append(target) 1532 with generate_guarded(mod.module, target, options.ignore_errors, options.verbose): 1533 generate_stub_from_ast(mod, target, 1534 options.parse_only, options.pyversion, 1535 options.include_private, 1536 options.export_less) 1537 1538 # Separately analyse C modules using different logic. 1539 for mod in c_modules: 1540 if any(py_mod.module.startswith(mod.module + '.') 1541 for py_mod in py_modules + c_modules): 1542 target = mod.module.replace('.', '/') + '/__init__.pyi' 1543 else: 1544 target = mod.module.replace('.', '/') + '.pyi' 1545 target = os.path.join(options.output_dir, target) 1546 files.append(target) 1547 with generate_guarded(mod.module, target, options.ignore_errors, options.verbose): 1548 generate_stub_for_c_module(mod.module, target, sigs=sigs, class_sigs=class_sigs) 1549 num_modules = len(py_modules) + len(c_modules) 1550 if not options.quiet and num_modules > 0: 1551 print('Processed %d modules' % num_modules) 1552 if len(files) == 1: 1553 print('Generated %s' % files[0]) 1554 else: 1555 print('Generated files under %s' % common_dir_prefix(files) + os.sep) 1556 1557 1558HEADER = """%(prog)s [-h] [--py2] [more options, see -h] 1559 [-m MODULE] [-p PACKAGE] [files ...]""" 1560 1561DESCRIPTION = """ 1562Generate draft stubs for modules. 1563 1564Stubs are generated in directory ./out, to avoid overriding files with 1565manual changes. This directory is assumed to exist. 1566""" 1567 1568 1569def parse_options(args: List[str]) -> Options: 1570 parser = argparse.ArgumentParser(prog='stubgen', 1571 usage=HEADER, 1572 description=DESCRIPTION) 1573 1574 parser.add_argument('--py2', action='store_true', 1575 help="run in Python 2 mode (default: Python 3 mode)") 1576 parser.add_argument('--ignore-errors', action='store_true', 1577 help="ignore errors when trying to generate stubs for modules") 1578 parser.add_argument('--no-import', action='store_true', 1579 help="don't import the modules, just parse and analyze them " 1580 "(doesn't work with C extension modules and might not " 1581 "respect __all__)") 1582 parser.add_argument('--parse-only', action='store_true', 1583 help="don't perform semantic analysis of sources, just parse them " 1584 "(only applies to Python modules, might affect quality of stubs)") 1585 parser.add_argument('--include-private', action='store_true', 1586 help="generate stubs for objects and members considered private " 1587 "(single leading underscore and no trailing underscores)") 1588 parser.add_argument('--export-less', action='store_true', 1589 help=("don't implicitly export all names imported from other modules " 1590 "in the same package")) 1591 parser.add_argument('-v', '--verbose', action='store_true', 1592 help="show more verbose messages") 1593 parser.add_argument('-q', '--quiet', action='store_true', 1594 help="show fewer messages") 1595 parser.add_argument('--doc-dir', metavar='PATH', default='', 1596 help="use .rst documentation in PATH (this may result in " 1597 "better stubs in some cases; consider setting this to " 1598 "DIR/Python-X.Y.Z/Doc/library)") 1599 parser.add_argument('--search-path', metavar='PATH', default='', 1600 help="specify module search directories, separated by ':' " 1601 "(currently only used if --no-import is given)") 1602 parser.add_argument('--python-executable', metavar='PATH', dest='interpreter', default='', 1603 help="use Python interpreter at PATH (only works for " 1604 "Python 2 right now)") 1605 parser.add_argument('-o', '--output', metavar='PATH', dest='output_dir', default='out', 1606 help="change the output directory [default: %(default)s]") 1607 parser.add_argument('-m', '--module', action='append', metavar='MODULE', 1608 dest='modules', default=[], 1609 help="generate stub for module; can repeat for more modules") 1610 parser.add_argument('-p', '--package', action='append', metavar='PACKAGE', 1611 dest='packages', default=[], 1612 help="generate stubs for package recursively; can be repeated") 1613 parser.add_argument(metavar='files', nargs='*', dest='files', 1614 help="generate stubs for given files or directories") 1615 1616 ns = parser.parse_args(args) 1617 1618 pyversion = defaults.PYTHON2_VERSION if ns.py2 else defaults.PYTHON3_VERSION 1619 if not ns.interpreter: 1620 ns.interpreter = sys.executable if pyversion[0] == 3 else default_py2_interpreter() 1621 if ns.modules + ns.packages and ns.files: 1622 parser.error("May only specify one of: modules/packages or files.") 1623 if ns.quiet and ns.verbose: 1624 parser.error('Cannot specify both quiet and verbose messages') 1625 1626 # Create the output folder if it doesn't already exist. 1627 if not os.path.exists(ns.output_dir): 1628 os.makedirs(ns.output_dir) 1629 1630 return Options(pyversion=pyversion, 1631 no_import=ns.no_import, 1632 doc_dir=ns.doc_dir, 1633 search_path=ns.search_path.split(':'), 1634 interpreter=ns.interpreter, 1635 ignore_errors=ns.ignore_errors, 1636 parse_only=ns.parse_only, 1637 include_private=ns.include_private, 1638 output_dir=ns.output_dir, 1639 modules=ns.modules, 1640 packages=ns.packages, 1641 files=ns.files, 1642 verbose=ns.verbose, 1643 quiet=ns.quiet, 1644 export_less=ns.export_less) 1645 1646 1647def main() -> None: 1648 mypy.util.check_python_version('stubgen') 1649 # Make sure that the current directory is in sys.path so that 1650 # stubgen can be run on packages in the current directory. 1651 if not ('' in sys.path or '.' in sys.path): 1652 sys.path.insert(0, '') 1653 1654 options = parse_options(sys.argv[1:]) 1655 generate_stubs(options) 1656 1657 1658if __name__ == '__main__': 1659 main() 1660