1############################################################################### 2# 3# coverageWindows.py - calculate coverage of windows within sequences 4# 5############################################################################### 6# # 7# This program is free software: you can redistribute it and/or modify # 8# it under the terms of the GNU General Public License as published by # 9# the Free Software Foundation, either version 3 of the License, or # 10# (at your option) any later version. # 11# # 12# This program is distributed in the hope that it will be useful, # 13# but WITHOUT ANY WARRANTY; without even the implied warranty of # 14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # 15# GNU General Public License for more details. # 16# # 17# You should have received a copy of the GNU General Public License # 18# along with this program. If not, see <http://www.gnu.org/licenses/>. # 19# # 20############################################################################### 21 22import sys 23import os 24import multiprocessing as mp 25import logging 26 27import pysam 28 29from numpy import zeros 30 31 32class ReadLoader: 33 """Callback for counting aligned reads with pysam.fetch""" 34 35 def __init__(self, refLength, bAllReads, minAlignPer, maxEditDistPer): 36 self.bAllReads = bAllReads 37 self.minAlignPer = minAlignPer 38 self.maxEditDistPer = maxEditDistPer 39 40 self.numReads = 0 41 self.numMappedReads = 0 42 self.numDuplicates = 0 43 self.numSecondary = 0 44 self.numFailedQC = 0 45 self.numFailedAlignLen = 0 46 self.numFailedEditDist = 0 47 self.numFailedProperPair = 0 48 49 self.coverage = zeros(refLength) 50 51 def __call__(self, read): 52 self.numReads += 1 53 54 if read.is_unmapped: 55 pass 56 elif read.is_duplicate: 57 self.numDuplicates += 1 58 elif read.is_secondary: 59 self.numSecondary += 1 60 elif read.is_qcfail: 61 self.numFailedQC += 1 62 elif read.alen < self.minAlignPer * read.rlen: 63 self.numFailedAlignLen += 1 64 elif read.opt('NM') > self.maxEditDistPer * read.rlen: 65 self.numFailedEditDist += 1 66 elif not self.bAllReads and not read.is_proper_pair: 67 self.numFailedProperPair += 1 68 else: 69 self.numMappedReads += 1 70 71 # Note: the alignment length (alen) is used instead of the 72 # read length (rlen) as this bring the calculated coverage 73 # in line with 'samtools depth' (at least when the min 74 # alignment length and edit distance thresholds are zero). 75 self.coverage[read.pos:read.pos + read.alen] += 1.0 76 77 78class CoverageStruct(): 79 def __init__(self, seqLen, mappedReads, coverage): 80 self.seqLen = seqLen 81 self.mappedReads = mappedReads 82 self.coverage = coverage 83 84 85class CoverageWindows(): 86 """Calculate coverage of all sequences.""" 87 def __init__(self, threads): 88 self.logger = logging.getLogger('timestamp') 89 90 self.totalThreads = threads 91 92 def run(self, binFiles, bamFile, bAllReads, minAlignPer, maxEditDistPer, windowSize): 93 """Calculate coverage of full sequences and windows.""" 94 95 # make sure BAM file is sorted 96 if not os.path.exists(bamFile + '.bai'): 97 self.logger.error('BAM file is not sorted: ' + bamFile + '\n') 98 sys.exit(1) 99 100 # calculate coverage of each BAM file 101 self.logger.info('Calculating coverage of windows.') 102 coverageInfo = mp.Manager().dict() 103 coverageInfo = self.__processBam(bamFile, bAllReads, minAlignPer, maxEditDistPer, windowSize, coverageInfo) 104 105 return coverageInfo 106 107 def __processBam(self, bamFile, bAllReads, minAlignPer, maxEditDistPer, windowSize, coverageInfo): 108 """Calculate coverage of sequences in BAM file.""" 109 110 # determine coverage for each reference sequence 111 workerQueue = mp.Queue() 112 writerQueue = mp.Queue() 113 114 bamfile = pysam.Samfile(bamFile, 'rb') 115 refSeqIds = bamfile.references 116 refSeqLens = bamfile.lengths 117 118 # populate each thread with reference sequence to process 119 # Note: reference sequences are sorted by number of mapped reads 120 # so it is important to distribute reads in a sensible way to each 121 # of the threads 122 refSeqLists = [[] for _ in range(self.totalThreads)] 123 refLenLists = [[] for _ in range(self.totalThreads)] 124 125 threadIndex = 0 126 incDir = 1 127 for refSeqId, refLen in zip(refSeqIds, refSeqLens): 128 refSeqLists[threadIndex].append(refSeqId) 129 refLenLists[threadIndex].append(refLen) 130 131 threadIndex += incDir 132 if threadIndex == self.totalThreads: 133 threadIndex = self.totalThreads - 1 134 incDir = -1 135 elif threadIndex == -1: 136 threadIndex = 0 137 incDir = 1 138 139 for i in range(self.totalThreads): 140 workerQueue.put((refSeqLists[i], refLenLists[i])) 141 142 for _ in range(self.totalThreads): 143 workerQueue.put((None, None)) 144 145 try: 146 workerProc = [mp.Process(target=self.__workerThread, args=(bamFile, bAllReads, minAlignPer, maxEditDistPer, windowSize, workerQueue, writerQueue)) for _ in range(self.totalThreads)] 147 writeProc = mp.Process(target=self.__writerThread, args=(coverageInfo, len(refSeqIds), writerQueue)) 148 149 writeProc.start() 150 151 for p in workerProc: 152 p.start() 153 154 for p in workerProc: 155 p.join() 156 157 writerQueue.put((None, None, None, None, None, None, None, None, None, None, None, None)) 158 writeProc.join() 159 except: 160 # make sure all processes are terminated 161 for p in workerProc: 162 p.terminate() 163 164 writeProc.terminate() 165 166 return coverageInfo 167 168 def __workerThread(self, bamFile, bAllReads, minAlignPer, maxEditDistPer, windowSize, queueIn, queueOut): 169 """Process each data item in parallel.""" 170 while True: 171 seqIds, seqLens = queueIn.get(block=True, timeout=None) 172 if seqIds == None: 173 break 174 175 bamfile = pysam.Samfile(bamFile, 'rb') 176 177 for seqId, seqLen in zip(seqIds, seqLens): 178 readLoader = ReadLoader(seqLen, bAllReads, minAlignPer, maxEditDistPer) 179 bamfile.fetch(seqId, 0, seqLen, callback=readLoader) 180 181 start = 0 182 end = windowSize 183 windowCoverages = [] 184 while(end < seqLen): 185 windowCoverages.append(sum(readLoader.coverage[start:end]) / windowSize) 186 187 start = end 188 try: 189 end += windowSize 190 except: 191 print '*****************' 192 print end 193 print windowSize 194 print '******************' 195 196 coverage = float(sum(readLoader.coverage)) / seqLen 197 198 queueOut.put((seqId, seqLen, coverage, windowCoverages, readLoader.numReads, 199 readLoader.numDuplicates, readLoader.numSecondary, readLoader.numFailedQC, 200 readLoader.numFailedAlignLen, readLoader.numFailedEditDist, 201 readLoader.numFailedProperPair, readLoader.numMappedReads)) 202 203 bamfile.close() 204 205 def __writerThread(self, coverageInfo, numRefSeqs, writerQueue): 206 """Store or write results of worker threads in a single thread.""" 207 totalReads = 0 208 totalDuplicates = 0 209 totalSecondary = 0 210 totalFailedQC = 0 211 totalFailedAlignLen = 0 212 totalFailedEditDist = 0 213 totalFailedProperPair = 0 214 totalMappedReads = 0 215 216 processedRefSeqs = 0 217 while True: 218 seqId, seqLen, coverage, windowCoverages, numReads, numDuplicates, numSecondary, numFailedQC, numFailedAlignLen, numFailedEditDist, numFailedProperPair, numMappedReads = writerQueue.get(block=True, timeout=None) 219 if seqId == None: 220 break 221 222 if self.logger.getEffectiveLevel() <= logging.INFO: 223 processedRefSeqs += 1 224 statusStr = ' Finished processing %d of %d (%.2f%%) reference sequences.' % (processedRefSeqs, numRefSeqs, float(processedRefSeqs) * 100 / numRefSeqs) 225 sys.stderr.write('%s\r' % statusStr) 226 sys.stderr.flush() 227 228 totalReads += numReads 229 totalDuplicates += numDuplicates 230 totalSecondary += numSecondary 231 totalFailedQC += numFailedQC 232 totalFailedAlignLen += numFailedAlignLen 233 totalFailedEditDist += numFailedEditDist 234 totalFailedProperPair += numFailedProperPair 235 totalMappedReads += numMappedReads 236 237 coverageInfo[seqId] = [coverage, windowCoverages] 238 239 if self.logger.getEffectiveLevel() <= logging.INFO: 240 sys.stderr.write('\n') 241 242 print '' 243 print ' # total reads: %d' % totalReads 244 print ' # properly mapped reads: %d (%.1f%%)' % (totalMappedReads, float(totalMappedReads) * 100 / totalReads) 245 print ' # duplicate reads: %d (%.1f%%)' % (totalDuplicates, float(totalDuplicates) * 100 / totalReads) 246 print ' # secondary reads: %d (%.1f%%)' % (totalSecondary, float(totalSecondary) * 100 / totalReads) 247 print ' # reads failing QC: %d (%.1f%%)' % (totalFailedQC, float(totalFailedQC) * 100 / totalReads) 248 print ' # reads failing alignment length: %d (%.1f%%)' % (totalFailedAlignLen, float(totalFailedAlignLen) * 100 / totalReads) 249 print ' # reads failing edit distance: %d (%.1f%%)' % (totalFailedEditDist, float(totalFailedEditDist) * 100 / totalReads) 250 print ' # reads not properly paired: %d (%.1f%%)' % (totalFailedProperPair, float(totalFailedProperPair) * 100 / totalReads) 251 print '' 252