1###############################################################################
2#
3# coverageWindows.py - calculate coverage of windows within sequences
4#
5###############################################################################
6#                                                                             #
7#    This program is free software: you can redistribute it and/or modify     #
8#    it under the terms of the GNU General Public License as published by     #
9#    the Free Software Foundation, either version 3 of the License, or        #
10#    (at your option) any later version.                                      #
11#                                                                             #
12#    This program is distributed in the hope that it will be useful,          #
13#    but WITHOUT ANY WARRANTY; without even the implied warranty of           #
14#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the            #
15#    GNU General Public License for more details.                             #
16#                                                                             #
17#    You should have received a copy of the GNU General Public License        #
18#    along with this program. If not, see <http://www.gnu.org/licenses/>.     #
19#                                                                             #
20###############################################################################
21
22import sys
23import os
24import multiprocessing as mp
25import logging
26
27import pysam
28
29from numpy import zeros
30
31
32class ReadLoader:
33    """Callback for counting aligned reads with pysam.fetch"""
34
35    def __init__(self, refLength, bAllReads, minAlignPer, maxEditDistPer):
36        self.bAllReads = bAllReads
37        self.minAlignPer = minAlignPer
38        self.maxEditDistPer = maxEditDistPer
39
40        self.numReads = 0
41        self.numMappedReads = 0
42        self.numDuplicates = 0
43        self.numSecondary = 0
44        self.numFailedQC = 0
45        self.numFailedAlignLen = 0
46        self.numFailedEditDist = 0
47        self.numFailedProperPair = 0
48
49        self.coverage = zeros(refLength)
50
51    def __call__(self, read):
52        self.numReads += 1
53
54        if read.is_unmapped:
55            pass
56        elif read.is_duplicate:
57            self.numDuplicates += 1
58        elif read.is_secondary:
59            self.numSecondary += 1
60        elif read.is_qcfail:
61            self.numFailedQC += 1
62        elif read.alen < self.minAlignPer * read.rlen:
63            self.numFailedAlignLen += 1
64        elif read.opt('NM') > self.maxEditDistPer * read.rlen:
65            self.numFailedEditDist += 1
66        elif not self.bAllReads and not read.is_proper_pair:
67            self.numFailedProperPair += 1
68        else:
69            self.numMappedReads += 1
70
71            # Note: the alignment length (alen) is used instead of the
72            # read length (rlen) as this bring the calculated coverage
73            # in line with 'samtools depth' (at least when the min
74            # alignment length and edit distance thresholds are zero).
75            self.coverage[read.pos:read.pos + read.alen] += 1.0
76
77
78class CoverageStruct():
79    def __init__(self, seqLen, mappedReads, coverage):
80        self.seqLen = seqLen
81        self.mappedReads = mappedReads
82        self.coverage = coverage
83
84
85class CoverageWindows():
86    """Calculate coverage of all sequences."""
87    def __init__(self, threads):
88        self.logger = logging.getLogger('timestamp')
89
90        self.totalThreads = threads
91
92    def run(self, binFiles, bamFile, bAllReads, minAlignPer, maxEditDistPer, windowSize):
93        """Calculate coverage of full sequences and windows."""
94
95        # make sure BAM file is sorted
96        if not os.path.exists(bamFile + '.bai'):
97            self.logger.error('BAM file is not sorted: ' + bamFile + '\n')
98            sys.exit(1)
99
100        # calculate coverage of each BAM file
101        self.logger.info('Calculating coverage of windows.')
102        coverageInfo = mp.Manager().dict()
103        coverageInfo = self.__processBam(bamFile, bAllReads, minAlignPer, maxEditDistPer, windowSize, coverageInfo)
104
105        return coverageInfo
106
107    def __processBam(self, bamFile, bAllReads, minAlignPer, maxEditDistPer, windowSize, coverageInfo):
108        """Calculate coverage of sequences in BAM file."""
109
110        # determine coverage for each reference sequence
111        workerQueue = mp.Queue()
112        writerQueue = mp.Queue()
113
114        bamfile = pysam.Samfile(bamFile, 'rb')
115        refSeqIds = bamfile.references
116        refSeqLens = bamfile.lengths
117
118        # populate each thread with reference sequence to process
119        # Note: reference sequences are sorted by number of mapped reads
120        # so it is important to distribute reads in a sensible way to each
121        # of the threads
122        refSeqLists = [[] for _ in range(self.totalThreads)]
123        refLenLists = [[] for _ in range(self.totalThreads)]
124
125        threadIndex = 0
126        incDir = 1
127        for refSeqId, refLen in zip(refSeqIds, refSeqLens):
128            refSeqLists[threadIndex].append(refSeqId)
129            refLenLists[threadIndex].append(refLen)
130
131            threadIndex += incDir
132            if threadIndex == self.totalThreads:
133                threadIndex = self.totalThreads - 1
134                incDir = -1
135            elif threadIndex == -1:
136                threadIndex = 0
137                incDir = 1
138
139        for i in range(self.totalThreads):
140            workerQueue.put((refSeqLists[i], refLenLists[i]))
141
142        for _ in range(self.totalThreads):
143            workerQueue.put((None, None))
144
145        try:
146            workerProc = [mp.Process(target=self.__workerThread, args=(bamFile, bAllReads, minAlignPer, maxEditDistPer, windowSize, workerQueue, writerQueue)) for _ in range(self.totalThreads)]
147            writeProc = mp.Process(target=self.__writerThread, args=(coverageInfo, len(refSeqIds), writerQueue))
148
149            writeProc.start()
150
151            for p in workerProc:
152                p.start()
153
154            for p in workerProc:
155                p.join()
156
157            writerQueue.put((None, None, None, None, None, None, None, None, None, None, None, None))
158            writeProc.join()
159        except:
160            # make sure all processes are terminated
161            for p in workerProc:
162                p.terminate()
163
164            writeProc.terminate()
165
166        return coverageInfo
167
168    def __workerThread(self, bamFile, bAllReads, minAlignPer, maxEditDistPer, windowSize, queueIn, queueOut):
169        """Process each data item in parallel."""
170        while True:
171            seqIds, seqLens = queueIn.get(block=True, timeout=None)
172            if seqIds == None:
173                break
174
175            bamfile = pysam.Samfile(bamFile, 'rb')
176
177            for seqId, seqLen in zip(seqIds, seqLens):
178                readLoader = ReadLoader(seqLen, bAllReads, minAlignPer, maxEditDistPer)
179                bamfile.fetch(seqId, 0, seqLen, callback=readLoader)
180
181                start = 0
182                end = windowSize
183                windowCoverages = []
184                while(end < seqLen):
185                    windowCoverages.append(sum(readLoader.coverage[start:end]) / windowSize)
186
187                    start = end
188                    try:
189                        end += windowSize
190                    except:
191                        print '*****************'
192                        print end
193                        print windowSize
194                        print '******************'
195
196                coverage = float(sum(readLoader.coverage)) / seqLen
197
198                queueOut.put((seqId, seqLen, coverage, windowCoverages, readLoader.numReads,
199                                readLoader.numDuplicates, readLoader.numSecondary, readLoader.numFailedQC,
200                                readLoader.numFailedAlignLen, readLoader.numFailedEditDist,
201                                readLoader.numFailedProperPair, readLoader.numMappedReads))
202
203            bamfile.close()
204
205    def __writerThread(self, coverageInfo, numRefSeqs, writerQueue):
206        """Store or write results of worker threads in a single thread."""
207        totalReads = 0
208        totalDuplicates = 0
209        totalSecondary = 0
210        totalFailedQC = 0
211        totalFailedAlignLen = 0
212        totalFailedEditDist = 0
213        totalFailedProperPair = 0
214        totalMappedReads = 0
215
216        processedRefSeqs = 0
217        while True:
218            seqId, seqLen, coverage, windowCoverages, numReads, numDuplicates, numSecondary, numFailedQC, numFailedAlignLen, numFailedEditDist, numFailedProperPair, numMappedReads = writerQueue.get(block=True, timeout=None)
219            if seqId == None:
220                break
221
222            if self.logger.getEffectiveLevel() <= logging.INFO:
223                processedRefSeqs += 1
224                statusStr = '    Finished processing %d of %d (%.2f%%) reference sequences.' % (processedRefSeqs, numRefSeqs, float(processedRefSeqs) * 100 / numRefSeqs)
225                sys.stderr.write('%s\r' % statusStr)
226                sys.stderr.flush()
227
228                totalReads += numReads
229                totalDuplicates += numDuplicates
230                totalSecondary += numSecondary
231                totalFailedQC += numFailedQC
232                totalFailedAlignLen += numFailedAlignLen
233                totalFailedEditDist += numFailedEditDist
234                totalFailedProperPair += numFailedProperPair
235                totalMappedReads += numMappedReads
236
237            coverageInfo[seqId] = [coverage, windowCoverages]
238
239        if self.logger.getEffectiveLevel() <= logging.INFO:
240            sys.stderr.write('\n')
241
242            print ''
243            print '    # total reads: %d' % totalReads
244            print '      # properly mapped reads: %d (%.1f%%)' % (totalMappedReads, float(totalMappedReads) * 100 / totalReads)
245            print '      # duplicate reads: %d (%.1f%%)' % (totalDuplicates, float(totalDuplicates) * 100 / totalReads)
246            print '      # secondary reads: %d (%.1f%%)' % (totalSecondary, float(totalSecondary) * 100 / totalReads)
247            print '      # reads failing QC: %d (%.1f%%)' % (totalFailedQC, float(totalFailedQC) * 100 / totalReads)
248            print '      # reads failing alignment length: %d (%.1f%%)' % (totalFailedAlignLen, float(totalFailedAlignLen) * 100 / totalReads)
249            print '      # reads failing edit distance: %d (%.1f%%)' % (totalFailedEditDist, float(totalFailedEditDist) * 100 / totalReads)
250            print '      # reads not properly paired: %d (%.1f%%)' % (totalFailedProperPair, float(totalFailedProperPair) * 100 / totalReads)
251            print ''
252