1#!/usr/bin/env python 2# coding: utf-8 3"""Test AST conversion.""" 4 5import operator 6import random 7import unittest 8import utils 9 10from triton import * 11 12 13 14class TestAstConversion(unittest.TestCase): 15 16 """Testing the AST conversion Triton <-> z3.""" 17 18 def setUp(self): 19 """Define the arch.""" 20 self.Triton = TritonContext() 21 self.Triton.setArchitecture(ARCH.X86_64) 22 23 self.astCtxt = self.Triton.getAstContext() 24 25 self.sv1 = self.Triton.newSymbolicVariable(8) 26 self.sv2 = self.Triton.newSymbolicVariable(8) 27 28 self.v1 = self.astCtxt.variable(self.sv1) 29 self.v2 = self.astCtxt.variable(self.sv2) 30 31 def test_binop(self): 32 """ 33 Check python binary operation. 34 35 Fuzz int8/uint8 binop values and check triton/z3 and python results. 36 """ 37 # No simplification available 38 # This only going to test Triton <-> z3 AST conversions. 39 binop = [ 40 # Overloaded operators 41 operator.and_, 42 operator.add, 43 operator.sub, 44 operator.xor, 45 operator.or_, 46 operator.mul, 47 operator.lshift, 48 operator.rshift, 49 operator.eq, 50 operator.ne, 51 operator.le, 52 operator.ge, 53 operator.lt, 54 operator.gt, 55 operator.floordiv, 56 operator.mod, 57 ] 58 operator_div = operator.floordiv 59 if hasattr(operator, "div"): 60 operator_div = operator.div 61 binop.append(operator_div) 62 63 for _ in range(100): 64 cv1 = random.randint(0, 255) 65 cv2 = random.randint(0, 255) 66 self.Triton.setConcreteVariableValue(self.sv1, cv1) 67 self.Triton.setConcreteVariableValue(self.sv2, cv2) 68 for op in binop: 69 n = op(self.v1, self.v2) 70 if op in (operator.floordiv, operator_div) and cv2 == 0: 71 ref = 255 72 elif op == operator.mod and cv2 == 0: 73 ref = cv1 74 else: 75 ref = op(cv1, cv2) % (2 ** 8) 76 self.assertEqual( 77 ref, 78 n.evaluate(), 79 "ref = {} and triton value = {} with operator {} operands were {} and {}".format(ref, n.evaluate(), op, cv1, cv2) 80 ) 81 self.assertEqual(ref, self.Triton.evaluateAstViaZ3(n)) 82 self.assertEqual(ref, self.Triton.simplify(n, True).evaluate()) 83 84 def test_unop(self): 85 """ 86 Check python unary operation. 87 88 Fuzz int8/uint8 binop values and check triton/z3 and python results. 89 """ 90 # No simplification available 91 # This only going to test Triton <-> z3 AST conversions. 92 unop = [ 93 operator.invert, 94 operator.neg, 95 ] 96 97 for cv1 in range(0, 256): 98 self.Triton.setConcreteVariableValue(self.sv1, cv1) 99 for op in unop: 100 n = op(self.v1) 101 ref = op(cv1) % (2 ** 8) 102 self.assertEqual(ref, n.evaluate(), 103 "ref = {} and triton value = {} with operator " 104 "{} operands was {}".format(ref, 105 n.evaluate(), 106 op, 107 cv1)) 108 self.assertEqual(ref, self.Triton.evaluateAstViaZ3(n)) 109 self.assertEqual(ref, self.Triton.simplify(n, True).evaluate()) 110 111 def test_smtbinop(self): 112 """ 113 Check smt binary operation. 114 115 Fuzz int8/uint8 binop values and check triton/z3 and python results. 116 """ 117 # No simplification available 118 # This only going to test Triton <-> z3 AST conversions. 119 smtbinop = [ 120 # AST API 121 self.astCtxt.bvadd, 122 self.astCtxt.bvand, 123 self.astCtxt.bvlshr, 124 self.astCtxt.bvashr, 125 self.astCtxt.bvmul, 126 self.astCtxt.bvnand, 127 self.astCtxt.bvnor, 128 self.astCtxt.bvor, 129 self.astCtxt.bvsdiv, 130 self.astCtxt.bvsge, 131 self.astCtxt.bvsgt, 132 self.astCtxt.bvshl, 133 self.astCtxt.bvsle, 134 self.astCtxt.bvslt, 135 self.astCtxt.bvsmod, 136 self.astCtxt.bvsrem, 137 self.astCtxt.bvsub, 138 self.astCtxt.bvudiv, 139 self.astCtxt.bvuge, 140 self.astCtxt.bvugt, 141 self.astCtxt.bvule, 142 self.astCtxt.bvult, 143 self.astCtxt.bvurem, 144 self.astCtxt.bvxnor, 145 self.astCtxt.bvxor, 146 self.astCtxt.concat, 147 self.astCtxt.distinct, 148 self.astCtxt.equal, 149 self.astCtxt.iff, 150 self.astCtxt.land, 151 self.astCtxt.lor, 152 self.astCtxt.lxor, 153 ] 154 155 for _ in range(100): 156 cv1 = random.randint(0, 255) 157 cv2 = random.randint(0, 255) 158 self.Triton.setConcreteVariableValue(self.sv1, cv1) 159 self.Triton.setConcreteVariableValue(self.sv2, cv2) 160 for op in smtbinop: 161 if op == self.astCtxt.concat: 162 n = op([self.v1, self.v2]) 163 elif op in (self.astCtxt.land, self.astCtxt.lor, self.astCtxt.lxor): 164 n = op([self.v1 != cv1, self.v2 != cv2]) 165 elif op == self.astCtxt.iff: 166 n = op(self.v1 > cv1, self.v2 < cv2) 167 else: 168 n = op(self.v1, self.v2) 169 self.assertEqual( 170 n.evaluate(), 171 self.Triton.evaluateAstViaZ3(n), 172 "triton = {} and z3 = {} with operator {} operands were {} and {}".format(n.evaluate(), self.Triton.evaluateAstViaZ3(n), op, cv1, cv2) 173 ) 174 self.assertEqual( 175 n.evaluate(), 176 self.Triton.simplify(n, True).evaluate(), 177 "triton = {} and z3 = {} with operator {} operands were {} and {}".format(n.evaluate(), self.Triton.simplify(n, True).evaluate(), op, cv1, cv2) 178 ) 179 180 def test_smt_unop(self): 181 """ 182 Check python unary operation. 183 184 Fuzz int8/uint8 binop values and check triton/z3 and python results. 185 """ 186 # No simplification available 187 # This only going to test Triton <-> z3 AST conversions. 188 smtunop = [ 189 self.astCtxt.bvneg, 190 self.astCtxt.bvnot, 191 self.astCtxt.lnot, 192 lambda x: self.astCtxt.bvrol(x, self.astCtxt.bv(2, x.getBitvectorSize())), 193 lambda x: self.astCtxt.bvror(x, self.astCtxt.bv(3, x.getBitvectorSize())), 194 lambda x: self.astCtxt.sx(16, x), 195 lambda x: self.astCtxt.zx(16, x), 196 ] 197 198 for cv1 in range(0, 256): 199 self.Triton.setConcreteVariableValue(self.sv1, cv1) 200 for op in smtunop: 201 if op == self.astCtxt.lnot: 202 n = op(self.v1 != 0) 203 else: 204 n = op(self.v1) 205 self.assertEqual(n.evaluate(), self.Triton.evaluateAstViaZ3(n)) 206 self.assertEqual(n.evaluate(), self.Triton.simplify(n, True).evaluate()) 207 208 def test_bvnode(self): 209 """Check python bit vector declaration.""" 210 for _ in range(100): 211 cv1 = random.randint(-127, 255) 212 n = self.astCtxt.bv(cv1, 8) 213 self.assertEqual(n.evaluate(), self.Triton.evaluateAstViaZ3(n)) 214 self.assertEqual(n.evaluate(), self.Triton.simplify(n, True).evaluate()) 215 216 def test_extract(self): 217 """Check bit extraction from bitvector.""" 218 for _ in range(100): 219 cv1 = random.randint(0, 255) 220 self.Triton.setConcreteVariableValue(self.sv1, cv1) 221 for lo in range(0, 8): 222 for hi in range(lo, 8): 223 n = self.astCtxt.extract(hi, lo, self.v1) 224 ref = ((cv1 << (7 - hi)) % 256) >> (7 - hi + lo) 225 self.assertEqual(ref, n.evaluate(), 226 "ref = {} and triton value = {} with operator" 227 "'extract' operands was {} low was : {} and " 228 "hi was : {}".format(ref, n.evaluate(), cv1, lo, hi)) 229 self.assertEqual(ref, self.Triton.evaluateAstViaZ3(n)) 230 self.assertEqual(ref, self.Triton.simplify(n, True).evaluate()) 231 232 def test_ite(self): 233 """Check ite node.""" 234 for _ in range(100): 235 cv1 = random.randint(0, 255) 236 cv2 = random.randint(0, 255) 237 self.Triton.setConcreteVariableValue(self.sv1, cv1) 238 self.Triton.setConcreteVariableValue(self.sv2, cv2) 239 n = self.astCtxt.ite(self.v1 < self.v2, self.v1, self.v2) 240 self.assertEqual(n.evaluate(), self.Triton.evaluateAstViaZ3(n)) 241 self.assertEqual(n.evaluate(), self.Triton.simplify(n, True).evaluate()) 242 243 @utils.xfail 244 def test_integer(self): 245 # Decimal node is not exported in the python interface 246 for cv1 in range(0, 256): 247 n = self.astCtxt.integer(cv1) 248 self.assertEqual(n.evaluate(), self.Triton.evaluateAstViaZ3(n)) 249 self.assertEqual(n.evaluate(), self.Triton.simplify(n, True).evaluate()) 250 251 @utils.xfail 252 def test_let(self): 253 # Let node didn't take the variable in its computation 254 for run in range(100): 255 cv1 = random.randint(0, 255) 256 cv2 = random.randint(0, 255) 257 self.Triton.setConcreteVariableValue(self.sv1, cv1) 258 self.Triton.setConcreteVariableValue(self.sv2, cv2) 259 n = self.astCtxt.let("b", self.astCtxt.bvadd(self.v1, self.v2), self.astCtxt.bvadd(self.astCtxt.string("b"), self.v1)) 260 self.assertEqual(n.evaluate(), self.Triton.evaluateAstViaZ3(n)) 261 self.assertEqual(n.evaluate(), self.Triton.simplify(n, True).evaluate()) 262 263 def test_fuzz(self): 264 """ 265 Fuzz test an ast evaluation. 266 267 It creates an ast node of depth 10 and evaluate it with triton and z3 268 and compare result. 269 """ 270 self.in_bool = [ 271 (self.astCtxt.lnot, 1), 272 (self.astCtxt.land, 2), 273 (self.astCtxt.lor, 2), 274 (self.astCtxt.lxor, 2), 275 (self.astCtxt.iff, 2), 276 ] 277 self.to_bool = [ 278 (self.astCtxt.bvsge, 2), 279 (self.astCtxt.bvsgt, 2), 280 (self.astCtxt.bvsle, 2), 281 (self.astCtxt.bvslt, 2), 282 (self.astCtxt.bvuge, 2), 283 (self.astCtxt.bvugt, 2), 284 (self.astCtxt.bvule, 2), 285 (self.astCtxt.bvult, 2), 286 (self.astCtxt.equal, 2), 287 ] + self.in_bool 288 self.bvop = [ 289 (self.astCtxt.bvneg, 1), 290 (self.astCtxt.bvnot, 1), 291 (lambda x: self.astCtxt.bvrol(x, self.astCtxt.bv(3, x.getBitvectorSize())), 1), 292 (lambda x: self.astCtxt.bvror(x, self.astCtxt.bv(2, x.getBitvectorSize())), 1), 293 (lambda x: self.astCtxt.extract(11, 4, self.astCtxt.sx(16, x)), 1), 294 (lambda x: self.astCtxt.extract(11, 4, self.astCtxt.zx(16, x)), 1), 295 296 # BinOp 297 (self.astCtxt.bvadd, 2), 298 (self.astCtxt.bvand, 2), 299 (self.astCtxt.bvlshr, 2), 300 (self.astCtxt.bvashr, 2), 301 (self.astCtxt.bvmul, 2), 302 (self.astCtxt.bvnand, 2), 303 (self.astCtxt.bvnor, 2), 304 (self.astCtxt.bvor, 2), 305 (self.astCtxt.bvsdiv, 2), 306 (self.astCtxt.bvshl, 2), 307 (self.astCtxt.bvsmod, 2), 308 (self.astCtxt.bvsrem, 2), 309 (self.astCtxt.bvsub, 2), 310 (self.astCtxt.bvudiv, 2), 311 (self.astCtxt.bvurem, 2), 312 (self.astCtxt.bvxnor, 2), 313 (self.astCtxt.bvxor, 2), 314 (lambda x, y: self.astCtxt.concat([self.astCtxt.extract(3, 0, x), self.astCtxt.extract(7, 4, y)]), 2), 315 316 (self.astCtxt.ite, -1), 317 318 # value 319 (self.v1, 0), 320 (self.v2, 0), 321 ] 322 for _ in range(10): 323 n = self.new_node(0, self.bvop) 324 for _ in range(10): 325 cv1 = random.randint(0, 255) 326 cv2 = random.randint(0, 255) 327 self.Triton.setConcreteVariableValue(self.sv1, cv1) 328 self.Triton.setConcreteVariableValue(self.sv2, cv2) 329 self.assertEqual(n.evaluate(), self.Triton.evaluateAstViaZ3(n)) 330 331 def new_node(self, depth, possible): 332 """Recursive function to create a random ast.""" 333 if depth >= 10: 334 # shortcut if the tree is deep enough 335 possible = possible[-2:] 336 337 op, nargs = random.choice(possible) 338 if op == self.astCtxt.ite: 339 return op(self.new_node(depth, self.to_bool), 340 self.new_node(depth + 1, self.bvop), 341 self.new_node(depth + 1, self.bvop)) 342 elif any(op == ibo for ibo, _ in self.in_bool): 343 args = [self.new_node(depth, self.to_bool) for _ in range(nargs)] 344 if op in (self.astCtxt.land, self.astCtxt.lor, self.astCtxt.lxor): 345 return op(args) 346 else: 347 return op(*args) 348 elif nargs == 0: 349 return op 350 else: 351 return op(*[self.new_node(depth + 1, self.bvop) for _ in range(nargs)]) 352 353 354class TestUnrollAst(unittest.TestCase): 355 356 """Testing unroll AST.""" 357 358 def setUp(self): 359 """Define the arch.""" 360 self.ctx = TritonContext() 361 self.ctx.setArchitecture(ARCH.X86_64) 362 self.ast = self.ctx.getAstContext() 363 364 def test_1(self): 365 self.ctx.processing(Instruction(b"\x48\xc7\xc0\x01\x00\x00\x00")) # mov rax, 1 366 self.ctx.processing(Instruction(b"\x48\x89\xc3")) # mov rbx, rax 367 self.ctx.processing(Instruction(b"\x48\x89\xd9")) # mov rcx, rbx 368 self.ctx.processing(Instruction(b"\x48\x89\xca")) # mov rdx, rcx 369 rdx = self.ctx.getRegisterAst(self.ctx.registers.rdx) 370 self.assertEqual(str(rdx), "ref!6") 371 self.assertEqual(str(self.ast.unroll(rdx)), "(_ bv1 64)") 372 return 373 374 def test_2(self): 375 self.ctx.processing(Instruction(b"\x48\xc7\xc0\x01\x00\x00\x00")) # mov rax, 1 376 self.ctx.processing(Instruction(b"\x48\x31\xc0")) # xor rax, rax 377 rax = self.ctx.getRegisterAst(self.ctx.registers.rax) 378 self.assertEqual(str(rax), "ref!2") 379 self.assertEqual(str(self.ast.unroll(rax)), "(bvxor (_ bv1 64) (_ bv1 64))") 380 return 381 382 def test_3(self): 383 self.ctx.processing(Instruction(b"\x48\xc7\xc0\x01\x00\x00\x00")) # mov rax, 1 384 self.ctx.processing(Instruction(b"\x48\xc7\xc3\x02\x00\x00\x00")) # mov rbx, 2 385 self.ctx.processing(Instruction(b"\x48\x31\xd8")) # xor rax, rbx 386 self.ctx.processing(Instruction(b"\x48\xff\xc0")) # inc rax 387 self.ctx.processing(Instruction(b"\x48\x89\xc2")) # mov rdx, rax 388 rdx = self.ctx.getRegisterAst(self.ctx.registers.rdx) 389 self.assertEqual(str(rdx), "ref!18") 390 self.assertEqual(str(self.ast.unroll(rdx)), "(bvadd (bvxor (_ bv1 64) (_ bv2 64)) (_ bv1 64))") 391 ref4 = self.ctx.getSymbolicExpression(4) 392 self.assertEqual(str(ref4.getAst()), "(bvxor ref!0 ref!2)") 393 return 394 395 396class TestAstTraversal(unittest.TestCase): 397 398 """Testing AST traversal.""" 399 def setUp(self): 400 """Define the arch.""" 401 self.ctx = TritonContext() 402 self.ctx.setArchitecture(ARCH.X86_64) 403 self.ast = self.ctx.getAstContext() 404 405 def test_1(self): 406 a = self.ast.bv(1, 8) 407 b = self.ast.bv(2, 8) 408 c = a ^ b 409 d = c + a 410 e = d + b 411 f = e + e 412 g = f + b 413 ref1 = self.ast.reference(self.ctx.newSymbolicExpression(g)) 414 ref2 = self.ast.reference(self.ctx.newSymbolicExpression(a)) 415 k = ref1 + ref2 416 self.assertEqual(k.evaluate(), self.ctx.evaluateAstViaZ3(k)) 417