1"""
2This module implements all the read modifications that cutadapt supports.
3A modifier must be callable and typically implemented as a class with a
4__call__ method.
5"""
6import re
7from types import SimpleNamespace
8from typing import Sequence, List, Tuple, Optional, Set
9from abc import ABC, abstractmethod
10from collections import OrderedDict
11
12from dnaio import record_names_match, Sequence as DnaSequence
13
14from .qualtrim import quality_trim_index, nextseq_trim_index
15from .adapters import MultipleAdapters, SingleAdapter, IndexedPrefixAdapters, IndexedSuffixAdapters, \
16    Match, remainder, Adapter
17from .tokenizer import tokenize_braces, TokenizeError, Token, BraceToken
18from .utils import reverse_complemented_sequence
19
20
21# If the number of prefix or suffix adapters is higher than this, switch to using an index
22INDEXING_THRESHOLD = 5
23
24
25class ModificationInfo:
26    """
27    An object of this class is created for each read that passes through the pipeline.
28    Any information (except the read itself) that needs to be passed from one modifier
29    to one later in the pipeline or from one modifier to the filters is recorded here.
30    """
31    __slots__ = ["matches", "original_read", "cut_prefix", "cut_suffix", "is_rc"]
32
33    def __init__(self, read):
34        self.matches = []  # type: List[Match]
35        self.original_read = read
36        self.cut_prefix = None
37        self.cut_suffix = None
38        self.is_rc = None
39
40
41class SingleEndModifier(ABC):
42    @abstractmethod
43    def __call__(self, read, info: ModificationInfo):
44        pass
45
46
47class PairedEndModifier(ABC):
48    @abstractmethod
49    def __call__(
50        self, read1, read2, info1: ModificationInfo, info2: ModificationInfo
51    ) -> Tuple[DnaSequence, DnaSequence]:
52        pass
53
54
55class PairedEndModifierWrapper(PairedEndModifier):
56    """
57    Wrapper for modifiers that work on both reads in a paired-end read
58    """
59    paired = True
60
61    def __init__(self, modifier1: Optional[SingleEndModifier], modifier2: Optional[SingleEndModifier]):
62        """Set one of the modifiers to None to work on R1 or R2 only"""
63        self._modifier1 = modifier1
64        self._modifier2 = modifier2
65        if self._modifier1 is None and self._modifier2 is None:
66            raise ValueError("Not both modifiers may be None")
67
68    def __repr__(self):
69        return "PairedEndModifierWrapper(modifier1={!r}, modifier2={!r})".format(
70            self._modifier1, self._modifier2)
71
72    def __call__(self, read1, read2, info1: ModificationInfo, info2: ModificationInfo):
73        if self._modifier1 is None:
74            return read1, self._modifier2(read2, info2)  # type: ignore
75        if self._modifier2 is None:
76            return self._modifier1(read1, info1), read2
77        return self._modifier1(read1, info1), self._modifier2(read2, info2)
78
79
80class AdapterCutter(SingleEndModifier):
81    """
82    Repeatedly find one of multiple adapters in reads.
83    The number of times the search is repeated is specified by the
84    times parameter.
85    """
86
87    def __init__(
88        self,
89        adapters: List[Adapter],
90        times: int = 1,
91        action: Optional[str] = "trim",
92        index: bool = True,
93    ):
94        """
95        action -- What to do with a found adapter:
96          None: Do nothing, only update the ModificationInfo appropriately
97          "trim": Remove the adapter and down- or upstream sequence depending on adapter type
98          "mask": Replace the part of the sequence that would have been removed with "N" bases
99          "lowercase": Convert the part of the sequence that would have been removed to lowercase
100          "retain": Like "trim", but leave the adapter sequence itself in the read
101
102        index -- if True, an adapter index (for multiple adapters) is created if possible
103        """
104        self.times = times
105        assert action in ("trim", "mask", "lowercase", "retain", None)
106        self.action = action
107        self.with_adapters = 0
108        self.adapter_statistics = OrderedDict((a, a.create_statistics()) for a in adapters)
109        if index:
110            self.adapters = MultipleAdapters(self._regroup_into_indexed_adapters(adapters))
111        else:
112            self.adapters = MultipleAdapters(adapters)
113        if action == "retain" and times > 1:
114            raise ValueError("'retain' cannot be combined with times > 1")
115
116    def __repr__(self):
117        return 'AdapterCutter(adapters={!r}, times={}, action={!r})'.format(
118            self.adapters, self.times, self.action)
119
120    def _regroup_into_indexed_adapters(self, adapters):
121        prefix, suffix, single = self._split_adapters(adapters)
122        # For somewhat better backwards compatibility, avoid re-ordering
123        # the adapters when we don’t need to
124        if len(prefix) > INDEXING_THRESHOLD or len(suffix) > INDEXING_THRESHOLD:
125            result = single
126            if len(prefix) > 1:
127                result.append(IndexedPrefixAdapters(prefix))
128            else:
129                result.extend(prefix)
130            if len(suffix) > 1:
131                result.append(IndexedSuffixAdapters(suffix))
132            else:
133                result.extend(suffix)
134            return result
135        else:
136            return adapters
137
138    @staticmethod
139    def _split_adapters(
140        adapters: Sequence[SingleAdapter]
141    ) -> Tuple[Sequence[SingleAdapter], Sequence[SingleAdapter], Sequence[SingleAdapter]]:
142        """
143        Split adapters into three different categories so that they can possibly be used
144        with a MultiAdapter. Return a tuple (prefix, suffix, other), where
145        - prefix is a list of all anchored 5' adapters that MultiAdapter would accept
146        - suffix is a list of all anchored 3' adapters that MultiAdapter would accept
147        - other is a list of all remaining adapters.
148        """
149        prefix = []  # type: List[SingleAdapter]
150        suffix = []  # type: List[SingleAdapter]
151        other = []  # type: List[SingleAdapter]
152        for a in adapters:
153            if IndexedPrefixAdapters.is_acceptable(a):
154                prefix.append(a)
155            elif IndexedSuffixAdapters.is_acceptable(a):
156                suffix.append(a)
157            else:
158                other.append(a)
159        return prefix, suffix, other
160
161    @staticmethod
162    def trim_but_retain_adapter(read, matches: Sequence[Match]):
163        start, stop = matches[-1].retained_adapter_interval()
164        return read[start:stop]
165
166    @staticmethod
167    def masked_read(read, matches: Sequence[Match]):
168        start, stop = remainder(matches)
169        result = read[:]
170        result.sequence = (
171            'N' * start
172            + read.sequence[start:stop]
173            + 'N' * (len(read) - stop))
174        return result
175
176    @staticmethod
177    def lowercased_read(read, matches: Sequence[Match]):
178        start, stop = remainder(matches)
179        result = read[:]
180        result.sequence = (
181            read.sequence[:start].lower()
182            + read.sequence[start:stop].upper()
183            + read.sequence[stop:].lower()
184        )
185        return result
186
187    def __call__(self, read, info: ModificationInfo):
188        trimmed_read, matches = self.match_and_trim(read)
189        if matches:
190            self.with_adapters += 1
191            for match in matches:
192                match.update_statistics(self.adapter_statistics[match.adapter])
193        info.matches.extend(matches)  # TODO extend or overwrite?
194        return trimmed_read
195
196    def match_and_trim(self, read):
197        """
198        Search for the best-matching adapter in a read, perform the requested action
199        ('trim', 'mask' etc. as determined by self.action) and return the
200        (possibly) modified read.
201
202        *self.times* adapter removal rounds are done. During each round,
203        only the best-matching adapter is trimmed. If no adapter was found in a round,
204        no further rounds are attempted.
205
206        Return a pair (trimmed_read, matches), where matches is a list of Match instances.
207        """
208        matches = []
209        if self.action == 'lowercase':  # TODO this should not be needed
210            read.sequence = read.sequence.upper()
211        trimmed_read = read
212        for _ in range(self.times):
213            match = self.adapters.match_to(trimmed_read.sequence)
214            if match is None:
215                # if nothing found, attempt no further rounds
216                break
217            matches.append(match)
218            trimmed_read = match.trimmed(trimmed_read)
219
220        if not matches:
221            return trimmed_read, []
222
223        if self.action == 'trim':
224            # read is already trimmed, nothing to do
225            pass
226        elif self.action == 'retain':
227            trimmed_read = self.trim_but_retain_adapter(read, matches)
228        elif self.action == 'mask':
229            trimmed_read = self.masked_read(read, matches)
230        elif self.action == 'lowercase':
231            trimmed_read = self.lowercased_read(read, matches)
232            assert len(trimmed_read.sequence) == len(read)
233        elif self.action is None:
234            trimmed_read = read[:]
235
236        return trimmed_read, matches
237
238
239class ReverseComplementer(SingleEndModifier):
240    """Trim adapters from a read and its reverse complement"""
241
242    def __init__(self, adapter_cutter: AdapterCutter, rc_suffix: Optional[str] = " rc"):
243        """
244        rc_suffix -- suffix to add to the read name if sequence was reverse-complemented
245        """
246        self.adapter_cutter = adapter_cutter
247        self.reverse_complemented = 0
248        self._suffix = rc_suffix
249
250    def __call__(self, read, info: ModificationInfo):
251        reverse_read = reverse_complemented_sequence(read)
252
253        forward_trimmed_read, forward_matches = self.adapter_cutter.match_and_trim(read)
254        reverse_trimmed_read, reverse_matches = self.adapter_cutter.match_and_trim(reverse_read)
255
256        forward_match_count = sum(m.matches for m in forward_matches)
257        reverse_match_count = sum(m.matches for m in reverse_matches)
258        use_reverse_complement = reverse_match_count > forward_match_count
259
260        if use_reverse_complement:
261            self.reverse_complemented += 1
262            assert reverse_matches
263            trimmed_read, matches = reverse_trimmed_read, reverse_matches
264            info.is_rc = True
265            if self._suffix:
266                trimmed_read.name += self._suffix
267        else:
268            info.is_rc = False
269            trimmed_read, matches = forward_trimmed_read, forward_matches
270
271        if matches:
272            self.adapter_cutter.with_adapters += 1
273            for match in matches:
274                stats = self.adapter_cutter.adapter_statistics[match.adapter]
275                match.update_statistics(stats)
276                stats.reverse_complemented += bool(use_reverse_complement)
277            info.matches.extend(matches)  # TODO extend or overwrite?
278        return trimmed_read
279
280
281class PairedAdapterCutterError(Exception):
282    pass
283
284
285class PairedAdapterCutter(PairedEndModifier):
286    """
287    A Modifier that trims adapter pairs from R1 and R2.
288    """
289
290    def __init__(self, adapters1, adapters2, action='trim'):
291        """
292        adapters1 -- list of Adapters to be removed from R1
293        adapters2 -- list of Adapters to be removed from R1
294
295        Both lists must have the same, non-zero length.
296         read pair is trimmed if adapters1[i] is found in R1 and adapters2[i] in R2.
297
298        action -- What to do with a found adapter: None, 'trim', 'lowercase' or 'mask'
299        """
300        super().__init__()
301        if len(adapters1) != len(adapters2):
302            raise PairedAdapterCutterError(
303                "The number of reads to trim from R1 and R2 must be the same. "
304                "Given: {} for R1, {} for R2".format(len(adapters1), len(adapters2)))
305        if not adapters1:
306            raise PairedAdapterCutterError("No adapters given")
307        self._adapters1 = MultipleAdapters(adapters1)
308        self._adapter_indices = {a: i for i, a in enumerate(adapters1)}
309        self._adapters2 = MultipleAdapters(adapters2)
310        self.action = action
311        self.with_adapters = 0
312        self.adapter_statistics = [None, None]
313        self.adapter_statistics[0] = OrderedDict((a, a.create_statistics()) for a in adapters1)
314        self.adapter_statistics[1] = OrderedDict((a, a.create_statistics()) for a in adapters2)
315
316    def __repr__(self):
317        return 'PairedAdapterCutter(adapters1={!r}, adapters2={!r})'.format(
318            self._adapters1, self._adapters2)
319
320    def __call__(self, read1, read2, info1, info2):
321        """
322        """
323        match1 = self._adapters1.match_to(read1.sequence)
324        if match1 is None:
325            return read1, read2
326        adapter1 = match1.adapter
327        adapter2 = self._adapters2[self._adapter_indices[adapter1]]
328        match2 = adapter2.match_to(read2.sequence)
329        if match2 is None:
330            return read1, read2
331
332        self.with_adapters += 1
333        result = []
334        for i, match, read in zip([0, 1], [match1, match2], [read1, read2]):
335            trimmed_read = read
336            if self.action == 'lowercase':
337                trimmed_read.sequence = trimmed_read.sequence.upper()
338
339            trimmed_read = match.trimmed(trimmed_read)
340            match.update_statistics(self.adapter_statistics[i][match.adapter])
341
342            if self.action == 'trim':
343                # read is already trimmed, nothing to do
344                pass
345            elif self.action == 'mask':
346                trimmed_read = AdapterCutter.masked_read(read, [match])
347            elif self.action == 'lowercase':
348                trimmed_read = AdapterCutter.lowercased_read(read, [match])
349                assert len(trimmed_read.sequence) == len(read)
350            elif self.action == 'retain':
351                trimmed_read = AdapterCutter.trim_but_retain_adapter(read, [match])
352            elif self.action is None:  # --no-trim
353                trimmed_read = read[:]
354            result.append(trimmed_read)
355        info1.matches.append(match1)
356        info2.matches.append(match2)
357        return result
358
359
360class UnconditionalCutter(SingleEndModifier):
361    """
362    A modifier that unconditionally removes the first n or the last n bases from a read.
363
364    If the length is positive, the bases are removed from the beginning of the read.
365    If the length is negative, the bases are removed from the end of the read.
366    """
367    def __init__(self, length: int):
368        self.length = length
369
370    def __call__(self, read, info: ModificationInfo):
371        if self.length > 0:
372            info.cut_prefix = read.sequence[:self.length]
373            return read[self.length:]
374        elif self.length < 0:
375            info.cut_suffix = read.sequence[self.length:]
376            return read[:self.length]
377
378
379class LengthTagModifier(SingleEndModifier):
380    """
381    Replace "length=..." strings in read names.
382    """
383    def __init__(self, length_tag):
384        self.regex = re.compile(r"\b" + length_tag + r"[0-9]*\b")
385        self.length_tag = length_tag
386
387    def __call__(self, read, info: ModificationInfo):
388        read = read[:]
389        if read.name.find(self.length_tag) >= 0:
390            read.name = self.regex.sub(self.length_tag + str(len(read.sequence)), read.name)
391        return read
392
393
394class SuffixRemover(SingleEndModifier):
395    """
396    Remove a given suffix from read names.
397    """
398    def __init__(self, suffix):
399        self.suffix = suffix
400
401    def __call__(self, read, info: ModificationInfo):
402        read = read[:]
403        if read.name.endswith(self.suffix):
404            read.name = read.name[:-len(self.suffix)]
405        return read
406
407
408class PrefixSuffixAdder(SingleEndModifier):
409    """
410    Add a suffix and a prefix to read names
411    """
412    def __init__(self, prefix, suffix):
413        self.prefix = prefix
414        self.suffix = suffix
415
416    def __call__(self, read, info):
417        read = read[:]
418        adapter_name = info.matches[-1].adapter.name if info.matches else 'no_adapter'
419        read.name = self.prefix.replace('{name}', adapter_name) + read.name + \
420            self.suffix.replace('{name}', adapter_name)
421        return read
422
423
424class InvalidTemplate(Exception):
425    pass
426
427
428class Renamer(SingleEndModifier):
429    """
430    Rename reads using a template
431
432    The template string can contain the following placeholders:
433
434    - {header} -- full, unchanged header
435    - {id} -- the part of the header before the first whitespace
436    - {comment} -- the part of the header after the ID, excluding initial whitespace
437    - {cut_prefix} -- prefix removed by UnconditionalCutter (with positive length argument)
438    - {cut_suffix} -- suffix removed by UnconditionalCutter (with negative length argument)
439    - {adapter_name} -- name of the *last* adapter match or no_adapter if there was none
440    - {rc} -- the string 'rc' if the read was reverse complemented (with --revcomp) or '' otherwise
441    """
442    variables = {
443        "header",
444        "id",
445        "comment",
446        "cut_prefix",
447        "cut_suffix",
448        "adapter_name",
449        "rc",
450    }
451
452    def __init__(self, template: str):
453        try:
454            self._tokens = list(tokenize_braces(template))
455        except TokenizeError as e:
456            raise InvalidTemplate("Error in template '{}': {}".format(template, e))
457        self.raise_if_invalid_variable(self._tokens, self.variables)
458        self._template = template
459
460    def __repr__(self):
461        return f"Renamer('{self._template}')"
462
463    @staticmethod
464    def raise_if_invalid_variable(tokens: List[Token], allowed: Set[str]) -> None:
465        for token in tokens:
466            if not isinstance(token, BraceToken):
467                continue
468            value = token.value
469            if value not in allowed:
470                raise InvalidTemplate(
471                    "Error in template: Variable '{}' not recognized".format(value)
472                )
473
474    @staticmethod
475    def parse_name(read_name: str) -> Tuple[str, str]:
476        """Parse read header and return (id, comment) tuple"""
477        fields = read_name.split(maxsplit=1)
478        if len(fields) == 2:
479            return (fields[0], fields[1])
480        else:
481            return (read_name, "")
482
483    def __call__(self, read: DnaSequence, info: ModificationInfo) -> DnaSequence:
484        id_, comment = self.parse_name(read.name)
485        read.name = self._template.format(
486            header=read.name,
487            id=id_,
488            comment=comment,
489            cut_prefix=info.cut_prefix if info.cut_prefix else "",
490            cut_suffix=info.cut_suffix if info.cut_suffix else "",
491            adapter_name=info.matches[-1].adapter.name if info.matches else "no_adapter",
492            rc="rc" if info.is_rc else "",
493        )
494        return read
495
496
497class PairedEndRenamer(PairedEndModifier):
498    """
499    Rename paired-end reads using a template. The template is applied to both
500    R1 and R2, and the same template variables as in the (single-end) renamer
501    are allowed. However,
502    these variables are evaluated separately for each read. For example, if `{comment}`
503    is used, it gets replaced with the R1 comment in the R1 header, and with the R2
504    comment in the R2 header.
505
506    Additionally, all template variables except `id` can be used in the read-specific
507    forms `{r1.variablename}` and `{r2.variablename}`. For example, `{r1.comment}`
508    always gets replaced with the R1 comment, even in R2.
509    """
510
511    def __init__(self, template: str):
512        try:
513            self._tokens = list(tokenize_braces(template))
514        except TokenizeError as e:
515            raise InvalidTemplate("Error in template '{}': {}".format(template, e))
516        Renamer.raise_if_invalid_variable(self._tokens, self._get_allowed_variables())
517        self._template = template
518
519    @staticmethod
520    def _get_allowed_variables() -> Set[str]:
521        allowed = (Renamer.variables - {"rc"}) | {"rn"}
522        for v in Renamer.variables - {"id", "rc"}:
523            allowed.add("r1." + v)
524            allowed.add("r2." + v)
525        return allowed
526
527    def __call__(
528        self, read1: DnaSequence, read2: DnaSequence, info1: ModificationInfo, info2: ModificationInfo
529    ) -> Tuple[DnaSequence, DnaSequence]:
530
531        id1, comment1 = Renamer.parse_name(read1.name)
532        id2, comment2 = Renamer.parse_name(read2.name)
533        if not record_names_match(read1.name, read2.name):
534            raise ValueError("Input read IDs not identical: '{}' != '{}'".format(id1, id2))
535        name1, name2 = self.get_new_headers(
536            id1=id1,
537            id2=id2,
538            comment1=comment1,
539            comment2=comment2,
540            header1=read1.name,
541            header2=read2.name,
542            info1=info1,
543            info2=info2,
544        )
545        new_id1 = Renamer.parse_name(name1)[0]
546        new_id2 = Renamer.parse_name(name2)[0]
547        if not record_names_match(name1, name2):
548            raise InvalidTemplate(
549                "After renaming R1 and R2, their IDs are no longer identical: "
550                "'{}' != '{}'. Original read ID: '{}'. ".format(new_id1, new_id2, id1)
551            )
552        read1.name = name1
553        read2.name = name2
554        return read1, read2
555
556    def get_new_headers(
557        self,
558        id1: str,
559        id2: str,
560        comment1: str,
561        comment2: str,
562        header1: str,
563        header2: str,
564        info1: ModificationInfo,
565        info2: ModificationInfo,
566    ) -> Tuple[str, str]:
567        d = []
568        for id_, comment, header, info in (
569            (id1, comment1, header1, info1), (id2, comment2, header2, info2)
570        ):
571            d.append(
572                dict(
573                    comment=comment,
574                    header=header,
575                    cut_prefix=info.cut_prefix if info.cut_prefix else "",
576                    cut_suffix=info.cut_suffix if info.cut_suffix else "",
577                    adapter_name=info.matches[-1].adapter.name if info.matches else "no_adapter",
578                )
579            )
580        name1 = self._template.format(
581            id=id1,
582            rn=1,
583            **d[0],
584            r1=SimpleNamespace(**d[0]),
585            r2=SimpleNamespace(**d[1]),
586        )
587        name2 = self._template.format(
588            id=id2,
589            rn=2,
590            **d[1],
591            r1=SimpleNamespace(**d[0]),
592            r2=SimpleNamespace(**d[1]),
593        )
594        return name1, name2
595
596
597class ZeroCapper(SingleEndModifier):
598    """
599    Change negative quality values of a read to zero
600    """
601    def __init__(self, quality_base=33):
602        qb = quality_base
603        self.zero_cap_trans = str.maketrans(''.join(map(chr, range(qb))), chr(qb) * qb)
604
605    def __call__(self, read, info: ModificationInfo):
606        read = read[:]
607        read.qualities = read.qualities.translate(self.zero_cap_trans)
608        return read
609
610
611class NextseqQualityTrimmer(SingleEndModifier):
612    def __init__(self, cutoff, base):
613        self.cutoff = cutoff
614        self.base = base
615        self.trimmed_bases = 0
616
617    def __call__(self, read, info: ModificationInfo):
618        stop = nextseq_trim_index(read, self.cutoff, self.base)
619        self.trimmed_bases += len(read) - stop
620        return read[:stop]
621
622
623class QualityTrimmer(SingleEndModifier):
624    def __init__(self, cutoff_front, cutoff_back, base):
625        self.cutoff_front = cutoff_front
626        self.cutoff_back = cutoff_back
627        self.base = base
628        self.trimmed_bases = 0
629
630    def __call__(self, read, info: ModificationInfo):
631        start, stop = quality_trim_index(read.qualities, self.cutoff_front, self.cutoff_back, self.base)
632        self.trimmed_bases += len(read) - (stop - start)
633        return read[start:stop]
634
635
636class Shortener(SingleEndModifier):
637    """Unconditionally shorten a read to the given length
638
639    If the length is positive, the bases are removed from the end of the read.
640    If the length is negative, the bases are removed from the beginning of the read.
641    """
642    def __init__(self, length):
643        self.length = length
644
645    def __call__(self, read, info: ModificationInfo):
646        if self.length >= 0:
647            return read[:self.length]
648        else:
649            return read[self.length:]
650
651
652class NEndTrimmer(SingleEndModifier):
653    """Trims Ns from the 3' and 5' end of reads"""
654    def __init__(self):
655        self.start_trim = re.compile(r'^N+')
656        self.end_trim = re.compile(r'N+$')
657
658    def __call__(self, read, info: ModificationInfo):
659        sequence = read.sequence
660        start_cut = self.start_trim.match(sequence)
661        end_cut = self.end_trim.search(sequence)
662        start_cut = start_cut.end() if start_cut else 0
663        end_cut = end_cut.start() if end_cut else len(read)
664        return read[start_cut:end_cut]
665