1"""
2Classes for writing and filtering of processed reads.
3
4A Filter is a callable that has the read as its only argument. If it is called,
5it returns True if the read should be filtered (discarded), and False if not.
6
7To be used, a filter needs to be wrapped in one of the redirector classes.
8They are called so because they can redirect filtered reads to a file if so
9desired. They also keep statistics.
10
11To determine what happens to a read, a list of redirectors with different
12filters is created and each redirector is called in turn until one returns True.
13The read is then assumed to have been "consumed", that is, either written
14somewhere or filtered (should be discarded).
15"""
16from collections import defaultdict, Counter
17from abc import ABC, abstractmethod
18from typing import Tuple, Optional, Dict, Any, DefaultDict
19
20from .qualtrim import expected_errors
21from .modifiers import ModificationInfo
22
23
24# Constants used when returning from a Filter’s __call__ method to improve
25# readability (it is unintuitive that "return True" means "discard the read").
26DISCARD = True
27KEEP = False
28
29
30class SingleEndFilter(ABC):
31    @abstractmethod
32    def __call__(self, read, info: ModificationInfo) -> bool:
33        """
34        Process a single read
35        """
36
37
38class PairedEndFilter(ABC):
39    @abstractmethod
40    def __call__(self, read1, read2, info1: ModificationInfo, info2: ModificationInfo) -> bool:
41        """
42        Process read pair (read1, read2)
43        """
44
45
46class WithStatistics(ABC):
47    def __init__(self) -> None:
48        # A defaultdict is much faster than a Counter
49        self._written_lengths1: DefaultDict[int, int] = defaultdict(int)
50        self._written_lengths2: DefaultDict[int, int] = defaultdict(int)
51
52    def written_reads(self) -> int:
53        """Return number of written reads or read pairs"""
54        return sum(self._written_lengths1.values())
55
56    def written_bp(self) -> Tuple[int, int]:
57        return (
58            self._compute_total_bp(self._written_lengths1),
59            self._compute_total_bp(self._written_lengths2),
60        )
61
62    def written_lengths(self) -> Tuple[Counter, Counter]:
63        return (Counter(self._written_lengths1), Counter(self._written_lengths2))
64
65    @staticmethod
66    def _compute_total_bp(counts: DefaultDict[int, int]) -> int:
67        return sum(length * count for length, count in counts.items())
68
69
70class SingleEndFilterWithStatistics(SingleEndFilter, WithStatistics, ABC):
71    def __init__(self):
72        super().__init__()
73
74    def update_statistics(self, read) -> None:
75        self._written_lengths1[len(read)] += 1
76
77
78class PairedEndFilterWithStatistics(PairedEndFilter, WithStatistics, ABC):
79    def __init__(self):
80        super().__init__()
81
82    def update_statistics(self, read1, read2):
83        self._written_lengths1[len(read1)] += 1
84        self._written_lengths2[len(read2)] += 1
85
86
87class NoFilter(SingleEndFilterWithStatistics):
88    """
89    No filtering, just send each read to the given writer.
90    """
91    def __init__(self, writer):
92        super().__init__()
93        self.writer = writer
94
95    def __repr__(self):
96        return "NoFilter({})".format(self.writer)
97
98    def __call__(self, read, info: ModificationInfo):
99        self.writer.write(read)
100        self.update_statistics(read)
101        return DISCARD
102
103
104class PairedNoFilter(PairedEndFilterWithStatistics):
105    """
106    No filtering, just send each paired-end read to the given writer.
107    """
108    def __init__(self, writer):
109        super().__init__()
110        self.writer = writer
111
112    def __repr__(self):
113        return "PairedNoFilter({})".format(self.writer)
114
115    def __call__(self, read1, read2, info1: ModificationInfo, info2: ModificationInfo):
116        self.writer.write(read1, read2)
117        self.update_statistics(read1, read2)
118        return DISCARD
119
120
121class Redirector(SingleEndFilter):
122    """
123    Redirect discarded reads to the given writer. This is for single-end reads.
124    """
125    def __init__(self, writer, filter: SingleEndFilter, filter2=None):
126        super().__init__()
127        # TODO filter2 should really not be here
128        self.filtered = 0
129        self.writer = writer
130        self.filter = filter
131
132    def __repr__(self):
133        return "Redirector(writer={}, filter={})".format(self.writer, self.filter)
134
135    def __call__(self, read, info: ModificationInfo):
136        if self.filter(read, info):
137            self.filtered += 1
138            if self.writer is not None:
139                self.writer.write(read)
140            return DISCARD
141        return KEEP
142
143
144class PairedRedirector(PairedEndFilter):
145    """
146    Redirect paired-end reads matching a filtering criterion to a writer.
147    Different filtering styles are supported, differing by which of the
148    two reads in a pair have to fulfill the filtering criterion.
149    """
150    def __init__(self, writer, filter, filter2, pair_filter_mode='any'):
151        """
152        pair_filter_mode -- these values are allowed:
153            'any': The pair is discarded if any read matches.
154            'both': The pair is discarded if both reads match.
155            'first': The pair is discarded if the first read matches.
156        """
157        super().__init__()
158        if pair_filter_mode not in ('any', 'both', 'first'):
159            raise ValueError("pair_filter_mode must be 'any', 'both' or 'first'")
160        self._pair_filter_mode = pair_filter_mode
161        self.filtered = 0
162        self.writer = writer
163        self.filter = filter
164        self.filter2 = filter2
165        if filter2 is None:
166            self._is_filtered = self._is_filtered_first
167        elif filter is None:
168            self._is_filtered = self._is_filtered_second
169        elif pair_filter_mode == 'any':
170            self._is_filtered = self._is_filtered_any
171        elif pair_filter_mode == 'both':
172            self._is_filtered = self._is_filtered_both
173        else:
174            self._is_filtered = self._is_filtered_first
175
176    def __repr__(self):
177        return "PairedRedirector(writer={}, filter={}, filter2={}, pair_filter_mode='{}')".format(
178            self.writer, self.filter, self.filter2, self._pair_filter_mode)
179
180    def _is_filtered_any(self, read1, read2, info1: ModificationInfo, info2: ModificationInfo):
181        return self.filter(read1, info1) or self.filter2(read2, info2)
182
183    def _is_filtered_both(self, read1, read2, info1: ModificationInfo, info2: ModificationInfo):
184        return self.filter(read1, info1) and self.filter2(read2, info2)
185
186    def _is_filtered_first(self, read1, read2, info1: ModificationInfo, info2: ModificationInfo):
187        return self.filter(read1, info1)
188
189    def _is_filtered_second(self, read1, read2, info1: ModificationInfo, info2: ModificationInfo):
190        return self.filter2(read2, info2)
191
192    def __call__(self, read1, read2, info1: ModificationInfo, info2: ModificationInfo):
193        if self._is_filtered(read1, read2, info1, info2):
194            self.filtered += 1
195            if self.writer is not None:
196                self.writer.write(read1, read2)
197            return DISCARD
198        return KEEP
199
200
201class TooShortReadFilter(SingleEndFilter):
202    def __init__(self, minimum_length):
203        self.minimum_length = minimum_length
204
205    def __repr__(self):
206        return "TooShortReadFilter(minimum_length={})".format(self.minimum_length)
207
208    def __call__(self, read, info: ModificationInfo):
209        return len(read) < self.minimum_length
210
211
212class TooLongReadFilter(SingleEndFilter):
213    def __init__(self, maximum_length):
214        self.maximum_length = maximum_length
215
216    def __repr__(self):
217        return "TooLongReadFilter(maximum_length={})".format(self.maximum_length)
218
219    def __call__(self, read, info: ModificationInfo):
220        return len(read) > self.maximum_length
221
222
223class MaximumExpectedErrorsFilter(SingleEndFilter):
224    """
225    Discard reads whose expected number of errors, according to the quality
226    values, exceeds a threshold.
227
228    The idea comes from usearch's -fastq_maxee parameter
229    (http://drive5.com/usearch/).
230    """
231    def __init__(self, max_errors):
232        self.max_errors = max_errors
233
234    def __repr__(self):
235        return "MaximumExpectedErrorsFilter(max_errors={})".format(self.max_errors)
236
237    def __call__(self, read, info: ModificationInfo):
238        """Return True when the read should be discarded"""
239        return expected_errors(read.qualities) > self.max_errors
240
241
242class NContentFilter(SingleEndFilter):
243    """
244    Discard a read if it has too many 'N' bases. It handles both raw counts
245    of Ns as well as proportions. Note, for raw counts, it is a 'greater than' comparison,
246    so a cutoff of '1' will keep reads with a single N in it.
247    """
248    def __init__(self, count):
249        """
250        Count -- if it is below 1.0, it will be considered a proportion, and above and equal to
251        1 will be considered as discarding reads with a number of N's greater than this cutoff.
252        """
253        assert count >= 0
254        self.is_proportion = count < 1.0
255        self.cutoff = count
256
257    def __repr__(self):
258        return "NContentFilter(cutoff={}, is_proportion={})".format(
259            self.cutoff, self.is_proportion)
260
261    def __call__(self, read, info: ModificationInfo):
262        """Return True when the read should be discarded"""
263        n_count = read.sequence.lower().count('n')
264        if self.is_proportion:
265            if len(read) == 0:
266                return False
267            return n_count / len(read) > self.cutoff
268        else:
269            return n_count > self.cutoff
270
271
272class DiscardUntrimmedFilter(SingleEndFilter):
273    """
274    Return True if read is untrimmed.
275    """
276    def __repr__(self):
277        return "DiscardUntrimmedFilter()"
278
279    def __call__(self, read, info: ModificationInfo):
280        return not info.matches
281
282
283class DiscardTrimmedFilter(SingleEndFilter):
284    """
285    Return True if read is trimmed.
286    """
287    def __repr__(self):
288        return "DiscardTrimmedFilter()"
289
290    def __call__(self, read, info: ModificationInfo):
291        return bool(info.matches)
292
293
294class CasavaFilter(SingleEndFilter):
295    """
296    Remove reads that fail the CASAVA filter. These have header lines that
297    look like ``xxxx x:Y:x:x`` (with a ``Y``). Reads that pass the filter
298    have an ``N`` instead of ``Y``.
299
300    Reads with unrecognized headers are kept.
301    """
302    def __repr__(self):
303        return "CasavaFilter()"
304
305    def __call__(self, read, info: ModificationInfo):
306        _, _, right = read.name.partition(' ')
307        return right[1:4] == ':Y:'  # discard if :Y: found
308
309
310class Demultiplexer(SingleEndFilterWithStatistics):
311    """
312    Demultiplex trimmed reads. Reads are written to different output files
313    depending on which adapter matches. Files are created when the first read
314    is written to them.
315
316    Untrimmed reads are sent to writers[None] if that key exists.
317    """
318    def __init__(self, writers: Dict[Optional[str], Any]):
319        """
320        out is a dictionary that maps an adapter name to a writer
321        """
322        super().__init__()
323        self._writers = writers
324        self._untrimmed_writer = self._writers.get(None, None)
325
326    def __call__(self, read, info):
327        """
328        Write the read to the proper output file according to the most recent match
329        """
330        if info.matches:
331            name = info.matches[-1].adapter.name
332            self.update_statistics(read)
333            self._writers[name].write(read)
334        elif self._untrimmed_writer is not None:
335            self.update_statistics(read)
336            self._untrimmed_writer.write(read)
337        return DISCARD
338
339
340class PairedDemultiplexer(PairedEndFilterWithStatistics):
341    """
342    Demultiplex trimmed paired-end reads. Reads are written to different output files
343    depending on which adapter (in read 1) matches.
344    """
345    def __init__(self, writers: Dict[Optional[str], Any]):
346        super().__init__()
347        self._writers = writers
348        self._untrimmed_writer = self._writers.get(None, None)
349
350    def __call__(self, read1, read2, info1: ModificationInfo, info2: ModificationInfo):
351        assert read2 is not None
352        if info1.matches:
353            name = info1.matches[-1].adapter.name  # type: ignore
354            self.update_statistics(read1, read2)
355            self._writers[name].write(read1, read2)
356        elif self._untrimmed_writer is not None:
357            self.update_statistics(read1, read2)
358            self._untrimmed_writer.write(read1, read2)
359        return DISCARD
360
361
362class CombinatorialDemultiplexer(PairedEndFilterWithStatistics):
363    """
364    Demultiplex paired-end reads depending on which adapter matches, taking into account
365    matches on R1 and R2.
366    """
367    def __init__(self, writers: Dict[Tuple[Optional[str], Optional[str]], Any]):
368        """
369        Adapter names of the matches on R1 and R2 will be used to look up the writer in the
370        writers dict. If there is no match on a read, None is used in the lookup instead
371        of the name. Missing dictionary keys are ignored and can be used to discard
372        read pairs.
373        """
374        super().__init__()
375        self._writers = writers
376
377    def __call__(self, read1, read2, info1, info2):
378        """
379        Write the read to the proper output file according to the most recent matches both on
380        R1 and R2
381        """
382        assert read2 is not None
383        name1 = info1.matches[-1].adapter.name if info1.matches else None
384        name2 = info2.matches[-1].adapter.name if info2.matches else None
385        key = (name1, name2)
386        if key in self._writers:
387            self.update_statistics(read1, read2)
388            self._writers[key].write(read1, read2)
389        return DISCARD
390
391
392class RestFileWriter(SingleEndFilter):
393    def __init__(self, file):
394        self.file = file
395
396    def __call__(self, read, info):
397        # TODO this fails with linked adapters
398        if info.matches:
399            rest = info.matches[-1].rest()
400            if len(rest) > 0:
401                print(rest, read.name, file=self.file)
402        return KEEP
403
404
405class WildcardFileWriter(SingleEndFilter):
406    def __init__(self, file):
407        self.file = file
408
409    def __call__(self, read, info):
410        # TODO this fails with linked adapters
411        if info.matches:
412            print(info.matches[-1].wildcards(), read.name, file=self.file)
413        return KEEP
414
415
416class InfoFileWriter(SingleEndFilter):
417    def __init__(self, file):
418        self.file = file
419
420    def __call__(self, read, info: ModificationInfo):
421        current_read = info.original_read
422        if info.matches:
423            for match in info.matches:
424                for info_record in match.get_info_records(current_read):
425                    # info_record[0] is the read name suffix
426                    print(read.name + info_record[0], *info_record[1:], sep='\t', file=self.file)
427                current_read = match.trimmed(current_read)
428        else:
429            seq = read.sequence
430            qualities = read.qualities if read.qualities is not None else ''
431            print(read.name, -1, seq, qualities, sep='\t', file=self.file)
432
433        return KEEP
434