1from typing import Set, Optional
2from collections import defaultdict
3import logging
4
5import ailment
6
7from ... import sim_options
8from ...engines.light import SpOffset
9from ...keyed_region import KeyedRegion
10from ...code_location import CodeLocation  # pylint:disable=unused-import
11from .. import register_analysis
12from ..analysis import Analysis
13from ..forward_analysis import ForwardAnalysis, FunctionGraphVisitor, SingleNodeGraphVisitor
14from .values import Top
15from .engine_vex import SimEnginePropagatorVEX
16from .engine_ail import SimEnginePropagatorAIL
17
18_l = logging.getLogger(name=__name__)
19
20
21# The base state
22
23class PropagatorState:
24    def __init__(self, arch, replacements=None, only_consts=False, prop_count=None, equivalence=None):
25        self.arch = arch
26        self.gpr_size = arch.bits // arch.byte_width  # size of the general-purpose registers
27
28        # propagation count of each expression
29        self._prop_count = defaultdict(int) if prop_count is None else prop_count
30        self._only_consts = only_consts
31        self._replacements = defaultdict(dict) if replacements is None else replacements
32        self._equivalence: Set[Equivalence] = equivalence if equivalence is not None else set()
33
34    def __repr__(self):
35        return "<PropagatorState>"
36
37    def copy(self) -> 'PropagatorState':
38        raise NotImplementedError()
39
40    def merge(self, *others):
41
42        state = self.copy()
43
44        for o in others:
45            for loc, vars_ in o._replacements.items():
46                if loc not in state._replacements:
47                    state._replacements[loc] = vars_.copy()
48                else:
49                    for var, repl in vars_.items():
50                        if var not in state._replacements[loc]:
51                            state._replacements[loc][var] = repl
52                        else:
53                            if state._replacements[loc][var] != repl:
54                                state._replacements[loc][var] = Top(1)
55            state._equivalence |= o._equivalence
56
57        return state
58
59    def add_replacement(self, codeloc, old, new):
60        """
61        Add a replacement record: Replacing expression `old` with `new` at program location `codeloc`.
62        If the self._only_consts flag is set to true, only constant values will be set.
63
64        :param CodeLocation codeloc:    The code location.
65        :param old:                     The expression to be replaced.
66        :param new:                     The expression to replace with.
67        :return:                        None
68        """
69        if self._only_consts:
70            if isinstance(new, int) or type(new) is Top:
71                self._replacements[codeloc][old] = new
72        else:
73            self._replacements[codeloc][old] = new
74
75    def filter_replacements(self):
76        pass
77
78
79# VEX state
80
81class PropagatorVEXState(PropagatorState):
82    def __init__(self, arch, registers=None, local_variables=None, replacements=None, only_consts=False,
83                 prop_count=None):
84        super().__init__(arch, replacements=replacements, only_consts=only_consts, prop_count=prop_count)
85        self.registers = {} if registers is None else registers  # offset to values
86        self.local_variables = {} if local_variables is None else local_variables  # offset to values
87
88    def __repr__(self):
89        return "<PropagatorVEXState>"
90
91    def copy(self):
92        cp = PropagatorVEXState(
93            self.arch,
94            registers=self.registers.copy(),
95            local_variables=self.local_variables.copy(),
96            replacements=self._replacements.copy(),
97            prop_count=self._prop_count.copy(),
98            only_consts=self._only_consts
99        )
100
101        return cp
102
103    def merge(self, *others):
104        state = self.copy()
105        for other in others:  # type: PropagatorVEXState
106            for offset, value in other.registers.items():
107                if offset not in state.registers:
108                    state.registers[offset] = value
109                else:
110                    if state.registers[offset] != value:
111                        state.registers[offset] = Top(self.arch.bytes)
112
113            for offset, value in other.local_variables.items():
114                if offset not in state.local_variables:
115                    state.local_variables[offset] = value
116                else:
117                    if state.local_variables[offset] != value:
118                        state.local_variables[offset] = Top(self.arch.bytes)
119
120        return state
121
122    def store_local_variable(self, offset, size, value):  # pylint:disable=unused-argument
123        # TODO: Handle size
124        self.local_variables[offset] = value
125
126    def load_local_variable(self, offset, size):  # pylint:disable=unused-argument
127        # TODO: Handle size
128        try:
129            return self.local_variables[offset]
130        except KeyError:
131            return Top(size)
132
133    def store_register(self, offset, size, value):
134        if size != self.gpr_size:
135            return
136
137        self.registers[offset] = value
138
139    def load_register(self, offset, size):
140
141        # TODO: Fix me
142        if size != self.gpr_size:
143            return Top(size)
144
145        try:
146            return self.registers[offset]
147        except KeyError:
148            return Top(size)
149
150
151# AIL state
152
153
154class Equivalence:
155    __slots__ = ('codeloc', 'atom0', 'atom1',)
156
157    def __init__(self, codeloc, atom0, atom1):
158        self.codeloc = codeloc
159        self.atom0 = atom0
160        self.atom1 = atom1
161
162    def __repr__(self):
163        return "<Eq@%r: %r==%r>" % (self.codeloc, self.atom0, self.atom1)
164
165    def __eq__(self, other):
166        return type(other) is Equivalence \
167               and other.codeloc == self.codeloc \
168               and other.atom0 == self.atom0 \
169               and other.atom1 == self.atom1
170
171    def __hash__(self):
172        return hash((Equivalence, self.codeloc, self.atom0, self.atom1))
173
174
175class PropagatorAILState(PropagatorState):
176    def __init__(self, arch, replacements=None, only_consts=False, prop_count=None, equivalence=None):
177        super().__init__(arch, replacements=replacements, only_consts=only_consts, prop_count=prop_count,
178                         equivalence=equivalence)
179
180        self._stack_variables = KeyedRegion()
181        self._registers = KeyedRegion()
182        self._tmps = {}
183
184    def __repr__(self):
185        return "<PropagatorAILState>"
186
187    def copy(self):
188        rd = PropagatorAILState(
189            self.arch,
190            replacements=self._replacements.copy(),
191            prop_count=self._prop_count.copy(),
192            only_consts=self._only_consts,
193            equivalence=self._equivalence.copy(),
194        )
195
196        rd._stack_variables = self._stack_variables.copy()
197        rd._registers = self._registers.copy()
198        # drop tmps
199
200        return rd
201
202    def merge(self, *others) -> 'PropagatorAILState':
203        # TODO:
204        state = super().merge(*others)
205
206        for o in others:
207            state._stack_variables.merge_to_top(o._stack_variables, top=Top(1))
208            state._registers.merge_to_top(o._registers, top=Top(1))
209
210        return state
211
212    def store_variable(self, old, new):
213        if old is None or new is None:
214            return
215        if type(new) is not Top and new.has_atom(old, identity=False):
216            return
217
218        if isinstance(old, ailment.Expr.Tmp):
219            self._tmps[old.tmp_idx] = new
220        elif isinstance(old, ailment.Expr.Register):
221            self._registers.set_object(old.reg_offset, new, old.size)
222        else:
223            _l.warning("Unsupported old variable type %s.", type(old))
224
225    def store_stack_variable(self, addr, size, new, endness=None):  # pylint:disable=unused-argument
226        if isinstance(addr, ailment.Expr.StackBaseOffset):
227            if addr.offset is None:
228                offset = 0
229            else:
230                offset = addr.offset
231            self._stack_variables.set_object(offset, new, size)
232        else:
233            _l.warning("Unsupported addr type %s.", type(addr))
234
235    def get_variable(self, variable):
236        if isinstance(variable, ailment.Expr.Tmp):
237            return self._tmps.get(variable.tmp_idx, None)
238        elif isinstance(variable, ailment.Expr.Register):
239            objs = self._registers.get_objects_by_offset(variable.reg_offset)
240            if not objs:
241                return None
242            # FIXME: Right now we are always getting one item - we should, in fact, work on a multi-value domain
243            first_obj = next(iter(objs))
244            if type(first_obj) is Top:
245                # return a Top
246                if first_obj.bits != variable.bits:
247                    return Top(variable.bits // 8)
248                return first_obj
249            if first_obj.bits != variable.bits:
250                # conversion is needed
251                if isinstance(first_obj, ailment.Expr.Convert):
252                    if variable.bits == first_obj.operand.bits:
253                        first_obj = first_obj.operand
254                    else:
255                        first_obj = ailment.Expr.Convert(first_obj.idx, first_obj.operand.bits, variable.bits,
256                                                         first_obj.is_signed, first_obj.operand)
257                else:
258                    first_obj = ailment.Expr.Convert(first_obj.idx, first_obj.bits, variable.bits, False, first_obj)
259            return first_obj
260        return None
261
262    def get_stack_variable(self, addr, size, endness=None):  # pylint:disable=unused-argument
263        if isinstance(addr, ailment.Expr.StackBaseOffset):
264            objs = self._stack_variables.get_objects_by_offset(addr.offset)
265            if not objs:
266                return None
267            return next(iter(objs))
268        return None
269
270    def add_replacement(self, codeloc, old, new):
271
272        prop_count = 0
273        if not isinstance(old, ailment.Expr.Tmp) and isinstance(new, ailment.Expr.Expression) \
274                and not isinstance(new, ailment.Expr.Const):
275            self._prop_count[new] += 1
276            prop_count = self._prop_count[new]
277
278        if prop_count <= 1:
279            # we can propagate this expression
280            super().add_replacement(codeloc, old, new)
281        else:
282            # eliminate the past propagation of this expression
283            for codeloc_ in self._replacements:
284                if old in self._replacements[codeloc_]:
285                    del self._replacements[codeloc_][old]
286
287    def filter_replacements(self):
288
289        to_remove = set()
290
291        for old, new in self._replacements.items():
292            if isinstance(new, ailment.Expr.Expression) and not isinstance(new, ailment.Expr.Const):
293                if self._prop_count[new] > 1:
294                    # do not propagate this expression
295                    to_remove.add(old)
296
297        for old in to_remove:
298            del self._replacements[old]
299
300    def add_equivalence(self, codeloc, old, new):
301        eq = Equivalence(codeloc, old, new)
302        self._equivalence.add(eq)
303
304
305class PropagatorAnalysis(ForwardAnalysis, Analysis):  # pylint:disable=abstract-method
306    """
307    PropagatorAnalysis propagates values (either constant values or variables) and expressions inside a block or across
308    a function. It supports both VEX and AIL. It performs certain arithmetic operations between constants, including
309    but are not limited to:
310
311    - addition
312    - subtraction
313    - multiplication
314    - division
315    - xor
316
317    It also performs the following memory operations, too:
318
319    - Loading values from a known address
320    - Writing values to a stack variable
321    """
322
323    def __init__(self, func=None, block=None, func_graph=None, base_state=None, max_iterations=3,
324                 load_callback=None, stack_pointer_tracker=None, only_consts=False, completed_funcs=None):
325        if func is not None:
326            if block is not None:
327                raise ValueError('You cannot specify both "func" and "block".')
328            # traversing a function
329            graph_visitor = FunctionGraphVisitor(func, func_graph)
330        elif block is not None:
331            # traversing a block
332            graph_visitor = SingleNodeGraphVisitor(block)
333        else:
334            raise ValueError('Unsupported analysis target.')
335
336        ForwardAnalysis.__init__(self, order_jobs=True, allow_merging=True, allow_widening=False,
337                                 graph_visitor=graph_visitor)
338
339        self._base_state = base_state
340        self._function = func
341        self._max_iterations = max_iterations
342        self._load_callback = load_callback
343        self._stack_pointer_tracker = stack_pointer_tracker  # only used when analyzing AIL functions
344        self._only_consts = only_consts
345        self._completed_funcs = completed_funcs
346
347        self._node_iterations = defaultdict(int)
348        self._states = {}
349        self.replacements: Optional[defaultdict] = None
350        self.equivalence: Set[Equivalence] = set()
351
352        self._engine_vex = SimEnginePropagatorVEX(project=self.project)
353        self._engine_ail = SimEnginePropagatorAIL(stack_pointer_tracker=self._stack_pointer_tracker)
354
355        self._analyze()
356
357    #
358    # Main analysis routines
359    #
360
361    def _pre_analysis(self):
362        pass
363
364    def _pre_job_handling(self, job):
365        pass
366
367    def _initial_abstract_state(self, node):
368        if isinstance(node, ailment.Block):
369            # AIL
370            state = PropagatorAILState(arch=self.project.arch, only_consts=self._only_consts)
371        else:
372            # VEX
373            state = PropagatorVEXState(arch=self.project.arch, only_consts=self._only_consts)
374            state.store_register(self.project.arch.sp_offset,
375                                 self.project.arch.bytes,
376                                 SpOffset(self.project.arch.bits, 0)
377                                 )
378        return state
379
380    def _merge_states(self, node, *states):
381        return states[0].merge(*states[1:])
382
383    def _run_on_node(self, node, state):
384
385        if isinstance(node, ailment.Block):
386            block = node
387            block_key = node.addr
388            engine = self._engine_ail
389        else:
390            block = self.project.factory.block(node.addr, node.size, opt_level=1, cross_insn_opt=False)
391            block_key = node.addr
392            engine = self._engine_vex
393            if block.size == 0:
394                # maybe the block is not decodeable
395                return False, state
396
397        state = state.copy()
398        # Suppress spurious output
399        if self._base_state is not None:
400            self._base_state.options.add(sim_options.SYMBOL_FILL_UNCONSTRAINED_REGISTERS)
401            self._base_state.options.add(sim_options.SYMBOL_FILL_UNCONSTRAINED_MEMORY)
402        state = engine.process(state, block=block, project=self.project, base_state=self._base_state,
403                               load_callback=self._load_callback, fail_fast=self._fail_fast)
404        state.filter_replacements()
405
406        self._node_iterations[block_key] += 1
407        self._states[block_key] = state
408
409        if self.replacements is None:
410            self.replacements = state._replacements
411        else:
412            self.replacements.update(state._replacements)
413
414        self.equivalence |= state._equivalence
415
416        # TODO: Clear registers according to calling conventions
417
418        if self._node_iterations[block_key] < self._max_iterations:
419            return True, state
420        else:
421            return False, state
422
423    def _intra_analysis(self):
424        pass
425
426    def _check_func_complete(self, func):
427        """
428        Checks if a function is completely created by the CFG. Completed
429        functions are passed to the Propagator at initialization. Defaults to
430        being empty if no pass is initiated.
431
432        :param func:    Function to check (knowledge_plugins.functions.function.Function)
433        :return:        Bool
434        """
435        complete = False
436        if self._completed_funcs is None:
437            return complete
438
439        if func.addr in self._completed_funcs:
440            complete = True
441
442        return complete
443
444    def _post_analysis(self):
445        """
446        Post Analysis of Propagation().
447        We add the current propagation replacements result to the kb if the
448        function has already been completed in cfg creation.
449        """
450        if self._function is not None:
451            if self._check_func_complete(self._function):
452                func_loc = CodeLocation(self._function.addr, None)
453                self.kb.propagations.update(func_loc, self.replacements)
454
455    def _check_prop_kb(self):
456        """
457        Checks, and gets, stored propagations from the KB for the current
458        Propagation state.
459
460        :return:    None or Dict of replacements
461        """
462        replacements = None
463        if self._function is not None:
464            func_loc = CodeLocation(self._function.addr, None)
465            replacements = self.kb.propagations.get(func_loc)
466
467        return replacements
468
469    def _analyze(self):
470        """
471        The main analysis for Propagator. Overwritten to include an optimization to stop
472        analysis if we have already analyzed the entire function once.
473        """
474        self._pre_analysis()
475
476        # optimization check
477        stored_replacements = self._check_prop_kb()
478        if stored_replacements is not None:
479            if self.replacements is not None:
480                self.replacements.update(stored_replacements)
481            else:
482                self.replacements = stored_replacements
483
484        # normal analysis execution
485        elif self._graph_visitor is None:
486            # There is no base graph that we can rely on. The analysis itself should generate successors for the
487            # current job.
488            # An example is the CFG recovery.
489
490            self._analysis_core_baremetal()
491
492        else:
493            # We have a base graph to follow. Just handle the current job.
494
495            self._analysis_core_graph()
496
497        self._post_analysis()
498
499
500register_analysis(PropagatorAnalysis, "Propagator")
501