1import logging
2import weakref
3from typing import Union, TYPE_CHECKING
4
5from sortedcontainers import SortedDict
6
7if TYPE_CHECKING:
8    from .knowledge_plugins.key_definitions.unknown_size import UnknownSize
9
10
11l = logging.getLogger(name=__name__)
12
13
14class StoredObject:
15
16    __slots__ = ('__weakref__', 'start', 'obj', 'size')
17
18    def __init__(self, start, obj, size):
19        self.start = start
20        self.obj = obj
21        self.size: Union['UnknownSize',int] = size
22
23    def __eq__(self, other):
24        assert type(other) is StoredObject
25
26        return self.obj == other.obj and self.start == other.start and self.size == other.size
27
28    def __hash__(self):
29        return hash((self.start, self.size, self.obj))
30
31    def __repr__(self):
32        return "<SO %s@%#x, %s bytes>" % (repr(self.obj), self.start, self.size)
33
34    @property
35    def obj_id(self):
36        return id(self.obj)
37
38
39class RegionObject:
40    """
41    Represents one or more objects occupying one or more bytes in KeyedRegion.
42    """
43
44    __slots__ = ('start', 'size', 'stored_objects', '_internal_objects')
45
46    def __init__(self, start, size, objects=None):
47        self.start = start
48        self.size = size
49        self.stored_objects = set() if objects is None else objects
50
51        self._internal_objects = set()
52        if self.stored_objects:
53            for obj in self.stored_objects:
54                self._internal_objects.add(obj.obj)
55
56    def __eq__(self, other):
57        return type(other) is RegionObject and self.start == other.start and self.size == other.size and \
58               self.stored_objects == other.stored_objects
59
60    def __ne__(self, other):
61        return not self == other
62
63    @property
64    def is_empty(self):
65        return len(self.stored_objects) == 0
66
67    @property
68    def end(self):
69        return self.start + self.size
70
71    @property
72    def internal_objects(self):
73        return self._internal_objects
74
75    def includes(self, offset):
76        return self.start <= offset < self.start + self.size
77
78    def split(self, split_at):
79        assert self.includes(split_at)
80        a = RegionObject(self.start, split_at - self.start, self.stored_objects.copy())
81        b = RegionObject(split_at, self.start + self.size - split_at, self.stored_objects.copy())
82
83        return a, b
84
85    def add_object(self, obj):
86        if obj in self.stored_objects:
87            # another StoredObject with the same hash exists, but they may not have the same ID
88            # remove the existing StoredObject and replace it with the new one
89            self.stored_objects.remove(obj)
90
91        self.stored_objects.add(obj)
92        self._internal_objects.add(obj.obj)
93
94    def set_object(self, obj):
95        self.stored_objects.clear()
96        self._internal_objects.clear()
97
98        self.add_object(obj)
99
100    def copy(self):
101        ro = RegionObject(self.start, self.size, objects=self.stored_objects.copy())
102        return ro
103
104
105class KeyedRegion:
106    """
107    KeyedRegion keeps a mapping between stack offsets and all objects covering that offset. It assumes no variable in
108    this region overlap with another variable in this region.
109
110    Registers and function frames can all be viewed as a keyed region.
111    """
112
113    __slots__ = ('_storage', '_object_mapping', '_phi_node_contains', '_canonical_size', )
114
115    def __init__(self, tree=None, phi_node_contains=None, canonical_size=8):
116        self._storage = SortedDict() if tree is None else tree
117        self._object_mapping = weakref.WeakValueDictionary()
118        self._phi_node_contains = phi_node_contains
119        self._canonical_size: int = canonical_size
120
121    def __getstate__(self):
122        return self._storage, dict(self._object_mapping), self._phi_node_contains
123
124    def __setstate__(self, s):
125        self._storage, om, self._phi_node_contains = s
126        self._object_mapping = weakref.WeakValueDictionary(om)
127
128    def _get_container(self, offset):
129        try:
130            base_offset = next(self._storage.irange(maximum=offset, reverse=True))
131        except StopIteration:
132            return offset, None
133        else:
134            container = self._storage[base_offset]
135            if container.includes(offset):
136                return base_offset, container
137            return offset, None
138
139    def __contains__(self, offset):
140        """
141        Test if there is at least one variable covering the given offset.
142
143        :param offset:
144        :return:
145        """
146
147        if type(offset) is not int:
148            raise TypeError("KeyedRegion only accepts concrete offsets.")
149
150        return self._get_container(offset)[1] is not None
151
152    def __len__(self):
153        return len(self._storage)
154
155    def __iter__(self):
156        return iter(self._storage.values())
157
158    def __eq__(self, other):
159        if set(self._storage.keys()) != set(other._storage.keys()):
160            return False
161
162        for k, v in self._storage.items():
163            if v != other._storage[k]:
164                return False
165
166        return True
167
168    def copy(self):
169        if not self._storage:
170            return KeyedRegion(phi_node_contains=self._phi_node_contains, canonical_size=self._canonical_size)
171
172        kr = KeyedRegion(phi_node_contains=self._phi_node_contains, canonical_size=self._canonical_size)
173        for key, ro in self._storage.items():
174            kr._storage[key] = ro.copy()
175        kr._object_mapping = self._object_mapping.copy()
176        return kr
177
178    def merge(self, other, replacements=None):
179        """
180        Merge another KeyedRegion into this KeyedRegion.
181
182        :param KeyedRegion other: The other instance to merge with.
183        :return: None
184        """
185
186        if self._canonical_size != other._canonical_size:
187            raise ValueError("The canonical sizes of two KeyedRegion objects must equal.")
188
189        # TODO: is the current solution not optimal enough?
190        for _, item in other._storage.items():  # type: RegionObject
191            for so in item.stored_objects:  # type: StoredObject
192                if replacements and so.obj in replacements:
193                    so = StoredObject(so.start, replacements[so.obj], so.size)
194                self._object_mapping[so.obj_id] = so
195                self.__store(so, overwrite=False)
196
197        return self
198
199    def merge_to_top(self, other, replacements=None, top=None):
200        """
201        Merge another KeyedRegion into this KeyedRegion, but mark all variables with different values as TOP.
202
203        :param other:   The other instance to merge with.
204        :param replacements:
205        :return:        self
206        """
207
208        for _, item in other._storage.items():  # type: RegionObject
209            for so in item.stored_objects:  # type: StoredObject
210                if replacements and so.obj in replacements:
211                    so = StoredObject(so.start, replacements[so.obj], so.size)
212                self._object_mapping[so.obj_id] = so
213                self.__store(so, overwrite=False, merge_to_top=True, top=top)
214
215        return self
216
217    def replace(self, replacements):
218        """
219        Replace variables with other variables.
220
221        :param dict replacements:   A dict of variable replacements.
222        :return:                    self
223        """
224
225        for old_var, new_var in replacements.items():
226            old_var_id = id(old_var)
227            if old_var_id in self._object_mapping:
228                # FIXME: we need to check if old_var still exists in the storage
229                old_so = self._object_mapping[old_var_id]  # type: StoredObject
230                self._store(old_so.start, new_var, old_so.size, overwrite=True)
231
232        return self
233
234    def dbg_repr(self):
235        """
236        Get a debugging representation of this keyed region.
237        :return: A string of debugging output.
238        """
239        keys = self._storage.keys()
240        offset_to_vars = { }
241
242        for key in sorted(keys):
243            ro = self._storage[key]
244            variables = [ obj.obj for obj in ro.stored_objects ]
245            offset_to_vars[key] = variables
246
247        s = [ ]
248        for offset, variables in offset_to_vars.items():
249            s.append("Offset %#x: %s" % (offset, variables))
250        return "\n".join(s)
251
252    def add_variable(self, start, variable):
253        """
254        Add a variable to this region at the given offset.
255
256        :param int start:
257        :param SimVariable variable:
258        :return: None
259        """
260
261        size = variable.size if variable.size is not None else 1
262
263        self.add_object(start, variable, size)
264
265    def add_object(self, start, obj, object_size):
266        """
267        Add/Store an object to this region at the given offset.
268
269        :param start:
270        :param obj:
271        :param int object_size: Size of the object
272        :return:
273        """
274
275        self._store(start, obj, object_size, overwrite=False)
276
277    def set_variable(self, start, variable):
278        """
279        Add a variable to this region at the given offset, and remove all other variables that are fully covered by
280        this variable.
281
282        :param int start:
283        :param SimVariable variable:
284        :return: None
285        """
286
287        size = variable.size if variable.size is not None else 1
288
289        self.set_object(start, variable, size)
290
291    def set_object(self, start, obj, object_size):
292        """
293        Add an object to this region at the given offset, and remove all other objects that are fully covered by this
294        object.
295
296        :param start:
297        :param obj:
298        :param object_size:
299        :return:
300        """
301
302        self._store(start, obj, object_size, overwrite=True)
303
304    def get_base_addr(self, addr):
305        """
306        Get the base offset (the key we are using to index objects covering the given offset) of a specific offset.
307
308        :param int addr:
309        :return:
310        :rtype:  int or None
311        """
312
313        base_addr, container = self._get_container(addr)
314        if container is None:
315            return None
316        else:
317            return base_addr
318
319    def get_variables_by_offset(self, start):
320        """
321        Find variables covering the given region offset.
322
323        :param int start:
324        :return: A set of variables.
325        :rtype:  set
326        """
327
328        _, container = self._get_container(start)
329        if container is None:
330            return set()
331        else:
332            return container.internal_objects
333
334    def get_objects_by_offset(self, start):
335        """
336        Find objects covering the given region offset.
337
338        :param start:
339        :return:
340        """
341
342        _, container = self._get_container(start)
343        if container is None:
344            return set()
345        else:
346            return container.internal_objects
347
348    def get_all_variables(self):
349        """
350        Get all variables covering the current region.
351
352        :return:    A set of all variables.
353        """
354        variables = set()
355        for ro in self._storage.values():
356            ro: RegionObject
357            variables |= ro.internal_objects
358        return variables
359
360    #
361    # Private methods
362    #
363
364    def _canonicalize_size(self, size: Union[int,'UnknownSize']) -> int:
365
366        # delayed import
367        from .knowledge_plugins.key_definitions.unknown_size import UnknownSize  # pylint:disable=import-outside-toplevel
368
369        if isinstance(size, UnknownSize):
370            return self._canonical_size
371        return size
372
373    def _store(self, start, obj, size, overwrite=False):
374        """
375        Store a variable into the storage.
376
377        :param int start: The beginning address of the variable.
378        :param obj: The object to store.
379        :param int size: Size of the object to store.
380        :param bool overwrite: Whether existing objects should be overwritten or not.
381        :return: None
382        """
383
384        stored_object = StoredObject(start, obj, size)
385        self._object_mapping[stored_object.obj_id] = stored_object
386        self.__store(stored_object, overwrite=overwrite)
387
388    def __store(self, stored_object, overwrite=False, merge_to_top=False, top=None):
389        """
390        Store a variable into the storage.
391
392        :param StoredObject stored_object: The descriptor describing start address and the variable.
393        :param bool overwrite:  Whether existing objects should be overwritten or not. True to make a strong update,
394                                False to make a weak update.
395        :return: None
396        """
397
398        start = stored_object.start
399        object_size = self._canonicalize_size(stored_object.size)
400        end: int = start + object_size
401
402        # region items in the middle
403        overlapping_items = list(self._storage.irange(start, end-1))
404
405        # is there a region item that begins before the start and overlaps with this variable?
406        floor_key, floor_item = self._get_container(start)
407        if floor_item is not None and floor_key not in overlapping_items:
408            # insert it into the beginning
409            overlapping_items.insert(0, floor_key)
410
411        # scan through the entire list of region items, split existing regions and insert new regions as needed
412        to_update = {start: RegionObject(start, object_size, {stored_object})}
413        last_end: int = start
414
415        for floor_key in overlapping_items:
416            item: RegionObject = self._storage[floor_key]
417            item_end: int = item.start + self._canonicalize_size(item.size)
418            if item.start < start:
419                # we need to break this item into two
420                a, b = item.split(start)
421                if overwrite:
422                    b.set_object(stored_object)
423                else:
424                    self._add_object_with_check(b, stored_object, merge_to_top=merge_to_top, top=top)
425                to_update[a.start] = a
426                to_update[b.start] = b
427                last_end = b.start + self._canonicalize_size(b.size)
428            elif item.start > last_end:
429                # there is a gap between the last item and the current item
430                # fill in the gap
431                new_item = RegionObject(last_end, item.start - last_end, {stored_object})
432                to_update[new_item.start] = new_item
433                last_end = new_item.end
434            elif item_end > end:
435                # we need to split this item into two
436                a, b = item.split(end)
437                if overwrite:
438                    a.set_object(stored_object)
439                else:
440                    self._add_object_with_check(a, stored_object, merge_to_top=merge_to_top, top=top)
441                to_update[a.start] = a
442                to_update[b.start] = b
443                last_end = b.start + self._canonicalize_size(b.size)
444            else:
445                if overwrite:
446                    item.set_object(stored_object)
447                else:
448                    self._add_object_with_check(item, stored_object, merge_to_top=merge_to_top, top=top)
449                to_update[item.start] = item
450
451        self._storage.update(to_update)
452
453    def _is_overlapping(self, start, variable):
454
455        if variable.size is not None:
456            # make sure this variable does not overlap with any other variable
457            end = start + variable.size
458            try:
459                prev_offset = next(self._storage.irange(maximum=end-1, reverse=True))
460            except StopIteration:
461                prev_offset = None
462
463            if prev_offset is not None:
464                if start <= prev_offset < end:
465                    return True
466                prev_item = self._storage[prev_offset][0]
467                prev_item_size = prev_item.size if prev_item.size is not None else 1
468                if start < prev_offset + prev_item_size < end:
469                    return True
470        else:
471            try:
472                prev_offset = next(self._storage.irange(maximum=start, reverse=True))
473            except StopIteration:
474                prev_offset = None
475
476            if prev_offset is not None:
477                prev_item = self._storage[prev_offset][0]
478                prev_item_size = prev_item.size if prev_item.size is not None else 1
479                if prev_offset <= start < prev_offset + prev_item_size:
480                    return True
481
482        return False
483
484    def _add_object_with_check(self, item, stored_object, merge_to_top=False, top=None):
485        if len({stored_object.obj} | item.internal_objects) > 1:
486            if merge_to_top:
487                item.set_object(StoredObject(stored_object.start, top, stored_object.size))
488                return
489
490            if self._phi_node_contains is not None:
491                # check if `item` is a phi node that contains stored_object.obj
492                for so in item.internal_objects:
493                    if self._phi_node_contains(so, stored_object.obj):
494                        # yes! so we want to skip this object
495                        return
496                # check if `stored_object.obj` is a phi node that contains item.internal_objects
497                if all(self._phi_node_contains(stored_object.obj, o) for o in item.internal_objects):
498                    # yes!
499                    item.set_object(stored_object)
500                    return
501
502        item.add_object(stored_object)
503