1#!/usr/bin/env python3
2"""Generator of dynamically typed draft stubs for arbitrary modules.
3
4The logic of this script can be split in three steps:
5* parsing options and finding sources:
6  - use runtime imports be default (to find also C modules)
7  - or use mypy's mechanisms, if importing is prohibited
8* (optionally) semantically analysing the sources using mypy (as a single set)
9* emitting the stubs text:
10  - for Python modules: from ASTs using StubGenerator
11  - for C modules using runtime introspection and (optionally) Sphinx docs
12
13During first and third steps some problematic files can be skipped, but any
14blocking error during second step will cause the whole program to stop.
15
16Basic usage:
17
18  $ stubgen foo.py bar.py some_directory
19  => Generate out/foo.pyi, out/bar.pyi, and stubs for some_directory (recursively).
20
21  $ stubgen -m urllib.parse
22  => Generate out/urllib/parse.pyi.
23
24  $ stubgen -p urllib
25  => Generate stubs for whole urlib package (recursively).
26
27For Python 2 mode, use --py2:
28
29  $ stubgen --py2 -m textwrap
30
31For C modules, you can get more precise function signatures by parsing .rst (Sphinx)
32documentation for extra information. For this, use the --doc-dir option:
33
34  $ stubgen --doc-dir <DIR>/Python-3.4.2/Doc/library -m curses
35
36Note: The generated stubs should be verified manually.
37
38TODO:
39 - support stubs for C modules in Python 2 mode
40 - detect 'if PY2 / is_py2' etc. and either preserve those or only include Python 2 or 3 case
41 - maybe use .rst docs also for Python modules
42 - maybe export more imported names if there is no __all__ (this affects ssl.SSLError, for example)
43   - a quick and dirty heuristic would be to turn this on if a module has something like
44     'from x import y as _y'
45 - we don't seem to always detect properties ('closed' in 'io', for example)
46"""
47
48import glob
49import os
50import os.path
51import sys
52import traceback
53import argparse
54from collections import defaultdict
55
56from typing import (
57    List, Dict, Tuple, Iterable, Mapping, Optional, Set, cast,
58)
59from typing_extensions import Final
60
61import mypy.build
62import mypy.parse
63import mypy.errors
64import mypy.traverser
65import mypy.mixedtraverser
66import mypy.util
67from mypy import defaults
68from mypy.modulefinder import (
69    ModuleNotFoundReason, FindModuleCache, SearchPaths, BuildSource, default_lib_path
70)
71from mypy.nodes import (
72    Expression, IntExpr, UnaryExpr, StrExpr, BytesExpr, NameExpr, FloatExpr, MemberExpr,
73    TupleExpr, ListExpr, ComparisonExpr, CallExpr, IndexExpr, EllipsisExpr,
74    ClassDef, MypyFile, Decorator, AssignmentStmt, TypeInfo,
75    IfStmt, ImportAll, ImportFrom, Import, FuncDef, FuncBase, Block,
76    Statement, OverloadedFuncDef, ARG_POS, ARG_STAR, ARG_STAR2, ARG_NAMED, ARG_NAMED_OPT
77)
78from mypy.stubgenc import generate_stub_for_c_module
79from mypy.stubutil import (
80    default_py2_interpreter, CantImport, generate_guarded,
81    walk_packages, find_module_path_and_all_py2, find_module_path_and_all_py3,
82    report_missing, fail_missing, remove_misplaced_type_comments, common_dir_prefix
83)
84from mypy.stubdoc import parse_all_signatures, find_unique_signatures, Sig
85from mypy.options import Options as MypyOptions
86from mypy.types import (
87    Type, TypeStrVisitor, CallableType, UnboundType, NoneType, TupleType, TypeList, Instance,
88    AnyType, get_proper_type
89)
90from mypy.visitor import NodeVisitor
91from mypy.find_sources import create_source_list, InvalidSourceList
92from mypy.build import build
93from mypy.errors import CompileError, Errors
94from mypy.traverser import has_return_statement
95from mypy.moduleinspect import ModuleInspect
96
97
98# Common ways of naming package containing vendored modules.
99VENDOR_PACKAGES = [
100    'packages',
101    'vendor',
102    'vendored',
103    '_vendor',
104    '_vendored_packages',
105]  # type: Final
106
107# Avoid some file names that are unnecessary or likely to cause trouble (\n for end of path).
108BLACKLIST = [
109    '/six.py\n',  # Likely vendored six; too dynamic for us to handle
110    '/vendored/',  # Vendored packages
111    '/vendor/',  # Vendored packages
112    '/_vendor/',
113    '/_vendored_packages/',
114]  # type: Final
115
116# Special-cased names that are implicitly exported from the stub (from m import y as y).
117EXTRA_EXPORTED = {
118    'pyasn1_modules.rfc2437.univ',
119    'pyasn1_modules.rfc2459.char',
120    'pyasn1_modules.rfc2459.univ',
121}  # type: Final
122
123# These names should be omitted from generated stubs.
124IGNORED_DUNDERS = {
125    '__all__',
126    '__author__',
127    '__version__',
128    '__about__',
129    '__copyright__',
130    '__email__',
131    '__license__',
132    '__summary__',
133    '__title__',
134    '__uri__',
135    '__str__',
136    '__repr__',
137    '__getstate__',
138    '__setstate__',
139    '__slots__',
140}  # type: Final
141
142# These methods are expected to always return a non-trivial value.
143METHODS_WITH_RETURN_VALUE = {
144    '__ne__',
145    '__eq__',
146    '__lt__',
147    '__le__',
148    '__gt__',
149    '__ge__',
150    '__hash__',
151    '__iter__',
152}  # type: Final
153
154
155class Options:
156    """Represents stubgen options.
157
158    This class is mutable to simplify testing.
159    """
160    def __init__(self,
161                 pyversion: Tuple[int, int],
162                 no_import: bool,
163                 doc_dir: str,
164                 search_path: List[str],
165                 interpreter: str,
166                 parse_only: bool,
167                 ignore_errors: bool,
168                 include_private: bool,
169                 output_dir: str,
170                 modules: List[str],
171                 packages: List[str],
172                 files: List[str],
173                 verbose: bool,
174                 quiet: bool,
175                 export_less: bool) -> None:
176        # See parse_options for descriptions of the flags.
177        self.pyversion = pyversion
178        self.no_import = no_import
179        self.doc_dir = doc_dir
180        self.search_path = search_path
181        self.interpreter = interpreter
182        self.decointerpreter = interpreter
183        self.parse_only = parse_only
184        self.ignore_errors = ignore_errors
185        self.include_private = include_private
186        self.output_dir = output_dir
187        self.modules = modules
188        self.packages = packages
189        self.files = files
190        self.verbose = verbose
191        self.quiet = quiet
192        self.export_less = export_less
193
194
195class StubSource:
196    """A single source for stub: can be a Python or C module.
197
198    A simple extension of BuildSource that also carries the AST and
199    the value of __all__ detected at runtime.
200    """
201    def __init__(self, module: str, path: Optional[str] = None,
202                 runtime_all: Optional[List[str]] = None) -> None:
203        self.source = BuildSource(path, module, None)
204        self.runtime_all = runtime_all
205        self.ast = None  # type: Optional[MypyFile]
206
207    @property
208    def module(self) -> str:
209        return self.source.module
210
211    @property
212    def path(self) -> Optional[str]:
213        return self.source.path
214
215
216# What was generated previously in the stub file. We keep track of these to generate
217# nicely formatted output (add empty line between non-empty classes, for example).
218EMPTY = 'EMPTY'  # type: Final
219FUNC = 'FUNC'  # type: Final
220CLASS = 'CLASS'  # type: Final
221EMPTY_CLASS = 'EMPTY_CLASS'  # type: Final
222VAR = 'VAR'  # type: Final
223NOT_IN_ALL = 'NOT_IN_ALL'  # type: Final
224
225# Indicates that we failed to generate a reasonable output
226# for a given node. These should be manually replaced by a user.
227
228ERROR_MARKER = '<ERROR>'  # type: Final
229
230
231class AnnotationPrinter(TypeStrVisitor):
232    """Visitor used to print existing annotations in a file.
233
234    The main difference from TypeStrVisitor is a better treatment of
235    unbound types.
236
237    Notes:
238    * This visitor doesn't add imports necessary for annotations, this is done separately
239      by ImportTracker.
240    * It can print all kinds of types, but the generated strings may not be valid (notably
241      callable types) since it prints the same string that reveal_type() does.
242    * For Instance types it prints the fully qualified names.
243    """
244    # TODO: Generate valid string representation for callable types.
245    # TODO: Use short names for Instances.
246    def __init__(self, stubgen: 'StubGenerator') -> None:
247        super().__init__()
248        self.stubgen = stubgen
249
250    def visit_any(self, t: AnyType) -> str:
251        s = super().visit_any(t)
252        self.stubgen.import_tracker.require_name(s)
253        return s
254
255    def visit_unbound_type(self, t: UnboundType) -> str:
256        s = t.name
257        self.stubgen.import_tracker.require_name(s)
258        if t.args:
259            s += '[{}]'.format(self.list_str(t.args))
260        return s
261
262    def visit_none_type(self, t: NoneType) -> str:
263        return "None"
264
265    def visit_type_list(self, t: TypeList) -> str:
266        return '[{}]'.format(self.list_str(t.items))
267
268
269class AliasPrinter(NodeVisitor[str]):
270    """Visitor used to collect type aliases _and_ type variable definitions.
271
272    Visit r.h.s of the definition to get the string representation of type alias.
273    """
274    def __init__(self, stubgen: 'StubGenerator') -> None:
275        self.stubgen = stubgen
276        super().__init__()
277
278    def visit_call_expr(self, node: CallExpr) -> str:
279        # Call expressions are not usually types, but we also treat `X = TypeVar(...)` as a
280        # type alias that has to be preserved (even if TypeVar is not the same as an alias)
281        callee = node.callee.accept(self)
282        args = []
283        for name, arg, kind in zip(node.arg_names, node.args, node.arg_kinds):
284            if kind == ARG_POS:
285                args.append(arg.accept(self))
286            elif kind == ARG_STAR:
287                args.append('*' + arg.accept(self))
288            elif kind == ARG_STAR2:
289                args.append('**' + arg.accept(self))
290            elif kind == ARG_NAMED:
291                args.append('{}={}'.format(name, arg.accept(self)))
292            else:
293                raise ValueError("Unknown argument kind %d in call" % kind)
294        return "{}({})".format(callee, ", ".join(args))
295
296    def visit_name_expr(self, node: NameExpr) -> str:
297        self.stubgen.import_tracker.require_name(node.name)
298        return node.name
299
300    def visit_member_expr(self, o: MemberExpr) -> str:
301        node = o  # type: Expression
302        trailer = ''
303        while isinstance(node, MemberExpr):
304            trailer = '.' + node.name + trailer
305            node = node.expr
306        if not isinstance(node, NameExpr):
307            return ERROR_MARKER
308        self.stubgen.import_tracker.require_name(node.name)
309        return node.name + trailer
310
311    def visit_str_expr(self, node: StrExpr) -> str:
312        return repr(node.value)
313
314    def visit_index_expr(self, node: IndexExpr) -> str:
315        base = node.base.accept(self)
316        index = node.index.accept(self)
317        return "{}[{}]".format(base, index)
318
319    def visit_tuple_expr(self, node: TupleExpr) -> str:
320        return ", ".join(n.accept(self) for n in node.items)
321
322    def visit_list_expr(self, node: ListExpr) -> str:
323        return "[{}]".format(", ".join(n.accept(self) for n in node.items))
324
325    def visit_ellipsis(self, node: EllipsisExpr) -> str:
326        return "..."
327
328
329class ImportTracker:
330    """Record necessary imports during stub generation."""
331
332    def __init__(self) -> None:
333        # module_for['foo'] has the module name where 'foo' was imported from, or None if
334        # 'foo' is a module imported directly; examples
335        #     'from pkg.m import f as foo' ==> module_for['foo'] == 'pkg.m'
336        #     'from m import f' ==> module_for['f'] == 'm'
337        #     'import m' ==> module_for['m'] == None
338        #     'import pkg.m' ==> module_for['pkg.m'] == None
339        #                    ==> module_for['pkg'] == None
340        self.module_for = {}  # type: Dict[str, Optional[str]]
341
342        # direct_imports['foo'] is the module path used when the name 'foo' was added to the
343        # namespace.
344        #   import foo.bar.baz  ==> direct_imports['foo'] == 'foo.bar.baz'
345        #                       ==> direct_imports['foo.bar'] == 'foo.bar.baz'
346        #                       ==> direct_imports['foo.bar.baz'] == 'foo.bar.baz'
347        self.direct_imports = {}  # type: Dict[str, str]
348
349        # reverse_alias['foo'] is the name that 'foo' had originally when imported with an
350        # alias; examples
351        #     'import numpy as np' ==> reverse_alias['np'] == 'numpy'
352        #     'import foo.bar as bar' ==> reverse_alias['bar'] == 'foo.bar'
353        #     'from decimal import Decimal as D' ==> reverse_alias['D'] == 'Decimal'
354        self.reverse_alias = {}  # type: Dict[str, str]
355
356        # required_names is the set of names that are actually used in a type annotation
357        self.required_names = set()  # type: Set[str]
358
359        # Names that should be reexported if they come from another module
360        self.reexports = set()  # type: Set[str]
361
362    def add_import_from(self, module: str, names: List[Tuple[str, Optional[str]]]) -> None:
363        for name, alias in names:
364            if alias:
365                # 'from {module} import {name} as {alias}'
366                self.module_for[alias] = module
367                self.reverse_alias[alias] = name
368            else:
369                # 'from {module} import {name}'
370                self.module_for[name] = module
371                self.reverse_alias.pop(name, None)
372            self.direct_imports.pop(alias or name, None)
373
374    def add_import(self, module: str, alias: Optional[str] = None) -> None:
375        if alias:
376            # 'import {module} as {alias}'
377            self.module_for[alias] = None
378            self.reverse_alias[alias] = module
379        else:
380            # 'import {module}'
381            name = module
382            # add module and its parent packages
383            while name:
384                self.module_for[name] = None
385                self.direct_imports[name] = module
386                self.reverse_alias.pop(name, None)
387                name = name.rpartition('.')[0]
388
389    def require_name(self, name: str) -> None:
390        self.required_names.add(name.split('.')[0])
391
392    def reexport(self, name: str) -> None:
393        """Mark a given non qualified name as needed in __all__.
394
395        This means that in case it comes from a module, it should be
396        imported with an alias even is the alias is the same as the name.
397        """
398        self.require_name(name)
399        self.reexports.add(name)
400
401    def import_lines(self) -> List[str]:
402        """The list of required import lines (as strings with python code)."""
403        result = []
404
405        # To summarize multiple names imported from a same module, we collect those
406        # in the `module_map` dictionary, mapping a module path to the list of names that should
407        # be imported from it. the names can also be alias in the form 'original as alias'
408        module_map = defaultdict(list)  # type: Mapping[str, List[str]]
409
410        for name in sorted(self.required_names):
411            # If we haven't seen this name in an import statement, ignore it
412            if name not in self.module_for:
413                continue
414
415            m = self.module_for[name]
416            if m is not None:
417                # This name was found in a from ... import ...
418                # Collect the name in the module_map
419                if name in self.reverse_alias:
420                    name = '{} as {}'.format(self.reverse_alias[name], name)
421                elif name in self.reexports:
422                    name = '{} as {}'.format(name, name)
423                module_map[m].append(name)
424            else:
425                # This name was found in an import ...
426                # We can already generate the import line
427                if name in self.reverse_alias:
428                    source = self.reverse_alias[name]
429                    result.append("import {} as {}\n".format(source, name))
430                elif name in self.reexports:
431                    assert '.' not in name  # Because reexports only has nonqualified names
432                    result.append("import {} as {}\n".format(name, name))
433                else:
434                    result.append("import {}\n".format(self.direct_imports[name]))
435
436        # Now generate all the from ... import ... lines collected in module_map
437        for module, names in sorted(module_map.items()):
438            result.append("from {} import {}\n".format(module, ', '.join(sorted(names))))
439        return result
440
441
442def find_defined_names(file: MypyFile) -> Set[str]:
443    finder = DefinitionFinder()
444    file.accept(finder)
445    return finder.names
446
447
448class DefinitionFinder(mypy.traverser.TraverserVisitor):
449    """Find names of things defined at the top level of a module."""
450
451    # TODO: Assignment statements etc.
452
453    def __init__(self) -> None:
454        # Short names of things defined at the top level.
455        self.names = set()  # type: Set[str]
456
457    def visit_class_def(self, o: ClassDef) -> None:
458        # Don't recurse into classes, as we only keep track of top-level definitions.
459        self.names.add(o.name)
460
461    def visit_func_def(self, o: FuncDef) -> None:
462        # Don't recurse, as we only keep track of top-level definitions.
463        self.names.add(o.name)
464
465
466def find_referenced_names(file: MypyFile) -> Set[str]:
467    finder = ReferenceFinder()
468    file.accept(finder)
469    return finder.refs
470
471
472class ReferenceFinder(mypy.mixedtraverser.MixedTraverserVisitor):
473    """Find all name references (both local and global)."""
474
475    # TODO: Filter out local variable and class attribute references
476
477    def __init__(self) -> None:
478        # Short names of things defined at the top level.
479        self.refs = set()  # type: Set[str]
480
481    def visit_block(self, block: Block) -> None:
482        if not block.is_unreachable:
483            super().visit_block(block)
484
485    def visit_name_expr(self, e: NameExpr) -> None:
486        self.refs.add(e.name)
487
488    def visit_instance(self, t: Instance) -> None:
489        self.add_ref(t.type.fullname)
490        super().visit_instance(t)
491
492    def visit_unbound_type(self, t: UnboundType) -> None:
493        if t.name:
494            self.add_ref(t.name)
495
496    def visit_tuple_type(self, t: TupleType) -> None:
497        # Ignore fallback
498        for item in t.items:
499            item.accept(self)
500
501    def visit_callable_type(self, t: CallableType) -> None:
502        # Ignore fallback
503        for arg in t.arg_types:
504            arg.accept(self)
505        t.ret_type.accept(self)
506
507    def add_ref(self, fullname: str) -> None:
508        self.refs.add(fullname.split('.')[-1])
509
510
511class StubGenerator(mypy.traverser.TraverserVisitor):
512    """Generate stub text from a mypy AST."""
513
514    def __init__(self,
515                 _all_: Optional[List[str]], pyversion: Tuple[int, int],
516                 include_private: bool = False,
517                 analyzed: bool = False,
518                 export_less: bool = False) -> None:
519        # Best known value of __all__.
520        self._all_ = _all_
521        self._output = []  # type: List[str]
522        self._decorators = []  # type: List[str]
523        self._import_lines = []  # type: List[str]
524        # Current indent level (indent is hardcoded to 4 spaces).
525        self._indent = ''
526        # Stack of defined variables (per scope).
527        self._vars = [[]]  # type: List[List[str]]
528        # What was generated previously in the stub file.
529        self._state = EMPTY
530        self._toplevel_names = []  # type: List[str]
531        self._pyversion = pyversion
532        self._include_private = include_private
533        self.import_tracker = ImportTracker()
534        # Was the tree semantically analysed before?
535        self.analyzed = analyzed
536        # Disable implicit exports of package-internal imports?
537        self.export_less = export_less
538        # Add imports that could be implicitly generated
539        self.import_tracker.add_import_from("typing", [("NamedTuple", None)])
540        # Names in __all__ are required
541        for name in _all_ or ():
542            if name not in IGNORED_DUNDERS:
543                self.import_tracker.reexport(name)
544        self.defined_names = set()  # type: Set[str]
545        # Short names of methods defined in the body of the current class
546        self.method_names = set()  # type: Set[str]
547
548    def visit_mypy_file(self, o: MypyFile) -> None:
549        self.module = o.fullname  # Current module being processed
550        self.path = o.path
551        self.defined_names = find_defined_names(o)
552        self.referenced_names = find_referenced_names(o)
553        typing_imports = ["Any", "Optional", "TypeVar"]
554        for t in typing_imports:
555            if t not in self.defined_names:
556                alias = None
557            else:
558                alias = '_' + t
559            self.import_tracker.add_import_from("typing", [(t, alias)])
560        super().visit_mypy_file(o)
561        undefined_names = [name for name in self._all_ or []
562                           if name not in self._toplevel_names]
563        if undefined_names:
564            if self._state != EMPTY:
565                self.add('\n')
566            self.add('# Names in __all__ with no definition:\n')
567            for name in sorted(undefined_names):
568                self.add('#   %s\n' % name)
569
570    def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None:
571        """@property with setters and getters, or @overload chain"""
572        overload_chain = False
573        for item in o.items:
574            if not isinstance(item, Decorator):
575                continue
576
577            if self.is_private_name(item.func.name, item.func.fullname):
578                continue
579
580            is_abstract, is_overload = self.process_decorator(item)
581
582            if not overload_chain:
583                self.visit_func_def(item.func, is_abstract=is_abstract, is_overload=is_overload)
584                if is_overload:
585                    overload_chain = True
586            elif overload_chain and is_overload:
587                self.visit_func_def(item.func, is_abstract=is_abstract, is_overload=is_overload)
588            else:
589                # skip the overload implementation and clear the decorator we just processed
590                self.clear_decorators()
591
592    def visit_func_def(self, o: FuncDef, is_abstract: bool = False,
593                       is_overload: bool = False) -> None:
594        if (self.is_private_name(o.name, o.fullname)
595                or self.is_not_in_all(o.name)
596                or (self.is_recorded_name(o.name) and not is_overload)):
597            self.clear_decorators()
598            return
599        if not self._indent and self._state not in (EMPTY, FUNC) and not o.is_awaitable_coroutine:
600            self.add('\n')
601        if not self.is_top_level():
602            self_inits = find_self_initializers(o)
603            for init, value in self_inits:
604                if init in self.method_names:
605                    # Can't have both an attribute and a method/property with the same name.
606                    continue
607                init_code = self.get_init(init, value)
608                if init_code:
609                    self.add(init_code)
610        # dump decorators, just before "def ..."
611        for s in self._decorators:
612            self.add(s)
613        self.clear_decorators()
614        self.add("%s%sdef %s(" % (self._indent, 'async ' if o.is_coroutine else '', o.name))
615        self.record_name(o.name)
616        args = []  # type: List[str]
617        for i, arg_ in enumerate(o.arguments):
618            var = arg_.variable
619            kind = arg_.kind
620            name = var.name
621            annotated_type = (o.unanalyzed_type.arg_types[i]
622                              if isinstance(o.unanalyzed_type, CallableType) else None)
623            # I think the name check is incorrect: there are libraries which
624            # name their 0th argument other than self/cls
625            is_self_arg = i == 0 and name == 'self'
626            is_cls_arg = i == 0 and name == 'cls'
627            annotation = ""
628            if annotated_type and not is_self_arg and not is_cls_arg:
629                # Luckily, an argument explicitly annotated with "Any" has
630                # type "UnboundType" and will not match.
631                if not isinstance(get_proper_type(annotated_type), AnyType):
632                    annotation = ": {}".format(self.print_annotation(annotated_type))
633            if arg_.initializer:
634                if kind in (ARG_NAMED, ARG_NAMED_OPT) and not any(arg.startswith('*')
635                                                                  for arg in args):
636                    args.append('*')
637                if not annotation:
638                    typename = self.get_str_type_of_node(arg_.initializer, True, False)
639                    if typename == '':
640                        annotation = '=...'
641                    else:
642                        annotation = ': {} = ...'.format(typename)
643                else:
644                    annotation += ' = ...'
645                arg = name + annotation
646            elif kind == ARG_STAR:
647                arg = '*%s%s' % (name, annotation)
648            elif kind == ARG_STAR2:
649                arg = '**%s%s' % (name, annotation)
650            else:
651                arg = name + annotation
652            args.append(arg)
653        retname = None
654        if o.name != '__init__' and isinstance(o.unanalyzed_type, CallableType):
655            if isinstance(get_proper_type(o.unanalyzed_type.ret_type), AnyType):
656                # Luckily, a return type explicitly annotated with "Any" has
657                # type "UnboundType" and will enter the else branch.
658                retname = None  # implicit Any
659            else:
660                retname = self.print_annotation(o.unanalyzed_type.ret_type)
661        elif isinstance(o, FuncDef) and (o.is_abstract or o.name in METHODS_WITH_RETURN_VALUE):
662            # Always assume abstract methods return Any unless explicitly annotated. Also
663            # some dunder methods should not have a None return type.
664            retname = None  # implicit Any
665        elif not has_return_statement(o) and not is_abstract:
666            retname = 'None'
667        retfield = ''
668        if retname is not None:
669            retfield = ' -> ' + retname
670
671        self.add(', '.join(args))
672        self.add("){}: ...\n".format(retfield))
673        self._state = FUNC
674
675    def visit_decorator(self, o: Decorator) -> None:
676        if self.is_private_name(o.func.name, o.func.fullname):
677            return
678
679        is_abstract, _ = self.process_decorator(o)
680        self.visit_func_def(o.func, is_abstract=is_abstract)
681
682    def process_decorator(self, o: Decorator) -> Tuple[bool, bool]:
683        """Process a series of decorators.
684
685        Only preserve certain special decorators such as @abstractmethod.
686
687        Return a pair of booleans:
688        - True if any of the decorators makes a method abstract.
689        - True if any of the decorators is typing.overload.
690        """
691        is_abstract = False
692        is_overload = False
693        for decorator in o.original_decorators:
694            if isinstance(decorator, NameExpr):
695                i_is_abstract, i_is_overload = self.process_name_expr_decorator(decorator, o)
696                is_abstract = is_abstract or i_is_abstract
697                is_overload = is_overload or i_is_overload
698            elif isinstance(decorator, MemberExpr):
699                i_is_abstract, i_is_overload = self.process_member_expr_decorator(decorator, o)
700                is_abstract = is_abstract or i_is_abstract
701                is_overload = is_overload or i_is_overload
702        return is_abstract, is_overload
703
704    def process_name_expr_decorator(self, expr: NameExpr, context: Decorator) -> Tuple[bool, bool]:
705        """Process a function decorator of form @foo.
706
707        Only preserve certain special decorators such as @abstractmethod.
708
709        Return a pair of booleans:
710        - True if the decorator makes a method abstract.
711        - True if the decorator is typing.overload.
712        """
713        is_abstract = False
714        is_overload = False
715        name = expr.name
716        if name in ('property', 'staticmethod', 'classmethod'):
717            self.add_decorator(name)
718        elif self.import_tracker.module_for.get(name) in ('asyncio',
719                                                          'asyncio.coroutines',
720                                                          'types'):
721            self.add_coroutine_decorator(context.func, name, name)
722        elif self.refers_to_fullname(name, 'abc.abstractmethod'):
723            self.add_decorator(name)
724            self.import_tracker.require_name(name)
725            is_abstract = True
726        elif self.refers_to_fullname(name, 'abc.abstractproperty'):
727            self.add_decorator('property')
728            self.add_decorator('abc.abstractmethod')
729            is_abstract = True
730        elif self.refers_to_fullname(name, 'typing.overload'):
731            self.add_decorator(name)
732            self.add_typing_import('overload')
733            is_overload = True
734        return is_abstract, is_overload
735
736    def refers_to_fullname(self, name: str, fullname: str) -> bool:
737        module, short = fullname.rsplit('.', 1)
738        return (self.import_tracker.module_for.get(name) == module and
739                (name == short or
740                 self.import_tracker.reverse_alias.get(name) == short))
741
742    def process_member_expr_decorator(self, expr: MemberExpr, context: Decorator) -> Tuple[bool,
743                                                                                           bool]:
744        """Process a function decorator of form @foo.bar.
745
746        Only preserve certain special decorators such as @abstractmethod.
747
748        Return a pair of booleans:
749        - True if the decorator makes a method abstract.
750        - True if the decorator is typing.overload.
751        """
752        is_abstract = False
753        is_overload = False
754        if expr.name == 'setter' and isinstance(expr.expr, NameExpr):
755            self.add_decorator('%s.setter' % expr.expr.name)
756        elif (isinstance(expr.expr, NameExpr) and
757              (expr.expr.name == 'abc' or
758               self.import_tracker.reverse_alias.get(expr.expr.name) == 'abc') and
759              expr.name in ('abstractmethod', 'abstractproperty')):
760            if expr.name == 'abstractproperty':
761                self.import_tracker.require_name(expr.expr.name)
762                self.add_decorator('%s' % ('property'))
763                self.add_decorator('%s.%s' % (expr.expr.name, 'abstractmethod'))
764            else:
765                self.import_tracker.require_name(expr.expr.name)
766                self.add_decorator('%s.%s' % (expr.expr.name, expr.name))
767            is_abstract = True
768        elif expr.name == 'coroutine':
769            if (isinstance(expr.expr, MemberExpr) and
770                expr.expr.name == 'coroutines' and
771                isinstance(expr.expr.expr, NameExpr) and
772                    (expr.expr.expr.name == 'asyncio' or
773                     self.import_tracker.reverse_alias.get(expr.expr.expr.name) ==
774                        'asyncio')):
775                self.add_coroutine_decorator(context.func,
776                                             '%s.coroutines.coroutine' %
777                                             (expr.expr.expr.name,),
778                                             expr.expr.expr.name)
779            elif (isinstance(expr.expr, NameExpr) and
780                  (expr.expr.name in ('asyncio', 'types') or
781                   self.import_tracker.reverse_alias.get(expr.expr.name) in
782                    ('asyncio', 'asyncio.coroutines', 'types'))):
783                self.add_coroutine_decorator(context.func,
784                                             expr.expr.name + '.coroutine',
785                                             expr.expr.name)
786        elif (isinstance(expr.expr, NameExpr) and
787              (expr.expr.name == 'typing' or
788               self.import_tracker.reverse_alias.get(expr.expr.name) == 'typing') and
789              expr.name == 'overload'):
790            self.import_tracker.require_name(expr.expr.name)
791            self.add_decorator('%s.%s' % (expr.expr.name, 'overload'))
792            is_overload = True
793        return is_abstract, is_overload
794
795    def visit_class_def(self, o: ClassDef) -> None:
796        self.method_names = find_method_names(o.defs.body)
797        sep = None  # type: Optional[int]
798        if not self._indent and self._state != EMPTY:
799            sep = len(self._output)
800            self.add('\n')
801        self.add('%sclass %s' % (self._indent, o.name))
802        self.record_name(o.name)
803        base_types = self.get_base_types(o)
804        if base_types:
805            for base in base_types:
806                self.import_tracker.require_name(base)
807        if isinstance(o.metaclass, (NameExpr, MemberExpr)):
808            meta = o.metaclass.accept(AliasPrinter(self))
809            base_types.append('metaclass=' + meta)
810        elif self.analyzed and o.info.is_abstract:
811            base_types.append('metaclass=abc.ABCMeta')
812            self.import_tracker.add_import('abc')
813            self.import_tracker.require_name('abc')
814        if base_types:
815            self.add('(%s)' % ', '.join(base_types))
816        self.add(':\n')
817        n = len(self._output)
818        self._indent += '    '
819        self._vars.append([])
820        super().visit_class_def(o)
821        self._indent = self._indent[:-4]
822        self._vars.pop()
823        self._vars[-1].append(o.name)
824        if len(self._output) == n:
825            if self._state == EMPTY_CLASS and sep is not None:
826                self._output[sep] = ''
827            self._output[-1] = self._output[-1][:-1] + ' ...\n'
828            self._state = EMPTY_CLASS
829        else:
830            self._state = CLASS
831        self.method_names = set()
832
833    def get_base_types(self, cdef: ClassDef) -> List[str]:
834        """Get list of base classes for a class."""
835        base_types = []  # type: List[str]
836        for base in cdef.base_type_exprs:
837            if isinstance(base, NameExpr):
838                if base.name != 'object':
839                    base_types.append(base.name)
840            elif isinstance(base, MemberExpr):
841                modname = get_qualified_name(base.expr)
842                base_types.append('%s.%s' % (modname, base.name))
843            elif isinstance(base, IndexExpr):
844                p = AliasPrinter(self)
845                base_types.append(base.accept(p))
846        return base_types
847
848    def visit_block(self, o: Block) -> None:
849        # Unreachable statements may be partially uninitialized and that may
850        # cause trouble.
851        if not o.is_unreachable:
852            super().visit_block(o)
853
854    def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
855        foundl = []
856
857        for lvalue in o.lvalues:
858            if isinstance(lvalue, NameExpr) and self.is_namedtuple(o.rvalue):
859                assert isinstance(o.rvalue, CallExpr)
860                self.process_namedtuple(lvalue, o.rvalue)
861                continue
862            if (self.is_top_level() and
863                    isinstance(lvalue, NameExpr) and not self.is_private_name(lvalue.name) and
864                    # it is never an alias with explicit annotation
865                    not o.unanalyzed_type and self.is_alias_expression(o.rvalue)):
866                self.process_typealias(lvalue, o.rvalue)
867                continue
868            if isinstance(lvalue, TupleExpr) or isinstance(lvalue, ListExpr):
869                items = lvalue.items
870                if isinstance(o.unanalyzed_type, TupleType):  # type: ignore
871                    annotations = o.unanalyzed_type.items  # type: Iterable[Optional[Type]]
872                else:
873                    annotations = [None] * len(items)
874            else:
875                items = [lvalue]
876                annotations = [o.unanalyzed_type]
877            sep = False
878            found = False
879            for item, annotation in zip(items, annotations):
880                if isinstance(item, NameExpr):
881                    init = self.get_init(item.name, o.rvalue, annotation)
882                    if init:
883                        found = True
884                        if not sep and not self._indent and \
885                                self._state not in (EMPTY, VAR):
886                            init = '\n' + init
887                            sep = True
888                        self.add(init)
889                        self.record_name(item.name)
890            foundl.append(found)
891
892        if all(foundl):
893            self._state = VAR
894
895    def is_namedtuple(self, expr: Expression) -> bool:
896        if not isinstance(expr, CallExpr):
897            return False
898        callee = expr.callee
899        return ((isinstance(callee, NameExpr) and callee.name.endswith('namedtuple')) or
900                (isinstance(callee, MemberExpr) and callee.name == 'namedtuple'))
901
902    def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
903        if self._state != EMPTY:
904            self.add('\n')
905        if isinstance(rvalue.args[1], StrExpr):
906            items = rvalue.args[1].value.split(" ")
907        elif isinstance(rvalue.args[1], (ListExpr, TupleExpr)):
908            list_items = cast(List[StrExpr], rvalue.args[1].items)
909            items = [item.value for item in list_items]
910        else:
911            self.add('%s%s: Any' % (self._indent, lvalue.name))
912            self.import_tracker.require_name('Any')
913            return
914        self.import_tracker.require_name('NamedTuple')
915        self.add('{}class {}(NamedTuple):'.format(self._indent, lvalue.name))
916        if len(items) == 0:
917            self.add(' ...\n')
918        else:
919            self.import_tracker.require_name('Any')
920            self.add('\n')
921            for item in items:
922                self.add('{}    {}: Any\n'.format(self._indent, item))
923        self._state = CLASS
924
925    def is_alias_expression(self, expr: Expression, top_level: bool = True) -> bool:
926        """Return True for things that look like target for an alias.
927
928        Used to know if assignments look like type aliases, function alias,
929        or module alias.
930        """
931        # Assignment of TypeVar(...) are passed through
932        if (isinstance(expr, CallExpr) and
933                isinstance(expr.callee, NameExpr) and
934                expr.callee.name == 'TypeVar'):
935            return True
936        elif isinstance(expr, EllipsisExpr):
937            return not top_level
938        elif isinstance(expr, NameExpr):
939            if expr.name in ('True', 'False'):
940                return False
941            elif expr.name == 'None':
942                return not top_level
943            else:
944                return not self.is_private_name(expr.name)
945        elif isinstance(expr, MemberExpr) and self.analyzed:
946            # Also add function and module aliases.
947            return ((top_level and isinstance(expr.node, (FuncDef, Decorator, MypyFile))
948                     or isinstance(expr.node, TypeInfo)) and
949                    not self.is_private_member(expr.node.fullname))
950        elif (isinstance(expr, IndexExpr) and isinstance(expr.base, NameExpr) and
951              not self.is_private_name(expr.base.name)):
952            if isinstance(expr.index, TupleExpr):
953                indices = expr.index.items
954            else:
955                indices = [expr.index]
956            if expr.base.name == 'Callable' and len(indices) == 2:
957                args, ret = indices
958                if isinstance(args, EllipsisExpr):
959                    indices = [ret]
960                elif isinstance(args, ListExpr):
961                    indices = args.items + [ret]
962                else:
963                    return False
964            return all(self.is_alias_expression(i, top_level=False) for i in indices)
965        else:
966            return False
967
968    def process_typealias(self, lvalue: NameExpr, rvalue: Expression) -> None:
969        p = AliasPrinter(self)
970        self.add("{} = {}\n".format(lvalue.name, rvalue.accept(p)))
971        self.record_name(lvalue.name)
972        self._vars[-1].append(lvalue.name)
973
974    def visit_if_stmt(self, o: IfStmt) -> None:
975        # Ignore if __name__ == '__main__'.
976        expr = o.expr[0]
977        if (isinstance(expr, ComparisonExpr) and
978                isinstance(expr.operands[0], NameExpr) and
979                isinstance(expr.operands[1], StrExpr) and
980                expr.operands[0].name == '__name__' and
981                '__main__' in expr.operands[1].value):
982            return
983        super().visit_if_stmt(o)
984
985    def visit_import_all(self, o: ImportAll) -> None:
986        self.add_import_line('from %s%s import *\n' % ('.' * o.relative, o.id))
987
988    def visit_import_from(self, o: ImportFrom) -> None:
989        exported_names = set()  # type: Set[str]
990        import_names = []
991        module, relative = translate_module_name(o.id, o.relative)
992        if self.module:
993            full_module, ok = mypy.util.correct_relative_import(
994                self.module, relative, module, self.path.endswith('.__init__.py')
995            )
996            if not ok:
997                full_module = module
998        else:
999            full_module = module
1000        if module == '__future__':
1001            return  # Not preserved
1002        for name, as_name in o.names:
1003            if name == 'six':
1004                # Vendored six -- translate into plain 'import six'.
1005                self.visit_import(Import([('six', None)]))
1006                continue
1007            exported = False
1008            if as_name is None and self.module and (self.module + '.' + name) in EXTRA_EXPORTED:
1009                # Special case certain names that should be exported, against our general rules.
1010                exported = True
1011            is_private = self.is_private_name(name, full_module + '.' + name)
1012            if (as_name is None
1013                    and name not in self.referenced_names
1014                    and (not self._all_ or name in IGNORED_DUNDERS)
1015                    and not is_private
1016                    and module not in ('abc', 'typing', 'asyncio')):
1017                # An imported name that is never referenced in the module is assumed to be
1018                # exported, unless there is an explicit __all__. Note that we need to special
1019                # case 'abc' since some references are deleted during semantic analysis.
1020                exported = True
1021            top_level = full_module.split('.')[0]
1022            if (as_name is None
1023                    and not self.export_less
1024                    and (not self._all_ or name in IGNORED_DUNDERS)
1025                    and self.module
1026                    and not is_private
1027                    and top_level in (self.module.split('.')[0],
1028                                      '_' + self.module.split('.')[0])):
1029                # Export imports from the same package, since we can't reliably tell whether they
1030                # are part of the public API.
1031                exported = True
1032            if exported:
1033                self.import_tracker.reexport(name)
1034                as_name = name
1035            import_names.append((name, as_name))
1036        self.import_tracker.add_import_from('.' * relative + module, import_names)
1037        self._vars[-1].extend(alias or name for name, alias in import_names)
1038        for name, alias in import_names:
1039            self.record_name(alias or name)
1040
1041        if self._all_:
1042            # Include import froms that import names defined in __all__.
1043            names = [name for name, alias in o.names
1044                     if name in self._all_ and alias is None and name not in IGNORED_DUNDERS]
1045            exported_names.update(names)
1046
1047    def visit_import(self, o: Import) -> None:
1048        for id, as_id in o.ids:
1049            self.import_tracker.add_import(id, as_id)
1050            if as_id is None:
1051                target_name = id.split('.')[0]
1052            else:
1053                target_name = as_id
1054            self._vars[-1].append(target_name)
1055            self.record_name(target_name)
1056
1057    def get_init(self, lvalue: str, rvalue: Expression,
1058                 annotation: Optional[Type] = None) -> Optional[str]:
1059        """Return initializer for a variable.
1060
1061        Return None if we've generated one already or if the variable is internal.
1062        """
1063        if lvalue in self._vars[-1]:
1064            # We've generated an initializer already for this variable.
1065            return None
1066        # TODO: Only do this at module top level.
1067        if self.is_private_name(lvalue) or self.is_not_in_all(lvalue):
1068            return None
1069        self._vars[-1].append(lvalue)
1070        if annotation is not None:
1071            typename = self.print_annotation(annotation)
1072            if (isinstance(annotation, UnboundType) and not annotation.args and
1073                    annotation.name == 'Final' and
1074                    self.import_tracker.module_for.get('Final') in ('typing',
1075                                                                    'typing_extensions')):
1076                # Final without type argument is invalid in stubs.
1077                final_arg = self.get_str_type_of_node(rvalue)
1078                typename += '[{}]'.format(final_arg)
1079        else:
1080            typename = self.get_str_type_of_node(rvalue)
1081        return '%s%s: %s\n' % (self._indent, lvalue, typename)
1082
1083    def add(self, string: str) -> None:
1084        """Add text to generated stub."""
1085        self._output.append(string)
1086
1087    def add_decorator(self, name: str) -> None:
1088        if not self._indent and self._state not in (EMPTY, FUNC):
1089            self._decorators.append('\n')
1090        self._decorators.append('%s@%s\n' % (self._indent, name))
1091
1092    def clear_decorators(self) -> None:
1093        self._decorators.clear()
1094
1095    def typing_name(self, name: str) -> str:
1096        if name in self.defined_names:
1097            # Avoid name clash between name from typing and a name defined in stub.
1098            return '_' + name
1099        else:
1100            return name
1101
1102    def add_typing_import(self, name: str) -> None:
1103        """Add a name to be imported from typing, unless it's imported already.
1104
1105        The import will be internal to the stub.
1106        """
1107        name = self.typing_name(name)
1108        self.import_tracker.require_name(name)
1109
1110    def add_import_line(self, line: str) -> None:
1111        """Add a line of text to the import section, unless it's already there."""
1112        if line not in self._import_lines:
1113            self._import_lines.append(line)
1114
1115    def add_coroutine_decorator(self, func: FuncDef, name: str, require_name: str) -> None:
1116        func.is_awaitable_coroutine = True
1117        self.add_decorator(name)
1118        self.import_tracker.require_name(require_name)
1119
1120    def output(self) -> str:
1121        """Return the text for the stub."""
1122        imports = ''
1123        if self._import_lines:
1124            imports += ''.join(self._import_lines)
1125        imports += ''.join(self.import_tracker.import_lines())
1126        if imports and self._output:
1127            imports += '\n'
1128        return imports + ''.join(self._output)
1129
1130    def is_not_in_all(self, name: str) -> bool:
1131        if self.is_private_name(name):
1132            return False
1133        if self._all_:
1134            return self.is_top_level() and name not in self._all_
1135        return False
1136
1137    def is_private_name(self, name: str, fullname: Optional[str] = None) -> bool:
1138        if self._include_private:
1139            return False
1140        if fullname in EXTRA_EXPORTED:
1141            return False
1142        return name.startswith('_') and (not name.endswith('__')
1143                                         or name in IGNORED_DUNDERS)
1144
1145    def is_private_member(self, fullname: str) -> bool:
1146        parts = fullname.split('.')
1147        for part in parts:
1148            if self.is_private_name(part):
1149                return True
1150        return False
1151
1152    def get_str_type_of_node(self, rvalue: Expression,
1153                             can_infer_optional: bool = False,
1154                             can_be_any: bool = True) -> str:
1155        if isinstance(rvalue, IntExpr):
1156            return 'int'
1157        if isinstance(rvalue, StrExpr):
1158            return 'str'
1159        if isinstance(rvalue, BytesExpr):
1160            return 'bytes'
1161        if isinstance(rvalue, FloatExpr):
1162            return 'float'
1163        if isinstance(rvalue, UnaryExpr) and isinstance(rvalue.expr, IntExpr):
1164            return 'int'
1165        if isinstance(rvalue, NameExpr) and rvalue.name in ('True', 'False'):
1166            return 'bool'
1167        if can_infer_optional and \
1168                isinstance(rvalue, NameExpr) and rvalue.name == 'None':
1169            self.add_typing_import('Any')
1170            return '{} | None'.format(self.typing_name('Any'))
1171        if can_be_any:
1172            self.add_typing_import('Any')
1173            return self.typing_name('Any')
1174        else:
1175            return ''
1176
1177    def print_annotation(self, t: Type) -> str:
1178        printer = AnnotationPrinter(self)
1179        return t.accept(printer)
1180
1181    def is_top_level(self) -> bool:
1182        """Are we processing the top level of a file?"""
1183        return self._indent == ''
1184
1185    def record_name(self, name: str) -> None:
1186        """Mark a name as defined.
1187
1188        This only does anything if at the top level of a module.
1189        """
1190        if self.is_top_level():
1191            self._toplevel_names.append(name)
1192
1193    def is_recorded_name(self, name: str) -> bool:
1194        """Has this name been recorded previously?"""
1195        return self.is_top_level() and name in self._toplevel_names
1196
1197
1198def find_method_names(defs: List[Statement]) -> Set[str]:
1199    # TODO: Traverse into nested definitions
1200    result = set()
1201    for defn in defs:
1202        if isinstance(defn, FuncDef):
1203            result.add(defn.name)
1204        elif isinstance(defn, Decorator):
1205            result.add(defn.func.name)
1206        elif isinstance(defn, OverloadedFuncDef):
1207            for item in defn.items:
1208                result.update(find_method_names([item]))
1209    return result
1210
1211
1212class SelfTraverser(mypy.traverser.TraverserVisitor):
1213    def __init__(self) -> None:
1214        self.results = []  # type: List[Tuple[str, Expression]]
1215
1216    def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
1217        lvalue = o.lvalues[0]
1218        if (isinstance(lvalue, MemberExpr) and
1219                isinstance(lvalue.expr, NameExpr) and
1220                lvalue.expr.name == 'self'):
1221            self.results.append((lvalue.name, o.rvalue))
1222
1223
1224def find_self_initializers(fdef: FuncBase) -> List[Tuple[str, Expression]]:
1225    """Find attribute initializers in a method.
1226
1227    Return a list of pairs (attribute name, r.h.s. expression).
1228    """
1229    traverser = SelfTraverser()
1230    fdef.accept(traverser)
1231    return traverser.results
1232
1233
1234def get_qualified_name(o: Expression) -> str:
1235    if isinstance(o, NameExpr):
1236        return o.name
1237    elif isinstance(o, MemberExpr):
1238        return '%s.%s' % (get_qualified_name(o.expr), o.name)
1239    else:
1240        return ERROR_MARKER
1241
1242
1243def remove_blacklisted_modules(modules: List[StubSource]) -> List[StubSource]:
1244    return [module for module in modules
1245            if module.path is None or not is_blacklisted_path(module.path)]
1246
1247
1248def is_blacklisted_path(path: str) -> bool:
1249    return any(substr in (normalize_path_separators(path) + '\n')
1250               for substr in BLACKLIST)
1251
1252
1253def normalize_path_separators(path: str) -> str:
1254    if sys.platform == 'win32':
1255        return path.replace('\\', '/')
1256    return path
1257
1258
1259def collect_build_targets(options: Options, mypy_opts: MypyOptions) -> Tuple[List[StubSource],
1260                                                                             List[StubSource]]:
1261    """Collect files for which we need to generate stubs.
1262
1263    Return list of Python modules and C modules.
1264    """
1265    if options.packages or options.modules:
1266        if options.no_import:
1267            py_modules = find_module_paths_using_search(options.modules,
1268                                                        options.packages,
1269                                                        options.search_path,
1270                                                        options.pyversion)
1271            c_modules = []  # type: List[StubSource]
1272        else:
1273            # Using imports is the default, since we can also find C modules.
1274            py_modules, c_modules = find_module_paths_using_imports(options.modules,
1275                                                                    options.packages,
1276                                                                    options.interpreter,
1277                                                                    options.pyversion,
1278                                                                    options.verbose,
1279                                                                    options.quiet)
1280    else:
1281        # Use mypy native source collection for files and directories.
1282        try:
1283            source_list = create_source_list(options.files, mypy_opts)
1284        except InvalidSourceList as e:
1285            raise SystemExit(str(e)) from e
1286        py_modules = [StubSource(m.module, m.path) for m in source_list]
1287        c_modules = []
1288
1289    py_modules = remove_blacklisted_modules(py_modules)
1290
1291    return py_modules, c_modules
1292
1293
1294def find_module_paths_using_imports(modules: List[str],
1295                                    packages: List[str],
1296                                    interpreter: str,
1297                                    pyversion: Tuple[int, int],
1298                                    verbose: bool,
1299                                    quiet: bool) -> Tuple[List[StubSource],
1300                                                          List[StubSource]]:
1301    """Find path and runtime value of __all__ (if possible) for modules and packages.
1302
1303    This function uses runtime Python imports to get the information.
1304    """
1305    with ModuleInspect() as inspect:
1306        py_modules = []  # type: List[StubSource]
1307        c_modules = []  # type: List[StubSource]
1308        found = list(walk_packages(inspect, packages, verbose))
1309        modules = modules + found
1310        modules = [mod
1311                   for mod in modules
1312                   if not is_non_library_module(mod)]  # We don't want to run any tests or scripts
1313        for mod in modules:
1314            try:
1315                if pyversion[0] == 2:
1316                    result = find_module_path_and_all_py2(mod, interpreter)
1317                else:
1318                    result = find_module_path_and_all_py3(inspect, mod, verbose)
1319            except CantImport as e:
1320                tb = traceback.format_exc()
1321                if verbose:
1322                    sys.stdout.write(tb)
1323                if not quiet:
1324                    report_missing(mod, e.message, tb)
1325                continue
1326            if not result:
1327                c_modules.append(StubSource(mod))
1328            else:
1329                path, runtime_all = result
1330                py_modules.append(StubSource(mod, path, runtime_all))
1331        return py_modules, c_modules
1332
1333
1334def is_non_library_module(module: str) -> bool:
1335    """Does module look like a test module or a script?"""
1336    if module.endswith((
1337            '.tests',
1338            '.test',
1339            '.testing',
1340            '_tests',
1341            '_test_suite',
1342            'test_util',
1343            'test_utils',
1344            'test_base',
1345            '.__main__',
1346            '.conftest',  # Used by pytest
1347            '.setup',  # Typically an install script
1348    )):
1349        return True
1350    if module.split('.')[-1].startswith('test_'):
1351        return True
1352    if ('.tests.' in module
1353            or '.test.' in module
1354            or '.testing.' in module
1355            or '.SelfTest.' in module):
1356        return True
1357    return False
1358
1359
1360def translate_module_name(module: str, relative: int) -> Tuple[str, int]:
1361    for pkg in VENDOR_PACKAGES:
1362        for alt in 'six.moves', 'six':
1363            substr = '{}.{}'.format(pkg, alt)
1364            if (module.endswith('.' + substr)
1365                    or (module == substr and relative)):
1366                return alt, 0
1367            if '.' + substr + '.' in module:
1368                return alt + '.' + module.partition('.' + substr + '.')[2], 0
1369    return module, relative
1370
1371
1372def find_module_paths_using_search(modules: List[str], packages: List[str],
1373                                   search_path: List[str],
1374                                   pyversion: Tuple[int, int]) -> List[StubSource]:
1375    """Find sources for modules and packages requested.
1376
1377    This function just looks for source files at the file system level.
1378    This is used if user passes --no-import, and will not find C modules.
1379    Exit if some of the modules or packages can't be found.
1380    """
1381    result = []  # type: List[StubSource]
1382    typeshed_path = default_lib_path(mypy.build.default_data_dir(), pyversion, None)
1383    search_paths = SearchPaths(('.',) + tuple(search_path), (), (), tuple(typeshed_path))
1384    cache = FindModuleCache(search_paths, fscache=None, options=None)
1385    for module in modules:
1386        m_result = cache.find_module(module)
1387        if isinstance(m_result, ModuleNotFoundReason):
1388            fail_missing(module, m_result)
1389            module_path = None
1390        else:
1391            module_path = m_result
1392        result.append(StubSource(module, module_path))
1393    for package in packages:
1394        p_result = cache.find_modules_recursive(package)
1395        if p_result:
1396            fail_missing(package, ModuleNotFoundReason.NOT_FOUND)
1397        sources = [StubSource(m.module, m.path) for m in p_result]
1398        result.extend(sources)
1399
1400    result = [m for m in result if not is_non_library_module(m.module)]
1401
1402    return result
1403
1404
1405def mypy_options(stubgen_options: Options) -> MypyOptions:
1406    """Generate mypy options using the flag passed by user."""
1407    options = MypyOptions()
1408    options.follow_imports = 'skip'
1409    options.incremental = False
1410    options.ignore_errors = True
1411    options.semantic_analysis_only = True
1412    options.python_version = stubgen_options.pyversion
1413    options.show_traceback = True
1414    options.transform_source = remove_misplaced_type_comments
1415    return options
1416
1417
1418def parse_source_file(mod: StubSource, mypy_options: MypyOptions) -> None:
1419    """Parse a source file.
1420
1421    On success, store AST in the corresponding attribute of the stub source.
1422    If there are syntax errors, print them and exit.
1423    """
1424    assert mod.path is not None, "Not found module was not skipped"
1425    with open(mod.path, 'rb') as f:
1426        data = f.read()
1427    source = mypy.util.decode_python_encoding(data, mypy_options.python_version)
1428    errors = Errors()
1429    mod.ast = mypy.parse.parse(source, fnam=mod.path, module=mod.module,
1430                               errors=errors, options=mypy_options)
1431    mod.ast._fullname = mod.module
1432    if errors.is_blockers():
1433        # Syntax error!
1434        for m in errors.new_messages():
1435            sys.stderr.write('%s\n' % m)
1436        sys.exit(1)
1437
1438
1439def generate_asts_for_modules(py_modules: List[StubSource],
1440                              parse_only: bool,
1441                              mypy_options: MypyOptions,
1442                              verbose: bool) -> None:
1443    """Use mypy to parse (and optionally analyze) source files."""
1444    if not py_modules:
1445        return  # Nothing to do here, but there may be C modules
1446    if verbose:
1447        print('Processing %d files...' % len(py_modules))
1448    if parse_only:
1449        for mod in py_modules:
1450            parse_source_file(mod, mypy_options)
1451        return
1452    # Perform full semantic analysis of the source set.
1453    try:
1454        res = build([module.source for module in py_modules], mypy_options)
1455    except CompileError as e:
1456        raise SystemExit("Critical error during semantic analysis: {}".format(e)) from e
1457
1458    for mod in py_modules:
1459        mod.ast = res.graph[mod.module].tree
1460        # Use statically inferred __all__ if there is no runtime one.
1461        if mod.runtime_all is None:
1462            mod.runtime_all = res.manager.semantic_analyzer.export_map[mod.module]
1463
1464
1465def generate_stub_from_ast(mod: StubSource,
1466                           target: str,
1467                           parse_only: bool = False,
1468                           pyversion: Tuple[int, int] = defaults.PYTHON3_VERSION,
1469                           include_private: bool = False,
1470                           export_less: bool = False) -> None:
1471    """Use analysed (or just parsed) AST to generate type stub for single file.
1472
1473    If directory for target doesn't exist it will created. Existing stub
1474    will be overwritten.
1475    """
1476    gen = StubGenerator(mod.runtime_all,
1477                        pyversion=pyversion,
1478                        include_private=include_private,
1479                        analyzed=not parse_only,
1480                        export_less=export_less)
1481    assert mod.ast is not None, "This function must be used only with analyzed modules"
1482    mod.ast.accept(gen)
1483
1484    # Write output to file.
1485    subdir = os.path.dirname(target)
1486    if subdir and not os.path.isdir(subdir):
1487        os.makedirs(subdir)
1488    with open(target, 'w') as file:
1489        file.write(''.join(gen.output()))
1490
1491
1492def collect_docs_signatures(doc_dir: str) -> Tuple[Dict[str, str], Dict[str, str]]:
1493    """Gather all function and class signatures in the docs.
1494
1495    Return a tuple (function signatures, class signatures).
1496    Currently only used for C modules.
1497    """
1498    all_sigs = []  # type: List[Sig]
1499    all_class_sigs = []  # type: List[Sig]
1500    for path in glob.glob('%s/*.rst' % doc_dir):
1501        with open(path) as f:
1502            loc_sigs, loc_class_sigs = parse_all_signatures(f.readlines())
1503        all_sigs += loc_sigs
1504        all_class_sigs += loc_class_sigs
1505    sigs = dict(find_unique_signatures(all_sigs))
1506    class_sigs = dict(find_unique_signatures(all_class_sigs))
1507    return sigs, class_sigs
1508
1509
1510def generate_stubs(options: Options) -> None:
1511    """Main entry point for the program."""
1512    mypy_opts = mypy_options(options)
1513    py_modules, c_modules = collect_build_targets(options, mypy_opts)
1514
1515    # Collect info from docs (if given):
1516    sigs = class_sigs = None  # type: Optional[Dict[str, str]]
1517    if options.doc_dir:
1518        sigs, class_sigs = collect_docs_signatures(options.doc_dir)
1519
1520    # Use parsed sources to generate stubs for Python modules.
1521    generate_asts_for_modules(py_modules, options.parse_only, mypy_opts, options.verbose)
1522    files = []
1523    for mod in py_modules:
1524        assert mod.path is not None, "Not found module was not skipped"
1525        target = mod.module.replace('.', '/')
1526        if os.path.basename(mod.path) == '__init__.py':
1527            target += '/__init__.pyi'
1528        else:
1529            target += '.pyi'
1530        target = os.path.join(options.output_dir, target)
1531        files.append(target)
1532        with generate_guarded(mod.module, target, options.ignore_errors, options.verbose):
1533            generate_stub_from_ast(mod, target,
1534                                   options.parse_only, options.pyversion,
1535                                   options.include_private,
1536                                   options.export_less)
1537
1538    # Separately analyse C modules using different logic.
1539    for mod in c_modules:
1540        if any(py_mod.module.startswith(mod.module + '.')
1541               for py_mod in py_modules + c_modules):
1542            target = mod.module.replace('.', '/') + '/__init__.pyi'
1543        else:
1544            target = mod.module.replace('.', '/') + '.pyi'
1545        target = os.path.join(options.output_dir, target)
1546        files.append(target)
1547        with generate_guarded(mod.module, target, options.ignore_errors, options.verbose):
1548            generate_stub_for_c_module(mod.module, target, sigs=sigs, class_sigs=class_sigs)
1549    num_modules = len(py_modules) + len(c_modules)
1550    if not options.quiet and num_modules > 0:
1551        print('Processed %d modules' % num_modules)
1552        if len(files) == 1:
1553            print('Generated %s' % files[0])
1554        else:
1555            print('Generated files under %s' % common_dir_prefix(files) + os.sep)
1556
1557
1558HEADER = """%(prog)s [-h] [--py2] [more options, see -h]
1559                     [-m MODULE] [-p PACKAGE] [files ...]"""
1560
1561DESCRIPTION = """
1562Generate draft stubs for modules.
1563
1564Stubs are generated in directory ./out, to avoid overriding files with
1565manual changes.  This directory is assumed to exist.
1566"""
1567
1568
1569def parse_options(args: List[str]) -> Options:
1570    parser = argparse.ArgumentParser(prog='stubgen',
1571                                     usage=HEADER,
1572                                     description=DESCRIPTION)
1573
1574    parser.add_argument('--py2', action='store_true',
1575                        help="run in Python 2 mode (default: Python 3 mode)")
1576    parser.add_argument('--ignore-errors', action='store_true',
1577                        help="ignore errors when trying to generate stubs for modules")
1578    parser.add_argument('--no-import', action='store_true',
1579                        help="don't import the modules, just parse and analyze them "
1580                             "(doesn't work with C extension modules and might not "
1581                             "respect __all__)")
1582    parser.add_argument('--parse-only', action='store_true',
1583                        help="don't perform semantic analysis of sources, just parse them "
1584                             "(only applies to Python modules, might affect quality of stubs)")
1585    parser.add_argument('--include-private', action='store_true',
1586                        help="generate stubs for objects and members considered private "
1587                             "(single leading underscore and no trailing underscores)")
1588    parser.add_argument('--export-less', action='store_true',
1589                        help=("don't implicitly export all names imported from other modules "
1590                              "in the same package"))
1591    parser.add_argument('-v', '--verbose', action='store_true',
1592                        help="show more verbose messages")
1593    parser.add_argument('-q', '--quiet', action='store_true',
1594                        help="show fewer messages")
1595    parser.add_argument('--doc-dir', metavar='PATH', default='',
1596                        help="use .rst documentation in PATH (this may result in "
1597                             "better stubs in some cases; consider setting this to "
1598                             "DIR/Python-X.Y.Z/Doc/library)")
1599    parser.add_argument('--search-path', metavar='PATH', default='',
1600                        help="specify module search directories, separated by ':' "
1601                             "(currently only used if --no-import is given)")
1602    parser.add_argument('--python-executable', metavar='PATH', dest='interpreter', default='',
1603                        help="use Python interpreter at PATH (only works for "
1604                             "Python 2 right now)")
1605    parser.add_argument('-o', '--output', metavar='PATH', dest='output_dir', default='out',
1606                        help="change the output directory [default: %(default)s]")
1607    parser.add_argument('-m', '--module', action='append', metavar='MODULE',
1608                        dest='modules', default=[],
1609                        help="generate stub for module; can repeat for more modules")
1610    parser.add_argument('-p', '--package', action='append', metavar='PACKAGE',
1611                        dest='packages', default=[],
1612                        help="generate stubs for package recursively; can be repeated")
1613    parser.add_argument(metavar='files', nargs='*', dest='files',
1614                        help="generate stubs for given files or directories")
1615
1616    ns = parser.parse_args(args)
1617
1618    pyversion = defaults.PYTHON2_VERSION if ns.py2 else defaults.PYTHON3_VERSION
1619    if not ns.interpreter:
1620        ns.interpreter = sys.executable if pyversion[0] == 3 else default_py2_interpreter()
1621    if ns.modules + ns.packages and ns.files:
1622        parser.error("May only specify one of: modules/packages or files.")
1623    if ns.quiet and ns.verbose:
1624        parser.error('Cannot specify both quiet and verbose messages')
1625
1626    # Create the output folder if it doesn't already exist.
1627    if not os.path.exists(ns.output_dir):
1628        os.makedirs(ns.output_dir)
1629
1630    return Options(pyversion=pyversion,
1631                   no_import=ns.no_import,
1632                   doc_dir=ns.doc_dir,
1633                   search_path=ns.search_path.split(':'),
1634                   interpreter=ns.interpreter,
1635                   ignore_errors=ns.ignore_errors,
1636                   parse_only=ns.parse_only,
1637                   include_private=ns.include_private,
1638                   output_dir=ns.output_dir,
1639                   modules=ns.modules,
1640                   packages=ns.packages,
1641                   files=ns.files,
1642                   verbose=ns.verbose,
1643                   quiet=ns.quiet,
1644                   export_less=ns.export_less)
1645
1646
1647def main() -> None:
1648    mypy.util.check_python_version('stubgen')
1649    # Make sure that the current directory is in sys.path so that
1650    # stubgen can be run on packages in the current directory.
1651    if not ('' in sys.path or '.' in sys.path):
1652        sys.path.insert(0, '')
1653
1654    options = parse_options(sys.argv[1:])
1655    generate_stubs(options)
1656
1657
1658if __name__ == '__main__':
1659    main()
1660