1#!/usr/bin/env python
2# coding: utf-8
3"""Test AST conversion."""
4
5import operator
6import random
7import unittest
8import utils
9
10from triton import *
11
12
13
14class TestAstConversion(unittest.TestCase):
15
16    """Testing the AST conversion Triton <-> z3."""
17
18    def setUp(self):
19        """Define the arch."""
20        self.Triton = TritonContext()
21        self.Triton.setArchitecture(ARCH.X86_64)
22
23        self.astCtxt = self.Triton.getAstContext()
24
25        self.sv1 = self.Triton.newSymbolicVariable(8)
26        self.sv2 = self.Triton.newSymbolicVariable(8)
27
28        self.v1 = self.astCtxt.variable(self.sv1)
29        self.v2 = self.astCtxt.variable(self.sv2)
30
31    def test_binop(self):
32        """
33        Check python binary operation.
34
35        Fuzz int8/uint8 binop values and check triton/z3 and python results.
36        """
37        # No simplification available
38        # This only going to test Triton <-> z3 AST conversions.
39        binop = [
40            # Overloaded operators
41            operator.and_,
42            operator.add,
43            operator.sub,
44            operator.xor,
45            operator.or_,
46            operator.mul,
47            operator.lshift,
48            operator.rshift,
49            operator.eq,
50            operator.ne,
51            operator.le,
52            operator.ge,
53            operator.lt,
54            operator.gt,
55            operator.floordiv,
56            operator.mod,
57        ]
58        operator_div = operator.floordiv
59        if hasattr(operator, "div"):
60            operator_div = operator.div
61            binop.append(operator_div)
62
63        for _ in range(100):
64            cv1 = random.randint(0, 255)
65            cv2 = random.randint(0, 255)
66            self.Triton.setConcreteVariableValue(self.sv1, cv1)
67            self.Triton.setConcreteVariableValue(self.sv2, cv2)
68            for op in binop:
69                n = op(self.v1, self.v2)
70                if op in (operator.floordiv, operator_div) and cv2 == 0:
71                    ref = 255
72                elif op == operator.mod and cv2 == 0:
73                    ref = cv1
74                else:
75                    ref = op(cv1, cv2) % (2 ** 8)
76                self.assertEqual(
77                    ref,
78                    n.evaluate(),
79                    "ref = {} and triton value = {} with operator {} operands were {} and {}".format(ref, n.evaluate(), op, cv1, cv2)
80                )
81                self.assertEqual(ref, self.Triton.evaluateAstViaZ3(n))
82                self.assertEqual(ref, self.Triton.simplify(n, True).evaluate())
83
84    def test_unop(self):
85        """
86        Check python unary operation.
87
88        Fuzz int8/uint8 binop values and check triton/z3 and python results.
89        """
90        # No simplification available
91        # This only going to test Triton <-> z3 AST conversions.
92        unop = [
93            operator.invert,
94            operator.neg,
95        ]
96
97        for cv1 in range(0, 256):
98            self.Triton.setConcreteVariableValue(self.sv1, cv1)
99            for op in unop:
100                n = op(self.v1)
101                ref = op(cv1) % (2 ** 8)
102                self.assertEqual(ref, n.evaluate(),
103                                 "ref = {} and triton value = {} with operator "
104                                 "{} operands was {}".format(ref,
105                                                             n.evaluate(),
106                                                             op,
107                                                             cv1))
108                self.assertEqual(ref, self.Triton.evaluateAstViaZ3(n))
109                self.assertEqual(ref, self.Triton.simplify(n, True).evaluate())
110
111    def test_smtbinop(self):
112        """
113        Check smt binary operation.
114
115        Fuzz int8/uint8 binop values and check triton/z3 and python results.
116        """
117        # No simplification available
118        # This only going to test Triton <-> z3 AST conversions.
119        smtbinop = [
120            # AST API
121            self.astCtxt.bvadd,
122            self.astCtxt.bvand,
123            self.astCtxt.bvlshr,
124            self.astCtxt.bvashr,
125            self.astCtxt.bvmul,
126            self.astCtxt.bvnand,
127            self.astCtxt.bvnor,
128            self.astCtxt.bvor,
129            self.astCtxt.bvsdiv,
130            self.astCtxt.bvsge,
131            self.astCtxt.bvsgt,
132            self.astCtxt.bvshl,
133            self.astCtxt.bvsle,
134            self.astCtxt.bvslt,
135            self.astCtxt.bvsmod,
136            self.astCtxt.bvsrem,
137            self.astCtxt.bvsub,
138            self.astCtxt.bvudiv,
139            self.astCtxt.bvuge,
140            self.astCtxt.bvugt,
141            self.astCtxt.bvule,
142            self.astCtxt.bvult,
143            self.astCtxt.bvurem,
144            self.astCtxt.bvxnor,
145            self.astCtxt.bvxor,
146            self.astCtxt.concat,
147            self.astCtxt.distinct,
148            self.astCtxt.equal,
149            self.astCtxt.iff,
150            self.astCtxt.land,
151            self.astCtxt.lor,
152            self.astCtxt.lxor,
153        ]
154
155        for _ in range(100):
156            cv1 = random.randint(0, 255)
157            cv2 = random.randint(0, 255)
158            self.Triton.setConcreteVariableValue(self.sv1, cv1)
159            self.Triton.setConcreteVariableValue(self.sv2, cv2)
160            for op in smtbinop:
161                if op == self.astCtxt.concat:
162                    n = op([self.v1, self.v2])
163                elif op in (self.astCtxt.land, self.astCtxt.lor, self.astCtxt.lxor):
164                    n = op([self.v1 != cv1, self.v2 != cv2])
165                elif op == self.astCtxt.iff:
166                    n = op(self.v1 > cv1, self.v2 < cv2)
167                else:
168                    n = op(self.v1, self.v2)
169                self.assertEqual(
170                    n.evaluate(),
171                    self.Triton.evaluateAstViaZ3(n),
172                    "triton = {} and z3 = {} with operator {} operands were {} and {}".format(n.evaluate(), self.Triton.evaluateAstViaZ3(n), op, cv1, cv2)
173                )
174                self.assertEqual(
175                    n.evaluate(),
176                    self.Triton.simplify(n, True).evaluate(),
177                    "triton = {} and z3 = {} with operator {} operands were {} and {}".format(n.evaluate(), self.Triton.simplify(n, True).evaluate(), op, cv1, cv2)
178                )
179
180    def test_smt_unop(self):
181        """
182        Check python unary operation.
183
184        Fuzz int8/uint8 binop values and check triton/z3 and python results.
185        """
186        # No simplification available
187        # This only going to test Triton <-> z3 AST conversions.
188        smtunop = [
189            self.astCtxt.bvneg,
190            self.astCtxt.bvnot,
191            self.astCtxt.lnot,
192            lambda x: self.astCtxt.bvrol(x, self.astCtxt.bv(2, x.getBitvectorSize())),
193            lambda x: self.astCtxt.bvror(x, self.astCtxt.bv(3, x.getBitvectorSize())),
194            lambda x: self.astCtxt.sx(16, x),
195            lambda x: self.astCtxt.zx(16, x),
196        ]
197
198        for cv1 in range(0, 256):
199            self.Triton.setConcreteVariableValue(self.sv1, cv1)
200            for op in smtunop:
201                if op == self.astCtxt.lnot:
202                    n = op(self.v1 != 0)
203                else:
204                    n = op(self.v1)
205                self.assertEqual(n.evaluate(), self.Triton.evaluateAstViaZ3(n))
206                self.assertEqual(n.evaluate(), self.Triton.simplify(n, True).evaluate())
207
208    def test_bvnode(self):
209        """Check python bit vector declaration."""
210        for _ in range(100):
211            cv1 = random.randint(-127, 255)
212            n = self.astCtxt.bv(cv1, 8)
213            self.assertEqual(n.evaluate(), self.Triton.evaluateAstViaZ3(n))
214            self.assertEqual(n.evaluate(), self.Triton.simplify(n, True).evaluate())
215
216    def test_extract(self):
217        """Check bit extraction from bitvector."""
218        for _ in range(100):
219            cv1 = random.randint(0, 255)
220            self.Triton.setConcreteVariableValue(self.sv1, cv1)
221            for lo in range(0, 8):
222                for hi in range(lo, 8):
223                    n = self.astCtxt.extract(hi, lo, self.v1)
224                    ref = ((cv1 << (7 - hi)) % 256) >> (7 - hi + lo)
225                    self.assertEqual(ref, n.evaluate(),
226                                     "ref = {} and triton value = {} with operator"
227                                     "'extract' operands was {} low was : {} and "
228                                     "hi was : {}".format(ref, n.evaluate(), cv1, lo, hi))
229                    self.assertEqual(ref, self.Triton.evaluateAstViaZ3(n))
230                    self.assertEqual(ref, self.Triton.simplify(n, True).evaluate())
231
232    def test_ite(self):
233        """Check ite node."""
234        for _ in range(100):
235            cv1 = random.randint(0, 255)
236            cv2 = random.randint(0, 255)
237            self.Triton.setConcreteVariableValue(self.sv1, cv1)
238            self.Triton.setConcreteVariableValue(self.sv2, cv2)
239            n = self.astCtxt.ite(self.v1 < self.v2, self.v1, self.v2)
240            self.assertEqual(n.evaluate(), self.Triton.evaluateAstViaZ3(n))
241            self.assertEqual(n.evaluate(), self.Triton.simplify(n, True).evaluate())
242
243    @utils.xfail
244    def test_integer(self):
245        # Decimal node is not exported in the python interface
246        for cv1 in range(0, 256):
247            n = self.astCtxt.integer(cv1)
248            self.assertEqual(n.evaluate(), self.Triton.evaluateAstViaZ3(n))
249            self.assertEqual(n.evaluate(), self.Triton.simplify(n, True).evaluate())
250
251    @utils.xfail
252    def test_let(self):
253        # Let node didn't take the variable in its computation
254        for run in range(100):
255            cv1 = random.randint(0, 255)
256            cv2 = random.randint(0, 255)
257            self.Triton.setConcreteVariableValue(self.sv1, cv1)
258            self.Triton.setConcreteVariableValue(self.sv2, cv2)
259            n = self.astCtxt.let("b", self.astCtxt.bvadd(self.v1, self.v2), self.astCtxt.bvadd(self.astCtxt.string("b"), self.v1))
260            self.assertEqual(n.evaluate(), self.Triton.evaluateAstViaZ3(n))
261            self.assertEqual(n.evaluate(), self.Triton.simplify(n, True).evaluate())
262
263    def test_fuzz(self):
264        """
265        Fuzz test an ast evaluation.
266
267        It creates an ast node of depth 10 and evaluate it with triton and z3
268        and compare result.
269        """
270        self.in_bool = [
271            (self.astCtxt.lnot, 1),
272            (self.astCtxt.land, 2),
273            (self.astCtxt.lor, 2),
274            (self.astCtxt.lxor, 2),
275            (self.astCtxt.iff, 2),
276        ]
277        self.to_bool = [
278            (self.astCtxt.bvsge, 2),
279            (self.astCtxt.bvsgt, 2),
280            (self.astCtxt.bvsle, 2),
281            (self.astCtxt.bvslt, 2),
282            (self.astCtxt.bvuge, 2),
283            (self.astCtxt.bvugt, 2),
284            (self.astCtxt.bvule, 2),
285            (self.astCtxt.bvult, 2),
286            (self.astCtxt.equal, 2),
287        ] + self.in_bool
288        self.bvop = [
289            (self.astCtxt.bvneg, 1),
290            (self.astCtxt.bvnot, 1),
291            (lambda x: self.astCtxt.bvrol(x, self.astCtxt.bv(3, x.getBitvectorSize())), 1),
292            (lambda x: self.astCtxt.bvror(x, self.astCtxt.bv(2, x.getBitvectorSize())), 1),
293            (lambda x: self.astCtxt.extract(11, 4, self.astCtxt.sx(16, x)), 1),
294            (lambda x: self.astCtxt.extract(11, 4, self.astCtxt.zx(16, x)), 1),
295
296            # BinOp
297            (self.astCtxt.bvadd, 2),
298            (self.astCtxt.bvand, 2),
299            (self.astCtxt.bvlshr, 2),
300            (self.astCtxt.bvashr, 2),
301            (self.astCtxt.bvmul, 2),
302            (self.astCtxt.bvnand, 2),
303            (self.astCtxt.bvnor, 2),
304            (self.astCtxt.bvor, 2),
305            (self.astCtxt.bvsdiv, 2),
306            (self.astCtxt.bvshl, 2),
307            (self.astCtxt.bvsmod, 2),
308            (self.astCtxt.bvsrem, 2),
309            (self.astCtxt.bvsub, 2),
310            (self.astCtxt.bvudiv, 2),
311            (self.astCtxt.bvurem, 2),
312            (self.astCtxt.bvxnor, 2),
313            (self.astCtxt.bvxor, 2),
314            (lambda x, y: self.astCtxt.concat([self.astCtxt.extract(3, 0, x), self.astCtxt.extract(7, 4, y)]), 2),
315
316            (self.astCtxt.ite, -1),
317
318            # value
319            (self.v1, 0),
320            (self.v2, 0),
321        ]
322        for _ in range(10):
323            n = self.new_node(0, self.bvop)
324            for _ in range(10):
325                cv1 = random.randint(0, 255)
326                cv2 = random.randint(0, 255)
327                self.Triton.setConcreteVariableValue(self.sv1, cv1)
328                self.Triton.setConcreteVariableValue(self.sv2, cv2)
329                self.assertEqual(n.evaluate(), self.Triton.evaluateAstViaZ3(n))
330
331    def new_node(self, depth, possible):
332        """Recursive function to create a random ast."""
333        if depth >= 10:
334            # shortcut if the tree is deep enough
335            possible = possible[-2:]
336
337        op, nargs = random.choice(possible)
338        if op == self.astCtxt.ite:
339            return op(self.new_node(depth, self.to_bool),
340                      self.new_node(depth + 1, self.bvop),
341                      self.new_node(depth + 1, self.bvop))
342        elif any(op == ibo for ibo, _ in self.in_bool):
343            args = [self.new_node(depth, self.to_bool) for _ in range(nargs)]
344            if op in (self.astCtxt.land, self.astCtxt.lor, self.astCtxt.lxor):
345                return op(args)
346            else:
347                return op(*args)
348        elif nargs == 0:
349            return op
350        else:
351            return op(*[self.new_node(depth + 1, self.bvop) for _ in range(nargs)])
352
353
354class TestUnrollAst(unittest.TestCase):
355
356    """Testing unroll AST."""
357
358    def setUp(self):
359        """Define the arch."""
360        self.ctx = TritonContext()
361        self.ctx.setArchitecture(ARCH.X86_64)
362        self.ast = self.ctx.getAstContext()
363
364    def test_1(self):
365        self.ctx.processing(Instruction(b"\x48\xc7\xc0\x01\x00\x00\x00")) # mov rax, 1
366        self.ctx.processing(Instruction(b"\x48\x89\xc3")) # mov rbx, rax
367        self.ctx.processing(Instruction(b"\x48\x89\xd9")) # mov rcx, rbx
368        self.ctx.processing(Instruction(b"\x48\x89\xca")) # mov rdx, rcx
369        rdx = self.ctx.getRegisterAst(self.ctx.registers.rdx)
370        self.assertEqual(str(rdx), "ref!6")
371        self.assertEqual(str(self.ast.unroll(rdx)), "(_ bv1 64)")
372        return
373
374    def test_2(self):
375        self.ctx.processing(Instruction(b"\x48\xc7\xc0\x01\x00\x00\x00")) # mov rax, 1
376        self.ctx.processing(Instruction(b"\x48\x31\xc0")) # xor rax, rax
377        rax = self.ctx.getRegisterAst(self.ctx.registers.rax)
378        self.assertEqual(str(rax), "ref!2")
379        self.assertEqual(str(self.ast.unroll(rax)), "(bvxor (_ bv1 64) (_ bv1 64))")
380        return
381
382    def test_3(self):
383        self.ctx.processing(Instruction(b"\x48\xc7\xc0\x01\x00\x00\x00")) # mov rax, 1
384        self.ctx.processing(Instruction(b"\x48\xc7\xc3\x02\x00\x00\x00")) # mov rbx, 2
385        self.ctx.processing(Instruction(b"\x48\x31\xd8")) # xor rax, rbx
386        self.ctx.processing(Instruction(b"\x48\xff\xc0")) # inc rax
387        self.ctx.processing(Instruction(b"\x48\x89\xc2")) # mov rdx, rax
388        rdx = self.ctx.getRegisterAst(self.ctx.registers.rdx)
389        self.assertEqual(str(rdx), "ref!18")
390        self.assertEqual(str(self.ast.unroll(rdx)), "(bvadd (bvxor (_ bv1 64) (_ bv2 64)) (_ bv1 64))")
391        ref4 = self.ctx.getSymbolicExpression(4)
392        self.assertEqual(str(ref4.getAst()), "(bvxor ref!0 ref!2)")
393        return
394
395
396class TestAstTraversal(unittest.TestCase):
397
398    """Testing AST traversal."""
399    def setUp(self):
400        """Define the arch."""
401        self.ctx = TritonContext()
402        self.ctx.setArchitecture(ARCH.X86_64)
403        self.ast = self.ctx.getAstContext()
404
405    def test_1(self):
406        a = self.ast.bv(1, 8)
407        b = self.ast.bv(2, 8)
408        c = a ^ b
409        d = c + a
410        e = d + b
411        f = e + e
412        g = f + b
413        ref1 = self.ast.reference(self.ctx.newSymbolicExpression(g))
414        ref2 = self.ast.reference(self.ctx.newSymbolicExpression(a))
415        k = ref1 + ref2
416        self.assertEqual(k.evaluate(), self.ctx.evaluateAstViaZ3(k))
417