1from collections import defaultdict 2from numbers import Number 3 4from . import sam 5 6def are_disjoint(first, second): 7 if first.is_empty or second.is_empty: 8 return True 9 else: 10 return first.start > second.end or second.start > first.end 11 12def are_adjacent(first, second): 13 return first.start == second.end + 1 or second.start == first.end + 1 14 15class Interval(object): 16 def __init__(self, start, end): 17 self.start = start 18 self.end = end 19 self.is_empty = (end < start) 20 21 @classmethod 22 def from_feature(self, feature): 23 ''' only really necessary to set is_empty ''' 24 return Interval(feature.start, feature.end) 25 26 @classmethod 27 def empty(self): 28 return Interval(-1, -2) 29 30 @classmethod 31 def from_slice(self, sl): 32 if sl.start == None: 33 start = 0 34 else: 35 start = sl.start 36 37 if sl.stop == None: 38 end = np.inf 39 else: 40 end = sl.stop - 1 # Note the -1 41 42 return Interval(start, end) 43 44 def __or__(self, other): 45 if are_disjoint(self, other): 46 left, right = sorted([self, other]) 47 if are_adjacent(self, other): 48 intervals = [Interval(left.start, right.end)] 49 else: 50 intervals = [left, right] 51 else: 52 intervals = [Interval(min(self.start, other.start), max(self.end, other.end))] 53 54 return DisjointIntervals(intervals) 55 56 def __and__(self, other): 57 if isinstance(other, DisjointIntervals): 58 # Defer to definition in DisjointIntervals 59 return other & self 60 elif are_disjoint(self, other): 61 return [] 62 else: 63 return Interval(max(self.start, other.start), min(self.end, other.end)) 64 65 def __contains__(self, other): 66 if isinstance(other, Interval): 67 # is a strict sub-interval of 68 return (other.start >= self.start and other.end <= self.end) and (self != other) 69 elif isinstance(other, Number): 70 return self.start <= other <= self.end 71 else: 72 raise ValueError(other) 73 74 @property 75 def comparison_key(self): 76 return self.start, self.end 77 78 @property 79 def total_length(self): 80 return len(self) 81 82 def __lt__(self, other): 83 return self.comparison_key < other.comparison_key 84 85 def __repr__(self): 86 return '[{0:,} - {1:,}]'.format(self.start, self.end) 87 88 def __key(self): 89 return (self.start, self.end) 90 91 def __eq__(self, other): 92 return self.__key() == other.__key() 93 94 def __hash__(self): 95 return hash(self.__key()) 96 97 def __ne__(self, other): 98 return not self == other 99 100 def __len__(self): 101 if self.is_empty: 102 return 0 103 else: 104 return self.end - self.start + 1 105 106 def __sub__(self, other): 107 if isinstance(other, DisjointIntervals): 108 survived_all = DisjointIntervals([self]) 109 for other_interval in other: 110 survived_this = self - other_interval 111 survived_all = survived_all & survived_this 112 return survived_all 113 114 else: 115 left = Interval(self.start, min(self.end, other.start - 1)) 116 right = Interval(max(self.start, other.end + 1), self.end) 117 disjoint = DisjointIntervals([left, right]) 118 119 if disjoint.total_length == 0: 120 return Interval(-1, -2) 121 elif len(disjoint.intervals) == 1: 122 return disjoint.intervals[0] 123 else: 124 return disjoint 125 126class DisjointIntervals(object): 127 def __init__(self, intervals): 128 self.intervals = sorted([i for i in intervals if i.end >= i.start]) 129 self.is_empty = (len(self.intervals) == 0) 130 131 def __len__(self): 132 return len(self.intervals) 133 134 @property 135 def start(self): 136 if len(self.intervals) == 0: 137 return None 138 else: 139 return min(interval.start for interval in self.intervals) 140 141 @property 142 def end(self): 143 if len(self.intervals) == 0: 144 return None 145 else: 146 return max(interval.end for interval in self.intervals) 147 148 def __repr__(self): 149 return '{{{}}}'.format(', '.join(map(str, self.intervals))) 150 151 def __getitem__(self, sl): 152 return self.intervals[sl] 153 154 def __or__(self, other): 155 if isinstance(other, DisjointIntervals): 156 everything = self 157 for other_interval in other: 158 everything = everything | other_interval 159 return everything 160 else: 161 disjoints = [] 162 163 for interval in self.intervals: 164 union = interval | other 165 if len(union) > 1: 166 disjoints.append(interval) 167 else: 168 other = union[0] 169 170 disjoints.append(other) 171 172 return DisjointIntervals(sorted(disjoints)) 173 174 def __and__(self, other): 175 if isinstance(other, DisjointIntervals): 176 survived_some = DisjointIntervals([]) 177 for other_interval in other: 178 survived_this = self & other_interval 179 survived_some = survived_some | survived_this 180 return survived_some 181 else: 182 intersections = [] 183 for interval in self.intervals: 184 intersection = interval & other 185 if intersection: 186 intersections.append(intersection) 187 188 return DisjointIntervals(intersections) 189 190 def __eq__(self, other): 191 return self.intervals == other.intervals 192 193 def __iter__(self): 194 return iter(self.intervals) 195 196 def __hash__(self): 197 return hash(self.intervals) 198 199 def __ne__(self, other): 200 return not self == other 201 202 def __contains__(self, other): 203 if isinstance(other, Number): 204 other = Interval(other, other) 205 206 return (self | other) == self 207 208 @property 209 def total_length(self): 210 return sum(len(i) for i in self.intervals) 211 212 def __sub__(self, other): 213 if isinstance(other, DisjointIntervals): 214 raise NotImplementedError 215 else: 216 pieces = [] 217 for interval in self.intervals: 218 piece = interval - other 219 if isinstance(piece, Interval): 220 pieces.append(piece) 221 elif isinstance(piece, DisjointIntervals): 222 pieces.extend(piece) 223 224 return DisjointIntervals(pieces) 225 226def get_covered(alignment): 227 if alignment is None or alignment.is_unmapped: 228 return Interval(-1, -2) 229 else: 230 return Interval(*sam.query_interval(alignment)) 231 232def make_disjoint(intervals): 233 disjoint = DisjointIntervals([]) 234 for interval in intervals: 235 disjoint = disjoint | interval 236 return disjoint 237 238def get_disjoint_covered(alignments): 239 intervals = [get_covered(al) for al in alignments if al is not None] 240 covered = make_disjoint(intervals) 241 return covered 242 243def remove_nested(alignments): 244 unnecessary = set() 245 covered_list = [get_covered(al) for al in alignments] 246 for i, left in enumerate(covered_list): 247 for j, right in enumerate(covered_list): 248 if i == j: 249 continue 250 if left in right: 251 unnecessary.add(i) 252 necessary = [al for i, al in enumerate(alignments) if i not in unnecessary] 253 return necessary 254 255def make_parsimonious(alignments): 256 initial_covered = get_disjoint_covered(alignments) 257 258 no_nested = remove_nested(alignments) 259 interval_to_als = defaultdict(list) 260 for al in no_nested: 261 interval_to_als[get_covered(al)].append(al) 262 263 unique_intervals = sorted(interval_to_als, key=len, reverse=True) 264 remaining = unique_intervals 265 266 contributes = [] 267 for possibly_exclude in unique_intervals: 268 exclude_one = [intvl for intvl in remaining if intvl != possibly_exclude] 269 now_covered = make_disjoint(exclude_one) 270 if initial_covered != now_covered: 271 contributes.append(possibly_exclude) 272 else: 273 remaining = exclude_one 274 275 parsimonious = [] 276 for interval in contributes: 277 parsimonious.extend(interval_to_als[interval]) 278 279 if get_disjoint_covered(parsimonious) != initial_covered: 280 raise ValueError 281 282 return parsimonious 283