1#!/usr/local/bin/python3.8
2#
3# PLASMA : Generate an indented asm code (pseudo-C) with colored syntax.
4# Copyright (C) 2015    Joel
5#
6# This program is free software: you can redistribute it and/or modify
7# it under the terms of the GNU General Public License as published by
8# the Free Software Foundation, either version 3 of the License, or
9# (at your option) any later version.
10#
11# This program is distributed in the hope that it will be useful,
12# but WITHOUT ANY WARRANTY; without even the implied warranty of
13# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.    See the
14# GNU General Public License for more details.
15#
16# You should have received a copy of the GNU General Public License
17# along with this program.    If not, see <http://www.gnu.org/licenses/>.
18#
19
20from time import time
21
22from plasma.lib.utils import BRANCH_NEXT, BRANCH_NEXT_JUMP, debug__, list_starts_with
23
24# For the loop's analysis
25MAX_NODES = 800
26
27
28# This class is used only for the decompilation mode. The analyzer create
29# also a graph but only on-the-fly.
30
31class Graph:
32    def __init__(self, dis, entry_point_addr):
33        # Each node contains a block (list) of instructions.
34        self.nodes = {} # ad -> [instruction, (prefetch)]
35
36        # For each address block, we store a list of next blocks.
37        # If there are 2 elements it means that the precedent instruction
38        # was a conditional jump :
39        # 1st : direct next instruction
40        # 2nd : for conditional jump : address of the jump
41        self.link_out = {} # ad -> [nxt1, nxt2]
42
43        self.link_in = {} # ad -> [prev, ...]
44
45        self.entry_point_addr = entry_point_addr
46        self.dis = dis
47
48        # For one loop : contains all address of the loop only
49        self.loops_set = {}
50
51        # For one loop : contains all address of the loop and sub-loops
52        self.loops_all = {}
53
54        # Rest of all address which are not in a loop
55        self.not_in_loop = set()
56
57        self.loops_start = set()
58
59        # Optimization
60        self.cond_jumps_set = set()
61        self.uncond_jumps_set = set()
62
63        self.equiv = {}
64        self.false_loops = set()
65
66        # Loop dependencies
67        self.deps = {}
68        self.rev_deps = {}
69
70        self.cache_path_exists = {}
71
72        # For each loop we search the last node that if we enter in it,
73        # we are sure to return to the loop.
74        self.last_node_loop = {}
75
76        self.all_deep_equiv = set()
77
78        self.skipped_loops_analysis = False
79
80        self.exit_or_ret = set()
81
82        # Could be modified by AddrContext.decompile
83        self.debug = False
84
85
86    # A jump is normally alone in a block, but for some architectures
87    # we save the prefetched instruction after.
88
89    def new_node(self, curr, prefetch, nxt):
90        ad = curr.address
91        self.nodes[ad] = [curr]
92
93        if nxt is not None:
94            self.link_out[ad] = nxt
95
96        if nxt is not None:
97            for n in nxt:
98                if n not in self.link_in:
99                    self.link_in[n] = [ad]
100                else:
101                    self.link_in[n].append(ad)
102
103        if prefetch is not None:
104            self.nodes[ad].append(prefetch)
105
106
107    def exists(self, inst):
108        return inst.address in self.nodes
109
110
111    # Concat instructions in single block
112    # jumps are in separated blocks
113    def simplify(self):
114        nodes = list(self.nodes.keys())
115        start = time()
116
117        for ad in nodes:
118            if ad in self.uncond_jumps_set or ad in self.cond_jumps_set:
119                continue
120
121            if ad not in self.link_in or len(self.link_in[ad]) != 1 or \
122                    ad == self.entry_point_addr:
123                continue
124
125            pred = self.link_in[ad][0]
126
127            # don't fuse with jumps
128            if pred in self.uncond_jumps_set or pred in self.cond_jumps_set:
129                continue
130
131            if pred not in self.link_out or len(self.link_out[pred]) != 1:
132                continue
133
134            if ad in self.link_out:
135                self.link_out[pred] = self.link_out[ad]
136            else:
137                del self.link_out[pred]
138
139            self.nodes[pred] += self.nodes[ad]
140
141            if ad in self.link_out:
142                del self.link_out[ad]
143
144            del self.link_in[ad]
145            del self.nodes[ad]
146
147            # replace all addr wich refers to ad
148            for k, lst_i in self.link_in.items():
149                if ad in lst_i:
150                    lst_i[lst_i.index(ad)] = pred
151
152        elapsed = time()
153        elapsed = elapsed - start
154        debug__("Graph simplified in %fs (%d nodes)" % (elapsed, len(self.nodes)))
155
156
157    def dot_loop_deps(self):
158        output = open("graph_loop_deps.dot", "w+")
159        output.write('digraph {\n')
160        output.write('node [fontname="liberation mono" style=filled fillcolor=white shape=box];\n')
161
162        for k, dp in self.deps.items():
163            output.write('node_%x_%x [label="(%x, %x)"' % (k[0], k[1], k[0], k[1]))
164
165            if k in self.false_loops:
166                output.write(' fillcolor="#B6FFDD"')
167
168            if k in self.all_deep_equiv:
169                output.write(' color="#ff0000"')
170
171            output.write('];\n')
172
173            for sub in dp:
174                output.write('node_%x_%x -> node_%x_%x;\n'
175                        % (k[0], k[1], sub[0], sub[1]))
176
177        output.write('}\n')
178        output.close()
179
180
181    def dot_graph(self, jmptables):
182        output = open("graph.dot", "w+")
183        output.write('digraph {\n')
184        # output.write('graph [bgcolor="#aaaaaa" pad=20];\n')
185        # output.write('node [fontname="liberation mono" style=filled fillcolor="#333333" fontcolor="#d3d3d3" shape=box];\n')
186        output.write('node [fontname="liberation mono" style=filled fillcolor=white shape=box];\n')
187
188        keys = list(self.nodes.keys())
189        keys.sort()
190
191        for k in keys:
192            lst_i = self.nodes[k]
193
194            output.write('node_%x [label="' % k)
195
196            for i in lst_i:
197                output.write('0x%x: %s %s\\l' % (i.address, i.mnemonic, i.op_str))
198
199            output.write('"')
200
201            if k in self.loops_start:
202                output.write(' fillcolor="#FFFCC4"')
203            elif k not in self.link_out:
204                output.write(' fillcolor="#ff7777"')
205            elif k not in self.link_in:
206                output.write(' fillcolor="#B6FFDD"')
207
208            output.write('];\n')
209
210        for k, i in self.link_out.items():
211            if k in jmptables:
212                for ad in jmptables[k].table:
213                    output.write('node_%x -> node_%x;\n' % (k, ad))
214            elif len(i) == 2:
215                # true green branch (jump is taken)
216                output.write('node_%x -> node_%x [color="#58DA9C"];\n'
217                        % (k, i[BRANCH_NEXT_JUMP]))
218
219                # false red branch (jump is not taken)
220                output.write('node_%x -> node_%x [color="#ff7777"];\n'
221                        % (k, i[BRANCH_NEXT]))
222            else:
223                output.write('node_%x -> node_%x;\n' % (k, i[BRANCH_NEXT]))
224
225        output.write('}')
226        output.close()
227
228
229    def __search_last_node_loop(self, l_prev_loop, l_start, l_set):
230        def __rec_branch_go_out(ad):
231            stack = [ad]
232            visited = set()
233            while stack:
234                ad = stack.pop(-1)
235                if ad not in l_set:
236                    return True
237                if ad == l_start or ad in visited:
238                    continue
239                visited.add(ad)
240                for n in self.link_out[ad]:
241                    stack.append(n)
242            return False
243
244        # Start from the end of the loop
245
246        stack = []
247        visited = {l_start}
248        for prev in self.link_in[l_start]:
249            if prev in l_set:
250                stack.append(prev)
251
252        res = []
253
254        while stack:
255            ad = stack.pop(-1)
256            if ad in visited or ad not in l_set:
257                continue
258            visited.add(ad)
259
260            for prev in self.link_in[ad]:
261                if prev == l_start:
262                    continue
263
264                go_out = False
265
266                for n in self.link_out[prev]:
267                    if n not in l_set:
268                        res.append(ad)
269                        go_out = True
270
271                if not go_out:
272                    stack.append(prev)
273
274        for ad in res:
275            if ad not in self.last_node_loop:
276                self.last_node_loop[ad] = set()
277            self.last_node_loop[ad].add((l_prev_loop, l_start))
278
279
280    def __is_inf_loop(self, l_set):
281        for ad in l_set:
282            if ad in self.link_out:
283                for nxt in self.link_out[ad]:
284                    if nxt not in l_set:
285                        return False
286        return True
287
288
289    def path_exists(self, from_addr, to_addr, loop_start):
290        def __path_exists(curr, local_visited):
291            stack = []
292            for n in self.link_out[from_addr]:
293                stack.append(n)
294            while stack:
295                curr = stack.pop(-1)
296                if curr == to_addr:
297                    return True
298                if curr in local_visited:
299                    continue
300                local_visited.add(curr)
301                if curr not in self.link_out or curr == loop_start:
302                    continue
303                for n in self.link_out[curr]:
304                    stack.append(n)
305            return False
306
307        if from_addr == to_addr:
308            return True
309
310        if from_addr not in self.link_out:
311            return False
312
313        if (from_addr, to_addr) in self.cache_path_exists:
314            return self.cache_path_exists[(from_addr, to_addr)]
315
316        local_visited = set()
317        res = __path_exists(from_addr, local_visited)
318        self.cache_path_exists[(from_addr, to_addr)] = res
319        return res
320
321
322    # Returns a set containing every addresses which are in paths from
323    # 'from_addr' to 'to_addr'.
324    def find_paths(self, from_addr, to_addr, global_visited):
325        def __rec_find_paths(curr, local_visited, path_set):
326            nonlocal isfirst
327            if curr == to_addr and not isfirst:
328                path_set.add(curr)
329                return
330            isfirst = False
331            if curr in local_visited:
332                return
333            local_visited.add(curr)
334            if curr in global_visited or curr not in self.link_out:
335                return
336            for n in self.link_out[curr]:
337                __rec_find_paths(n, local_visited, path_set)
338
339                if n in path_set:
340                    path_set.add(curr)
341
342        isfirst = True
343        path_set = set()
344        local_visited = set()
345        __rec_find_paths(from_addr, local_visited, path_set)
346        return path_set
347
348
349    def __try_find_loops(self, entry, waiting, par_loops, l_set, is_sub_loop):
350        detected_loops = {}
351        keys = set(waiting.keys())
352
353        for ad in keys:
354            if l_set is not None and ad not in l_set:
355                continue
356
357            if (entry, ad) in self.loops_set:
358                continue
359
360            # search a path from ad to ad, but don't return in pareent loops
361            l = self.find_paths(ad, ad, par_loops)
362
363            # If the set is empty, it's not a loop
364            if l:
365                self.loops_set[(entry, ad)] = l
366                is_sub_loop.add(ad)
367                detected_loops[ad] = (entry, ad)
368
369        return detected_loops
370
371
372    # This function removes entries in waiting list if we have seen
373    # all previous nodes for one node.
374    def __manage_waiting(self, stack, visited, waiting, l_set, done):
375        keys = set(waiting.keys())
376        for ad in keys:
377            if l_set is not None and ad not in l_set:
378                continue
379            if len(waiting[ad]) == 0:
380                del waiting[ad]
381                done.add(ad)
382                stack.append((-1, ad))
383
384
385    # This function reads all nodes in the current stack.
386    # Each node is added to the waiting list if all previous
387    # nodes were not seen.
388    # Then if there is an out-link, this new node is added to
389    # the stack, and it returns True.
390    def __until_stack_empty(self, stack, waiting, visited,
391                            par_loops, l_set, is_sub_loop, done):
392        has_moved = False
393
394        while stack:
395            prev, ad = stack.pop(-1)
396
397            # if ad has parent nodes, check if we have seen all parents
398            if ad in self.link_in and ad not in done:
399                l_in = self.link_in[ad]
400
401                if len(l_in) > 1 or l_set is not None and ad not in l_set:
402                    if ad in waiting:
403                        if prev in waiting[ad]:
404                            waiting[ad].remove(prev)
405                    else:
406                        unseen = set(l_in)
407                        unseen.remove(prev)
408                        waiting[ad] = unseen
409                    continue
410
411            if ad in visited:
412                continue
413
414            visited.add(ad)
415
416            if ad in self.link_out:
417                for n in self.link_out[ad]:
418                    if n in par_loops:
419                        continue
420                    stack.append((ad, n))
421                    has_moved = True
422
423        return has_moved
424
425
426    def __get_new_loops(self, waiting, detected_loops, l_set, is_sub_loop):
427        new_loops = set()
428
429        # Remove internal links to the beginning of the loop
430        # If later we enter in the loop it means that len(waiting[ad]) == 0
431        for ad, k in detected_loops.items():
432            loop = self.loops_set[k]
433
434            was_removed = False
435
436            for rest in set(waiting[ad]):
437                if rest in loop:
438                    waiting[ad].remove(rest)
439                    was_removed = True
440
441            if was_removed:
442                if len(waiting[ad]) == 0:
443                    new_loops.add(ad)
444                    del waiting[ad]
445
446        # Remove external jumps which are outside the current loop
447        for ad, unseen in waiting.items():
448            if l_set is not None and ad not in l_set:
449                continue
450            for i in set(unseen):
451                if l_set is not None and i not in l_set:
452                    unseen.remove(i)
453
454        return new_loops
455
456
457    # It explores the graph and detects loops
458    def __explore(self, entry, par_loops, visited, waiting, l_set, done):
459        stack = []
460
461        # Check if the first address (entry point of the function) is the
462        # beginning of a loop.
463        if not visited and entry in self.link_in:
464            for p in self.link_in[entry]:
465                stack.append((p, entry))
466        else:
467            if entry in self.link_out:
468                for n in self.link_out[entry]:
469                    stack.append((entry, n))
470            visited.add(entry)
471
472        is_sub_loop = set()
473
474        while 1:
475            # first: manage pending nodes
476            if self.__until_stack_empty(
477                    stack, waiting, visited, par_loops, l_set, is_sub_loop, done):
478                self.__manage_waiting(stack, visited, waiting, l_set, done)
479                continue
480
481            detected_loops = self.__try_find_loops(
482                    entry, waiting, par_loops, l_set, is_sub_loop)
483
484            new_loops = self.__get_new_loops(
485                    waiting, detected_loops, l_set, is_sub_loop)
486
487            while new_loops:
488                # Follow loops
489                for ad in new_loops:
490                    # TODO : optimize
491                    v = set(visited)
492                    v.add(ad)
493                    pl = set(par_loops)
494                    pl.add(ad)
495
496                    l = self.loops_set[(entry, ad)]
497                    self.__explore(ad, pl, v, waiting, l, set(done))
498
499                detected_loops = self.__try_find_loops(
500                        entry, waiting, par_loops, l_set, is_sub_loop)
501
502                new_loops = self.__get_new_loops(
503                        waiting, detected_loops, l_set, is_sub_loop)
504
505
506            self.__manage_waiting(stack, visited, waiting, l_set, done)
507
508            if not stack:
509                break
510
511        # Now for each current loop, we add the content of each sub-loops.
512        # It means that a loop contains all sub-loops (which is not the case
513        # in loops_set : in contains only the current loop).
514        for ad in is_sub_loop:
515            loop = set(self.loops_set[(entry, ad)])
516            self.loops_all[(entry, ad)] = loop
517
518            self.deps[(entry, ad)] = set()
519
520            for (prev, start), l in self.loops_set.items():
521                # Skip current loop
522                if (prev, start) == (entry, ad):
523                    continue
524
525                # Is it a sub loop ?
526                if prev == ad and start != entry and start in loop:
527                    k1 = (entry, ad)
528                    k2 = (prev, start)
529                    if k2 not in self.rev_deps:
530                        self.rev_deps[k2] = set()
531                    self.rev_deps[k2].add(k1)
532                    self.deps[k1].add(k2)
533                    self.loops_all[(entry, ad)].update(self.loops_all[(prev, start)])
534
535
536    def all_false(self, loops_key):
537        for k in loops_key:
538            if k not in self.false_loops:
539                return False
540        return True
541
542
543    # Mark recursively parent loops, all children of a loop must
544    # have been set to false
545    def __rec_mark_parent_false(self, k):
546        self.false_loops.add(k)
547        if k not in self.rev_deps:
548            return
549
550        for par in self.rev_deps[k]:
551            if par in self.false_loops:
552                continue
553
554            if self.all_false(self.deps[par]):
555                self.__rec_mark_parent_false(par)
556
557
558    def __rec_mark_children(self, k, myset):
559        myset.add(k)
560        for sub in self.deps[k]:
561            if sub not in self.false_loops:
562                self.__rec_mark_children(sub, myset)
563
564
565    def __yield_cmp_loops(self, keys1, not_in_false=True):
566        # optim: don't compare twice two loops
567        keys2 = set(keys1)
568        for k1 in keys1:
569            keys2.remove(k1)
570            if not_in_false and k1 in self.false_loops:
571                continue
572            for k2 in keys2:
573                if not_in_false and k2 in self.false_loops:
574                    continue
575                yield k1, k2
576
577
578    # Some heuristics to detect false loops
579    def __search_false_loops(self):
580        # If all previous link of a loop start is inside the loop, this
581        # is a false loop.
582        for (prev, start), l_set in self.loops_all.items():
583            if prev == start:
584                # special case: see tests/entryloop1
585                continue
586
587            l_set_copy = set(l_set)
588            l_set_copy.remove(start)
589            lin = set(self.link_in[start])
590
591            if lin.issubset(l_set):
592                self.false_loops.add((prev, start))
593
594
595        # Find loops at the last level (no loop inside)
596        loops_to_check = set()
597        for k in self.loops_all:
598            if len(self.deps[k]) == 0:
599                loops_to_check.add(k)
600
601
602        # Now try to find other false loops...
603        for (prev1, start1), (prev2, start2) in \
604                    self.__yield_cmp_loops(self.loops_all.keys()):
605
606            if (prev1, start1) not in loops_to_check:
607                continue
608
609            l1 = self.loops_set[(prev1, start1)]
610            l2 = self.loops_set[(prev2, start2)]
611
612            if start1 == start2:
613                continue
614
615            if prev1 in l2 and \
616                 start1 in l2 and \
617                 start2 in l1:
618                if l1.issubset(l2):
619                    self.__rec_mark_parent_false((prev1, start1))
620
621
622        # Make a diff to keep only real loops (false loops will be
623        # deleted by __update_loops)
624
625        correct_loops = set()
626
627        for k in self.roots:
628            if k in self.false_loops:
629                continue
630            self.__rec_mark_children(k, correct_loops)
631
632        self.false_loops = self.loops_all.keys() - correct_loops
633
634
635    def __search_same_deep_equiv_loops(self):
636        #
637        # Search equivalent loops at the same deep, but compare with
638        # 'loops_all' -> each item contains all sub-loops instead of
639        # 'loops_set' wich contains only the loop.
640        #
641        # example:
642        #
643        #      loop1
644        #     /     \
645        #  loop2   loop3
646        #
647        # If loops_all[loop2] == loops_all[loop3], and if loop2 or loop3 is
648        # in false_loops, we remove these loops.
649        #
650
651        def do_add(k1, k2):
652            nonlocal idx_count, set_index, deep_equiv
653            l1 = self.loops_all[k1]
654            l2 = self.loops_all[k2]
655            if l1 == l2:
656                if k1 in set_index:
657                    i = set_index[k1]
658                    deep_equiv[i].add(k2)
659                    self.all_deep_equiv.add(k2)
660                    set_index[k2] = i
661                elif k2 in set_index:
662                    i = set_index[k2]
663                    deep_equiv[i].add(k1)
664                    self.all_deep_equiv.add(k1)
665                    set_index[k1] = i
666                else:
667                    i = idx_count
668                    idx_count += 1
669                    deep_equiv[i] = {k1, k2}
670                    set_index[k1] = i
671                    set_index[k2] = i
672                    self.all_deep_equiv.add(k1)
673                    self.all_deep_equiv.add(k2)
674
675        set_index = {}
676        deep_equiv = {}
677        idx_count = 0
678
679        for k in self.deps:
680            for k1, k2 in self.__yield_cmp_loops(self.deps[k], False):
681                do_add(k1, k2)
682
683        for k1, k2 in self.__yield_cmp_loops(self.roots, False):
684            do_add(k1, k2)
685
686        if not deep_equiv:
687            return
688
689        last_length = 0
690        while last_length != len(self.false_loops):
691            last_length = len(self.false_loops)
692
693            for i, keys in deep_equiv.items():
694                nb_false = 0
695                for k in keys:
696                    if k in self.false_loops:
697                        nb_false += 1
698
699                if nb_false > 0:
700                    for k in set(keys):
701                        if k in self.false_loops:
702                            continue
703                        subs = self.deps[k]
704                        if len(subs) == 0 or self.all_false(subs):
705                            keys.remove(k)
706                            self.__rec_mark_parent_false(k)
707
708
709    def __update_loops(self):
710        def rec_remove(k):
711            if k not in self.loops_all:
712                return
713            del self.loops_all[k]
714            del self.loops_set[k]
715            for sub in self.deps[k]:
716                if sub in self.false_loops:
717                    rec_remove(sub)
718        for k in self.false_loops:
719            if k not in self.rev_deps or k in self.all_deep_equiv:
720                rec_remove(k)
721
722
723    def loop_detection(self, entry, bypass_false_search=False):
724        start = time()
725
726        # Equivalent loops at a same deep in the loops dependencies tree
727        self.deep_equiv = set()
728        # For one loop : contains all address of the loop only
729        self.loops_set = {}
730        # For one loop : contains all address of the loop and sub-loops
731        self.loops_all = {}
732        # Loop dependencies
733        self.deps = {}
734        self.rev_deps = {}
735        # Loops marked as "False"
736        self.false_loops = set()
737
738        if len(self.nodes) > MAX_NODES:
739            self.skipped_loops_analysis = True
740            return
741
742        # Detect loops and compute loop dependencies on the fly
743        self.__explore(entry, set(), set(), {}, None, set())
744
745        # Keep only loops at the first level
746        self.roots = self.loops_set.keys() - self.rev_deps.keys()
747
748        # debug__(self.deps)
749        # debug__(self.roots)
750        # debug__(self.loops_all)
751
752        # Detect 'strange loops'
753        if not bypass_false_search:
754            self.__search_false_loops()
755            self.__search_same_deep_equiv_loops()
756
757        # Remove recursively marked false loops
758        self.__update_loops()
759
760        # debug__(self.loops_all)
761
762        # Compute all addresses which are not in a loop
763        in_loop = set()
764        for l in self.loops_set.items():
765            in_loop.update(l[1])
766
767        self.not_in_loop = self.nodes.keys() - in_loop
768
769        # Search inifinite loops
770        self.infinite_loop = set()
771        for l_curr_loop, l_set in self.loops_all.items():
772            if self.__is_inf_loop(l_set):
773                self.infinite_loop.add(l_curr_loop)
774
775        # Save first address of loops
776        self.loops_start = set()
777        for _, l_start in self.loops_all:
778            self.loops_start.add(l_start)
779
780        # For each loop we search the last node that if we enter in it,
781        # we are sure to return to the loop.
782        self.last_node_loop = {}
783        for (l_prev_loop, l_start), l_set in self.loops_all.items():
784            if (l_prev_loop, l_start) not in self.infinite_loop:
785                self.__search_last_node_loop(l_prev_loop, l_start, l_set)
786
787        if self.debug:
788            self.dot_loop_deps()
789
790        elapsed = time()
791        elapsed = elapsed - start
792        debug__("Exploration: found %d loop(s) in %fs" %
793                (len(self.loops_all), elapsed))
794