1import logging
2
3import claripy
4
5from . import Analysis
6
7l = logging.getLogger(name=__name__)
8#l.setLevel(logging.DEBUG)
9
10
11class CongruencyCheck(Analysis):
12    """
13    This is an analysis to ensure that angr executes things identically with different execution backends (i.e., unicorn vs vex).
14    """
15
16    def __init__(self, throw=False):
17        """
18        Initializes a CongruencyCheck analysis.
19
20        :param throw: whether to raise an exception if an incongruency is found.
21        """
22        self._throw = throw
23        self.simgr = None
24        self.prev_pg = None
25
26    def set_state_options(self, left_add_options=None, left_remove_options=None, right_add_options=None, right_remove_options=None):
27        """
28        Checks that the specified state options result in the same states over the next `depth` states.
29        """
30        s_right = self.project.factory.full_init_state(
31            add_options=right_add_options, remove_options=right_remove_options,
32            args=[],
33        )
34        s_left = self.project.factory.full_init_state(
35            add_options=left_add_options, remove_options=left_remove_options,
36            args=[],
37        )
38
39        return self.set_states(s_left, s_right)
40
41    def set_states(self, left_state, right_state):
42        """
43        Checks that the specified paths stay the same over the next `depth` states.
44        """
45
46        simgr = self.project.factory.simulation_manager(right_state)
47        simgr.stash(to_stash='right')
48        simgr.active.append(left_state)
49        simgr.stash(to_stash='left')
50        simgr.stash(to_stash='stashed_left')
51        simgr.stash(to_stash='stashed_right')
52
53        return self.set_simgr(simgr)
54
55    def set_simgr(self, simgr):
56        self.simgr = simgr
57        return self
58
59    @staticmethod
60    def _sync_steps(simgr, max_steps=None):
61        l.debug("Sync-stepping pathgroup...")
62        l.debug(
63            "... left width: %s, right width: %s",
64            simgr.left[0].history.block_count if len(simgr.left) > 0 else None,
65            simgr.right[0].history.block_count if len(simgr.right) > 0 else None,
66        )
67
68        if len(simgr.errored) != 0 and (len(simgr.left) == 0 or len(simgr.right) == 0):
69            l.debug("... looks like a path errored")
70            return simgr
71        if len(simgr.left) == 0 and len(simgr.right) != 0:
72            l.debug("... left is deadended; stepping right %s times", max_steps)
73            npg = simgr.run(stash='right', n=max_steps)
74        elif len(simgr.right) == 0 and len(simgr.left) != 0:
75            l.debug("... right is deadended; stepping left %s times", max_steps)
76            npg = simgr.run(stash='left', n=max_steps)
77        elif len(simgr.right) == 0 and len(simgr.left) == 0:
78            l.debug("... both deadended.")
79            return simgr
80        elif simgr.left[0].history.block_count == simgr.right[0].history.block_count:
81            l.debug("... synced")
82            return simgr
83        elif simgr.left[0].history.block_count < simgr.right[0].history.block_count:
84            l.debug("... right is ahead; stepping left %s times",
85                    simgr.right[0].history.block_count - simgr.left[0].history.block_count)
86            npg = simgr.run(
87                stash='left',
88                until=lambda lpg: lpg.left[0].history.block_count >= simgr.right[0].history.block_count,
89                n=max_steps
90            )
91        elif simgr.right[0].history.block_count < simgr.left[0].history.block_count:
92            l.debug("... left is ahead; stepping right %s times",
93                    simgr.left[0].history.block_count - simgr.right[0].history.block_count)
94            npg = simgr.run(
95                stash='right',
96                until=lambda lpg: lpg.right[0].history.block_count >= simgr.left[0].history.block_count,
97                n=max_steps
98            )
99
100        return CongruencyCheck._sync_steps(npg)
101
102    def _validate_incongruency(self):
103        """
104        Checks that a detected incongruency is not caused by translation backends having a different
105        idea of what constitutes a basic block.
106        """
107
108        ot = self._throw
109
110        try:
111            self._throw = False
112            l.debug("Validating incongruency.")
113
114            if ("UNICORN" in self.simgr.right[0].options) ^ ("UNICORN" in self.simgr.left[0].options):
115                if "UNICORN" in self.simgr.right[0].options:
116                    unicorn_stash = 'right'
117                    normal_stash = 'left'
118                else:
119                    unicorn_stash = 'left'
120                    normal_stash = 'right'
121
122                unicorn_path = self.simgr.stashes[unicorn_stash][0]
123                normal_path = self.simgr.stashes[normal_stash][0]
124
125                if unicorn_path.arch.name in ("X86", "AMD64"):
126                    # unicorn "falls behind" on loop and rep instructions, since
127                    # it sees them as ending a basic block. Here, we will
128                    # step the unicorn until it's caught up
129                    npg = self.project.factory.simulation_manager(unicorn_path)
130                    npg.explore(find=lambda p: p.addr == normal_path.addr, n=200)
131                    if len(npg.found) == 0:
132                        l.debug("Validator failed to sync paths.")
133                        return True
134
135                    new_unicorn = npg.found[0]
136                    delta = new_unicorn.history.block_count - normal_path.history.block_count
137                    normal_path.history.recent_block_count += delta
138                    new_normal = normal_path
139                elif unicorn_path.arch.name == "MIPS32":
140                    # unicorn gets ahead here, because VEX falls behind for unknown reasons
141                    # for example, this block:
142                    #
143                    # 0x1016f20:      lui     $gp, 0x17
144                    # 0x1016f24:      addiu   $gp, $gp, -0x35c0
145                    # 0x1016f28:      addu    $gp, $gp, $t9
146                    # 0x1016f2c:      addiu   $sp, $sp, -0x28
147                    # 0x1016f30:      sw      $ra, 0x24($sp)
148                    # 0x1016f34:      sw      $s0, 0x20($sp)
149                    # 0x1016f38:      sw      $gp, 0x10($sp)
150                    # 0x1016f3c:      lw      $v0, -0x6cf0($gp)
151                    # 0x1016f40:      move    $at, $at
152                    npg = self.project.factory.simulation_manager(normal_path)
153                    npg.explore(find=lambda p: p.addr == unicorn_path.addr, n=200)
154                    if len(npg.found) == 0:
155                        l.debug("Validator failed to sync paths.")
156                        return True
157
158                    new_normal = npg.found[0]
159                    delta = new_normal.history.block_count - unicorn_path.history.block_count
160                    unicorn_path.history.recent_block_count += delta
161                    new_unicorn = unicorn_path
162                else:
163                    l.debug("Dunno!")
164                    return True
165
166                if self.compare_paths(new_unicorn, new_normal):
167                    l.debug("Divergence accounted for by unicorn.")
168                    self.simgr.stashes[unicorn_stash][0] = new_unicorn
169                    self.simgr.stashes[normal_stash][0] = new_normal
170                    return False
171                else:
172                    l.warning("Divergence unaccounted for by unicorn.")
173                    return True
174            else:
175                # no idea
176                l.warning("Divergence unaccounted for.")
177                return True
178        finally:
179            self._throw = ot
180
181    def _report_incongruency(self, *args):
182        l.warning(*args)
183        if self._throw:
184            raise AngrIncongruencyError(*args)
185
186    def run(self, depth=None):
187        """
188        Checks that the paths in the specified path group stay the same over the next
189        `depth` bytes.
190
191        The path group should have a "left" and a "right" stash, each with a single
192        path.
193        """
194        #pg_history = [ ]
195        if len(self.simgr.right) != 1 or len(self.simgr.left) != 1:
196            self._report_incongruency("Single path in pg.left and pg.right required.")
197            return False
198
199        if "UNICORN" in self.simgr.one_right.options and depth is not None:
200            self.simgr.one_right.unicorn.max_steps = depth
201
202        if "UNICORN" in self.simgr.one_left.options and depth is not None:
203            self.simgr.one_left.unicorn.max_steps = depth
204
205        l.debug("Performing initial path comparison.")
206        if not self.compare_paths(self.simgr.left[0], self.simgr.right[0]):
207            self._report_incongruency("Initial path comparison check failed.")
208            return False
209
210        while len(self.simgr.left) > 0 and len(self.simgr.right) > 0:
211            if depth is not None:
212                self._update_progress(100. * float(self.simgr.one_left.history.block_count) / depth)
213
214            if len(self.simgr.deadended) != 0:
215                self._report_incongruency("Unexpected deadended paths before step.")
216                return False
217            if len(self.simgr.right) == 0 and len(self.simgr.left) == 0:
218                l.debug("All done!")
219                return True
220            if len(self.simgr.right) != 1 or len(self.simgr.left) != 1:
221                self._report_incongruency("Different numbers of paths in left and right stash..")
222                return False
223
224            # do a step
225            l.debug(
226                "Stepping right path with weighted length %d/%d",
227                self.simgr.right[0].history.block_count,
228                depth
229            )
230            self.prev_pg = self.simgr.copy() #pylint:disable=unused-variable
231            self.simgr.step(stash='right')
232            CongruencyCheck._sync_steps(self.simgr)
233
234            if len(self.simgr.errored) != 0:
235                self._report_incongruency("Unexpected errored paths.")
236                return False
237
238            try:
239                if not self.compare_path_group(self.simgr) and self._validate_incongruency():
240                    self._report_incongruency("Path group comparison failed.")
241                    return False
242            except AngrIncongruencyError:
243                if self._validate_incongruency():
244                    raise
245
246            if depth is not None:
247                self.simgr.drop(stash='left', filter_func=lambda p: p.history.block_count >= depth)
248                self.simgr.drop(stash='right', filter_func=lambda p: p.history.block_count >= depth)
249
250            self.simgr.right.sort(key=lambda p: p.addr)
251            self.simgr.left.sort(key=lambda p: p.addr)
252            self.simgr.stashed_right[:] = self.simgr.stashed_right[::-1]
253            self.simgr.stashed_left[:] = self.simgr.stashed_left[::-1]
254            self.simgr.move('stashed_right', 'right')
255            self.simgr.move('stashed_left', 'left')
256
257            if len(self.simgr.left) > 1:
258                self.simgr.split(from_stash='left', limit=1, to_stash='stashed_left')
259                self.simgr.split(from_stash='right', limit=1, to_stash='stashed_right')
260
261    def compare_path_group(self, pg):
262        if len(pg.left) != len(pg.right):
263            self._report_incongruency("Number of left and right paths differ.")
264            return False
265        if len(pg.deadended) % 2 != 0:
266            self._report_incongruency("Odd number of deadended paths after step.")
267            return False
268        pg.drop(stash='deadended')
269
270        if len(pg.left) == 0 and len(pg.right) == 0:
271            return True
272
273        # make sure the paths are the same
274        for pl,pr in zip(sorted(pg.left, key=lambda p: p.addr), sorted(pg.right, key=lambda p: p.addr)):
275            if not self.compare_paths(pl, pr):
276                self._report_incongruency("Differing paths.")
277                return False
278
279        return True
280
281    def compare_states(self, sl, sr):
282        """
283        Compares two states for similarity.
284        """
285        joint_solver = claripy.Solver()
286
287        # make sure the canonicalized constraints are the same
288        n_map, n_counter, n_canon_constraint = claripy.And(*sr.solver.constraints).canonicalize() #pylint:disable=no-member
289        u_map, u_counter, u_canon_constraint = claripy.And(*sl.solver.constraints).canonicalize() #pylint:disable=no-member
290        if n_canon_constraint is not u_canon_constraint:
291            # https://github.com/Z3Prover/z3/issues/2359
292            # don't try to simplify unless we really need to, as it can introduce serious nondeterminism
293            n_canoner_constraint = sr.solver.simplify(n_canon_constraint)
294            u_canoner_constraint = sl.solver.simplify(u_canon_constraint)
295        else:
296            n_canoner_constraint = u_canoner_constraint = n_canon_constraint
297        joint_solver.add((n_canoner_constraint, u_canoner_constraint))
298        if n_canoner_constraint is not u_canoner_constraint:
299            # extra check: are these two constraints equivalent?
300            tmp_solver = claripy.Solver()
301            a = tmp_solver.satisfiable(extra_constraints=(n_canoner_constraint == u_canoner_constraint,))
302            b = tmp_solver.satisfiable(extra_constraints=(n_canoner_constraint != u_canoner_constraint,))
303
304            if not (a is True and b is False):
305                self._report_incongruency("Different constraints!")
306                return False
307
308        # get the differences in registers and memory
309        mem_diff = sr.memory.changed_bytes(sl.memory)
310        reg_diff = sr.registers.changed_bytes(sl.registers)
311
312        # this is only for unicorn
313        if "UNICORN" in sl.options or "UNICORN" in sr.options:
314            if sl.arch.name == "X86":
315                reg_diff -= set(range(40, 52)) #ignore cc psuedoregisters
316                reg_diff -= set(range(320, 324)) #some other VEX weirdness
317                reg_diff -= set(range(340, 344)) #ip_at_syscall
318            elif sl.arch.name == "AMD64":
319                reg_diff -= set(range(144, 168)) #ignore cc psuedoregisters
320
321        # make sure the differences in registers and memory are actually just renamed
322        # versions of the same ASTs
323        for diffs,(um,nm) in (
324            (reg_diff, (sl.registers, sr.registers)),
325            (mem_diff, (sl.memory, sr.memory)),
326        ):
327            for i in diffs:
328                bn = nm.load(i, 1)
329                bu = um.load(i, 1)
330
331                bnc = bn.canonicalize(var_map=n_map, counter=n_counter)[-1]
332                buc = bu.canonicalize(var_map=u_map, counter=u_counter)[-1]
333
334                if bnc is not buc:
335                    self._report_incongruency("Different memory or registers (index %d, values %r and %r)!", i, bn, bu)
336                    return False
337
338        # make sure the flags are the same
339        if sl.arch.name in ("AMD64", "X86", "ARM", "ARMEL", "ARMHF", "AARCH64"):
340            # pylint: disable=unused-variable
341            n_bkp = sr.regs.cc_op, sr.regs.cc_dep1, sr.regs.cc_dep2, sr.regs.cc_ndep
342            u_bkp = sl.regs.cc_op, sl.regs.cc_dep1, sl.regs.cc_dep2, sl.regs.cc_ndep
343            if sl.arch.name in ('AMD64', 'X86'):
344                n_flags = sr.regs.eflags.canonicalize(var_map=n_map, counter=n_counter)[-1]
345                u_flags = sl.regs.eflags.canonicalize(var_map=u_map, counter=u_counter)[-1]
346            else:
347                n_flags = sr.regs.flags.canonicalize(var_map=n_map, counter=n_counter)[-1]
348                u_flags = sl.regs.flags.canonicalize(var_map=u_map, counter=u_counter)[-1]
349            if n_flags is not u_flags and sl.solver.simplify(n_flags) is not sr.solver.simplify(u_flags):
350                self._report_incongruency("Different flags!")
351                return False
352
353        return True
354
355    def compare_paths(self, pl, pr):
356        l.debug("Comparing paths...")
357        if not self.compare_states(pl, pr):
358            self._report_incongruency("Failed state similarity check!")
359            return False
360
361        if pr.history.block_count != pl.history.block_count:
362            self._report_incongruency("Different weights!")
363            return False
364
365        if pl.addr != pr.addr:
366            self._report_incongruency("Different addresses!")
367            return False
368
369        return True
370
371from ..errors import AngrIncongruencyError
372from angr.analyses import AnalysesHub
373AnalysesHub.register_default('CongruencyCheck', CongruencyCheck)
374