1"""
2Serializes a Cython code tree to Cython code. This is primarily useful for
3debugging and testing purposes.
4
5The output is in a strict format, no whitespace or comments from the input
6is preserved (and it could not be as it is not present in the code tree).
7"""
8
9from __future__ import absolute_import, print_function
10
11from .Compiler.Visitor import TreeVisitor
12from .Compiler.ExprNodes import *
13
14
15class LinesResult(object):
16    def __init__(self):
17        self.lines = []
18        self.s = u""
19
20    def put(self, s):
21        self.s += s
22
23    def newline(self):
24        self.lines.append(self.s)
25        self.s = u""
26
27    def putline(self, s):
28        self.put(s)
29        self.newline()
30
31class DeclarationWriter(TreeVisitor):
32
33    indent_string = u"    "
34
35    def __init__(self, result=None):
36        super(DeclarationWriter, self).__init__()
37        if result is None:
38            result = LinesResult()
39        self.result = result
40        self.numindents = 0
41        self.tempnames = {}
42        self.tempblockindex = 0
43
44    def write(self, tree):
45        self.visit(tree)
46        return self.result
47
48    def indent(self):
49        self.numindents += 1
50
51    def dedent(self):
52        self.numindents -= 1
53
54    def startline(self, s=u""):
55        self.result.put(self.indent_string * self.numindents + s)
56
57    def put(self, s):
58        self.result.put(s)
59
60    def putline(self, s):
61        self.result.putline(self.indent_string * self.numindents + s)
62
63    def endline(self, s=u""):
64        self.result.putline(s)
65
66    def line(self, s):
67        self.startline(s)
68        self.endline()
69
70    def comma_separated_list(self, items, output_rhs=False):
71        if len(items) > 0:
72            for item in items[:-1]:
73                self.visit(item)
74                if output_rhs and item.default is not None:
75                    self.put(u" = ")
76                    self.visit(item.default)
77                self.put(u", ")
78            self.visit(items[-1])
79
80    def visit_Node(self, node):
81        raise AssertionError("Node not handled by serializer: %r" % node)
82
83    def visit_ModuleNode(self, node):
84        self.visitchildren(node)
85
86    def visit_StatListNode(self, node):
87        self.visitchildren(node)
88
89    def visit_CDefExternNode(self, node):
90        if node.include_file is None:
91            file = u'*'
92        else:
93            file = u'"%s"' % node.include_file
94        self.putline(u"cdef extern from %s:" % file)
95        self.indent()
96        self.visit(node.body)
97        self.dedent()
98
99    def visit_CPtrDeclaratorNode(self, node):
100        self.put('*')
101        self.visit(node.base)
102
103    def visit_CReferenceDeclaratorNode(self, node):
104        self.put('&')
105        self.visit(node.base)
106
107    def visit_CArrayDeclaratorNode(self, node):
108        self.visit(node.base)
109        self.put(u'[')
110        if node.dimension is not None:
111            self.visit(node.dimension)
112        self.put(u']')
113
114    def visit_CArrayDeclaratorNode(self, node):
115        self.visit(node.base)
116        self.put(u'[')
117        if node.dimension is not None:
118            self.visit(node.dimension)
119        self.put(u']')
120
121    def visit_CFuncDeclaratorNode(self, node):
122        # TODO: except, gil, etc.
123        self.visit(node.base)
124        self.put(u'(')
125        self.comma_separated_list(node.args)
126        self.endline(u')')
127
128    def visit_CNameDeclaratorNode(self, node):
129        self.put(node.name)
130
131    def visit_CSimpleBaseTypeNode(self, node):
132        # See Parsing.p_sign_and_longness
133        if node.is_basic_c_type:
134            self.put(("unsigned ", "", "signed ")[node.signed])
135            if node.longness < 0:
136                self.put("short " * -node.longness)
137            elif node.longness > 0:
138                self.put("long " * node.longness)
139        self.put(node.name)
140
141    def visit_CComplexBaseTypeNode(self, node):
142        self.put(u'(')
143        self.visit(node.base_type)
144        self.visit(node.declarator)
145        self.put(u')')
146
147    def visit_CNestedBaseTypeNode(self, node):
148        self.visit(node.base_type)
149        self.put(u'.')
150        self.put(node.name)
151
152    def visit_TemplatedTypeNode(self, node):
153        self.visit(node.base_type_node)
154        self.put(u'[')
155        self.comma_separated_list(node.positional_args + node.keyword_args.key_value_pairs)
156        self.put(u']')
157
158    def visit_CVarDefNode(self, node):
159        self.startline(u"cdef ")
160        self.visit(node.base_type)
161        self.put(u" ")
162        self.comma_separated_list(node.declarators, output_rhs=True)
163        self.endline()
164
165    def visit_container_node(self, node, decl, extras, attributes):
166        # TODO: visibility
167        self.startline(decl)
168        if node.name:
169            self.put(u' ')
170            self.put(node.name)
171            if node.cname is not None:
172                self.put(u' "%s"' % node.cname)
173        if extras:
174            self.put(extras)
175        self.endline(':')
176        self.indent()
177        if not attributes:
178            self.putline('pass')
179        else:
180            for attribute in attributes:
181                self.visit(attribute)
182        self.dedent()
183
184    def visit_CStructOrUnionDefNode(self, node):
185        if node.typedef_flag:
186            decl = u'ctypedef '
187        else:
188            decl = u'cdef '
189        if node.visibility == 'public':
190            decl += u'public '
191        if node.packed:
192            decl += u'packed '
193        decl += node.kind
194        self.visit_container_node(node, decl, None, node.attributes)
195
196    def visit_CppClassNode(self, node):
197        extras = ""
198        if node.templates:
199            extras = u"[%s]" % ", ".join(node.templates)
200        if node.base_classes:
201            extras += "(%s)" % ", ".join(node.base_classes)
202        self.visit_container_node(node, u"cdef cppclass", extras, node.attributes)
203
204    def visit_CEnumDefNode(self, node):
205        self.visit_container_node(node, u"cdef enum", None, node.items)
206
207    def visit_CEnumDefItemNode(self, node):
208        self.startline(node.name)
209        if node.cname:
210            self.put(u' "%s"' % node.cname)
211        if node.value:
212            self.put(u" = ")
213            self.visit(node.value)
214        self.endline()
215
216    def visit_CClassDefNode(self, node):
217        assert not node.module_name
218        if node.decorators:
219            for decorator in node.decorators:
220                self.visit(decorator)
221        self.startline(u"cdef class ")
222        self.put(node.class_name)
223        if node.base_class_name:
224            self.put(u"(")
225            if node.base_class_module:
226                self.put(node.base_class_module)
227                self.put(u".")
228            self.put(node.base_class_name)
229            self.put(u")")
230        self.endline(u":")
231        self.indent()
232        self.visit(node.body)
233        self.dedent()
234
235    def visit_CTypeDefNode(self, node):
236        self.startline(u"ctypedef ")
237        self.visit(node.base_type)
238        self.put(u" ")
239        self.visit(node.declarator)
240        self.endline()
241
242    def visit_FuncDefNode(self, node):
243        self.startline(u"def %s(" % node.name)
244        self.comma_separated_list(node.args)
245        self.endline(u"):")
246        self.indent()
247        self.visit(node.body)
248        self.dedent()
249
250    def visit_CArgDeclNode(self, node):
251        if node.base_type.name is not None:
252            self.visit(node.base_type)
253            self.put(u" ")
254        self.visit(node.declarator)
255        if node.default is not None:
256            self.put(u" = ")
257            self.visit(node.default)
258
259    def visit_CImportStatNode(self, node):
260        self.startline(u"cimport ")
261        self.put(node.module_name)
262        if node.as_name:
263            self.put(u" as ")
264            self.put(node.as_name)
265        self.endline()
266
267    def visit_FromCImportStatNode(self, node):
268        self.startline(u"from ")
269        self.put(node.module_name)
270        self.put(u" cimport ")
271        first = True
272        for pos, name, as_name, kind in node.imported_names:
273            assert kind is None
274            if first:
275                first = False
276            else:
277                self.put(u", ")
278            self.put(name)
279            if as_name:
280                self.put(u" as ")
281                self.put(as_name)
282        self.endline()
283
284    def visit_NameNode(self, node):
285        self.put(node.name)
286
287    def visit_IntNode(self, node):
288        self.put(node.value)
289
290    def visit_NoneNode(self, node):
291        self.put(u"None")
292
293    def visit_NotNode(self, node):
294        self.put(u"(not ")
295        self.visit(node.operand)
296        self.put(u")")
297
298    def visit_DecoratorNode(self, node):
299        self.startline("@")
300        self.visit(node.decorator)
301        self.endline()
302
303    def visit_BinopNode(self, node):
304        self.visit(node.operand1)
305        self.put(u" %s " % node.operator)
306        self.visit(node.operand2)
307
308    def visit_AttributeNode(self, node):
309        self.visit(node.obj)
310        self.put(u".%s" % node.attribute)
311
312    def visit_BoolNode(self, node):
313        self.put(str(node.value))
314
315    # FIXME: represent string nodes correctly
316    def visit_StringNode(self, node):
317        value = node.value
318        if value.encoding is not None:
319            value = value.encode(value.encoding)
320        self.put(repr(value))
321
322    def visit_PassStatNode(self, node):
323        self.startline(u"pass")
324        self.endline()
325
326class CodeWriter(DeclarationWriter):
327
328    def visit_SingleAssignmentNode(self, node):
329        self.startline()
330        self.visit(node.lhs)
331        self.put(u" = ")
332        self.visit(node.rhs)
333        self.endline()
334
335    def visit_CascadedAssignmentNode(self, node):
336        self.startline()
337        for lhs in node.lhs_list:
338            self.visit(lhs)
339            self.put(u" = ")
340        self.visit(node.rhs)
341        self.endline()
342
343    def visit_PrintStatNode(self, node):
344        self.startline(u"print ")
345        self.comma_separated_list(node.arg_tuple.args)
346        if not node.append_newline:
347            self.put(u",")
348        self.endline()
349
350    def visit_ForInStatNode(self, node):
351        self.startline(u"for ")
352        self.visit(node.target)
353        self.put(u" in ")
354        self.visit(node.iterator.sequence)
355        self.endline(u":")
356        self.indent()
357        self.visit(node.body)
358        self.dedent()
359        if node.else_clause is not None:
360            self.line(u"else:")
361            self.indent()
362            self.visit(node.else_clause)
363            self.dedent()
364
365    def visit_IfStatNode(self, node):
366        # The IfClauseNode is handled directly without a separate match
367        # for clariy.
368        self.startline(u"if ")
369        self.visit(node.if_clauses[0].condition)
370        self.endline(":")
371        self.indent()
372        self.visit(node.if_clauses[0].body)
373        self.dedent()
374        for clause in node.if_clauses[1:]:
375            self.startline("elif ")
376            self.visit(clause.condition)
377            self.endline(":")
378            self.indent()
379            self.visit(clause.body)
380            self.dedent()
381        if node.else_clause is not None:
382            self.line("else:")
383            self.indent()
384            self.visit(node.else_clause)
385            self.dedent()
386
387    def visit_SequenceNode(self, node):
388        self.comma_separated_list(node.args) # Might need to discover whether we need () around tuples...hmm...
389
390    def visit_SimpleCallNode(self, node):
391        self.visit(node.function)
392        self.put(u"(")
393        self.comma_separated_list(node.args)
394        self.put(")")
395
396    def visit_GeneralCallNode(self, node):
397        self.visit(node.function)
398        self.put(u"(")
399        posarg = node.positional_args
400        if isinstance(posarg, AsTupleNode):
401            self.visit(posarg.arg)
402        else:
403            self.comma_separated_list(posarg.args)  # TupleNode.args
404        if node.keyword_args:
405            if isinstance(node.keyword_args, DictNode):
406                for i, (name, value) in enumerate(node.keyword_args.key_value_pairs):
407                    if i > 0:
408                        self.put(', ')
409                    self.visit(name)
410                    self.put('=')
411                    self.visit(value)
412            else:
413                raise Exception("Not implemented yet")
414        self.put(u")")
415
416    def visit_ExprStatNode(self, node):
417        self.startline()
418        self.visit(node.expr)
419        self.endline()
420
421    def visit_InPlaceAssignmentNode(self, node):
422        self.startline()
423        self.visit(node.lhs)
424        self.put(u" %s= " % node.operator)
425        self.visit(node.rhs)
426        self.endline()
427
428    def visit_WithStatNode(self, node):
429        self.startline()
430        self.put(u"with ")
431        self.visit(node.manager)
432        if node.target is not None:
433            self.put(u" as ")
434            self.visit(node.target)
435        self.endline(u":")
436        self.indent()
437        self.visit(node.body)
438        self.dedent()
439
440    def visit_TryFinallyStatNode(self, node):
441        self.line(u"try:")
442        self.indent()
443        self.visit(node.body)
444        self.dedent()
445        self.line(u"finally:")
446        self.indent()
447        self.visit(node.finally_clause)
448        self.dedent()
449
450    def visit_TryExceptStatNode(self, node):
451        self.line(u"try:")
452        self.indent()
453        self.visit(node.body)
454        self.dedent()
455        for x in node.except_clauses:
456            self.visit(x)
457        if node.else_clause is not None:
458            self.visit(node.else_clause)
459
460    def visit_ExceptClauseNode(self, node):
461        self.startline(u"except")
462        if node.pattern is not None:
463            self.put(u" ")
464            self.visit(node.pattern)
465        if node.target is not None:
466            self.put(u", ")
467            self.visit(node.target)
468        self.endline(":")
469        self.indent()
470        self.visit(node.body)
471        self.dedent()
472
473    def visit_ReturnStatNode(self, node):
474        self.startline("return ")
475        self.visit(node.value)
476        self.endline()
477
478    def visit_ReraiseStatNode(self, node):
479        self.line("raise")
480
481    def visit_ImportNode(self, node):
482        self.put(u"(import %s)" % node.module_name.value)
483
484    def visit_TempsBlockNode(self, node):
485        """
486        Temporaries are output like $1_1', where the first number is
487        an index of the TempsBlockNode and the second number is an index
488        of the temporary which that block allocates.
489        """
490        idx = 0
491        for handle in node.temps:
492            self.tempnames[handle] = "$%d_%d" % (self.tempblockindex, idx)
493            idx += 1
494        self.tempblockindex += 1
495        self.visit(node.body)
496
497    def visit_TempRefNode(self, node):
498        self.put(self.tempnames[node.handle])
499
500
501class PxdWriter(DeclarationWriter):
502    def __call__(self, node):
503        print(u'\n'.join(self.write(node).lines))
504        return node
505
506    def visit_CFuncDefNode(self, node):
507        if 'inline' in node.modifiers:
508            return
509        if node.overridable:
510            self.startline(u'cpdef ')
511        else:
512            self.startline(u'cdef ')
513        if node.visibility != 'private':
514            self.put(node.visibility)
515            self.put(u' ')
516        if node.api:
517            self.put(u'api ')
518        self.visit(node.declarator)
519
520    def visit_StatNode(self, node):
521        pass
522
523
524class ExpressionWriter(TreeVisitor):
525
526    def __init__(self, result=None):
527        super(ExpressionWriter, self).__init__()
528        if result is None:
529            result = u""
530        self.result = result
531        self.precedence = [0]
532
533    def write(self, tree):
534        self.visit(tree)
535        return self.result
536
537    def put(self, s):
538        self.result += s
539
540    def remove(self, s):
541        if self.result.endswith(s):
542            self.result = self.result[:-len(s)]
543
544    def comma_separated_list(self, items):
545        if len(items) > 0:
546            for item in items[:-1]:
547                self.visit(item)
548                self.put(u", ")
549            self.visit(items[-1])
550
551    def visit_Node(self, node):
552        raise AssertionError("Node not handled by serializer: %r" % node)
553
554    def visit_NameNode(self, node):
555        self.put(node.name)
556
557    def visit_NoneNode(self, node):
558        self.put(u"None")
559
560    def visit_EllipsisNode(self, node):
561        self.put(u"...")
562
563    def visit_BoolNode(self, node):
564        self.put(str(node.value))
565
566    def visit_ConstNode(self, node):
567        self.put(str(node.value))
568
569    def visit_ImagNode(self, node):
570        self.put(node.value)
571        self.put(u"j")
572
573    def emit_string(self, node, prefix=u""):
574        repr_val = repr(node.value)
575        if repr_val[0] in 'ub':
576            repr_val = repr_val[1:]
577        self.put(u"%s%s" % (prefix, repr_val))
578
579    def visit_BytesNode(self, node):
580        self.emit_string(node, u"b")
581
582    def visit_StringNode(self, node):
583        self.emit_string(node)
584
585    def visit_UnicodeNode(self, node):
586        self.emit_string(node, u"u")
587
588    def emit_sequence(self, node, parens=(u"", u"")):
589        open_paren, close_paren = parens
590        items = node.subexpr_nodes()
591        self.put(open_paren)
592        self.comma_separated_list(items)
593        self.put(close_paren)
594
595    def visit_ListNode(self, node):
596        self.emit_sequence(node, u"[]")
597
598    def visit_TupleNode(self, node):
599        self.emit_sequence(node, u"()")
600
601    def visit_SetNode(self, node):
602        if len(node.subexpr_nodes()) > 0:
603            self.emit_sequence(node, u"{}")
604        else:
605            self.put(u"set()")
606
607    def visit_DictNode(self, node):
608        self.emit_sequence(node, u"{}")
609
610    def visit_DictItemNode(self, node):
611        self.visit(node.key)
612        self.put(u": ")
613        self.visit(node.value)
614
615    unop_precedence = {
616        'not': 3, '!': 3,
617        '+': 11, '-': 11, '~': 11,
618    }
619    binop_precedence = {
620        'or': 1,
621        'and': 2,
622        # unary: 'not': 3, '!': 3,
623        'in': 4, 'not_in': 4, 'is': 4, 'is_not': 4, '<': 4, '<=': 4, '>': 4, '>=': 4, '!=': 4, '==': 4,
624        '|': 5,
625        '^': 6,
626        '&': 7,
627        '<<': 8, '>>': 8,
628        '+': 9, '-': 9,
629        '*': 10, '@': 10, '/': 10, '//': 10, '%': 10,
630        # unary: '+': 11, '-': 11, '~': 11
631        '**': 12,
632    }
633
634    def operator_enter(self, new_prec):
635        old_prec = self.precedence[-1]
636        if old_prec > new_prec:
637            self.put(u"(")
638        self.precedence.append(new_prec)
639
640    def operator_exit(self):
641        old_prec, new_prec = self.precedence[-2:]
642        if old_prec > new_prec:
643            self.put(u")")
644        self.precedence.pop()
645
646    def visit_NotNode(self, node):
647        op = 'not'
648        prec = self.unop_precedence[op]
649        self.operator_enter(prec)
650        self.put(u"not ")
651        self.visit(node.operand)
652        self.operator_exit()
653
654    def visit_UnopNode(self, node):
655        op = node.operator
656        prec = self.unop_precedence[op]
657        self.operator_enter(prec)
658        self.put(u"%s" % node.operator)
659        self.visit(node.operand)
660        self.operator_exit()
661
662    def visit_BinopNode(self, node):
663        op = node.operator
664        prec = self.binop_precedence.get(op, 0)
665        self.operator_enter(prec)
666        self.visit(node.operand1)
667        self.put(u" %s " % op.replace('_', ' '))
668        self.visit(node.operand2)
669        self.operator_exit()
670
671    def visit_BoolBinopNode(self, node):
672        self.visit_BinopNode(node)
673
674    def visit_PrimaryCmpNode(self, node):
675        self.visit_BinopNode(node)
676
677    def visit_IndexNode(self, node):
678        self.visit(node.base)
679        self.put(u"[")
680        if isinstance(node.index, TupleNode):
681            self.emit_sequence(node.index)
682        else:
683            self.visit(node.index)
684        self.put(u"]")
685
686    def visit_SliceIndexNode(self, node):
687        self.visit(node.base)
688        self.put(u"[")
689        if node.start:
690            self.visit(node.start)
691        self.put(u":")
692        if node.stop:
693            self.visit(node.stop)
694        if node.slice:
695            self.put(u":")
696            self.visit(node.slice)
697        self.put(u"]")
698
699    def visit_SliceNode(self, node):
700        if not node.start.is_none:
701            self.visit(node.start)
702        self.put(u":")
703        if not node.stop.is_none:
704            self.visit(node.stop)
705        if not node.step.is_none:
706            self.put(u":")
707            self.visit(node.step)
708
709    def visit_CondExprNode(self, node):
710        self.visit(node.true_val)
711        self.put(u" if ")
712        self.visit(node.test)
713        self.put(u" else ")
714        self.visit(node.false_val)
715
716    def visit_AttributeNode(self, node):
717        self.visit(node.obj)
718        self.put(u".%s" % node.attribute)
719
720    def visit_SimpleCallNode(self, node):
721        self.visit(node.function)
722        self.put(u"(")
723        self.comma_separated_list(node.args)
724        self.put(")")
725
726    def emit_pos_args(self, node):
727        if node is None:
728            return
729        if isinstance(node, AddNode):
730            self.emit_pos_args(node.operand1)
731            self.emit_pos_args(node.operand2)
732        elif isinstance(node, TupleNode):
733            for expr in node.subexpr_nodes():
734                self.visit(expr)
735                self.put(u", ")
736        elif isinstance(node, AsTupleNode):
737            self.put("*")
738            self.visit(node.arg)
739            self.put(u", ")
740        else:
741            self.visit(node)
742            self.put(u", ")
743
744    def emit_kwd_args(self, node):
745        if node is None:
746            return
747        if isinstance(node, MergedDictNode):
748            for expr in node.subexpr_nodes():
749                self.emit_kwd_args(expr)
750        elif isinstance(node, DictNode):
751            for expr in node.subexpr_nodes():
752                self.put(u"%s=" % expr.key.value)
753                self.visit(expr.value)
754                self.put(u", ")
755        else:
756            self.put(u"**")
757            self.visit(node)
758            self.put(u", ")
759
760    def visit_GeneralCallNode(self, node):
761        self.visit(node.function)
762        self.put(u"(")
763        self.emit_pos_args(node.positional_args)
764        self.emit_kwd_args(node.keyword_args)
765        self.remove(u", ")
766        self.put(")")
767
768    def emit_comprehension(self, body, target,
769                           sequence, condition,
770                           parens=(u"", u"")):
771        open_paren, close_paren = parens
772        self.put(open_paren)
773        self.visit(body)
774        self.put(u" for ")
775        self.visit(target)
776        self.put(u" in ")
777        self.visit(sequence)
778        if condition:
779            self.put(u" if ")
780            self.visit(condition)
781        self.put(close_paren)
782
783    def visit_ComprehensionAppendNode(self, node):
784        self.visit(node.expr)
785
786    def visit_DictComprehensionAppendNode(self, node):
787        self.visit(node.key_expr)
788        self.put(u": ")
789        self.visit(node.value_expr)
790
791    def visit_ComprehensionNode(self, node):
792        tpmap = {'list': u"[]", 'dict': u"{}", 'set': u"{}"}
793        parens = tpmap[node.type.py_type_name()]
794        body = node.loop.body
795        target = node.loop.target
796        sequence = node.loop.iterator.sequence
797        condition = None
798        if hasattr(body, 'if_clauses'):
799            # type(body) is Nodes.IfStatNode
800            condition = body.if_clauses[0].condition
801            body = body.if_clauses[0].body
802        self.emit_comprehension(body, target, sequence, condition, parens)
803
804    def visit_GeneratorExpressionNode(self, node):
805        body = node.loop.body
806        target = node.loop.target
807        sequence = node.loop.iterator.sequence
808        condition = None
809        if hasattr(body, 'if_clauses'):
810            # type(body) is Nodes.IfStatNode
811            condition = body.if_clauses[0].condition
812            body = body.if_clauses[0].body.expr.arg
813        elif hasattr(body, 'expr'):
814            # type(body) is Nodes.ExprStatNode
815            body = body.expr.arg
816        self.emit_comprehension(body, target, sequence, condition, u"()")
817