1# -*- coding: utf-8 -*-
2#
3# Copyright (C) 2008-2009 Edgewall Software
4# All rights reserved.
5#
6# This software is licensed as described in the file COPYING, which
7# you should have received as part of this distribution. The terms
8# are also available at http://genshi.edgewall.org/wiki/License.
9#
10# This software consists of voluntary contributions made by many
11# individuals. For the exact contribution history, see the revision
12# history and logs, available at http://genshi.edgewall.org/log/.
13
14"""Emulation of the proper abstract syntax tree API for Python 2.4."""
15
16import compiler
17import compiler.ast
18
19from genshi.template import _ast24 as _ast
20
21__all__ = ['_ast', 'parse']
22__docformat__ = 'restructuredtext en'
23
24
25def _new(cls, *args, **kwargs):
26    ret = cls()
27    if ret._fields:
28        for attr, value in zip(ret._fields, args):
29            if attr in kwargs:
30                raise ValueError('Field set both in args and kwargs')
31            setattr(ret, attr, value)
32    for attr in kwargs:
33        if (getattr(ret, '_fields', None) and attr in ret._fields) \
34                or (getattr(ret, '_attributes', None) and
35                        attr in ret._attributes):
36            setattr(ret, attr, kwargs[attr])
37    return ret
38
39
40class ASTUpgrader(object):
41    """Transformer changing structure of Python 2.4 ASTs to
42    Python 2.5 ones.
43
44    Transforms ``compiler.ast`` Abstract Syntax Tree to builtin ``_ast``.
45    It can use fake`` _ast`` classes and this way allow ``_ast`` emulation
46    in Python 2.4.
47    """
48
49    def __init__(self):
50        self.out_flags = None
51        self.lines = [-1]
52
53    def _new(self, *args, **kwargs):
54        return _new(lineno = self.lines[-1], *args, **kwargs)
55
56    def visit(self, node):
57        if node is None:
58            return None
59        if type(node) is tuple:
60            return tuple([self.visit(n) for n in node])
61        lno = getattr(node, 'lineno', None)
62        if lno is not None:
63            self.lines.append(lno)
64        visitor = getattr(self, 'visit_%s' % node.__class__.__name__, None)
65        if visitor is None:
66            raise Exception('Unhandled node type %r' % type(node))
67
68        retval = visitor(node)
69        if lno is not None:
70            self.lines.pop()
71        return retval
72
73    def visit_Module(self, node):
74        body = self.visit(node.node)
75        if node.doc:
76            body = [self._new(_ast.Expr, self._new(_ast.Str, node.doc))] + body
77        return self._new(_ast.Module, body)
78
79    def visit_Expression(self, node):
80        return self._new(_ast.Expression, self.visit(node.node))
81
82    def _extract_args(self, node):
83        tab = node.argnames[:]
84        if node.flags & compiler.ast.CO_VARKEYWORDS:
85            kwarg = tab[-1]
86            tab = tab[:-1]
87        else:
88            kwarg = None
89
90        if node.flags & compiler.ast.CO_VARARGS:
91            vararg = tab[-1]
92            tab = tab[:-1]
93        else:
94            vararg = None
95
96        def _tup(t):
97            if isinstance(t, str):
98                return self._new(_ast.Name, t, _ast.Store())
99            elif isinstance(t, tuple):
100                elts = [_tup(x) for x in t]
101                return self._new(_ast.Tuple, elts, _ast.Store())
102            else:
103                raise NotImplemented
104
105        args = []
106        for arg in tab:
107            if isinstance(arg, str):
108                args.append(self._new(_ast.Name, arg, _ast.Param()))
109            elif isinstance(arg, tuple):
110                args.append(_tup(arg))
111            else:
112                assert False, node.__class__
113
114        defaults = [self.visit(d) for d in node.defaults]
115        return self._new(_ast.arguments, args, vararg, kwarg, defaults)
116
117
118    def visit_Function(self, node):
119        if getattr(node, 'decorators', ()):
120            decorators = [self.visit(d) for d in node.decorators.nodes]
121        else:
122            decorators = []
123
124        args = self._extract_args(node)
125        body = self.visit(node.code)
126        if node.doc:
127            body = [self._new(_ast.Expr, self._new(_ast.Str, node.doc))] + body
128        return self._new(_ast.FunctionDef, node.name, args, body, decorators)
129
130    def visit_Class(self, node):
131        #self.name_types.append(_ast.Load)
132        bases = [self.visit(b) for b in node.bases]
133        #self.name_types.pop()
134        body = self.visit(node.code)
135        if node.doc:
136            body = [self._new(_ast.Expr, self._new(_ast.Str, node.doc))] + body
137        return self._new(_ast.ClassDef, node.name, bases, body)
138
139    def visit_Return(self, node):
140        return self._new(_ast.Return, self.visit(node.value))
141
142    def visit_Assign(self, node):
143        #self.name_types.append(_ast.Store)
144        targets = [self.visit(t) for t in node.nodes]
145        #self.name_types.pop()
146        return self._new(_ast.Assign, targets, self.visit(node.expr))
147
148    aug_operators = {
149        '+=': _ast.Add,
150        '/=': _ast.Div,
151        '//=': _ast.FloorDiv,
152        '<<=': _ast.LShift,
153        '%=': _ast.Mod,
154        '*=': _ast.Mult,
155        '**=': _ast.Pow,
156        '>>=': _ast.RShift,
157        '-=': _ast.Sub,
158    }
159
160    def visit_AugAssign(self, node):
161        target = self.visit(node.node)
162
163        # Because it's AugAssign target can't be list nor tuple
164        # so we only have to change context of one node
165        target.ctx = _ast.Store()
166        op = self.aug_operators[node.op]()
167        return self._new(_ast.AugAssign, target, op, self.visit(node.expr))
168
169    def _visit_Print(nl):
170        def _visit(self, node):
171            values = [self.visit(v) for v in node.nodes]
172            return self._new(_ast.Print, self.visit(node.dest), values, nl)
173        return _visit
174
175    visit_Print = _visit_Print(False)
176    visit_Printnl = _visit_Print(True)
177    del _visit_Print
178
179    def visit_For(self, node):
180        return self._new(_ast.For, self.visit(node.assign), self.visit(node.list),
181                        self.visit(node.body), self.visit(node.else_))
182
183    def visit_While(self, node):
184        return self._new(_ast.While, self.visit(node.test), self.visit(node.body),
185                        self.visit(node.else_))
186
187    def visit_If(self, node):
188        def _level(tests, else_):
189            test = self.visit(tests[0][0])
190            body = self.visit(tests[0][1])
191            if len(tests) == 1:
192                orelse = self.visit(else_)
193            else:
194                orelse = [_level(tests[1:], else_)]
195            return self._new(_ast.If, test, body, orelse)
196        return _level(node.tests, node.else_)
197
198    def visit_With(self, node):
199        return self._new(_ast.With, self.visit(node.expr),
200                            self.visit(node.vars), self.visit(node.body))
201
202    def visit_Raise(self, node):
203        return self._new(_ast.Raise, self.visit(node.expr1),
204                        self.visit(node.expr2), self.visit(node.expr3))
205
206    def visit_TryExcept(self, node):
207        handlers = []
208        for type, name, body in node.handlers:
209            handlers.append(self._new(_ast.excepthandler, self.visit(type),
210                            self.visit(name), self.visit(body)))
211        return self._new(_ast.TryExcept, self.visit(node.body),
212                        handlers, self.visit(node.else_))
213
214    def visit_TryFinally(self, node):
215        return self._new(_ast.TryFinally, self.visit(node.body),
216                        self.visit(node.final))
217
218    def visit_Assert(self, node):
219        return self._new(_ast.Assert, self.visit(node.test), self.visit(node.fail))
220
221    def visit_Import(self, node):
222        names = [self._new(_ast.alias, n[0], n[1]) for n in node.names]
223        return self._new(_ast.Import, names)
224
225    def visit_From(self, node):
226        names = [self._new(_ast.alias, n[0], n[1]) for n in node.names]
227        return self._new(_ast.ImportFrom, node.modname, names, 0)
228
229    def visit_Exec(self, node):
230        return self._new(_ast.Exec, self.visit(node.expr),
231                        self.visit(node.locals), self.visit(node.globals))
232
233    def visit_Global(self, node):
234        return self._new(_ast.Global, node.names[:])
235
236    def visit_Discard(self, node):
237        return self._new(_ast.Expr, self.visit(node.expr))
238
239    def _map_class(to):
240        def _visit(self, node):
241            return self._new(to)
242        return _visit
243
244    visit_Pass = _map_class(_ast.Pass)
245    visit_Break = _map_class(_ast.Break)
246    visit_Continue = _map_class(_ast.Continue)
247
248    def _visit_BinOperator(opcls):
249        def _visit(self, node):
250            return self._new(_ast.BinOp, self.visit(node.left),
251                            opcls(), self.visit(node.right))
252        return _visit
253    visit_Add = _visit_BinOperator(_ast.Add)
254    visit_Div = _visit_BinOperator(_ast.Div)
255    visit_FloorDiv = _visit_BinOperator(_ast.FloorDiv)
256    visit_LeftShift = _visit_BinOperator(_ast.LShift)
257    visit_Mod = _visit_BinOperator(_ast.Mod)
258    visit_Mul = _visit_BinOperator(_ast.Mult)
259    visit_Power = _visit_BinOperator(_ast.Pow)
260    visit_RightShift = _visit_BinOperator(_ast.RShift)
261    visit_Sub = _visit_BinOperator(_ast.Sub)
262    del _visit_BinOperator
263
264    def _visit_BitOperator(opcls):
265        def _visit(self, node):
266            def _make(nodes):
267                if len(nodes) == 1:
268                    return self.visit(nodes[0])
269                left = _make(nodes[:-1])
270                right = self.visit(nodes[-1])
271                return self._new(_ast.BinOp, left, opcls(), right)
272            return _make(node.nodes)
273        return _visit
274    visit_Bitand = _visit_BitOperator(_ast.BitAnd)
275    visit_Bitor = _visit_BitOperator(_ast.BitOr)
276    visit_Bitxor = _visit_BitOperator(_ast.BitXor)
277    del _visit_BitOperator
278
279    def _visit_UnaryOperator(opcls):
280        def _visit(self, node):
281            return self._new(_ast.UnaryOp, opcls(), self.visit(node.expr))
282        return _visit
283
284    visit_Invert = _visit_UnaryOperator(_ast.Invert)
285    visit_Not = _visit_UnaryOperator(_ast.Not)
286    visit_UnaryAdd = _visit_UnaryOperator(_ast.UAdd)
287    visit_UnarySub = _visit_UnaryOperator(_ast.USub)
288    del _visit_UnaryOperator
289
290    def _visit_BoolOperator(opcls):
291        def _visit(self, node):
292            values = [self.visit(n) for n in node.nodes]
293            return self._new(_ast.BoolOp, opcls(), values)
294        return _visit
295    visit_And = _visit_BoolOperator(_ast.And)
296    visit_Or = _visit_BoolOperator(_ast.Or)
297    del _visit_BoolOperator
298
299    cmp_operators = {
300        '==': _ast.Eq,
301        '!=': _ast.NotEq,
302        '<': _ast.Lt,
303        '<=': _ast.LtE,
304        '>': _ast.Gt,
305        '>=': _ast.GtE,
306        'is': _ast.Is,
307        'is not': _ast.IsNot,
308        'in': _ast.In,
309        'not in': _ast.NotIn,
310    }
311
312    def visit_Compare(self, node):
313        left = self.visit(node.expr)
314        ops = []
315        comparators = []
316        for optype, expr in node.ops:
317            ops.append(self.cmp_operators[optype]())
318            comparators.append(self.visit(expr))
319        return self._new(_ast.Compare, left, ops, comparators)
320
321    def visit_Lambda(self, node):
322        args = self._extract_args(node)
323        body = self.visit(node.code)
324        return self._new(_ast.Lambda, args, body)
325
326    def visit_IfExp(self, node):
327        return self._new(_ast.IfExp, self.visit(node.test), self.visit(node.then),
328                        self.visit(node.else_))
329
330    def visit_Dict(self, node):
331        keys = [self.visit(x[0]) for x in node.items]
332        values = [self.visit(x[1]) for x in node.items]
333        return self._new(_ast.Dict, keys, values)
334
335    def visit_ListComp(self, node):
336        generators = [self.visit(q) for q in node.quals]
337        return self._new(_ast.ListComp, self.visit(node.expr), generators)
338
339    def visit_GenExprInner(self, node):
340        generators = [self.visit(q) for q in node.quals]
341        return self._new(_ast.GeneratorExp, self.visit(node.expr), generators)
342
343    def visit_GenExpr(self, node):
344        return self.visit(node.code)
345
346    def visit_GenExprFor(self, node):
347        ifs = [self.visit(i) for i in node.ifs]
348        return self._new(_ast.comprehension, self.visit(node.assign),
349                        self.visit(node.iter), ifs)
350
351    def visit_ListCompFor(self, node):
352        ifs = [self.visit(i) for i in node.ifs]
353        return self._new(_ast.comprehension, self.visit(node.assign),
354                        self.visit(node.list), ifs)
355
356    def visit_GenExprIf(self, node):
357        return self.visit(node.test)
358    visit_ListCompIf = visit_GenExprIf
359
360    def visit_Yield(self, node):
361        return self._new(_ast.Yield, self.visit(node.value))
362
363    def visit_CallFunc(self, node):
364        args = []
365        keywords = []
366        for arg in node.args:
367            if isinstance(arg, compiler.ast.Keyword):
368                keywords.append(self._new(_ast.keyword, arg.name,
369                                        self.visit(arg.expr)))
370            else:
371                args.append(self.visit(arg))
372        return self._new(_ast.Call, self.visit(node.node), args, keywords,
373                    self.visit(node.star_args), self.visit(node.dstar_args))
374
375    def visit_Backquote(self, node):
376        return self._new(_ast.Repr, self.visit(node.expr))
377
378    def visit_Const(self, node):
379        if node.value is None: # appears in slices
380            return None
381        elif isinstance(node.value, basestring):
382            return self._new(_ast.Str, node.value)
383        else:
384            return self._new(_ast.Num, node.value)
385
386    def visit_Name(self, node):
387        return self._new(_ast.Name, node.name, _ast.Load())
388
389    def visit_Getattr(self, node):
390        return self._new(_ast.Attribute, self.visit(node.expr), node.attrname,
391                         _ast.Load())
392
393    def visit_Tuple(self, node):
394        nodes = [self.visit(n) for n in node.nodes]
395        return self._new(_ast.Tuple, nodes, _ast.Load())
396
397    def visit_List(self, node):
398        nodes = [self.visit(n) for n in node.nodes]
399        return self._new(_ast.List, nodes, _ast.Load())
400
401    def get_ctx(self, flags):
402        if flags == 'OP_DELETE':
403            return _ast.Del()
404        elif flags == 'OP_APPLY':
405            return _ast.Load()
406        elif flags == 'OP_ASSIGN':
407            return _ast.Store()
408        else:
409            # FIXME Exception here
410            assert False, repr(flags)
411
412    def visit_AssName(self, node):
413        self.out_flags = node.flags
414        ctx = self.get_ctx(node.flags)
415        return self._new(_ast.Name, node.name, ctx)
416
417    def visit_AssAttr(self, node):
418        self.out_flags = node.flags
419        ctx = self.get_ctx(node.flags)
420        return self._new(_ast.Attribute, self.visit(node.expr),
421                         node.attrname, ctx)
422
423    def _visit_AssCollection(cls):
424        def _visit(self, node):
425            flags = None
426            elts = []
427            for n in node.nodes:
428                elts.append(self.visit(n))
429                if flags is None:
430                    flags = self.out_flags
431                else:
432                    assert flags == self.out_flags
433            self.out_flags = flags
434            ctx = self.get_ctx(flags)
435            return self._new(cls, elts, ctx)
436        return _visit
437
438    visit_AssList = _visit_AssCollection(_ast.List)
439    visit_AssTuple = _visit_AssCollection(_ast.Tuple)
440    del _visit_AssCollection
441
442    def visit_Slice(self, node):
443        lower = self.visit(node.lower)
444        upper = self.visit(node.upper)
445        ctx = self.get_ctx(node.flags)
446        self.out_flags = node.flags
447        return self._new(_ast.Subscript, self.visit(node.expr),
448                    self._new(_ast.Slice, lower, upper, None), ctx)
449
450    def visit_Subscript(self, node):
451        ctx = self.get_ctx(node.flags)
452        subs = [self.visit(s) for s in node.subs]
453
454        advanced = (_ast.Slice, _ast.Ellipsis)
455        slices = []
456        nonindex = False
457        for sub in subs:
458            if isinstance(sub, advanced):
459                nonindex = True
460                slices.append(sub)
461            else:
462                slices.append(self._new(_ast.Index, sub))
463        if len(slices) == 1:
464            slice = slices[0]
465        elif nonindex:
466            slice = self._new(_ast.ExtSlice, slices)
467        else:
468            slice = self._new(_ast.Tuple, slices, _ast.Load())
469
470        self.out_flags = node.flags
471        return self._new(_ast.Subscript, self.visit(node.expr), slice, ctx)
472
473    def visit_Sliceobj(self, node):
474        a = [self.visit(n) for n in node.nodes + [None]*(3 - len(node.nodes))]
475        return self._new(_ast.Slice, a[0], a[1], a[2])
476
477    def visit_Ellipsis(self, node):
478        return self._new(_ast.Ellipsis)
479
480    def visit_Stmt(self, node):
481        def _check_del(n):
482            # del x is just AssName('x', 'OP_DELETE')
483            # we want to transform it to Delete([Name('x', Del())])
484            dcls = (_ast.Name, _ast.List, _ast.Subscript, _ast.Attribute)
485            if isinstance(n, dcls) and isinstance(n.ctx, _ast.Del):
486                return self._new(_ast.Delete, [n])
487            elif isinstance(n, _ast.Tuple) and isinstance(n.ctx, _ast.Del):
488                # unpack last tuple to avoid making del (x, y, z,);
489                # out of del x, y, z; (there's no difference between
490                # this two in compiler.ast)
491                return self._new(_ast.Delete, n.elts)
492            else:
493                return n
494        def _keep(n):
495            if isinstance(n, _ast.Expr) and n.value is None:
496                return False
497            else:
498                return True
499        return [s for s in [_check_del(self.visit(n)) for n in node.nodes]
500                if _keep(s)]
501
502
503def parse(source, mode):
504    node = compiler.parse(source, mode)
505    return ASTUpgrader().visit(node)
506