1#! /usr/bin/env python 2"""Generate C code from an ASDL description.""" 3 4import os 5import sys 6 7from argparse import ArgumentParser 8from contextlib import contextmanager 9from pathlib import Path 10 11import asdl 12 13TABSIZE = 4 14MAX_COL = 80 15AUTOGEN_MESSAGE = "/* File automatically generated by {}. */\n\n" 16 17def get_c_type(name): 18 """Return a string for the C name of the type. 19 20 This function special cases the default types provided by asdl. 21 """ 22 if name in asdl.builtin_types: 23 return name 24 else: 25 return "%s_ty" % name 26 27def reflow_lines(s, depth): 28 """Reflow the line s indented depth tabs. 29 30 Return a sequence of lines where no line extends beyond MAX_COL 31 when properly indented. The first line is properly indented based 32 exclusively on depth * TABSIZE. All following lines -- these are 33 the reflowed lines generated by this function -- start at the same 34 column as the first character beyond the opening { in the first 35 line. 36 """ 37 size = MAX_COL - depth * TABSIZE 38 if len(s) < size: 39 return [s] 40 41 lines = [] 42 cur = s 43 padding = "" 44 while len(cur) > size: 45 i = cur.rfind(' ', 0, size) 46 # XXX this should be fixed for real 47 if i == -1 and 'GeneratorExp' in cur: 48 i = size + 3 49 assert i != -1, "Impossible line %d to reflow: %r" % (size, s) 50 lines.append(padding + cur[:i]) 51 if len(lines) == 1: 52 # find new size based on brace 53 j = cur.find('{', 0, i) 54 if j >= 0: 55 j += 2 # account for the brace and the space after it 56 size -= j 57 padding = " " * j 58 else: 59 j = cur.find('(', 0, i) 60 if j >= 0: 61 j += 1 # account for the paren (no space after it) 62 size -= j 63 padding = " " * j 64 cur = cur[i+1:] 65 else: 66 lines.append(padding + cur) 67 return lines 68 69def reflow_c_string(s, depth): 70 return '"%s"' % s.replace('\n', '\\n"\n%s"' % (' ' * depth * TABSIZE)) 71 72def is_simple(sum): 73 """Return True if a sum is a simple. 74 75 A sum is simple if its types have no fields, e.g. 76 unaryop = Invert | Not | UAdd | USub 77 """ 78 for t in sum.types: 79 if t.fields: 80 return False 81 return True 82 83def asdl_of(name, obj): 84 if isinstance(obj, asdl.Product) or isinstance(obj, asdl.Constructor): 85 fields = ", ".join(map(str, obj.fields)) 86 if fields: 87 fields = "({})".format(fields) 88 return "{}{}".format(name, fields) 89 else: 90 if is_simple(obj): 91 types = " | ".join(type.name for type in obj.types) 92 else: 93 sep = "\n{}| ".format(" " * (len(name) + 1)) 94 types = sep.join( 95 asdl_of(type.name, type) for type in obj.types 96 ) 97 return "{} = {}".format(name, types) 98 99class EmitVisitor(asdl.VisitorBase): 100 """Visit that emits lines""" 101 102 def __init__(self, file): 103 self.file = file 104 self.identifiers = set() 105 self.singletons = set() 106 self.types = set() 107 super(EmitVisitor, self).__init__() 108 109 def emit_identifier(self, name): 110 self.identifiers.add(str(name)) 111 112 def emit_singleton(self, name): 113 self.singletons.add(str(name)) 114 115 def emit_type(self, name): 116 self.types.add(str(name)) 117 118 def emit(self, s, depth, reflow=True): 119 # XXX reflow long lines? 120 if reflow: 121 lines = reflow_lines(s, depth) 122 else: 123 lines = [s] 124 for line in lines: 125 if line: 126 line = (" " * TABSIZE * depth) + line 127 self.file.write(line + "\n") 128 129 130class TypeDefVisitor(EmitVisitor): 131 def visitModule(self, mod): 132 for dfn in mod.dfns: 133 self.visit(dfn) 134 135 def visitType(self, type, depth=0): 136 self.visit(type.value, type.name, depth) 137 138 def visitSum(self, sum, name, depth): 139 if is_simple(sum): 140 self.simple_sum(sum, name, depth) 141 else: 142 self.sum_with_constructors(sum, name, depth) 143 144 def simple_sum(self, sum, name, depth): 145 enum = [] 146 for i in range(len(sum.types)): 147 type = sum.types[i] 148 enum.append("%s=%d" % (type.name, i + 1)) 149 enums = ", ".join(enum) 150 ctype = get_c_type(name) 151 s = "typedef enum _%s { %s } %s;" % (name, enums, ctype) 152 self.emit(s, depth) 153 self.emit("", depth) 154 155 def sum_with_constructors(self, sum, name, depth): 156 ctype = get_c_type(name) 157 s = "typedef struct _%(name)s *%(ctype)s;" % locals() 158 self.emit(s, depth) 159 self.emit("", depth) 160 161 def visitProduct(self, product, name, depth): 162 ctype = get_c_type(name) 163 s = "typedef struct _%(name)s *%(ctype)s;" % locals() 164 self.emit(s, depth) 165 self.emit("", depth) 166 167 168class StructVisitor(EmitVisitor): 169 """Visitor to generate typedefs for AST.""" 170 171 def visitModule(self, mod): 172 for dfn in mod.dfns: 173 self.visit(dfn) 174 175 def visitType(self, type, depth=0): 176 self.visit(type.value, type.name, depth) 177 178 def visitSum(self, sum, name, depth): 179 if not is_simple(sum): 180 self.sum_with_constructors(sum, name, depth) 181 182 def sum_with_constructors(self, sum, name, depth): 183 def emit(s, depth=depth): 184 self.emit(s % sys._getframe(1).f_locals, depth) 185 enum = [] 186 for i in range(len(sum.types)): 187 type = sum.types[i] 188 enum.append("%s_kind=%d" % (type.name, i + 1)) 189 190 emit("enum _%(name)s_kind {" + ", ".join(enum) + "};") 191 192 emit("struct _%(name)s {") 193 emit("enum _%(name)s_kind kind;", depth + 1) 194 emit("union {", depth + 1) 195 for t in sum.types: 196 self.visit(t, depth + 2) 197 emit("} v;", depth + 1) 198 for field in sum.attributes: 199 # rudimentary attribute handling 200 type = str(field.type) 201 assert type in asdl.builtin_types, type 202 emit("%s %s;" % (type, field.name), depth + 1); 203 emit("};") 204 emit("") 205 206 def visitConstructor(self, cons, depth): 207 if cons.fields: 208 self.emit("struct {", depth) 209 for f in cons.fields: 210 self.visit(f, depth + 1) 211 self.emit("} %s;" % cons.name, depth) 212 self.emit("", depth) 213 214 def visitField(self, field, depth): 215 # XXX need to lookup field.type, because it might be something 216 # like a builtin... 217 ctype = get_c_type(field.type) 218 name = field.name 219 if field.seq: 220 if field.type == 'cmpop': 221 self.emit("asdl_int_seq *%(name)s;" % locals(), depth) 222 else: 223 self.emit("asdl_seq *%(name)s;" % locals(), depth) 224 else: 225 self.emit("%(ctype)s %(name)s;" % locals(), depth) 226 227 def visitProduct(self, product, name, depth): 228 self.emit("struct _%(name)s {" % locals(), depth) 229 for f in product.fields: 230 self.visit(f, depth + 1) 231 for field in product.attributes: 232 # rudimentary attribute handling 233 type = str(field.type) 234 assert type in asdl.builtin_types, type 235 self.emit("%s %s;" % (type, field.name), depth + 1); 236 self.emit("};", depth) 237 self.emit("", depth) 238 239 240class PrototypeVisitor(EmitVisitor): 241 """Generate function prototypes for the .h file""" 242 243 def visitModule(self, mod): 244 for dfn in mod.dfns: 245 self.visit(dfn) 246 247 def visitType(self, type): 248 self.visit(type.value, type.name) 249 250 def visitSum(self, sum, name): 251 if is_simple(sum): 252 pass # XXX 253 else: 254 for t in sum.types: 255 self.visit(t, name, sum.attributes) 256 257 def get_args(self, fields): 258 """Return list of C argument into, one for each field. 259 260 Argument info is 3-tuple of a C type, variable name, and flag 261 that is true if type can be NULL. 262 """ 263 args = [] 264 unnamed = {} 265 for f in fields: 266 if f.name is None: 267 name = f.type 268 c = unnamed[name] = unnamed.get(name, 0) + 1 269 if c > 1: 270 name = "name%d" % (c - 1) 271 else: 272 name = f.name 273 # XXX should extend get_c_type() to handle this 274 if f.seq: 275 if f.type == 'cmpop': 276 ctype = "asdl_int_seq *" 277 else: 278 ctype = "asdl_seq *" 279 else: 280 ctype = get_c_type(f.type) 281 args.append((ctype, name, f.opt or f.seq)) 282 return args 283 284 def visitConstructor(self, cons, type, attrs): 285 args = self.get_args(cons.fields) 286 attrs = self.get_args(attrs) 287 ctype = get_c_type(type) 288 self.emit_function(cons.name, ctype, args, attrs) 289 290 def emit_function(self, name, ctype, args, attrs, union=True): 291 args = args + attrs 292 if args: 293 argstr = ", ".join(["%s %s" % (atype, aname) 294 for atype, aname, opt in args]) 295 argstr += ", PyArena *arena" 296 else: 297 argstr = "PyArena *arena" 298 margs = "a0" 299 for i in range(1, len(args)+1): 300 margs += ", a%d" % i 301 self.emit("#define %s(%s) _Py_%s(%s)" % (name, margs, name, margs), 0, 302 reflow=False) 303 self.emit("%s _Py_%s(%s);" % (ctype, name, argstr), False) 304 305 def visitProduct(self, prod, name): 306 self.emit_function(name, get_c_type(name), 307 self.get_args(prod.fields), 308 self.get_args(prod.attributes), 309 union=False) 310 311 312class FunctionVisitor(PrototypeVisitor): 313 """Visitor to generate constructor functions for AST.""" 314 315 def emit_function(self, name, ctype, args, attrs, union=True): 316 def emit(s, depth=0, reflow=True): 317 self.emit(s, depth, reflow) 318 argstr = ", ".join(["%s %s" % (atype, aname) 319 for atype, aname, opt in args + attrs]) 320 if argstr: 321 argstr += ", PyArena *arena" 322 else: 323 argstr = "PyArena *arena" 324 self.emit("%s" % ctype, 0) 325 emit("%s(%s)" % (name, argstr)) 326 emit("{") 327 emit("%s p;" % ctype, 1) 328 for argtype, argname, opt in args: 329 if not opt and argtype != "int": 330 emit("if (!%s) {" % argname, 1) 331 emit("PyErr_SetString(PyExc_ValueError,", 2) 332 msg = "field '%s' is required for %s" % (argname, name) 333 emit(' "%s");' % msg, 334 2, reflow=False) 335 emit('return NULL;', 2) 336 emit('}', 1) 337 338 emit("p = (%s)PyArena_Malloc(arena, sizeof(*p));" % ctype, 1); 339 emit("if (!p)", 1) 340 emit("return NULL;", 2) 341 if union: 342 self.emit_body_union(name, args, attrs) 343 else: 344 self.emit_body_struct(name, args, attrs) 345 emit("return p;", 1) 346 emit("}") 347 emit("") 348 349 def emit_body_union(self, name, args, attrs): 350 def emit(s, depth=0, reflow=True): 351 self.emit(s, depth, reflow) 352 emit("p->kind = %s_kind;" % name, 1) 353 for argtype, argname, opt in args: 354 emit("p->v.%s.%s = %s;" % (name, argname, argname), 1) 355 for argtype, argname, opt in attrs: 356 emit("p->%s = %s;" % (argname, argname), 1) 357 358 def emit_body_struct(self, name, args, attrs): 359 def emit(s, depth=0, reflow=True): 360 self.emit(s, depth, reflow) 361 for argtype, argname, opt in args: 362 emit("p->%s = %s;" % (argname, argname), 1) 363 for argtype, argname, opt in attrs: 364 emit("p->%s = %s;" % (argname, argname), 1) 365 366 367class PickleVisitor(EmitVisitor): 368 369 def visitModule(self, mod): 370 for dfn in mod.dfns: 371 self.visit(dfn) 372 373 def visitType(self, type): 374 self.visit(type.value, type.name) 375 376 def visitSum(self, sum, name): 377 pass 378 379 def visitProduct(self, sum, name): 380 pass 381 382 def visitConstructor(self, cons, name): 383 pass 384 385 def visitField(self, sum): 386 pass 387 388 389class Obj2ModPrototypeVisitor(PickleVisitor): 390 def visitProduct(self, prod, name): 391 code = "static int obj2ast_%s(astmodulestate *state, PyObject* obj, %s* out, PyArena* arena);" 392 self.emit(code % (name, get_c_type(name)), 0) 393 394 visitSum = visitProduct 395 396 397class Obj2ModVisitor(PickleVisitor): 398 @contextmanager 399 def recursive_call(self, node, level): 400 self.emit('if (Py_EnterRecursiveCall(" while traversing \'%s\' node")) {' % node, level, reflow=False) 401 self.emit('goto failed;', level + 1) 402 self.emit('}', level) 403 yield 404 self.emit('Py_LeaveRecursiveCall();', level) 405 406 def funcHeader(self, name): 407 ctype = get_c_type(name) 408 self.emit("int", 0) 409 self.emit("obj2ast_%s(astmodulestate *state, PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0) 410 self.emit("{", 0) 411 self.emit("int isinstance;", 1) 412 self.emit("", 0) 413 414 def sumTrailer(self, name, add_label=False): 415 self.emit("", 0) 416 # there's really nothing more we can do if this fails ... 417 error = "expected some sort of %s, but got %%R" % name 418 format = "PyErr_Format(PyExc_TypeError, \"%s\", obj);" 419 self.emit(format % error, 1, reflow=False) 420 if add_label: 421 self.emit("failed:", 1) 422 self.emit("Py_XDECREF(tmp);", 1) 423 self.emit("return 1;", 1) 424 self.emit("}", 0) 425 self.emit("", 0) 426 427 def simpleSum(self, sum, name): 428 self.funcHeader(name) 429 for t in sum.types: 430 line = ("isinstance = PyObject_IsInstance(obj, " 431 "state->%s_type);") 432 self.emit(line % (t.name,), 1) 433 self.emit("if (isinstance == -1) {", 1) 434 self.emit("return 1;", 2) 435 self.emit("}", 1) 436 self.emit("if (isinstance) {", 1) 437 self.emit("*out = %s;" % t.name, 2) 438 self.emit("return 0;", 2) 439 self.emit("}", 1) 440 self.sumTrailer(name) 441 442 def buildArgs(self, fields): 443 return ", ".join(fields + ["arena"]) 444 445 def complexSum(self, sum, name): 446 self.funcHeader(name) 447 self.emit("PyObject *tmp = NULL;", 1) 448 self.emit("PyObject *tp;", 1) 449 for a in sum.attributes: 450 self.visitAttributeDeclaration(a, name, sum=sum) 451 self.emit("", 0) 452 # XXX: should we only do this for 'expr'? 453 self.emit("if (obj == Py_None) {", 1) 454 self.emit("*out = NULL;", 2) 455 self.emit("return 0;", 2) 456 self.emit("}", 1) 457 for a in sum.attributes: 458 self.visitField(a, name, sum=sum, depth=1) 459 for t in sum.types: 460 self.emit("tp = state->%s_type;" % (t.name,), 1) 461 self.emit("isinstance = PyObject_IsInstance(obj, tp);", 1) 462 self.emit("if (isinstance == -1) {", 1) 463 self.emit("return 1;", 2) 464 self.emit("}", 1) 465 self.emit("if (isinstance) {", 1) 466 for f in t.fields: 467 self.visitFieldDeclaration(f, t.name, sum=sum, depth=2) 468 self.emit("", 0) 469 for f in t.fields: 470 self.visitField(f, t.name, sum=sum, depth=2) 471 args = [f.name for f in t.fields] + [a.name for a in sum.attributes] 472 self.emit("*out = %s(%s);" % (t.name, self.buildArgs(args)), 2) 473 self.emit("if (*out == NULL) goto failed;", 2) 474 self.emit("return 0;", 2) 475 self.emit("}", 1) 476 self.sumTrailer(name, True) 477 478 def visitAttributeDeclaration(self, a, name, sum=sum): 479 ctype = get_c_type(a.type) 480 self.emit("%s %s;" % (ctype, a.name), 1) 481 482 def visitSum(self, sum, name): 483 if is_simple(sum): 484 self.simpleSum(sum, name) 485 else: 486 self.complexSum(sum, name) 487 488 def visitProduct(self, prod, name): 489 ctype = get_c_type(name) 490 self.emit("int", 0) 491 self.emit("obj2ast_%s(astmodulestate *state, PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0) 492 self.emit("{", 0) 493 self.emit("PyObject* tmp = NULL;", 1) 494 for f in prod.fields: 495 self.visitFieldDeclaration(f, name, prod=prod, depth=1) 496 for a in prod.attributes: 497 self.visitFieldDeclaration(a, name, prod=prod, depth=1) 498 self.emit("", 0) 499 for f in prod.fields: 500 self.visitField(f, name, prod=prod, depth=1) 501 for a in prod.attributes: 502 self.visitField(a, name, prod=prod, depth=1) 503 args = [f.name for f in prod.fields] 504 args.extend([a.name for a in prod.attributes]) 505 self.emit("*out = %s(%s);" % (name, self.buildArgs(args)), 1) 506 self.emit("return 0;", 1) 507 self.emit("failed:", 0) 508 self.emit("Py_XDECREF(tmp);", 1) 509 self.emit("return 1;", 1) 510 self.emit("}", 0) 511 self.emit("", 0) 512 513 def visitFieldDeclaration(self, field, name, sum=None, prod=None, depth=0): 514 ctype = get_c_type(field.type) 515 if field.seq: 516 if self.isSimpleType(field): 517 self.emit("asdl_int_seq* %s;" % field.name, depth) 518 else: 519 self.emit("asdl_seq* %s;" % field.name, depth) 520 else: 521 ctype = get_c_type(field.type) 522 self.emit("%s %s;" % (ctype, field.name), depth) 523 524 def isSimpleSum(self, field): 525 # XXX can the members of this list be determined automatically? 526 return field.type in ('expr_context', 'boolop', 'operator', 527 'unaryop', 'cmpop') 528 529 def isNumeric(self, field): 530 return get_c_type(field.type) in ("int", "bool") 531 532 def isSimpleType(self, field): 533 return self.isSimpleSum(field) or self.isNumeric(field) 534 535 def visitField(self, field, name, sum=None, prod=None, depth=0): 536 ctype = get_c_type(field.type) 537 line = "if (_PyObject_LookupAttr(obj, state->%s, &tmp) < 0) {" 538 self.emit(line % field.name, depth) 539 self.emit("return 1;", depth+1) 540 self.emit("}", depth) 541 if not field.opt: 542 self.emit("if (tmp == NULL) {", depth) 543 message = "required field \\\"%s\\\" missing from %s" % (field.name, name) 544 format = "PyErr_SetString(PyExc_TypeError, \"%s\");" 545 self.emit(format % message, depth+1, reflow=False) 546 self.emit("return 1;", depth+1) 547 else: 548 self.emit("if (tmp == NULL || tmp == Py_None) {", depth) 549 self.emit("Py_CLEAR(tmp);", depth+1) 550 if self.isNumeric(field): 551 self.emit("%s = 0;" % field.name, depth+1) 552 elif not self.isSimpleType(field): 553 self.emit("%s = NULL;" % field.name, depth+1) 554 else: 555 raise TypeError("could not determine the default value for %s" % field.name) 556 self.emit("}", depth) 557 self.emit("else {", depth) 558 559 self.emit("int res;", depth+1) 560 if field.seq: 561 self.emit("Py_ssize_t len;", depth+1) 562 self.emit("Py_ssize_t i;", depth+1) 563 self.emit("if (!PyList_Check(tmp)) {", depth+1) 564 self.emit("PyErr_Format(PyExc_TypeError, \"%s field \\\"%s\\\" must " 565 "be a list, not a %%.200s\", _PyType_Name(Py_TYPE(tmp)));" % 566 (name, field.name), 567 depth+2, reflow=False) 568 self.emit("goto failed;", depth+2) 569 self.emit("}", depth+1) 570 self.emit("len = PyList_GET_SIZE(tmp);", depth+1) 571 if self.isSimpleType(field): 572 self.emit("%s = _Py_asdl_int_seq_new(len, arena);" % field.name, depth+1) 573 else: 574 self.emit("%s = _Py_asdl_seq_new(len, arena);" % field.name, depth+1) 575 self.emit("if (%s == NULL) goto failed;" % field.name, depth+1) 576 self.emit("for (i = 0; i < len; i++) {", depth+1) 577 self.emit("%s val;" % ctype, depth+2) 578 self.emit("PyObject *tmp2 = PyList_GET_ITEM(tmp, i);", depth+2) 579 self.emit("Py_INCREF(tmp2);", depth+2) 580 with self.recursive_call(name, depth+2): 581 self.emit("res = obj2ast_%s(state, tmp2, &val, arena);" % 582 field.type, depth+2, reflow=False) 583 self.emit("Py_DECREF(tmp2);", depth+2) 584 self.emit("if (res != 0) goto failed;", depth+2) 585 self.emit("if (len != PyList_GET_SIZE(tmp)) {", depth+2) 586 self.emit("PyErr_SetString(PyExc_RuntimeError, \"%s field \\\"%s\\\" " 587 "changed size during iteration\");" % 588 (name, field.name), 589 depth+3, reflow=False) 590 self.emit("goto failed;", depth+3) 591 self.emit("}", depth+2) 592 self.emit("asdl_seq_SET(%s, i, val);" % field.name, depth+2) 593 self.emit("}", depth+1) 594 else: 595 with self.recursive_call(name, depth+1): 596 self.emit("res = obj2ast_%s(state, tmp, &%s, arena);" % 597 (field.type, field.name), depth+1) 598 self.emit("if (res != 0) goto failed;", depth+1) 599 600 self.emit("Py_CLEAR(tmp);", depth+1) 601 self.emit("}", depth) 602 603 604class MarshalPrototypeVisitor(PickleVisitor): 605 606 def prototype(self, sum, name): 607 ctype = get_c_type(name) 608 self.emit("static int marshal_write_%s(PyObject **, int *, %s);" 609 % (name, ctype), 0) 610 611 visitProduct = visitSum = prototype 612 613 614class PyTypesDeclareVisitor(PickleVisitor): 615 616 def visitProduct(self, prod, name): 617 self.emit_type("%s_type" % name) 618 self.emit("static PyObject* ast2obj_%s(astmodulestate *state, void*);" % name, 0) 619 if prod.attributes: 620 for a in prod.attributes: 621 self.emit_identifier(a.name) 622 self.emit("static const char * const %s_attributes[] = {" % name, 0) 623 for a in prod.attributes: 624 self.emit('"%s",' % a.name, 1) 625 self.emit("};", 0) 626 if prod.fields: 627 for f in prod.fields: 628 self.emit_identifier(f.name) 629 self.emit("static const char * const %s_fields[]={" % name,0) 630 for f in prod.fields: 631 self.emit('"%s",' % f.name, 1) 632 self.emit("};", 0) 633 634 def visitSum(self, sum, name): 635 self.emit_type("%s_type" % name) 636 if sum.attributes: 637 for a in sum.attributes: 638 self.emit_identifier(a.name) 639 self.emit("static const char * const %s_attributes[] = {" % name, 0) 640 for a in sum.attributes: 641 self.emit('"%s",' % a.name, 1) 642 self.emit("};", 0) 643 ptype = "void*" 644 if is_simple(sum): 645 ptype = get_c_type(name) 646 for t in sum.types: 647 self.emit_singleton("%s_singleton" % t.name) 648 self.emit("static PyObject* ast2obj_%s(astmodulestate *state, %s);" % (name, ptype), 0) 649 for t in sum.types: 650 self.visitConstructor(t, name) 651 652 def visitConstructor(self, cons, name): 653 if cons.fields: 654 for t in cons.fields: 655 self.emit_identifier(t.name) 656 self.emit("static const char * const %s_fields[]={" % cons.name, 0) 657 for t in cons.fields: 658 self.emit('"%s",' % t.name, 1) 659 self.emit("};",0) 660 661class PyTypesVisitor(PickleVisitor): 662 663 def visitModule(self, mod): 664 self.emit(""" 665 666typedef struct { 667 PyObject_HEAD 668 PyObject *dict; 669} AST_object; 670 671static void 672ast_dealloc(AST_object *self) 673{ 674 /* bpo-31095: UnTrack is needed before calling any callbacks */ 675 PyTypeObject *tp = Py_TYPE(self); 676 PyObject_GC_UnTrack(self); 677 Py_CLEAR(self->dict); 678 freefunc free_func = PyType_GetSlot(tp, Py_tp_free); 679 assert(free_func != NULL); 680 free_func(self); 681 Py_DECREF(tp); 682} 683 684static int 685ast_traverse(AST_object *self, visitproc visit, void *arg) 686{ 687 Py_VISIT(Py_TYPE(self)); 688 Py_VISIT(self->dict); 689 return 0; 690} 691 692static int 693ast_clear(AST_object *self) 694{ 695 Py_CLEAR(self->dict); 696 return 0; 697} 698 699static int 700ast_type_init(PyObject *self, PyObject *args, PyObject *kw) 701{ 702 astmodulestate *state = get_global_ast_state(); 703 if (state == NULL) { 704 return -1; 705 } 706 707 Py_ssize_t i, numfields = 0; 708 int res = -1; 709 PyObject *key, *value, *fields; 710 if (_PyObject_LookupAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) { 711 goto cleanup; 712 } 713 if (fields) { 714 numfields = PySequence_Size(fields); 715 if (numfields == -1) { 716 goto cleanup; 717 } 718 } 719 720 res = 0; /* if no error occurs, this stays 0 to the end */ 721 if (numfields < PyTuple_GET_SIZE(args)) { 722 PyErr_Format(PyExc_TypeError, "%.400s constructor takes at most " 723 "%zd positional argument%s", 724 _PyType_Name(Py_TYPE(self)), 725 numfields, numfields == 1 ? "" : "s"); 726 res = -1; 727 goto cleanup; 728 } 729 for (i = 0; i < PyTuple_GET_SIZE(args); i++) { 730 /* cannot be reached when fields is NULL */ 731 PyObject *name = PySequence_GetItem(fields, i); 732 if (!name) { 733 res = -1; 734 goto cleanup; 735 } 736 res = PyObject_SetAttr(self, name, PyTuple_GET_ITEM(args, i)); 737 Py_DECREF(name); 738 if (res < 0) { 739 goto cleanup; 740 } 741 } 742 if (kw) { 743 i = 0; /* needed by PyDict_Next */ 744 while (PyDict_Next(kw, &i, &key, &value)) { 745 int contains = PySequence_Contains(fields, key); 746 if (contains == -1) { 747 res = -1; 748 goto cleanup; 749 } else if (contains == 1) { 750 Py_ssize_t p = PySequence_Index(fields, key); 751 if (p == -1) { 752 res = -1; 753 goto cleanup; 754 } 755 if (p < PyTuple_GET_SIZE(args)) { 756 PyErr_Format(PyExc_TypeError, 757 "%.400s got multiple values for argument '%U'", 758 Py_TYPE(self)->tp_name, key); 759 res = -1; 760 goto cleanup; 761 } 762 } 763 res = PyObject_SetAttr(self, key, value); 764 if (res < 0) { 765 goto cleanup; 766 } 767 } 768 } 769 cleanup: 770 Py_XDECREF(fields); 771 return res; 772} 773 774/* Pickling support */ 775static PyObject * 776ast_type_reduce(PyObject *self, PyObject *unused) 777{ 778 astmodulestate *state = get_global_ast_state(); 779 if (state == NULL) { 780 return NULL; 781 } 782 783 PyObject *dict; 784 if (_PyObject_LookupAttr(self, state->__dict__, &dict) < 0) { 785 return NULL; 786 } 787 if (dict) { 788 return Py_BuildValue("O()N", Py_TYPE(self), dict); 789 } 790 return Py_BuildValue("O()", Py_TYPE(self)); 791} 792 793static PyMemberDef ast_type_members[] = { 794 {"__dictoffset__", T_PYSSIZET, offsetof(AST_object, dict), READONLY}, 795 {NULL} /* Sentinel */ 796}; 797 798static PyMethodDef ast_type_methods[] = { 799 {"__reduce__", ast_type_reduce, METH_NOARGS, NULL}, 800 {NULL} 801}; 802 803static PyGetSetDef ast_type_getsets[] = { 804 {"__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict}, 805 {NULL} 806}; 807 808static PyType_Slot AST_type_slots[] = { 809 {Py_tp_dealloc, ast_dealloc}, 810 {Py_tp_getattro, PyObject_GenericGetAttr}, 811 {Py_tp_setattro, PyObject_GenericSetAttr}, 812 {Py_tp_traverse, ast_traverse}, 813 {Py_tp_clear, ast_clear}, 814 {Py_tp_members, ast_type_members}, 815 {Py_tp_methods, ast_type_methods}, 816 {Py_tp_getset, ast_type_getsets}, 817 {Py_tp_init, ast_type_init}, 818 {Py_tp_alloc, PyType_GenericAlloc}, 819 {Py_tp_new, PyType_GenericNew}, 820 {Py_tp_free, PyObject_GC_Del}, 821 {0, 0}, 822}; 823 824static PyType_Spec AST_type_spec = { 825 "ast.AST", 826 sizeof(AST_object), 827 0, 828 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, 829 AST_type_slots 830}; 831 832static PyObject * 833make_type(astmodulestate *state, const char *type, PyObject* base, 834 const char* const* fields, int num_fields, const char *doc) 835{ 836 PyObject *fnames, *result; 837 int i; 838 fnames = PyTuple_New(num_fields); 839 if (!fnames) return NULL; 840 for (i = 0; i < num_fields; i++) { 841 PyObject *field = PyUnicode_InternFromString(fields[i]); 842 if (!field) { 843 Py_DECREF(fnames); 844 return NULL; 845 } 846 PyTuple_SET_ITEM(fnames, i, field); 847 } 848 result = PyObject_CallFunction((PyObject*)&PyType_Type, "s(O){OOOOOs}", 849 type, base, 850 state->_fields, fnames, 851 state->__module__, 852 state->ast, 853 state->__doc__, doc); 854 Py_DECREF(fnames); 855 return result; 856} 857 858static int 859add_attributes(astmodulestate *state, PyObject *type, const char * const *attrs, int num_fields) 860{ 861 int i, result; 862 PyObject *s, *l = PyTuple_New(num_fields); 863 if (!l) 864 return 0; 865 for (i = 0; i < num_fields; i++) { 866 s = PyUnicode_InternFromString(attrs[i]); 867 if (!s) { 868 Py_DECREF(l); 869 return 0; 870 } 871 PyTuple_SET_ITEM(l, i, s); 872 } 873 result = PyObject_SetAttr(type, state->_attributes, l) >= 0; 874 Py_DECREF(l); 875 return result; 876} 877 878/* Conversion AST -> Python */ 879 880static PyObject* ast2obj_list(astmodulestate *state, asdl_seq *seq, PyObject* (*func)(astmodulestate *state, void*)) 881{ 882 Py_ssize_t i, n = asdl_seq_LEN(seq); 883 PyObject *result = PyList_New(n); 884 PyObject *value; 885 if (!result) 886 return NULL; 887 for (i = 0; i < n; i++) { 888 value = func(state, asdl_seq_GET(seq, i)); 889 if (!value) { 890 Py_DECREF(result); 891 return NULL; 892 } 893 PyList_SET_ITEM(result, i, value); 894 } 895 return result; 896} 897 898static PyObject* ast2obj_object(astmodulestate *Py_UNUSED(state), void *o) 899{ 900 if (!o) 901 o = Py_None; 902 Py_INCREF((PyObject*)o); 903 return (PyObject*)o; 904} 905#define ast2obj_constant ast2obj_object 906#define ast2obj_identifier ast2obj_object 907#define ast2obj_string ast2obj_object 908 909static PyObject* ast2obj_int(astmodulestate *Py_UNUSED(state), long b) 910{ 911 return PyLong_FromLong(b); 912} 913 914/* Conversion Python -> AST */ 915 916static int obj2ast_object(astmodulestate *Py_UNUSED(state), PyObject* obj, PyObject** out, PyArena* arena) 917{ 918 if (obj == Py_None) 919 obj = NULL; 920 if (obj) { 921 if (PyArena_AddPyObject(arena, obj) < 0) { 922 *out = NULL; 923 return -1; 924 } 925 Py_INCREF(obj); 926 } 927 *out = obj; 928 return 0; 929} 930 931static int obj2ast_constant(astmodulestate *Py_UNUSED(state), PyObject* obj, PyObject** out, PyArena* arena) 932{ 933 if (PyArena_AddPyObject(arena, obj) < 0) { 934 *out = NULL; 935 return -1; 936 } 937 Py_INCREF(obj); 938 *out = obj; 939 return 0; 940} 941 942static int obj2ast_identifier(astmodulestate *state, PyObject* obj, PyObject** out, PyArena* arena) 943{ 944 if (!PyUnicode_CheckExact(obj) && obj != Py_None) { 945 PyErr_SetString(PyExc_TypeError, "AST identifier must be of type str"); 946 return 1; 947 } 948 return obj2ast_object(state, obj, out, arena); 949} 950 951static int obj2ast_string(astmodulestate *state, PyObject* obj, PyObject** out, PyArena* arena) 952{ 953 if (!PyUnicode_CheckExact(obj) && !PyBytes_CheckExact(obj)) { 954 PyErr_SetString(PyExc_TypeError, "AST string must be of type str"); 955 return 1; 956 } 957 return obj2ast_object(state, obj, out, arena); 958} 959 960static int obj2ast_int(astmodulestate* Py_UNUSED(state), PyObject* obj, int* out, PyArena* arena) 961{ 962 int i; 963 if (!PyLong_Check(obj)) { 964 PyErr_Format(PyExc_ValueError, "invalid integer value: %R", obj); 965 return 1; 966 } 967 968 i = _PyLong_AsInt(obj); 969 if (i == -1 && PyErr_Occurred()) 970 return 1; 971 *out = i; 972 return 0; 973} 974 975static int add_ast_fields(astmodulestate *state) 976{ 977 PyObject *empty_tuple; 978 empty_tuple = PyTuple_New(0); 979 if (!empty_tuple || 980 PyObject_SetAttrString(state->AST_type, "_fields", empty_tuple) < 0 || 981 PyObject_SetAttrString(state->AST_type, "_attributes", empty_tuple) < 0) { 982 Py_XDECREF(empty_tuple); 983 return -1; 984 } 985 Py_DECREF(empty_tuple); 986 return 0; 987} 988 989""", 0, reflow=False) 990 991 self.emit("static int init_types(astmodulestate *state)",0) 992 self.emit("{", 0) 993 self.emit("if (state->initialized) return 1;", 1) 994 self.emit("if (init_identifiers(state) < 0) return 0;", 1) 995 self.emit("state->AST_type = PyType_FromSpec(&AST_type_spec);", 1) 996 self.emit("if (!state->AST_type) return 0;", 1) 997 self.emit("if (add_ast_fields(state) < 0) return 0;", 1) 998 for dfn in mod.dfns: 999 self.visit(dfn) 1000 self.emit("state->initialized = 1;", 1) 1001 self.emit("return 1;", 1); 1002 self.emit("}", 0) 1003 1004 def visitProduct(self, prod, name): 1005 if prod.fields: 1006 fields = name+"_fields" 1007 else: 1008 fields = "NULL" 1009 self.emit('state->%s_type = make_type(state, "%s", state->AST_type, %s, %d,' % 1010 (name, name, fields, len(prod.fields)), 1) 1011 self.emit('%s);' % reflow_c_string(asdl_of(name, prod), 2), 2, reflow=False) 1012 self.emit("if (!state->%s_type) return 0;" % name, 1) 1013 self.emit_type("AST_type") 1014 self.emit_type("%s_type" % name) 1015 if prod.attributes: 1016 self.emit("if (!add_attributes(state, state->%s_type, %s_attributes, %d)) return 0;" % 1017 (name, name, len(prod.attributes)), 1) 1018 else: 1019 self.emit("if (!add_attributes(state, state->%s_type, NULL, 0)) return 0;" % name, 1) 1020 self.emit_defaults(name, prod.fields, 1) 1021 self.emit_defaults(name, prod.attributes, 1) 1022 1023 def visitSum(self, sum, name): 1024 self.emit('state->%s_type = make_type(state, "%s", state->AST_type, NULL, 0,' % 1025 (name, name), 1) 1026 self.emit('%s);' % reflow_c_string(asdl_of(name, sum), 2), 2, reflow=False) 1027 self.emit_type("%s_type" % name) 1028 self.emit("if (!state->%s_type) return 0;" % name, 1) 1029 if sum.attributes: 1030 self.emit("if (!add_attributes(state, state->%s_type, %s_attributes, %d)) return 0;" % 1031 (name, name, len(sum.attributes)), 1) 1032 else: 1033 self.emit("if (!add_attributes(state, state->%s_type, NULL, 0)) return 0;" % name, 1) 1034 self.emit_defaults(name, sum.attributes, 1) 1035 simple = is_simple(sum) 1036 for t in sum.types: 1037 self.visitConstructor(t, name, simple) 1038 1039 def visitConstructor(self, cons, name, simple): 1040 if cons.fields: 1041 fields = cons.name+"_fields" 1042 else: 1043 fields = "NULL" 1044 self.emit('state->%s_type = make_type(state, "%s", state->%s_type, %s, %d,' % 1045 (cons.name, cons.name, name, fields, len(cons.fields)), 1) 1046 self.emit('%s);' % reflow_c_string(asdl_of(cons.name, cons), 2), 2, reflow=False) 1047 self.emit("if (!state->%s_type) return 0;" % cons.name, 1) 1048 self.emit_type("%s_type" % cons.name) 1049 self.emit_defaults(cons.name, cons.fields, 1) 1050 if simple: 1051 self.emit("state->%s_singleton = PyType_GenericNew((PyTypeObject *)" 1052 "state->%s_type, NULL, NULL);" % 1053 (cons.name, cons.name), 1) 1054 self.emit("if (!state->%s_singleton) return 0;" % cons.name, 1) 1055 1056 def emit_defaults(self, name, fields, depth): 1057 for field in fields: 1058 if field.opt: 1059 self.emit('if (PyObject_SetAttr(state->%s_type, state->%s, Py_None) == -1)' % 1060 (name, field.name), depth) 1061 self.emit("return 0;", depth+1) 1062 1063 1064class ASTModuleVisitor(PickleVisitor): 1065 1066 def visitModule(self, mod): 1067 self.emit("static int", 0) 1068 self.emit("astmodule_exec(PyObject *m)", 0) 1069 self.emit("{", 0) 1070 self.emit('astmodulestate *state = get_ast_state(m);', 1) 1071 self.emit("", 0) 1072 1073 self.emit("if (!init_types(state)) {", 1) 1074 self.emit("return -1;", 2) 1075 self.emit("}", 1) 1076 self.emit('if (PyModule_AddObject(m, "AST", state->AST_type) < 0) {', 1) 1077 self.emit('return -1;', 2) 1078 self.emit('}', 1) 1079 self.emit('Py_INCREF(state->AST_type);', 1) 1080 self.emit('if (PyModule_AddIntMacro(m, PyCF_ALLOW_TOP_LEVEL_AWAIT) < 0) {', 1) 1081 self.emit("return -1;", 2) 1082 self.emit('}', 1) 1083 self.emit('if (PyModule_AddIntMacro(m, PyCF_ONLY_AST) < 0) {', 1) 1084 self.emit("return -1;", 2) 1085 self.emit('}', 1) 1086 self.emit('if (PyModule_AddIntMacro(m, PyCF_TYPE_COMMENTS) < 0) {', 1) 1087 self.emit("return -1;", 2) 1088 self.emit('}', 1) 1089 for dfn in mod.dfns: 1090 self.visit(dfn) 1091 self.emit("return 0;", 1) 1092 self.emit("}", 0) 1093 self.emit("", 0) 1094 self.emit(""" 1095static PyModuleDef_Slot astmodule_slots[] = { 1096 {Py_mod_exec, astmodule_exec}, 1097 {0, NULL} 1098}; 1099 1100static struct PyModuleDef _astmodule = { 1101 PyModuleDef_HEAD_INIT, 1102 .m_name = "_ast", 1103 // The _ast module uses a global state (global_ast_state). 1104 .m_size = 0, 1105 .m_slots = astmodule_slots, 1106}; 1107 1108PyMODINIT_FUNC 1109PyInit__ast(void) 1110{ 1111 return PyModuleDef_Init(&_astmodule); 1112} 1113""".strip(), 0, reflow=False) 1114 1115 def visitProduct(self, prod, name): 1116 self.addObj(name) 1117 1118 def visitSum(self, sum, name): 1119 self.addObj(name) 1120 for t in sum.types: 1121 self.visitConstructor(t, name) 1122 1123 def visitConstructor(self, cons, name): 1124 self.addObj(cons.name) 1125 1126 def addObj(self, name): 1127 self.emit("if (PyModule_AddObject(m, \"%s\", " 1128 "state->%s_type) < 0) {" % (name, name), 1) 1129 self.emit("return -1;", 2) 1130 self.emit('}', 1) 1131 self.emit("Py_INCREF(state->%s_type);" % name, 1) 1132 1133 1134_SPECIALIZED_SEQUENCES = ('stmt', 'expr') 1135 1136def find_sequence(fields, doing_specialization): 1137 """Return True if any field uses a sequence.""" 1138 for f in fields: 1139 if f.seq: 1140 if not doing_specialization: 1141 return True 1142 if str(f.type) not in _SPECIALIZED_SEQUENCES: 1143 return True 1144 return False 1145 1146def has_sequence(types, doing_specialization): 1147 for t in types: 1148 if find_sequence(t.fields, doing_specialization): 1149 return True 1150 return False 1151 1152 1153class StaticVisitor(PickleVisitor): 1154 CODE = '''Very simple, always emit this static code. Override CODE''' 1155 1156 def visit(self, object): 1157 self.emit(self.CODE, 0, reflow=False) 1158 1159 1160class ObjVisitor(PickleVisitor): 1161 1162 def func_begin(self, name): 1163 ctype = get_c_type(name) 1164 self.emit("PyObject*", 0) 1165 self.emit("ast2obj_%s(astmodulestate *state, void* _o)" % (name), 0) 1166 self.emit("{", 0) 1167 self.emit("%s o = (%s)_o;" % (ctype, ctype), 1) 1168 self.emit("PyObject *result = NULL, *value = NULL;", 1) 1169 self.emit("PyTypeObject *tp;", 1) 1170 self.emit('if (!o) {', 1) 1171 self.emit("Py_RETURN_NONE;", 2) 1172 self.emit("}", 1) 1173 1174 def func_end(self): 1175 self.emit("return result;", 1) 1176 self.emit("failed:", 0) 1177 self.emit("Py_XDECREF(value);", 1) 1178 self.emit("Py_XDECREF(result);", 1) 1179 self.emit("return NULL;", 1) 1180 self.emit("}", 0) 1181 self.emit("", 0) 1182 1183 def visitSum(self, sum, name): 1184 if is_simple(sum): 1185 self.simpleSum(sum, name) 1186 return 1187 self.func_begin(name) 1188 self.emit("switch (o->kind) {", 1) 1189 for i in range(len(sum.types)): 1190 t = sum.types[i] 1191 self.visitConstructor(t, i + 1, name) 1192 self.emit("}", 1) 1193 for a in sum.attributes: 1194 self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1) 1195 self.emit("if (!value) goto failed;", 1) 1196 self.emit('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a.name, 1) 1197 self.emit('goto failed;', 2) 1198 self.emit('Py_DECREF(value);', 1) 1199 self.func_end() 1200 1201 def simpleSum(self, sum, name): 1202 self.emit("PyObject* ast2obj_%s(astmodulestate *state, %s_ty o)" % (name, name), 0) 1203 self.emit("{", 0) 1204 self.emit("switch(o) {", 1) 1205 for t in sum.types: 1206 self.emit("case %s:" % t.name, 2) 1207 self.emit("Py_INCREF(state->%s_singleton);" % t.name, 3) 1208 self.emit("return state->%s_singleton;" % t.name, 3) 1209 self.emit("}", 1) 1210 self.emit("Py_UNREACHABLE();", 1); 1211 self.emit("}", 0) 1212 1213 def visitProduct(self, prod, name): 1214 self.func_begin(name) 1215 self.emit("tp = (PyTypeObject *)state->%s_type;" % name, 1) 1216 self.emit("result = PyType_GenericNew(tp, NULL, NULL);", 1); 1217 self.emit("if (!result) return NULL;", 1) 1218 for field in prod.fields: 1219 self.visitField(field, name, 1, True) 1220 for a in prod.attributes: 1221 self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1) 1222 self.emit("if (!value) goto failed;", 1) 1223 self.emit("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a.name, 1) 1224 self.emit('goto failed;', 2) 1225 self.emit('Py_DECREF(value);', 1) 1226 self.func_end() 1227 1228 def visitConstructor(self, cons, enum, name): 1229 self.emit("case %s_kind:" % cons.name, 1) 1230 self.emit("tp = (PyTypeObject *)state->%s_type;" % cons.name, 2) 1231 self.emit("result = PyType_GenericNew(tp, NULL, NULL);", 2); 1232 self.emit("if (!result) goto failed;", 2) 1233 for f in cons.fields: 1234 self.visitField(f, cons.name, 2, False) 1235 self.emit("break;", 2) 1236 1237 def visitField(self, field, name, depth, product): 1238 def emit(s, d): 1239 self.emit(s, depth + d) 1240 if product: 1241 value = "o->%s" % field.name 1242 else: 1243 value = "o->v.%s.%s" % (name, field.name) 1244 self.set(field, value, depth) 1245 emit("if (!value) goto failed;", 0) 1246 emit("if (PyObject_SetAttr(result, state->%s, value) == -1)" % field.name, 0) 1247 emit("goto failed;", 1) 1248 emit("Py_DECREF(value);", 0) 1249 1250 def emitSeq(self, field, value, depth, emit): 1251 emit("seq = %s;" % value, 0) 1252 emit("n = asdl_seq_LEN(seq);", 0) 1253 emit("value = PyList_New(n);", 0) 1254 emit("if (!value) goto failed;", 0) 1255 emit("for (i = 0; i < n; i++) {", 0) 1256 self.set("value", field, "asdl_seq_GET(seq, i)", depth + 1) 1257 emit("if (!value1) goto failed;", 1) 1258 emit("PyList_SET_ITEM(value, i, value1);", 1) 1259 emit("value1 = NULL;", 1) 1260 emit("}", 0) 1261 1262 def set(self, field, value, depth): 1263 if field.seq: 1264 # XXX should really check for is_simple, but that requires a symbol table 1265 if field.type == "cmpop": 1266 # While the sequence elements are stored as void*, 1267 # ast2obj_cmpop expects an enum 1268 self.emit("{", depth) 1269 self.emit("Py_ssize_t i, n = asdl_seq_LEN(%s);" % value, depth+1) 1270 self.emit("value = PyList_New(n);", depth+1) 1271 self.emit("if (!value) goto failed;", depth+1) 1272 self.emit("for(i = 0; i < n; i++)", depth+1) 1273 # This cannot fail, so no need for error handling 1274 self.emit("PyList_SET_ITEM(value, i, ast2obj_cmpop(state, (cmpop_ty)asdl_seq_GET(%s, i)));" % value, 1275 depth+2, reflow=False) 1276 self.emit("}", depth) 1277 else: 1278 self.emit("value = ast2obj_list(state, %s, ast2obj_%s);" % (value, field.type), depth) 1279 else: 1280 ctype = get_c_type(field.type) 1281 self.emit("value = ast2obj_%s(state, %s);" % (field.type, value), depth, reflow=False) 1282 1283 1284class PartingShots(StaticVisitor): 1285 1286 CODE = """ 1287PyObject* PyAST_mod2obj(mod_ty t) 1288{ 1289 astmodulestate *state = get_global_ast_state(); 1290 if (state == NULL) { 1291 return NULL; 1292 } 1293 return ast2obj_mod(state, t); 1294} 1295 1296/* mode is 0 for "exec", 1 for "eval" and 2 for "single" input */ 1297mod_ty PyAST_obj2mod(PyObject* ast, PyArena* arena, int mode) 1298{ 1299 const char * const req_name[] = {"Module", "Expression", "Interactive"}; 1300 int isinstance; 1301 1302 if (PySys_Audit("compile", "OO", ast, Py_None) < 0) { 1303 return NULL; 1304 } 1305 1306 astmodulestate *state = get_global_ast_state(); 1307 PyObject *req_type[3]; 1308 req_type[0] = state->Module_type; 1309 req_type[1] = state->Expression_type; 1310 req_type[2] = state->Interactive_type; 1311 1312 assert(0 <= mode && mode <= 2); 1313 1314 isinstance = PyObject_IsInstance(ast, req_type[mode]); 1315 if (isinstance == -1) 1316 return NULL; 1317 if (!isinstance) { 1318 PyErr_Format(PyExc_TypeError, "expected %s node, got %.400s", 1319 req_name[mode], _PyType_Name(Py_TYPE(ast))); 1320 return NULL; 1321 } 1322 1323 mod_ty res = NULL; 1324 if (obj2ast_mod(state, ast, &res, arena) != 0) 1325 return NULL; 1326 else 1327 return res; 1328} 1329 1330int PyAST_Check(PyObject* obj) 1331{ 1332 astmodulestate *state = get_global_ast_state(); 1333 if (state == NULL) { 1334 return -1; 1335 } 1336 return PyObject_IsInstance(obj, state->AST_type); 1337} 1338""" 1339 1340class ChainOfVisitors: 1341 def __init__(self, *visitors): 1342 self.visitors = visitors 1343 1344 def visit(self, object): 1345 for v in self.visitors: 1346 v.visit(object) 1347 v.emit("", 0) 1348 1349 1350def generate_module_def(f, mod): 1351 # Gather all the data needed for ModuleSpec 1352 visitor_list = set() 1353 with open(os.devnull, "w") as devnull: 1354 visitor = PyTypesDeclareVisitor(devnull) 1355 visitor.visit(mod) 1356 visitor_list.add(visitor) 1357 visitor = PyTypesVisitor(devnull) 1358 visitor.visit(mod) 1359 visitor_list.add(visitor) 1360 1361 state_strings = { 1362 "ast", 1363 "_fields", 1364 "__doc__", 1365 "__dict__", 1366 "__module__", 1367 "_attributes", 1368 } 1369 module_state = state_strings.copy() 1370 for visitor in visitor_list: 1371 for identifier in visitor.identifiers: 1372 module_state.add(identifier) 1373 state_strings.add(identifier) 1374 for singleton in visitor.singletons: 1375 module_state.add(singleton) 1376 for tp in visitor.types: 1377 module_state.add(tp) 1378 state_strings = sorted(state_strings) 1379 module_state = sorted(module_state) 1380 f.write('typedef struct {\n') 1381 f.write(' int initialized;\n') 1382 for s in module_state: 1383 f.write(' PyObject *' + s + ';\n') 1384 f.write('} astmodulestate;\n\n') 1385 f.write(""" 1386// Forward declaration 1387static int init_types(astmodulestate *state); 1388 1389// bpo-41194, bpo-41261, bpo-41631: The _ast module uses a global state. 1390static astmodulestate global_ast_state = {0}; 1391 1392static astmodulestate* 1393get_global_ast_state(void) 1394{ 1395 astmodulestate* state = &global_ast_state; 1396 if (!init_types(state)) { 1397 return NULL; 1398 } 1399 return state; 1400} 1401 1402static astmodulestate* 1403get_ast_state(PyObject* Py_UNUSED(module)) 1404{ 1405 astmodulestate* state = get_global_ast_state(); 1406 // get_ast_state() must only be called after _ast module is imported, 1407 // and astmodule_exec() calls init_types() 1408 assert(state != NULL); 1409 return state; 1410} 1411 1412void _PyAST_Fini() 1413{ 1414 astmodulestate* state = &global_ast_state; 1415""") 1416 for s in module_state: 1417 f.write(" Py_CLEAR(state->" + s + ');\n') 1418 f.write(""" 1419 state->initialized = 0; 1420} 1421 1422""") 1423 f.write('static int init_identifiers(astmodulestate *state)\n') 1424 f.write('{\n') 1425 for identifier in state_strings: 1426 f.write(' if ((state->' + identifier) 1427 f.write(' = PyUnicode_InternFromString("') 1428 f.write(identifier + '")) == NULL) return 0;\n') 1429 f.write(' return 1;\n') 1430 f.write('};\n\n') 1431 1432def write_header(f, mod): 1433 f.write('#ifndef Py_PYTHON_AST_H\n') 1434 f.write('#define Py_PYTHON_AST_H\n') 1435 f.write('#ifdef __cplusplus\n') 1436 f.write('extern "C" {\n') 1437 f.write('#endif\n') 1438 f.write('\n') 1439 f.write('#ifndef Py_LIMITED_API\n') 1440 f.write('#include "asdl.h"\n') 1441 f.write('\n') 1442 f.write('#undef Yield /* undefine macro conflicting with <winbase.h> */\n') 1443 f.write('\n') 1444 c = ChainOfVisitors(TypeDefVisitor(f), 1445 StructVisitor(f)) 1446 c.visit(mod) 1447 f.write("// Note: these macros affect function definitions, not only call sites.\n") 1448 PrototypeVisitor(f).visit(mod) 1449 f.write("\n") 1450 f.write("PyObject* PyAST_mod2obj(mod_ty t);\n") 1451 f.write("mod_ty PyAST_obj2mod(PyObject* ast, PyArena* arena, int mode);\n") 1452 f.write("int PyAST_Check(PyObject* obj);\n") 1453 f.write("#endif /* !Py_LIMITED_API */\n") 1454 f.write('\n') 1455 f.write('#ifdef __cplusplus\n') 1456 f.write('}\n') 1457 f.write('#endif\n') 1458 f.write('#endif /* !Py_PYTHON_AST_H */\n') 1459 1460def write_source(f, mod): 1461 f.write('#include <stddef.h>\n') 1462 f.write('\n') 1463 f.write('#include "Python.h"\n') 1464 f.write('#include "%s-ast.h"\n' % mod.name) 1465 f.write('#include "structmember.h" // PyMemberDef\n') 1466 f.write('\n') 1467 1468 generate_module_def(f, mod) 1469 1470 v = ChainOfVisitors( 1471 PyTypesDeclareVisitor(f), 1472 PyTypesVisitor(f), 1473 Obj2ModPrototypeVisitor(f), 1474 FunctionVisitor(f), 1475 ObjVisitor(f), 1476 Obj2ModVisitor(f), 1477 ASTModuleVisitor(f), 1478 PartingShots(f), 1479 ) 1480 v.visit(mod) 1481 1482def main(input_file, c_file, h_file, dump_module=False): 1483 auto_gen_msg = AUTOGEN_MESSAGE.format("/".join(Path(__file__).parts[-2:])) 1484 mod = asdl.parse(input_file) 1485 if dump_module: 1486 print('Parsed Module:') 1487 print(mod) 1488 if not asdl.check(mod): 1489 sys.exit(1) 1490 for file, writer in (c_file, write_source), (h_file, write_header): 1491 if file is not None: 1492 with file.open("w") as f: 1493 f.write(auto_gen_msg) 1494 writer(f, mod) 1495 print(file, "regenerated.") 1496 1497if __name__ == "__main__": 1498 parser = ArgumentParser() 1499 parser.add_argument("input_file", type=Path) 1500 parser.add_argument("-C", "--c-file", type=Path, default=None) 1501 parser.add_argument("-H", "--h-file", type=Path, default=None) 1502 parser.add_argument("-d", "--dump-module", action="store_true") 1503 1504 options = parser.parse_args() 1505 main(**vars(options)) 1506