1import io
2import os
3import sys
4import copy
5import logging
6import functools
7from typing import List, Optional, BinaryIO, TextIO, Any, Tuple, Dict
8from abc import ABC, abstractmethod
9from multiprocessing import Process, Pipe, Queue
10from pathlib import Path
11import multiprocessing.connection
12from multiprocessing.connection import Connection
13import traceback
14
15import dnaio
16
17from .utils import Progress, FileOpener
18from .modifiers import SingleEndModifier, PairedEndModifier, PairedEndModifierWrapper, ModificationInfo
19from .report import Statistics
20from .filters import (Redirector, PairedRedirector, NoFilter, PairedNoFilter, InfoFileWriter,
21    RestFileWriter, WildcardFileWriter, TooShortReadFilter, TooLongReadFilter, NContentFilter,
22    MaximumExpectedErrorsFilter,
23    CasavaFilter, DiscardTrimmedFilter, DiscardUntrimmedFilter, Demultiplexer,
24    PairedDemultiplexer, CombinatorialDemultiplexer)
25
26logger = logging.getLogger()
27
28
29class InputFiles:
30    def __init__(self, file1: BinaryIO, file2: Optional[BinaryIO] = None, interleaved: bool = False):
31        self.file1 = file1
32        self.file2 = file2
33        self.interleaved = interleaved
34
35    def open(self):
36        return dnaio.open(self.file1, file2=self.file2, interleaved=self.interleaved, mode="r")
37
38    def close(self) -> None:
39        self.file1.close()
40        if self.file2 is not None:
41            self.file2.close()
42
43
44class InputPaths:
45    def __init__(self, path1: str, path2: Optional[str] = None, interleaved: bool = False):
46        self.path1 = path1
47        self.path2 = path2
48        self.interleaved = interleaved
49
50    def open(self, file_opener: FileOpener) -> InputFiles:
51        file1, file2 = file_opener.xopen_pair(self.path1, self.path2, "rb")
52        return InputFiles(file1, file2, self.interleaved)
53
54
55class OutputFiles:
56    """
57    The attributes are either None or open file-like objects except for demultiplex_out
58    and demultiplex_out2, which are dictionaries that map an adapter name
59    to file-like objects.
60    """
61    def __init__(
62        self,
63        out: Optional[BinaryIO] = None,
64        out2: Optional[BinaryIO] = None,
65        untrimmed: Optional[BinaryIO] = None,
66        untrimmed2: Optional[BinaryIO] = None,
67        too_short: Optional[BinaryIO] = None,
68        too_short2: Optional[BinaryIO] = None,
69        too_long: Optional[BinaryIO] = None,
70        too_long2: Optional[BinaryIO] = None,
71        info: Optional[BinaryIO] = None,
72        rest: Optional[BinaryIO] = None,
73        wildcard: Optional[BinaryIO] = None,
74        demultiplex_out: Optional[Dict[str, BinaryIO]] = None,
75        demultiplex_out2: Optional[Dict[str, BinaryIO]] = None,
76        combinatorial_out: Optional[Dict[Tuple[str, str], BinaryIO]] = None,
77        combinatorial_out2: Optional[Dict[Tuple[str, str], BinaryIO]] = None,
78        force_fasta: Optional[bool] = None,
79    ):
80        self.out = out
81        self.out2 = out2
82        self.untrimmed = untrimmed
83        self.untrimmed2 = untrimmed2
84        self.too_short = too_short
85        self.too_short2 = too_short2
86        self.too_long = too_long
87        self.too_long2 = too_long2
88        self.info = info
89        self.rest = rest
90        self.wildcard = wildcard
91        self.demultiplex_out = demultiplex_out
92        self.demultiplex_out2 = demultiplex_out2
93        self.combinatorial_out = combinatorial_out
94        self.combinatorial_out2 = combinatorial_out2
95        self.force_fasta = force_fasta
96
97    def __iter__(self):
98        for f in [
99            self.out,
100            self.out2,
101            self.untrimmed,
102            self.untrimmed2,
103            self.too_short,
104            self.too_short2,
105            self.too_long,
106            self.too_long2,
107            self.info,
108            self.rest,
109            self.wildcard,
110        ]:
111            if f is not None:
112                yield f
113        for outs in (
114            self.demultiplex_out, self.demultiplex_out2,
115            self.combinatorial_out, self.combinatorial_out2,
116        ):
117            if outs is not None:
118                for f in outs.values():
119                    assert f is not None
120                    yield f
121
122    def as_bytesio(self) -> "OutputFiles":
123        """
124        Create a new OutputFiles instance that has BytesIO instances for each non-None output file
125        """
126        result = OutputFiles(force_fasta=self.force_fasta)
127        for attr in (
128            "out", "out2", "untrimmed", "untrimmed2", "too_short", "too_short2", "too_long",
129            "too_long2", "info", "rest", "wildcard"
130        ):
131            if getattr(self, attr) is not None:
132                setattr(result, attr, io.BytesIO())
133        for attr in "demultiplex_out", "demultiplex_out2", "combinatorial_out", "combinatorial_out2":
134            if getattr(self, attr) is not None:
135                setattr(result, attr, dict())
136                for k, v in getattr(self, attr).items():
137                    getattr(result, attr)[k] = io.BytesIO()
138        return result
139
140    def close(self) -> None:
141        """Close all output files that are not stdout"""
142        for f in self:
143            if f is sys.stdout or f is sys.stdout.buffer:
144                continue
145            f.close()
146
147
148class Pipeline(ABC):
149    """
150    Processing pipeline that loops over reads and applies modifiers and filters
151    """
152    n_adapters = 0
153    paired = False
154
155    def __init__(self, file_opener: FileOpener):
156        self._reader = None  # type: Any
157        self._filters = []  # type: List[Any]
158        self._infiles = None  # type: Optional[InputFiles]
159        self._outfiles = None  # type: Optional[OutputFiles]
160        self._demultiplexer = None
161        self._textiowrappers = []  # type: List[TextIO]
162
163        # Filter settings
164        self._minimum_length = None
165        self._maximum_length = None
166        self.max_n = None
167        self.max_expected_errors = None
168        self.discard_casava = False
169        self.discard_trimmed = False
170        self.discard_untrimmed = False
171        self.file_opener = file_opener
172
173    def connect_io(self, infiles: InputFiles, outfiles: OutputFiles) -> None:
174        self._infiles = infiles
175        self._reader = infiles.open()
176        self._set_output(outfiles)
177
178    @abstractmethod
179    def _open_writer(
180        self,
181        file: BinaryIO,
182        file2: Optional[BinaryIO] = None,
183        force_fasta: Optional[bool] = None,
184    ):
185        pass
186
187    def _set_output(self, outfiles: OutputFiles) -> None:
188        self._filters = []
189        self._outfiles = outfiles
190        filter_wrapper = self._filter_wrapper()
191
192        for filter_class, outfile in (
193            (RestFileWriter, outfiles.rest),
194            (InfoFileWriter, outfiles.info),
195            (WildcardFileWriter, outfiles.wildcard),
196        ):
197            if outfile:
198                textiowrapper = io.TextIOWrapper(outfile)
199                self._textiowrappers.append(textiowrapper)
200                self._filters.append(filter_wrapper(None, filter_class(textiowrapper), None))
201
202        # minimum length and maximum length
203        for lengths, file1, file2, filter_class in (
204                (self._minimum_length, outfiles.too_short, outfiles.too_short2, TooShortReadFilter),
205                (self._maximum_length, outfiles.too_long, outfiles.too_long2, TooLongReadFilter)
206        ):
207            if lengths is None:
208                continue
209            writer = self._open_writer(file1, file2) if file1 else None
210            f1 = filter_class(lengths[0]) if lengths[0] is not None else None
211            if len(lengths) == 2 and lengths[1] is not None:
212                f2 = filter_class(lengths[1])
213            else:
214                f2 = None
215            self._filters.append(filter_wrapper(writer, filter=f1, filter2=f2))
216
217        if self.max_n is not None:
218            f1 = f2 = NContentFilter(self.max_n)
219            self._filters.append(filter_wrapper(None, f1, f2))
220
221        if self.max_expected_errors is not None:
222            if not self._reader.delivers_qualities:
223                logger.warning("Ignoring option --max-ee as input does not contain quality values")
224            else:
225                f1 = f2 = MaximumExpectedErrorsFilter(self.max_expected_errors)
226                self._filters.append(filter_wrapper(None, f1, f2))
227
228        if self.discard_casava:
229            f1 = f2 = CasavaFilter()
230            self._filters.append(filter_wrapper(None, f1, f2))
231
232        if int(self.discard_trimmed) + int(self.discard_untrimmed) + int(outfiles.untrimmed is not None) > 1:
233            raise ValueError('discard_trimmed, discard_untrimmed and outfiles.untrimmed must not '
234                'be set simultaneously')
235
236        if outfiles.demultiplex_out is not None or outfiles.combinatorial_out is not None:
237            self._demultiplexer = self._create_demultiplexer(outfiles)
238            self._filters.append(self._demultiplexer)
239        else:
240            # Allow overriding the wrapper for --discard-untrimmed/--untrimmed-(paired-)output
241            untrimmed_filter_wrapper = self._untrimmed_filter_wrapper()
242
243            # Set up the remaining filters to deal with --discard-trimmed,
244            # --discard-untrimmed and --untrimmed-output. These options
245            # are mutually exclusive in order to avoid brain damage.
246            if self.discard_trimmed:
247                self._filters.append(
248                    filter_wrapper(None, DiscardTrimmedFilter(), DiscardTrimmedFilter()))
249            elif self.discard_untrimmed:
250                self._filters.append(
251                    untrimmed_filter_wrapper(None, DiscardUntrimmedFilter(), DiscardUntrimmedFilter()))
252            elif outfiles.untrimmed:
253                untrimmed_writer = self._open_writer(outfiles.untrimmed, outfiles.untrimmed2)
254                self._filters.append(
255                    untrimmed_filter_wrapper(untrimmed_writer, DiscardUntrimmedFilter(), DiscardUntrimmedFilter()))
256            self._filters.append(self._final_filter(outfiles))
257        logger.debug("Filters: %s", self._filters)
258
259    def flush(self) -> None:
260        for f in self._textiowrappers:
261            f.flush()
262        assert self._outfiles is not None
263        for f in self._outfiles:
264            f.flush()
265
266    def close(self) -> None:
267        self._close_input()
268        self._close_output()
269
270    def _close_input(self) -> None:
271        self._reader.close()
272        if self._infiles is not None:
273            self._infiles.close()
274
275    def _close_output(self) -> None:
276        for f in self._textiowrappers:
277            f.close()
278        # Closing a TextIOWrapper also closes the underlying file, so
279        # this closes some files a second time.
280        assert self._outfiles is not None
281        self._outfiles.close()
282
283    @property
284    def uses_qualities(self) -> bool:
285        assert self._reader is not None
286        return self._reader.delivers_qualities
287
288    @abstractmethod
289    def process_reads(self, progress: Progress = None) -> Tuple[int, int, Optional[int]]:
290        pass
291
292    @abstractmethod
293    def _filter_wrapper(self):
294        pass
295
296    @abstractmethod
297    def _untrimmed_filter_wrapper(self):
298        pass
299
300    @abstractmethod
301    def _final_filter(self, outfiles: OutputFiles):
302        pass
303
304    @abstractmethod
305    def _create_demultiplexer(self, outfiles: OutputFiles):
306        pass
307
308
309class SingleEndPipeline(Pipeline):
310    """
311    Processing pipeline for single-end reads
312    """
313    def __init__(self, file_opener: FileOpener):
314        super().__init__(file_opener)
315        self._modifiers = []  # type: List[SingleEndModifier]
316
317    def add(self, modifier: SingleEndModifier):
318        if modifier is None:
319            raise ValueError("Modifier must not be None")
320        self._modifiers.append(modifier)
321
322    def process_reads(self, progress: Progress = None) -> Tuple[int, int, Optional[int]]:
323        """Run the pipeline. Return statistics"""
324        n = 0  # no. of processed reads
325        total_bp = 0
326        for read in self._reader:
327            n += 1
328            if n % 10000 == 0 and progress:
329                progress.update(n)
330            total_bp += len(read)
331            info = ModificationInfo(read)
332            for modifier in self._modifiers:
333                read = modifier(read, info)
334            for filter_ in self._filters:
335                if filter_(read, info):
336                    break
337        return (n, total_bp, None)
338
339    def _open_writer(
340        self,
341        file: BinaryIO,
342        file2: Optional[BinaryIO] = None,
343        force_fasta: Optional[bool] = None,
344    ):
345        assert file2 is None
346        assert not isinstance(file, (str, bytes, Path))
347        return self.file_opener.dnaio_open_raise_limit(
348            file, mode="w", qualities=self.uses_qualities, fileformat="fasta" if force_fasta else None)
349
350    def _filter_wrapper(self):
351        return Redirector
352
353    def _untrimmed_filter_wrapper(self):
354        return Redirector
355
356    def _final_filter(self, outfiles: OutputFiles):
357        assert outfiles.out2 is None and outfiles.out is not None
358        writer = self._open_writer(outfiles.out, force_fasta=outfiles.force_fasta)
359        return NoFilter(writer)
360
361    def _create_demultiplexer(self, outfiles: OutputFiles):
362        writers = dict()  # type: Dict[Optional[str], Any]
363        if outfiles.untrimmed is not None:
364            writers[None] = self._open_writer(outfiles.untrimmed, force_fasta=outfiles.force_fasta)
365        assert outfiles.demultiplex_out is not None
366        for name, file in outfiles.demultiplex_out.items():
367            writers[name] = self._open_writer(file, force_fasta=outfiles.force_fasta)
368        return Demultiplexer(writers)
369
370    @property
371    def minimum_length(self):
372        return self._minimum_length
373
374    @minimum_length.setter
375    def minimum_length(self, value):
376        assert value is None or len(value) == 1
377        self._minimum_length = value
378
379    @property
380    def maximum_length(self):
381        return self._maximum_length
382
383    @maximum_length.setter
384    def maximum_length(self, value):
385        assert value is None or len(value) == 1
386        self._maximum_length = value
387
388
389class PairedEndPipeline(Pipeline):
390    """
391    Processing pipeline for paired-end reads.
392    """
393    paired = True
394
395    def __init__(self, pair_filter_mode, file_opener: FileOpener):
396        super().__init__(file_opener)
397        self._modifiers = []  # type: List[PairedEndModifier]
398        self._pair_filter_mode = pair_filter_mode
399        self._reader = None
400        # Whether to ignore pair_filter mode for discard-untrimmed filter
401        self.override_untrimmed_pair_filter = False
402
403    def add(self, modifier1: Optional[SingleEndModifier], modifier2: Optional[SingleEndModifier]) -> None:
404        """
405        Add a modifier for R1 and R2. One of them can be None, in which case the modifier
406        will only be added for the respective read.
407        """
408        if modifier1 is None and modifier2 is None:
409            raise ValueError("Not both modifiers can be None")
410        self._modifiers.append(PairedEndModifierWrapper(modifier1, modifier2))
411
412    def add_both(self, modifier: SingleEndModifier) -> None:
413        """
414        Add one modifier for both R1 and R2
415        """
416        assert modifier is not None
417        self._modifiers.append(PairedEndModifierWrapper(modifier, copy.copy(modifier)))
418
419    def add_paired_modifier(self, modifier: PairedEndModifier) -> None:
420        """Add a Modifier (without wrapping it in a PairedEndModifierWrapper)"""
421        self._modifiers.append(modifier)
422
423    def process_reads(self, progress: Progress = None) -> Tuple[int, int, Optional[int]]:
424        n = 0  # no. of processed reads
425        total1_bp = 0
426        total2_bp = 0
427        assert self._reader is not None
428        for read1, read2 in self._reader:
429            n += 1
430            if n % 10000 == 0 and progress:
431                progress.update(n)
432            total1_bp += len(read1)
433            total2_bp += len(read2)
434            info1 = ModificationInfo(read1)
435            info2 = ModificationInfo(read2)
436            for modifier in self._modifiers:
437                read1, read2 = modifier(read1, read2, info1, info2)
438            for filter_ in self._filters:
439                # Stop writing as soon as one of the filters was successful.
440                if filter_(read1, read2, info1, info2):
441                    break
442        return (n, total1_bp, total2_bp)
443
444    def _open_writer(
445        self,
446        file: BinaryIO,
447        file2: Optional[BinaryIO] = None,
448        force_fasta: Optional[bool] = None,
449    ):
450        # file and file2 must already be file-like objects because we don’t want to
451        # take care of threads and compression levels here.
452        for f in (file, file2):
453            assert not isinstance(f, (str, bytes, Path))
454        return self.file_opener.dnaio_open_raise_limit(
455            file,
456            file2=file2,
457            mode="w",
458            qualities=self.uses_qualities,
459            fileformat="fasta" if force_fasta else None,
460            interleaved=file2 is None,
461        )
462
463    def _filter_wrapper(self, pair_filter_mode=None):
464        if pair_filter_mode is None:
465            pair_filter_mode = self._pair_filter_mode
466        return functools.partial(PairedRedirector, pair_filter_mode=pair_filter_mode)
467
468    def _untrimmed_filter_wrapper(self):
469        """
470        Return a different filter wrapper when adapters were given only for R1
471        or only for R2 (then override_untrimmed_pair_filter will be set)
472        """
473        if self.override_untrimmed_pair_filter:
474            return self._filter_wrapper(pair_filter_mode='both')
475        else:
476            return self._filter_wrapper()
477
478    def _final_filter(self, outfiles):
479        writer = self._open_writer(outfiles.out, outfiles.out2, force_fasta=outfiles.force_fasta)
480        return PairedNoFilter(writer)
481
482    def _create_demultiplexer(self, outfiles):
483        def open_writer(file, file2):
484            return self._open_writer(file, file2, force_fasta=outfiles.force_fasta)
485
486        if outfiles.combinatorial_out is not None:
487            assert outfiles.untrimmed is None and outfiles.untrimmed2 is None
488            writers = dict()
489            for key, out in outfiles.combinatorial_out.items():
490                writers[key] = open_writer(out, outfiles.combinatorial_out2[key])
491            return CombinatorialDemultiplexer(writers)
492        else:
493            writers = dict()
494            if outfiles.untrimmed is not None:
495                writers[None] = open_writer(outfiles.untrimmed, outfiles.untrimmed2)
496            for name, file in outfiles.demultiplex_out.items():
497                writers[name] = open_writer(file, outfiles.demultiplex_out2[name])
498            return PairedDemultiplexer(writers)
499
500    @property
501    def minimum_length(self):
502        return self._minimum_length
503
504    @minimum_length.setter
505    def minimum_length(self, value):
506        assert value is None or len(value) == 2
507        self._minimum_length = value
508
509    @property
510    def maximum_length(self):
511        return self._maximum_length
512
513    @maximum_length.setter
514    def maximum_length(self, value):
515        assert value is None or len(value) == 2
516        self._maximum_length = value
517
518
519class ReaderProcess(Process):
520    """
521    Read chunks of FASTA or FASTQ data (single-end or paired) and send to a worker.
522
523    The reader repeatedly
524
525    - reads a chunk from the file(s)
526    - reads a worker index from the Queue
527    - sends the chunk to connections[index]
528
529    and finally sends the stop token -1 ("poison pills") to all connections.
530    """
531
532    def __init__(self, path: str, path2: Optional[str], opener: FileOpener, connections, queue, buffer_size, stdin_fd):
533        """
534        queue -- a Queue of worker indices. A worker writes its own index into this
535            queue to notify the reader that it is ready to receive more data.
536        connections -- a list of Connection objects, one for each worker.
537        """
538        super().__init__()
539        self.path = path
540        self.path2 = path2
541        self.connections = connections
542        self.queue = queue
543        self.buffer_size = buffer_size
544        self.stdin_fd = stdin_fd
545        self._opener = opener
546
547    def run(self):
548        if self.stdin_fd != -1:
549            sys.stdin.close()
550            sys.stdin = os.fdopen(self.stdin_fd)
551        try:
552            with self._opener.xopen(self.path, 'rb') as f:
553                if self.path2:
554                    with self._opener.xopen(self.path2, 'rb') as f2:
555                        for chunk_index, (chunk1, chunk2) in enumerate(
556                                dnaio.read_paired_chunks(f, f2, self.buffer_size)):
557                            self.send_to_worker(chunk_index, chunk1, chunk2)
558                else:
559                    for chunk_index, chunk in enumerate(dnaio.read_chunks(f, self.buffer_size)):
560                        self.send_to_worker(chunk_index, chunk)
561
562            # Send poison pills to all workers
563            for _ in range(len(self.connections)):
564                worker_index = self.queue.get()
565                self.connections[worker_index].send(-1)
566        except Exception as e:
567            # TODO better send this to a common "something went wrong" Queue
568            for connection in self.connections:
569                connection.send(-2)
570                connection.send((e, traceback.format_exc()))
571
572    def send_to_worker(self, chunk_index, chunk1, chunk2=None):
573        worker_index = self.queue.get()
574        connection = self.connections[worker_index]
575        connection.send(chunk_index)
576        connection.send_bytes(chunk1)
577        if chunk2 is not None:
578            connection.send_bytes(chunk2)
579
580
581class WorkerProcess(Process):
582    """
583    The worker repeatedly reads chunks of data from the read_pipe, runs the pipeline on it
584    and sends the processed chunks to the write_pipe.
585
586    To notify the reader process that it wants data, it puts its own identifier into the
587    need_work_queue before attempting to read data from the read_pipe.
588    """
589    def __init__(
590        self,
591        id_: int,
592        pipeline: Pipeline,
593        two_input_files: bool,
594        interleaved_input: bool,
595        orig_outfiles: OutputFiles,
596        read_pipe: Connection,
597        write_pipe: Connection,
598        need_work_queue: Queue,
599    ):
600        super().__init__()
601        self._id = id_
602        self._pipeline = pipeline
603        self._two_input_files = two_input_files
604        self._interleaved_input = interleaved_input
605        self._read_pipe = read_pipe
606        self._write_pipe = write_pipe
607        self._need_work_queue = need_work_queue
608        # Do not store orig_outfiles directly because it contains
609        # _io.BufferedWriter attributes, which cannot be pickled.
610        self._original_outfiles = orig_outfiles.as_bytesio()
611
612    def run(self):
613        try:
614            stats = Statistics()
615            while True:
616                # Notify reader that we need data
617                self._need_work_queue.put(self._id)
618                chunk_index = self._read_pipe.recv()
619                if chunk_index == -1:
620                    # reader is done
621                    break
622                elif chunk_index == -2:
623                    # An exception has occurred in the reader
624                    e, tb_str = self._read_pipe.recv()
625                    logger.error('%s', tb_str)
626                    raise e
627
628                infiles = self._make_input_files()
629                outfiles = self._original_outfiles.as_bytesio()
630                self._pipeline.connect_io(infiles, outfiles)
631                (n, bp1, bp2) = self._pipeline.process_reads()
632                self._pipeline.flush()
633                cur_stats = Statistics().collect(n, bp1, bp2, [], self._pipeline._filters)
634                stats += cur_stats
635                self._send_outfiles(outfiles, chunk_index, n)
636
637            m = self._pipeline._modifiers
638            modifier_stats = Statistics().collect(0, 0, 0 if self._pipeline.paired else None, m, [])
639            stats += modifier_stats
640            self._write_pipe.send(-1)
641            self._write_pipe.send(stats)
642        except Exception as e:
643            self._write_pipe.send(-2)
644            self._write_pipe.send((e, traceback.format_exc()))
645
646    def _make_input_files(self) -> InputFiles:
647        data = self._read_pipe.recv_bytes()
648        input = io.BytesIO(data)
649
650        if self._two_input_files:
651            data = self._read_pipe.recv_bytes()
652            input2 = io.BytesIO(data)  # type: Optional[BinaryIO]
653        else:
654            input2 = None
655        return InputFiles(input, input2, interleaved=self._interleaved_input)
656
657    def _send_outfiles(self, outfiles: OutputFiles, chunk_index: int, n_reads: int):
658        self._write_pipe.send(chunk_index)
659        self._write_pipe.send(n_reads)
660
661        for f in outfiles:
662            f.flush()
663            assert isinstance(f, io.BytesIO)
664            processed_chunk = f.getvalue()
665            self._write_pipe.send_bytes(processed_chunk)
666
667
668class OrderedChunkWriter:
669    """
670    We may receive chunks of processed data from worker processes
671    in any order. This class writes them to an output file in
672    the correct order.
673    """
674    def __init__(self, outfile):
675        self._chunks = dict()
676        self._current_index = 0
677        self._outfile = outfile
678
679    def write(self, data, index):
680        """
681        """
682        self._chunks[index] = data
683        while self._current_index in self._chunks:
684            self._outfile.write(self._chunks[self._current_index])
685            del self._chunks[self._current_index]
686            self._current_index += 1
687
688    def wrote_everything(self):
689        return not self._chunks
690
691
692class PipelineRunner(ABC):
693    """
694    A read processing pipeline
695    """
696    def __init__(self, pipeline: Pipeline, progress: Progress):
697        self._pipeline = pipeline
698        self._progress = progress
699
700    @abstractmethod
701    def run(self):
702        pass
703
704    @abstractmethod
705    def close(self):
706        pass
707
708    def __enter__(self):
709        return self
710
711    def __exit__(self, *args):
712        self.close()
713
714
715class ParallelPipelineRunner(PipelineRunner):
716    """
717    Run a Pipeline in parallel
718
719    - When connect_io() is called, a reader process is spawned.
720    - When run() is called, as many worker processes as requested are spawned.
721    - In the main process, results are written to the output files in the correct
722      order, and statistics are aggregated.
723
724    If a worker needs work, it puts its own index into a Queue() (_need_work_queue).
725    The reader process listens on this queue and sends the raw data to the
726    worker that has requested work. For sending the data from reader to worker,
727    a Connection() is used. There is one such connection for each worker (self._pipes).
728
729    For sending the processed data from the worker to the main process, there
730    is a second set of connections, again one for each worker.
731
732    When the reader is finished, it sends 'poison pills' to all workers.
733    When a worker receives this, it sends a poison pill to the main process,
734    followed by a Statistics object that contains statistics about all the reads
735    processed by that worker.
736    """
737
738    def __init__(
739        self,
740        pipeline: Pipeline,
741        infiles: InputPaths,
742        outfiles: OutputFiles,
743        opener: FileOpener,
744        progress: Progress,
745        n_workers: int,
746        buffer_size: int = 4 * 1024**2,
747    ):
748        super().__init__(pipeline, progress)
749        self._n_workers = n_workers
750        self._need_work_queue = Queue()  # type: Queue
751        self._buffer_size = buffer_size
752        self._outfiles = outfiles
753        self._opener = opener
754        self._assign_input(infiles.path1, infiles.path2, infiles.interleaved)
755
756    def _assign_input(
757        self,
758        path1: str,
759        path2: Optional[str] = None,
760        interleaved: bool = False,
761    ) -> None:
762        self._two_input_files = path2 is not None
763        self._interleaved_input = interleaved
764        # the workers read from these connections
765        connections = [Pipe(duplex=False) for _ in range(self._n_workers)]
766        self._connections, connw = zip(*connections)
767        try:
768            fileno = sys.stdin.fileno()
769        except io.UnsupportedOperation:
770            # This happens during tests: pytest sets sys.stdin to an object
771            # that does not have a file descriptor.
772            fileno = -1
773        self._reader_process = ReaderProcess(path1, path2, self._opener, connw,
774            self._need_work_queue, self._buffer_size, fileno)
775        self._reader_process.daemon = True
776        self._reader_process.start()
777
778    def _start_workers(self) -> Tuple[List[WorkerProcess], List[Connection]]:
779        workers = []
780        connections = []
781        for index in range(self._n_workers):
782            conn_r, conn_w = Pipe(duplex=False)
783            connections.append(conn_r)
784            worker = WorkerProcess(
785                index, self._pipeline,
786                self._two_input_files,
787                self._interleaved_input, self._outfiles,
788                self._connections[index], conn_w, self._need_work_queue)
789            worker.daemon = True
790            worker.start()
791            workers.append(worker)
792        return workers, connections
793
794    def run(self) -> Statistics:
795        workers, connections = self._start_workers()
796        writers = []
797        for f in self._outfiles:
798            writers.append(OrderedChunkWriter(f))
799        stats = Statistics()
800        n = 0  # A running total of the number of processed reads (for progress indicator)
801        while connections:
802            ready_connections = multiprocessing.connection.wait(connections)
803            for connection in ready_connections:  # type: Any
804                chunk_index = connection.recv()
805                if chunk_index == -1:
806                    # the worker is done
807                    cur_stats = connection.recv()
808                    if stats == -2:
809                        # An exception has occurred in the worker (see below,
810                        # this happens only when there is an exception sending
811                        # the statistics)
812                        e, tb_str = connection.recv()
813                        logger.error('%s', tb_str)
814                        raise e
815                    stats += cur_stats
816                    connections.remove(connection)
817                    continue
818                elif chunk_index == -2:
819                    # An exception has occurred in the worker
820                    e, tb_str = connection.recv()
821
822                    # We should use the worker's actual traceback object
823                    # here, but traceback objects are not picklable.
824                    logger.error('%s', tb_str)
825                    raise e
826
827                # No. of reads processed in this chunk
828                chunk_n = connection.recv()
829                if chunk_n == -2:
830                    e, tb_str = connection.recv()
831                    logger.error('%s', tb_str)
832                    raise e
833                n += chunk_n
834                self._progress.update(n)
835                for writer in writers:
836                    data = connection.recv_bytes()
837                    writer.write(data, chunk_index)
838        for writer in writers:
839            assert writer.wrote_everything()
840        for w in workers:
841            w.join()
842        self._reader_process.join()
843        self._progress.stop(n)
844        return stats
845
846    def close(self) -> None:
847        self._outfiles.close()
848
849
850class SerialPipelineRunner(PipelineRunner):
851    """
852    Run a Pipeline on a single core
853    """
854
855    def __init__(
856        self,
857        pipeline: Pipeline,
858        infiles: InputFiles,
859        outfiles: OutputFiles,
860        progress: Progress,
861    ):
862        super().__init__(pipeline, progress)
863        self._pipeline.connect_io(infiles, outfiles)
864
865    def run(self) -> Statistics:
866        (n, total1_bp, total2_bp) = self._pipeline.process_reads(progress=self._progress)
867        if self._progress:
868            self._progress.stop(n)
869        # TODO
870        modifiers = getattr(self._pipeline, "_modifiers", None)
871        assert modifiers is not None
872        return Statistics().collect(n, total1_bp, total2_bp, modifiers, self._pipeline._filters)
873
874    def close(self) -> None:
875        self._pipeline.close()
876