1# pylint:disable=too-many-boolean-expressions
2import logging
3from typing import Optional, TYPE_CHECKING
4
5from ailment.statement import Assignment, ConditionalJump, Call
6from ailment.expression import Expression, Convert, Tmp, Register, Load, BinaryOp, UnaryOp, Const, ITE
7
8from ...engines.light.data import SpOffset
9from ...knowledge_plugins.key_definitions.constants import OP_AFTER
10from ...knowledge_plugins.key_definitions import atoms
11from ...analyses.reaching_definitions.external_codeloc import ExternalCodeLocation
12
13from .. import Analysis, register_analysis
14
15if TYPE_CHECKING:
16    from ailment.block import Block
17
18
19_l = logging.getLogger(name=__name__)
20
21
22class BlockSimplifier(Analysis):
23    """
24    Simplify an AIL block.
25    """
26    def __init__(self, block: Optional['Block'], remove_dead_memdefs=False, stack_pointer_tracker=None):
27        """
28        :param block:   The AIL block to simplify. Setting it to None to skip calling self._analyze(), which is useful
29                        in test cases.
30        """
31
32        self.block = block
33
34        self._remove_dead_memdefs = remove_dead_memdefs
35        self._stack_pointer_tracker = stack_pointer_tracker
36
37        self.result_block = None
38
39        if self.block is not None:
40            self._analyze()
41
42    def _analyze(self):
43
44        block = self.block
45        ctr = 0
46        max_ctr = 30
47
48        block = self._eliminate_self_assignments(block)
49        block = self._eliminate_dead_assignments(block)
50
51        while True:
52            ctr += 1
53            # print(str(block))
54            new_block = self._simplify_block_once(block)
55            # print()
56            # print(str(new_block))
57            if new_block == block:
58                break
59            block = new_block
60            if ctr >= max_ctr:
61                _l.error("Simplification does not reach a fixed point after %d iterations. "
62                         "Block comparison is probably incorrect.", max_ctr)
63                break
64
65        self.result_block = block
66
67    def _simplify_block_once(self, block):
68
69        # propagator
70        propagator = self.project.analyses.Propagator(block=block, stack_pointer_tracker=self._stack_pointer_tracker)
71        replacements = list(propagator._states.values())[0]._replacements
72        if not replacements:
73            return block
74        new_block = self._replace_and_build(block, replacements)
75        new_block = self._eliminate_dead_assignments(new_block)
76        new_block = self._peephole_optimize(new_block)
77        return new_block
78
79    @staticmethod
80    def _replace_and_build(block, replacements):
81
82        new_statements = block.statements[::]
83
84        for codeloc, repls in replacements.items():
85            for old, new in repls.items():
86                if isinstance(old, Load):
87                    # skip memory-based replacement
88                    continue
89                stmt = new_statements[codeloc.stmt_idx]
90                if stmt == old:
91                    # replace this statement
92                    r = True
93                    new_stmt = new
94                else:
95                    # replace the expressions involved in this statement
96                    if isinstance(stmt, Call) and isinstance(new, Call) and old == stmt.ret_expr:
97                        # special case: do not replace the ret_expr of a call statement to another call statement
98                        r = False
99                        new_stmt = None
100                    else:
101                        r, new_stmt = stmt.replace(old, new)
102
103                if r:
104                    new_statements[codeloc.stmt_idx] = new_stmt
105
106        new_block = block.copy()
107        new_block.statements = new_statements
108        return new_block
109
110    @staticmethod
111    def _eliminate_self_assignments(block):
112
113        new_statements = [ ]
114
115        for stmt in block.statements:
116            if type(stmt) is Assignment:
117                if stmt.dst.likes(stmt.src):
118                    continue
119            new_statements.append(stmt)
120
121        new_block = block.copy(statements=new_statements)
122        return new_block
123
124    def _eliminate_dead_assignments(self, block):
125
126        new_statements = [ ]
127        if not block.statements:
128            return block
129
130        rd = self.project.analyses.ReachingDefinitions(subject=block,
131                                                       track_tmps=True,
132                                                       observation_points=[('node', block.addr, OP_AFTER)]
133                                                       )
134
135        used_tmp_indices = set(rd.one_result.tmp_uses.keys())
136        live_defs = rd.one_result
137
138        # Find dead assignments
139        dead_defs_stmt_idx = set()
140        all_defs = rd.all_definitions
141        for d in all_defs:
142            if isinstance(d.codeloc, ExternalCodeLocation) or d.dummy:
143                continue
144            if not self._remove_dead_memdefs and isinstance(d.atom, (atoms.MemoryLocation, SpOffset)):
145                continue
146
147            if isinstance(d.atom, atoms.Tmp):
148                uses = live_defs.tmp_uses[d.atom.tmp_idx]
149                if not uses:
150                    dead_defs_stmt_idx.add(d.codeloc.stmt_idx)
151            else:
152                uses = rd.all_uses.get_uses(d)
153                if not uses:
154                    # is entirely possible that at the end of the block, a register definition is not used.
155                    # however, it might be used in future blocks.
156                    # so we only remove a definition if the definition is not alive anymore at the end of the block
157                    if isinstance(d.atom, atoms.Register):
158                        if d not in live_defs.register_definitions.get_variables_by_offset(d.atom.reg_offset):
159                            dead_defs_stmt_idx.add(d.codeloc.stmt_idx)
160                    if isinstance(d.atom, atoms.MemoryLocation) and isinstance(d.atom.addr, SpOffset):
161                        if d not in live_defs.stack_definitions.get_variables_by_offset(d.atom.addr.offset):
162                            dead_defs_stmt_idx.add(d.codeloc.stmt_idx)
163
164        # Remove dead assignments
165        for idx, stmt in enumerate(block.statements):
166            if type(stmt) is Assignment:
167                if type(stmt.dst) is Tmp:
168                    if stmt.dst.tmp_idx not in used_tmp_indices:
169                        continue
170
171                # is it a dead virgin?
172                if idx in dead_defs_stmt_idx:
173                    continue
174
175                # is it an assignment to an artificial register?
176                if type(stmt.dst) is Register and self.project.arch.is_artificial_register(stmt.dst.reg_offset, stmt.dst.size):
177                    continue
178
179                if stmt.src == stmt.dst:
180                    continue
181
182            new_statements.append(stmt)
183
184        new_block = block.copy(statements=new_statements)
185        return new_block
186
187    #
188    # Peephole optimization
189    #
190
191    def _peephole_optimize(self, block):
192
193        statements = [ ]
194        any_update = False
195        for stmt in block.statements:
196            new_stmt = None
197            if isinstance(stmt, Assignment):
198                new_stmt = self._peephole_optimize_ConstantDereference(stmt)
199
200            elif isinstance(stmt, ConditionalJump):
201                new_stmt = self._peephole_optimize_ConditionalJump(stmt)
202
203            if new_stmt is not None and new_stmt is not stmt:
204                statements.append(new_stmt)
205                any_update = True
206                continue
207
208            statements.append(stmt)
209
210        if not any_update:
211            return block
212        new_block = block.copy(statements=statements)
213        return new_block
214
215    def _peephole_optimize_ConstantDereference(self, stmt: Assignment):
216        if isinstance(stmt.src, Load) and isinstance(stmt.src.addr, Const):
217            # is it loading from a read-only section?
218            sec = self.project.loader.find_section_containing(stmt.src.addr.value)
219            if sec is not None and sec.is_readable and not sec.is_writable:
220                # do we know the value that it's reading?
221                try:
222                    val = self.project.loader.memory.unpack_word(stmt.src.addr.value, size=self.project.arch.bytes)
223                except KeyError:
224                    return stmt
225
226                return Assignment(stmt.idx, stmt.dst,
227                                  Const(None, None, val, stmt.src.size * self.project.arch.byte_width),
228                                  **stmt.tags,
229                                  )
230
231        return stmt
232
233    def _peephole_optimize_ConditionalJump(self, stmt: ConditionalJump):
234
235        new_condition = self._peephole_optimize_Expr(stmt.condition)
236
237        # if (!cond) {} else { ITE(cond, true_branch, false_branch } ==> if (cond) { ITE(...) } else {}
238        if isinstance(stmt.false_target, ITE) and \
239                isinstance(new_condition, UnaryOp) and \
240                new_condition.op == "Not":
241            new_true_target = stmt.false_target
242            new_false_target = stmt.true_target
243            new_condition = new_condition.operand
244        else:
245            new_true_target = stmt.true_target
246            new_false_target = stmt.false_target
247
248        if new_condition is not stmt.condition or \
249                new_true_target is not stmt.true_target or \
250                new_false_target is not stmt.false_target:
251            # it's updated
252            return self._peephole_optimize_ConditionalJump(
253                ConditionalJump(stmt.idx, new_condition, new_true_target, new_false_target, **stmt.tags)
254            )
255
256        # if (cond) {ITE(cond, true_branch, false_branch)} else {} ==> if (cond) {true_branch} else {}
257        if isinstance(stmt.true_target, ITE) and new_condition == stmt.true_target.cond:
258            new_true_target = stmt.true_target.iftrue
259        else:
260            new_true_target = stmt.true_target
261
262        if new_condition is not stmt.condition or new_true_target is not stmt.true_target:
263            # it's updated
264            return self._peephole_optimize_ConditionalJump(
265                ConditionalJump(stmt.idx, new_condition, new_true_target, stmt.false_target, **stmt.tags)
266            )
267
268        return stmt
269
270    def _peephole_optimize_Expr(self, expr: Expression):
271
272        # Convert(N->1, (Convert(1->N, t_x) ^ 0x1<N>) ==> Not(t_x)
273        if isinstance(expr, Convert) and \
274                isinstance(expr.operand, BinaryOp) and \
275                expr.operand.op == "Xor" and \
276                isinstance(expr.operand.operands[0], Convert) and \
277                isinstance(expr.operand.operands[1], Const) and \
278                expr.operand.operands[1].value == 1:
279            new_expr = UnaryOp(None, "Not", expr.operand.operands[0].operand)
280            return self._peephole_optimize_Expr(new_expr)
281
282        return expr
283
284
285register_analysis(BlockSimplifier, 'AILBlockSimplifier')
286