1import itertools
2import math
3from array import array
4from collections import defaultdict, Counter
5
6import numpy as np
7
8from urh.awre.CommonRange import CommonRange
9from urh.awre.engines.Engine import Engine
10from urh.cythonext import awre_util
11from urh.util.Logger import logger
12
13
14class AddressEngine(Engine):
15    def __init__(self, msg_vectors, participant_indices, known_participant_addresses: dict = None,
16                 already_labeled: list = None, src_field_present=False):
17        """
18
19        :param msg_vectors: Message data behind synchronization
20        :type msg_vectors: list of np.ndarray
21        :param participant_indices: list of participant indices
22                                    where ith position holds participants index for ith messages
23        :type participant_indices: list of int
24        """
25        assert len(msg_vectors) == len(participant_indices)
26
27        self.minimum_score = 0.1
28
29        self.msg_vectors = msg_vectors
30        self.participant_indices = participant_indices
31        self.already_labeled = []
32
33        self.src_field_present = src_field_present
34
35        if already_labeled is not None:
36            for start, end in already_labeled:
37                # convert it to hex
38                self.already_labeled.append((int(math.ceil(start / 4)), int(math.ceil(end / 4))))
39
40        self.message_indices_by_participant = defaultdict(list)
41        for i, participant_index in enumerate(self.participant_indices):
42            self.message_indices_by_participant[participant_index].append(i)
43
44        if known_participant_addresses is None:
45            self.known_addresses_by_participant = dict()  # type: dict[int, np.ndarray]
46        else:
47            self.known_addresses_by_participant = known_participant_addresses  # type: dict[int, np.ndarray]
48
49    @staticmethod
50    def cross_swap_check(rng1: CommonRange, rng2: CommonRange):
51        return (rng1.start == rng2.start + rng1.length or rng1.start == rng2.start - rng1.length) \
52               and rng1.value.tobytes() == rng2.value.tobytes()
53
54    @staticmethod
55    def ack_check(rng1: CommonRange, rng2: CommonRange):
56        return rng1.start == rng2.start and rng1.length == rng2.length and rng1.value.tobytes() != rng2.value.tobytes()
57
58    def find(self):
59        addresses_by_participant = {p: [addr.tostring()] for p, addr in self.known_addresses_by_participant.items()}
60        addresses_by_participant.update(self.find_addresses())
61        self._debug("Addresses by participant", addresses_by_participant)
62
63        # Find the address candidates by participant in messages
64        ranges_by_participant = defaultdict(list)  # type: dict[int, list[CommonRange]]
65
66        addresses = [np.array(np.frombuffer(a, dtype=np.uint8))
67                     for address_list in addresses_by_participant.values()
68                     for a in address_list]
69
70        already_labeled_cols = array("L", [e for rng in self.already_labeled for e in range(*rng)])
71
72        # Find occurrences of address candidates in messages and create common ranges over matching positions
73        for i, msg_vector in enumerate(self.msg_vectors):
74            participant = self.participant_indices[i]
75            for address in addresses:
76                for index in awre_util.find_occurrences(msg_vector, address, already_labeled_cols):
77                    common_ranges = ranges_by_participant[participant]
78                    rng = next((cr for cr in common_ranges if cr.matches(index, address)), None)  # type: CommonRange
79                    if rng is not None:
80                        rng.message_indices.add(i)
81                    else:
82                        common_ranges.append(CommonRange(index, len(address), address,
83                                                         message_indices={i},
84                                                         range_type="hex"))
85
86        num_messages_by_participant = defaultdict(int)
87        for participant in self.participant_indices:
88            num_messages_by_participant[participant] += 1
89
90        # Look for cross swapped values between participant clusters
91        for p1, p2 in itertools.combinations(ranges_by_participant, 2):
92            ranges1_set, ranges2_set = set(ranges_by_participant[p1]), set(ranges_by_participant[p2])
93
94            for rng1, rng2 in itertools.product(ranges_by_participant[p1], ranges_by_participant[p2]):
95                if rng1 in ranges2_set and rng2 in ranges1_set:
96                    if self.cross_swap_check(rng1, rng2):
97                        rng1.score += len(rng2.message_indices) / num_messages_by_participant[p2]
98                        rng2.score += len(rng1.message_indices) / num_messages_by_participant[p1]
99                    elif self.ack_check(rng1, rng2):
100                        # Add previous score in divisor to add bonus to ranges that apply to all messages
101                        rng1.score += len(rng2.message_indices) / (num_messages_by_participant[p2] + rng1.score)
102                        rng2.score += len(rng1.message_indices) / (num_messages_by_participant[p1] + rng2.score)
103
104        if len(ranges_by_participant) == 1 and not self.src_field_present:
105            for p, ranges in ranges_by_participant.items():
106                for rng in sorted(ranges):
107                    try:
108                        if np.array_equal(rng.value, self.known_addresses_by_participant[p]):
109                            # Only one participant in this iteration and address already known -> Highscore
110                            rng.score = 1
111                            break  # Take only the first (leftmost) range
112                    except KeyError:
113                        pass
114
115        high_scored_ranges_by_participant = defaultdict(list)
116
117        address_length = self.__estimate_address_length(ranges_by_participant)
118
119        # Get highscored ranges by participant
120        for participant, common_ranges in ranges_by_participant.items():
121            # Sort by negative score so ranges with highest score appear first
122            # Secondary sort by tuple to ensure order when ranges have same score
123            sorted_ranges = sorted(filter(lambda cr: cr.score > self.minimum_score, common_ranges),
124                                   key=lambda cr: (-cr.score, cr))
125            if len(sorted_ranges) == 0:
126                addresses_by_participant[participant] = dict()
127                continue
128
129            addresses_by_participant[participant] = {a for a in addresses_by_participant.get(participant, [])
130                                                     if len(a) == address_length}
131
132            for rng in filter(lambda r: r.length == address_length, sorted_ranges):
133                rng.score = min(rng.score, 1.0)
134                high_scored_ranges_by_participant[participant].append(rng)
135
136        # Now we find the most probable address for all participants
137        self.__assign_participant_addresses(addresses_by_participant, high_scored_ranges_by_participant)
138
139        # Eliminate participants for which we could not assign an address
140        for participant, address in addresses_by_participant.copy().items():
141            if address is None:
142                del addresses_by_participant[participant]
143
144        # Now we can separate SRC and DST
145        for participant, ranges in high_scored_ranges_by_participant.items():
146            try:
147                address = addresses_by_participant[participant]
148            except KeyError:
149                high_scored_ranges_by_participant[participant] = []
150                continue
151
152            result = []
153
154            for rng in sorted(ranges, key=lambda r: r.score, reverse=True):
155                rng.field_type = "source address" if rng.value.tostring() == address else "destination address"
156                if len(result) == 0:
157                    result.append(rng)
158                else:
159                    subset = next((r for r in result if rng.message_indices.issubset(r.message_indices)), None)
160                    if subset is not None:
161                        if rng.field_type == subset.field_type:
162                            # Avoid adding same address type twice
163                            continue
164
165                        if rng.length != subset.length or (rng.start != subset.end + 1 and rng.end + 1 != subset.start):
166                            # Ensure addresses are next to each other
167                            continue
168
169                    result.append(rng)
170
171            high_scored_ranges_by_participant[participant] = result
172
173        self.__find_broadcast_fields(high_scored_ranges_by_participant, addresses_by_participant)
174
175        result = [rng for ranges in high_scored_ranges_by_participant.values() for rng in ranges]
176        # If we did not find a SRC address, lower the score a bit,
177        # so DST fields do not win later e.g. again length fields in case of tie
178        if not any(rng.field_type == "source address" for rng in result):
179            for rng in result:
180                rng.score *= 0.95
181
182        return result
183
184    def __estimate_address_length(self, ranges_by_participant: dict):
185        """
186        Estimate the address length which is assumed to be the same for all participants
187
188        :param ranges_by_participant:
189        :return:
190        """
191        address_lengths = []
192        for participant, common_ranges in ranges_by_participant.items():
193            sorted_ranges = sorted(filter(lambda cr: cr.score > self.minimum_score, common_ranges),
194                                   key=lambda cr: (-cr.score, cr))
195
196            max_scored = [r for r in sorted_ranges if r.score == sorted_ranges[0].score]
197
198            # Prevent overestimation of address length by looking for substrings
199            for rng in max_scored[:]:
200                same_message_rng = [r for r in sorted_ranges
201                                    if r not in max_scored and r.score > 0 and r.message_indices == rng.message_indices]
202
203                if len(same_message_rng) > 1 and all(
204                        r.value.tobytes() in rng.value.tobytes() for r in same_message_rng):
205                    # remove the longer range and add the smaller ones
206                    max_scored.remove(rng)
207                    max_scored.extend(same_message_rng)
208
209            possible_address_lengths = [r.length for r in max_scored]
210
211            # Count possible address lengths.
212            frequencies = Counter(possible_address_lengths)
213            # Take the most common one. On tie, take the shorter one
214            try:
215                addr_len = max(frequencies, key=lambda x: (frequencies[x], -x))
216                address_lengths.append(addr_len)
217            except ValueError:  # max() arg is an empty sequence
218                pass
219
220        # Take most common address length of participants, to ensure they all have same address length
221        counted = Counter(address_lengths)
222        try:
223            address_length = max(counted, key=lambda x: (counted[x], -x))
224            return address_length
225        except ValueError:  # max() arg is an empty sequence
226            return 0
227
228    def __assign_participant_addresses(self, addresses_by_participant, high_scored_ranges_by_participant):
229        scored_participants_addresses = dict()
230        for participant in addresses_by_participant:
231            scored_participants_addresses[participant] = defaultdict(int)
232
233        for participant, addresses in addresses_by_participant.items():
234            if participant in self.known_addresses_by_participant:
235                address = self.known_addresses_by_participant[participant].tostring()
236                scored_participants_addresses[participant][address] = 9999999999
237                continue
238
239            for i in self.message_indices_by_participant[participant]:
240                matching = [rng for rng in high_scored_ranges_by_participant[participant]
241                            if i in rng.message_indices and rng.value.tostring() in addresses]
242
243                if len(matching) == 1:
244                    address = matching[0].value.tostring()
245                    # only one address, so probably a destination and not a source
246                    scored_participants_addresses[participant][address] *= 0.9
247
248                    # Since this is probably an ACK, the address is probably SRC of participant of previous message
249                    if i > 0 and self.participant_indices[i - 1] != participant:
250                        prev_participant = self.participant_indices[i - 1]
251                        prev_matching = [rng for rng in high_scored_ranges_by_participant[prev_participant]
252                                         if i - 1 in rng.message_indices and rng.value.tostring() in addresses]
253                        if len(prev_matching) > 1:
254                            for prev_rng in filter(lambda r: r.value.tostring() == address, prev_matching):
255                                scored_participants_addresses[prev_participant][address] += prev_rng.score
256
257                elif len(matching) > 1:
258                    # more than one address, so there must be a source address included
259                    for rng in matching:
260                        scored_participants_addresses[participant][rng.value.tostring()] += rng.score
261
262        minimum_score = 0.5
263        taken_addresses = set()
264        self._debug("Scored addresses", scored_participants_addresses)
265
266        # If all participants have exactly one possible address and they all differ, we can assign them right away
267        if all(len(addresses) == 1 for addresses in scored_participants_addresses.values()):
268            all_addresses = [list(addresses)[0] for addresses in scored_participants_addresses.values()]
269            if len(all_addresses) == len(set(all_addresses)):  # ensure all addresses are different
270                for p, addresses in scored_participants_addresses.items():
271                    addresses_by_participant[p] = list(addresses)[0]
272                return
273
274        for participant, addresses in sorted(scored_participants_addresses.items()):
275            try:
276                # sort filtered results to prevent randomness for equal scores
277                found_address = max(sorted(
278                    filter(lambda a: a not in taken_addresses and addresses[a] >= minimum_score, addresses),
279                    reverse=True
280                ), key=addresses.get)
281            except ValueError:
282                # Could not assign address for this participant
283                addresses_by_participant[participant] = None
284                continue
285
286            addresses_by_participant[participant] = found_address
287            taken_addresses.add(found_address)
288
289    def __find_broadcast_fields(self, high_scored_ranges_by_participant, addresses_by_participant: dict):
290        """
291        Last we check for messages that were sent to broadcast
292          1. we search for messages that have a SRC address but no DST address
293          2. we look at other messages that have this SRC field and find the corresponding DST position
294          3. we evaluate the value of message without DST from 1 and compare these values with each other.
295             if they match, we found the broadcast address
296        :param high_scored_ranges_by_participant:
297        :return:
298        """
299        if -1 in addresses_by_participant:
300            # broadcast address is already known
301            return
302
303        broadcast_bag = defaultdict(list)  # type: dict[CommonRange, list[int]]
304        for common_ranges in high_scored_ranges_by_participant.values():
305            src_address_fields = sorted(filter(lambda r: r.field_type == "source address", common_ranges))
306            dst_address_fields = sorted(filter(lambda r: r.field_type == "destination address", common_ranges))
307            msg_with_dst = {i for dst_address_field in dst_address_fields for i in dst_address_field.message_indices}
308
309            for src_address_field in src_address_fields:  # type: CommonRange
310                msg_without_dst = {i for i in src_address_field.message_indices if i not in msg_with_dst}
311                if len(msg_without_dst) == 0:
312                    continue
313                try:
314                    matching_dst = next(dst for dst in dst_address_fields
315                                        if all(i in dst.message_indices
316                                               for i in src_address_field.message_indices - msg_without_dst))
317                except StopIteration:
318                    continue
319                for msg in msg_without_dst:
320                    broadcast_bag[matching_dst].append(msg)
321
322        if len(broadcast_bag) == 0:
323            return
324
325        broadcast_address = None
326        for dst, messages in broadcast_bag.items():
327            for msg_index in messages:
328                value = self.msg_vectors[msg_index][dst.start:dst.end + 1]
329                if broadcast_address is None:
330                    broadcast_address = value
331                elif value.tobytes() != broadcast_address.tobytes():
332                    # Address is not common across messages so it can't be a broadcast address
333                    return
334
335        addresses_by_participant[-1] = broadcast_address.tobytes()
336        for dst, messages in broadcast_bag.items():
337            dst.values.append(broadcast_address)
338            dst.message_indices.update(messages)
339
340    def find_addresses(self) -> dict:
341        already_assigned = list(self.known_addresses_by_participant.keys())
342        if len(already_assigned) == len(self.message_indices_by_participant):
343            self._debug("Skipping find addresses as already known.")
344            return dict()
345
346        common_ranges_by_participant = dict()
347        for participant, message_indices in self.message_indices_by_participant.items():
348            # Cluster by length
349            length_clusters = defaultdict(list)
350            for i in message_indices:
351                length_clusters[len(self.msg_vectors[i])].append(i)
352
353            common_ranges_by_length = self.find_common_ranges_by_cluster(self.msg_vectors, length_clusters, range_type="hex")
354            common_ranges_by_participant[participant] = []
355            for ranges in common_ranges_by_length.values():
356                common_ranges_by_participant[participant].extend(self.ignore_already_labeled(ranges,
357                                                                                             self.already_labeled))
358
359        self._debug("Common ranges by participant:", common_ranges_by_participant)
360
361        result = defaultdict(set)
362        participants = sorted(common_ranges_by_participant)  # type: list[int]
363
364        if len(participants) < 2:
365            return result
366
367        # If we already know the address length we do not need to bother with other candidates
368        if len(already_assigned) > 0:
369            addr_len = len(self.known_addresses_by_participant[already_assigned[0]])
370            if any(len(self.known_addresses_by_participant[i]) != addr_len for i in already_assigned):
371                logger.warning("Addresses do not have a common length. Assuming length of {}".format(addr_len))
372        else:
373            addr_len = None
374
375        for p1, p2 in itertools.combinations(participants, 2):
376            p1_already_assigned = p1 in already_assigned
377            p2_already_assigned = p2 in already_assigned
378
379            if p1_already_assigned and p2_already_assigned:
380                continue
381
382            # common ranges are not merged yet, so there is only one element in values
383            values1 = [cr.value for cr in common_ranges_by_participant[p1]]
384            values2 = [cr.value for cr in common_ranges_by_participant[p2]]
385            for seq1, seq2 in itertools.product(values1, values2):
386                lcs = self.find_longest_common_sub_sequences(seq1, seq2)
387                vals = lcs if len(lcs) > 0 else [seq1, seq2]
388                # Address candidate must be at least 2 values long
389                for val in filter(lambda v: len(v) >= 2, vals):
390                    if addr_len is not None and len(val) != addr_len:
391                        continue
392                    if not p1_already_assigned and not p2_already_assigned:
393                        result[p1].add(val.tostring())
394                        result[p2].add(val.tostring())
395                    elif p1_already_assigned and val.tostring() != self.known_addresses_by_participant[p1].tostring():
396                        result[p2].add(val.tostring())
397                    elif p2_already_assigned and val.tostring() != self.known_addresses_by_participant[p2].tostring():
398                        result[p1].add(val.tostring())
399        return result
400