1"""
2    sphinx.pycode.ast
3    ~~~~~~~~~~~~~~~~~
4
5    Helpers for AST (Abstract Syntax Tree).
6
7    :copyright: Copyright 2007-2021 by the Sphinx team, see AUTHORS.
8    :license: BSD, see LICENSE for details.
9"""
10
11import sys
12from typing import Dict, List, Optional, Type
13
14if sys.version_info > (3, 8):
15    import ast
16else:
17    try:
18        # use typed_ast module if installed
19        from typed_ast import ast3 as ast
20    except ImportError:
21        import ast  # type: ignore
22
23
24OPERATORS = {
25    ast.Add: "+",
26    ast.And: "and",
27    ast.BitAnd: "&",
28    ast.BitOr: "|",
29    ast.BitXor: "^",
30    ast.Div: "/",
31    ast.FloorDiv: "//",
32    ast.Invert: "~",
33    ast.LShift: "<<",
34    ast.MatMult: "@",
35    ast.Mult: "*",
36    ast.Mod: "%",
37    ast.Not: "not",
38    ast.Pow: "**",
39    ast.Or: "or",
40    ast.RShift: ">>",
41    ast.Sub: "-",
42    ast.UAdd: "+",
43    ast.USub: "-",
44}  # type: Dict[Type[ast.AST], str]
45
46
47def parse(code: str, mode: str = 'exec') -> "ast.AST":
48    """Parse the *code* using built-in ast or typed_ast.
49
50    This enables "type_comments" feature if possible.
51    """
52    try:
53        # type_comments parameter is available on py38+
54        return ast.parse(code, mode=mode, type_comments=True)  # type: ignore
55    except SyntaxError:
56        # Some syntax error found. To ignore invalid type comments, retry parsing without
57        # type_comments parameter (refs: https://github.com/sphinx-doc/sphinx/issues/8652).
58        return ast.parse(code, mode=mode)
59    except TypeError:
60        # fallback to ast module.
61        # typed_ast is used to parse type_comments if installed.
62        return ast.parse(code, mode=mode)
63
64
65def unparse(node: Optional[ast.AST], code: str = '') -> Optional[str]:
66    """Unparse an AST to string."""
67    if node is None:
68        return None
69    elif isinstance(node, str):
70        return node
71    return _UnparseVisitor(code).visit(node)
72
73
74# a greatly cut-down version of `ast._Unparser`
75class _UnparseVisitor(ast.NodeVisitor):
76    def __init__(self, code: str = '') -> None:
77        self.code = code
78
79    def _visit_op(self, node: ast.AST) -> str:
80        return OPERATORS[node.__class__]
81    for _op in OPERATORS:
82        locals()['visit_{}'.format(_op.__name__)] = _visit_op
83
84    def visit_arg(self, node: ast.arg) -> str:
85        if node.annotation:
86            return "%s: %s" % (node.arg, self.visit(node.annotation))
87        else:
88            return node.arg
89
90    def _visit_arg_with_default(self, arg: ast.arg, default: Optional[ast.AST]) -> str:
91        """Unparse a single argument to a string."""
92        name = self.visit(arg)
93        if default:
94            if arg.annotation:
95                name += " = %s" % self.visit(default)
96            else:
97                name += "=%s" % self.visit(default)
98        return name
99
100    def visit_arguments(self, node: ast.arguments) -> str:
101        defaults = list(node.defaults)
102        positionals = len(node.args)
103        posonlyargs = 0
104        if hasattr(node, "posonlyargs"):  # for py38+
105            posonlyargs += len(node.posonlyargs)  # type:ignore
106            positionals += posonlyargs
107        for _ in range(len(defaults), positionals):
108            defaults.insert(0, None)
109
110        kw_defaults = list(node.kw_defaults)
111        for _ in range(len(kw_defaults), len(node.kwonlyargs)):
112            kw_defaults.insert(0, None)
113
114        args = []  # type: List[str]
115        if hasattr(node, "posonlyargs"):  # for py38+
116            for i, arg in enumerate(node.posonlyargs):  # type: ignore
117                args.append(self._visit_arg_with_default(arg, defaults[i]))
118
119            if node.posonlyargs:  # type: ignore
120                args.append('/')
121
122        for i, arg in enumerate(node.args):
123            args.append(self._visit_arg_with_default(arg, defaults[i + posonlyargs]))
124
125        if node.vararg:
126            args.append("*" + self.visit(node.vararg))
127
128        if node.kwonlyargs and not node.vararg:
129            args.append('*')
130        for i, arg in enumerate(node.kwonlyargs):
131            args.append(self._visit_arg_with_default(arg, kw_defaults[i]))
132
133        if node.kwarg:
134            args.append("**" + self.visit(node.kwarg))
135
136        return ", ".join(args)
137
138    def visit_Attribute(self, node: ast.Attribute) -> str:
139        return "%s.%s" % (self.visit(node.value), node.attr)
140
141    def visit_BinOp(self, node: ast.BinOp) -> str:
142        return " ".join(self.visit(e) for e in [node.left, node.op, node.right])
143
144    def visit_BoolOp(self, node: ast.BoolOp) -> str:
145        op = " %s " % self.visit(node.op)
146        return op.join(self.visit(e) for e in node.values)
147
148    def visit_Call(self, node: ast.Call) -> str:
149        args = ([self.visit(e) for e in node.args] +
150                ["%s=%s" % (k.arg, self.visit(k.value)) for k in node.keywords])
151        return "%s(%s)" % (self.visit(node.func), ", ".join(args))
152
153    def visit_Dict(self, node: ast.Dict) -> str:
154        keys = (self.visit(k) for k in node.keys)
155        values = (self.visit(v) for v in node.values)
156        items = (k + ": " + v for k, v in zip(keys, values))
157        return "{" + ", ".join(items) + "}"
158
159    def visit_Index(self, node: ast.Index) -> str:
160        return self.visit(node.value)
161
162    def visit_Lambda(self, node: ast.Lambda) -> str:
163        return "lambda %s: ..." % self.visit(node.args)
164
165    def visit_List(self, node: ast.List) -> str:
166        return "[" + ", ".join(self.visit(e) for e in node.elts) + "]"
167
168    def visit_Name(self, node: ast.Name) -> str:
169        return node.id
170
171    def visit_Set(self, node: ast.Set) -> str:
172        return "{" + ", ".join(self.visit(e) for e in node.elts) + "}"
173
174    def visit_Subscript(self, node: ast.Subscript) -> str:
175        def is_simple_tuple(value: ast.AST) -> bool:
176            return (
177                isinstance(value, ast.Tuple) and
178                bool(value.elts) and
179                not any(isinstance(elt, ast.Starred) for elt in value.elts)
180            )
181
182        if is_simple_tuple(node.slice):
183            elts = ", ".join(self.visit(e) for e in node.slice.elts)  # type: ignore
184            return "%s[%s]" % (self.visit(node.value), elts)
185        elif isinstance(node.slice, ast.Index) and is_simple_tuple(node.slice.value):
186            elts = ", ".join(self.visit(e) for e in node.slice.value.elts)  # type: ignore
187            return "%s[%s]" % (self.visit(node.value), elts)
188        else:
189            return "%s[%s]" % (self.visit(node.value), self.visit(node.slice))
190
191    def visit_UnaryOp(self, node: ast.UnaryOp) -> str:
192        return "%s %s" % (self.visit(node.op), self.visit(node.operand))
193
194    def visit_Tuple(self, node: ast.Tuple) -> str:
195        if node.elts:
196            return "(" + ", ".join(self.visit(e) for e in node.elts) + ")"
197        else:
198            return "()"
199
200    if sys.version_info >= (3, 6):
201        def visit_Constant(self, node: ast.Constant) -> str:
202            if node.value is Ellipsis:
203                return "..."
204            elif isinstance(node.value, (int, float, complex)):
205                if self.code and sys.version_info > (3, 8):
206                    return ast.get_source_segment(self.code, node)
207                else:
208                    return repr(node.value)
209            else:
210                return repr(node.value)
211
212    if sys.version_info < (3, 8):
213        # these ast nodes were deprecated in python 3.8
214        def visit_Bytes(self, node: ast.Bytes) -> str:
215            return repr(node.s)
216
217        def visit_Ellipsis(self, node: ast.Ellipsis) -> str:
218            return "..."
219
220        def visit_NameConstant(self, node: ast.NameConstant) -> str:
221            return repr(node.value)
222
223        def visit_Num(self, node: ast.Num) -> str:
224            return repr(node.n)
225
226        def visit_Str(self, node: ast.Str) -> str:
227            return repr(node.s)
228
229    def generic_visit(self, node):
230        raise NotImplementedError('Unable to parse %s object' % type(node).__name__)
231