1from collections import defaultdict
2import logging
3
4import networkx
5
6from .utils.constants import DEFAULT_STATEMENT
7from .errors import AngrAnnotatedCFGError, AngrExitError
8from .knowledge_plugins.cfg import CFGNode
9
10l = logging.getLogger(name=__name__)
11
12class AnnotatedCFG:
13    """
14    AnnotatedCFG is a control flow graph with statement whitelists and exit whitelists to describe a slice of the
15    program.
16    """
17    def __init__(self, project, cfg=None, detect_loops=False):
18        """
19        Constructor.
20
21        :param project: The angr Project instance
22        :param cfg: Control flow graph.
23        :param detect_loops:
24        """
25        self._project = project
26
27        self._cfg = None
28        self._target = None
29
30        self._run_statement_whitelist = defaultdict(list)
31        self._exit_taken = defaultdict(list)
32        self._addr_to_run = {}
33        self._addr_to_last_stmt_id = {}
34        self._loops = []
35        self._path_merge_points = [ ]
36
37        if cfg is not None:
38            self._cfg = cfg
39
40        if self._cfg is not None:
41            for run in self._cfg.model.nodes():
42                self._addr_to_run[self.get_addr(run)] = run
43
44    #
45    # Public methods
46    #
47
48    def from_digraph(self, digraph):
49        """
50        Initialize this AnnotatedCFG object with a networkx.DiGraph consisting of the following
51        form of nodes:
52
53        Tuples like (block address, statement ID)
54
55        Those nodes are connected by edges indicating the execution flow.
56
57        :param networkx.DiGraph digraph: A networkx.DiGraph object
58        """
59
60        for n1 in digraph.nodes():
61            addr1, stmt_idx1 = n1
62            self.add_statements_to_whitelist(addr1, (stmt_idx1,))
63
64            successors = digraph[n1]
65            for n2 in successors:
66                addr2, stmt_idx2 = n2
67
68                if addr1 != addr2:
69                    # There is a control flow transition from block `addr1` to block `addr2`
70                    self.add_exit_to_whitelist(addr1, addr2)
71
72                self.add_statements_to_whitelist(addr2, (stmt_idx2,))
73
74    def get_addr(self, run):
75        if isinstance(run, CFGNode):
76            return run.addr
77        elif type(run) is int:
78            return run
79        else:
80            raise AngrAnnotatedCFGError("Unknown type '%s' of the 'run' argument" % type(run))
81
82    def add_block_to_whitelist(self, block):
83        addr = self.get_addr(block)
84        self._run_statement_whitelist[addr] = True
85
86    def add_statements_to_whitelist(self, block, stmt_ids):
87        addr = self.get_addr(block)
88        if type(stmt_ids) is bool:
89            if type(self._run_statement_whitelist[addr]) is list and self._run_statement_whitelist[addr]:
90                raise Exception("WTF")
91            self._run_statement_whitelist[addr] = stmt_ids
92        elif -1 in stmt_ids:
93            self._run_statement_whitelist[addr] = True
94        else:
95            self._run_statement_whitelist[addr].extend(stmt_ids)
96            self._run_statement_whitelist[addr] = \
97                sorted(set(self._run_statement_whitelist[addr]), key=lambda v: v if type(v) is int else float('inf'))
98
99    def add_exit_to_whitelist(self, run_from, run_to):
100        addr_from = self.get_addr(run_from)
101        addr_to = self.get_addr(run_to)
102        self._exit_taken[addr_from].append(addr_to)
103
104    def set_last_statement(self, block_addr, stmt_id):
105        self._addr_to_last_stmt_id[block_addr] = stmt_id
106
107    def add_loop(self, loop_tuple):
108        """
109        A loop tuple contains a series of IRSB addresses that form a loop. Ideally
110        it always starts with the first IRSB that we meet during the execution.
111        """
112        self._loops.append(loop_tuple)
113
114    def should_take_exit(self, addr_from, addr_to):
115        if addr_from in self._exit_taken:
116            return addr_to in self._exit_taken[addr_from]
117
118        return False
119
120    def should_execute_statement(self, addr, stmt_id):
121        if self._run_statement_whitelist is None:
122            return True
123        elif addr in self._run_statement_whitelist:
124            r = self._run_statement_whitelist[addr]
125            if isinstance(r, bool):
126                return r
127            else:
128                return stmt_id in self._run_statement_whitelist[addr]
129        return False
130
131    def get_run(self, addr):
132        if addr in self._addr_to_run:
133            return self._addr_to_run[addr]
134        return None
135
136    def get_whitelisted_statements(self, addr):
137        """
138        :returns: True if all statements are whitelisted
139        """
140        if addr in self._run_statement_whitelist:
141            if self._run_statement_whitelist[addr] is True:
142                return None # This is the default value used to say
143                            # we execute all statements in this basic block. A
144                            # little weird...
145
146            else:
147                return self._run_statement_whitelist[addr]
148
149        else:
150            return []
151
152    def get_last_statement_index(self, addr):
153        """
154        Get the statement index of the last statement to execute in the basic block specified by `addr`.
155
156        :param int addr:    Address of the basic block.
157        :return:            The statement index of the last statement to be executed in the block. Usually if the
158                            default exit is taken, it will be the last statement to execute. If the block is not in the
159                            slice or we should never take any exit going to this block, None is returned.
160        :rtype:             int or None
161        """
162
163        if addr in self._exit_taken:
164            return None
165        if addr in self._addr_to_last_stmt_id:
166            return self._addr_to_last_stmt_id[addr]
167        elif addr in self._run_statement_whitelist:
168            # is the default exit there? it equals to a negative number (-2 by default) so `max()` won't work.
169            if DEFAULT_STATEMENT in self._run_statement_whitelist[addr]:
170                return DEFAULT_STATEMENT
171            return max(self._run_statement_whitelist[addr], key=lambda v: v if type(v) is int else float('inf'))
172        return None
173
174    def get_loops(self):
175        return self._loops
176
177    def get_targets(self, source_addr):
178        if source_addr in self._exit_taken:
179            return self._exit_taken[source_addr]
180        return None
181
182    #
183    # Debugging helpers
184    #
185
186    def dbg_repr(self):
187        ret_str = ""
188
189        ret_str += "IRSBs:\n"
190        for addr, run in self._addr_to_run.items():
191            if addr is None:
192                continue
193            ret_str += "%#x => %s\n" % (addr, run)
194        l.debug("statements: ")
195        for addr, stmts in self._run_statement_whitelist.items():
196            if addr is None:
197                continue
198            ret_str += "Address 0x%08x:\n" % addr
199            l.debug(stmts)
200        l.debug("Loops: ")
201        for loop in self._loops:
202            s = ""
203            for addr in loop:
204                s += "0x%08x -> " % addr
205            ret_str += s + "\n"
206
207        return ret_str
208
209    def dbg_print_irsb(self, irsb_addr, project=None):
210        """
211        Pretty-print an IRSB with whitelist information
212        """
213
214        if project is None:
215            project = self._project
216
217        if project is None:
218            raise Exception("Dict addr_to_run is empty. " + \
219                            "Give me a project, and I'll recreate the IRSBs for you.")
220
221        vex_block = project.factory.block(irsb_addr).vex
222        statements = vex_block.statements
223        whitelist = self.get_whitelisted_statements(irsb_addr)
224        for i in range(0, len(statements)):
225            if whitelist is True or i in whitelist:
226                line = "+"
227            else:
228                line = "-"
229            line += "[% 3d] " % i
230            # We cannot get data returned by pp(). WTF?
231            print(line, end='')
232            statements[i].pp()
233
234    #
235    # Helper methods for path priorization
236    #
237
238    def keep_path(self, path):
239        """
240        Given a path, returns True if the path should be kept, False if it should be cut.
241        """
242        if len(path.addr_trace) < 2:
243            return True
244
245        return self.should_take_exit(path.addr_trace[-2], path.addr_trace[-1])
246
247    def merge_points(self, path):
248        addr = path.addr
249        if addr in self._path_merge_points:
250            return {self._path_merge_points[addr]}
251        else:
252            return set()
253
254    def successor_func(self, path):
255        """
256        Callback routine that takes in a path, and returns all feasible successors to path group. This callback routine
257        should be passed to the keyword argument "successor_func" of PathGroup.step().
258
259        :param path: A Path instance.
260        :return: A list of all feasible Path successors.
261        """
262
263        whitelist = self.get_whitelisted_statements(path.addr)
264        last_stmt = self.get_last_statement_index(path.addr)
265
266        # pass in those arguments
267        successors = path.step(
268            stmt_whitelist=whitelist,
269            last_stmt=None
270        )
271
272        # further filter successors based on the annotated CFG
273        taken_successors = [ ]
274        for suc in successors:
275            try:
276                taken = self.should_take_exit(path.addr, suc.addr)
277            except AngrExitError:
278                l.debug("Got an unknown exit that AnnotatedCFG does not know about: %#x -> %#x", path.addr, suc.addr)
279                continue
280
281            if taken:
282                taken_successors.append(suc)
283
284        return taken_successors
285
286    #
287    # Overridden methods
288    #
289
290    def __getstate__(self):
291        state = {}
292        state['_run_statement_whitelist'] = self._run_statement_whitelist
293        state['_exit_taken'] = self._exit_taken
294        # state['_addr_to_run'] = self._addr_to_run
295        state['_addr_to_last_stmt_id'] = self._addr_to_last_stmt_id
296        state['_loops'] = self._loops
297        state['_path_merge_points'] = self._path_merge_points
298        state['_cfg'] = None
299        state['_project'] = None
300        state['_addr_to_run'] = None
301        return state
302
303    #
304    # Private methods
305    #
306
307    def _detect_loops(self):
308        temp_graph = networkx.DiGraph()
309        for source, target_list in self._cfg._edge_map.items():
310            for target in target_list:
311                temp_graph.add_edge(source, target)
312        ctr = 0
313        for loop_lst in networkx.simple_cycles(temp_graph):
314            l.debug("A loop is found. %d", ctr)
315            ctr += 1
316            loop = (tuple([x[-1] for x in loop_lst]))
317            print(" => ".join(["0x%08x" % x for x in loop]))
318            self.add_loop(loop)
319