1###############################################################################
2#
3# coverage.py - calculate coverage of all sequences
4#
5###############################################################################
6#                                                                             #
7#    This program is free software: you can redistribute it and/or modify     #
8#    it under the terms of the GNU General Public License as published by     #
9#    the Free Software Foundation, either version 3 of the License, or        #
10#    (at your option) any later version.                                      #
11#                                                                             #
12#    This program is distributed in the hope that it will be useful,          #
13#    but WITHOUT ANY WARRANTY; without even the implied warranty of           #
14#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the            #
15#    GNU General Public License for more details.                             #
16#                                                                             #
17#    You should have received a copy of the GNU General Public License        #
18#    along with this program. If not, see <http://www.gnu.org/licenses/>.     #
19#                                                                             #
20###############################################################################
21
22import sys
23import os
24import multiprocessing as mp
25import logging
26import ntpath
27import traceback
28from collections import defaultdict
29
30import pysam
31
32from checkm.defaultValues import DefaultValues
33from checkm.common import reassignStdOut, restoreStdOut, binIdFromFilename
34from checkm.util.seqUtils import readFasta
35
36from numpy import mean, sqrt
37
38
39class CoverageStruct():
40    def __init__(self, seqLen, mappedReads, coverage):
41        self.seqLen = seqLen
42        self.mappedReads = mappedReads
43        self.coverage = coverage
44
45
46class Coverage():
47    """Calculate coverage of all sequences."""
48    def __init__(self, threads):
49        self.logger = logging.getLogger('timestamp')
50
51        self.totalThreads = threads
52
53    def run(self, binFiles, bamFiles, outFile, bAllReads, minAlignPer, maxEditDistPer, minQC):
54        """Calculate coverage of sequences for each BAM file."""
55
56        # determine bin assignment of each sequence
57        self.logger.info('Determining bin assignment of each sequence.')
58
59        seqIdToBinId = {}
60        seqIdToSeqLen = {}
61        for binFile in binFiles:
62            binId = binIdFromFilename(binFile)
63
64            seqs = readFasta(binFile)
65            for seqId, seq in seqs.iteritems():
66                seqIdToBinId[seqId] = binId
67                seqIdToSeqLen[seqId] = len(seq)
68
69        # process each fasta file
70        self.logger.info("Processing %d file(s) with %d threads.\n" % (len(bamFiles), self.totalThreads))
71
72        # make sure all BAM files are sorted
73        self.numFiles = len(bamFiles)
74        for bamFile in bamFiles:
75            if not os.path.exists(bamFile + '.bai'):
76                self.logger.error('  [Error] BAM file is either unsorted or not indexed: ' + bamFile + '\n')
77                sys.exit(1)
78
79        # calculate coverage of each BAM file
80        coverageInfo = {}
81        numFilesStarted = 0
82        for bamFile in bamFiles:
83            numFilesStarted += 1
84            self.logger.info('Processing %s (%d of %d):' % (ntpath.basename(bamFile), numFilesStarted, len(bamFiles)))
85
86            coverageInfo[bamFile] = mp.Manager().dict()
87            coverageInfo[bamFile] = self.__processBam(bamFile, bAllReads, minAlignPer, maxEditDistPer, minQC, coverageInfo[bamFile])
88
89        # redirect output
90        self.logger.info('Writing coverage information to file.')
91        oldStdOut = reassignStdOut(outFile)
92
93        header = 'Sequence Id\tBin Id\tSequence length (bp)'
94        for bamFile in bamFiles:
95            header += '\tBam Id\tCoverage\tMapped reads'
96
97        print(header)
98
99        # get length of all seqs
100        for bamFile, seqIds in coverageInfo.iteritems():
101            for seqId in seqIds.keys():
102                seqIdToSeqLen[seqId] = seqIds[seqId].seqLen
103
104        # write coverage stats for all scaffolds to file
105        for seqId, seqLen in seqIdToSeqLen.iteritems():
106            rowStr = seqId + '\t' + seqIdToBinId.get(seqId, DefaultValues.UNBINNED) + '\t' + str(seqLen)
107            for bamFile in bamFiles:
108                bamId = binIdFromFilename(bamFile)
109
110                if seqId in coverageInfo[bamFile]:
111                    rowStr += '\t%s\t%f\t%d' % (bamId, coverageInfo[bamFile][seqId].coverage, coverageInfo[bamFile][seqId].mappedReads)
112                else:
113                    rowStr += '\t%s\t%f\t%d' % (bamId, 0, 0)
114
115            print(rowStr)
116
117        # restore stdout
118        restoreStdOut(outFile, oldStdOut)
119
120    def __processBam(self, bamFile, bAllReads, minAlignPer, maxEditDistPer, minQC, coverageInfo):
121        """Calculate coverage of sequences in BAM file."""
122
123        # determine coverage for each reference sequence
124        workerQueue = mp.Queue()
125        writerQueue = mp.Queue()
126
127        bamfile = pysam.Samfile(bamFile, 'rb')
128        refSeqIds = bamfile.references
129        refSeqLens = bamfile.lengths
130
131        # populate each thread with reference sequence to process
132        # Note: reference sequences are sorted by number of mapped reads
133        # so it is important to distribute reads in a sensible way to each
134        # of the threads
135        refSeqLists = [[] for _ in range(self.totalThreads)]
136        refLenLists = [[] for _ in range(self.totalThreads)]
137
138        threadIndex = 0
139        incDir = 1
140        for refSeqId, refLen in zip(refSeqIds, refSeqLens):
141            refSeqLists[threadIndex].append(refSeqId)
142            refLenLists[threadIndex].append(refLen)
143
144            threadIndex += incDir
145            if threadIndex == self.totalThreads:
146                threadIndex = self.totalThreads - 1
147                incDir = -1
148            elif threadIndex == -1:
149                threadIndex = 0
150                incDir = 1
151
152        for i in range(self.totalThreads):
153            workerQueue.put((refSeqLists[i], refLenLists[i]))
154
155        for _ in range(self.totalThreads):
156            workerQueue.put((None, None))
157
158        try:
159            workerProc = [mp.Process(target=self.__workerThread, args=(bamFile, bAllReads, minAlignPer, maxEditDistPer, minQC, workerQueue, writerQueue)) for _ in range(self.totalThreads)]
160            writeProc = mp.Process(target=self.__writerThread, args=(coverageInfo, len(refSeqIds), writerQueue))
161
162            writeProc.start()
163
164            for p in workerProc:
165                p.start()
166
167            for p in workerProc:
168                p.join()
169
170            writerQueue.put((None, None, None, None, None, None, None, None, None, None, None))
171            writeProc.join()
172        except:
173            # make sure all processes are terminated
174            print traceback.format_exc()
175            for p in workerProc:
176                p.terminate()
177
178            writeProc.terminate()
179
180        return coverageInfo
181
182    def __workerThread(self, bamFile, bAllReads, minAlignPer, maxEditDistPer, minQC, queueIn, queueOut):
183        """Process each data item in parallel."""
184        while True:
185            seqIds, seqLens = queueIn.get(block=True, timeout=None)
186            if seqIds == None:
187                break
188
189            bamfile = pysam.Samfile(bamFile, 'rb')
190
191            for seqId, seqLen in zip(seqIds, seqLens):
192                numReads = 0
193                numMappedReads = 0
194                numDuplicates = 0
195                numSecondary = 0
196                numFailedQC = 0
197                numFailedAlignLen = 0
198                numFailedEditDist = 0
199                numFailedProperPair = 0
200                coverage = 0
201
202                for read in bamfile.fetch(seqId, 0, seqLen):
203                    numReads += 1
204
205                    if read.is_unmapped:
206                        pass
207                    elif read.is_duplicate:
208                        numDuplicates += 1
209                    elif read.is_secondary or read.is_supplementary:
210                        numSecondary += 1
211                    elif read.is_qcfail or read.mapping_quality < minQC:
212                        numFailedQC += 1
213                    elif read.query_alignment_length < minAlignPer * read.query_length:
214                        numFailedAlignLen += 1
215                    elif read.get_tag('NM') > maxEditDistPer * read.query_length:
216                        numFailedEditDist += 1
217                    elif not bAllReads and not read.is_proper_pair:
218                        numFailedProperPair += 1
219                    else:
220                        numMappedReads += 1
221
222                        # Note: the alignment length (query_alignment_length) is used instead of the
223                        # read length (query_length) as this bring the calculated coverage
224                        # in line with 'samtools depth' (at least when the min
225                        # alignment length and edit distance thresholds are zero).
226                        coverage += read.query_alignment_length
227
228                coverage = float(coverage) / seqLen
229
230                queueOut.put((seqId, seqLen, coverage, numReads,
231                                numDuplicates, numSecondary, numFailedQC,
232                                numFailedAlignLen, numFailedEditDist,
233                                numFailedProperPair, numMappedReads))
234
235            bamfile.close()
236
237    def __writerThread(self, coverageInfo, numRefSeqs, writerQueue):
238        """Store or write results of worker threads in a single thread."""
239        totalReads = 0
240        totalDuplicates = 0
241        totalSecondary = 0
242        totalFailedQC = 0
243        totalFailedAlignLen = 0
244        totalFailedEditDist = 0
245        totalFailedProperPair = 0
246        totalMappedReads = 0
247
248        processedRefSeqs = 0
249        while True:
250            seqId, seqLen, coverage, numReads, numDuplicates, numSecondary, numFailedQC, numFailedAlignLen, numFailedEditDist, numFailedProperPair, numMappedReads = writerQueue.get(block=True, timeout=None)
251            if seqId == None:
252                break
253
254            if self.logger.getEffectiveLevel() <= logging.INFO:
255                processedRefSeqs += 1
256                statusStr = '    Finished processing %d of %d (%.2f%%) reference sequences.' % (processedRefSeqs, numRefSeqs, float(processedRefSeqs) * 100 / numRefSeqs)
257                sys.stderr.write('%s\r' % statusStr)
258                sys.stderr.flush()
259
260                totalReads += numReads
261                totalDuplicates += numDuplicates
262                totalSecondary += numSecondary
263                totalFailedQC += numFailedQC
264                totalFailedAlignLen += numFailedAlignLen
265                totalFailedEditDist += numFailedEditDist
266                totalFailedProperPair += numFailedProperPair
267                totalMappedReads += numMappedReads
268
269            coverageInfo[seqId] = CoverageStruct(seqLen=seqLen, mappedReads=numMappedReads, coverage=coverage)
270
271        if self.logger.getEffectiveLevel() <= logging.INFO:
272            sys.stderr.write('\n')
273
274            print ''
275            print '    # total reads: %d' % totalReads
276            print '      # properly mapped reads: %d (%.1f%%)' % (totalMappedReads, float(totalMappedReads) * 100 / totalReads)
277            print '      # duplicate reads: %d (%.1f%%)' % (totalDuplicates, float(totalDuplicates) * 100 / totalReads)
278            print '      # secondary reads: %d (%.1f%%)' % (totalSecondary, float(totalSecondary) * 100 / totalReads)
279            print '      # reads failing QC: %d (%.1f%%)' % (totalFailedQC, float(totalFailedQC) * 100 / totalReads)
280            print '      # reads failing alignment length: %d (%.1f%%)' % (totalFailedAlignLen, float(totalFailedAlignLen) * 100 / totalReads)
281            print '      # reads failing edit distance: %d (%.1f%%)' % (totalFailedEditDist, float(totalFailedEditDist) * 100 / totalReads)
282            print '      # reads not properly paired: %d (%.1f%%)' % (totalFailedProperPair, float(totalFailedProperPair) * 100 / totalReads)
283            print ''
284
285    def parseCoverage(self, coverageFile):
286        """Read coverage information from file."""
287        coverageStats = {}
288        bHeader = True
289        for line in open(coverageFile):
290            if bHeader:
291                bHeader = False
292                continue
293
294            lineSplit = line.split('\t')
295            seqId = lineSplit[0]
296            binId = lineSplit[1]
297
298            if binId not in coverageStats:
299                coverageStats[binId] = {}
300
301            if seqId not in coverageStats[binId]:
302                coverageStats[binId][seqId] = {}
303
304            for i in xrange(3, len(lineSplit), 3):
305                bamId = lineSplit[i]
306                coverage = float(lineSplit[i + 1])
307                coverageStats[binId][seqId][bamId] = coverage
308
309        return coverageStats
310
311    def binProfiles(self, coverageFile):
312        """Read coverage information for each bin."""
313        binCoverages = defaultdict(lambda: defaultdict(list))
314        binStats = defaultdict(dict)
315
316        bHeader = True
317        for line in open(coverageFile):
318            if bHeader:
319                bHeader = False
320                continue
321
322            lineSplit = line.split('\t')
323            binId = lineSplit[1]
324            seqLen = int(lineSplit[2])
325
326            # calculate mean coverage (weighted by scaffold length)
327            # for each bin under each BAM file
328            for i in xrange(3, len(lineSplit), 3):
329                bamId = lineSplit[i]
330                coverage = float(lineSplit[i + 1])
331                binCoverages[binId][bamId].append(coverage)
332
333                if bamId not in binStats[binId]:
334                    binStats[binId][bamId] = [0, 0]
335
336                binLength = binStats[binId][bamId][0] + seqLen
337                weight = float(seqLen) / binLength
338                meanBinCoverage = coverage * weight + binStats[binId][bamId][1] * (1 - weight)
339
340                binStats[binId][bamId] = [binLength, meanBinCoverage]
341
342        profiles = defaultdict(dict)
343        for binId in binStats:
344            for bamId, stats in binStats[binId].iteritems():
345                binLength, meanBinCoverage = stats
346                coverages = binCoverages[binId][bamId]
347
348                varCoverage = 0
349                if len(coverages) > 1:
350                    varCoverage = mean(map(lambda x: (x - meanBinCoverage) ** 2, coverages))
351
352                profiles[binId][bamId] = [meanBinCoverage, sqrt(varCoverage)]
353
354        return profiles
355