1from collections import defaultdict
2from numbers import Number
3
4from . import sam
5
6def are_disjoint(first, second):
7    if first.is_empty or second.is_empty:
8        return True
9    else:
10        return first.start > second.end or second.start > first.end
11
12def are_adjacent(first, second):
13    return first.start == second.end + 1 or second.start == first.end + 1
14
15class Interval(object):
16    def __init__(self, start, end):
17        self.start = start
18        self.end = end
19        self.is_empty = (end < start)
20
21    @classmethod
22    def from_feature(self, feature):
23        ''' only really necessary to set is_empty '''
24        return Interval(feature.start, feature.end)
25
26    @classmethod
27    def empty(self):
28        return Interval(-1, -2)
29
30    @classmethod
31    def from_slice(self, sl):
32        if sl.start == None:
33            start = 0
34        else:
35            start = sl.start
36
37        if sl.stop == None:
38            end = np.inf
39        else:
40            end = sl.stop - 1 # Note the -1
41
42        return Interval(start, end)
43
44    def __or__(self, other):
45        if are_disjoint(self, other):
46            left, right = sorted([self, other])
47            if are_adjacent(self, other):
48                intervals = [Interval(left.start, right.end)]
49            else:
50                intervals = [left, right]
51        else:
52            intervals = [Interval(min(self.start, other.start), max(self.end, other.end))]
53
54        return DisjointIntervals(intervals)
55
56    def __and__(self, other):
57        if isinstance(other, DisjointIntervals):
58            # Defer to definition in DisjointIntervals
59            return other & self
60        elif are_disjoint(self, other):
61            return []
62        else:
63            return Interval(max(self.start, other.start), min(self.end, other.end))
64
65    def __contains__(self, other):
66        if isinstance(other, Interval):
67            # is a strict sub-interval of
68            return (other.start >= self.start and other.end <= self.end) and (self != other)
69        elif isinstance(other, Number):
70            return self.start <= other <= self.end
71        else:
72            raise ValueError(other)
73
74    @property
75    def comparison_key(self):
76        return self.start, self.end
77
78    @property
79    def total_length(self):
80        return len(self)
81
82    def __lt__(self, other):
83        return self.comparison_key < other.comparison_key
84
85    def __repr__(self):
86        return '[{0:,} - {1:,}]'.format(self.start, self.end)
87
88    def __key(self):
89        return (self.start, self.end)
90
91    def __eq__(self, other):
92        return self.__key() == other.__key()
93
94    def __hash__(self):
95        return hash(self.__key())
96
97    def __ne__(self, other):
98        return not self == other
99
100    def __len__(self):
101        if self.is_empty:
102            return 0
103        else:
104            return self.end - self.start + 1
105
106    def __sub__(self, other):
107        if isinstance(other, DisjointIntervals):
108            survived_all = DisjointIntervals([self])
109            for other_interval in other:
110                survived_this = self - other_interval
111                survived_all = survived_all & survived_this
112            return survived_all
113
114        else:
115            left = Interval(self.start, min(self.end, other.start - 1))
116            right = Interval(max(self.start, other.end + 1), self.end)
117            disjoint = DisjointIntervals([left, right])
118
119            if disjoint.total_length == 0:
120                return Interval(-1, -2)
121            elif len(disjoint.intervals) == 1:
122                return disjoint.intervals[0]
123            else:
124                return disjoint
125
126class DisjointIntervals(object):
127    def __init__(self, intervals):
128        self.intervals = sorted([i for i in intervals if i.end >= i.start])
129        self.is_empty = (len(self.intervals) == 0)
130
131    def __len__(self):
132        return len(self.intervals)
133
134    @property
135    def start(self):
136        if len(self.intervals) == 0:
137            return None
138        else:
139            return min(interval.start for interval in self.intervals)
140
141    @property
142    def end(self):
143        if len(self.intervals) == 0:
144            return None
145        else:
146            return max(interval.end for interval in self.intervals)
147
148    def __repr__(self):
149        return '{{{}}}'.format(', '.join(map(str, self.intervals)))
150
151    def __getitem__(self, sl):
152        return self.intervals[sl]
153
154    def __or__(self, other):
155        if isinstance(other, DisjointIntervals):
156            everything = self
157            for other_interval in other:
158                everything = everything | other_interval
159            return everything
160        else:
161            disjoints = []
162
163            for interval in self.intervals:
164                union = interval | other
165                if len(union) > 1:
166                    disjoints.append(interval)
167                else:
168                    other = union[0]
169
170            disjoints.append(other)
171
172            return DisjointIntervals(sorted(disjoints))
173
174    def __and__(self, other):
175        if isinstance(other, DisjointIntervals):
176            survived_some = DisjointIntervals([])
177            for other_interval in other:
178                survived_this = self & other_interval
179                survived_some = survived_some | survived_this
180            return survived_some
181        else:
182            intersections = []
183            for interval in self.intervals:
184                intersection = interval & other
185                if intersection:
186                    intersections.append(intersection)
187
188            return DisjointIntervals(intersections)
189
190    def __eq__(self, other):
191        return self.intervals == other.intervals
192
193    def __iter__(self):
194        return iter(self.intervals)
195
196    def __hash__(self):
197        return hash(self.intervals)
198
199    def __ne__(self, other):
200        return not self == other
201
202    def __contains__(self, other):
203        if isinstance(other, Number):
204            other = Interval(other, other)
205
206        return (self | other) == self
207
208    @property
209    def total_length(self):
210        return sum(len(i) for i in self.intervals)
211
212    def __sub__(self, other):
213        if isinstance(other, DisjointIntervals):
214            raise NotImplementedError
215        else:
216            pieces = []
217            for interval in self.intervals:
218                piece = interval - other
219                if isinstance(piece, Interval):
220                    pieces.append(piece)
221                elif isinstance(piece, DisjointIntervals):
222                    pieces.extend(piece)
223
224            return DisjointIntervals(pieces)
225
226def get_covered(alignment):
227    if alignment is None or alignment.is_unmapped:
228        return Interval(-1, -2)
229    else:
230        return Interval(*sam.query_interval(alignment))
231
232def make_disjoint(intervals):
233    disjoint = DisjointIntervals([])
234    for interval in intervals:
235        disjoint = disjoint | interval
236    return disjoint
237
238def get_disjoint_covered(alignments):
239    intervals = [get_covered(al) for al in alignments if al is not None]
240    covered = make_disjoint(intervals)
241    return covered
242
243def remove_nested(alignments):
244    unnecessary = set()
245    covered_list = [get_covered(al) for al in alignments]
246    for i, left in enumerate(covered_list):
247        for j, right in enumerate(covered_list):
248            if i == j:
249                continue
250            if left in right:
251                unnecessary.add(i)
252    necessary = [al for i, al in enumerate(alignments) if i not in unnecessary]
253    return necessary
254
255def make_parsimonious(alignments):
256    initial_covered = get_disjoint_covered(alignments)
257
258    no_nested = remove_nested(alignments)
259    interval_to_als = defaultdict(list)
260    for al in no_nested:
261        interval_to_als[get_covered(al)].append(al)
262
263    unique_intervals = sorted(interval_to_als, key=len, reverse=True)
264    remaining = unique_intervals
265
266    contributes = []
267    for possibly_exclude in unique_intervals:
268        exclude_one = [intvl for intvl in remaining if intvl != possibly_exclude]
269        now_covered = make_disjoint(exclude_one)
270        if initial_covered != now_covered:
271            contributes.append(possibly_exclude)
272        else:
273            remaining = exclude_one
274
275    parsimonious = []
276    for interval in contributes:
277        parsimonious.extend(interval_to_als[interval])
278
279    if get_disjoint_covered(parsimonious) != initial_covered:
280        raise ValueError
281
282    return parsimonious
283