1# pylint:disable=arguments-differ
2import logging
3
4from ailment import Stmt, Expr
5
6from ...utils.constants import is_alignment_mask
7from ...engines.light import SimEngineLightAILMixin
8from ...sim_variable import SimStackVariable
9from .engine_base import SimEnginePropagatorBase
10from .values import Top
11
12l = logging.getLogger(name=__name__)
13
14
15class SimEnginePropagatorAIL(
16    SimEngineLightAILMixin,
17    SimEnginePropagatorBase,
18):
19
20    #
21    # AIL statement handlers
22    #
23
24    def _ail_handle_Assignment(self, stmt):
25        """
26
27        :param Stmt.Assignment stmt:
28        :return:
29        """
30
31        src = self._expr(stmt.src)
32        dst = stmt.dst
33
34        if type(dst) is Expr.Tmp:
35            new_src = self.state.get_variable(src)
36            if new_src is not None:
37                l.debug("%s = %s, replace %s with %s.", dst, src, src, new_src)
38                self.state.store_variable(dst, new_src)
39
40            else:
41                l.debug("Replacing %s with %s.", dst, src)
42                self.state.store_variable(dst, src)
43
44        elif type(dst) is Expr.Register:
45            self.state.store_variable(dst, src)
46            if isinstance(stmt.src, (Expr.Register, Stmt.Call)):
47                # set equivalence
48                self.state.add_equivalence(self._codeloc(), dst, stmt.src)
49        else:
50            l.warning('Unsupported type of Assignment dst %s.', type(dst).__name__)
51
52    def _ail_handle_Store(self, stmt):
53        addr = self._expr(stmt.addr)
54        data = self._expr(stmt.data)
55
56        if isinstance(addr, Expr.StackBaseOffset):
57            if data is not None:
58                # Storing data to a stack variable
59                self.state.store_stack_variable(addr, data.bits // 8, data, endness=stmt.endness)
60                # set equivalence
61                var = SimStackVariable(addr.offset, data.bits // 8)
62                self.state.add_equivalence(self._codeloc(), var, stmt.data)
63
64    def _ail_handle_Jump(self, stmt):
65        target = self._expr(stmt.target)
66        if target == stmt.target:
67            return
68
69        new_jump_stmt = Stmt.Jump(stmt.idx, target, **stmt.tags)
70        self.state.add_replacement(self._codeloc(),
71                                   stmt,
72                                   new_jump_stmt,
73                                   )
74
75    def _ail_handle_Call(self, expr_stmt: Stmt.Call):
76        _ = self._expr(expr_stmt.target)
77
78        if expr_stmt.args:
79            for arg in expr_stmt.args:
80                _ = self._expr(arg)
81
82        if expr_stmt.ret_expr:
83            # it has a return expression. awesome - treat it as an assignment
84            # set equivalence
85            self.state.add_equivalence(self._codeloc(), expr_stmt.ret_expr, expr_stmt)
86
87    def _ail_handle_ConditionalJump(self, stmt):
88        _ = self._expr(stmt.condition)
89        _ = self._expr(stmt.true_target)
90        _ = self._expr(stmt.false_target)
91
92    def _ail_handle_Return(self, stmt: Stmt.Return):
93        if stmt.ret_exprs:
94            for ret_expr in stmt.ret_exprs:
95                self._expr(ret_expr)
96
97    #
98    # AIL expression handlers
99    #
100
101    def _ail_handle_Tmp(self, expr):
102        new_expr = self.state.get_variable(expr)
103
104        if new_expr is not None:
105            # check if this new_expr uses any expression that has been overwritten
106            new_value = self._expr(new_expr)
107            if new_value != new_expr:
108                return expr
109
110            l.debug("Add a replacement: %s with %s", expr, new_expr)
111            self.state.add_replacement(self._codeloc(), expr, new_expr)
112            if type(new_expr) in [Expr.Register, Expr.Const, Expr.Convert, Expr.StackBaseOffset, Expr.BasePointerOffset]:
113                expr = new_expr
114
115        return expr
116
117    def _ail_handle_Register(self, expr):
118        # Special handling for SP and BP
119        if self._stack_pointer_tracker is not None:
120            if expr.reg_offset == self.arch.sp_offset:
121                sb_offset = self._stack_pointer_tracker.offset_before(self.ins_addr, self.arch.sp_offset)
122                if sb_offset is not None:
123                    new_expr = Expr.StackBaseOffset(None, self.arch.bits, sb_offset)
124                    self.state.add_replacement(self._codeloc(), expr, new_expr)
125                    return new_expr
126            elif expr.reg_offset == self.arch.bp_offset:
127                sb_offset = self._stack_pointer_tracker.offset_before(self.ins_addr, self.arch.bp_offset)
128                if sb_offset is not None:
129                    new_expr = Expr.StackBaseOffset(None, self.arch.bits, sb_offset)
130                    self.state.add_replacement(self._codeloc(), expr, new_expr)
131                    return new_expr
132
133        new_expr = self.state.get_variable(expr)
134        if new_expr is not None:
135            l.debug("Add a replacement: %s with %s", expr, new_expr)
136            self.state.add_replacement(self._codeloc(), expr, new_expr)
137            expr = new_expr
138        return expr
139
140    def _ail_handle_Load(self, expr):
141        addr = self._expr(expr.addr)
142
143        if type(addr) is Top:
144            return Top(expr.size)
145
146        if isinstance(addr, Expr.StackBaseOffset):
147            var = self.state.get_stack_variable(addr, expr.size, endness=expr.endness)
148            if var is not None:
149                return var
150
151        if addr != expr.addr:
152            return Expr.Load(expr.idx, addr, expr.size, expr.endness, **expr.tags)
153        return expr
154
155    def _ail_handle_Convert(self, expr):
156        operand_expr = self._expr(expr.operand)
157
158        if type(operand_expr) is Top:
159            return Top(expr.to_bits // 8)
160
161        if type(operand_expr) is Expr.Convert:
162            if expr.from_bits == operand_expr.to_bits and expr.to_bits == operand_expr.from_bits:
163                # eliminate the redundant Convert
164                return operand_expr.operand
165            else:
166                return Expr.Convert(expr.idx, operand_expr.from_bits, expr.to_bits, expr.is_signed, operand_expr.operand)
167        elif type(operand_expr) is Expr.Const:
168            # do the conversion right away
169            value = operand_expr.value
170            mask = (2 ** expr.to_bits) - 1
171            value &= mask
172            return Expr.Const(expr.idx, operand_expr.variable, value, expr.to_bits)
173
174        converted = Expr.Convert(expr.idx, expr.from_bits, expr.to_bits, expr.is_signed, operand_expr, **expr.tags)
175        return converted
176
177    def _ail_handle_Const(self, expr):
178        return expr
179
180    def _ail_handle_DirtyExpression(self, expr):  # pylint:disable=no-self-use
181        return expr
182
183    def _ail_handle_ITE(self, expr: Expr.ITE):
184        cond = self._expr(expr.cond)  # pylint:disable=unused-variable
185        iftrue = self._expr(expr.iftrue)  # pylint:disable=unused-variable
186        iffalse = self._expr(expr.iffalse)  # pylint:disable=unused-variable
187
188        return expr
189
190    def _ail_handle_CallExpr(self, expr_stmt: Stmt.Call):  # pylint:disable=useless-return
191        _ = self._expr(expr_stmt.target)
192
193        if expr_stmt.args:
194            for arg in expr_stmt.args:
195                _ = self._expr(arg)
196
197        # ignore ret_expr
198        return expr_stmt
199
200    def _ail_handle_CmpLE(self, expr):
201        operand_0 = self._expr(expr.operands[0])
202        operand_1 = self._expr(expr.operands[1])
203
204        if type(operand_0) is Top or type(operand_1) is Top:
205            return Top(1)
206
207        return Expr.BinaryOp(expr.idx, 'CmpLE', [ operand_0, operand_1 ], expr.signed, **expr.tags)
208
209    def _ail_handle_CmpLEs(self, expr):
210        operand_0 = self._expr(expr.operands[0])
211        operand_1 = self._expr(expr.operands[1])
212
213        if type(operand_0) is Top or type(operand_1) is Top:
214            return Top(1)
215
216        return Expr.BinaryOp(expr.idx, 'CmpLEs', [ operand_0, operand_1 ], expr.signed, **expr.tags)
217
218    def _ail_handle_CmpLT(self, expr):
219        operand_0 = self._expr(expr.operands[0])
220        operand_1 = self._expr(expr.operands[1])
221
222        if type(operand_0) is Top or type(operand_1) is Top:
223            return Top(1)
224
225        return Expr.BinaryOp(expr.idx, 'CmpLT', [ operand_0, operand_1 ], expr.signed, **expr.tags)
226
227    def _ail_handle_CmpLTs(self, expr):
228        operand_0 = self._expr(expr.operands[0])
229        operand_1 = self._expr(expr.operands[1])
230
231        if type(operand_0) is Top or type(operand_1) is Top:
232            return Top(1)
233
234        return Expr.BinaryOp(expr.idx, 'CmpLTs', [ operand_0, operand_1 ], expr.signed, **expr.tags)
235
236    def _ail_handle_CmpGE(self, expr):
237        operand_0 = self._expr(expr.operands[0])
238        operand_1 = self._expr(expr.operands[1])
239
240        if type(operand_0) is Top or type(operand_1) is Top:
241            return Top(1)
242
243        return Expr.BinaryOp(expr.idx, 'CmpGE', [ operand_0, operand_1 ], expr.signed, **expr.tags)
244
245    def _ail_handle_CmpGEs(self, expr):
246        operand_0 = self._expr(expr.operands[0])
247        operand_1 = self._expr(expr.operands[1])
248
249        if type(operand_0) is Top or type(operand_1) is Top:
250            return Top(1)
251
252        return Expr.BinaryOp(expr.idx, 'CmpGEs', [ operand_0, operand_1 ], expr.signed, **expr.tags)
253
254    def _ail_handle_CmpGT(self, expr):
255        operand_0 = self._expr(expr.operands[0])
256        operand_1 = self._expr(expr.operands[1])
257
258        if type(operand_0) is Top or type(operand_1) is Top:
259            return Top(1)
260
261        return Expr.BinaryOp(expr.idx, 'CmpGT', [ operand_0, operand_1 ], expr.signed, **expr.tags)
262
263    def _ail_handle_CmpGTs(self, expr):
264        operand_0 = self._expr(expr.operands[0])
265        operand_1 = self._expr(expr.operands[1])
266
267        if type(operand_0) is Top or type(operand_1) is Top:
268            return Top(1)
269
270        return Expr.BinaryOp(expr.idx, 'CmpGTs', [ operand_0, operand_1 ], expr.signed, **expr.tags)
271
272    def _ail_handle_CmpEQ(self, expr):
273        operand_0 = self._expr(expr.operands[0])
274        operand_1 = self._expr(expr.operands[1])
275
276        if type(operand_0) is Top or type(operand_1) is Top:
277            return Top(1)
278
279        return Expr.BinaryOp(expr.idx, 'CmpEQ', [ operand_0, operand_1 ], expr.signed, **expr.tags)
280
281    def _ail_handle_CmpNE(self, expr):
282        operand_0 = self._expr(expr.operands[0])
283        operand_1 = self._expr(expr.operands[1])
284
285        if type(operand_0) is Top or type(operand_1) is Top:
286            return Top(1)
287
288        return Expr.BinaryOp(expr.idx, 'CmpNE', [ operand_0, operand_1 ], expr.signed, **expr.tags)
289
290    def _ail_handle_Add(self, expr):
291        operand_0 = self._expr(expr.operands[0])
292        operand_1 = self._expr(expr.operands[1])
293
294        if type(operand_0) is Top or type(operand_1) is Top:
295            return Top(operand_0.size)
296
297        if isinstance(operand_0, Expr.Const) and isinstance(operand_1, Expr.Const):
298            return Expr.Const(expr.idx, None, operand_0.value + operand_1.value, expr.bits)
299        elif isinstance(operand_0, Expr.BasePointerOffset) and isinstance(operand_1, Expr.Const):
300            r = operand_0.copy()
301            r.offset += operand_1.value
302            return r
303        return Expr.BinaryOp(expr.idx, 'Add', [operand_0 if operand_0 is not None else expr.operands[0],
304                                               operand_1 if operand_1 is not None else expr.operands[1]
305                                               ],
306                             expr.signed)
307
308    def _ail_handle_Sub(self, expr):
309        operand_0 = self._expr(expr.operands[0])
310        operand_1 = self._expr(expr.operands[1])
311
312        if type(operand_0) is Top or type(operand_1) is Top:
313            return Top(operand_0.size)
314
315        if isinstance(operand_0, Expr.Const) and isinstance(operand_1, Expr.Const):
316            return Expr.Const(expr.idx, None, operand_0.value - operand_1.value, expr.bits)
317        elif isinstance(operand_0, Expr.BasePointerOffset) and isinstance(operand_1, Expr.Const):
318            r = operand_0.copy()
319            r.offset -= operand_1.value
320            return r
321        if type(operand_0) is Top or type(operand_1) is Top:
322            return Top(expr.bits // 8)
323        return Expr.BinaryOp(expr.idx, 'Sub', [ operand_0 if operand_0 is not None else expr.operands[0],
324                                                operand_1 if operand_1 is not None else expr.operands[1]
325                                                ],
326                             expr.signed,
327                             **expr.tags)
328
329    def _ail_handle_StackBaseOffset(self, expr):  # pylint:disable=no-self-use
330        return expr
331
332    def _ail_handle_And(self, expr):
333        operand_0 = self._expr(expr.operands[0])
334        operand_1 = self._expr(expr.operands[1])
335
336        if type(operand_0) is Top or type(operand_1) is Top:
337            return Top(operand_0.size)
338
339        # Special logic for SP alignment
340        if type(operand_0) is Expr.StackBaseOffset and \
341                type(operand_1) is Expr.Const and is_alignment_mask(operand_1.value):
342            return operand_0
343
344        return Expr.BinaryOp(expr.idx, 'And', [ operand_0, operand_1 ], expr.signed, **expr.tags)
345
346    def _ail_handle_Xor(self, expr):
347        operand_0 = self._expr(expr.operands[0])
348        operand_1 = self._expr(expr.operands[1])
349
350        if type(operand_0) is Top or type(operand_1) is Top:
351            return Top(operand_0.size)
352
353        return Expr.BinaryOp(expr.idx, 'Xor', [ operand_0, operand_1 ], expr.signed, **expr.tags)
354