1# pylint:disable=too-many-boolean-expressions 2import logging 3from typing import Optional, TYPE_CHECKING 4 5from ailment.statement import Assignment, ConditionalJump, Call 6from ailment.expression import Expression, Convert, Tmp, Register, Load, BinaryOp, UnaryOp, Const, ITE 7 8from ...engines.light.data import SpOffset 9from ...knowledge_plugins.key_definitions.constants import OP_AFTER 10from ...knowledge_plugins.key_definitions import atoms 11from ...analyses.reaching_definitions.external_codeloc import ExternalCodeLocation 12 13from .. import Analysis, register_analysis 14 15if TYPE_CHECKING: 16 from ailment.block import Block 17 18 19_l = logging.getLogger(name=__name__) 20 21 22class BlockSimplifier(Analysis): 23 """ 24 Simplify an AIL block. 25 """ 26 def __init__(self, block: Optional['Block'], remove_dead_memdefs=False, stack_pointer_tracker=None): 27 """ 28 :param block: The AIL block to simplify. Setting it to None to skip calling self._analyze(), which is useful 29 in test cases. 30 """ 31 32 self.block = block 33 34 self._remove_dead_memdefs = remove_dead_memdefs 35 self._stack_pointer_tracker = stack_pointer_tracker 36 37 self.result_block = None 38 39 if self.block is not None: 40 self._analyze() 41 42 def _analyze(self): 43 44 block = self.block 45 ctr = 0 46 max_ctr = 30 47 48 block = self._eliminate_self_assignments(block) 49 block = self._eliminate_dead_assignments(block) 50 51 while True: 52 ctr += 1 53 # print(str(block)) 54 new_block = self._simplify_block_once(block) 55 # print() 56 # print(str(new_block)) 57 if new_block == block: 58 break 59 block = new_block 60 if ctr >= max_ctr: 61 _l.error("Simplification does not reach a fixed point after %d iterations. " 62 "Block comparison is probably incorrect.", max_ctr) 63 break 64 65 self.result_block = block 66 67 def _simplify_block_once(self, block): 68 69 # propagator 70 propagator = self.project.analyses.Propagator(block=block, stack_pointer_tracker=self._stack_pointer_tracker) 71 replacements = list(propagator._states.values())[0]._replacements 72 if not replacements: 73 return block 74 new_block = self._replace_and_build(block, replacements) 75 new_block = self._eliminate_dead_assignments(new_block) 76 new_block = self._peephole_optimize(new_block) 77 return new_block 78 79 @staticmethod 80 def _replace_and_build(block, replacements): 81 82 new_statements = block.statements[::] 83 84 for codeloc, repls in replacements.items(): 85 for old, new in repls.items(): 86 if isinstance(old, Load): 87 # skip memory-based replacement 88 continue 89 stmt = new_statements[codeloc.stmt_idx] 90 if stmt == old: 91 # replace this statement 92 r = True 93 new_stmt = new 94 else: 95 # replace the expressions involved in this statement 96 if isinstance(stmt, Call) and isinstance(new, Call) and old == stmt.ret_expr: 97 # special case: do not replace the ret_expr of a call statement to another call statement 98 r = False 99 new_stmt = None 100 else: 101 r, new_stmt = stmt.replace(old, new) 102 103 if r: 104 new_statements[codeloc.stmt_idx] = new_stmt 105 106 new_block = block.copy() 107 new_block.statements = new_statements 108 return new_block 109 110 @staticmethod 111 def _eliminate_self_assignments(block): 112 113 new_statements = [ ] 114 115 for stmt in block.statements: 116 if type(stmt) is Assignment: 117 if stmt.dst.likes(stmt.src): 118 continue 119 new_statements.append(stmt) 120 121 new_block = block.copy(statements=new_statements) 122 return new_block 123 124 def _eliminate_dead_assignments(self, block): 125 126 new_statements = [ ] 127 if not block.statements: 128 return block 129 130 rd = self.project.analyses.ReachingDefinitions(subject=block, 131 track_tmps=True, 132 observation_points=[('node', block.addr, OP_AFTER)] 133 ) 134 135 used_tmp_indices = set(rd.one_result.tmp_uses.keys()) 136 live_defs = rd.one_result 137 138 # Find dead assignments 139 dead_defs_stmt_idx = set() 140 all_defs = rd.all_definitions 141 for d in all_defs: 142 if isinstance(d.codeloc, ExternalCodeLocation) or d.dummy: 143 continue 144 if not self._remove_dead_memdefs and isinstance(d.atom, (atoms.MemoryLocation, SpOffset)): 145 continue 146 147 if isinstance(d.atom, atoms.Tmp): 148 uses = live_defs.tmp_uses[d.atom.tmp_idx] 149 if not uses: 150 dead_defs_stmt_idx.add(d.codeloc.stmt_idx) 151 else: 152 uses = rd.all_uses.get_uses(d) 153 if not uses: 154 # is entirely possible that at the end of the block, a register definition is not used. 155 # however, it might be used in future blocks. 156 # so we only remove a definition if the definition is not alive anymore at the end of the block 157 if isinstance(d.atom, atoms.Register): 158 if d not in live_defs.register_definitions.get_variables_by_offset(d.atom.reg_offset): 159 dead_defs_stmt_idx.add(d.codeloc.stmt_idx) 160 if isinstance(d.atom, atoms.MemoryLocation) and isinstance(d.atom.addr, SpOffset): 161 if d not in live_defs.stack_definitions.get_variables_by_offset(d.atom.addr.offset): 162 dead_defs_stmt_idx.add(d.codeloc.stmt_idx) 163 164 # Remove dead assignments 165 for idx, stmt in enumerate(block.statements): 166 if type(stmt) is Assignment: 167 if type(stmt.dst) is Tmp: 168 if stmt.dst.tmp_idx not in used_tmp_indices: 169 continue 170 171 # is it a dead virgin? 172 if idx in dead_defs_stmt_idx: 173 continue 174 175 # is it an assignment to an artificial register? 176 if type(stmt.dst) is Register and self.project.arch.is_artificial_register(stmt.dst.reg_offset, stmt.dst.size): 177 continue 178 179 if stmt.src == stmt.dst: 180 continue 181 182 new_statements.append(stmt) 183 184 new_block = block.copy(statements=new_statements) 185 return new_block 186 187 # 188 # Peephole optimization 189 # 190 191 def _peephole_optimize(self, block): 192 193 statements = [ ] 194 any_update = False 195 for stmt in block.statements: 196 new_stmt = None 197 if isinstance(stmt, Assignment): 198 new_stmt = self._peephole_optimize_ConstantDereference(stmt) 199 200 elif isinstance(stmt, ConditionalJump): 201 new_stmt = self._peephole_optimize_ConditionalJump(stmt) 202 203 if new_stmt is not None and new_stmt is not stmt: 204 statements.append(new_stmt) 205 any_update = True 206 continue 207 208 statements.append(stmt) 209 210 if not any_update: 211 return block 212 new_block = block.copy(statements=statements) 213 return new_block 214 215 def _peephole_optimize_ConstantDereference(self, stmt: Assignment): 216 if isinstance(stmt.src, Load) and isinstance(stmt.src.addr, Const): 217 # is it loading from a read-only section? 218 sec = self.project.loader.find_section_containing(stmt.src.addr.value) 219 if sec is not None and sec.is_readable and not sec.is_writable: 220 # do we know the value that it's reading? 221 try: 222 val = self.project.loader.memory.unpack_word(stmt.src.addr.value, size=self.project.arch.bytes) 223 except KeyError: 224 return stmt 225 226 return Assignment(stmt.idx, stmt.dst, 227 Const(None, None, val, stmt.src.size * self.project.arch.byte_width), 228 **stmt.tags, 229 ) 230 231 return stmt 232 233 def _peephole_optimize_ConditionalJump(self, stmt: ConditionalJump): 234 235 new_condition = self._peephole_optimize_Expr(stmt.condition) 236 237 # if (!cond) {} else { ITE(cond, true_branch, false_branch } ==> if (cond) { ITE(...) } else {} 238 if isinstance(stmt.false_target, ITE) and \ 239 isinstance(new_condition, UnaryOp) and \ 240 new_condition.op == "Not": 241 new_true_target = stmt.false_target 242 new_false_target = stmt.true_target 243 new_condition = new_condition.operand 244 else: 245 new_true_target = stmt.true_target 246 new_false_target = stmt.false_target 247 248 if new_condition is not stmt.condition or \ 249 new_true_target is not stmt.true_target or \ 250 new_false_target is not stmt.false_target: 251 # it's updated 252 return self._peephole_optimize_ConditionalJump( 253 ConditionalJump(stmt.idx, new_condition, new_true_target, new_false_target, **stmt.tags) 254 ) 255 256 # if (cond) {ITE(cond, true_branch, false_branch)} else {} ==> if (cond) {true_branch} else {} 257 if isinstance(stmt.true_target, ITE) and new_condition == stmt.true_target.cond: 258 new_true_target = stmt.true_target.iftrue 259 else: 260 new_true_target = stmt.true_target 261 262 if new_condition is not stmt.condition or new_true_target is not stmt.true_target: 263 # it's updated 264 return self._peephole_optimize_ConditionalJump( 265 ConditionalJump(stmt.idx, new_condition, new_true_target, stmt.false_target, **stmt.tags) 266 ) 267 268 return stmt 269 270 def _peephole_optimize_Expr(self, expr: Expression): 271 272 # Convert(N->1, (Convert(1->N, t_x) ^ 0x1<N>) ==> Not(t_x) 273 if isinstance(expr, Convert) and \ 274 isinstance(expr.operand, BinaryOp) and \ 275 expr.operand.op == "Xor" and \ 276 isinstance(expr.operand.operands[0], Convert) and \ 277 isinstance(expr.operand.operands[1], Const) and \ 278 expr.operand.operands[1].value == 1: 279 new_expr = UnaryOp(None, "Not", expr.operand.operands[0].operand) 280 return self._peephole_optimize_Expr(new_expr) 281 282 return expr 283 284 285register_analysis(BlockSimplifier, 'AILBlockSimplifier') 286