1# TODO: replace ExprInner by passes that replaces the objects by what's needed 2# for arybo 3 4import functools 5import operator 6import six 7import collections 8 9from six.moves import range, reduce 10 11from arybo.lib import MBA, simplify_inplace 12from pytanque import imm, expand_esf_inplace, simplify_inplace, Vector, esf 13 14class Expr(object): 15 @property 16 def nbits(self): 17 raise NotImplementedError() 18 19 @property 20 def args(self): 21 return None 22 23 def init_ctx(self): 24 return None 25 26 def eval(self, vec, i, ctx, args, use_esf): 27 raise NotImplementedError() 28 29 def __parse_arg(self,v): 30 if isinstance(v, six.integer_types): 31 return ExprCst(v, self.nbits) 32 if not isinstance(v, Expr): 33 raise ValueError("argument must be an integer or an Expr") 34 return v 35 36 def __check_arg_int(self,v): 37 if not isinstance(v, six.integer_types): 38 raise ValueError("argument must be an integer") 39 40 def __add__(self, o): 41 o = self.__parse_arg(o) 42 return ExprAdd(self, o) 43 44 def __radd__(self, o): 45 o = self.__parse_arg(o) 46 return ExprAdd(o, self) 47 48 def __sub__(self, o): 49 o = self.__parse_arg(o) 50 return ExprSub(self, o) 51 52 def __rsub__(self, o): 53 o = self.__parse_arg(o) 54 return ExprSub(o,self) 55 56 def __mul__(self, o): 57 o = self.__parse_arg(o) 58 return ExprMul(self, o) 59 60 def __rmul__(self, o): 61 o = self.__parse_arg(o) 62 return ExprMul(o, self) 63 64 def __xor__(self, o): 65 o = self.__parse_arg(o) 66 return ExprXor(self, o) 67 68 def __rxor__(self, o): 69 o = self.__parse_arg(o) 70 return ExprXor(o, self) 71 72 def __and__(self, o): 73 o = self.__parse_arg(o) 74 return ExprAnd(self, o) 75 76 def __rand__(self, o): 77 o = self.__parse_arg(o) 78 return ExprAnd(o, self) 79 80 def __or__(self, o): 81 o = self.__parse_arg(o) 82 return ExprOr(self, o) 83 84 def __ror__(self, o): 85 o = self.__parse_arg(o) 86 return ExprOr(o, self) 87 88 def __lshift__(self, o): 89 o = self.__parse_arg(o) 90 return ExprShl(self, o) 91 92 def __rshift__(self, o): 93 return self.lshr(o) 94 95 def lshr(self, o): 96 o = self.__parse_arg(o) 97 return ExprLShr(self, o) 98 99 def ashr(self, o): 100 o = self.__parse_arg(o) 101 return ExprAShr(self, o) 102 103 def __neg__(self): 104 return ExprSub(ExprCst(0, self.nbits), self) 105 106 def __invert__(self): 107 return ExprNot(self) 108 109 def zext(self,n): 110 self.__check_arg_int(n) 111 return ExprZX(self,n) 112 113 def sext(self,n): 114 self.__check_arg_int(n) 115 return ExprSX(self,n) 116 117 def rol(self,o): 118 o = self.__parse_arg(o) 119 return ExprRol(self,o) 120 121 def ror(self,o): 122 o = self.__parse_arg(o) 123 return ExprRor(self,o) 124 125 def udiv(self,o): 126 o = self.__parse_arg(o) 127 return ExprDiv(self, o, is_signed=False) 128 129 def sdiv(self,o): 130 o = self.__parse_arg(o) 131 return ExprDiv(self, o, is_signed=True) 132 133 def urem(self, o): 134 o = self.__parse_arg(o) 135 return ExprRem(self, o, is_signed=False) 136 137 def srem(self, o): 138 o = self.__parse_arg(o) 139 return ExprRem(self, o, is_signed=True) 140 141 def __getitem__(self,s): 142 if not isinstance(s, slice): 143 raise ValueError("can only get slices") 144 return ExprSlice(self, s) 145 146# Leaves 147class ExprCst(Expr): 148 def __init__(self, n, nbits): 149 assert(n >= 0) 150 self.n = n & ((1<<nbits)-1) 151 self.__nbits = nbits 152 153 @property 154 def nbits(self): 155 return self.__nbits 156 157 def eval(self, vec, i, ctx, args, use_esf): 158 return imm((self.n>>i)&1) 159 160 def to_cst(self): 161 return self.n 162 163 @staticmethod 164 def get_cst(obj,nbits=None): 165 if isinstance(obj, ExprCst): 166 ret = obj.n 167 elif isinstance(obj, six.integer_types): 168 ret = obj 169 else: 170 raise ValueError("obj must be an ExprCst or an integer") 171 if nbits is None: 172 return ret 173 return ExprCst(ret,nbits).n 174 175class ExprBV(Expr): 176 def __init__(self, v): 177 self.v = v 178 179 @property 180 def nbits(self): 181 return self.v.nbits 182 183 def eval(self, vec, i, ctx, args, use_esf): 184 return self.v.vec[i] 185 186 def to_cst(self): 187 return self.v.to_cst() 188 189# Unary ops 190class ExprUnaryOp(Expr): 191 def __init__(self, arg): 192 self.arg = arg 193 194 @property 195 def args(self): 196 return [self.arg] 197 198 @args.setter 199 def args(self, args): 200 self.arg = args[0] 201 202 @property 203 def nbits(self): 204 return self.arg.nbits 205 206class ExprNot(ExprUnaryOp): 207 def eval(self, vec, i, ctx, args, use_esf): 208 return args[0].eval(vec, i, use_esf) + imm(True) 209 210# Nary ops 211class ExprNaryOp(Expr): 212 def __init__(self, *args): 213 self._args = args 214 215 @property 216 def args(self): 217 return self._args 218 219 @args.setter 220 def args(self, args): 221 self._args = args 222 223 @property 224 def nbits(self): 225 # TODO assert every args has the same size 226 return self.args[0].nbits 227 228 def compute(self, vec, i, args, ctx, use_esf): 229 raise NotImplementedError() 230 231 def eval(self, vec, i, ctx, args, use_esf): 232 args = (a.eval(vec, i, use_esf) for a in args) 233 return self.compute(vec, i, args, ctx, use_esf) 234 235# Binary ops 236# We can't implement this as an NaryOp, because we need one context per binary 237# operation (and in this case, they would share the same context, leading to 238# incorrect results). 239class ExprBinaryOp(Expr): 240 def __init__(self, X, Y): 241 self._nbits = X.nbits 242 if (self._nbits != Y.nbits): 243 raise ValueError("X and Y must have the same number of bits!") 244 self.X = X 245 self.Y = Y 246 247 @property 248 def args(self): 249 return [self.X,self.Y] 250 251 @args.setter 252 def args(self, args): 253 self.X,self.Y = args 254 255 @property 256 def nbits(self): 257 return self._nbits 258 259 def eval(self, vec, i, ctx, args, use_esf): 260 X,Y = args 261 X = X.eval(vec, i, use_esf) 262 Y = Y.eval(vec, i, use_esf) 263 return self.compute_binop(vec, i, X, Y, ctx, use_esf) 264 265 @staticmethod 266 def compute_binop(vec, i, X, Y, ctx, use_esf): 267 raise NotImplementedError() 268 269# Nary ops 270class ExprXor(ExprNaryOp): 271 def compute(self, vec, i, args, ctx, use_esf): 272 return sum(args, imm(0)) 273 274class ExprAnd(ExprNaryOp): 275 def __init__(self, *args): 276 super(ExprAnd,self).__init__(*args) 277 self.mask = (1<<self.nbits)-1 278 self._rem_args = [] 279 for a in args: 280 if isinstance(a, ExprCst): 281 self.mask &= a.n 282 def compute(self, vec, i, args, ctx, use_esf): 283 if ((self.mask >> i) & 1) == 0: 284 return imm(0) 285 return reduce(lambda x,y: x*y, args) 286 287class ExprOr(ExprNaryOp): 288 def __init__(self, *args): 289 super(ExprOr,self).__init__(*args) 290 self.mask = 0 291 for a in args: 292 if isinstance(a, ExprCst): 293 self.mask |= a.n 294 295 def compute(self, vec, i, args, ctx, use_esf): 296 if ((self.mask >> i) & 1) == 1: 297 return imm(1) 298 args = list(args) 299 ret = esf(1, args) 300 for i in range(2, len(args)+1): 301 ret += esf(i, args) 302 if not use_esf: 303 expand_esf_inplace(ret) 304 simplify_inplace(ret) 305 return ret 306 307# Binary shifts 308class ExprShl(ExprBinaryOp): 309 @property 310 def n(self): 311 return ExprCst.get_cst(self.Y, self.nbits) 312 313 def eval(self, vec, i, ctx, args, use_esf): 314 if i < self.n: 315 return imm(False) 316 return args[0].eval(vec, i-self.n, use_esf) 317 318class ExprLShr(ExprBinaryOp): 319 @property 320 def n(self): 321 return ExprCst.get_cst(self.Y, self.nbits) 322 323 def eval(self, vec, i, ctx, args, use_esf): 324 if i >= self.nbits-self.n: 325 return imm(False) 326 return args[0].eval(vec, i+self.n, use_esf) 327 328class ExprAShr(ExprBinaryOp): 329 @property 330 def n(self): 331 return ExprCst.get_cst(self.Y, self.nbits) 332 333 def init_ctx(self): 334 return CtxUninitialized 335 336 def eval(self, vec, i, ctx, args, use_esf): 337 a = args[0] 338 n = self.n 339 # Let's cache the last bit we need to propagate 340 if i >= self.nbits-n: 341 last_bit = ctx.get() 342 if last_bit is CtxUninitialized: 343 last_bit = a.eval(vec, self.nbits-1, use_esf) 344 ctx.set(last_bit) 345 return last_bit 346 return a.eval(vec, i+n, use_esf) 347 348class ExprRol(ExprBinaryOp): 349 @property 350 def n(self): 351 return ExprCst.get_cst(self.Y, self.nbits) 352 353 def eval(self, vec, i, ctx, args, use_esf): 354 return args[0].eval(vec, (i-self.n)%self.nbits, use_esf) 355 356class ExprRor(ExprBinaryOp): 357 @property 358 def n(self): 359 return ExprCst.get_cst(self.Y, self.nbits) 360 361 def eval(self, vec, i, ctx, args, use_esf): 362 return args[0].eval(vec, (i+self.n)%self.nbits, use_esf) 363 364# Concat/slice/{z,s}ext/broadcast 365 366class ExprExtend(ExprUnaryOp): 367 def __init__(self, arg, n): 368 super(ExprExtend, self).__init__(arg) 369 self.n = ExprCst.get_cst(n) 370 self.arg_nbits = self.arg.nbits 371 assert(n >= self.nbits) 372 373 @property 374 def nbits(self): 375 return self.n 376 377class ExprSX(ExprExtend): 378 def init_ctx(self): 379 return CtxUninitialized 380 381 def eval(self, vec, i, ctx, args, use_esf): 382 arg = args[0] 383 if (i >= (self.arg_nbits-1)): 384 last_bit = ctx.get() 385 if last_bit is CtxUninitialized: 386 last_bit = arg.eval(vec, self.arg_nbits-1, use_esf) 387 ctx.set(last_bit) 388 return last_bit 389 return arg.eval(vec, i, use_esf) 390 391class ExprZX(ExprExtend): 392 def eval(self, vec, i, ctx, args, use_esf): 393 if (i >= self.arg_nbits): 394 return imm(0) 395 return args[0].eval(vec, i, use_esf) 396 397class ExprSlice(ExprUnaryOp): 398 def __init__(self, arg, slice_): 399 super(ExprSlice, self).__init__(arg) 400 if not isinstance(slice_, slice): 401 raise ValueError("slice_ must a slice object") 402 if (not slice_.step is None) and (slice_.step != 1): 403 raise ValueError("only slices with a step of 1 are supported!") 404 self.idxes = list(range(*slice_.indices(self.arg.nbits))) 405 406 @property 407 def nbits(self): 408 return len(self.idxes) 409 410 def eval(self, vec, i, ctx, args, use_esf): 411 return args[0].eval(vec, self.idxes[i], use_esf) 412 413class ExprConcat(ExprNaryOp): 414 @property 415 def nbits(self): 416 return sum((a.nbits for a in self.args)) 417 418 def eval(self, vec, i, ctx, args, use_esf): 419 it = iter(args) 420 cur_arg = next(it) 421 cur_len = cur_arg.nbits 422 org_i = i 423 while i >= cur_len: 424 i -= cur_len 425 cur_arg = next(it) 426 cur_len = cur_arg.nbits 427 return cur_arg.eval(vec, i, use_esf) 428 429class ExprBroadcast(ExprUnaryOp): 430 def __init__(self, arg, idx, nbits): 431 super(ExprBroadcast, self).__init__(arg) 432 assert(idx >= 0) 433 self._nbits = ExprCst.get_cst(nbits) 434 self.idx = ExprCst.get_cst(idx,arg.nbits) 435 436 def init_ctx(self): 437 return CtxUninitialized 438 439 @property 440 def nbits(self): 441 return self._nbits 442 443 def eval(self, vec, i, ctx, args, use_esf): 444 ret = ctx.get() 445 if ret is CtxUninitialized: 446 ret = args[0].eval(vec, self.idx, use_esf) 447 ctx.set(ret) 448 return ret 449 450# Arithmetic ops 451class ExprBinopCarry(ExprBinaryOp): 452 class CtxCache: 453 def __init__(self, nbits): 454 self.cache = [CtxUninitialized]*nbits 455 self.last_bit = -1 456 self.carry = imm(0) 457 458 def init_ctx(self): 459 return ExprBinopCarry.CtxCache(self.nbits) 460 461 def eval(self, vec, i, ctx, args, use_esf): 462 CC = ctx.get() 463 ret = CC.cache[i] 464 if not ret is CtxUninitialized: 465 return ret 466 if i < CC.last_bit: 467 raise ValueError("asking for a bit before the last computed bit. This should not happen!") 468 X,Y = args 469 for j in range(CC.last_bit+1, i+1): 470 a = X.eval(vec, j, use_esf) 471 b = Y.eval(vec, j, use_esf) 472 CC.cache[j] = self.compute_binop_(vec, j, a, b, CC, use_esf) 473 CC.last_bit = i 474 return CC.cache[i] 475 476 @staticmethod 477 def compute_binop_(vec, i, X, Y, ctx, use_esf): 478 raise NotImplementedError() 479 480class ExprAdd(ExprBinopCarry): 481 @staticmethod 482 def compute_binop_(vec, i, X, Y, CC, use_esf): 483 carry = CC.carry 484 485 sum_args = simplify_inplace(X+Y) 486 ret = simplify_inplace(sum_args + carry) 487 # TODO: optimize this like in mba_if 488 carry = esf(2, [X, Y, carry]) 489 if not use_esf: 490 expand_esf_inplace(carry) 491 simplify_inplace(carry) 492 493 CC.carry = carry 494 return ret 495 496class ExprSub(ExprBinopCarry): 497 @staticmethod 498 def compute_binop_(vec, i, X, Y, CC, use_esf): 499 carry = CC.carry 500 501 sum_args = simplify_inplace(X+Y) 502 ret = simplify_inplace(sum_args + carry) 503 carry = esf(2, [X+imm(1), Y, carry]) 504 if not use_esf: 505 expand_esf_inplace(carry) 506 carry = simplify_inplace(carry) 507 508 CC.carry = carry 509 return ret 510 511class ExprInner(object): 512 def __init__(self, e): 513 self.inner_expr = e 514 515 def eval(self, vec, i, ctx, args, use_esf): 516 return ctx.eval(vec, i, use_esf) 517 518# x*y = x*(y0+y1<<1+y2<<2+...) 519class ExprMul(ExprInner, ExprBinaryOp): 520 def __init__(self, X, Y): 521 ExprBinaryOp.__init__(self,X,Y) 522 nbits = X.nbits 523 e = ExprAnd(X, ExprBroadcast(Y, 0, nbits)) 524 for i in range(1, nbits): 525 e = ExprAdd( 526 e, 527 ExprAnd(ExprShl(X, ExprCst(i,nbits)), ExprBroadcast(Y, i, nbits))) 528 ExprInner.__init__(self,e) 529 530class ExprDiv(ExprInner, ExprBinaryOp): 531 def __init__(self, X, n, is_signed=False): 532 ExprBinaryOp.__init__(self,X,n) 533 nbits = X.nbits 534 self._is_signed = is_signed 535 536 # Arybo specific 537 if isinstance(n, ExprCst) and not self._is_signed: 538 n = ExprCst.get_cst(n,self.nbits) 539 nc = ((2**nbits)/n)*n - 1 540 for p in range(nbits, 2*self.nbits+1): 541 if(2**p > nc*(n - 1 - ((2**p - 1) % n))): 542 break 543 else: 544 raise RuntimeError("division: unable to find the shifting count") 545 m = (2**p + n - 1 - ((2**p - 1) % n))//n 546 547 mul_nbits = 2*nbits+1 548 e = ExprSlice( 549 ExprLShr( 550 ExprMul( 551 ExprZX(X, mul_nbits), 552 ExprCst(m, mul_nbits)), 553 ExprCst(p, mul_nbits)), 554 slice(0, nbits, 1)) 555 ExprInner.__init__(self,e) 556 557 @property 558 def is_signed(self): 559 return self._is_signed 560 561class ExprRem(ExprBinaryOp): 562 def __init__(self, X, Y, is_signed=False): 563 ExprBinaryOp.__init__(self,X,Y) 564 self._is_signed = is_signed 565 566 @property 567 def is_signed(self): 568 return self._is_signed 569 570# Logical expressions (1-bit) 571class ExprLogical(ExprBinaryOp): 572 @property 573 def nbits(self): 574 return 1 575 576 def eval(self, vec, i, ctx, args, use_esf): 577 assert(i == 0) 578 X,Y = args 579 return self.compute_logical(vec, X, Y, ctx, use_esf) 580 581 @staticmethod 582 def compute_logical(vec, X, Y, ctx, use_esf): 583 raise NotImplementedError() 584 585class ExprCmp(ExprLogical): 586 OpEq, OpNeq, OpLt, OpLte, OpGt, OpGte = list(range(6)) 587 588 def __init__(self, op, X, Y, is_signed=False): 589 super(ExprCmp,self).__init__(X,Y) 590 self.op = op 591 self._is_signed = is_signed 592 593 @property 594 def is_signed(self): 595 return self._is_signed 596 597class ExprCmpEq(ExprCmp): 598 def __init__(self, X, Y): 599 super(ExprCmpEq,self).__init__(ExprCmp.OpEq, X, Y) 600 601 @staticmethod 602 def compute_logical(vec, X, Y, ctx, use_esf): 603 nbits = X.nbits 604 e = imm(1) 605 for i in range(nbits): 606 e *= X.eval(vec, i, use_esf) + Y.eval(vec, i, use_esf) + imm(1) 607 simplify_inplace(e) 608 return e 609 610class ExprCmpNeq(ExprCmp): 611 def __init__(self, X, Y): 612 super(ExprCmpNeq,self).__init__(ExprCmp.OpNeq, X, Y) 613 614 @staticmethod 615 def compute_logical(vec, X, Y, ctx, use_esf): 616 return simplify_inplace(ExprCmpEq.compute_logical(vec,X,Y,ctx,use_esf)+imm(1)) 617 618class ExprCmpLt(ExprCmp): 619 def __init__(self, *args, **kwargs): 620 super(ExprCmpLt,self).__init__(ExprCmp.OpLt, *args, **kwargs) 621 622class ExprCmpLte(ExprCmp): 623 def __init__(self, *args, **kwargs): 624 super(ExprCmpLte,self).__init__(ExprCmp.OpLte, *args, **kwargs) 625 626class ExprCmpGt(ExprCmp): 627 def __init__(self, *args, **kwargs): 628 super(ExprCmpGt,self).__init__(ExprCmp.OpGt, *args, **kwargs) 629 630class ExprCmpGte(ExprCmp): 631 def __init__(self, *args, **kwargs): 632 super(ExprCmpGte,self).__init__(ExprCmp.OpGte, *args, **kwargs) 633 634# Condition operator 635# res = (cond) ? a:b 636 637class ExprCond(ExprInner, Expr): 638 def __init__(self, cond, a, b): 639 if cond.nbits != 1: 640 raise ValueError("condition must be a one-bit expression") 641 self.cond = cond 642 self.a = a 643 self.b = b 644 self._nbits = a.nbits 645 if self._nbits != self.b.nbits: 646 raise ValueError("a and b must have the same number of bits!") 647 648 cond_broadcast = ExprBroadcast(cond, 0, self._nbits) 649 e = ExprXor( 650 ExprAnd(cond_broadcast, self.a), 651 ExprAnd(ExprNot(cond_broadcast), self.b) 652 ) 653 ExprInner.__init__(self,e) 654 655 @property 656 def args(self): 657 return [self.cond,self.a,self.b] 658 @args.setter 659 def args(self, args): 660 self.cond,self.a,self.b = args 661 662 @property 663 def nbits(self): 664 return self._nbits 665 666# Generic visitors 667def visit(e, visitor): 668 def visit_type(e): 669 a = "visit_%s" % e.__name__[4:] 670 return a 671 e_try = collections.deque() 672 e_try.append(e.__class__) 673 cb = None 674 while len(e_try) > 0: 675 cur_ty = e_try.pop() 676 try: 677 cb = getattr(visitor, visit_type(cur_ty)) 678 break 679 except AttributeError: 680 e_try.extend((B for B in cur_ty.__bases__ if not B in (object,Expr,ExprInner))) 681 if cb is None: 682 cb = getattr(visitor, "visit_Expr") 683 if hasattr(visitor, "visit_wrapper"): 684 return visitor.visit_wrapper(e, cb) 685 return cb(e) 686 687# Evaluator 688class ExprWithCtx(object): 689 def __init__(self, e, ctx): 690 self.e = e 691 self.ctx = ctx 692 self.args = None 693 694 @property 695 def nbits(self): 696 return self.e.nbits 697 698 def eval(self, vec, i, use_esf): 699 return simplify_inplace(self.e.eval(vec, i, self.ctx, self.args, use_esf)) 700 701class CtxWrapper: 702 def __init__(self, v): 703 self.__v = v 704 def get(self): return self.__v 705 def set(self, v): self.__v = v 706 707class _CtxUninitialized: 708 pass 709CtxUninitialized = _CtxUninitialized() 710 711def init_ctxes(E): 712 all_ctxs = dict() 713 def init_ctx(e_): 714 ectx = all_ctxs.get(id(e_), None) 715 if not ectx is None: 716 return ectx 717 if isinstance(e_, ExprInner): 718 einn = e_.inner_expr 719 ectx = init_ctx(einn) 720 else: 721 ectx = ExprWithCtx(e_, CtxWrapper(e_.init_ctx())) 722 args = e_.args 723 if not args is None: 724 ectx.args = [init_ctx(a) for a in args] 725 all_ctxs[id(e_)] = ectx 726 return ectx 727 return init_ctx(E) 728 729def eval_expr(e,use_esf=False): 730 ectx = init_ctxes(e) 731 732 ret = Vector(e.nbits) 733 for i in range(e.nbits): 734 ret[i] = ectx.eval(ret, i, use_esf) 735 mba = MBA(len(ret)) 736 return mba.from_vec(ret) 737 738# Prettyprinter 739class PrettyPrinter(object): 740 def visit(self, e): 741 return visit(e, self) 742 def visit_Cst(self, e): 743 return hex(e.n) 744 def visit_BV(self, e): 745 e = e.v 746 if not e.name is None: 747 return e.name 748 estr = ", ".join((str(a) for a in e.vec)) 749 return "BV(%s)" % estr 750 def visit_Not(self, e): 751 return "~"+self.visit(e.args[0]) 752 def visit_Shl(self, e): 753 return "(%s << %d)" % (self.visit(e.args[0]), e.n) 754 def visit_LShr(self, e): 755 return "(%s l>> %d)" % (self.visit(e.args[0]), e.n) 756 def visit_AShr(self, e): 757 return "(%s a>> %d)" % (self.visit(e.args[0]), e.n) 758 def visit_Rol(self, e): 759 return "rol(%s,%d)" % (self.visit(e.args[0]), e.n) 760 def visit_Ror(self, e): 761 return "rol(%s,%d)" % (self.visit(e.args[0]), e.n) 762 def visit_SX(self, e): 763 return "sx(%d, %s)" % (e.n, self.visit(e.args[0])) 764 def visit_ZX(self, e): 765 return "zx(%d, %s)" % (e.n, self.visit(e.args[0])) 766 def visit_Slice(self, e): 767 idxes = sorted(e.idxes) 768 return "%s[%d:%d]" % (self.visit(e.arg), idxes[0], idxes[-1]) 769 def visit_Concat(self, e): 770 return "concat(%s)" % (",".join((self.visit(a) for a in e.args))) 771 def visit_Broadcast(self, e): 772 return "broadcast(%d, %s)" % (e.idx, self.visit(e.arg)) 773 def visit_nary_args(self, e, ops): 774 op = ops[type(e)] 775 return "("+(" %s " % op).join(self.visit(a) for a in e.args)+")" 776 def visit_BinaryOp(self, e): 777 ops = {ExprAdd: '+', ExprMul: '*', ExprSub: '-', ExprDiv: '/', ExprRem: '%'} 778 return self.visit_nary_args(e, ops) 779 def visit_NaryOp(self, e): 780 ops = {ExprXor: '^', ExprAnd: '&', ExprOr: '|'} 781 return self.visit_nary_args(e, ops) 782 def visit_Cond(self, e): 783 cond = self.visit(e.cond) 784 a = self.visit(e.a) 785 b = self.visit(e.b) 786 return "(%s) ? (%s) : (%s)" % (cond,a,b) 787 def visit_Cmp(self, e): 788 ops = { 789 ExprCmpEq: '==', 790 ExprCmpNeq: '!=', 791 ExprCmpLt: '<', 792 ExprCmpLte: '<=', 793 ExprCmpGt: '>', 794 ExprCmpGte: '>=' 795 } 796 op = ops[type(e)] 797 X = self.visit(e.X) 798 Y = self.visit(e.Y) 799 return "(%s) %s (%s)" % (X, op, Y) 800 def visit_Expr(self, e): 801 raise ValueError("unsupported type %s" % str(type(e))) 802 803def prettyprint(e): 804 ret = PrettyPrinter().visit(e) 805 if ret[0] == '(' and ret[-1] == ')': 806 ret = ret[1:-1] 807 return ret 808