1 2import logging 3from collections import defaultdict 4 5from ...serializable import Serializable 6from ...protos import xrefs_pb2 7from ..plugin import KnowledgeBasePlugin 8from .xref import XRef, XRefType 9 10 11l = logging.getLogger(name=__name__) 12 13 14class XRefManager(KnowledgeBasePlugin, Serializable): 15 def __init__(self, kb): 16 super().__init__() 17 self._kb = kb 18 19 self.xrefs_by_ins_addr = defaultdict(set) 20 self.xrefs_by_dst = defaultdict(set) 21 22 def copy(self): 23 xm = XRefManager(self._kb) 24 xm.xrefs_by_ins_addr = self.xrefs_by_ins_addr.copy() 25 xm.xrefs_by_dst = self.xrefs_by_dst.copy() 26 return xm 27 28 def add_xref(self, xref): 29 to_remove = set() 30 # Overwrite existing "offset" refs 31 if xref.type != XRefType.Offset: 32 existing = self.get_xrefs_by_ins_addr(xref.ins_addr) 33 if existing: 34 for ex in existing: 35 if ex.dst == xref.dst and ex.type == XRefType.Offset: 36 # We want to remove this one and replace it with the new one 37 to_remove.add(ex) 38 39 d0 = self.xrefs_by_ins_addr[xref.ins_addr] 40 d0.add(xref) 41 d1 = self.xrefs_by_dst[xref.dst] 42 d1.add(xref) 43 44 for ex in to_remove: 45 d0.discard(ex) 46 d1.discard(ex) 47 48 def add_xrefs(self, xrefs): 49 for xref in xrefs: 50 self.add_xref(xref) 51 52 def get_xrefs_by_ins_addr(self, ins_addr): 53 return self.xrefs_by_ins_addr.get(ins_addr, set()) 54 55 def get_xrefs_by_dst(self, dst): 56 return self.xrefs_by_dst.get(dst, set()) 57 58 def get_xrefs_by_dst_region(self, start, end): 59 """ 60 Get a set of XRef objects that point to a given address region 61 bounded by start and end. 62 Will only return absolute xrefs, not relative ones (like SP offsets) 63 """ 64 f = lambda x: isinstance(x, int) and start <= x <= end 65 addrs = filter(f, self.xrefs_by_dst.keys()) 66 refs = set() 67 for addr in addrs: 68 refs = refs.union(self.xrefs_by_dst[addr]) 69 return refs 70 71 def get_xrefs_by_ins_addr_region(self, start, end): 72 """ 73 Get a set of XRef objects that originate at a given address region 74 bounded by start and end. Useful for finding references from a basic block or function. 75 """ 76 f = lambda x: isinstance(x, int) and start <= x <= end 77 addrs = filter(f, self.xrefs_by_ins_addr.keys()) 78 refs = set() 79 for addr in addrs: 80 refs = refs.union(self.xrefs_by_ins_addr[addr]) 81 return refs 82 83 # TODO: Maybe add some helpers that accept Function or Block objects for the sake of clean analyses. 84 85 @classmethod 86 def _get_cmsg(cls): 87 return xrefs_pb2.XRefs() 88 89 def serialize_to_cmessage(self): 90 # pylint:disable=no-member 91 cmsg = self._get_cmsg() 92 # references 93 refs = [] 94 for ref_set in self.xrefs_by_ins_addr.values(): 95 for ref in ref_set: 96 refs.append(ref.serialize_to_cmessage()) 97 cmsg.xrefs.extend(refs) 98 return cmsg 99 100 @classmethod 101 def parse_from_cmessage(cls, cmsg, cfg_model=None, kb=None, **kwargs): # pylint:disable=arguments-differ 102 103 model = XRefManager(kb) 104 bits = kb._project.arch.bits 105 106 # references 107 for xref_pb2 in cmsg.xrefs: 108 if xref_pb2.data_ea == -1: 109 l.warning("Unknown address of the referenced data item. Ignore the reference at %#x.", xref_pb2.ea) 110 continue 111 xref = XRef.parse_from_cmessage(xref_pb2, bits=bits) 112 if cfg_model is not None and isinstance(xref.dst, int): 113 xref.memory_data = cfg_model.memory_data.get(xref.dst, None) 114 model.add_xref(xref) 115 116 return model 117 118 119KnowledgeBasePlugin.register_default('xrefs', XRefManager) 120