1from typing import Set, List, Tuple, Dict, Union, TYPE_CHECKING
2import logging
3from collections import defaultdict
4from itertools import count
5
6from claripy.utils.orderedset import OrderedSet
7
8from ...sim_variable import SimVariable, SimStackVariable, SimMemoryVariable, SimRegisterVariable
9from ...keyed_region import KeyedRegion
10from ..plugin import KnowledgeBasePlugin
11from .variable_access import VariableAccess
12
13if TYPE_CHECKING:
14    from ...knowledge_base import KnowledgeBase
15
16
17l = logging.getLogger(name=__name__)
18
19
20class VariableType:
21    REGISTER = 0
22    MEMORY = 1
23
24
25class LiveVariables:
26    """
27    A collection of live variables at a program point.
28    """
29    def __init__(self, register_region, stack_region):
30        self.register_region = register_region
31        self.stack_region = stack_region
32
33
34def _defaultdict_set():
35    return defaultdict(set)
36
37
38class VariableManagerInternal:
39    """
40    Manage variables for a function. It is meant to be used internally by VariableManager.
41    """
42    def __init__(self, manager, func_addr=None):
43        self.manager = manager
44
45        self.func_addr = func_addr
46
47        self._variables = OrderedSet()  # all variables that are added to any region
48        self._global_region = KeyedRegion()
49        self._stack_region = KeyedRegion()
50        self._register_region = KeyedRegion()
51        self._live_variables = { }  # a mapping between addresses of program points and live variable collections
52
53        self._variable_accesses = defaultdict(set)
54        self._insn_to_variable = defaultdict(set)
55        self._block_to_variable = defaultdict(set)
56        self._stmt_to_variable = defaultdict(set)
57        self._atom_to_variable = defaultdict(_defaultdict_set)
58        self._variable_counters = {
59            'register': count(),
60            'stack': count(),
61            'argument': count(),
62            'phi': count(),
63            'global': count(),
64        }
65
66        self._phi_variables = { }
67        self._phi_variables_by_block = defaultdict(set)
68
69        self.types = { }
70
71    #
72    # Public methods
73    #
74
75    def next_variable_ident(self, sort):
76        if sort not in self._variable_counters:
77            raise ValueError('Unsupported variable sort %s' % sort)
78
79        if sort == 'register':
80            prefix = "r"
81        elif sort == 'stack':
82            prefix = "s"
83        elif sort == 'argument':
84            prefix = 'arg'
85        elif sort == 'global':
86            prefix = 'g'
87        else:
88            prefix = "m"
89
90        ident = "i%s_%d" % (prefix, next(self._variable_counters[sort]))
91        return ident
92
93    def add_variable(self, sort, start, variable):
94        if sort == 'stack':
95            self._stack_region.add_variable(start, variable)
96        elif sort == 'register':
97            self._register_region.add_variable(start, variable)
98        elif sort == 'global':
99            self._global_region.add_variable(start, variable)
100        else:
101            raise ValueError('Unsupported sort %s in add_variable().' % sort)
102
103    def set_variable(self, sort, start, variable):
104        if sort == 'stack':
105            self._stack_region.set_variable(start, variable)
106        elif sort == 'register':
107            self._register_region.set_variable(start, variable)
108        elif sort == 'global':
109            self._global_region.set_variable(start, variable)
110        else:
111            raise ValueError('Unsupported sort %s in add_variable().' % sort)
112
113    def write_to(self, variable, offset, location, overwrite=False, atom=None):
114        self._record_variable_access('write', variable, offset, location, overwrite=overwrite, atom=atom)
115
116    def read_from(self, variable, offset, location, overwrite=False, atom=None):
117        self._record_variable_access('read', variable, offset, location, overwrite=overwrite, atom=atom)
118
119    def reference_at(self, variable, offset, location, overwrite=False, atom=None):
120        self._record_variable_access('reference', variable, offset, location, overwrite=overwrite, atom=atom)
121
122    def _record_variable_access(self, sort, variable, offset, location, overwrite=False, atom=None):
123        self._variables.add(variable)
124        var_and_offset = variable, offset
125        if overwrite:
126            self._variable_accesses[variable] = {VariableAccess(variable, sort, location)}
127            self._insn_to_variable[location.ins_addr] = {var_and_offset}
128            self._block_to_variable[location.block_addr] = {var_and_offset}
129            self._stmt_to_variable[(location.block_addr, location.stmt_idx)] = {var_and_offset}
130            if atom is not None:
131                self._atom_to_variable[(location.block_addr, location.stmt_idx)][atom] = var_and_offset
132        else:
133            self._variable_accesses[variable].add(VariableAccess(variable, sort, location))
134            self._insn_to_variable[location.ins_addr].add(var_and_offset)
135            self._block_to_variable[location.block_addr].add(var_and_offset)
136            self._stmt_to_variable[(location.block_addr, location.stmt_idx)].add(var_and_offset)
137            if atom is not None:
138                self._atom_to_variable[(location.block_addr, location.stmt_idx)][atom].add(var_and_offset)
139
140    def make_phi_node(self, block_addr, *variables):
141        """
142        Create a phi variable for variables at block `block_addr`.
143
144        :param int block_addr:  The address of the current block.
145        :param variables:       Variables that the phi variable represents.
146        :return:                The created phi variable.
147        """
148
149        existing_phis = set()
150        non_phis = set()
151        for var in variables:
152            if self.is_phi_variable(var):
153                existing_phis.add(var)
154            else:
155                non_phis.add(var)
156        if len(existing_phis) == 1:
157            existing_phi = next(iter(existing_phis))
158            if non_phis.issubset(self.get_phi_subvariables(existing_phi)):
159                return existing_phi
160            else:
161                # Update phi variables
162                self._phi_variables[existing_phi] |= non_phis
163                return existing_phi
164
165        repre = next(iter(variables))
166        repre_type = type(repre)
167        if repre_type is SimRegisterVariable:
168            ident_sort = 'register'
169            a = SimRegisterVariable(repre.reg, repre.size, ident=self.next_variable_ident(ident_sort))
170        elif repre_type is SimMemoryVariable:
171            ident_sort = 'memory'
172            a = SimMemoryVariable(repre.addr, repre.size, ident=self.next_variable_ident(ident_sort))
173        elif repre_type is SimStackVariable:
174            ident_sort = 'stack'
175            a = SimStackVariable(repre.offset, repre.size, ident=self.next_variable_ident(ident_sort))
176        else:
177            raise TypeError('make_phi_node(): Unsupported variable type "%s".' % type(repre))
178
179        # Keep a record of all phi variables
180        self._phi_variables[a] = set(variables)
181        self._phi_variables_by_block[block_addr].add(a)
182
183        return a
184
185    def set_live_variables(self, addr, register_region, stack_region):
186        lv = LiveVariables(register_region, stack_region)
187        self._live_variables[addr] = lv
188
189    def find_variables_by_insn(self, ins_addr, sort):
190        if ins_addr not in self._insn_to_variable:
191            return None
192
193        if sort in (VariableType.MEMORY, 'memory'):
194            vars_and_offset = [(var, offset) for var, offset in self._insn_to_variable[ins_addr]
195                        if isinstance(var, (SimStackVariable, SimMemoryVariable))]
196        elif sort in (VariableType.REGISTER, 'register'):
197            vars_and_offset = [(var, offset) for var, offset in self._insn_to_variable[ins_addr]
198                        if isinstance(var, SimRegisterVariable)]
199        else:
200            l.error('find_variable_by_insn(): Unsupported variable sort "%s".', sort)
201            return [ ]
202
203        return vars_and_offset
204
205    def find_variable_by_stmt(self, block_addr, stmt_idx, sort):
206        return next(iter(self.find_variables_by_stmt(block_addr, stmt_idx, sort)), None)
207
208    def find_variables_by_stmt(self, block_addr: int, stmt_idx: int, sort: str) -> List[Tuple[SimVariable,int]]:
209
210        key = block_addr, stmt_idx
211
212        if key not in self._stmt_to_variable:
213            return [ ]
214
215        variables = self._stmt_to_variable[key]
216        if not variables:
217            return [ ]
218
219        if sort == 'memory':
220            var_and_offsets = list((var, offset) for var, offset in self._stmt_to_variable[key]
221                                   if isinstance(var, (SimStackVariable, SimMemoryVariable)))
222        elif sort == 'register':
223            var_and_offsets = list((var, offset) for var, offset in self._stmt_to_variable[key]
224                                   if isinstance(var, SimRegisterVariable))
225        else:
226            l.error('find_variables_by_stmt(): Unsupported variable sort "%s".', sort)
227            return [ ]
228
229        return var_and_offsets
230
231    def find_variable_by_atom(self, block_addr, stmt_idx, atom):
232        return next(iter(self.find_variables_by_atom(block_addr, stmt_idx, atom)), None)
233
234    def find_variables_by_atom(self, block_addr, stmt_idx, atom) -> Set[Tuple[SimVariable, int]]:
235
236        key = block_addr, stmt_idx
237
238        if key not in self._atom_to_variable:
239            return set()
240
241        if atom not in self._atom_to_variable[key]:
242            return set()
243
244        return self._atom_to_variable[key][atom]
245
246    def get_variable_accesses(self, variable: SimVariable, same_name: bool=False) -> List[VariableAccess]:
247
248        if not same_name:
249            if variable in self._variable_accesses:
250                return list(self._variable_accesses[variable])
251
252            return [ ]
253
254        # find all variables with the same variable name
255
256        vars_list = [ ]
257
258        for var in self._variable_accesses.keys():
259            if variable.name == var.name:
260                vars_list.append(var)
261
262        accesses = [ ]
263        for var in vars_list:
264            accesses.extend(self.get_variable_accesses(var))
265
266        return accesses
267
268    def get_variables(self, sort=None, collapse_same_ident=False) -> List[Union[SimStackVariable,SimRegisterVariable]]:
269        """
270        Get a list of variables.
271
272        :param str or None sort:    Sort of the variable to get.
273        :param collapse_same_ident: Whether variables of the same identifier should be collapsed or not.
274        :return:                    A list of variables.
275        :rtype:                     list
276        """
277
278        variables = [ ]
279
280        if collapse_same_ident:
281            raise NotImplementedError()
282
283        for var in self._variables:
284            if sort == 'stack' and not isinstance(var, SimStackVariable):
285                continue
286            if sort == 'reg' and not isinstance(var, SimRegisterVariable):
287                continue
288            variables.append(var)
289
290        return variables
291
292    def get_global_variables(self, addr):
293        """
294        Get global variable by the address of the variable.
295
296        :param int addr:    Address of the variable.
297        :return:            A set of variables or an empty set if no variable exists.
298        """
299        return self._global_region.get_variables_by_offset(addr)
300
301    def is_phi_variable(self, var):
302        """
303        Test if `var` is a phi variable.
304
305        :param SimVariable var: The variable instance.
306        :return:                True if `var` is a phi variable, False otherwise.
307        :rtype:                 bool
308        """
309
310        return var in self._phi_variables
311
312    def get_phi_subvariables(self, var):
313        """
314        Get sub-variables that phi variable `var` represents.
315
316        :param SimVariable var: The variable instance.
317        :return:                A set of sub-variables, or an empty set if `var` is not a phi variable.
318        :rtype:                 set
319        """
320
321        if not self.is_phi_variable(var):
322            return set()
323        return self._phi_variables[var]
324
325    def get_phi_variables(self, block_addr):
326        """
327        Get a dict of phi variables and their corresponding variables.
328
329        :param int block_addr:  Address of the block.
330        :return:                A dict of phi variables of an empty dict if there are no phi variables at the block.
331        :rtype:                 dict
332        """
333
334        if block_addr not in self._phi_variables_by_block:
335            return dict()
336        variables = { }
337        for phi in self._phi_variables_by_block[block_addr]:
338            variables[phi] = self._phi_variables[phi]
339        return variables
340
341    def input_variables(self, exclude_specials=True):
342        """
343        Get all variables that have never been written to.
344
345        :return: A list of variables that are never written to.
346        """
347
348        def has_write_access(accesses):
349            return any(acc for acc in accesses if acc.access_type == 'write')
350
351        def has_read_access(accesses):
352            return any(acc for acc in accesses if acc.access_type == 'read')
353
354        input_variables = [ ]
355
356        for variable, accesses in self._variable_accesses.items():
357            if variable in self._phi_variables:
358                # a phi variable is definitely not an input variable
359                continue
360            if not has_write_access(accesses) and has_read_access(accesses):
361                if not exclude_specials or not variable.category:
362                    input_variables.append(variable)
363
364        return input_variables
365
366    def assign_variable_names(self, labels=None):
367        """
368        Assign default names to all variables.
369
370        :return: None
371        """
372
373        for var in self._variables:
374            if isinstance(var, SimStackVariable):
375                if var.name is not None:
376                    continue
377                if var.ident.startswith('iarg'):
378                    var.name = 'arg_%x' % var.offset
379                else:
380                    var.name = 's_%x' % (-var.offset)
381                    # var.name = var.ident
382            elif isinstance(var, SimRegisterVariable):
383                if var.name is not None:
384                    continue
385                var.name = var.ident
386            elif isinstance(var, SimMemoryVariable):
387                if var.name is not None:
388                    continue
389                if labels is not None and var.addr in labels:
390                    var.name = labels[var.addr]
391                else:
392                    var.name = var.ident
393
394    def get_variable_type(self, var):
395        return self.types.get(var, None)
396
397    def remove_types(self):
398        self.types.clear()
399
400
401class VariableManager(KnowledgeBasePlugin):
402    """
403    Manage variables.
404    """
405    def __init__(self, kb):
406        super(VariableManager, self).__init__()
407        self._kb: 'KnowledgeBase' = kb
408        self.global_manager = VariableManagerInternal(self)
409        self.function_managers: Dict[int,VariableManagerInternal] = { }
410
411    def __getitem__(self, key) -> VariableManagerInternal:
412        """
413        Get the VariableManagerInternal object for a function or a region.
414
415        :param str or int key: Key of the region. "global" for the global region, or a function address for the
416                               function.
417        :return:               The VariableManagerInternal object.
418        """
419
420        if key == 'global':  # pylint:disable=no-else-return
421            return self.global_manager
422
423        else:
424            # key refers to a function address
425            return self.get_function_manager(key)
426
427    def __delitem__(self, key) -> None:
428        """
429        Remove the existing VariableManagerInternal object for a function or a region.
430
431        :param Union[str,int] key:  Key of the region. "global" for the global region, or a function address for the
432                                    function.
433        :return:                    None
434        """
435
436        if key == 'global':
437            self.global_manager = VariableManagerInternal(self)
438        else:
439            del self.function_managers[key]
440
441    def has_function_manager(self, key: int) -> bool:
442        return key in self.function_managers
443
444    def get_function_manager(self, func_addr) -> VariableManagerInternal:
445        if not isinstance(func_addr, int):
446            raise TypeError('Argument "func_addr" must be an int.')
447
448        if func_addr not in self.function_managers:
449            self.function_managers[func_addr] = VariableManagerInternal(self, func_addr=func_addr)
450
451        return self.function_managers[func_addr]
452
453    def initialize_variable_names(self) -> None:
454        self.global_manager.assign_variable_names()
455        for manager in self.function_managers.values():
456            manager.assign_variable_names()
457
458    def get_variable_accesses(self, variable: SimVariable, same_name: bool=False) -> List[VariableAccess]:
459        """
460        Get a list of all references to the given variable.
461
462        :param variable:        The variable.
463        :param same_name:       Whether to include all variables with the same variable name, or just based on the
464                                variable identifier.
465        :return:                All references to the variable.
466        """
467
468        if variable.region == 'global':
469            return self.global_manager.get_variable_accesses(variable, same_name=same_name)
470
471        elif variable.region in self.function_managers:
472            return self.function_managers[variable.region].get_variable_accesses(variable, same_name=same_name)
473
474        l.warning('get_variable_accesses(): Region %s is not found.', variable.region)
475        return [ ]
476
477    def copy(self):
478        raise NotImplementedError
479
480
481KnowledgeBasePlugin.register_default('variables', VariableManager)
482