1# TODO: replace ExprInner by passes that replaces the objects by what's needed
2# for arybo
3
4import functools
5import operator
6import six
7import collections
8
9from six.moves import range, reduce
10
11from arybo.lib import MBA, simplify_inplace
12from pytanque import imm, expand_esf_inplace, simplify_inplace, Vector, esf
13
14class Expr(object):
15    @property
16    def nbits(self):
17        raise NotImplementedError()
18
19    @property
20    def args(self):
21        return None
22
23    def init_ctx(self):
24        return None
25
26    def eval(self, vec, i, ctx, args, use_esf):
27        raise NotImplementedError()
28
29    def __parse_arg(self,v):
30        if isinstance(v, six.integer_types):
31            return ExprCst(v, self.nbits)
32        if not isinstance(v, Expr):
33            raise ValueError("argument must be an integer or an Expr")
34        return v
35
36    def __check_arg_int(self,v):
37        if not isinstance(v, six.integer_types):
38            raise ValueError("argument must be an integer")
39
40    def __add__(self, o):
41        o = self.__parse_arg(o)
42        return ExprAdd(self, o)
43
44    def __radd__(self, o):
45        o = self.__parse_arg(o)
46        return ExprAdd(o, self)
47
48    def __sub__(self, o):
49        o = self.__parse_arg(o)
50        return ExprSub(self, o)
51
52    def __rsub__(self, o):
53        o = self.__parse_arg(o)
54        return ExprSub(o,self)
55
56    def __mul__(self, o):
57        o = self.__parse_arg(o)
58        return ExprMul(self, o)
59
60    def __rmul__(self, o):
61        o = self.__parse_arg(o)
62        return ExprMul(o, self)
63
64    def __xor__(self, o):
65        o = self.__parse_arg(o)
66        return ExprXor(self, o)
67
68    def __rxor__(self, o):
69        o = self.__parse_arg(o)
70        return ExprXor(o, self)
71
72    def __and__(self, o):
73        o = self.__parse_arg(o)
74        return ExprAnd(self, o)
75
76    def __rand__(self, o):
77        o = self.__parse_arg(o)
78        return ExprAnd(o, self)
79
80    def __or__(self, o):
81        o = self.__parse_arg(o)
82        return ExprOr(self, o)
83
84    def __ror__(self, o):
85        o = self.__parse_arg(o)
86        return ExprOr(o, self)
87
88    def __lshift__(self, o):
89        o = self.__parse_arg(o)
90        return ExprShl(self, o)
91
92    def __rshift__(self, o):
93        return self.lshr(o)
94
95    def lshr(self, o):
96        o = self.__parse_arg(o)
97        return ExprLShr(self, o)
98
99    def ashr(self, o):
100        o = self.__parse_arg(o)
101        return ExprAShr(self, o)
102
103    def __neg__(self):
104        return ExprSub(ExprCst(0, self.nbits), self)
105
106    def __invert__(self):
107        return ExprNot(self)
108
109    def zext(self,n):
110        self.__check_arg_int(n)
111        return ExprZX(self,n)
112
113    def sext(self,n):
114        self.__check_arg_int(n)
115        return ExprSX(self,n)
116
117    def rol(self,o):
118        o = self.__parse_arg(o)
119        return ExprRol(self,o)
120
121    def ror(self,o):
122        o = self.__parse_arg(o)
123        return ExprRor(self,o)
124
125    def udiv(self,o):
126        o = self.__parse_arg(o)
127        return ExprDiv(self, o, is_signed=False)
128
129    def sdiv(self,o):
130        o = self.__parse_arg(o)
131        return ExprDiv(self, o, is_signed=True)
132
133    def urem(self, o):
134        o = self.__parse_arg(o)
135        return ExprRem(self, o, is_signed=False)
136
137    def srem(self, o):
138        o = self.__parse_arg(o)
139        return ExprRem(self, o, is_signed=True)
140
141    def __getitem__(self,s):
142        if not isinstance(s, slice):
143            raise ValueError("can only get slices")
144        return ExprSlice(self, s)
145
146# Leaves
147class ExprCst(Expr):
148    def __init__(self, n, nbits):
149        assert(n >= 0)
150        self.n = n & ((1<<nbits)-1)
151        self.__nbits = nbits
152
153    @property
154    def nbits(self):
155        return self.__nbits
156
157    def eval(self, vec, i, ctx, args, use_esf):
158        return imm((self.n>>i)&1)
159
160    def to_cst(self):
161        return self.n
162
163    @staticmethod
164    def get_cst(obj,nbits=None):
165        if isinstance(obj, ExprCst):
166            ret = obj.n
167        elif isinstance(obj, six.integer_types):
168            ret = obj
169        else:
170            raise ValueError("obj must be an ExprCst or an integer")
171        if nbits is None:
172            return ret
173        return ExprCst(ret,nbits).n
174
175class ExprBV(Expr):
176    def __init__(self, v):
177        self.v = v
178
179    @property
180    def nbits(self):
181        return self.v.nbits
182
183    def eval(self, vec, i, ctx, args, use_esf):
184        return self.v.vec[i]
185
186    def to_cst(self):
187        return self.v.to_cst()
188
189# Unary ops
190class ExprUnaryOp(Expr):
191    def __init__(self, arg):
192        self.arg = arg
193
194    @property
195    def args(self):
196        return [self.arg]
197
198    @args.setter
199    def args(self, args):
200        self.arg = args[0]
201
202    @property
203    def nbits(self):
204        return self.arg.nbits
205
206class ExprNot(ExprUnaryOp):
207    def eval(self, vec, i, ctx, args, use_esf):
208        return args[0].eval(vec, i, use_esf) + imm(True)
209
210# Nary ops
211class ExprNaryOp(Expr):
212    def __init__(self, *args):
213        self._args = args
214
215    @property
216    def args(self):
217        return self._args
218
219    @args.setter
220    def args(self, args):
221        self._args = args
222
223    @property
224    def nbits(self):
225        # TODO assert every args has the same size
226        return self.args[0].nbits
227
228    def compute(self, vec, i, args, ctx, use_esf):
229        raise NotImplementedError()
230
231    def eval(self, vec, i, ctx, args, use_esf):
232        args = (a.eval(vec, i, use_esf) for a in args)
233        return self.compute(vec, i, args, ctx, use_esf)
234
235# Binary ops
236# We can't implement this as an NaryOp, because we need one context per binary
237# operation (and in this case, they would share the same context, leading to
238# incorrect results).
239class ExprBinaryOp(Expr):
240    def __init__(self, X, Y):
241        self._nbits = X.nbits
242        if (self._nbits != Y.nbits):
243            raise ValueError("X and Y must have the same number of bits!")
244        self.X = X
245        self.Y = Y
246
247    @property
248    def args(self):
249        return [self.X,self.Y]
250
251    @args.setter
252    def args(self, args):
253        self.X,self.Y = args
254
255    @property
256    def nbits(self):
257        return self._nbits
258
259    def eval(self, vec, i, ctx, args, use_esf):
260        X,Y = args
261        X = X.eval(vec, i, use_esf)
262        Y = Y.eval(vec, i, use_esf)
263        return self.compute_binop(vec, i, X, Y, ctx, use_esf)
264
265    @staticmethod
266    def compute_binop(vec, i, X, Y, ctx, use_esf):
267        raise NotImplementedError()
268
269# Nary ops
270class ExprXor(ExprNaryOp):
271    def compute(self, vec, i, args, ctx, use_esf):
272        return sum(args, imm(0))
273
274class ExprAnd(ExprNaryOp):
275    def __init__(self, *args):
276        super(ExprAnd,self).__init__(*args)
277        self.mask = (1<<self.nbits)-1
278        self._rem_args = []
279        for a in args:
280            if isinstance(a, ExprCst):
281                self.mask &= a.n
282    def compute(self, vec, i, args, ctx, use_esf):
283        if ((self.mask >> i) & 1) == 0:
284            return imm(0)
285        return reduce(lambda x,y: x*y, args)
286
287class ExprOr(ExprNaryOp):
288    def __init__(self, *args):
289        super(ExprOr,self).__init__(*args)
290        self.mask = 0
291        for a in args:
292            if isinstance(a, ExprCst):
293                self.mask |= a.n
294
295    def compute(self, vec, i, args, ctx, use_esf):
296        if ((self.mask >> i) & 1) == 1:
297            return imm(1)
298        args = list(args)
299        ret = esf(1, args)
300        for i in range(2, len(args)+1):
301            ret += esf(i, args)
302        if not use_esf:
303            expand_esf_inplace(ret)
304            simplify_inplace(ret)
305        return ret
306
307# Binary shifts
308class ExprShl(ExprBinaryOp):
309    @property
310    def n(self):
311        return ExprCst.get_cst(self.Y, self.nbits)
312
313    def eval(self, vec, i, ctx, args, use_esf):
314        if i < self.n:
315            return imm(False)
316        return args[0].eval(vec, i-self.n, use_esf)
317
318class ExprLShr(ExprBinaryOp):
319    @property
320    def n(self):
321        return ExprCst.get_cst(self.Y, self.nbits)
322
323    def eval(self, vec, i, ctx, args, use_esf):
324        if i >= self.nbits-self.n:
325            return imm(False)
326        return args[0].eval(vec, i+self.n, use_esf)
327
328class ExprAShr(ExprBinaryOp):
329    @property
330    def n(self):
331        return ExprCst.get_cst(self.Y, self.nbits)
332
333    def init_ctx(self):
334        return CtxUninitialized
335
336    def eval(self, vec, i, ctx, args, use_esf):
337        a = args[0]
338        n = self.n
339        # Let's cache the last bit we need to propagate
340        if i >= self.nbits-n:
341            last_bit = ctx.get()
342            if last_bit is CtxUninitialized:
343                last_bit = a.eval(vec, self.nbits-1, use_esf)
344                ctx.set(last_bit)
345            return last_bit
346        return a.eval(vec, i+n, use_esf)
347
348class ExprRol(ExprBinaryOp):
349    @property
350    def n(self):
351        return ExprCst.get_cst(self.Y, self.nbits)
352
353    def eval(self, vec, i, ctx, args, use_esf):
354        return args[0].eval(vec, (i-self.n)%self.nbits, use_esf)
355
356class ExprRor(ExprBinaryOp):
357    @property
358    def n(self):
359        return ExprCst.get_cst(self.Y, self.nbits)
360
361    def eval(self, vec, i, ctx, args, use_esf):
362        return args[0].eval(vec, (i+self.n)%self.nbits, use_esf)
363
364# Concat/slice/{z,s}ext/broadcast
365
366class ExprExtend(ExprUnaryOp):
367    def __init__(self, arg, n):
368        super(ExprExtend, self).__init__(arg)
369        self.n = ExprCst.get_cst(n)
370        self.arg_nbits = self.arg.nbits
371        assert(n >= self.nbits)
372
373    @property
374    def nbits(self):
375        return self.n
376
377class ExprSX(ExprExtend):
378    def init_ctx(self):
379        return CtxUninitialized
380
381    def eval(self, vec, i, ctx, args, use_esf):
382        arg = args[0]
383        if (i >= (self.arg_nbits-1)):
384            last_bit = ctx.get()
385            if last_bit is CtxUninitialized:
386                last_bit = arg.eval(vec, self.arg_nbits-1, use_esf)
387                ctx.set(last_bit)
388            return last_bit
389        return arg.eval(vec, i, use_esf)
390
391class ExprZX(ExprExtend):
392    def eval(self, vec, i, ctx, args, use_esf):
393        if (i >= self.arg_nbits):
394            return imm(0)
395        return args[0].eval(vec, i, use_esf)
396
397class ExprSlice(ExprUnaryOp):
398    def __init__(self, arg, slice_):
399        super(ExprSlice, self).__init__(arg)
400        if not isinstance(slice_, slice):
401            raise ValueError("slice_ must a slice object")
402        if (not slice_.step is None) and (slice_.step != 1):
403            raise ValueError("only slices with a step of 1 are supported!")
404        self.idxes = list(range(*slice_.indices(self.arg.nbits)))
405
406    @property
407    def nbits(self):
408        return len(self.idxes)
409
410    def eval(self, vec, i, ctx, args, use_esf):
411        return args[0].eval(vec, self.idxes[i], use_esf)
412
413class ExprConcat(ExprNaryOp):
414    @property
415    def nbits(self):
416        return sum((a.nbits for a in self.args))
417
418    def eval(self, vec, i, ctx, args, use_esf):
419        it = iter(args)
420        cur_arg = next(it)
421        cur_len = cur_arg.nbits
422        org_i = i
423        while i >= cur_len:
424            i -= cur_len
425            cur_arg = next(it)
426            cur_len = cur_arg.nbits
427        return cur_arg.eval(vec, i, use_esf)
428
429class ExprBroadcast(ExprUnaryOp):
430    def __init__(self, arg, idx, nbits):
431        super(ExprBroadcast, self).__init__(arg)
432        assert(idx >= 0)
433        self._nbits = ExprCst.get_cst(nbits)
434        self.idx = ExprCst.get_cst(idx,arg.nbits)
435
436    def init_ctx(self):
437        return CtxUninitialized
438
439    @property
440    def nbits(self):
441        return self._nbits
442
443    def eval(self, vec, i, ctx, args, use_esf):
444        ret = ctx.get()
445        if ret is CtxUninitialized:
446            ret = args[0].eval(vec, self.idx, use_esf)
447            ctx.set(ret)
448        return ret
449
450# Arithmetic ops
451class ExprBinopCarry(ExprBinaryOp):
452    class CtxCache:
453        def __init__(self, nbits):
454            self.cache = [CtxUninitialized]*nbits
455            self.last_bit = -1
456            self.carry = imm(0)
457
458    def init_ctx(self):
459        return ExprBinopCarry.CtxCache(self.nbits)
460
461    def eval(self, vec, i, ctx, args, use_esf):
462        CC = ctx.get()
463        ret = CC.cache[i]
464        if not ret is CtxUninitialized:
465            return ret
466        if i < CC.last_bit:
467            raise ValueError("asking for a bit before the last computed bit. This should not happen!")
468        X,Y = args
469        for j in range(CC.last_bit+1, i+1):
470            a = X.eval(vec, j, use_esf)
471            b = Y.eval(vec, j, use_esf)
472            CC.cache[j] = self.compute_binop_(vec, j, a, b, CC, use_esf)
473        CC.last_bit = i
474        return CC.cache[i]
475
476    @staticmethod
477    def compute_binop_(vec, i, X, Y, ctx, use_esf):
478        raise NotImplementedError()
479
480class ExprAdd(ExprBinopCarry):
481    @staticmethod
482    def compute_binop_(vec, i, X, Y, CC, use_esf):
483        carry = CC.carry
484
485        sum_args = simplify_inplace(X+Y)
486        ret = simplify_inplace(sum_args + carry)
487        # TODO: optimize this like in mba_if
488        carry = esf(2, [X, Y, carry])
489        if not use_esf:
490            expand_esf_inplace(carry)
491            simplify_inplace(carry)
492
493        CC.carry = carry
494        return ret
495
496class ExprSub(ExprBinopCarry):
497    @staticmethod
498    def compute_binop_(vec, i, X, Y, CC, use_esf):
499        carry = CC.carry
500
501        sum_args = simplify_inplace(X+Y)
502        ret = simplify_inplace(sum_args + carry)
503        carry = esf(2, [X+imm(1), Y, carry])
504        if not use_esf:
505            expand_esf_inplace(carry)
506            carry = simplify_inplace(carry)
507
508        CC.carry = carry
509        return ret
510
511class ExprInner(object):
512    def __init__(self, e):
513        self.inner_expr = e
514
515    def eval(self, vec, i, ctx, args, use_esf):
516        return ctx.eval(vec, i, use_esf)
517
518# x*y = x*(y0+y1<<1+y2<<2+...)
519class ExprMul(ExprInner, ExprBinaryOp):
520    def __init__(self, X, Y):
521        ExprBinaryOp.__init__(self,X,Y)
522        nbits = X.nbits
523        e = ExprAnd(X, ExprBroadcast(Y, 0, nbits))
524        for i in range(1, nbits):
525            e = ExprAdd(
526                e,
527                ExprAnd(ExprShl(X, ExprCst(i,nbits)), ExprBroadcast(Y, i, nbits)))
528        ExprInner.__init__(self,e)
529
530class ExprDiv(ExprInner, ExprBinaryOp):
531    def __init__(self, X, n, is_signed=False):
532        ExprBinaryOp.__init__(self,X,n)
533        nbits = X.nbits
534        self._is_signed = is_signed
535
536        # Arybo specific
537        if isinstance(n, ExprCst) and not self._is_signed:
538            n = ExprCst.get_cst(n,self.nbits)
539            nc = ((2**nbits)/n)*n - 1
540            for p in range(nbits, 2*self.nbits+1):
541                if(2**p > nc*(n - 1 - ((2**p - 1) % n))):
542                    break
543            else:
544                raise RuntimeError("division: unable to find the shifting count")
545            m = (2**p + n - 1 - ((2**p - 1) % n))//n
546
547            mul_nbits = 2*nbits+1
548            e = ExprSlice(
549                    ExprLShr(
550                        ExprMul(
551                            ExprZX(X, mul_nbits),
552                            ExprCst(m, mul_nbits)),
553                        ExprCst(p, mul_nbits)),
554                    slice(0, nbits, 1))
555            ExprInner.__init__(self,e)
556
557    @property
558    def is_signed(self):
559        return self._is_signed
560
561class ExprRem(ExprBinaryOp):
562    def __init__(self, X, Y, is_signed=False):
563        ExprBinaryOp.__init__(self,X,Y)
564        self._is_signed = is_signed
565
566    @property
567    def is_signed(self):
568        return self._is_signed
569
570# Logical expressions (1-bit)
571class ExprLogical(ExprBinaryOp):
572    @property
573    def nbits(self):
574        return 1
575
576    def eval(self, vec, i, ctx, args, use_esf):
577        assert(i == 0)
578        X,Y = args
579        return self.compute_logical(vec, X, Y, ctx, use_esf)
580
581    @staticmethod
582    def compute_logical(vec, X, Y, ctx, use_esf):
583        raise NotImplementedError()
584
585class ExprCmp(ExprLogical):
586    OpEq, OpNeq, OpLt, OpLte, OpGt, OpGte = list(range(6))
587
588    def __init__(self, op, X, Y, is_signed=False):
589        super(ExprCmp,self).__init__(X,Y)
590        self.op = op
591        self._is_signed = is_signed
592
593    @property
594    def is_signed(self):
595        return self._is_signed
596
597class ExprCmpEq(ExprCmp):
598    def __init__(self, X, Y):
599        super(ExprCmpEq,self).__init__(ExprCmp.OpEq, X, Y)
600
601    @staticmethod
602    def compute_logical(vec, X, Y, ctx, use_esf):
603        nbits = X.nbits
604        e = imm(1)
605        for i in range(nbits):
606            e *= X.eval(vec, i, use_esf) + Y.eval(vec, i, use_esf) + imm(1)
607            simplify_inplace(e)
608        return e
609
610class ExprCmpNeq(ExprCmp):
611    def __init__(self, X, Y):
612        super(ExprCmpNeq,self).__init__(ExprCmp.OpNeq, X, Y)
613
614    @staticmethod
615    def compute_logical(vec, X, Y, ctx, use_esf):
616        return simplify_inplace(ExprCmpEq.compute_logical(vec,X,Y,ctx,use_esf)+imm(1))
617
618class ExprCmpLt(ExprCmp):
619    def __init__(self, *args, **kwargs):
620        super(ExprCmpLt,self).__init__(ExprCmp.OpLt, *args, **kwargs)
621
622class ExprCmpLte(ExprCmp):
623    def __init__(self, *args, **kwargs):
624        super(ExprCmpLte,self).__init__(ExprCmp.OpLte, *args, **kwargs)
625
626class ExprCmpGt(ExprCmp):
627    def __init__(self, *args, **kwargs):
628        super(ExprCmpGt,self).__init__(ExprCmp.OpGt, *args, **kwargs)
629
630class ExprCmpGte(ExprCmp):
631    def __init__(self, *args, **kwargs):
632        super(ExprCmpGte,self).__init__(ExprCmp.OpGte, *args, **kwargs)
633
634# Condition operator
635# res = (cond) ? a:b
636
637class ExprCond(ExprInner, Expr):
638    def __init__(self, cond, a, b):
639        if cond.nbits != 1:
640            raise ValueError("condition must be a one-bit expression")
641        self.cond = cond
642        self.a = a
643        self.b = b
644        self._nbits = a.nbits
645        if self._nbits != self.b.nbits:
646            raise ValueError("a and b must have the same number of bits!")
647
648        cond_broadcast = ExprBroadcast(cond, 0, self._nbits)
649        e = ExprXor(
650            ExprAnd(cond_broadcast, self.a),
651            ExprAnd(ExprNot(cond_broadcast), self.b)
652        )
653        ExprInner.__init__(self,e)
654
655    @property
656    def args(self):
657        return [self.cond,self.a,self.b]
658    @args.setter
659    def args(self, args):
660        self.cond,self.a,self.b = args
661
662    @property
663    def nbits(self):
664        return self._nbits
665
666# Generic visitors
667def visit(e, visitor):
668    def visit_type(e):
669        a = "visit_%s" % e.__name__[4:]
670        return a
671    e_try = collections.deque()
672    e_try.append(e.__class__)
673    cb = None
674    while len(e_try) > 0:
675        cur_ty = e_try.pop()
676        try:
677            cb = getattr(visitor, visit_type(cur_ty))
678            break
679        except AttributeError:
680            e_try.extend((B for B in cur_ty.__bases__ if not B in (object,Expr,ExprInner)))
681    if cb is None:
682        cb = getattr(visitor, "visit_Expr")
683    if hasattr(visitor, "visit_wrapper"):
684        return visitor.visit_wrapper(e, cb)
685    return cb(e)
686
687# Evaluator
688class ExprWithCtx(object):
689    def __init__(self, e, ctx):
690        self.e = e
691        self.ctx = ctx
692        self.args = None
693
694    @property
695    def nbits(self):
696        return self.e.nbits
697
698    def eval(self, vec, i, use_esf):
699        return simplify_inplace(self.e.eval(vec, i, self.ctx, self.args, use_esf))
700
701class CtxWrapper:
702    def __init__(self, v):
703        self.__v = v
704    def get(self): return self.__v
705    def set(self, v): self.__v = v
706
707class _CtxUninitialized:
708    pass
709CtxUninitialized = _CtxUninitialized()
710
711def init_ctxes(E):
712    all_ctxs = dict()
713    def init_ctx(e_):
714        ectx = all_ctxs.get(id(e_), None)
715        if not ectx is None:
716            return ectx
717        if isinstance(e_, ExprInner):
718            einn = e_.inner_expr
719            ectx = init_ctx(einn)
720        else:
721            ectx = ExprWithCtx(e_, CtxWrapper(e_.init_ctx()))
722            args = e_.args
723            if not args is None:
724                ectx.args = [init_ctx(a) for a in args]
725        all_ctxs[id(e_)] = ectx
726        return ectx
727    return init_ctx(E)
728
729def eval_expr(e,use_esf=False):
730    ectx = init_ctxes(e)
731
732    ret = Vector(e.nbits)
733    for i in range(e.nbits):
734        ret[i] = ectx.eval(ret, i, use_esf)
735    mba = MBA(len(ret))
736    return mba.from_vec(ret)
737
738# Prettyprinter
739class PrettyPrinter(object):
740    def visit(self, e):
741        return visit(e, self)
742    def visit_Cst(self, e):
743        return hex(e.n)
744    def visit_BV(self, e):
745        e = e.v
746        if not e.name is None:
747            return e.name
748        estr = ", ".join((str(a) for a in e.vec))
749        return "BV(%s)" % estr
750    def visit_Not(self, e):
751        return "~"+self.visit(e.args[0])
752    def visit_Shl(self, e):
753        return "(%s << %d)" % (self.visit(e.args[0]), e.n)
754    def visit_LShr(self, e):
755        return "(%s l>> %d)" % (self.visit(e.args[0]), e.n)
756    def visit_AShr(self, e):
757        return "(%s a>> %d)" % (self.visit(e.args[0]), e.n)
758    def visit_Rol(self, e):
759        return "rol(%s,%d)" % (self.visit(e.args[0]), e.n)
760    def visit_Ror(self, e):
761        return "rol(%s,%d)" % (self.visit(e.args[0]), e.n)
762    def visit_SX(self, e):
763        return "sx(%d, %s)" % (e.n, self.visit(e.args[0]))
764    def visit_ZX(self, e):
765        return "zx(%d, %s)" % (e.n, self.visit(e.args[0]))
766    def visit_Slice(self, e):
767        idxes = sorted(e.idxes)
768        return "%s[%d:%d]" % (self.visit(e.arg), idxes[0], idxes[-1])
769    def visit_Concat(self, e):
770        return "concat(%s)" % (",".join((self.visit(a) for a in e.args)))
771    def visit_Broadcast(self, e):
772        return "broadcast(%d, %s)" % (e.idx, self.visit(e.arg))
773    def visit_nary_args(self, e, ops):
774        op = ops[type(e)]
775        return "("+(" %s " % op).join(self.visit(a) for a in e.args)+")"
776    def visit_BinaryOp(self, e):
777        ops = {ExprAdd: '+', ExprMul: '*', ExprSub: '-', ExprDiv: '/', ExprRem: '%'}
778        return self.visit_nary_args(e, ops)
779    def visit_NaryOp(self, e):
780        ops = {ExprXor: '^', ExprAnd: '&', ExprOr: '|'}
781        return self.visit_nary_args(e, ops)
782    def visit_Cond(self, e):
783        cond = self.visit(e.cond)
784        a = self.visit(e.a)
785        b = self.visit(e.b)
786        return "(%s) ? (%s) : (%s)" % (cond,a,b)
787    def visit_Cmp(self, e):
788        ops = {
789          ExprCmpEq: '==',
790          ExprCmpNeq: '!=',
791          ExprCmpLt: '<',
792          ExprCmpLte: '<=',
793          ExprCmpGt: '>',
794          ExprCmpGte: '>='
795        }
796        op = ops[type(e)]
797        X = self.visit(e.X)
798        Y = self.visit(e.Y)
799        return "(%s) %s (%s)" % (X, op, Y)
800    def visit_Expr(self, e):
801        raise ValueError("unsupported type %s" % str(type(e)))
802
803def prettyprint(e):
804    ret = PrettyPrinter().visit(e)
805    if ret[0] == '(' and ret[-1] == ')':
806        ret = ret[1:-1]
807    return ret
808