1''' Utilities for dealing with sam files. '''
2
3import re
4import subprocess
5import os
6import sys
7import shutil
8import logging
9import heapq
10import contextlib
11import copy
12import functools
13
14from collections import Counter
15from itertools import chain
16from pathlib import Path
17
18import pysam
19import numpy as np
20
21from . import utilities
22from . import external_sort
23from . import fastq
24from . import fasta
25from . import mapping_tools
26from . import interval
27
28BAM_CMATCH = 0     # M
29BAM_CINS = 1       # I
30BAM_CDEL = 2       # D
31BAM_CREF_SKIP = 3  # N
32BAM_CSOFT_CLIP = 4 # S
33BAM_CHARD_CLIP = 5 # H
34BAM_CPAD = 6       # P
35BAM_CEQUAL = 7     # =
36BAM_CDIFF = 8      # X
37
38op_to_char = {
39    BAM_CMATCH:     'M',
40    BAM_CINS:       'I',
41    BAM_CDEL:       'D',
42    BAM_CREF_SKIP:  'N',
43    BAM_CSOFT_CLIP: 'S',
44    BAM_CHARD_CLIP: 'H',
45    BAM_CPAD:       'P',
46    BAM_CEQUAL:     '=',
47    BAM_CDIFF:      'X',
48}
49# Want to be able to lookup with int or char keys, so make every relevant char
50# return itself.
51for v in list(op_to_char.values()):
52    op_to_char[v] = v
53
54read_consuming_ops = {
55    BAM_CMATCH,
56    BAM_CINS,
57    BAM_CSOFT_CLIP,
58    BAM_CEQUAL,
59    BAM_CDIFF,
60}
61
62ref_consuming_ops = {
63    BAM_CMATCH,
64    BAM_CDEL,
65    BAM_CEQUAL,
66    BAM_CDIFF,
67    BAM_CREF_SKIP,
68}
69
70_unmapped_template = '{0}\t4\t*\t0\t0\t*\t*\t0\t0\t*\t*\n'.format
71
72def get_strand(mapping):
73    if mapping.is_reverse:
74        strand = '-'
75    else:
76        strand = '+'
77    return strand
78
79def get_original_seq(mapping):
80    if mapping.is_reverse:
81        original_seq = utilities.reverse_complement(mapping.query_sequence)
82    else:
83        original_seq = mapping.query_sequence
84    return original_seq
85
86def get_original_qual(mapping):
87    if mapping.is_reverse:
88        original_qual = mapping.query_qualities[::-1]
89    else:
90        original_qual = mapping.query_qualities
91    return original_qual
92
93def unmapped_aligned_read(qname):
94    aligned_read = pysam.AlignedRead()
95    aligned_read.qname = qname
96    aligned_read.flag = 0x4
97    aligned_read.rname = -1
98    aligned_read.pos = -1
99    aligned_read.mapq = 0
100    aligned_read.cigar = None
101    aligned_read.rnext = -1
102    aligned_read.pnext = -1
103    aligned_read.tlen = 0
104    aligned_read.seq = '*'
105    aligned_read.qual = '*'
106    return aligned_read
107
108def splice_in_name(line, new_name):
109    return '\t'.join([new_name] + line.split('\t')[1:])
110
111cigar_block = re.compile(r'(\d+)([MIDNSHP=X])')
112
113def cigar_string_to_blocks(cigar_string):
114    """ Decomposes a CIGAR string into a list of its operations. """
115    return [(int(l), k) for l, k in cigar_block.findall(cigar_string)]
116
117def total_reference_nucs(cigar):
118    return sum(length for op, length in cigar if op in ref_consuming_ops)
119
120def total_reference_nucs_except_splicing(cigar):
121    return sum(length for op, length in cigar if op in ref_consuming_ops and op != BAM_CREF_SKIP)
122
123def total_read_nucs(cigar):
124    return sum(length for op, length in cigar if op in read_consuming_ops)
125
126def contains_indel(parsed_line):
127    cigar_blocks = cigar_string_to_blocks(parsed_line['CIGAR'])
128    kinds = [k for l, k in cigar_blocks]
129    return ('I' in kinds or 'D' in kinds)
130
131def contains_indel_pysam(read):
132    kinds = [k for k, l in read.cigar]
133    return (BAM_CINS in kinds or BAM_CDEL in kinds)
134
135def indel_distance_from_edge(cigar):
136    indel_indices = [i for i, (k, l) in enumerate(cigar) if k == BAM_CINS or k == BAM_CDEL]
137    first_indel_index = min(indel_indices)
138    ref_nucs_before = total_reference_nucs(cigar[:first_indel_index])
139    last_indel_index = max(indel_indices)
140    ref_nucs_after = total_reference_nucs(cigar[last_indel_index + 1:])
141    return min(ref_nucs_before, ref_nucs_after)
142
143def contains_splicing(read):
144    return any(k == BAM_CREF_SKIP for k, l in read.cigar)
145
146def contains_soft_clipping(parsed_line):
147    cigar_blocks = cigar_string_to_blocks(parsed_line['CIGAR'])
148    kinds = [k for k, l in cigar_blocks]
149    return ('S' in kinds)
150
151def contains_soft_clipping_pysam(read):
152    kinds = [k for k, l in read.cigar]
153    return (BAM_CSOFT_CLIP in kinds)
154
155def get_soft_clipped_block(alignment, edge):
156    strand = get_strand(alignment)
157
158    if (edge == 5 and strand == '+') or (edge == 3 and strand == '-'):
159        op, length = alignment.cigar[0]
160        if op == BAM_CSOFT_CLIP:
161            sl = slice(None, length)
162        else:
163            sl = slice(0)
164    elif (edge == 5 and strand == '-') or (edge == 3 and strand == '+'):
165        op, length = alignment.cigar[-1]
166        if op == BAM_CSOFT_CLIP:
167            sl = slice(-length, None)
168        else:
169            sl = slice(0)
170
171    seq = alignment.seq[sl]
172    qual = alignment.query_qualities[sl]
173
174    return seq, qual
175
176def get_max_soft_clipped_length(alignment):
177    max_soft_clipped = 0
178    for edge in [5, 3]:
179        seq, qual = get_soft_clipped_block(alignment, edge)
180        max_soft_clipped = max(max_soft_clipped, len(seq))
181    return max_soft_clipped
182
183def cigar_blocks_to_string(cigar_blocks):
184    ''' Builds a CIGAR string out of a corresponding list of operations. '''
185    string = ['{0}{1}'.format(length, op_to_char[op])
186              for op, length in cigar_blocks
187             ]
188    return ''.join(string)
189
190def alignment_to_cigar_blocks(ref_aligned, read_aligned):
191    """ Builds a list of CIGAR operations from an alignment. """
192    expanded_sequence = []
193    for ref_char, read_char in zip(ref_aligned, read_aligned):
194        if ref_char == '-':
195            expanded_sequence.append('I')
196        elif read_char == '-':
197            expanded_sequence.append('D')
198        elif ref_char == read_char:
199            #expanded_sequence.append('=')
200            expanded_sequence.append('M')
201        else:
202            #expanded_sequence.append('X')
203            expanded_sequence.append('M')
204    sequence, counts = utilities.decompose_homopolymer_sequence(expanded_sequence)
205    return [[count, char] for char, count in zip(sequence, counts)]
206
207def aligned_pairs_to_cigar(aligned_pairs, guide=None):
208    op_sequence = []
209    for read, ref in aligned_pairs:
210        if read == None or read == '-':
211            op_sequence.append(BAM_CDEL)
212        elif read == 's':
213            op_sequence.append(BAM_CREF_SKIP)
214        elif ref == None or ref == '-':
215            op_sequence.append(BAM_CINS)
216        elif ref == 'S':
217            op_sequence.append(BAM_CSOFT_CLIP)
218        else:
219            op_sequence.append(BAM_CMATCH)
220
221    cigar = [(op, len(times)) for op, times in utilities.group_by(op_sequence)]
222
223    if guide:
224        guide_cigar, from_side = guide
225
226        if from_side == 'right':
227            cigar = cigar[::-1]
228            guide_cigar = guide_cigar[::-1]
229
230        for i in range(min(len(cigar), len(guide_cigar))):
231            op, length = cigar[i]
232            guide_op, guide_length = guide_cigar[i]
233            cigar[i] = (guide_op, length)
234
235        if from_side == 'right':
236            cigar = cigar[::-1]
237            guide_cigar = guide_cigar[::-1]
238
239    return cigar
240
241def cigar_to_aligned_pairs(cigar, start):
242    aligned_pairs = []
243
244    ref_pos = start
245    read_pos = 0
246    for op, length in cigar:
247        if op == BAM_CMATCH or op == BAM_CEQUAL or op == BAM_CDIFF:
248            for i in range(length):
249                aligned_pairs.append((read_pos, ref_pos))
250
251                ref_pos += 1
252                read_pos += 1
253
254        elif op == BAM_CDEL:
255            # Deletion results in gap in read
256            for i in range(length):
257                aligned_pairs.append((None, ref_pos))
258
259                ref_pos += 1
260
261        elif op == BAM_CREF_SKIP:
262            # Skip results in gap in read
263            for i in range(length):
264                aligned_pairs.append(('s', ref_pos))
265
266                ref_pos += 1
267
268        elif op == BAM_CINS:
269            # Insertion results in gap in ref
270            for i in range(length):
271                aligned_pairs.append((read_pos, None))
272
273                read_pos += 1
274
275        elif BAM_CSOFT_CLIP:
276            # Soft-clipping results in gap in ref
277            for i in range(length):
278                aligned_pairs.append((read_pos, 'S'))
279
280                read_pos += 1
281
282        else:
283            raise ValueError('Unsupported op', cigar)
284
285    return aligned_pairs
286
287def cigar_to_aligned_pairs_backwards(cigar, end, read_length):
288    aligned_pairs = []
289
290    ref_pos = end
291    read_pos = read_length - 1
292    for op, length in cigar[::-1]:
293        if op == BAM_CMATCH or op == BAM_CEQUAL or op == BAM_CDIFF:
294            for i in range(length):
295                aligned_pairs.append((read_pos, ref_pos))
296
297                ref_pos -= 1
298                read_pos -= 1
299
300        elif op == BAM_CDEL or op == BAM_CREF_SKIP:
301            # Deletion results in gap in read
302            for i in range(length):
303                aligned_pairs.append((None, ref_pos))
304
305                ref_pos -= 1
306
307        elif op == BAM_CINS:
308            # Insertion results in gap in ref
309            for i in range(length):
310                aligned_pairs.append((read_pos, None))
311
312                read_pos -= 1
313
314        elif op == BAM_CSOFT_CLIP:
315            # Soft-clipping results in gap in ref
316            for i in range(length):
317                aligned_pairs.append((read_pos, 'S'))
318
319                read_pos -= 1
320
321        else:
322            raise ValueError('Unsupported op', cigar)
323
324    return aligned_pairs
325
326def truncate_cigar_blocks_up_to(cigar_blocks, truncated_length):
327    ''' Given pysam-style cigar_blocks, truncates the blocks to explain
328        truncated_length read bases.
329    '''
330    bases_so_far = 0
331    truncated_blocks = []
332
333    for operation, length in cigar_blocks:
334        # If the next block wouldn't consume anything, we want to include it.
335        if bases_so_far == truncated_length and operation in read_consuming_ops:
336            break
337
338        if operation in read_consuming_ops:
339            length_to_use = min(truncated_length - bases_so_far, length)
340            bases_so_far += length_to_use
341        else:
342            length_to_use = length
343
344        truncated_blocks.append((operation, length_to_use))
345
346        # If we didn't use the whole block, need to break because the next loop
347        # will try to use the next block if it is non-consuming.
348        if length_to_use < length:
349            break
350
351    return truncated_blocks
352
353def truncate_cigar_blocks_from_beginning(cigar_blocks, truncated_length):
354    ''' Removes cigar operations from the beginning of cigar_blocks so that
355    truncated_length total read bases remain accounted for.
356    '''
357    flipped_truncated_blocks = truncate_cigar_blocks_up_to(cigar_blocks[::-1],
358                                                           truncated_length,
359                                                          )
360    truncated_blocks = flipped_truncated_blocks[::-1]
361    return truncated_blocks
362
363def collapse_cigar_blocks(cigar_blocks):
364    collapsed = []
365    for kind, blocks in utilities.group_by(cigar_blocks, lambda t: t[0]):
366        collapsed.append((kind, sum(t[1] for t in blocks)))
367    return collapsed
368
369def alignment_to_cigar_string(ref_aligned, read_aligned):
370    """ Builds a CIGAR string from an alignment. """
371    cigar_blocks = alignment_to_cigar_blocks(ref_aligned, read_aligned)
372    return cigar_blocks_to_string(cigar_blocks)
373
374md_number = re.compile(r'[0-9]+')
375md_text = re.compile(r'[A-Z]+')
376
377def md_string_to_ops_string(md_string):
378    ''' Converts an MD string into a list of operations for supplying reference
379        characters, either '=' if equal to the read, or any other char if equal
380        to that char.
381    '''
382    # In the presence of a CIGAR string, the '^' character seems extraneous.
383    # 94 is unicode ^.
384    md_string = md_string.translate({94: None})
385
386    match_lengths = [int(s) for s in re.findall(md_number, md_string)]
387    text_blocks = re.findall(md_text, md_string)
388
389    # The standard calls for a number to start and end, zero if necessary,
390    # so after removing the initial number, there must be the same number of
391    # match_lengths and text_blocks.
392    if len(text_blocks) != len(match_lengths) - 1:
393        raise ValueError(md_string)
394
395    ops_string = '='*match_lengths[0]
396    for text_block, match_length in zip(text_blocks, match_lengths[1:]):
397        ops_string += text_block
398        ops_string += '='*match_length
399
400    return ops_string
401
402md_item_pattern = re.compile(r'[0-9]+|[TCAGN^]+')
403
404def int_if_possible(string):
405    try:
406        int_value = int(string)
407        return int_value
408    except ValueError:
409        return string
410
411def md_string_to_items(md_string):
412    items = [int_if_possible(item) for item in md_item_pattern.findall(md_string)]
413    return items
414
415def md_items_to_md_string(items):
416    string = ''.join(str(item) for item in items)
417    return string
418
419def reverse_md_items(items):
420    reversed_items = []
421    for item in items[::-1]:
422        if isinstance(item, int):
423            reversed_items.append(item)
424        else:
425            if item.startswith('^'):
426                reversed_items.append('^' + item[:0:-1])
427            else:
428                reversed_items.append(item[::-1])
429    return reversed_items
430
431def truncate_md_items(md_items, truncated_length):
432    if truncated_length == 0:
433        truncated_items = [0]
434    else:
435        bases_so_far = 0
436        truncated_items = []
437
438        for item in md_items:
439            if bases_so_far == truncated_length:
440                break
441
442            if isinstance(item, int):
443                length_to_use = min(truncated_length - bases_so_far, item)
444                bases_so_far += length_to_use
445                truncated_items.append(length_to_use)
446            else:
447                if item.startswith('^'):
448                    truncated_item = ['^']
449                    item = item[1:]
450                else:
451                    truncated_item = []
452
453                for c in item:
454                    if c == '^':
455                        raise ValueError
456
457                    if bases_so_far == truncated_length:
458                        break
459
460                    truncated_item.append(c)
461                    bases_so_far += 1
462
463                truncated_items.append(''.join(truncated_item))
464
465        # Ensure that it starts and ends with a number
466        if not isinstance(truncated_items[0], int):
467            truncated_items = [0] + truncated_items
468        if not isinstance(truncated_items[-1], int):
469            truncated_items = truncated_items + [0]
470
471    return truncated_items
472
473def combine_md_strings(first_string, second_string):
474    if first_string == '':
475        return second_string
476    if second_string == '':
477        return first_string
478
479    first_items = md_string_to_items(first_string)
480    second_items = md_string_to_items(second_string)
481    before = first_items[:-1]
482    after = second_items[1:]
483
484    if isinstance(first_items[-1], int):
485        if isinstance(second_items[0], int):
486            interface = [first_items[-1] + second_items[0]]
487        else:
488            interface = [first_items[-1], second_items[0]]
489    else:
490        if isinstance(second_items[0], int):
491            interface = [first_items[-1], second_items[0]]
492        else:
493            if first_items[-1].startswith('^'):
494                if second_items[0].startswith('^'):
495                    interface = [first_items[-1] + second_items[0][1:]]
496                else:
497                    interface = [first_items[-1], 0, second_items[0]]
498            else:
499                if second_items[0].startswith('^'):
500                    interface = [first_items[-1], 0, second_items[0]]
501                else:
502                    interface = [first_items[-1] + second_items[0]]
503
504    combined_items = before + interface + after
505
506    combined_string = md_items_to_md_string(combined_items)
507
508    return combined_string
509
510def truncate_md_string_up_to(md_string, truncated_length):
511    ''' Truncates from the end of md_string so that the result only consumes
512    truncated_length ref characters.
513    '''
514    md_items = md_string_to_items(md_string)
515    truncated_items = truncate_md_items(md_items, truncated_length)
516    return md_items_to_md_string(truncated_items)
517
518def truncate_md_string_from_beginning(md_string, truncated_length):
519    ''' Truncates from the beginning of md_string so that the result only
520    consumes truncated_length ref characters.
521    '''
522    md_items = md_string_to_items(md_string)
523    reversed_items = reverse_md_items(md_items)
524    reversed_truncated_items = truncate_md_items(reversed_items, truncated_length)
525    truncated_items = reverse_md_items(reversed_truncated_items)
526    return md_items_to_md_string(truncated_items)
527
528def produce_alignment(mapping):
529    ''' Returns a list of (ref_char, read_char, qual_char, ref_pos, read_pos)
530        tuples.
531    '''
532    read_seq = mapping.seq
533    if read_seq == None:
534        read_seq = ''
535
536    read_quals = mapping.query_qualities
537    if read_quals == None:
538        read_quals = []
539
540    MD_string = dict(mapping.tags)['MD']
541
542    ref_ops = iter(md_string_to_ops_string(MD_string))
543
544    columns = []
545
546    ref_pos = mapping.pos
547    read_pos = 0
548    for op, length in mapping.cigar:
549        if op == BAM_CMATCH or op == BAM_CEQUAL or op == BAM_CDIFF:
550            for i in range(length):
551                read_char = read_seq[read_pos]
552
553                ref_op = next(ref_ops)
554                if ref_op == '=':
555                    ref_char = read_char
556                else:
557                    ref_char = ref_op
558
559                qual = read_quals[read_pos]
560
561                column = (ref_char, read_char, qual, ref_pos, read_pos)
562                columns.append(column)
563
564                ref_pos += 1
565                read_pos += 1
566
567        elif op == BAM_CDEL:
568            # Deletion results in gap in read
569            for i in range(length):
570                read_char = '-'
571                ref_char = next(ref_ops)
572                qual = 0
573
574                column = (ref_char, read_char, qual, ref_pos, read_pos)
575                columns.append(column)
576
577                ref_pos += 1
578
579        elif op == BAM_CINS:
580            # Insertion results in gap in ref
581            for i in range(length):
582                read_char = read_seq[read_pos]
583                ref_char = '-'
584                qual = read_quals[read_pos]
585                column = (ref_char, read_char, qual, ref_pos, read_pos)
586                columns.append(column)
587
588                read_pos += 1
589
590        elif op == BAM_CREF_SKIP:
591            ref_pos += length
592
593        elif op == BAM_CSOFT_CLIP:
594            read_pos += length
595
596    return columns
597
598def ref_dict_from_mapping(mapping):
599    ''' Build a dictionary mapping reference positions to base identities from
600    the cigar and MD tag of a mapping.
601    '''
602    alignment = produce_alignment(mapping)
603    ref_dict = {}
604    for ref_char, _, _, ref_position, _ in alignment:
605        if ref_char == '-':
606            continue
607
608        if ref_position in ref_dict:
609            # A ref_position shouldn't appear more than once
610            raise ValueError(mapping)
611
612        ref_dict[ref_position] = ref_char
613
614    return ref_dict
615
616def merge_ref_dicts(first_dict, second_dict):
617    ''' Merge dictionaries mapping reference positions to base identities. '''
618    merged_dict = {}
619    merged_dict.update(first_dict)
620    for position, base in second_dict.items():
621        if position in merged_dict:
622            if merged_dict[position] != base:
623                # contradiction
624                raise ValueError(first_dict, second_dict)
625        else:
626            merged_dict[position] = base
627
628    return merged_dict
629
630def alignment_to_MD_string(ref_aligned, read_aligned):
631    ''' Produce an MD string from an alignment. '''
632    # Mark all blocks of matching with numbers, all deletion bases with '^*0', and all mismatch bases.
633    MD_list = []
634    current_match_length = 0
635    for ref_char, read_char in zip(ref_aligned, read_aligned):
636        if ref_char == read_char:
637            current_match_length += 1
638        elif ref_char != '-':
639            if current_match_length > 0:
640                MD_list.append(current_match_length)
641                current_match_length = 0
642
643            if read_char == '-':
644                MD_list.append(0)
645            else:
646                MD_list.append(ref_char)
647
648    if current_match_length > 0:
649        MD_list.append(current_match_length)
650
651    # Remove all zeros that aren't a deletion followed by a mismatch
652    reduced_MD_list = []
653    for i in range(len(MD_list)):
654        if isinstance(MD_list[i], int):
655            if MD_list[i] > 0:
656                reduced_MD_list.append(MD_list[i])
657            elif 0 < i < len(MD_list) - 1:
658                if isinstance(MD_list[i - 1], str) and isinstance(MD_list[i + 1], str) and MD_list[i - 1][0] == '^' and MD_list[i + 1][0] != '^':
659                    reduced_MD_list.append(MD_list[i])
660        else:
661            reduced_MD_list.append(MD_list[i])
662
663    # Collapse all deletions.
664    collapsed_MD_list = [reduced_MD_list[0]]
665    for i in range(1, len(reduced_MD_list)):
666        if isinstance(collapsed_MD_list[-1], str) and collapsed_MD_list[-1][0] == '^' and \
667           isinstance(reduced_MD_list[i], str) and reduced_MD_list[i][0] == '^':
668
669            collapsed_MD_list[-1] += reduced_MD_list[i][1]
670        else:
671            collapsed_MD_list.append(reduced_MD_list[i])
672
673    # The standard calls for a number to start and to end, zero if necessary.
674    if isinstance(collapsed_MD_list[0], str):
675        collapsed_MD_list.insert(0, 0)
676    if isinstance(collapsed_MD_list[-1], str):
677        collapsed_MD_list.append(0)
678
679    MD_string = ''.join(map(str, collapsed_MD_list))
680    return MD_string
681
682def sort_bam(input_file_name, output_file_name, by_name=False, num_threads=1):
683    output_file_name = Path(output_file_name)
684
685    samtools_command = ['samtools', 'sort']
686    if by_name:
687        samtools_command.append('-n')
688
689    # For unambiguously marking temp files from this function.
690    tail = 'SORT_TEMP'
691
692    # Clean up any temporary files left behind by previous attempts to sort.
693    pattern = f'{output_file_name.name}\.{tail}\.\d\d\d\d\.bam'
694    for fn in output_file_name.parent.iterdir():
695        if re.match(pattern, fn.name):
696            fn.unlink()
697
698    samtools_command.extend(['-@', str(num_threads),
699                             '-T', str(output_file_name) + f'.{tail}',
700                             '-o', str(output_file_name),
701                             str(input_file_name),
702                            ])
703
704    try:
705        subprocess.run(samtools_command, check=True, stderr=subprocess.PIPE)
706    except subprocess.CalledProcessError as e:
707        print(e.stderr)
708        raise
709
710    if not by_name:
711        pysam.index(str(output_file_name))
712
713def merge_sorted_bam_files(input_file_names, merged_file_name, by_name=False, make_index=True):
714    # To avoid running into max open file limits, split into groups of 500.
715    if len(input_file_names) > 500:
716        chunks = utilities.list_chunks(input_file_names, 500)
717        merged_chunk_fns = []
718        for i, chunk in enumerate(chunks):
719            merged_chunk_fn = str(merged_file_name) + '.{:04d}'.format(i)
720            merged_chunk_fns.append(merged_chunk_fn)
721            merge_sorted_bam_files(chunk, merged_chunk_fn, by_name=by_name, make_index=False)
722
723        merge_sorted_bam_files(merged_chunk_fns, merged_file_name)
724
725        for merged_chunk_fn in merged_chunk_fns:
726            os.remove(merged_chunk_fn)
727
728    else:
729        input_file_names = [str(fn) for fn in input_file_names]
730        merged_file_name = str(merged_file_name)
731
732        if len(input_file_names) == 1:
733            shutil.copy(input_file_names[0], merged_file_name)
734        else:
735            merge_command = ['samtools', 'merge', '-f']
736
737            if by_name:
738                merge_command.append('-n')
739
740            merge_command.extend([merged_file_name] + input_file_names)
741
742            try:
743                subprocess.run(merge_command, check=True, stderr=subprocess.PIPE)
744            except subprocess.CalledProcessError as e:
745                print(e.stderr)
746                raise
747
748        if make_index and not by_name:
749            try:
750                pysam.index(merged_file_name)
751            except pysam.utils.SamtoolsError:
752                # Need to sort the merged file because at least one input file was missing a target.
753                temp_sorted_name = merged_file_name + '.temp_sorted'
754                sort_bam(merged_file_name, temp_sorted_name)
755                os.rename(temp_sorted_name, merged_file_name)
756                os.rename(temp_sorted_name + '.bai', merged_file_name + '.bai')
757
758def bam_to_sam(bam_file_name, sam_file_name):
759    view_command = ['samtools', 'view', '-h', '-o', sam_file_name, bam_file_name]
760    subprocess.check_call(view_command)
761
762def get_length_counts(bam_file_name, only_primary=True, only_unique=False):
763    bam_file = pysam.AlignmentFile(bam_file_name)
764    if only_unique:
765        qlen_counts = Counter(ar.qlen for ar in bam_file if ar.mapping_quality == 50)
766    elif only_primary:
767        qlen_counts = Counter(ar.qlen for ar in bam_file if not ar.is_unmapped and not ar.is_secondary)
768    else:
769        qlen_counts = Counter(ar.qlen for ar in bam_file)
770
771    return qlen_counts
772
773def get_tlen_counts(bam_file_name, only_primary=True, only_unique=False):
774    bam_file = pysam.AlignmentFile(bam_file_name)
775    if only_unique:
776        tlen_counts = Counter(ar.tlen for ar in bam_file if ar.mapping_quality == 50)
777    elif only_primary:
778        tlen_counts = Counter(ar.tlen for ar in bam_file if not ar.is_unmapped and not ar.is_secondary)
779    else:
780        tlen_counts = Counter(ar.tlen for ar in bam_file)
781
782    return tlen_counts
783
784def get_mapq_counts(bam_file_name):
785    bam_file = pysam.AlignmentFile(bam_file_name)
786    mapq_counts = Counter(ar.mapq for ar in bam_file)
787    return mapq_counts
788
789def mapping_to_Read(mapping):
790    seq = mapping.get_forward_sequence()
791    qual = fastq.encode_sanger(mapping.get_forward_qualities())
792
793    read = fastq.Read(mapping.query_name, seq, qual)
794    return read
795
796def sam_to_fastq(sam_file_name):
797    sam_file = pysam.AlignmentFile(str(sam_file_name))
798    for mapping in sam_file:
799        yield mapping_to_Read(mapping)
800
801bam_to_fastq = sam_to_fastq
802
803class AlignmentSorter(object):
804    ''' Context manager that handles writing AlignedSegments into a samtools
805    sort process.
806    '''
807    temp_prefix_tail = 'ALIGNMENTSORTER_TEMP'
808
809    def __init__(self, output_file_name, header, by_name=False):
810        self.header = header
811        self.output_file_name = Path(output_file_name)
812        self.by_name = by_name
813        self.fifo = mapping_tools.TemporaryFifo(name='unsorted_fifo.bam')
814
815    def remove_temporary_files(self):
816        ''' Find any temporary files that might have been left behind by a
817        previous call with the same output_file_name.
818        '''
819        pattern = f'{self.output_file_name.name}\.{AlignmentSorter.temp_prefix_tail}\.\d\d\d\d\.bam'
820        for fn in self.output_file_name.parent.iterdir():
821            if re.match(pattern, fn.name):
822                fn.unlink()
823
824    def __enter__(self):
825        self.fifo.__enter__()
826
827        self.remove_temporary_files()
828
829        sort_command = ['samtools', 'sort']
830        if self.by_name:
831            sort_command.append('-n')
832
833        self.dev_null = open(os.devnull, 'w')
834        sort_command.extend(['-T', str(self.output_file_name) + f'.{AlignmentSorter.temp_prefix_tail}',
835                             '-o', str(self.output_file_name),
836                             self.fifo.file_name,
837                            ])
838        self.sort_process = subprocess.Popen(sort_command,
839                                             stderr=subprocess.PIPE,
840                                            )
841
842        self.sam_file = pysam.AlignmentFile(self.fifo.file_name, 'wbu', header=self.header)
843
844        return self
845
846    def __exit__(self, exception_type, exception_value, exception_traceback):
847        self.sam_file.close()
848        _, err_output = self.sort_process.communicate()
849        self.dev_null.close()
850        self.fifo.__exit__(exception_type, exception_value, exception_traceback)
851
852        self.remove_temporary_files()
853
854        if self.sort_process.returncode:
855            raise RuntimeError(err_output)
856
857        if not self.by_name:
858            pysam.index(str(self.output_file_name))
859
860    def write(self, alignment):
861        self.sam_file.write(alignment)
862
863class multiple_AlignmentSorters(contextlib.ExitStack):
864    def __init__(self, header=None, by_name=False):
865        super().__init__()
866        self.sorters = {}
867        self.header = header
868        self.by_name = by_name
869
870    def __enter__(self):
871        super().__enter__()
872        for name in self.sorters:
873            self.enter_context(self.sorters[name])
874
875        return self
876
877    def __getitem__(self, key):
878        return self.sorters[key]
879
880    def __setitem__(self, name, fn_and_possibly_header):
881        if isinstance(fn_and_possibly_header, tuple):
882            fn, header = fn_and_possibly_header
883        else:
884            fn = fn_and_possibly_header
885            header = self.header
886
887        self.sorters[name] = AlignmentSorter(fn, header, self.by_name)
888
889class AlignedSegmentByName(object):
890    def __init__(self, aligned_segment):
891        self.aligned_segment = aligned_segment
892
893    def __lt__(self, other):
894        return self.aligned_segment.query_name < other.aligned_segment.query_name
895
896def merge_by_name(*mapping_iterators):
897    ''' Merges iterators over mappings that are sorted by name.
898    '''
899    wrapped_iterators = [(AlignedSegmentByName(m) for m in mappings) for mappings in mapping_iterators]
900    merged_wrapped = heapq.merge(*wrapped_iterators)
901    last_qname = None
902    for al_by_name in merged_wrapped:
903        qname = al_by_name.aligned_segment.query_name
904        if last_qname is not None and qname < last_qname:
905            print(last_qname, qname)
906            raise ValueError('Attempted to merge unsorted mapping iterators')
907
908        last_qname = qname
909        yield al_by_name.aligned_segment
910
911def aligned_pairs_exclude_soft_clipping(mapping):
912    cigar = mapping.cigartuples
913    aligned_pairs = mapping.aligned_pairs
914
915    first_op, first_length = cigar[0]
916
917    if first_op == BAM_CSOFT_CLIP:
918        aligned_pairs = aligned_pairs[first_length:]
919
920    if len(cigar) > 1:
921        last_op, last_length = cigar[-1]
922        if last_op == BAM_CSOFT_CLIP and last_length != 0:
923            aligned_pairs = aligned_pairs[:-last_length]
924
925    return aligned_pairs
926
927def parse_idxstats(bam_fn):
928    lines = pysam.idxstats(str(bam_fn)).splitlines()
929    fields = [line.split('\t') for line in lines]
930    parsed = {rname: int(count) for rname, _, count, _ in fields}
931    return parsed
932
933def get_num_alignments(bam_fn):
934    return sum(parse_idxstats(bam_fn).values())
935
936def collapse_soft_clip_blocks(cigar_blocks):
937    ''' If there are multiple consecutive soft clip blocks on either end,
938    collapse them into a single block.
939    '''
940    first_non_cigar_index = 0
941    while cigar_blocks[first_non_cigar_index][0] == BAM_CSOFT_CLIP:
942        first_non_cigar_index += 1
943
944    if first_non_cigar_index > 0:
945        total_length = sum(length for kind, length in cigar_blocks[:first_non_cigar_index])
946        if total_length > 0:
947            to_add = [(BAM_CSOFT_CLIP, total_length)]
948        else:
949            to_add = []
950        cigar_blocks = to_add + cigar_blocks[first_non_cigar_index:]
951
952    last_non_cigar_index = len(cigar_blocks) - 1
953    while cigar_blocks[last_non_cigar_index][0] == BAM_CSOFT_CLIP:
954        last_non_cigar_index -= 1
955
956    if last_non_cigar_index < len(cigar_blocks) - 1:
957        total_length = sum(length for kind, length in cigar_blocks[last_non_cigar_index + 1:])
958        if total_length > 0:
959            to_add = [(BAM_CSOFT_CLIP, total_length)]
960        else:
961            to_add = []
962        cigar_blocks = cigar_blocks[:last_non_cigar_index + 1] + to_add
963
964    return cigar_blocks
965
966def crop_al_to_query_int(alignment, start, end):
967    ''' Replace any parts of alignment that involve query bases not in the
968    interval [start, end] with soft clipping.
969    query coords are given relative to the original read (and are therefore
970    transformed if alignment is reversed.)
971    '''
972    alignment = copy.deepcopy(alignment)
973
974    if alignment is None or alignment.is_unmapped:
975        return alignment
976
977    if end < start:
978        # query interval is empty
979        return None
980    else:
981        overlap = interval.Interval(start, end) & interval.get_covered(alignment)
982        if len(overlap) == 0:
983            return None
984
985    if alignment.is_reverse:
986        start, end = alignment.query_length - 1 - end, alignment.query_length - 1 - start
987
988    aligned_pairs = cigar_to_aligned_pairs(alignment.cigar, alignment.reference_start)
989
990    start_i = 0
991    read, ref = aligned_pairs[start_i]
992    while read is None or read == 's' or read < start:
993        start_i += 1
994        read, ref = aligned_pairs[start_i]
995
996    end_i = len(aligned_pairs) - 1
997    read, ref = aligned_pairs[end_i]
998    while read is None or read == 's' or read > end:
999        end_i -= 1
1000        read, ref = aligned_pairs[end_i]
1001
1002    if alignment.has_tag('MD'):
1003        MD = alignment.get_tag('MD')
1004
1005        total_ref_nucs = total_reference_nucs_except_splicing(alignment.cigar)
1006        removed_from_start = total_reference_nucs_except_splicing(aligned_pairs_to_cigar(aligned_pairs[:start_i]))
1007        removed_from_end = total_reference_nucs_except_splicing(aligned_pairs_to_cigar(aligned_pairs[end_i + 1:]))
1008
1009        MD = truncate_md_string_up_to(MD, total_ref_nucs - removed_from_end)
1010        MD = truncate_md_string_from_beginning(MD, total_ref_nucs - removed_from_end - removed_from_start)
1011
1012        alignment.set_tag('MD', MD)
1013
1014    restricted_pairs = aligned_pairs[start_i:end_i + 1]
1015    cigar = aligned_pairs_to_cigar(restricted_pairs)
1016
1017    before_soft = start
1018    if before_soft > 0:
1019        cigar = [(BAM_CSOFT_CLIP, before_soft)] + cigar
1020
1021    after_soft = alignment.query_length - end - 1
1022    if after_soft > 0:
1023        cigar = cigar + [(BAM_CSOFT_CLIP, after_soft)]
1024
1025    cigar = collapse_soft_clip_blocks(cigar)
1026
1027    restricted_rs = [r for q, r in restricted_pairs if r != None and r != 'S']
1028    if not restricted_rs:
1029        alignment.is_unmapped = True
1030        alignment.cigar = []
1031    else:
1032        alignment.reference_start = min(restricted_rs)
1033        alignment.cigar = cigar
1034
1035    return alignment
1036
1037def crop_al_to_ref_int(alignment, start, end):
1038    ''' Returns a copy of alignment in which any query bases that align
1039    outside the interval [start, end] are soft-clipped. If no bases are left,
1040    sets alignment.is_unmapped to true.
1041    '''
1042    alignment = copy.deepcopy(alignment)
1043
1044    if alignment.reference_start > end or alignment.reference_end - 1 < start:
1045        # alignment doesn't overlap the ref interval at all
1046        return None
1047
1048    if alignment.reference_start >= start and alignment.reference_end - 1 <= end:
1049        # alignment is entirely contained in the ref_interval
1050        return alignment
1051
1052    query_length = alignment.query_length
1053    aligned_pairs = cigar_to_aligned_pairs(alignment.cigar, alignment.reference_start)
1054
1055    start_i = 0
1056    while (
1057        aligned_pairs[start_i][1] == 'S' or
1058        aligned_pairs[start_i][0] is None or
1059        aligned_pairs[start_i][1] is None or
1060        aligned_pairs[start_i][1] < start
1061    ):
1062        start_i += 1
1063
1064    end_i = len(aligned_pairs) - 1
1065    while (
1066        aligned_pairs[end_i][1] == 'S' or
1067        aligned_pairs[end_i][0] is None or
1068        aligned_pairs[end_i][1] is None or
1069        aligned_pairs[end_i][1] > end
1070    ):
1071        end_i -= 1
1072
1073    remaining = aligned_pairs[start_i:end_i + 1]
1074    if remaining:
1075        cigar = aligned_pairs_to_cigar(remaining)
1076        before_soft = remaining[0][0]
1077        if before_soft > 0:
1078            cigar = [(BAM_CSOFT_CLIP, before_soft)] + cigar
1079
1080        after_soft = query_length - remaining[-1][0] - 1
1081        if after_soft > 0:
1082            cigar = cigar + [(BAM_CSOFT_CLIP, after_soft)]
1083
1084        alignment.cigar = cigar
1085
1086        alignment.reference_start = aligned_pairs[start_i][1]
1087    else:
1088        alignment.is_unmapped = True
1089        alignment.cigar = []
1090
1091    return alignment
1092
1093def disallow_query_positions_from_other(alignment, other):
1094    start, end = query_interval(alignment)
1095    other_start, other_end = query_interval(other)
1096    if other_start <= end or other_end >= start:
1097        if other_start > start and other_end < end:
1098            raise ValueError
1099        elif other_start <= start:
1100            alignment = crop_al_to_query_int(alignment, other_end + 1, alignment.query_length - 1)
1101        elif other_end >= end:
1102            alignment = crop_al_to_query_int(alignment, 0, other_start - 1)
1103
1104    return alignment
1105
1106def query_interval(alignment):
1107    start = alignment.query_alignment_start
1108    end = alignment.query_alignment_end - 1
1109
1110    if alignment.is_reverse:
1111        start, end = true_query_position(end, alignment), true_query_position(start, alignment)
1112
1113    return start, end
1114
1115def merge_multiple_adjacent_alignments(als, ref_seqs):
1116    merger = functools.partial(merge_adjacent_alignments, ref_seqs=ref_seqs)
1117    als = sorted(als, key=query_interval)
1118    return functools.reduce(merger, als)
1119
1120def merge_adjacent_alignments(first, second, ref_seqs):
1121    ''' If first and second are alignments to the same reference name and strand
1122    that are adjacent or partially overlap on the query, returns a single merged
1123    alignment with an appropriately sized deletion that minimizes edit distance,
1124    otherwise return None.
1125    '''
1126    if first is None or second is None:
1127        return None
1128
1129    if first == second:
1130        return first
1131
1132    if first.reference_name != second.reference_name:
1133        return None
1134    else:
1135        ref_seq = ref_seqs[first.reference_name]
1136
1137    if get_strand(first) != get_strand(second):
1138        return None
1139
1140    left_query, right_query = sorted([first, second], key=query_interval)
1141    left_covered = interval.get_covered(left_query)
1142    right_covered = interval.get_covered(right_query)
1143
1144    # Ensure that the alignments point towards each other.
1145    strand = get_strand(first)
1146    if strand == '+':
1147        left_cropped = crop_al_to_query_int(left_query, 0, right_covered.start - 1)
1148        if left_cropped is None:
1149            # left alignment doesn't cover any query not covered by right
1150            return None
1151
1152        if left_cropped.reference_end > right_query.reference_start:
1153            return None
1154
1155    elif strand == '-':
1156        right_cropped = crop_al_to_query_int(right_query, left_covered.end + 1, np.inf)
1157        if right_cropped is None:
1158           # right alignment doesn't cover any query not covered by left
1159           return None
1160
1161        if right_cropped.reference_end > left_query.reference_start:
1162            return None
1163
1164    if interval.are_adjacent(left_covered, right_covered):
1165        left_cropped, right_cropped = left_query, right_query
1166
1167    elif interval.are_disjoint(left_covered, right_covered):
1168        return None
1169
1170    else:
1171        overlap = left_covered & right_covered
1172        left_ceds = cumulative_edit_distances(left_query, overlap, False, ref_seq=ref_seq)
1173        right_ceds = cumulative_edit_distances(right_query, overlap, True, ref_seq=ref_seq)
1174
1175        switch_after_edits = {
1176            overlap.start - 1 : right_ceds[overlap.start],
1177            overlap.end: left_ceds[overlap.end],
1178        }
1179
1180        for q in range(overlap.start, overlap.end):
1181            switch_after_edits[q] = left_ceds[q] + right_ceds[q + 1]
1182
1183        min_edits = min(switch_after_edits.values())
1184        best_switch_points = [s for s, d in switch_after_edits.items() if d == min_edits]
1185        switch_after = best_switch_points[0]
1186
1187        left_cropped = crop_al_to_query_int(left_query, 0, switch_after)
1188        right_cropped = crop_al_to_query_int(right_query, switch_after + 1, right_query.query_length)
1189
1190    if left_cropped is None or left_cropped.is_unmapped or right_cropped is None or right_cropped.is_unmapped:
1191        # this may not be appropriate in all circumstances
1192        return None
1193
1194    left_ref, right_ref = sorted([left_cropped, right_cropped], key=lambda al: al.reference_start)
1195
1196    if left_ref.reference_end >= right_ref.reference_start:
1197        return None
1198
1199    deletion_length = right_ref.reference_start - left_ref.reference_end
1200
1201    merged = copy.deepcopy(left_ref)
1202    merged.cigar = left_ref.cigar[:-1] + [(BAM_CDEL, deletion_length)] + right_ref.cigar[1:]
1203
1204    return merged
1205
1206def cumulative_edit_distances(mapping, query_interval, from_end, ref_seq=None):
1207    ''' Returns a dictionary of how many cumulatives edits are involved
1208    in mapping from the beginning (or end, if from_end is True) of query_interval
1209    to each query position in query_interval.
1210    '''
1211    tuples = aligned_tuples(mapping, ref_seq=ref_seq)
1212
1213    if get_strand(mapping) == '-':
1214        tuples = tuples[::-1]
1215
1216    beginning = [i for i, (q, read, r, ref, qual) in enumerate(tuples) if q == query_interval.start][0]
1217    end = [i for i, (q, read, r, ref, qual) in enumerate(tuples) if q == query_interval.end][-1]
1218
1219    relevant = tuples[beginning:end + 1]
1220    if from_end:
1221        relevant = relevant[::-1]
1222
1223    c_e_ds = {}
1224
1225    total = 0
1226
1227    for q, read_b, r, ref_b, qual in relevant:
1228        if read_b != ref_b:
1229            total += 1
1230
1231        if q is not None:
1232            c_e_ds[q] = total
1233
1234    return c_e_ds
1235
1236def find_best_query_switch_after(left_al, right_al, left_ref_seq, right_ref_seq, tie_break):
1237    ''' If left_al and right_al overlap on the query, find the query position such that switching from
1238    left_al to right_al after that position minimizes the total number of edits.
1239    '''
1240    left_covered = interval.get_covered(left_al)
1241    right_covered = interval.get_covered(right_al)
1242    overlap = left_covered & right_covered
1243
1244    if left_al is None:
1245        if right_al is None:
1246            raise ValueError
1247        gap_interval = interval.Interval(0, right_covered.start - 1)
1248    elif right_al is None:
1249        if left_al is None:
1250            raise ValueError
1251        gap_interval = interval.Interval(left_covered.end + 1, len(left_al.seq) - 1)
1252    else:
1253        gap_interval = interval.Interval(left_covered.end + 1, right_covered.start - 1)
1254
1255    if overlap:
1256        left_ceds = cumulative_edit_distances(left_al, overlap, False, ref_seq=left_ref_seq)
1257        right_ceds = cumulative_edit_distances(right_al, overlap, True, ref_seq=right_ref_seq)
1258
1259        switch_after_edits = {
1260            overlap.start - 1 : right_ceds[overlap.start],
1261            overlap.end: left_ceds[overlap.end],
1262        }
1263
1264        for q in range(overlap.start, overlap.end):
1265            switch_after_edits[q] = left_ceds[q] + right_ceds[q + 1]
1266
1267        min_edits = min(switch_after_edits.values())
1268        best_switch_points = [s for s, d in switch_after_edits.items() if d == min_edits]
1269        switch_after = tie_break(best_switch_points)
1270    else:
1271        min_edits = 0
1272        switch_after = left_covered.end
1273
1274    if gap_interval.is_empty:
1275        gap_length = 0
1276        gap_interval = None
1277    else:
1278        gap_length = len(gap_interval)
1279
1280    results = {
1281        'switch_after': switch_after,
1282        'min_edits': min_edits,
1283        'gap_interval': gap_interval,
1284        'gap_length': gap_length,
1285    }
1286
1287    return results
1288
1289def true_query_position(p, alignment):
1290    if alignment.is_reverse:
1291        p = alignment.query_length - 1 - p
1292    return p
1293
1294def closest_query_position(r, alignment, which_side='either'):
1295    ''' Return query position paired with r (or with the closest to r) '''
1296    r_to_q = {r: true_query_position(q, alignment)
1297              for q, r in alignment.aligned_pairs
1298              if r is not None and q is not None
1299             }
1300    if r in r_to_q:
1301        q = r_to_q[r]
1302    else:
1303        if which_side == 'either':
1304            rs = r_to_q
1305        elif which_side == 'before':
1306            rs = (other_r for other_r in r_to_q if other_r < r)
1307        elif which_side == 'after':
1308            rs = (other_r for other_r in r_to_q if other_r > r)
1309
1310        closest_r = min(rs, key=lambda other_r: abs(other_r - r))
1311        q = r_to_q[closest_r]
1312
1313    return q
1314
1315def closest_ref_position(q, alignment, which_side='either'):
1316    ''' Return ref position paired with q (or with the closest to q) '''
1317    q_to_r = {true_query_position(q, alignment): r
1318              for q, r in alignment.aligned_pairs
1319              if r is not None and q is not None
1320             }
1321
1322    if q in q_to_r:
1323        r = q_to_r[q]
1324    else:
1325        if which_side == 'either':
1326            qs = q_to_r
1327        elif which_side == 'before':
1328            qs = [other_q for other_q in q_to_r if other_q < q]
1329        elif which_side == 'after':
1330            qs = [other_q for other_q in q_to_r if other_q > q]
1331
1332        if len(qs) == 0:
1333            r = None
1334        else:
1335            closest_q = min(qs, key=lambda other_q: abs(other_q - q))
1336            r = q_to_r[closest_q]
1337
1338    return r
1339
1340def max_block_length(alignment, block_types):
1341    if alignment is None or alignment.is_unmapped:
1342        return 0
1343    else:
1344        block_lengths = [l for k, l in alignment.cigar if k in block_types]
1345        if len(block_lengths) == 0:
1346            return 0
1347        else:
1348            return max(block_lengths)
1349
1350def total_indel_lengths(alignment):
1351    if alignment.is_unmapped:
1352        return 0
1353    else:
1354        return sum(l for k, l in alignment.cigar if k == BAM_CDEL or k == BAM_CINS)
1355
1356def get_ref_pos_to_block(alignment):
1357    ref_pos_to_block = {}
1358    ref_pos = alignment.reference_start
1359    for kind, length in alignment.cigar:
1360        if kind in ref_consuming_ops:
1361            starts_at = ref_pos
1362
1363            for r in range(ref_pos, ref_pos + length):
1364                ref_pos_to_block[r] = (kind, length, starts_at)
1365
1366            ref_pos += length
1367
1368    return ref_pos_to_block
1369
1370def split_at_first_large_insertion(alignment, min_length):
1371    q = 0
1372
1373    # Using cigar blocks, march from beginning of the read to the (possible)
1374    # insertion point to determine the query interval to crop to.
1375    # If the alignment is reversed, alignment.cigar is reversed relative to
1376    # true query positions.
1377    cigar = alignment.cigar
1378    if alignment.is_reverse:
1379        cigar = cigar[::-1]
1380
1381    for kind, length in cigar:
1382        if kind == BAM_CINS and length >= min_length:
1383            before = crop_al_to_query_int(alignment, 0, q - 1)
1384            after = crop_al_to_query_int(alignment, q + length, alignment.query_length)
1385            return [before, after]
1386        else:
1387            if kind in read_consuming_ops:
1388                q += length
1389
1390    return [alignment]
1391
1392def split_at_large_insertions(alignment, min_length):
1393    ''' O(n^2) behavior can be bad for pacbio alignments. '''
1394    all_split = []
1395
1396    to_split = [alignment]
1397
1398    while len(to_split) > 0:
1399        candidate = to_split.pop()
1400        split = split_at_first_large_insertion(candidate, min_length)
1401        if len(split) > 1:
1402            to_split.extend(split)
1403        else:
1404            all_split.extend(split)
1405
1406    return all_split
1407
1408def split_at_deletions(alignment, min_length, exempt_if_overlaps=None):
1409    ''' Split at deletions at least min_length that don't overlap exempt_if_overlaps. '''
1410
1411    ref_start = alignment.reference_start
1412    query_bases_before = 0
1413    query_bases_after = alignment.query_length
1414
1415    tuples = []
1416
1417    split_at = []
1418
1419    for i, (kind, length) in enumerate(alignment.cigar):
1420        if kind == BAM_CDEL:
1421            del_interval = interval.Interval(ref_start, ref_start + length - 1)
1422
1423            if exempt_if_overlaps is None:
1424                overlaps = False
1425            else:
1426                overlaps = len(del_interval & exempt_if_overlaps) > 0
1427
1428            if length >= min_length and not overlaps:
1429                if i != 0 and i != len(alignment.cigar) - 1:
1430                    split_at.append(i)
1431
1432        if kind in read_consuming_ops:
1433            read_consumed = length
1434        else:
1435            read_consumed = 0
1436
1437        if kind in ref_consuming_ops:
1438            ref_consumed = length
1439        else:
1440            ref_consumed = 0
1441
1442        query_bases_after -= read_consumed
1443        tuples.append((query_bases_before, query_bases_after, ref_start))
1444
1445        query_bases_before += read_consumed
1446        ref_start += ref_consumed
1447
1448    split_alignments = []
1449
1450    if len(split_at) == 0:
1451        split_alignments = [alignment]
1452    else:
1453        split_at = [-1] + split_at + [len(alignment.cigar)]
1454
1455        for i in range(len(split_at) - 1):
1456            query_bases_before, _, ref_start = tuples[split_at[i] + 1]
1457            if split_at[i + 1] == len(alignment.cigar):
1458                query_bases_after = 0
1459            else:
1460                _, query_bases_after, _ = tuples[split_at[i + 1]]
1461
1462            new_cigar = alignment.cigar[split_at[i] + 1:split_at[i + 1]]
1463
1464            if query_bases_before > 0:
1465                new_cigar = [(BAM_CSOFT_CLIP, query_bases_before)] + new_cigar
1466
1467            if query_bases_after > 0:
1468                new_cigar = new_cigar + [(BAM_CSOFT_CLIP, query_bases_after)]
1469
1470            split_al = copy.deepcopy(alignment)
1471            split_al.cigar = new_cigar
1472            split_al.reference_start = ref_start
1473
1474            split_alignments.append(split_al)
1475
1476    split_alignments = [soft_clip_terminal_insertions(al) for al in split_alignments]
1477
1478    return split_alignments
1479
1480def soft_clip_terminal_insertions(al):
1481    ''' If al starts or ends with insertions, convert the relevant bases into soft-clipping. '''
1482    initial_cigar = al.cigar
1483    clip_lengths = {}
1484
1485    if initial_cigar[0][0] == BAM_CSOFT_CLIP:
1486        clip_lengths['beginning'] = initial_cigar[0][1]
1487    else:
1488        clip_lengths['beginning'] = 0
1489
1490    if len(initial_cigar) > 1 and initial_cigar[-1][0] == BAM_CSOFT_CLIP:
1491        clip_lengths['end'] = initial_cigar[-1][1]
1492    else:
1493        clip_lengths['end'] = 0
1494
1495    non_soft_clipped_blocks = [(kind, length) for kind, length in al.cigar if kind != BAM_CSOFT_CLIP]
1496
1497    if len(non_soft_clipped_blocks) == 0:
1498        return al
1499
1500    had_terminal_insertion = False
1501
1502    first_kind, first_length = non_soft_clipped_blocks[0]
1503
1504    if first_kind == BAM_CINS:
1505        had_terminal_insertion = True
1506        clip_lengths['beginning'] += first_length
1507
1508        non_soft_clipped_blocks = non_soft_clipped_blocks[1:]
1509
1510    if len(non_soft_clipped_blocks) > 0:
1511        last_kind, last_length = non_soft_clipped_blocks[-1]
1512
1513        if last_kind == BAM_CINS:
1514            had_terminal_insertion = True
1515            clip_lengths['end'] += last_length
1516
1517            non_soft_clipped_blocks = non_soft_clipped_blocks[:-1]
1518
1519    if had_terminal_insertion:
1520        new_cigar = non_soft_clipped_blocks
1521
1522        if clip_lengths['beginning'] > 0:
1523            new_cigar = [(BAM_CSOFT_CLIP, clip_lengths['beginning'])] + new_cigar
1524
1525        if clip_lengths['end'] > 0:
1526            new_cigar = new_cigar + [(BAM_CSOFT_CLIP, clip_lengths['end'])]
1527
1528        new_al = copy.deepcopy(al)
1529        new_al.cigar = new_cigar
1530    else:
1531        new_al = al
1532
1533    return new_al
1534
1535def grouped_by_name(als):
1536    if isinstance(als, (str, Path)):
1537        als = pysam.AlignmentFile(als)
1538
1539    grouped = utilities.group_by(als, lambda al: al.query_name)
1540
1541    return grouped
1542
1543def header_from_STAR_index(index):
1544    index = Path(index)
1545    names = [l.strip() for l in (index / 'chrName.txt').open()]
1546    lengths = [int(l.strip()) for l in (index / 'chrLength.txt').open()]
1547    header = pysam.AlignmentHeader.from_references(names, lengths)
1548    return header
1549
1550def header_from_fasta(fasta_fn):
1551    fai = fasta.load_fai(fasta_fn).sort_index()
1552
1553    names = [name for name, row in fai.iterrows()]
1554    lengths = [row['LENGTH'] for name, row in fai.iterrows()]
1555
1556    header = pysam.AlignmentHeader.from_references(names, lengths)
1557
1558    return header
1559
1560def overlaps_feature(alignment, feature, require_same_strand=True):
1561    if alignment is None or alignment.is_unmapped:
1562        return False
1563
1564    same_reference = alignment.reference_name == feature.seqname
1565    num_overlapping_bases = alignment.get_overlap(feature.start, feature.end)
1566
1567    if require_same_strand:
1568        same_strand = (get_strand(alignment) == feature.strand)
1569    else:
1570        same_strand = True
1571
1572    return same_reference and same_strand and (num_overlapping_bases > 0)
1573
1574def reference_edges(alignment):
1575    if alignment is None or alignment.is_unmapped:
1576        return {5: None, 3: None}
1577
1578    strand = get_strand(alignment)
1579
1580    if strand == '+':
1581        edges = {
1582            5: alignment.reference_start,
1583            3: alignment.reference_end - 1,
1584        }
1585    elif strand == '-':
1586        edges = {
1587            5: alignment.reference_end - 1,
1588            3: alignment.reference_start,
1589        }
1590
1591    return edges
1592
1593def reference_interval(alignment):
1594    return interval.Interval(alignment.reference_start, alignment.reference_end - 1)
1595
1596def aligned_tuples(alignment, ref_seq=None):
1597    tuples = []
1598
1599    if ref_seq is None:
1600        aligned_pairs = alignment.get_aligned_pairs(with_seq=True)
1601
1602        # Remove soft-clipping
1603        min_i = min(i for i, (q, _, ref_b) in enumerate(aligned_pairs) if ref_b is not None)
1604        max_i = max(i for i, (q, _, ref_b) in enumerate(aligned_pairs) if ref_b is not None)
1605        aligned_pairs = aligned_pairs[min_i:max_i + 1]
1606
1607        for read_i, ref_i, ref_b in aligned_pairs:
1608            if read_i is None:
1609                true_read_i = None
1610                read_b = '-'
1611                qual = -1
1612            else:
1613                true_read_i = true_query_position(read_i, alignment)
1614                read_b = alignment.query_sequence[read_i].upper()
1615                qual = alignment.query_qualities[read_i]
1616
1617            if ref_i is not None and ref_b is None:
1618                # don't include spliced reference positions. (Is this an appropriate check for this situation?)
1619                continue
1620
1621            if ref_i is None:
1622                ref_b = '-'
1623            else:
1624                ref_b = ref_b.upper()
1625
1626            tuples.append((true_read_i, read_b, ref_i, ref_b, qual))
1627
1628    else:
1629        aligned_pairs = aligned_pairs_exclude_soft_clipping(alignment)
1630
1631        for read_i, ref_i in aligned_pairs:
1632            if read_i is None:
1633                read_b = '-'
1634                true_read_i = None
1635                qual = -1
1636            else:
1637                true_read_i = true_query_position(read_i, alignment)
1638                read_b = alignment.query_sequence[read_i].upper()
1639                qual = alignment.query_qualities[read_i]
1640
1641            if ref_i is None:
1642                ref_b = '-'
1643            else:
1644                ref_b = ref_seq[ref_i].upper()
1645
1646            tuples.append((true_read_i, read_b, ref_i, ref_b, qual))
1647
1648    return tuples
1649
1650def total_edit_distance(alignment, ref_seq=None):
1651    return edit_distance_in_query_interval(alignment, ref_seq=ref_seq)
1652
1653def edit_distance_in_query_interval(alignment, query_interval=None, ref_seq=None):
1654    if query_interval is None:
1655        query_interval = interval.Interval(0, np.inf)
1656
1657    if query_interval.is_empty or alignment is None or alignment.is_unmapped:
1658        return 0
1659
1660    distance = 0
1661
1662    start = query_interval.start
1663    end = query_interval.end
1664
1665    tuples = aligned_tuples(alignment, ref_seq)
1666    if alignment.is_reverse:
1667        tuples = tuples[::-1]
1668
1669    first_i = min(i for i, (q, _, _, _, _) in enumerate(tuples) if q is not None and q >= start)
1670    last_i = max(i for i, (q, _, _, _, _) in enumerate(tuples) if q is not None and q <= end)
1671
1672    for q, q_base, r, r_base, qual in tuples[first_i:last_i + 1]:
1673        if q_base != r_base:
1674            distance += 1
1675
1676    return distance
1677
1678def get_header(bam_fn):
1679    with pysam.AlignmentFile(bam_fn) as bam_file:
1680        header = bam_file.header
1681    return header
1682
1683def flip_alignment(alignment):
1684    flipped_alignment = copy.deepcopy(alignment)
1685    flipped_alignment.is_reverse = not alignment.is_reverse
1686    return flipped_alignment
1687
1688def make_nonredundant(alignments):
1689    ''' Two alignments of the same read are redundant if they pair the same read bases with the same
1690    reference bases. Given alignments of the same read, return alignments in which only one representative
1691    of each equivalent class of redundancy is retained.
1692    '''
1693    def fingerprint(al):
1694        return tuple(al.cigar), al.reference_start, al.reference_name, al.is_reverse
1695
1696    nonredundant = {fingerprint(al): al for al in alignments}
1697
1698    return list(nonredundant.values())
1699