1#!/usr/local/bin/python3.8 2# 3# PLASMA : Generate an indented asm code (pseudo-C) with colored syntax. 4# Copyright (C) 2015 Joel 5# 6# This program is free software: you can redistribute it and/or modify 7# it under the terms of the GNU General Public License as published by 8# the Free Software Foundation, either version 3 of the License, or 9# (at your option) any later version. 10# 11# This program is distributed in the hope that it will be useful, 12# but WITHOUT ANY WARRANTY; without even the implied warranty of 13# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14# GNU General Public License for more details. 15# 16# You should have received a copy of the GNU General Public License 17# along with this program. If not, see <http://www.gnu.org/licenses/>. 18# 19 20from time import time 21 22from plasma.lib.utils import BRANCH_NEXT, BRANCH_NEXT_JUMP, debug__, list_starts_with 23 24# For the loop's analysis 25MAX_NODES = 800 26 27 28# This class is used only for the decompilation mode. The analyzer create 29# also a graph but only on-the-fly. 30 31class Graph: 32 def __init__(self, dis, entry_point_addr): 33 # Each node contains a block (list) of instructions. 34 self.nodes = {} # ad -> [instruction, (prefetch)] 35 36 # For each address block, we store a list of next blocks. 37 # If there are 2 elements it means that the precedent instruction 38 # was a conditional jump : 39 # 1st : direct next instruction 40 # 2nd : for conditional jump : address of the jump 41 self.link_out = {} # ad -> [nxt1, nxt2] 42 43 self.link_in = {} # ad -> [prev, ...] 44 45 self.entry_point_addr = entry_point_addr 46 self.dis = dis 47 48 # For one loop : contains all address of the loop only 49 self.loops_set = {} 50 51 # For one loop : contains all address of the loop and sub-loops 52 self.loops_all = {} 53 54 # Rest of all address which are not in a loop 55 self.not_in_loop = set() 56 57 self.loops_start = set() 58 59 # Optimization 60 self.cond_jumps_set = set() 61 self.uncond_jumps_set = set() 62 63 self.equiv = {} 64 self.false_loops = set() 65 66 # Loop dependencies 67 self.deps = {} 68 self.rev_deps = {} 69 70 self.cache_path_exists = {} 71 72 # For each loop we search the last node that if we enter in it, 73 # we are sure to return to the loop. 74 self.last_node_loop = {} 75 76 self.all_deep_equiv = set() 77 78 self.skipped_loops_analysis = False 79 80 self.exit_or_ret = set() 81 82 # Could be modified by AddrContext.decompile 83 self.debug = False 84 85 86 # A jump is normally alone in a block, but for some architectures 87 # we save the prefetched instruction after. 88 89 def new_node(self, curr, prefetch, nxt): 90 ad = curr.address 91 self.nodes[ad] = [curr] 92 93 if nxt is not None: 94 self.link_out[ad] = nxt 95 96 if nxt is not None: 97 for n in nxt: 98 if n not in self.link_in: 99 self.link_in[n] = [ad] 100 else: 101 self.link_in[n].append(ad) 102 103 if prefetch is not None: 104 self.nodes[ad].append(prefetch) 105 106 107 def exists(self, inst): 108 return inst.address in self.nodes 109 110 111 # Concat instructions in single block 112 # jumps are in separated blocks 113 def simplify(self): 114 nodes = list(self.nodes.keys()) 115 start = time() 116 117 for ad in nodes: 118 if ad in self.uncond_jumps_set or ad in self.cond_jumps_set: 119 continue 120 121 if ad not in self.link_in or len(self.link_in[ad]) != 1 or \ 122 ad == self.entry_point_addr: 123 continue 124 125 pred = self.link_in[ad][0] 126 127 # don't fuse with jumps 128 if pred in self.uncond_jumps_set or pred in self.cond_jumps_set: 129 continue 130 131 if pred not in self.link_out or len(self.link_out[pred]) != 1: 132 continue 133 134 if ad in self.link_out: 135 self.link_out[pred] = self.link_out[ad] 136 else: 137 del self.link_out[pred] 138 139 self.nodes[pred] += self.nodes[ad] 140 141 if ad in self.link_out: 142 del self.link_out[ad] 143 144 del self.link_in[ad] 145 del self.nodes[ad] 146 147 # replace all addr wich refers to ad 148 for k, lst_i in self.link_in.items(): 149 if ad in lst_i: 150 lst_i[lst_i.index(ad)] = pred 151 152 elapsed = time() 153 elapsed = elapsed - start 154 debug__("Graph simplified in %fs (%d nodes)" % (elapsed, len(self.nodes))) 155 156 157 def dot_loop_deps(self): 158 output = open("graph_loop_deps.dot", "w+") 159 output.write('digraph {\n') 160 output.write('node [fontname="liberation mono" style=filled fillcolor=white shape=box];\n') 161 162 for k, dp in self.deps.items(): 163 output.write('node_%x_%x [label="(%x, %x)"' % (k[0], k[1], k[0], k[1])) 164 165 if k in self.false_loops: 166 output.write(' fillcolor="#B6FFDD"') 167 168 if k in self.all_deep_equiv: 169 output.write(' color="#ff0000"') 170 171 output.write('];\n') 172 173 for sub in dp: 174 output.write('node_%x_%x -> node_%x_%x;\n' 175 % (k[0], k[1], sub[0], sub[1])) 176 177 output.write('}\n') 178 output.close() 179 180 181 def dot_graph(self, jmptables): 182 output = open("graph.dot", "w+") 183 output.write('digraph {\n') 184 # output.write('graph [bgcolor="#aaaaaa" pad=20];\n') 185 # output.write('node [fontname="liberation mono" style=filled fillcolor="#333333" fontcolor="#d3d3d3" shape=box];\n') 186 output.write('node [fontname="liberation mono" style=filled fillcolor=white shape=box];\n') 187 188 keys = list(self.nodes.keys()) 189 keys.sort() 190 191 for k in keys: 192 lst_i = self.nodes[k] 193 194 output.write('node_%x [label="' % k) 195 196 for i in lst_i: 197 output.write('0x%x: %s %s\\l' % (i.address, i.mnemonic, i.op_str)) 198 199 output.write('"') 200 201 if k in self.loops_start: 202 output.write(' fillcolor="#FFFCC4"') 203 elif k not in self.link_out: 204 output.write(' fillcolor="#ff7777"') 205 elif k not in self.link_in: 206 output.write(' fillcolor="#B6FFDD"') 207 208 output.write('];\n') 209 210 for k, i in self.link_out.items(): 211 if k in jmptables: 212 for ad in jmptables[k].table: 213 output.write('node_%x -> node_%x;\n' % (k, ad)) 214 elif len(i) == 2: 215 # true green branch (jump is taken) 216 output.write('node_%x -> node_%x [color="#58DA9C"];\n' 217 % (k, i[BRANCH_NEXT_JUMP])) 218 219 # false red branch (jump is not taken) 220 output.write('node_%x -> node_%x [color="#ff7777"];\n' 221 % (k, i[BRANCH_NEXT])) 222 else: 223 output.write('node_%x -> node_%x;\n' % (k, i[BRANCH_NEXT])) 224 225 output.write('}') 226 output.close() 227 228 229 def __search_last_node_loop(self, l_prev_loop, l_start, l_set): 230 def __rec_branch_go_out(ad): 231 stack = [ad] 232 visited = set() 233 while stack: 234 ad = stack.pop(-1) 235 if ad not in l_set: 236 return True 237 if ad == l_start or ad in visited: 238 continue 239 visited.add(ad) 240 for n in self.link_out[ad]: 241 stack.append(n) 242 return False 243 244 # Start from the end of the loop 245 246 stack = [] 247 visited = {l_start} 248 for prev in self.link_in[l_start]: 249 if prev in l_set: 250 stack.append(prev) 251 252 res = [] 253 254 while stack: 255 ad = stack.pop(-1) 256 if ad in visited or ad not in l_set: 257 continue 258 visited.add(ad) 259 260 for prev in self.link_in[ad]: 261 if prev == l_start: 262 continue 263 264 go_out = False 265 266 for n in self.link_out[prev]: 267 if n not in l_set: 268 res.append(ad) 269 go_out = True 270 271 if not go_out: 272 stack.append(prev) 273 274 for ad in res: 275 if ad not in self.last_node_loop: 276 self.last_node_loop[ad] = set() 277 self.last_node_loop[ad].add((l_prev_loop, l_start)) 278 279 280 def __is_inf_loop(self, l_set): 281 for ad in l_set: 282 if ad in self.link_out: 283 for nxt in self.link_out[ad]: 284 if nxt not in l_set: 285 return False 286 return True 287 288 289 def path_exists(self, from_addr, to_addr, loop_start): 290 def __path_exists(curr, local_visited): 291 stack = [] 292 for n in self.link_out[from_addr]: 293 stack.append(n) 294 while stack: 295 curr = stack.pop(-1) 296 if curr == to_addr: 297 return True 298 if curr in local_visited: 299 continue 300 local_visited.add(curr) 301 if curr not in self.link_out or curr == loop_start: 302 continue 303 for n in self.link_out[curr]: 304 stack.append(n) 305 return False 306 307 if from_addr == to_addr: 308 return True 309 310 if from_addr not in self.link_out: 311 return False 312 313 if (from_addr, to_addr) in self.cache_path_exists: 314 return self.cache_path_exists[(from_addr, to_addr)] 315 316 local_visited = set() 317 res = __path_exists(from_addr, local_visited) 318 self.cache_path_exists[(from_addr, to_addr)] = res 319 return res 320 321 322 # Returns a set containing every addresses which are in paths from 323 # 'from_addr' to 'to_addr'. 324 def find_paths(self, from_addr, to_addr, global_visited): 325 def __rec_find_paths(curr, local_visited, path_set): 326 nonlocal isfirst 327 if curr == to_addr and not isfirst: 328 path_set.add(curr) 329 return 330 isfirst = False 331 if curr in local_visited: 332 return 333 local_visited.add(curr) 334 if curr in global_visited or curr not in self.link_out: 335 return 336 for n in self.link_out[curr]: 337 __rec_find_paths(n, local_visited, path_set) 338 339 if n in path_set: 340 path_set.add(curr) 341 342 isfirst = True 343 path_set = set() 344 local_visited = set() 345 __rec_find_paths(from_addr, local_visited, path_set) 346 return path_set 347 348 349 def __try_find_loops(self, entry, waiting, par_loops, l_set, is_sub_loop): 350 detected_loops = {} 351 keys = set(waiting.keys()) 352 353 for ad in keys: 354 if l_set is not None and ad not in l_set: 355 continue 356 357 if (entry, ad) in self.loops_set: 358 continue 359 360 # search a path from ad to ad, but don't return in pareent loops 361 l = self.find_paths(ad, ad, par_loops) 362 363 # If the set is empty, it's not a loop 364 if l: 365 self.loops_set[(entry, ad)] = l 366 is_sub_loop.add(ad) 367 detected_loops[ad] = (entry, ad) 368 369 return detected_loops 370 371 372 # This function removes entries in waiting list if we have seen 373 # all previous nodes for one node. 374 def __manage_waiting(self, stack, visited, waiting, l_set, done): 375 keys = set(waiting.keys()) 376 for ad in keys: 377 if l_set is not None and ad not in l_set: 378 continue 379 if len(waiting[ad]) == 0: 380 del waiting[ad] 381 done.add(ad) 382 stack.append((-1, ad)) 383 384 385 # This function reads all nodes in the current stack. 386 # Each node is added to the waiting list if all previous 387 # nodes were not seen. 388 # Then if there is an out-link, this new node is added to 389 # the stack, and it returns True. 390 def __until_stack_empty(self, stack, waiting, visited, 391 par_loops, l_set, is_sub_loop, done): 392 has_moved = False 393 394 while stack: 395 prev, ad = stack.pop(-1) 396 397 # if ad has parent nodes, check if we have seen all parents 398 if ad in self.link_in and ad not in done: 399 l_in = self.link_in[ad] 400 401 if len(l_in) > 1 or l_set is not None and ad not in l_set: 402 if ad in waiting: 403 if prev in waiting[ad]: 404 waiting[ad].remove(prev) 405 else: 406 unseen = set(l_in) 407 unseen.remove(prev) 408 waiting[ad] = unseen 409 continue 410 411 if ad in visited: 412 continue 413 414 visited.add(ad) 415 416 if ad in self.link_out: 417 for n in self.link_out[ad]: 418 if n in par_loops: 419 continue 420 stack.append((ad, n)) 421 has_moved = True 422 423 return has_moved 424 425 426 def __get_new_loops(self, waiting, detected_loops, l_set, is_sub_loop): 427 new_loops = set() 428 429 # Remove internal links to the beginning of the loop 430 # If later we enter in the loop it means that len(waiting[ad]) == 0 431 for ad, k in detected_loops.items(): 432 loop = self.loops_set[k] 433 434 was_removed = False 435 436 for rest in set(waiting[ad]): 437 if rest in loop: 438 waiting[ad].remove(rest) 439 was_removed = True 440 441 if was_removed: 442 if len(waiting[ad]) == 0: 443 new_loops.add(ad) 444 del waiting[ad] 445 446 # Remove external jumps which are outside the current loop 447 for ad, unseen in waiting.items(): 448 if l_set is not None and ad not in l_set: 449 continue 450 for i in set(unseen): 451 if l_set is not None and i not in l_set: 452 unseen.remove(i) 453 454 return new_loops 455 456 457 # It explores the graph and detects loops 458 def __explore(self, entry, par_loops, visited, waiting, l_set, done): 459 stack = [] 460 461 # Check if the first address (entry point of the function) is the 462 # beginning of a loop. 463 if not visited and entry in self.link_in: 464 for p in self.link_in[entry]: 465 stack.append((p, entry)) 466 else: 467 if entry in self.link_out: 468 for n in self.link_out[entry]: 469 stack.append((entry, n)) 470 visited.add(entry) 471 472 is_sub_loop = set() 473 474 while 1: 475 # first: manage pending nodes 476 if self.__until_stack_empty( 477 stack, waiting, visited, par_loops, l_set, is_sub_loop, done): 478 self.__manage_waiting(stack, visited, waiting, l_set, done) 479 continue 480 481 detected_loops = self.__try_find_loops( 482 entry, waiting, par_loops, l_set, is_sub_loop) 483 484 new_loops = self.__get_new_loops( 485 waiting, detected_loops, l_set, is_sub_loop) 486 487 while new_loops: 488 # Follow loops 489 for ad in new_loops: 490 # TODO : optimize 491 v = set(visited) 492 v.add(ad) 493 pl = set(par_loops) 494 pl.add(ad) 495 496 l = self.loops_set[(entry, ad)] 497 self.__explore(ad, pl, v, waiting, l, set(done)) 498 499 detected_loops = self.__try_find_loops( 500 entry, waiting, par_loops, l_set, is_sub_loop) 501 502 new_loops = self.__get_new_loops( 503 waiting, detected_loops, l_set, is_sub_loop) 504 505 506 self.__manage_waiting(stack, visited, waiting, l_set, done) 507 508 if not stack: 509 break 510 511 # Now for each current loop, we add the content of each sub-loops. 512 # It means that a loop contains all sub-loops (which is not the case 513 # in loops_set : in contains only the current loop). 514 for ad in is_sub_loop: 515 loop = set(self.loops_set[(entry, ad)]) 516 self.loops_all[(entry, ad)] = loop 517 518 self.deps[(entry, ad)] = set() 519 520 for (prev, start), l in self.loops_set.items(): 521 # Skip current loop 522 if (prev, start) == (entry, ad): 523 continue 524 525 # Is it a sub loop ? 526 if prev == ad and start != entry and start in loop: 527 k1 = (entry, ad) 528 k2 = (prev, start) 529 if k2 not in self.rev_deps: 530 self.rev_deps[k2] = set() 531 self.rev_deps[k2].add(k1) 532 self.deps[k1].add(k2) 533 self.loops_all[(entry, ad)].update(self.loops_all[(prev, start)]) 534 535 536 def all_false(self, loops_key): 537 for k in loops_key: 538 if k not in self.false_loops: 539 return False 540 return True 541 542 543 # Mark recursively parent loops, all children of a loop must 544 # have been set to false 545 def __rec_mark_parent_false(self, k): 546 self.false_loops.add(k) 547 if k not in self.rev_deps: 548 return 549 550 for par in self.rev_deps[k]: 551 if par in self.false_loops: 552 continue 553 554 if self.all_false(self.deps[par]): 555 self.__rec_mark_parent_false(par) 556 557 558 def __rec_mark_children(self, k, myset): 559 myset.add(k) 560 for sub in self.deps[k]: 561 if sub not in self.false_loops: 562 self.__rec_mark_children(sub, myset) 563 564 565 def __yield_cmp_loops(self, keys1, not_in_false=True): 566 # optim: don't compare twice two loops 567 keys2 = set(keys1) 568 for k1 in keys1: 569 keys2.remove(k1) 570 if not_in_false and k1 in self.false_loops: 571 continue 572 for k2 in keys2: 573 if not_in_false and k2 in self.false_loops: 574 continue 575 yield k1, k2 576 577 578 # Some heuristics to detect false loops 579 def __search_false_loops(self): 580 # If all previous link of a loop start is inside the loop, this 581 # is a false loop. 582 for (prev, start), l_set in self.loops_all.items(): 583 if prev == start: 584 # special case: see tests/entryloop1 585 continue 586 587 l_set_copy = set(l_set) 588 l_set_copy.remove(start) 589 lin = set(self.link_in[start]) 590 591 if lin.issubset(l_set): 592 self.false_loops.add((prev, start)) 593 594 595 # Find loops at the last level (no loop inside) 596 loops_to_check = set() 597 for k in self.loops_all: 598 if len(self.deps[k]) == 0: 599 loops_to_check.add(k) 600 601 602 # Now try to find other false loops... 603 for (prev1, start1), (prev2, start2) in \ 604 self.__yield_cmp_loops(self.loops_all.keys()): 605 606 if (prev1, start1) not in loops_to_check: 607 continue 608 609 l1 = self.loops_set[(prev1, start1)] 610 l2 = self.loops_set[(prev2, start2)] 611 612 if start1 == start2: 613 continue 614 615 if prev1 in l2 and \ 616 start1 in l2 and \ 617 start2 in l1: 618 if l1.issubset(l2): 619 self.__rec_mark_parent_false((prev1, start1)) 620 621 622 # Make a diff to keep only real loops (false loops will be 623 # deleted by __update_loops) 624 625 correct_loops = set() 626 627 for k in self.roots: 628 if k in self.false_loops: 629 continue 630 self.__rec_mark_children(k, correct_loops) 631 632 self.false_loops = self.loops_all.keys() - correct_loops 633 634 635 def __search_same_deep_equiv_loops(self): 636 # 637 # Search equivalent loops at the same deep, but compare with 638 # 'loops_all' -> each item contains all sub-loops instead of 639 # 'loops_set' wich contains only the loop. 640 # 641 # example: 642 # 643 # loop1 644 # / \ 645 # loop2 loop3 646 # 647 # If loops_all[loop2] == loops_all[loop3], and if loop2 or loop3 is 648 # in false_loops, we remove these loops. 649 # 650 651 def do_add(k1, k2): 652 nonlocal idx_count, set_index, deep_equiv 653 l1 = self.loops_all[k1] 654 l2 = self.loops_all[k2] 655 if l1 == l2: 656 if k1 in set_index: 657 i = set_index[k1] 658 deep_equiv[i].add(k2) 659 self.all_deep_equiv.add(k2) 660 set_index[k2] = i 661 elif k2 in set_index: 662 i = set_index[k2] 663 deep_equiv[i].add(k1) 664 self.all_deep_equiv.add(k1) 665 set_index[k1] = i 666 else: 667 i = idx_count 668 idx_count += 1 669 deep_equiv[i] = {k1, k2} 670 set_index[k1] = i 671 set_index[k2] = i 672 self.all_deep_equiv.add(k1) 673 self.all_deep_equiv.add(k2) 674 675 set_index = {} 676 deep_equiv = {} 677 idx_count = 0 678 679 for k in self.deps: 680 for k1, k2 in self.__yield_cmp_loops(self.deps[k], False): 681 do_add(k1, k2) 682 683 for k1, k2 in self.__yield_cmp_loops(self.roots, False): 684 do_add(k1, k2) 685 686 if not deep_equiv: 687 return 688 689 last_length = 0 690 while last_length != len(self.false_loops): 691 last_length = len(self.false_loops) 692 693 for i, keys in deep_equiv.items(): 694 nb_false = 0 695 for k in keys: 696 if k in self.false_loops: 697 nb_false += 1 698 699 if nb_false > 0: 700 for k in set(keys): 701 if k in self.false_loops: 702 continue 703 subs = self.deps[k] 704 if len(subs) == 0 or self.all_false(subs): 705 keys.remove(k) 706 self.__rec_mark_parent_false(k) 707 708 709 def __update_loops(self): 710 def rec_remove(k): 711 if k not in self.loops_all: 712 return 713 del self.loops_all[k] 714 del self.loops_set[k] 715 for sub in self.deps[k]: 716 if sub in self.false_loops: 717 rec_remove(sub) 718 for k in self.false_loops: 719 if k not in self.rev_deps or k in self.all_deep_equiv: 720 rec_remove(k) 721 722 723 def loop_detection(self, entry, bypass_false_search=False): 724 start = time() 725 726 # Equivalent loops at a same deep in the loops dependencies tree 727 self.deep_equiv = set() 728 # For one loop : contains all address of the loop only 729 self.loops_set = {} 730 # For one loop : contains all address of the loop and sub-loops 731 self.loops_all = {} 732 # Loop dependencies 733 self.deps = {} 734 self.rev_deps = {} 735 # Loops marked as "False" 736 self.false_loops = set() 737 738 if len(self.nodes) > MAX_NODES: 739 self.skipped_loops_analysis = True 740 return 741 742 # Detect loops and compute loop dependencies on the fly 743 self.__explore(entry, set(), set(), {}, None, set()) 744 745 # Keep only loops at the first level 746 self.roots = self.loops_set.keys() - self.rev_deps.keys() 747 748 # debug__(self.deps) 749 # debug__(self.roots) 750 # debug__(self.loops_all) 751 752 # Detect 'strange loops' 753 if not bypass_false_search: 754 self.__search_false_loops() 755 self.__search_same_deep_equiv_loops() 756 757 # Remove recursively marked false loops 758 self.__update_loops() 759 760 # debug__(self.loops_all) 761 762 # Compute all addresses which are not in a loop 763 in_loop = set() 764 for l in self.loops_set.items(): 765 in_loop.update(l[1]) 766 767 self.not_in_loop = self.nodes.keys() - in_loop 768 769 # Search inifinite loops 770 self.infinite_loop = set() 771 for l_curr_loop, l_set in self.loops_all.items(): 772 if self.__is_inf_loop(l_set): 773 self.infinite_loop.add(l_curr_loop) 774 775 # Save first address of loops 776 self.loops_start = set() 777 for _, l_start in self.loops_all: 778 self.loops_start.add(l_start) 779 780 # For each loop we search the last node that if we enter in it, 781 # we are sure to return to the loop. 782 self.last_node_loop = {} 783 for (l_prev_loop, l_start), l_set in self.loops_all.items(): 784 if (l_prev_loop, l_start) not in self.infinite_loop: 785 self.__search_last_node_loop(l_prev_loop, l_start, l_set) 786 787 if self.debug: 788 self.dot_loop_deps() 789 790 elapsed = time() 791 elapsed = elapsed - start 792 debug__("Exploration: found %d loop(s) in %fs" % 793 (len(self.loops_all), elapsed)) 794