1""" 2 sphinx.pycode.ast 3 ~~~~~~~~~~~~~~~~~ 4 5 Helpers for AST (Abstract Syntax Tree). 6 7 :copyright: Copyright 2007-2021 by the Sphinx team, see AUTHORS. 8 :license: BSD, see LICENSE for details. 9""" 10 11import sys 12from typing import Dict, List, Optional, Type 13 14if sys.version_info > (3, 8): 15 import ast 16else: 17 try: 18 # use typed_ast module if installed 19 from typed_ast import ast3 as ast 20 except ImportError: 21 import ast # type: ignore 22 23 24OPERATORS = { 25 ast.Add: "+", 26 ast.And: "and", 27 ast.BitAnd: "&", 28 ast.BitOr: "|", 29 ast.BitXor: "^", 30 ast.Div: "/", 31 ast.FloorDiv: "//", 32 ast.Invert: "~", 33 ast.LShift: "<<", 34 ast.MatMult: "@", 35 ast.Mult: "*", 36 ast.Mod: "%", 37 ast.Not: "not", 38 ast.Pow: "**", 39 ast.Or: "or", 40 ast.RShift: ">>", 41 ast.Sub: "-", 42 ast.UAdd: "+", 43 ast.USub: "-", 44} # type: Dict[Type[ast.AST], str] 45 46 47def parse(code: str, mode: str = 'exec') -> "ast.AST": 48 """Parse the *code* using built-in ast or typed_ast. 49 50 This enables "type_comments" feature if possible. 51 """ 52 try: 53 # type_comments parameter is available on py38+ 54 return ast.parse(code, mode=mode, type_comments=True) # type: ignore 55 except SyntaxError: 56 # Some syntax error found. To ignore invalid type comments, retry parsing without 57 # type_comments parameter (refs: https://github.com/sphinx-doc/sphinx/issues/8652). 58 return ast.parse(code, mode=mode) 59 except TypeError: 60 # fallback to ast module. 61 # typed_ast is used to parse type_comments if installed. 62 return ast.parse(code, mode=mode) 63 64 65def unparse(node: Optional[ast.AST], code: str = '') -> Optional[str]: 66 """Unparse an AST to string.""" 67 if node is None: 68 return None 69 elif isinstance(node, str): 70 return node 71 return _UnparseVisitor(code).visit(node) 72 73 74# a greatly cut-down version of `ast._Unparser` 75class _UnparseVisitor(ast.NodeVisitor): 76 def __init__(self, code: str = '') -> None: 77 self.code = code 78 79 def _visit_op(self, node: ast.AST) -> str: 80 return OPERATORS[node.__class__] 81 for _op in OPERATORS: 82 locals()['visit_{}'.format(_op.__name__)] = _visit_op 83 84 def visit_arg(self, node: ast.arg) -> str: 85 if node.annotation: 86 return "%s: %s" % (node.arg, self.visit(node.annotation)) 87 else: 88 return node.arg 89 90 def _visit_arg_with_default(self, arg: ast.arg, default: Optional[ast.AST]) -> str: 91 """Unparse a single argument to a string.""" 92 name = self.visit(arg) 93 if default: 94 if arg.annotation: 95 name += " = %s" % self.visit(default) 96 else: 97 name += "=%s" % self.visit(default) 98 return name 99 100 def visit_arguments(self, node: ast.arguments) -> str: 101 defaults = list(node.defaults) 102 positionals = len(node.args) 103 posonlyargs = 0 104 if hasattr(node, "posonlyargs"): # for py38+ 105 posonlyargs += len(node.posonlyargs) # type:ignore 106 positionals += posonlyargs 107 for _ in range(len(defaults), positionals): 108 defaults.insert(0, None) 109 110 kw_defaults = list(node.kw_defaults) 111 for _ in range(len(kw_defaults), len(node.kwonlyargs)): 112 kw_defaults.insert(0, None) 113 114 args = [] # type: List[str] 115 if hasattr(node, "posonlyargs"): # for py38+ 116 for i, arg in enumerate(node.posonlyargs): # type: ignore 117 args.append(self._visit_arg_with_default(arg, defaults[i])) 118 119 if node.posonlyargs: # type: ignore 120 args.append('/') 121 122 for i, arg in enumerate(node.args): 123 args.append(self._visit_arg_with_default(arg, defaults[i + posonlyargs])) 124 125 if node.vararg: 126 args.append("*" + self.visit(node.vararg)) 127 128 if node.kwonlyargs and not node.vararg: 129 args.append('*') 130 for i, arg in enumerate(node.kwonlyargs): 131 args.append(self._visit_arg_with_default(arg, kw_defaults[i])) 132 133 if node.kwarg: 134 args.append("**" + self.visit(node.kwarg)) 135 136 return ", ".join(args) 137 138 def visit_Attribute(self, node: ast.Attribute) -> str: 139 return "%s.%s" % (self.visit(node.value), node.attr) 140 141 def visit_BinOp(self, node: ast.BinOp) -> str: 142 return " ".join(self.visit(e) for e in [node.left, node.op, node.right]) 143 144 def visit_BoolOp(self, node: ast.BoolOp) -> str: 145 op = " %s " % self.visit(node.op) 146 return op.join(self.visit(e) for e in node.values) 147 148 def visit_Call(self, node: ast.Call) -> str: 149 args = ([self.visit(e) for e in node.args] + 150 ["%s=%s" % (k.arg, self.visit(k.value)) for k in node.keywords]) 151 return "%s(%s)" % (self.visit(node.func), ", ".join(args)) 152 153 def visit_Dict(self, node: ast.Dict) -> str: 154 keys = (self.visit(k) for k in node.keys) 155 values = (self.visit(v) for v in node.values) 156 items = (k + ": " + v for k, v in zip(keys, values)) 157 return "{" + ", ".join(items) + "}" 158 159 def visit_Index(self, node: ast.Index) -> str: 160 return self.visit(node.value) 161 162 def visit_Lambda(self, node: ast.Lambda) -> str: 163 return "lambda %s: ..." % self.visit(node.args) 164 165 def visit_List(self, node: ast.List) -> str: 166 return "[" + ", ".join(self.visit(e) for e in node.elts) + "]" 167 168 def visit_Name(self, node: ast.Name) -> str: 169 return node.id 170 171 def visit_Set(self, node: ast.Set) -> str: 172 return "{" + ", ".join(self.visit(e) for e in node.elts) + "}" 173 174 def visit_Subscript(self, node: ast.Subscript) -> str: 175 def is_simple_tuple(value: ast.AST) -> bool: 176 return ( 177 isinstance(value, ast.Tuple) and 178 bool(value.elts) and 179 not any(isinstance(elt, ast.Starred) for elt in value.elts) 180 ) 181 182 if is_simple_tuple(node.slice): 183 elts = ", ".join(self.visit(e) for e in node.slice.elts) # type: ignore 184 return "%s[%s]" % (self.visit(node.value), elts) 185 elif isinstance(node.slice, ast.Index) and is_simple_tuple(node.slice.value): 186 elts = ", ".join(self.visit(e) for e in node.slice.value.elts) # type: ignore 187 return "%s[%s]" % (self.visit(node.value), elts) 188 else: 189 return "%s[%s]" % (self.visit(node.value), self.visit(node.slice)) 190 191 def visit_UnaryOp(self, node: ast.UnaryOp) -> str: 192 return "%s %s" % (self.visit(node.op), self.visit(node.operand)) 193 194 def visit_Tuple(self, node: ast.Tuple) -> str: 195 if node.elts: 196 return "(" + ", ".join(self.visit(e) for e in node.elts) + ")" 197 else: 198 return "()" 199 200 if sys.version_info >= (3, 6): 201 def visit_Constant(self, node: ast.Constant) -> str: 202 if node.value is Ellipsis: 203 return "..." 204 elif isinstance(node.value, (int, float, complex)): 205 if self.code and sys.version_info > (3, 8): 206 return ast.get_source_segment(self.code, node) 207 else: 208 return repr(node.value) 209 else: 210 return repr(node.value) 211 212 if sys.version_info < (3, 8): 213 # these ast nodes were deprecated in python 3.8 214 def visit_Bytes(self, node: ast.Bytes) -> str: 215 return repr(node.s) 216 217 def visit_Ellipsis(self, node: ast.Ellipsis) -> str: 218 return "..." 219 220 def visit_NameConstant(self, node: ast.NameConstant) -> str: 221 return repr(node.value) 222 223 def visit_Num(self, node: ast.Num) -> str: 224 return repr(node.n) 225 226 def visit_Str(self, node: ast.Str) -> str: 227 return repr(node.s) 228 229 def generic_visit(self, node): 230 raise NotImplementedError('Unable to parse %s object' % type(node).__name__) 231