1############################################################################### 2# 3# coverage.py - calculate coverage of all sequences 4# 5############################################################################### 6# # 7# This program is free software: you can redistribute it and/or modify # 8# it under the terms of the GNU General Public License as published by # 9# the Free Software Foundation, either version 3 of the License, or # 10# (at your option) any later version. # 11# # 12# This program is distributed in the hope that it will be useful, # 13# but WITHOUT ANY WARRANTY; without even the implied warranty of # 14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # 15# GNU General Public License for more details. # 16# # 17# You should have received a copy of the GNU General Public License # 18# along with this program. If not, see <http://www.gnu.org/licenses/>. # 19# # 20############################################################################### 21 22import sys 23import os 24import multiprocessing as mp 25import logging 26import ntpath 27import traceback 28from collections import defaultdict 29 30import pysam 31 32from checkm.defaultValues import DefaultValues 33from checkm.common import reassignStdOut, restoreStdOut, binIdFromFilename 34from checkm.util.seqUtils import readFasta 35 36from numpy import mean, sqrt 37 38 39class CoverageStruct(): 40 def __init__(self, seqLen, mappedReads, coverage): 41 self.seqLen = seqLen 42 self.mappedReads = mappedReads 43 self.coverage = coverage 44 45 46class Coverage(): 47 """Calculate coverage of all sequences.""" 48 def __init__(self, threads): 49 self.logger = logging.getLogger('timestamp') 50 51 self.totalThreads = threads 52 53 def run(self, binFiles, bamFiles, outFile, bAllReads, minAlignPer, maxEditDistPer, minQC): 54 """Calculate coverage of sequences for each BAM file.""" 55 56 # determine bin assignment of each sequence 57 self.logger.info('Determining bin assignment of each sequence.') 58 59 seqIdToBinId = {} 60 seqIdToSeqLen = {} 61 for binFile in binFiles: 62 binId = binIdFromFilename(binFile) 63 64 seqs = readFasta(binFile) 65 for seqId, seq in seqs.iteritems(): 66 seqIdToBinId[seqId] = binId 67 seqIdToSeqLen[seqId] = len(seq) 68 69 # process each fasta file 70 self.logger.info("Processing %d file(s) with %d threads.\n" % (len(bamFiles), self.totalThreads)) 71 72 # make sure all BAM files are sorted 73 self.numFiles = len(bamFiles) 74 for bamFile in bamFiles: 75 if not os.path.exists(bamFile + '.bai'): 76 self.logger.error(' [Error] BAM file is either unsorted or not indexed: ' + bamFile + '\n') 77 sys.exit(1) 78 79 # calculate coverage of each BAM file 80 coverageInfo = {} 81 numFilesStarted = 0 82 for bamFile in bamFiles: 83 numFilesStarted += 1 84 self.logger.info('Processing %s (%d of %d):' % (ntpath.basename(bamFile), numFilesStarted, len(bamFiles))) 85 86 coverageInfo[bamFile] = mp.Manager().dict() 87 coverageInfo[bamFile] = self.__processBam(bamFile, bAllReads, minAlignPer, maxEditDistPer, minQC, coverageInfo[bamFile]) 88 89 # redirect output 90 self.logger.info('Writing coverage information to file.') 91 oldStdOut = reassignStdOut(outFile) 92 93 header = 'Sequence Id\tBin Id\tSequence length (bp)' 94 for bamFile in bamFiles: 95 header += '\tBam Id\tCoverage\tMapped reads' 96 97 print(header) 98 99 # get length of all seqs 100 for bamFile, seqIds in coverageInfo.iteritems(): 101 for seqId in seqIds.keys(): 102 seqIdToSeqLen[seqId] = seqIds[seqId].seqLen 103 104 # write coverage stats for all scaffolds to file 105 for seqId, seqLen in seqIdToSeqLen.iteritems(): 106 rowStr = seqId + '\t' + seqIdToBinId.get(seqId, DefaultValues.UNBINNED) + '\t' + str(seqLen) 107 for bamFile in bamFiles: 108 bamId = binIdFromFilename(bamFile) 109 110 if seqId in coverageInfo[bamFile]: 111 rowStr += '\t%s\t%f\t%d' % (bamId, coverageInfo[bamFile][seqId].coverage, coverageInfo[bamFile][seqId].mappedReads) 112 else: 113 rowStr += '\t%s\t%f\t%d' % (bamId, 0, 0) 114 115 print(rowStr) 116 117 # restore stdout 118 restoreStdOut(outFile, oldStdOut) 119 120 def __processBam(self, bamFile, bAllReads, minAlignPer, maxEditDistPer, minQC, coverageInfo): 121 """Calculate coverage of sequences in BAM file.""" 122 123 # determine coverage for each reference sequence 124 workerQueue = mp.Queue() 125 writerQueue = mp.Queue() 126 127 bamfile = pysam.Samfile(bamFile, 'rb') 128 refSeqIds = bamfile.references 129 refSeqLens = bamfile.lengths 130 131 # populate each thread with reference sequence to process 132 # Note: reference sequences are sorted by number of mapped reads 133 # so it is important to distribute reads in a sensible way to each 134 # of the threads 135 refSeqLists = [[] for _ in range(self.totalThreads)] 136 refLenLists = [[] for _ in range(self.totalThreads)] 137 138 threadIndex = 0 139 incDir = 1 140 for refSeqId, refLen in zip(refSeqIds, refSeqLens): 141 refSeqLists[threadIndex].append(refSeqId) 142 refLenLists[threadIndex].append(refLen) 143 144 threadIndex += incDir 145 if threadIndex == self.totalThreads: 146 threadIndex = self.totalThreads - 1 147 incDir = -1 148 elif threadIndex == -1: 149 threadIndex = 0 150 incDir = 1 151 152 for i in range(self.totalThreads): 153 workerQueue.put((refSeqLists[i], refLenLists[i])) 154 155 for _ in range(self.totalThreads): 156 workerQueue.put((None, None)) 157 158 try: 159 workerProc = [mp.Process(target=self.__workerThread, args=(bamFile, bAllReads, minAlignPer, maxEditDistPer, minQC, workerQueue, writerQueue)) for _ in range(self.totalThreads)] 160 writeProc = mp.Process(target=self.__writerThread, args=(coverageInfo, len(refSeqIds), writerQueue)) 161 162 writeProc.start() 163 164 for p in workerProc: 165 p.start() 166 167 for p in workerProc: 168 p.join() 169 170 writerQueue.put((None, None, None, None, None, None, None, None, None, None, None)) 171 writeProc.join() 172 except: 173 # make sure all processes are terminated 174 print traceback.format_exc() 175 for p in workerProc: 176 p.terminate() 177 178 writeProc.terminate() 179 180 return coverageInfo 181 182 def __workerThread(self, bamFile, bAllReads, minAlignPer, maxEditDistPer, minQC, queueIn, queueOut): 183 """Process each data item in parallel.""" 184 while True: 185 seqIds, seqLens = queueIn.get(block=True, timeout=None) 186 if seqIds == None: 187 break 188 189 bamfile = pysam.Samfile(bamFile, 'rb') 190 191 for seqId, seqLen in zip(seqIds, seqLens): 192 numReads = 0 193 numMappedReads = 0 194 numDuplicates = 0 195 numSecondary = 0 196 numFailedQC = 0 197 numFailedAlignLen = 0 198 numFailedEditDist = 0 199 numFailedProperPair = 0 200 coverage = 0 201 202 for read in bamfile.fetch(seqId, 0, seqLen): 203 numReads += 1 204 205 if read.is_unmapped: 206 pass 207 elif read.is_duplicate: 208 numDuplicates += 1 209 elif read.is_secondary or read.is_supplementary: 210 numSecondary += 1 211 elif read.is_qcfail or read.mapping_quality < minQC: 212 numFailedQC += 1 213 elif read.query_alignment_length < minAlignPer * read.query_length: 214 numFailedAlignLen += 1 215 elif read.get_tag('NM') > maxEditDistPer * read.query_length: 216 numFailedEditDist += 1 217 elif not bAllReads and not read.is_proper_pair: 218 numFailedProperPair += 1 219 else: 220 numMappedReads += 1 221 222 # Note: the alignment length (query_alignment_length) is used instead of the 223 # read length (query_length) as this bring the calculated coverage 224 # in line with 'samtools depth' (at least when the min 225 # alignment length and edit distance thresholds are zero). 226 coverage += read.query_alignment_length 227 228 coverage = float(coverage) / seqLen 229 230 queueOut.put((seqId, seqLen, coverage, numReads, 231 numDuplicates, numSecondary, numFailedQC, 232 numFailedAlignLen, numFailedEditDist, 233 numFailedProperPair, numMappedReads)) 234 235 bamfile.close() 236 237 def __writerThread(self, coverageInfo, numRefSeqs, writerQueue): 238 """Store or write results of worker threads in a single thread.""" 239 totalReads = 0 240 totalDuplicates = 0 241 totalSecondary = 0 242 totalFailedQC = 0 243 totalFailedAlignLen = 0 244 totalFailedEditDist = 0 245 totalFailedProperPair = 0 246 totalMappedReads = 0 247 248 processedRefSeqs = 0 249 while True: 250 seqId, seqLen, coverage, numReads, numDuplicates, numSecondary, numFailedQC, numFailedAlignLen, numFailedEditDist, numFailedProperPair, numMappedReads = writerQueue.get(block=True, timeout=None) 251 if seqId == None: 252 break 253 254 if self.logger.getEffectiveLevel() <= logging.INFO: 255 processedRefSeqs += 1 256 statusStr = ' Finished processing %d of %d (%.2f%%) reference sequences.' % (processedRefSeqs, numRefSeqs, float(processedRefSeqs) * 100 / numRefSeqs) 257 sys.stderr.write('%s\r' % statusStr) 258 sys.stderr.flush() 259 260 totalReads += numReads 261 totalDuplicates += numDuplicates 262 totalSecondary += numSecondary 263 totalFailedQC += numFailedQC 264 totalFailedAlignLen += numFailedAlignLen 265 totalFailedEditDist += numFailedEditDist 266 totalFailedProperPair += numFailedProperPair 267 totalMappedReads += numMappedReads 268 269 coverageInfo[seqId] = CoverageStruct(seqLen=seqLen, mappedReads=numMappedReads, coverage=coverage) 270 271 if self.logger.getEffectiveLevel() <= logging.INFO: 272 sys.stderr.write('\n') 273 274 print '' 275 print ' # total reads: %d' % totalReads 276 print ' # properly mapped reads: %d (%.1f%%)' % (totalMappedReads, float(totalMappedReads) * 100 / totalReads) 277 print ' # duplicate reads: %d (%.1f%%)' % (totalDuplicates, float(totalDuplicates) * 100 / totalReads) 278 print ' # secondary reads: %d (%.1f%%)' % (totalSecondary, float(totalSecondary) * 100 / totalReads) 279 print ' # reads failing QC: %d (%.1f%%)' % (totalFailedQC, float(totalFailedQC) * 100 / totalReads) 280 print ' # reads failing alignment length: %d (%.1f%%)' % (totalFailedAlignLen, float(totalFailedAlignLen) * 100 / totalReads) 281 print ' # reads failing edit distance: %d (%.1f%%)' % (totalFailedEditDist, float(totalFailedEditDist) * 100 / totalReads) 282 print ' # reads not properly paired: %d (%.1f%%)' % (totalFailedProperPair, float(totalFailedProperPair) * 100 / totalReads) 283 print '' 284 285 def parseCoverage(self, coverageFile): 286 """Read coverage information from file.""" 287 coverageStats = {} 288 bHeader = True 289 for line in open(coverageFile): 290 if bHeader: 291 bHeader = False 292 continue 293 294 lineSplit = line.split('\t') 295 seqId = lineSplit[0] 296 binId = lineSplit[1] 297 298 if binId not in coverageStats: 299 coverageStats[binId] = {} 300 301 if seqId not in coverageStats[binId]: 302 coverageStats[binId][seqId] = {} 303 304 for i in xrange(3, len(lineSplit), 3): 305 bamId = lineSplit[i] 306 coverage = float(lineSplit[i + 1]) 307 coverageStats[binId][seqId][bamId] = coverage 308 309 return coverageStats 310 311 def binProfiles(self, coverageFile): 312 """Read coverage information for each bin.""" 313 binCoverages = defaultdict(lambda: defaultdict(list)) 314 binStats = defaultdict(dict) 315 316 bHeader = True 317 for line in open(coverageFile): 318 if bHeader: 319 bHeader = False 320 continue 321 322 lineSplit = line.split('\t') 323 binId = lineSplit[1] 324 seqLen = int(lineSplit[2]) 325 326 # calculate mean coverage (weighted by scaffold length) 327 # for each bin under each BAM file 328 for i in xrange(3, len(lineSplit), 3): 329 bamId = lineSplit[i] 330 coverage = float(lineSplit[i + 1]) 331 binCoverages[binId][bamId].append(coverage) 332 333 if bamId not in binStats[binId]: 334 binStats[binId][bamId] = [0, 0] 335 336 binLength = binStats[binId][bamId][0] + seqLen 337 weight = float(seqLen) / binLength 338 meanBinCoverage = coverage * weight + binStats[binId][bamId][1] * (1 - weight) 339 340 binStats[binId][bamId] = [binLength, meanBinCoverage] 341 342 profiles = defaultdict(dict) 343 for binId in binStats: 344 for bamId, stats in binStats[binId].iteritems(): 345 binLength, meanBinCoverage = stats 346 coverages = binCoverages[binId][bamId] 347 348 varCoverage = 0 349 if len(coverages) > 1: 350 varCoverage = mean(map(lambda x: (x - meanBinCoverage) ** 2, coverages)) 351 352 profiles[binId][bamId] = [meanBinCoverage, sqrt(varCoverage)] 353 354 return profiles 355