1"""
2Script for binning fast5 reads into separate directories based on column value in summary file
3Inteded for demultiplexing reads using barcoding summary file.
4"""
5from pathlib import Path
6from typing import Union, Dict, Set, List
7from multiprocessing import Pool
8import logging
9from csv import reader
10from collections import defaultdict
11from time import sleep
12from math import ceil
13from argparse import ArgumentParser
14
15from ont_fast5_api.compression_settings import COMPRESSION_MAP
16from ont_fast5_api.conversion_tools.conversion_utils import (
17    get_fast5_file_list,
18    get_progress_bar,
19    Fast5FilterWorker,
20    READS_PER_FILE,
21    FILENAME_BASE,
22    ProgressBar,
23)
24
25DEMULTIPLEX_COLUMN = "barcode_arrangement"
26READ_ID_COLUMN = "read_id"
27
28
29class Fast5Demux:
30    """
31    Bin reads from directory of fast5 files according to demultiplex_column in sequencing_summary path
32    :param input_dir: Path to input Fast5 file or directory of Fast5 files
33    :param output_dir: Path to output directory
34    :param summary_file: Path to TSV summary file
35    :param demultiplex_column: str name of column with demultiplex values
36    :param read_id_column: str name of column with read ids
37    :param filename_base: str prefix for output Fast5 files
38    :param batch_size: int maximum number of reads per output file
39    :param threads: int maximum number of worker processes
40    :param recursive: bool flag to search recursively through input_dir for Fast5 files
41    :param follow_symlinks: bool flag to follow symlinks in input_dir
42    :param target_compression: str compression type in output Fast5 files
43    """
44
45    def __init__(
46        self,
47        input_dir: Path,
48        output_dir: Path,
49        summary_file: Path,
50        demultiplex_column: str,
51        read_id_column: str = READ_ID_COLUMN,
52        filename_base: str = FILENAME_BASE,
53        batch_size: int = READS_PER_FILE,
54        threads: int = 1,
55        recursive: bool = False,
56        follow_symlinks: bool = True,
57        target_compression: Union[str, None] = None,
58    ):
59        self.input_dir = input_dir
60        self.output_dir = output_dir
61        self.summary = summary_file
62        self.demultiplex_column = demultiplex_column
63        self.read_id_column = read_id_column
64        self.filename_base = filename_base
65        self.batch_size = batch_size
66        self.threads = threads
67        self.recursive = recursive
68        self.follow_symlinks = follow_symlinks
69        self.target_compression = target_compression
70
71        self.read_sets: Dict[str, Set[str]] = {}
72        self.input_fast5s: List[Path] = []
73        self.max_threads: int = 0
74        self.workers: List = []
75        self.progressbar: Union[ProgressBar, None] = None
76        self.logger: logging.Logger = logging.getLogger(self.__class__.__name__)
77
78    def create_output_dirs(self) -> None:
79        """
80        In output directory create a subdirectory per demux category
81        :return:
82        """
83        self.output_dir.mkdir(parents=True, exist_ok=True)
84        for demux in self.read_sets:
85            out_dir = self.output_dir / demux
86            out_dir.mkdir(exist_ok=True)
87
88    def run_batch(self) -> None:
89        """
90        Run workers in pool or sequentially
91        Starts multiprocessing pool if max_threads allows it
92        :return:
93        """
94        self.workers_setup()
95
96        if self.max_threads > 1:
97            with Pool(self.max_threads) as pool:
98                for worker in self.workers:
99                    worker.run_batch(pool=pool)
100                while any(worker.tasks for worker in self.workers):
101                    sleep(1)
102
103            pool.join()
104            pool.close()
105        else:
106            for worker in self.workers:
107                worker.run_batch(pool=None)
108
109        self.progressbar.finish()
110
111    def workers_setup(self) -> None:
112        """
113        Parse input summary and input file list to determine amount of work
114        Create output directories and initialise workers
115        :return:
116        """
117        self.read_sets = self.parse_summary_demultiplex()
118        self.input_fast5s = get_fast5_file_list(
119            input_path=self.input_dir,
120            recursive=self.recursive,
121            follow_symlinks=self.follow_symlinks,
122        )
123        self.max_threads = self.calculate_max_threads()
124        # progressbar length is total numbers of reads to be extracted plus total number of files to be read
125        total_progress = sum(len(item) for item in self.read_sets.values()) + (
126            len(self.input_fast5s) * len(self.read_sets)
127        )
128        self.progressbar = get_progress_bar(num_reads=total_progress)
129        self.create_output_dirs()
130        for demux in sorted(self.read_sets):
131            self.workers.append(
132                Fast5FilterWorker(
133                    input_file_list=self.input_fast5s,
134                    output_dir=self.output_dir / demux,
135                    read_set=self.read_sets[demux],
136                    progressbar=self.progressbar,
137                    logger=self.logger,
138                    filename_base=self.filename_base,
139                    batch_size=self.batch_size,
140                    target_compression=self.target_compression,
141                )
142            )
143
144    def report(self) -> None:
145        """
146        Log summary of work done
147        :return:
148        """
149        total_reads = 0
150        for idx, _ in enumerate(sorted(self.read_sets)):
151            worker = self.workers[idx]
152            for file, reads in worker.out_files.items():
153                total_reads += len(reads)
154
155        self.logger.info("{} reads extracted".format(total_reads))
156
157        # report reads not found
158        reads_to_extract = sum(len(item) for item in self.read_sets.values())
159        if reads_to_extract > total_reads:
160            self.logger.warning(
161                "{} reads not found!".format(reads_to_extract - total_reads)
162            )
163
164    def calculate_max_threads(self) -> int:
165        """
166        Calculate max number of workers based on number of output files, input files and threads argument
167        :return: int
168        """
169        max_inputs_per_worker = len(self.input_fast5s)
170        total_outputs = 0
171        for read_set in self.read_sets.values():
172            outputs = int(ceil(len(read_set) / float(self.batch_size)))
173            total_outputs += min(outputs, max_inputs_per_worker)
174
175        return min(self.threads, total_outputs)
176
177    def parse_summary_demultiplex(self) -> Dict[str, Set[str]]:
178        """
179        Open a TSV file and parse read_id and demultiplex columns into dict {demultiplex: read_id_set}
180        :return:
181        """
182        read_sets = defaultdict(set)
183        with open(str(self.summary), "r") as fh:
184            read_list_tsv = reader(fh, delimiter="\t")
185            header = next(read_list_tsv)
186
187            if self.read_id_column in header:
188                read_id_col_idx = header.index(self.read_id_column)
189            else:
190                raise ValueError(
191                    "No '{}' read_id column in header: {}".format(
192                        self.read_id_column, header
193                    )
194                )
195
196            if self.demultiplex_column in header:
197                demultiplex_col_idx = header.index(self.demultiplex_column)
198            else:
199                raise ValueError(
200                    "No '{}' demultiplex column in header: {}".format(
201                        self.demultiplex_column, header
202                    )
203                )
204
205            for line in read_list_tsv:
206                read_id = line[read_id_col_idx]
207                demux = line[demultiplex_col_idx]
208                read_sets[demux].add(read_id)
209
210        return read_sets
211
212
213def create_arg_parser():
214    parser = ArgumentParser(
215        "Tool for binning reads from a multi_read_fast5_file by column value in summary file"
216    )
217    parser.add_argument(
218        "-i",
219        "--input",
220        required=True,
221        type=Path,
222        help="Path to Fast5 file or directory of Fast5 files",
223    )
224    parser.add_argument(
225        "-s",
226        "--save_path",
227        required=True,
228        type=Path,
229        help="Directory to output MultiRead subset to",
230    )
231    parser.add_argument(
232        "-l",
233        "--summary_file",
234        required=True,
235        type=Path,
236        help="TSV file containing read_id column (sequencing_summary.txt file)",
237    )
238    parser.add_argument(
239        "-f",
240        "--filename_base",
241        default="batch",
242        required=False,
243        help="Root of output filename, default='{}' -> '{}0.fast5'".format(
244            FILENAME_BASE, FILENAME_BASE
245        ),
246    )
247    parser.add_argument(
248        "-n",
249        "--batch_size",
250        type=int,
251        default=READS_PER_FILE,
252        required=False,
253        help="Number of reads per multi-read file (default {})".format(READS_PER_FILE),
254    )
255    parser.add_argument(
256        "-t",
257        "--threads",
258        type=int,
259        default=1,
260        required=False,
261        help="Maximum number of parallel processes to use (default 1)",
262    )
263    parser.add_argument(
264        "-r",
265        "--recursive",
266        action="store_true",
267        required=False,
268        default=False,
269        help="Flag to search recursively through input directory for MultiRead fast5 files",
270    )
271    parser.add_argument(
272        "--ignore_symlinks",
273        action="store_true",
274        help="Ignore symlinks when searching recursively for fast5 files",
275    )
276    parser.add_argument(
277        "-c",
278        "--compression",
279        required=False,
280        default=None,
281        choices=list(COMPRESSION_MAP.keys()) + [None],
282        help="Target output compression type. If omitted - don't change compression type",
283    )
284    parser.add_argument(
285        "--demultiplex_column",
286        type=str,
287        default=DEMULTIPLEX_COLUMN,
288        required=False,
289        help="Name of column for demultiplexing in summary file (default '{}'".format(
290            DEMULTIPLEX_COLUMN
291        ),
292    )
293    parser.add_argument(
294        "--read_id_column",
295        type=str,
296        default=READ_ID_COLUMN,
297        required=False,
298        help="Name of read_id column in summary file (default '{}'".format(
299            READ_ID_COLUMN
300        ),
301    )
302    return parser
303
304
305def main():
306    parser = create_arg_parser()
307    args = parser.parse_args()
308    if args.compression is not None:
309        args.compression = COMPRESSION_MAP[args.compression]
310
311    demux = Fast5Demux(
312        input_dir=args.input,
313        output_dir=args.save_path,
314        summary_file=args.summary_file,
315        demultiplex_column=args.demultiplex_column,
316        read_id_column=args.read_id_column,
317        filename_base=args.filename_base,
318        batch_size=args.batch_size,
319        threads=args.threads,
320        recursive=args.recursive,
321        follow_symlinks=not args.ignore_symlinks,
322        target_compression=args.compression,
323    )
324    demux.run_batch()
325    demux.report()
326
327
328if __name__ == "__main__":
329    main()
330