1# built-in
2from functools import reduce
3from itertools import islice, permutations, repeat
4from math import log
5
6# app
7from .base import Base as _Base, BaseSimilarity as _BaseSimilarity
8from .edit_based import DamerauLevenshtein
9
10
11__all__ = [
12    'Jaccard', 'Sorensen', 'Tversky',
13    'Overlap', 'Cosine', 'Tanimoto', 'MongeElkan', 'Bag',
14
15    'jaccard', 'sorensen', 'tversky', 'sorensen_dice',
16    'overlap', 'cosine', 'tanimoto', 'monge_elkan', 'bag',
17]
18
19
20class Jaccard(_BaseSimilarity):
21    """
22    Compute the Jaccard similarity between the two sequences.
23    They should contain hashable items.
24    The return value is a float between 0 and 1, where 1 means equal,
25    and 0 totally different.
26
27    https://en.wikipedia.org/wiki/Jaccard_index
28    https://github.com/Yomguithereal/talisman/blob/master/src/metrics/jaccard.js
29    """
30    def __init__(self, qval=1, as_set=False, external=True):
31        self.qval = qval
32        self.as_set = as_set
33        self.external = external
34
35    def maximum(self, *sequences):
36        return 1
37
38    def __call__(self, *sequences):
39        result = self.quick_answer(*sequences)
40        if result is not None:
41            return result
42
43        sequences = self._get_counters(*sequences)               # sets
44        intersection = self._intersect_counters(*sequences)      # set
45        intersection = self._count_counters(intersection)        # int
46        union = self._union_counters(*sequences)                 # set
47        union = self._count_counters(union)                      # int
48        return intersection / union
49
50
51class Sorensen(_BaseSimilarity):
52    """
53    Compute the Sorensen distance between the two sequences.
54    They should contain hashable items.
55    The return value is a float between 0 and 1, where 0 means equal,
56    and 1 totally different.
57
58    https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
59    https://github.com/Yomguithereal/talisman/blob/master/src/metrics/dice.js
60    """
61    def __init__(self, qval=1, as_set=False, external=True):
62        self.qval = qval
63        self.as_set = as_set
64        self.external = external
65
66    def maximum(self, *sequences):
67        return 1
68
69    def __call__(self, *sequences):
70        result = self.quick_answer(*sequences)
71        if result is not None:
72            return result
73
74        sequences = self._get_counters(*sequences)               # sets
75        count = sum(self._count_counters(s) for s in sequences)
76        intersection = self._intersect_counters(*sequences)      # set
77        intersection = self._count_counters(intersection)        # int
78        return 2.0 * intersection / count
79
80
81class Tversky(_BaseSimilarity):
82    """Tversky index
83
84    https://en.wikipedia.org/wiki/Tversky_index
85    https://github.com/Yomguithereal/talisman/blob/master/src/metrics/tversky.js
86    """
87    def __init__(self, qval=1, ks=None, bias=None, as_set=False, external=True):
88        self.qval = qval
89        self.ks = ks or repeat(1)
90        self.bias = bias
91        self.as_set = as_set
92        self.external = external
93
94    def maximum(self, *sequences):
95        return 1
96
97    def __call__(self, *sequences):
98        result = self.quick_answer(*sequences)
99        if result is not None:
100            return result
101
102        sequences = self._get_counters(*sequences)                # sets
103        intersection = self._intersect_counters(*sequences)       # set
104        intersection = self._count_counters(intersection)         # int
105        sequences = [self._count_counters(s) for s in sequences]  # ints
106        ks = list(islice(self.ks, len(sequences)))
107
108        if len(sequences) == 2 or self.bias is None:
109            result = intersection
110            for k, s in zip(ks, sequences):
111                result += k * (s - intersection)
112            return intersection / result
113
114        s1, s2 = sequences
115        alpha, beta = ks
116        a_val = min([s1, s2])
117        b_val = max([s1, s2])
118        c_val = intersection + self.bias
119        result = alpha * beta * (a_val - b_val) + b_val * beta
120        return c_val / (result + c_val)
121
122
123class Overlap(_BaseSimilarity):
124    """overlap coefficient
125
126    https://en.wikipedia.org/wiki/Overlap_coefficient
127    https://github.com/Yomguithereal/talisman/blob/master/src/metrics/overlap.js
128    """
129    def __init__(self, qval=1, as_set=False, external=True):
130        self.qval = qval
131        self.as_set = as_set
132        self.external = external
133
134    def maximum(self, *sequences):
135        return 1
136
137    def __call__(self, *sequences):
138        result = self.quick_answer(*sequences)
139        if result is not None:
140            return result
141
142        sequences = self._get_counters(*sequences)                  # sets
143        intersection = self._intersect_counters(*sequences)         # set
144        intersection = self._count_counters(intersection)           # int
145        sequences = [self._count_counters(s) for s in sequences]    # ints
146
147        return intersection / min(sequences)
148
149
150class Cosine(_BaseSimilarity):
151    """cosine similarity (Ochiai coefficient)
152
153    https://en.wikipedia.org/wiki/Cosine_similarity
154    https://github.com/Yomguithereal/talisman/blob/master/src/metrics/cosine.js
155    """
156    def __init__(self, qval=1, as_set=False, external=True):
157        self.qval = qval
158        self.as_set = as_set
159        self.external = external
160
161    def maximum(self, *sequences):
162        return 1
163
164    def __call__(self, *sequences):
165        result = self.quick_answer(*sequences)
166        if result is not None:
167            return result
168
169        sequences = self._get_counters(*sequences)                  # sets
170        intersection = self._intersect_counters(*sequences)         # set
171        intersection = self._count_counters(intersection)           # int
172        sequences = [self._count_counters(s) for s in sequences]    # ints
173        prod = reduce(lambda x, y: x * y, sequences)
174
175        return intersection / pow(prod, 1.0 / len(sequences))
176
177
178class Tanimoto(Jaccard):
179    """Tanimoto distance
180    This is identical to the Jaccard similarity coefficient
181    and the Tversky index for alpha=1 and beta=1.
182    """
183    def __call__(self, *sequences):
184        result = super().__call__(*sequences)
185        if result == 0:
186            return float('-inf')
187        else:
188            return log(result, 2)
189
190
191class MongeElkan(_BaseSimilarity):
192    """
193    https://www.academia.edu/200314/Generalized_Monge-Elkan_Method_for_Approximate_Text_String_Comparison
194    http://www.cs.cmu.edu/~wcohen/postscript/kdd-2003-match-ws.pdf
195    https://github.com/Yomguithereal/talisman/blob/master/src/metrics/monge-elkan.js
196    """
197    _damerau_levenshtein = DamerauLevenshtein()
198
199    def __init__(self, algorithm=_damerau_levenshtein, symmetric=False, qval=1, external=True):
200        self.algorithm = algorithm
201        self.symmetric = symmetric
202        self.qval = qval
203        self.external = external
204
205    def maximum(self, *sequences):
206        result = self.algorithm.maximum(sequences)
207        for seq in sequences:
208            if seq:
209                result = max(result, self.algorithm.maximum(*seq))
210        return result
211
212    def _calc(self, seq, *sequences):
213        if not seq:
214            return 0
215        maxes = []
216        for c1 in seq:
217            for s in sequences:
218                max_sim = float('-inf')
219                for c2 in s:
220                    max_sim = max(max_sim, self.algorithm.similarity(c1, c2))
221                maxes.append(max_sim)
222        return sum(maxes) / len(seq) / len(maxes)
223
224    def __call__(self, *sequences):
225        result = self.quick_answer(*sequences)
226        if result is not None:
227            return result
228        sequences = self._get_sequences(*sequences)
229
230        if self.symmetric:
231            result = []
232            for seqs in permutations(sequences):
233                result.append(self._calc(*seqs))
234            return sum(result) / len(result)
235        else:
236            return self._calc(*sequences)
237
238
239class Bag(_Base):
240    """Bag distance
241    https://github.com/Yomguithereal/talisman/blob/master/src/metrics/bag.js
242    """
243    def __call__(self, *sequences):
244        sequences = self._get_counters(*sequences)              # sets
245        intersection = self._intersect_counters(*sequences)     # set
246        sequences = (self._count_counters(sequence - intersection) for sequence in sequences)
247        # ^ ints
248        return max(sequences)
249
250
251bag = Bag()
252cosine = Cosine()
253dice = Sorensen()
254jaccard = Jaccard()
255monge_elkan = MongeElkan()
256overlap = Overlap()
257sorensen = Sorensen()
258sorensen_dice = Sorensen()
259# sorensen_dice = Tversky(ks=[.5, .5])
260tanimoto = Tanimoto()
261tversky = Tversky()
262