1############################################################################### 2# 3# hmmerAlign.py - runs HMMER and provides functions for parsing output 4# 5############################################################################### 6# # 7# This program is free software: you can redistribute it and/or modify # 8# it under the terms of the GNU General Public License as published by # 9# the Free Software Foundation, either version 3 of the License, or # 10# (at your option) any later version. # 11# # 12# This program is distributed in the hope that it will be useful, # 13# but WITHOUT ANY WARRANTY; without even the implied warranty of # 14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # 15# GNU General Public License for more details. # 16# # 17# You should have received a copy of the GNU General Public License # 18# along with this program. If not, see <http://www.gnu.org/licenses/>. # 19# # 20############################################################################### 21 22import os 23import sys 24import uuid 25import logging 26import tempfile 27import shutil 28import multiprocessing as mp 29from collections import defaultdict 30 31from checkm.defaultValues import DefaultValues 32from checkm.common import makeSurePathExists 33from checkm.util.seqUtils import readFasta 34from checkm.hmmer import HMMERRunner 35from checkm.resultsParser import ResultsParser 36 37 38class HmmerAligner: 39 def __init__(self, threads): 40 self.logger = logging.getLogger('timestamp') 41 self.totalThreads = threads 42 43 self.outputFormat = 'Pfam' 44 45 def makeAlignmentTopHit(self, 46 outDir, 47 hmmModelFile, 48 hmmTableFile, 49 binIdToModels, 50 bIgnoreThresholds, 51 evalueThreshold, 52 lengthThreshold, 53 bReportHitStats, 54 alignOutputDir, 55 bKeepUnmaskedAlign=False 56 ): 57 """Align top hits in each bin. Assumes all bins are using the same marker genes.""" 58 59 self.logger.info("Extracting marker genes to align.") 60 61 # parse HMM information 62 resultsParser = ResultsParser(binIdToModels) 63 64 # get HMM hits to each bin 65 resultsParser.parseBinHits(outDir, hmmTableFile, False, bIgnoreThresholds, evalueThreshold, lengthThreshold) 66 67 # extract the ORFs to align 68 markerSeqs, markerStats = self.__extractMarkerSeqsTopHits(outDir, resultsParser) 69 70 # generate individual HMMs required to create multiple sequence alignments 71 binId = binIdToModels.keys()[0] 72 hmmModelFiles = {} 73 self.__makeAlignmentModels(hmmModelFile, binIdToModels[binId], hmmModelFiles) 74 75 # align each of the marker genes 76 makeSurePathExists(alignOutputDir) 77 self.__alignMarkerGenes(markerSeqs, markerStats, bReportHitStats, hmmModelFiles, alignOutputDir, bKeepUnmaskedAlign) 78 79 # remove the temporary HMM files 80 for fileName in hmmModelFiles: 81 os.remove(hmmModelFiles[fileName]) 82 83 return resultsParser 84 85 def makeAlignmentToPhyloMarkers(self, 86 outDir, 87 hmmModelFile, 88 hmmTableFile, 89 binIdToModels, 90 bIgnoreThresholds, 91 evalueThreshold, 92 lengthThreshold, 93 bReportHitStats, 94 alignOutputDir, 95 bKeepUnmaskedAlign=False 96 ): 97 """Align hits to a set of common marker genes.""" 98 99 self.logger.info("Extracting marker genes to align.") 100 101 # parse HMM information 102 resultsParser = ResultsParser(binIdToModels) 103 104 # get HMM hits to each bin 105 resultsParser.parseBinHits(outDir, hmmTableFile, False, bIgnoreThresholds, evalueThreshold, lengthThreshold) 106 107 # extract the ORFs to align 108 markerSeqs, markerStats = self.__extractMarkerSeqsUnique(outDir, resultsParser) 109 110 # generate individual HMMs required to create multiple sequence alignments 111 binId = binIdToModels.keys()[0] 112 hmmModelFiles = {} 113 self.__makeAlignmentModels(hmmModelFile, binIdToModels[binId], hmmModelFiles) 114 115 # align each of the marker genes 116 makeSurePathExists(alignOutputDir) 117 self.__alignMarkerGenes(markerSeqs, markerStats, bReportHitStats, hmmModelFiles, alignOutputDir, bKeepUnmaskedAlign) 118 119 # remove the temporary HMM files 120 for fileName in hmmModelFiles: 121 os.remove(hmmModelFiles[fileName]) 122 123 return resultsParser 124 125 def makeAlignmentsOfMultipleHits(self, 126 outDir, 127 markerFile, 128 hmmTableFile, 129 binIdToModels, 130 binIdToBinMarkerSets, 131 bIgnoreThresholds, 132 evalueThreshold, 133 lengthThreshold, 134 alignOutputDir, 135 ): 136 """Align markers with multiple hits within a bin.""" 137 138 makeSurePathExists(alignOutputDir) 139 140 # parse HMM information 141 resultsParser = ResultsParser(binIdToModels) 142 143 # get HMM hits to each bin 144 resultsParser.parseBinHits(outDir, hmmTableFile, False, bIgnoreThresholds, evalueThreshold, lengthThreshold) 145 146 # align any markers with multiple hits in a bin 147 self.logger.info('Aligning marker genes with multiple hits in a single bin:') 148 149 # process each bin in parallel 150 workerQueue = mp.Queue() 151 writerQueue = mp.Queue() 152 153 for binId in binIdToModels: 154 workerQueue.put(binId) 155 156 for _ in range(self.totalThreads): 157 workerQueue.put(None) 158 159 try: 160 calcProc = [mp.Process(target=self.__createMSA, args=(resultsParser, binIdToBinMarkerSets, markerFile, outDir, alignOutputDir, workerQueue, writerQueue)) for _ in range(self.totalThreads)] 161 writeProc = mp.Process(target=self.__reportBinProgress, args=(len(binIdToModels), writerQueue)) 162 163 writeProc.start() 164 165 for p in calcProc: 166 p.start() 167 168 for p in calcProc: 169 p.join() 170 171 writerQueue.put(None) 172 writeProc.join() 173 except: 174 # make sure all processes are terminated 175 for p in calcProc: 176 p.terminate() 177 178 writeProc.terminate() 179 180 def __createMSA(self, resultsParser, binIdToBinMarkerSets, hmmModelFile, outDir, alignOutputDir, queueIn, queueOut): 181 """Create multiple sequence alignment for markers with multiple hits in a bin.""" 182 183 HF = HMMERRunner(mode='fetch') 184 185 while True: 186 binId = queueIn.get(block=True, timeout=None) 187 if binId == None: 188 break 189 190 markersWithMultipleHits = self.__extractMarkersWithMultipleHits(outDir, binId, resultsParser, binIdToBinMarkerSets[binId]) 191 192 if len(markersWithMultipleHits) != 0: 193 # create multiple sequence alignments for markers with multiple hits 194 binAlignOutputDir = os.path.join(alignOutputDir, binId) 195 makeSurePathExists(binAlignOutputDir) 196 for markerId in markersWithMultipleHits: 197 tempModelFile = os.path.join(tempfile.gettempdir(), str(uuid.uuid4())) 198 HF.fetch(hmmModelFile, markerId, tempModelFile) 199 200 self.__alignMarker(markerId, markersWithMultipleHits[markerId], None, False, binAlignOutputDir, tempModelFile, bKeepUnmaskedAlign=False) 201 202 os.remove(tempModelFile) 203 204 queueOut.put(binId) 205 206 def __reportBinProgress(self, numBins, queueIn): 207 """Report number of processed bins.""" 208 209 numProcessedBins = 0 210 if self.logger.getEffectiveLevel() <= logging.INFO: 211 statusStr = ' Finished processing %d of %d (%.2f%%) bins.' % (numProcessedBins, numBins, float(numProcessedBins) * 100 / numBins) 212 sys.stderr.write('%s\r' % statusStr) 213 sys.stderr.flush() 214 215 while True: 216 binId = queueIn.get(block=True, timeout=None) 217 if binId == None: 218 break 219 220 if self.logger.getEffectiveLevel() <= logging.INFO: 221 numProcessedBins += 1 222 statusStr = ' Finished processing %d of %d (%.2f%%) bins.' % (numProcessedBins, numBins, float(numProcessedBins) * 100 / numBins) 223 sys.stderr.write('%s\r' % statusStr) 224 sys.stderr.flush() 225 226 if self.logger.getEffectiveLevel() <= logging.INFO: 227 sys.stderr.write('\n') 228 229 def __alignMarkerGenes(self, markerSeqs, markerStats, bReportHitStats, hmmModelFiles, alignOutputDir, bKeepUnmaskedAlign=False, bReportProgress=True): 230 """Align marker genes with HMMs in parallel.""" 231 232 if bReportProgress: 233 self.logger.info("Aligning %d marker genes with %d threads:" % (len(hmmModelFiles), self.totalThreads)) 234 235 # process each bin in parallel 236 workerQueue = mp.Queue() 237 writerQueue = mp.Queue() 238 239 for markerId in hmmModelFiles: 240 workerQueue.put(markerId) 241 242 for _ in range(self.totalThreads): 243 workerQueue.put(None) 244 245 try: 246 calcProc = [mp.Process(target=self.__alignMarkerParallel, args=(markerSeqs, markerStats, bReportHitStats, alignOutputDir, hmmModelFiles, bKeepUnmaskedAlign, workerQueue, writerQueue)) for _ in range(self.totalThreads)] 247 writeProc = mp.Process(target=self.__reportAlignmentProgress, args=(len(hmmModelFiles), bReportProgress, writerQueue)) 248 249 writeProc.start() 250 251 for p in calcProc: 252 p.start() 253 254 for p in calcProc: 255 p.join() 256 257 writerQueue.put(None) 258 writeProc.join() 259 except: 260 # make sure all processes are terminated 261 for p in calcProc: 262 p.terminate() 263 264 writeProc.terminate() 265 266 def __alignMarkerParallel(self, markerSeqs, markerStats, bReportHitStats, alignOutputDir, hmmModelFiles, bKeepUnmaskedAlign, queueIn, queueOut): 267 while True: 268 markerId = queueIn.get(block=True, timeout=None) 269 if markerId == None: 270 break 271 272 self.__alignMarker(markerId, markerSeqs[markerId], markerStats[markerId], bReportHitStats, alignOutputDir, hmmModelFiles[markerId], bKeepUnmaskedAlign) 273 274 queueOut.put(markerId) 275 276 def __alignMarker(self, markerId, binSeqs, binStats, bReportHitStats, alignOutputDir, hmmModelFile, bKeepUnmaskedAlign): 277 unalignSeqFile = os.path.join(alignOutputDir, markerId + '.unaligned.faa') 278 fout = open(unalignSeqFile, 'w') 279 numSeqs = 0 280 for binId, seqs in binSeqs.iteritems(): 281 for seqId, seq in seqs.iteritems(): 282 header = '>' + binId + DefaultValues.SEQ_CONCAT_CHAR + seqId 283 if bReportHitStats: 284 header += ' [e-value=%.4g,score=%.1f]' % (binStats[binId][seqId][0], binStats[binId][seqId][1]) 285 286 fout.write(header + '\n') 287 fout.write(seq + '\n') 288 numSeqs += 1 289 fout.close() 290 291 if numSeqs > 0: 292 alignSeqFile = os.path.join(alignOutputDir, markerId + '.aligned.faa') 293 HA = HMMERRunner(mode='align') 294 HA.align(hmmModelFile, unalignSeqFile, alignSeqFile, writeMode='>', outputFormat=self.outputFormat, trim=False) 295 296 makedSeqFile = os.path.join(alignOutputDir, markerId + '.masked.faa') 297 self.__maskAlignment(alignSeqFile, makedSeqFile) 298 299 if not bKeepUnmaskedAlign: 300 os.remove(alignSeqFile) 301 302 os.remove(unalignSeqFile) 303 304 def __reportAlignmentProgress(self, numMarkers, bReportProgress, queueIn): 305 """Report number of processed markers.""" 306 307 numProcessedGenes = 0 308 if bReportProgress and self.logger.getEffectiveLevel() <= logging.INFO: 309 statusStr = ' Finished aligning %d of %d (%.2f%%) marker genes.' % (numProcessedGenes, numMarkers, float(numProcessedGenes) * 100 / numMarkers) 310 sys.stderr.write('%s\r' % statusStr) 311 sys.stderr.flush() 312 313 while True: 314 binId = queueIn.get(block=True, timeout=None) 315 if binId == None: 316 break 317 318 if bReportProgress and self.logger.getEffectiveLevel() <= logging.INFO: 319 numProcessedGenes += 1 320 statusStr = ' Finished aligning %d of %d (%.2f%%) marker genes.' % (numProcessedGenes, numMarkers, float(numProcessedGenes) * 100 / numMarkers) 321 sys.stderr.write('%s\r' % statusStr) 322 sys.stderr.flush() 323 324 if bReportProgress and self.logger.getEffectiveLevel() <= logging.INFO: 325 sys.stderr.write('\n') 326 327 def __maskAlignment(self, inputFile, outputFile): 328 """Read HMMER alignment in STOCKHOLM format and output masked alignment in FASTA format.""" 329 # read STOCKHOLM alignment 330 seqs = {} 331 seqStats = {} 332 for line in open(inputFile): 333 line = line.rstrip() 334 if line == '' or line[0] == '#' or line == '//': 335 if 'GC RF' in line: 336 mask = line.split('GC RF')[1].strip() 337 elif '=GS' in line: 338 # read additional sequence informations 339 lineSplit = line.split() 340 seqId = lineSplit[1] 341 stats = lineSplit[3].strip() 342 seqStats[seqId] = stats 343 continue 344 else: 345 lineSplit = line.split() 346 seqs[lineSplit[0]] = lineSplit[1].upper().replace('.', '-').strip() 347 348 # output masked sequences in FASTA format 349 fout = open(outputFile, 'w') 350 for seqId, seq in seqs.iteritems(): 351 if seqStats: 352 fout.write('>%s %s\n' % (seqId, seqStats[seqId])) 353 else: 354 fout.write('>' + seqId + '\n') 355 356 maskedSeq = ''.join([seq[i] for i in xrange(0, len(seq)) if mask[i] == 'x']) 357 fout.write(maskedSeq + '\n') 358 fout.close() 359 360 def __extractMarkerSeqsTopHits(self, outDir, resultsParser): 361 """Extract marker sequences from top hits (assume all bins use the same HMM file).""" 362 363 markerSeqs = defaultdict(dict) 364 markerStats = defaultdict(dict) 365 for binId in resultsParser.results: 366 # read ORFs for bin 367 aaGeneFile = os.path.join(outDir, 'bins', binId, DefaultValues.PRODIGAL_AA) 368 binORFs = readFasta(aaGeneFile) 369 370 # extract ORFs hitting a marker 371 for markerId, hits in resultsParser.results[binId].markerHits.iteritems(): 372 markerSeqs[markerId][binId] = {} 373 markerStats[markerId][binId] = {} 374 375 # sort hits from highest to lowest e-value in order to ensure only the best hit 376 # to a given target is retained 377 hits.sort(key=lambda x: x.full_e_value, reverse=True) 378 topHit = hits[0] 379 markerSeqs[markerId][binId][topHit.target_name] = self.__extractSeq(topHit.target_name, binORFs) 380 markerStats[markerId][binId][topHit.target_name] = [topHit.full_e_value, topHit.full_score] 381 382 return markerSeqs, markerStats 383 384 def __extractMarkerSeqsUnique(self, outDir, resultsParser): 385 """Extract marker sequences with a single unique hit.""" 386 387 markerSeqs = defaultdict(dict) 388 markerStats = defaultdict(dict) 389 for binId in resultsParser.results: 390 # read ORFs for bin 391 aaGeneFile = os.path.join(outDir, 'bins', binId, DefaultValues.PRODIGAL_AA) 392 binORFs = readFasta(aaGeneFile) 393 394 # extract ORFs hitting a marker 395 for markerId, hits in resultsParser.results[binId].markerHits.iteritems(): 396 markerSeqs[markerId][binId] = {} 397 markerStats[markerId][binId] = {} 398 399 # only record hits which are unique 400 if len(hits) == 1: 401 hit = hits[0] 402 markerSeqs[markerId][binId][hit.target_name] = self.__extractSeq(hit.target_name, binORFs) 403 markerStats[markerId][binId][hit.target_name] = [hit.full_e_value, hit.full_score] 404 405 return markerSeqs, markerStats 406 407 def __extractSeq(self, seqId, seqs): 408 """Extract sequence data.""" 409 410 if DefaultValues.SEQ_CONCAT_CHAR in seqId: 411 seqIds = seqId.split(DefaultValues.SEQ_CONCAT_CHAR) 412 413 seq = '' 414 for seqId in seqIds: 415 tempSeq = seqs[seqId] 416 if tempSeq[-1] == '*': 417 tempSeq = tempSeq[0:-1] # remove final '*' inserted by prodigal 418 419 seq += tempSeq 420 421 rtnSeq = seq 422 else: 423 rtnSeq = seqs[seqId] 424 425 if rtnSeq[-1] == '*': 426 rtnSeq = rtnSeq[0:-1] # remove final '*' inserted by prodigal 427 428 return rtnSeq 429 430 def __extractMarkersWithMultipleHits(self, outDir, binId, resultsParser, binMarkerSet): 431 """Extract markers with multiple hits within a single bin.""" 432 433 markersWithMultipleHits = defaultdict(dict) 434 435 aaGeneFile = os.path.join(outDir, 'bins', binId, DefaultValues.PRODIGAL_AA) 436 binORFs = readFasta(aaGeneFile) 437 438 markerGenes = binMarkerSet.selectedMarkerSet().getMarkerGenes() 439 for markerId, hits in resultsParser.results[binId].markerHits.iteritems(): 440 if markerId not in markerGenes or len(hits) < 2: 441 continue 442 443 # sort hits from highest to lowest e-value in order to ensure only the best hit 444 # to a given target is retained 445 hits.sort(key=lambda x: x.full_e_value, reverse=True) 446 447 # Note: this data structure is used to mimic that used by __extractMarkerSeqsTopHits() 448 markersWithMultipleHits[markerId][binId] = {} 449 for hit in hits: 450 markersWithMultipleHits[markerId][binId][hit.target_name] = self.__extractSeq(hit.target_name, binORFs) 451 452 return markersWithMultipleHits 453 454 def __makeAlignmentModels(self, hmmModelFile, modelKeys, hmmModelFiles, bReportProgress=True): 455 """Make temporary HMM files used to create HMM alignments.""" 456 457 if bReportProgress: 458 self.logger.info("Extracting %d HMMs with %d threads:" % (len(modelKeys), self.totalThreads)) 459 460 # process each marker in parallel 461 workerQueue = mp.Queue() 462 writerQueue = mp.Queue() 463 464 for modelId in modelKeys: 465 fetchFilename = os.path.join(tempfile.gettempdir(), str(uuid.uuid4())) 466 hmmModelFiles[modelId] = fetchFilename 467 workerQueue.put((modelId, fetchFilename)) 468 469 for _ in range(self.totalThreads): 470 workerQueue.put((None, None)) 471 472 try: 473 calcProc = [mp.Process(target=self.__extractModel, args=(hmmModelFile, workerQueue, writerQueue)) for _ in range(self.totalThreads)] 474 writeProc = mp.Process(target=self.__reportModelExtractionProgress, args=(len(modelKeys), bReportProgress, writerQueue)) 475 476 writeProc.start() 477 478 for p in calcProc: 479 p.start() 480 481 for p in calcProc: 482 p.join() 483 484 writerQueue.put(None) 485 writeProc.join() 486 except: 487 # make sure all processes are terminated 488 for p in calcProc: 489 p.terminate() 490 491 writeProc.terminate() 492 493 def __extractModel(self, hmmModelFile, queueIn, queueOut): 494 """Extract HMM.""" 495 HF = HMMERRunner(mode='fetch') 496 497 while True: 498 modelId, fetchFilename = queueIn.get(block=True, timeout=None) 499 if modelId == None: 500 break 501 502 HF.fetch(hmmModelFile, modelId, fetchFilename) 503 504 queueOut.put(modelId) 505 506 def __reportModelExtractionProgress(self, numMarkers, bReportProgress, queueIn): 507 """Report number of extracted HMMs.""" 508 509 numModelsExtracted = 0 510 if bReportProgress and self.logger.getEffectiveLevel() <= logging.INFO: 511 statusStr = ' Finished extracting %d of %d (%.2f%%) HMMs.' % (numModelsExtracted, numMarkers, float(numModelsExtracted) * 100 / numMarkers) 512 sys.stderr.write('%s\r' % statusStr) 513 sys.stderr.flush() 514 515 while True: 516 modelId = queueIn.get(block=True, timeout=None) 517 if modelId == None: 518 break 519 520 if bReportProgress and self.logger.getEffectiveLevel() <= logging.INFO: 521 numModelsExtracted += 1 522 statusStr = ' Finished extracting %d of %d (%.2f%%) HMMs.' % (numModelsExtracted, numMarkers, float(numModelsExtracted) * 100 / numMarkers) 523 sys.stderr.write('%s\r' % statusStr) 524 sys.stderr.flush() 525 526 if bReportProgress and self.logger.getEffectiveLevel() <= logging.INFO: 527 sys.stderr.write('\n') 528