1# built-in
2from itertools import takewhile
3
4# app
5from .base import Base as _Base, BaseSimilarity as _BaseSimilarity
6
7
8__all__ = [
9    'Prefix', 'Postfix', 'Length', 'Identity', 'Matrix',
10    'prefix', 'postfix', 'length', 'identity', 'matrix',
11]
12
13
14class Prefix(_BaseSimilarity):
15    """prefix similarity
16    """
17    def __init__(self, qval=1, sim_test=None):
18        self.qval = qval
19        self.sim_test = sim_test or self._ident
20
21    def __call__(self, *sequences):
22        if not sequences:
23            return 0
24        sequences = self._get_sequences(*sequences)
25        test = lambda seq: self.sim_test(*seq)  # noQA
26        result = [c[0] for c in takewhile(test, zip(*sequences))]
27
28        s = sequences[0]
29        if isinstance(s, str):
30            return ''.join(result)
31        if isinstance(s, bytes):
32            return b''.join(result)
33        return result
34
35    def similarity(self, *sequences):
36        return len(self(*sequences))
37
38
39class Postfix(Prefix):
40    """postfix similarity
41    """
42    def __call__(self, *sequences):
43        s = sequences[0]
44        sequences = [reversed(s) for s in sequences]
45        result = reversed(super().__call__(*sequences))
46        if isinstance(s, str):
47            return ''.join(result)
48        if isinstance(s, bytes):
49            return b''.join(result)
50        return list(result)
51
52
53class Length(_Base):
54    """Length distance
55    """
56    def __call__(self, *sequences):
57        lengths = list(map(len, sequences))
58        return max(lengths) - min(lengths)
59
60
61class Identity(_BaseSimilarity):
62    """Identity similarity
63    """
64
65    def maximum(self, *sequences):
66        return 1
67
68    def __call__(self, *sequences):
69        return int(self._ident(*sequences))
70
71
72class Matrix(_BaseSimilarity):
73    """Matrix similarity
74    """
75
76    def __init__(self, mat=None, mismatch_cost=0, match_cost=1, symmetric=True, external=True):
77        self.mat = mat
78        self.mismatch_cost = mismatch_cost
79        self.match_cost = match_cost
80        self.symmetric = symmetric
81        # self.alphabet = sum(mat.keys(), ())
82
83    def maximum(self, *sequences):
84        return self.match_cost
85
86    def __call__(self, *sequences):
87        if not self.mat:
88            if self._ident(*sequences):
89                return self.match_cost
90            return self.mismatch_cost
91
92        # search in matrix
93        if sequences in self.mat:
94            return self.mat[sequences]
95        # search in symmetric matrix
96        if self.symmetric:
97            sequences = tuple(reversed(sequences))
98            if sequences in self.mat:
99                return self.mat[sequences]
100        # if identity then return match_cost
101        if self._ident(*sequences):
102            return self.match_cost
103        # not found
104        return self.mismatch_cost
105
106
107prefix = Prefix()
108postfix = Postfix()
109length = Length()
110identity = Identity()
111matrix = Matrix()
112