1"""
2Routines for printing a report.
3"""
4import sys
5from io import StringIO
6import textwrap
7from collections import Counter
8from typing import Any, Optional, List, Dict
9from .adapters import (
10    EndStatistics, AdapterStatistics, FrontAdapter, NonInternalFrontAdapter, PrefixAdapter,
11    BackAdapter, NonInternalBackAdapter, SuffixAdapter, AnywhereAdapter, LinkedAdapter,
12)
13from .modifiers import (QualityTrimmer, NextseqQualityTrimmer,
14    AdapterCutter, PairedAdapterCutter, ReverseComplementer, PairedEndModifierWrapper)
15from .filters import (WithStatistics, TooShortReadFilter, TooLongReadFilter, NContentFilter,
16    CasavaFilter, MaximumExpectedErrorsFilter)
17
18
19def safe_divide(numerator: Optional[int], denominator: int) -> float:
20    if numerator is None or not denominator:
21        return 0.0
22    else:
23        return numerator / denominator
24
25
26def add_if_not_none(a: Optional[int], b: Optional[int]) -> Optional[int]:
27    if a is None:
28        return b
29    if b is None:
30        return a
31    return a + b
32
33
34class Statistics:
35    def __init__(self) -> None:
36        """
37        """
38        self.paired: Optional[bool] = None
39        self.did_quality_trimming: Optional[bool] = None
40        self.too_short: Optional[int] = None
41        self.too_long: Optional[int] = None
42        self.too_many_n: Optional[int] = None
43        self.too_many_expected_errors: Optional[int] = None
44        self.casava_filtered: Optional[int] = None
45        self.reverse_complemented: Optional[int] = None
46        self.n = 0
47        self.written = 0
48        self.total_bp = [0, 0]
49        self.written_bp = [0, 0]
50        self.written_lengths: List[Counter] = [Counter(), Counter()]
51        self.with_adapters = [0, 0]
52        self.quality_trimmed_bp = [0, 0]
53        self.adapter_stats: List[List[AdapterStatistics]] = [[], []]
54
55    def __iadd__(self, other: Any):
56        self.n += other.n
57        self.written += other.written
58
59        if self.paired is None:
60            self.paired = other.paired
61        elif self.paired != other.paired:
62            raise ValueError('Incompatible Statistics: paired is not equal')
63        if self.did_quality_trimming is None:
64            self.did_quality_trimming = other.did_quality_trimming
65        elif self.did_quality_trimming != other.did_quality_trimming:
66            raise ValueError('Incompatible Statistics: did_quality_trimming is not equal')
67
68        self.reverse_complemented = add_if_not_none(
69            self.reverse_complemented, other.reverse_complemented)
70        self.too_short = add_if_not_none(self.too_short, other.too_short)
71        self.too_long = add_if_not_none(self.too_long, other.too_long)
72        self.too_many_n = add_if_not_none(self.too_many_n, other.too_many_n)
73        self.too_many_expected_errors = add_if_not_none(
74            self.too_many_expected_errors, other.too_many_expected_errors)
75        self.casava_filtered = add_if_not_none(self.casava_filtered, other.casava_filtered)
76        for i in (0, 1):
77            self.total_bp[i] += other.total_bp[i]
78            self.written_bp[i] += other.written_bp[i]
79            self.written_lengths[i] += other.written_lengths[i]
80            self.with_adapters[i] += other.with_adapters[i]
81            self.quality_trimmed_bp[i] += other.quality_trimmed_bp[i]
82            if self.adapter_stats[i] and other.adapter_stats[i]:
83                if len(self.adapter_stats[i]) != len(other.adapter_stats[i]):
84                    raise ValueError('Incompatible Statistics objects (adapter_stats length)')
85                for j in range(len(self.adapter_stats[i])):
86                    self.adapter_stats[i][j] += other.adapter_stats[i][j]
87            elif other.adapter_stats[i]:
88                assert self.adapter_stats[i] == []
89                self.adapter_stats[i] = other.adapter_stats[i]
90        return self
91
92    def collect(self, n: int, total_bp1: int, total_bp2: Optional[int], modifiers, writers):
93        """
94        n -- total number of reads
95        total_bp1 -- number of bases in first reads
96        total_bp2 -- number of bases in second reads. None for single-end data.
97        """
98        self.n = n
99        self.total_bp[0] = total_bp1
100        if total_bp2 is None:
101            self.paired = False
102        else:
103            self.paired = True
104            self.total_bp[1] = total_bp2
105
106        for writer in writers:
107            self._collect_writer(writer)
108        assert self.written is not None
109        for modifier in modifiers:
110            self._collect_modifier(modifier)
111
112        # For chaining
113        return self
114
115    def _collect_writer(self, w) -> None:
116        if isinstance(w, WithStatistics):
117            self.written += w.written_reads()
118            written_bp = w.written_bp()
119            written_lengths = w.written_lengths()
120            for i in 0, 1:
121                self.written_bp[i] += written_bp[i]
122                self.written_lengths[i] += written_lengths[i]
123        if hasattr(w, "filter"):
124            if isinstance(w.filter, TooShortReadFilter):
125                self.too_short = w.filtered
126            elif isinstance(w.filter, TooLongReadFilter):
127                self.too_long = w.filtered
128            elif isinstance(w.filter, NContentFilter):
129                self.too_many_n = w.filtered
130            elif isinstance(w.filter, MaximumExpectedErrorsFilter):
131                self.too_many_expected_errors = w.filtered
132            elif isinstance(w.filter, CasavaFilter):
133                self.casava_filtered = w.filtered
134
135    def _collect_modifier(self, m) -> None:
136        if isinstance(m, PairedAdapterCutter):
137            for i in 0, 1:
138                self.with_adapters[i] += m.with_adapters
139                self.adapter_stats[i] = list(m.adapter_statistics[i].values())
140            return
141        if isinstance(m, PairedEndModifierWrapper):
142            modifiers_list = [(0, m._modifier1), (1, m._modifier2)]
143        else:
144            modifiers_list = [(0, m)]
145        for i, modifier in modifiers_list:
146            if isinstance(modifier, (QualityTrimmer, NextseqQualityTrimmer)):
147                self.quality_trimmed_bp[i] = modifier.trimmed_bases
148                self.did_quality_trimming = True
149            elif isinstance(modifier, AdapterCutter):
150                self.with_adapters[i] += modifier.with_adapters
151                self.adapter_stats[i] = list(modifier.adapter_statistics.values())
152            elif isinstance(modifier, ReverseComplementer):
153                self.with_adapters[i] += modifier.adapter_cutter.with_adapters
154                self.adapter_stats[i] = list(modifier.adapter_cutter.adapter_statistics.values())
155                self.reverse_complemented = modifier.reverse_complemented
156
157    @property
158    def total(self) -> int:
159        return sum(self.total_bp)
160
161    @property
162    def quality_trimmed(self) -> int:
163        return sum(self.quality_trimmed_bp)
164
165    @property
166    def total_written_bp(self) -> int:
167        return sum(self.written_bp)
168
169    @property
170    def written_fraction(self) -> float:
171        return safe_divide(self.written, self.n)
172
173    @property
174    def with_adapters_fraction(self) -> List[float]:
175        return [safe_divide(v, self.n) for v in self.with_adapters]
176
177    @property
178    def quality_trimmed_fraction(self) -> float:
179        return safe_divide(self.quality_trimmed, self.total)
180
181    @property
182    def total_written_bp_fraction(self) -> float:
183        return safe_divide(self.total_written_bp, self.total)
184
185    @property
186    def reverse_complemented_fraction(self) -> float:
187        return safe_divide(self.reverse_complemented, self.n)
188
189    @property
190    def too_short_fraction(self) -> float:
191        return safe_divide(self.too_short, self.n)
192
193    @property
194    def too_long_fraction(self) -> float:
195        return safe_divide(self.too_long, self.n)
196
197    @property
198    def too_many_n_fraction(self) -> float:
199        return safe_divide(self.too_many_n, self.n)
200
201    @property
202    def too_many_expected_errors_fraction(self) -> float:
203        return safe_divide(self.too_many_expected_errors, self.n)
204
205    @property
206    def casava_filtered_fraction(self) -> float:
207        return safe_divide(self.casava_filtered, self.n)
208
209
210def error_ranges(adapter_statistics: EndStatistics) -> str:
211    length = adapter_statistics.effective_length
212    error_rate = adapter_statistics.max_error_rate
213    if adapter_statistics.allows_partial_matches:
214        prev = 1
215        s = "\n"
216        for errors in range(1, int(error_rate * length) + 1):
217            r = int(errors / error_rate)
218            s += "{}-{} bp: {}; ".format(prev, r - 1, errors - 1)
219            prev = r
220        if prev == length:
221            s += "{} bp: {}".format(length, int(error_rate * length))
222        else:
223            s += "{}-{} bp: {}".format(prev, length, int(error_rate * length))
224    else:
225        s = f" {int(error_rate * length)}"
226
227    return "No. of allowed errors:" + s + "\n"
228
229
230def histogram(end_statistics: EndStatistics, n: int, gc_content: float) -> str:
231    """
232    Return a formatted histogram. Include the no. of reads expected to be
233    trimmed by chance (assuming a uniform distribution of nucleotides in the reads).
234
235    adapter_statistics -- EndStatistics object
236    adapter_length -- adapter length
237    n -- total no. of reads.
238    """
239    sio = StringIO()
240    d = end_statistics.lengths
241    errors = end_statistics.errors
242
243    match_probabilities = end_statistics.random_match_probabilities(gc_content=gc_content)
244    print("length", "count", "expect", "max.err", "error counts", sep="\t", file=sio)
245    for length in sorted(d):
246        # when length surpasses adapter_length, the
247        # probability does not increase anymore
248        expect = n * match_probabilities[min(len(end_statistics.sequence), length)]
249        count = d[length]
250        max_errors = max(errors[length].keys())
251        errs = ' '.join(str(errors[length][e]) for e in range(max_errors + 1))
252        print(
253            length,
254            count,
255            "{:.1F}".format(expect),
256            int(end_statistics.max_error_rate * min(length, len(end_statistics.sequence))),
257            errs,
258            sep="\t",
259            file=sio,
260        )
261    return sio.getvalue() + "\n"
262
263
264class AdjacentBaseStatistics:
265    def __init__(self, bases: Dict[str, int]):
266        """
267        """
268        self.bases: Dict[str, int] = bases
269        self._warnbase: Optional[str] = None
270        total = sum(self.bases.values())
271        if total == 0:
272            self._fractions = None
273        else:
274            self._fractions = []
275            for base in ['A', 'C', 'G', 'T', '']:
276                text = base if base != '' else 'none/other'
277                fraction = 1.0 * self.bases[base] / total
278                self._fractions.append((text, 1.0 * self.bases[base] / total))
279                if fraction > 0.8 and base != '':
280                    self._warnbase = text
281            if total < 20:
282                self._warnbase = None
283
284    def __repr__(self):
285        return 'AdjacentBaseStatistics(bases={})'.format(self.bases)
286
287    @property
288    def should_warn(self) -> bool:
289        return self._warnbase is not None
290
291    def __str__(self) -> str:
292        if not self._fractions:
293            return ""
294        sio = StringIO()
295        print('Bases preceding removed adapters:', file=sio)
296        for text, fraction in self._fractions:
297            print('  {}: {:.1%}'.format(text, fraction), file=sio)
298        if self.should_warn:
299            print('WARNING:', file=sio)
300            print('    The adapter is preceded by "{}" extremely often.'.format(self._warnbase), file=sio)
301            print("    The provided adapter sequence could be incomplete at its 5' end.", file=sio)
302        return sio.getvalue()
303
304
305def full_report(stats: Statistics, time: float, gc_content: float) -> str:  # noqa: C901
306    """Print report to standard output."""
307    if stats.n == 0:
308        return "No reads processed!"
309    if time == 0:
310        time = 1E-6
311    sio = StringIO()
312
313    def print_s(*args, **kwargs):
314        kwargs['file'] = sio
315        print(*args, **kwargs)
316
317    if sys.version_info[:2] <= (3, 6):
318        micro = "u"
319    else:
320        micro = "µ"
321    print_s("Finished in {:.2F} s ({:.0F} {}s/read; {:.2F} M reads/minute).".format(
322        time, 1E6 * time / stats.n, micro, stats.n / time * 60 / 1E6))
323
324    report = "\n=== Summary ===\n\n"
325    if stats.paired:
326        report += textwrap.dedent("""\
327        Total read pairs processed:      {o.n:13,d}
328          Read 1 with adapter:           {o.with_adapters[0]:13,d} ({o.with_adapters_fraction[0]:.1%})
329          Read 2 with adapter:           {o.with_adapters[1]:13,d} ({o.with_adapters_fraction[1]:.1%})
330        """)
331    else:
332        report += textwrap.dedent("""\
333        Total reads processed:           {o.n:13,d}
334        Reads with adapters:             {o.with_adapters[0]:13,d} ({o.with_adapters_fraction[0]:.1%})
335        """)
336    if stats.reverse_complemented is not None:
337        report += "Reverse-complemented:            " \
338                  "{o.reverse_complemented:13,d} ({o.reverse_complemented_fraction:.1%})\n"
339
340    if stats.too_short is not None:
341        report += "{pairs_or_reads} that were too short:       {o.too_short:13,d} ({o.too_short_fraction:.1%})\n"
342    if stats.too_long is not None:
343        report += "{pairs_or_reads} that were too long:        {o.too_long:13,d} ({o.too_long_fraction:.1%})\n"
344    if stats.too_many_n is not None:
345        report += "{pairs_or_reads} with too many N:           {o.too_many_n:13,d} ({o.too_many_n_fraction:.1%})\n"
346    if stats.too_many_expected_errors is not None:
347        report += "{pairs_or_reads} with too many exp. errors: " \
348                  "{o.too_many_expected_errors:13,d} ({o.too_many_expected_errors_fraction:.1%})\n"
349    if stats.casava_filtered is not None:
350        report += "{pairs_or_reads} failed CASAVA filter:      " \
351                  "{o.casava_filtered:13,d} ({o.casava_filtered_fraction:.1%})\n"
352
353    report += textwrap.dedent("""\
354    {pairs_or_reads} written (passing filters): {o.written:13,d} ({o.written_fraction:.1%})
355
356    Total basepairs processed: {o.total:13,d} bp
357    """)
358    if stats.paired:
359        report += "  Read 1: {o.total_bp[0]:13,d} bp\n"
360        report += "  Read 2: {o.total_bp[1]:13,d} bp\n"
361
362    if stats.did_quality_trimming:
363        report += "Quality-trimmed:           {o.quality_trimmed:13,d} bp ({o.quality_trimmed_fraction:.1%})\n"
364        if stats.paired:
365            report += "  Read 1: {o.quality_trimmed_bp[0]:13,d} bp\n"
366            report += "  Read 2: {o.quality_trimmed_bp[1]:13,d} bp\n"
367
368    report += "Total written (filtered):  {o.total_written_bp:13,d} bp ({o.total_written_bp_fraction:.1%})\n"
369    if stats.paired:
370        report += "  Read 1: {o.written_bp[0]:13,d} bp\n"
371        report += "  Read 2: {o.written_bp[1]:13,d} bp\n"
372    pairs_or_reads = "Pairs" if stats.paired else "Reads"
373    report = report.format(o=stats, pairs_or_reads=pairs_or_reads)
374    print_s(report)
375
376    warning = False
377    for which_in_pair in (0, 1):
378        for adapter_statistics in stats.adapter_stats[which_in_pair]:
379            total_front = sum(adapter_statistics.front.lengths.values())
380            total_back = sum(adapter_statistics.back.lengths.values())
381            total = total_front + total_back
382            reverse_complemented = adapter_statistics.reverse_complemented
383            adapter = adapter_statistics.adapter
384            if isinstance(adapter, (BackAdapter, NonInternalBackAdapter, SuffixAdapter)):
385                assert total_front == 0
386            if isinstance(adapter, (FrontAdapter, NonInternalFrontAdapter, PrefixAdapter)):
387                assert total_back == 0
388
389            if stats.paired:
390                extra = 'First read: ' if which_in_pair == 0 else 'Second read: '
391            else:
392                extra = ''
393
394            print_s("=" * 3, extra + "Adapter", adapter_statistics.name, "=" * 3)
395            print_s()
396
397            if isinstance(adapter, LinkedAdapter):
398                print_s("Sequence: {}...{}; Type: linked; Length: {}+{}; "
399                    "5' trimmed: {} times; 3' trimmed: {} times".format(
400                        adapter_statistics.front.sequence,
401                        adapter_statistics.back.sequence,
402                        len(adapter_statistics.front.sequence),
403                        len(adapter_statistics.back.sequence),
404                        total_front, total_back))
405            else:
406                print_s("Sequence: {}; Type: {}; Length: {}; Trimmed: {} times".
407                    format(adapter_statistics.front.sequence, adapter.description,
408                        len(adapter_statistics.front.sequence), total), end="")
409            if stats.reverse_complemented is not None:
410                print_s("; Reverse-complemented: {} times".format(reverse_complemented))
411            else:
412                print_s()
413            if total == 0:
414                print_s()
415                continue
416            if isinstance(adapter, AnywhereAdapter):
417                print_s(total_front, "times, it overlapped the 5' end of a read")
418                print_s(total_back, "times, it overlapped the 3' end or was within the read")
419                print_s()
420                print_s(error_ranges(adapter_statistics.front))
421                print_s("Overview of removed sequences (5')")
422                print_s(histogram(adapter_statistics.front, stats.n, gc_content))
423                print_s()
424                print_s("Overview of removed sequences (3' or within)")
425                print_s(histogram(adapter_statistics.back, stats.n, gc_content))
426            elif isinstance(adapter, LinkedAdapter):
427                print_s()
428                print_s(error_ranges(adapter_statistics.front))
429                print_s(error_ranges(adapter_statistics.back))
430                print_s("Overview of removed sequences at 5' end")
431                print_s(histogram(adapter_statistics.front, stats.n, gc_content))
432                print_s()
433                print_s("Overview of removed sequences at 3' end")
434                print_s(histogram(adapter_statistics.back, stats.n, gc_content))
435            elif isinstance(adapter, (FrontAdapter, NonInternalFrontAdapter, PrefixAdapter)):
436                print_s()
437                print_s(error_ranges(adapter_statistics.front))
438                print_s("Overview of removed sequences")
439                print_s(histogram(adapter_statistics.front, stats.n, gc_content))
440            else:
441                assert isinstance(adapter, (BackAdapter, NonInternalBackAdapter, SuffixAdapter))
442                print_s()
443                print_s(error_ranges(adapter_statistics.back))
444                base_stats = AdjacentBaseStatistics(adapter_statistics.back.adjacent_bases)
445                warning = warning or base_stats.should_warn
446                print_s(base_stats)
447                print_s("Overview of removed sequences")
448                print_s(histogram(adapter_statistics.back, stats.n, gc_content))
449
450    if warning:
451        print_s('WARNING:')
452        print_s('    One or more of your adapter sequences may be incomplete.')
453        print_s('    Please see the detailed output above.')
454
455    return sio.getvalue().rstrip()
456
457
458def minimal_report(stats: Statistics, time: float, gc_content: float) -> str:
459    """Create a minimal tabular report suitable for concatenation"""
460    _ = time
461    _ = gc_content
462
463    def none(value):
464        return 0 if value is None else value
465
466    fields = [
467        "OK",
468        stats.n,  # reads/pairs in
469        stats.total,  # bases in
470        none(stats.too_short),  # reads/pairs
471        none(stats.too_long),  # reads/pairs
472        none(stats.too_many_n),  # reads/pairs
473        stats.written,  # reads/pairs out
474        stats.with_adapters[0],  # reads
475        stats.quality_trimmed_bp[0],  # bases
476        stats.written_bp[0],  # bases out
477    ]
478    if stats.paired:
479        fields += [
480            stats.with_adapters[1],  # reads/pairs
481            stats.quality_trimmed_bp[1],  # bases
482            stats.written_bp[1],  # bases
483        ]
484
485    warning = False
486    for which_in_pair in (0, 1):
487        for adapter_statistics in stats.adapter_stats[which_in_pair]:
488            if isinstance(adapter_statistics.adapter, (BackAdapter, NonInternalBackAdapter, SuffixAdapter)):
489                if AdjacentBaseStatistics(adapter_statistics.back.adjacent_bases).should_warn:
490                    warning = True
491                    break
492    if warning:
493        fields[0] = "WARN"
494    header = [
495        'status', 'in_reads', 'in_bp', 'too_short', 'too_long', 'too_many_n', 'out_reads',
496        'w/adapters', 'qualtrim_bp', 'out_bp']
497    if stats.paired:
498        header += ['w/adapters2', 'qualtrim2_bp', 'out2_bp']
499    return "\t".join(header) + "\n" + "\t".join(str(x) for x in fields)
500