1import itertools 2import math 3from array import array 4from collections import defaultdict, Counter 5 6import numpy as np 7 8from urh.awre.CommonRange import CommonRange 9from urh.awre.engines.Engine import Engine 10from urh.cythonext import awre_util 11from urh.util.Logger import logger 12 13 14class AddressEngine(Engine): 15 def __init__(self, msg_vectors, participant_indices, known_participant_addresses: dict = None, 16 already_labeled: list = None, src_field_present=False): 17 """ 18 19 :param msg_vectors: Message data behind synchronization 20 :type msg_vectors: list of np.ndarray 21 :param participant_indices: list of participant indices 22 where ith position holds participants index for ith messages 23 :type participant_indices: list of int 24 """ 25 assert len(msg_vectors) == len(participant_indices) 26 27 self.minimum_score = 0.1 28 29 self.msg_vectors = msg_vectors 30 self.participant_indices = participant_indices 31 self.already_labeled = [] 32 33 self.src_field_present = src_field_present 34 35 if already_labeled is not None: 36 for start, end in already_labeled: 37 # convert it to hex 38 self.already_labeled.append((int(math.ceil(start / 4)), int(math.ceil(end / 4)))) 39 40 self.message_indices_by_participant = defaultdict(list) 41 for i, participant_index in enumerate(self.participant_indices): 42 self.message_indices_by_participant[participant_index].append(i) 43 44 if known_participant_addresses is None: 45 self.known_addresses_by_participant = dict() # type: dict[int, np.ndarray] 46 else: 47 self.known_addresses_by_participant = known_participant_addresses # type: dict[int, np.ndarray] 48 49 @staticmethod 50 def cross_swap_check(rng1: CommonRange, rng2: CommonRange): 51 return (rng1.start == rng2.start + rng1.length or rng1.start == rng2.start - rng1.length) \ 52 and rng1.value.tobytes() == rng2.value.tobytes() 53 54 @staticmethod 55 def ack_check(rng1: CommonRange, rng2: CommonRange): 56 return rng1.start == rng2.start and rng1.length == rng2.length and rng1.value.tobytes() != rng2.value.tobytes() 57 58 def find(self): 59 addresses_by_participant = {p: [addr.tostring()] for p, addr in self.known_addresses_by_participant.items()} 60 addresses_by_participant.update(self.find_addresses()) 61 self._debug("Addresses by participant", addresses_by_participant) 62 63 # Find the address candidates by participant in messages 64 ranges_by_participant = defaultdict(list) # type: dict[int, list[CommonRange]] 65 66 addresses = [np.array(np.frombuffer(a, dtype=np.uint8)) 67 for address_list in addresses_by_participant.values() 68 for a in address_list] 69 70 already_labeled_cols = array("L", [e for rng in self.already_labeled for e in range(*rng)]) 71 72 # Find occurrences of address candidates in messages and create common ranges over matching positions 73 for i, msg_vector in enumerate(self.msg_vectors): 74 participant = self.participant_indices[i] 75 for address in addresses: 76 for index in awre_util.find_occurrences(msg_vector, address, already_labeled_cols): 77 common_ranges = ranges_by_participant[participant] 78 rng = next((cr for cr in common_ranges if cr.matches(index, address)), None) # type: CommonRange 79 if rng is not None: 80 rng.message_indices.add(i) 81 else: 82 common_ranges.append(CommonRange(index, len(address), address, 83 message_indices={i}, 84 range_type="hex")) 85 86 num_messages_by_participant = defaultdict(int) 87 for participant in self.participant_indices: 88 num_messages_by_participant[participant] += 1 89 90 # Look for cross swapped values between participant clusters 91 for p1, p2 in itertools.combinations(ranges_by_participant, 2): 92 ranges1_set, ranges2_set = set(ranges_by_participant[p1]), set(ranges_by_participant[p2]) 93 94 for rng1, rng2 in itertools.product(ranges_by_participant[p1], ranges_by_participant[p2]): 95 if rng1 in ranges2_set and rng2 in ranges1_set: 96 if self.cross_swap_check(rng1, rng2): 97 rng1.score += len(rng2.message_indices) / num_messages_by_participant[p2] 98 rng2.score += len(rng1.message_indices) / num_messages_by_participant[p1] 99 elif self.ack_check(rng1, rng2): 100 # Add previous score in divisor to add bonus to ranges that apply to all messages 101 rng1.score += len(rng2.message_indices) / (num_messages_by_participant[p2] + rng1.score) 102 rng2.score += len(rng1.message_indices) / (num_messages_by_participant[p1] + rng2.score) 103 104 if len(ranges_by_participant) == 1 and not self.src_field_present: 105 for p, ranges in ranges_by_participant.items(): 106 for rng in sorted(ranges): 107 try: 108 if np.array_equal(rng.value, self.known_addresses_by_participant[p]): 109 # Only one participant in this iteration and address already known -> Highscore 110 rng.score = 1 111 break # Take only the first (leftmost) range 112 except KeyError: 113 pass 114 115 high_scored_ranges_by_participant = defaultdict(list) 116 117 address_length = self.__estimate_address_length(ranges_by_participant) 118 119 # Get highscored ranges by participant 120 for participant, common_ranges in ranges_by_participant.items(): 121 # Sort by negative score so ranges with highest score appear first 122 # Secondary sort by tuple to ensure order when ranges have same score 123 sorted_ranges = sorted(filter(lambda cr: cr.score > self.minimum_score, common_ranges), 124 key=lambda cr: (-cr.score, cr)) 125 if len(sorted_ranges) == 0: 126 addresses_by_participant[participant] = dict() 127 continue 128 129 addresses_by_participant[participant] = {a for a in addresses_by_participant.get(participant, []) 130 if len(a) == address_length} 131 132 for rng in filter(lambda r: r.length == address_length, sorted_ranges): 133 rng.score = min(rng.score, 1.0) 134 high_scored_ranges_by_participant[participant].append(rng) 135 136 # Now we find the most probable address for all participants 137 self.__assign_participant_addresses(addresses_by_participant, high_scored_ranges_by_participant) 138 139 # Eliminate participants for which we could not assign an address 140 for participant, address in addresses_by_participant.copy().items(): 141 if address is None: 142 del addresses_by_participant[participant] 143 144 # Now we can separate SRC and DST 145 for participant, ranges in high_scored_ranges_by_participant.items(): 146 try: 147 address = addresses_by_participant[participant] 148 except KeyError: 149 high_scored_ranges_by_participant[participant] = [] 150 continue 151 152 result = [] 153 154 for rng in sorted(ranges, key=lambda r: r.score, reverse=True): 155 rng.field_type = "source address" if rng.value.tostring() == address else "destination address" 156 if len(result) == 0: 157 result.append(rng) 158 else: 159 subset = next((r for r in result if rng.message_indices.issubset(r.message_indices)), None) 160 if subset is not None: 161 if rng.field_type == subset.field_type: 162 # Avoid adding same address type twice 163 continue 164 165 if rng.length != subset.length or (rng.start != subset.end + 1 and rng.end + 1 != subset.start): 166 # Ensure addresses are next to each other 167 continue 168 169 result.append(rng) 170 171 high_scored_ranges_by_participant[participant] = result 172 173 self.__find_broadcast_fields(high_scored_ranges_by_participant, addresses_by_participant) 174 175 result = [rng for ranges in high_scored_ranges_by_participant.values() for rng in ranges] 176 # If we did not find a SRC address, lower the score a bit, 177 # so DST fields do not win later e.g. again length fields in case of tie 178 if not any(rng.field_type == "source address" for rng in result): 179 for rng in result: 180 rng.score *= 0.95 181 182 return result 183 184 def __estimate_address_length(self, ranges_by_participant: dict): 185 """ 186 Estimate the address length which is assumed to be the same for all participants 187 188 :param ranges_by_participant: 189 :return: 190 """ 191 address_lengths = [] 192 for participant, common_ranges in ranges_by_participant.items(): 193 sorted_ranges = sorted(filter(lambda cr: cr.score > self.minimum_score, common_ranges), 194 key=lambda cr: (-cr.score, cr)) 195 196 max_scored = [r for r in sorted_ranges if r.score == sorted_ranges[0].score] 197 198 # Prevent overestimation of address length by looking for substrings 199 for rng in max_scored[:]: 200 same_message_rng = [r for r in sorted_ranges 201 if r not in max_scored and r.score > 0 and r.message_indices == rng.message_indices] 202 203 if len(same_message_rng) > 1 and all( 204 r.value.tobytes() in rng.value.tobytes() for r in same_message_rng): 205 # remove the longer range and add the smaller ones 206 max_scored.remove(rng) 207 max_scored.extend(same_message_rng) 208 209 possible_address_lengths = [r.length for r in max_scored] 210 211 # Count possible address lengths. 212 frequencies = Counter(possible_address_lengths) 213 # Take the most common one. On tie, take the shorter one 214 try: 215 addr_len = max(frequencies, key=lambda x: (frequencies[x], -x)) 216 address_lengths.append(addr_len) 217 except ValueError: # max() arg is an empty sequence 218 pass 219 220 # Take most common address length of participants, to ensure they all have same address length 221 counted = Counter(address_lengths) 222 try: 223 address_length = max(counted, key=lambda x: (counted[x], -x)) 224 return address_length 225 except ValueError: # max() arg is an empty sequence 226 return 0 227 228 def __assign_participant_addresses(self, addresses_by_participant, high_scored_ranges_by_participant): 229 scored_participants_addresses = dict() 230 for participant in addresses_by_participant: 231 scored_participants_addresses[participant] = defaultdict(int) 232 233 for participant, addresses in addresses_by_participant.items(): 234 if participant in self.known_addresses_by_participant: 235 address = self.known_addresses_by_participant[participant].tostring() 236 scored_participants_addresses[participant][address] = 9999999999 237 continue 238 239 for i in self.message_indices_by_participant[participant]: 240 matching = [rng for rng in high_scored_ranges_by_participant[participant] 241 if i in rng.message_indices and rng.value.tostring() in addresses] 242 243 if len(matching) == 1: 244 address = matching[0].value.tostring() 245 # only one address, so probably a destination and not a source 246 scored_participants_addresses[participant][address] *= 0.9 247 248 # Since this is probably an ACK, the address is probably SRC of participant of previous message 249 if i > 0 and self.participant_indices[i - 1] != participant: 250 prev_participant = self.participant_indices[i - 1] 251 prev_matching = [rng for rng in high_scored_ranges_by_participant[prev_participant] 252 if i - 1 in rng.message_indices and rng.value.tostring() in addresses] 253 if len(prev_matching) > 1: 254 for prev_rng in filter(lambda r: r.value.tostring() == address, prev_matching): 255 scored_participants_addresses[prev_participant][address] += prev_rng.score 256 257 elif len(matching) > 1: 258 # more than one address, so there must be a source address included 259 for rng in matching: 260 scored_participants_addresses[participant][rng.value.tostring()] += rng.score 261 262 minimum_score = 0.5 263 taken_addresses = set() 264 self._debug("Scored addresses", scored_participants_addresses) 265 266 # If all participants have exactly one possible address and they all differ, we can assign them right away 267 if all(len(addresses) == 1 for addresses in scored_participants_addresses.values()): 268 all_addresses = [list(addresses)[0] for addresses in scored_participants_addresses.values()] 269 if len(all_addresses) == len(set(all_addresses)): # ensure all addresses are different 270 for p, addresses in scored_participants_addresses.items(): 271 addresses_by_participant[p] = list(addresses)[0] 272 return 273 274 for participant, addresses in sorted(scored_participants_addresses.items()): 275 try: 276 # sort filtered results to prevent randomness for equal scores 277 found_address = max(sorted( 278 filter(lambda a: a not in taken_addresses and addresses[a] >= minimum_score, addresses), 279 reverse=True 280 ), key=addresses.get) 281 except ValueError: 282 # Could not assign address for this participant 283 addresses_by_participant[participant] = None 284 continue 285 286 addresses_by_participant[participant] = found_address 287 taken_addresses.add(found_address) 288 289 def __find_broadcast_fields(self, high_scored_ranges_by_participant, addresses_by_participant: dict): 290 """ 291 Last we check for messages that were sent to broadcast 292 1. we search for messages that have a SRC address but no DST address 293 2. we look at other messages that have this SRC field and find the corresponding DST position 294 3. we evaluate the value of message without DST from 1 and compare these values with each other. 295 if they match, we found the broadcast address 296 :param high_scored_ranges_by_participant: 297 :return: 298 """ 299 if -1 in addresses_by_participant: 300 # broadcast address is already known 301 return 302 303 broadcast_bag = defaultdict(list) # type: dict[CommonRange, list[int]] 304 for common_ranges in high_scored_ranges_by_participant.values(): 305 src_address_fields = sorted(filter(lambda r: r.field_type == "source address", common_ranges)) 306 dst_address_fields = sorted(filter(lambda r: r.field_type == "destination address", common_ranges)) 307 msg_with_dst = {i for dst_address_field in dst_address_fields for i in dst_address_field.message_indices} 308 309 for src_address_field in src_address_fields: # type: CommonRange 310 msg_without_dst = {i for i in src_address_field.message_indices if i not in msg_with_dst} 311 if len(msg_without_dst) == 0: 312 continue 313 try: 314 matching_dst = next(dst for dst in dst_address_fields 315 if all(i in dst.message_indices 316 for i in src_address_field.message_indices - msg_without_dst)) 317 except StopIteration: 318 continue 319 for msg in msg_without_dst: 320 broadcast_bag[matching_dst].append(msg) 321 322 if len(broadcast_bag) == 0: 323 return 324 325 broadcast_address = None 326 for dst, messages in broadcast_bag.items(): 327 for msg_index in messages: 328 value = self.msg_vectors[msg_index][dst.start:dst.end + 1] 329 if broadcast_address is None: 330 broadcast_address = value 331 elif value.tobytes() != broadcast_address.tobytes(): 332 # Address is not common across messages so it can't be a broadcast address 333 return 334 335 addresses_by_participant[-1] = broadcast_address.tobytes() 336 for dst, messages in broadcast_bag.items(): 337 dst.values.append(broadcast_address) 338 dst.message_indices.update(messages) 339 340 def find_addresses(self) -> dict: 341 already_assigned = list(self.known_addresses_by_participant.keys()) 342 if len(already_assigned) == len(self.message_indices_by_participant): 343 self._debug("Skipping find addresses as already known.") 344 return dict() 345 346 common_ranges_by_participant = dict() 347 for participant, message_indices in self.message_indices_by_participant.items(): 348 # Cluster by length 349 length_clusters = defaultdict(list) 350 for i in message_indices: 351 length_clusters[len(self.msg_vectors[i])].append(i) 352 353 common_ranges_by_length = self.find_common_ranges_by_cluster(self.msg_vectors, length_clusters, range_type="hex") 354 common_ranges_by_participant[participant] = [] 355 for ranges in common_ranges_by_length.values(): 356 common_ranges_by_participant[participant].extend(self.ignore_already_labeled(ranges, 357 self.already_labeled)) 358 359 self._debug("Common ranges by participant:", common_ranges_by_participant) 360 361 result = defaultdict(set) 362 participants = sorted(common_ranges_by_participant) # type: list[int] 363 364 if len(participants) < 2: 365 return result 366 367 # If we already know the address length we do not need to bother with other candidates 368 if len(already_assigned) > 0: 369 addr_len = len(self.known_addresses_by_participant[already_assigned[0]]) 370 if any(len(self.known_addresses_by_participant[i]) != addr_len for i in already_assigned): 371 logger.warning("Addresses do not have a common length. Assuming length of {}".format(addr_len)) 372 else: 373 addr_len = None 374 375 for p1, p2 in itertools.combinations(participants, 2): 376 p1_already_assigned = p1 in already_assigned 377 p2_already_assigned = p2 in already_assigned 378 379 if p1_already_assigned and p2_already_assigned: 380 continue 381 382 # common ranges are not merged yet, so there is only one element in values 383 values1 = [cr.value for cr in common_ranges_by_participant[p1]] 384 values2 = [cr.value for cr in common_ranges_by_participant[p2]] 385 for seq1, seq2 in itertools.product(values1, values2): 386 lcs = self.find_longest_common_sub_sequences(seq1, seq2) 387 vals = lcs if len(lcs) > 0 else [seq1, seq2] 388 # Address candidate must be at least 2 values long 389 for val in filter(lambda v: len(v) >= 2, vals): 390 if addr_len is not None and len(val) != addr_len: 391 continue 392 if not p1_already_assigned and not p2_already_assigned: 393 result[p1].add(val.tostring()) 394 result[p2].add(val.tostring()) 395 elif p1_already_assigned and val.tostring() != self.known_addresses_by_participant[p1].tostring(): 396 result[p2].add(val.tostring()) 397 elif p2_already_assigned and val.tostring() != self.known_addresses_by_participant[p2].tostring(): 398 result[p1].add(val.tostring()) 399 return result 400