1###############################################################################
2#
3# hmmerAlign.py - runs HMMER and provides functions for parsing output
4#
5###############################################################################
6#                                                                             #
7#    This program is free software: you can redistribute it and/or modify     #
8#    it under the terms of the GNU General Public License as published by     #
9#    the Free Software Foundation, either version 3 of the License, or        #
10#    (at your option) any later version.                                      #
11#                                                                             #
12#    This program is distributed in the hope that it will be useful,          #
13#    but WITHOUT ANY WARRANTY; without even the implied warranty of           #
14#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the            #
15#    GNU General Public License for more details.                             #
16#                                                                             #
17#    You should have received a copy of the GNU General Public License        #
18#    along with this program. If not, see <http://www.gnu.org/licenses/>.     #
19#                                                                             #
20###############################################################################
21
22import os
23import sys
24import uuid
25import logging
26import tempfile
27import shutil
28import multiprocessing as mp
29from collections import defaultdict
30
31from checkm.defaultValues import DefaultValues
32from checkm.common import makeSurePathExists
33from checkm.util.seqUtils import readFasta
34from checkm.hmmer import HMMERRunner
35from checkm.resultsParser import ResultsParser
36
37
38class HmmerAligner:
39    def __init__(self, threads):
40        self.logger = logging.getLogger('timestamp')
41        self.totalThreads = threads
42
43        self.outputFormat = 'Pfam'
44
45    def makeAlignmentTopHit(self,
46                               outDir,
47                               hmmModelFile,
48                               hmmTableFile,
49                               binIdToModels,
50                               bIgnoreThresholds,
51                               evalueThreshold,
52                               lengthThreshold,
53                               bReportHitStats,
54                               alignOutputDir,
55                               bKeepUnmaskedAlign=False
56                               ):
57        """Align top hits in each bin. Assumes all bins are using the same marker genes."""
58
59        self.logger.info("Extracting marker genes to align.")
60
61        # parse HMM information
62        resultsParser = ResultsParser(binIdToModels)
63
64        # get HMM hits to each bin
65        resultsParser.parseBinHits(outDir, hmmTableFile, False, bIgnoreThresholds, evalueThreshold, lengthThreshold)
66
67        # extract the ORFs to align
68        markerSeqs, markerStats = self.__extractMarkerSeqsTopHits(outDir, resultsParser)
69
70        # generate individual HMMs required to create multiple sequence alignments
71        binId = binIdToModels.keys()[0]
72        hmmModelFiles = {}
73        self.__makeAlignmentModels(hmmModelFile, binIdToModels[binId], hmmModelFiles)
74
75        # align each of the marker genes
76        makeSurePathExists(alignOutputDir)
77        self.__alignMarkerGenes(markerSeqs, markerStats, bReportHitStats, hmmModelFiles, alignOutputDir, bKeepUnmaskedAlign)
78
79        # remove the temporary HMM files
80        for fileName in hmmModelFiles:
81            os.remove(hmmModelFiles[fileName])
82
83        return resultsParser
84
85    def makeAlignmentToPhyloMarkers(self,
86                                       outDir,
87                                       hmmModelFile,
88                                       hmmTableFile,
89                                       binIdToModels,
90                                       bIgnoreThresholds,
91                                       evalueThreshold,
92                                       lengthThreshold,
93                                       bReportHitStats,
94                                       alignOutputDir,
95                                       bKeepUnmaskedAlign=False
96                                       ):
97        """Align hits to a set of common marker genes."""
98
99        self.logger.info("Extracting marker genes to align.")
100
101        # parse HMM information
102        resultsParser = ResultsParser(binIdToModels)
103
104        # get HMM hits to each bin
105        resultsParser.parseBinHits(outDir, hmmTableFile, False, bIgnoreThresholds, evalueThreshold, lengthThreshold)
106
107        # extract the ORFs to align
108        markerSeqs, markerStats = self.__extractMarkerSeqsUnique(outDir, resultsParser)
109
110        # generate individual HMMs required to create multiple sequence alignments
111        binId = binIdToModels.keys()[0]
112        hmmModelFiles = {}
113        self.__makeAlignmentModels(hmmModelFile, binIdToModels[binId], hmmModelFiles)
114
115        # align each of the marker genes
116        makeSurePathExists(alignOutputDir)
117        self.__alignMarkerGenes(markerSeqs, markerStats, bReportHitStats, hmmModelFiles, alignOutputDir, bKeepUnmaskedAlign)
118
119        # remove the temporary HMM files
120        for fileName in hmmModelFiles:
121            os.remove(hmmModelFiles[fileName])
122
123        return resultsParser
124
125    def makeAlignmentsOfMultipleHits(self,
126                                       outDir,
127                                       markerFile,
128                                       hmmTableFile,
129                                       binIdToModels,
130                                       binIdToBinMarkerSets,
131                                       bIgnoreThresholds,
132                                       evalueThreshold,
133                                       lengthThreshold,
134                                       alignOutputDir,
135                                       ):
136        """Align markers with multiple hits within a bin."""
137
138        makeSurePathExists(alignOutputDir)
139
140        # parse HMM information
141        resultsParser = ResultsParser(binIdToModels)
142
143        # get HMM hits to each bin
144        resultsParser.parseBinHits(outDir, hmmTableFile, False, bIgnoreThresholds, evalueThreshold, lengthThreshold)
145
146        # align any markers with multiple hits in a bin
147        self.logger.info('Aligning marker genes with multiple hits in a single bin:')
148
149        # process each bin in parallel
150        workerQueue = mp.Queue()
151        writerQueue = mp.Queue()
152
153        for binId in binIdToModels:
154            workerQueue.put(binId)
155
156        for _ in range(self.totalThreads):
157            workerQueue.put(None)
158
159        try:
160            calcProc = [mp.Process(target=self.__createMSA, args=(resultsParser, binIdToBinMarkerSets, markerFile, outDir, alignOutputDir, workerQueue, writerQueue)) for _ in range(self.totalThreads)]
161            writeProc = mp.Process(target=self.__reportBinProgress, args=(len(binIdToModels), writerQueue))
162
163            writeProc.start()
164
165            for p in calcProc:
166                p.start()
167
168            for p in calcProc:
169                p.join()
170
171            writerQueue.put(None)
172            writeProc.join()
173        except:
174            # make sure all processes are terminated
175            for p in calcProc:
176                p.terminate()
177
178            writeProc.terminate()
179
180    def __createMSA(self, resultsParser, binIdToBinMarkerSets, hmmModelFile, outDir, alignOutputDir, queueIn, queueOut):
181        """Create multiple sequence alignment for markers with multiple hits in a bin."""
182
183        HF = HMMERRunner(mode='fetch')
184
185        while True:
186            binId = queueIn.get(block=True, timeout=None)
187            if binId == None:
188                break
189
190            markersWithMultipleHits = self.__extractMarkersWithMultipleHits(outDir, binId, resultsParser, binIdToBinMarkerSets[binId])
191
192            if len(markersWithMultipleHits) != 0:
193                # create multiple sequence alignments for markers with multiple hits
194                binAlignOutputDir = os.path.join(alignOutputDir, binId)
195                makeSurePathExists(binAlignOutputDir)
196                for markerId in markersWithMultipleHits:
197                    tempModelFile = os.path.join(tempfile.gettempdir(), str(uuid.uuid4()))
198                    HF.fetch(hmmModelFile, markerId, tempModelFile)
199
200                    self.__alignMarker(markerId, markersWithMultipleHits[markerId], None, False, binAlignOutputDir, tempModelFile, bKeepUnmaskedAlign=False)
201
202                    os.remove(tempModelFile)
203
204            queueOut.put(binId)
205
206    def __reportBinProgress(self, numBins, queueIn):
207        """Report number of processed bins."""
208
209        numProcessedBins = 0
210        if self.logger.getEffectiveLevel() <= logging.INFO:
211            statusStr = '    Finished processing %d of %d (%.2f%%) bins.' % (numProcessedBins, numBins, float(numProcessedBins) * 100 / numBins)
212            sys.stderr.write('%s\r' % statusStr)
213            sys.stderr.flush()
214
215        while True:
216            binId = queueIn.get(block=True, timeout=None)
217            if binId == None:
218                break
219
220            if self.logger.getEffectiveLevel() <= logging.INFO:
221                numProcessedBins += 1
222                statusStr = '    Finished processing %d of %d (%.2f%%) bins.' % (numProcessedBins, numBins, float(numProcessedBins) * 100 / numBins)
223                sys.stderr.write('%s\r' % statusStr)
224                sys.stderr.flush()
225
226        if self.logger.getEffectiveLevel() <= logging.INFO:
227            sys.stderr.write('\n')
228
229    def __alignMarkerGenes(self, markerSeqs, markerStats, bReportHitStats, hmmModelFiles, alignOutputDir, bKeepUnmaskedAlign=False, bReportProgress=True):
230        """Align marker genes with HMMs in parallel."""
231
232        if bReportProgress:
233            self.logger.info("Aligning %d marker genes with %d threads:" % (len(hmmModelFiles), self.totalThreads))
234
235        # process each bin in parallel
236        workerQueue = mp.Queue()
237        writerQueue = mp.Queue()
238
239        for markerId in hmmModelFiles:
240            workerQueue.put(markerId)
241
242        for _ in range(self.totalThreads):
243            workerQueue.put(None)
244
245        try:
246            calcProc = [mp.Process(target=self.__alignMarkerParallel, args=(markerSeqs, markerStats, bReportHitStats, alignOutputDir, hmmModelFiles, bKeepUnmaskedAlign, workerQueue, writerQueue)) for _ in range(self.totalThreads)]
247            writeProc = mp.Process(target=self.__reportAlignmentProgress, args=(len(hmmModelFiles), bReportProgress, writerQueue))
248
249            writeProc.start()
250
251            for p in calcProc:
252                p.start()
253
254            for p in calcProc:
255                p.join()
256
257            writerQueue.put(None)
258            writeProc.join()
259        except:
260            # make sure all processes are terminated
261            for p in calcProc:
262                p.terminate()
263
264            writeProc.terminate()
265
266    def __alignMarkerParallel(self, markerSeqs, markerStats, bReportHitStats, alignOutputDir, hmmModelFiles, bKeepUnmaskedAlign, queueIn, queueOut):
267        while True:
268            markerId = queueIn.get(block=True, timeout=None)
269            if markerId == None:
270                break
271
272            self.__alignMarker(markerId, markerSeqs[markerId], markerStats[markerId], bReportHitStats, alignOutputDir, hmmModelFiles[markerId], bKeepUnmaskedAlign)
273
274            queueOut.put(markerId)
275
276    def __alignMarker(self, markerId, binSeqs, binStats, bReportHitStats, alignOutputDir, hmmModelFile, bKeepUnmaskedAlign):
277        unalignSeqFile = os.path.join(alignOutputDir, markerId + '.unaligned.faa')
278        fout = open(unalignSeqFile, 'w')
279        numSeqs = 0
280        for binId, seqs in binSeqs.iteritems():
281            for seqId, seq in seqs.iteritems():
282                header = '>' + binId + DefaultValues.SEQ_CONCAT_CHAR + seqId
283                if bReportHitStats:
284                    header += ' [e-value=%.4g,score=%.1f]' % (binStats[binId][seqId][0], binStats[binId][seqId][1])
285
286                fout.write(header + '\n')
287                fout.write(seq + '\n')
288                numSeqs += 1
289        fout.close()
290
291        if numSeqs > 0:
292            alignSeqFile = os.path.join(alignOutputDir, markerId + '.aligned.faa')
293            HA = HMMERRunner(mode='align')
294            HA.align(hmmModelFile, unalignSeqFile, alignSeqFile, writeMode='>', outputFormat=self.outputFormat, trim=False)
295
296            makedSeqFile = os.path.join(alignOutputDir, markerId + '.masked.faa')
297            self.__maskAlignment(alignSeqFile, makedSeqFile)
298
299            if not bKeepUnmaskedAlign:
300                os.remove(alignSeqFile)
301
302        os.remove(unalignSeqFile)
303
304    def __reportAlignmentProgress(self, numMarkers, bReportProgress, queueIn):
305        """Report number of processed markers."""
306
307        numProcessedGenes = 0
308        if bReportProgress and self.logger.getEffectiveLevel() <= logging.INFO:
309            statusStr = '    Finished aligning %d of %d (%.2f%%) marker genes.' % (numProcessedGenes, numMarkers, float(numProcessedGenes) * 100 / numMarkers)
310            sys.stderr.write('%s\r' % statusStr)
311            sys.stderr.flush()
312
313        while True:
314            binId = queueIn.get(block=True, timeout=None)
315            if binId == None:
316                break
317
318            if bReportProgress and self.logger.getEffectiveLevel() <= logging.INFO:
319                numProcessedGenes += 1
320                statusStr = '    Finished aligning %d of %d (%.2f%%) marker genes.' % (numProcessedGenes, numMarkers, float(numProcessedGenes) * 100 / numMarkers)
321                sys.stderr.write('%s\r' % statusStr)
322                sys.stderr.flush()
323
324        if bReportProgress and self.logger.getEffectiveLevel() <= logging.INFO:
325            sys.stderr.write('\n')
326
327    def __maskAlignment(self, inputFile, outputFile):
328        """Read HMMER alignment in STOCKHOLM format and output masked alignment in FASTA format."""
329        # read STOCKHOLM alignment
330        seqs = {}
331        seqStats = {}
332        for line in open(inputFile):
333            line = line.rstrip()
334            if line == '' or line[0] == '#' or line == '//':
335                if 'GC RF' in line:
336                    mask = line.split('GC RF')[1].strip()
337                elif '=GS' in line:
338                    # read additional sequence informations
339                    lineSplit = line.split()
340                    seqId = lineSplit[1]
341                    stats = lineSplit[3].strip()
342                    seqStats[seqId] = stats
343                continue
344            else:
345                lineSplit = line.split()
346                seqs[lineSplit[0]] = lineSplit[1].upper().replace('.', '-').strip()
347
348        # output masked sequences in FASTA format
349        fout = open(outputFile, 'w')
350        for seqId, seq in seqs.iteritems():
351            if seqStats:
352                fout.write('>%s %s\n' % (seqId, seqStats[seqId]))
353            else:
354                fout.write('>' + seqId + '\n')
355
356            maskedSeq = ''.join([seq[i] for i in xrange(0, len(seq)) if mask[i] == 'x'])
357            fout.write(maskedSeq + '\n')
358        fout.close()
359
360    def __extractMarkerSeqsTopHits(self, outDir, resultsParser):
361        """Extract marker sequences from top hits (assume all bins use the same HMM file)."""
362
363        markerSeqs = defaultdict(dict)
364        markerStats = defaultdict(dict)
365        for binId in resultsParser.results:
366            # read ORFs for bin
367            aaGeneFile = os.path.join(outDir, 'bins', binId, DefaultValues.PRODIGAL_AA)
368            binORFs = readFasta(aaGeneFile)
369
370            # extract ORFs hitting a marker
371            for markerId, hits in resultsParser.results[binId].markerHits.iteritems():
372                markerSeqs[markerId][binId] = {}
373                markerStats[markerId][binId] = {}
374
375                # sort hits from highest to lowest e-value in order to ensure only the best hit
376                # to a given target is retained
377                hits.sort(key=lambda x: x.full_e_value, reverse=True)
378                topHit = hits[0]
379                markerSeqs[markerId][binId][topHit.target_name] = self.__extractSeq(topHit.target_name, binORFs)
380                markerStats[markerId][binId][topHit.target_name] = [topHit.full_e_value, topHit.full_score]
381
382        return markerSeqs, markerStats
383
384    def __extractMarkerSeqsUnique(self, outDir, resultsParser):
385        """Extract marker sequences with a single unique hit."""
386
387        markerSeqs = defaultdict(dict)
388        markerStats = defaultdict(dict)
389        for binId in resultsParser.results:
390            # read ORFs for bin
391            aaGeneFile = os.path.join(outDir, 'bins', binId, DefaultValues.PRODIGAL_AA)
392            binORFs = readFasta(aaGeneFile)
393
394            # extract ORFs hitting a marker
395            for markerId, hits in resultsParser.results[binId].markerHits.iteritems():
396                markerSeqs[markerId][binId] = {}
397                markerStats[markerId][binId] = {}
398
399                # only record hits which are unique
400                if len(hits) == 1:
401                    hit = hits[0]
402                    markerSeqs[markerId][binId][hit.target_name] = self.__extractSeq(hit.target_name, binORFs)
403                    markerStats[markerId][binId][hit.target_name] = [hit.full_e_value, hit.full_score]
404
405        return markerSeqs, markerStats
406
407    def __extractSeq(self, seqId, seqs):
408        """Extract sequence data."""
409
410        if DefaultValues.SEQ_CONCAT_CHAR in seqId:
411            seqIds = seqId.split(DefaultValues.SEQ_CONCAT_CHAR)
412
413            seq = ''
414            for seqId in seqIds:
415                tempSeq = seqs[seqId]
416                if tempSeq[-1] == '*':
417                    tempSeq = tempSeq[0:-1]  # remove final '*' inserted by prodigal
418
419                seq += tempSeq
420
421            rtnSeq = seq
422        else:
423            rtnSeq = seqs[seqId]
424
425            if rtnSeq[-1] == '*':
426                rtnSeq = rtnSeq[0:-1]  # remove final '*' inserted by prodigal
427
428        return rtnSeq
429
430    def __extractMarkersWithMultipleHits(self, outDir, binId, resultsParser, binMarkerSet):
431        """Extract markers with multiple hits within a single bin."""
432
433        markersWithMultipleHits = defaultdict(dict)
434
435        aaGeneFile = os.path.join(outDir, 'bins', binId, DefaultValues.PRODIGAL_AA)
436        binORFs = readFasta(aaGeneFile)
437
438        markerGenes = binMarkerSet.selectedMarkerSet().getMarkerGenes()
439        for markerId, hits in resultsParser.results[binId].markerHits.iteritems():
440            if markerId not in markerGenes or len(hits) < 2:
441                continue
442
443            # sort hits from highest to lowest e-value in order to ensure only the best hit
444            # to a given target is retained
445            hits.sort(key=lambda x: x.full_e_value, reverse=True)
446
447            # Note: this data structure is used to mimic that used by __extractMarkerSeqsTopHits()
448            markersWithMultipleHits[markerId][binId] = {}
449            for hit in hits:
450                markersWithMultipleHits[markerId][binId][hit.target_name] = self.__extractSeq(hit.target_name, binORFs)
451
452        return markersWithMultipleHits
453
454    def __makeAlignmentModels(self, hmmModelFile, modelKeys, hmmModelFiles, bReportProgress=True):
455        """Make temporary HMM files used to create HMM alignments."""
456
457        if bReportProgress:
458            self.logger.info("Extracting %d HMMs with %d threads:" % (len(modelKeys), self.totalThreads))
459
460        # process each marker in parallel
461        workerQueue = mp.Queue()
462        writerQueue = mp.Queue()
463
464        for modelId in modelKeys:
465            fetchFilename = os.path.join(tempfile.gettempdir(), str(uuid.uuid4()))
466            hmmModelFiles[modelId] = fetchFilename
467            workerQueue.put((modelId, fetchFilename))
468
469        for _ in range(self.totalThreads):
470            workerQueue.put((None, None))
471
472        try:
473            calcProc = [mp.Process(target=self.__extractModel, args=(hmmModelFile, workerQueue, writerQueue)) for _ in range(self.totalThreads)]
474            writeProc = mp.Process(target=self.__reportModelExtractionProgress, args=(len(modelKeys), bReportProgress, writerQueue))
475
476            writeProc.start()
477
478            for p in calcProc:
479                p.start()
480
481            for p in calcProc:
482                p.join()
483
484            writerQueue.put(None)
485            writeProc.join()
486        except:
487            # make sure all processes are terminated
488            for p in calcProc:
489                p.terminate()
490
491            writeProc.terminate()
492
493    def __extractModel(self, hmmModelFile, queueIn, queueOut):
494        """Extract HMM."""
495        HF = HMMERRunner(mode='fetch')
496
497        while True:
498            modelId, fetchFilename = queueIn.get(block=True, timeout=None)
499            if modelId == None:
500                break
501
502            HF.fetch(hmmModelFile, modelId, fetchFilename)
503
504            queueOut.put(modelId)
505
506    def __reportModelExtractionProgress(self, numMarkers, bReportProgress, queueIn):
507        """Report number of extracted HMMs."""
508
509        numModelsExtracted = 0
510        if bReportProgress and self.logger.getEffectiveLevel() <= logging.INFO:
511            statusStr = '    Finished extracting %d of %d (%.2f%%) HMMs.' % (numModelsExtracted, numMarkers, float(numModelsExtracted) * 100 / numMarkers)
512            sys.stderr.write('%s\r' % statusStr)
513            sys.stderr.flush()
514
515        while True:
516            modelId = queueIn.get(block=True, timeout=None)
517            if modelId == None:
518                break
519
520            if bReportProgress and self.logger.getEffectiveLevel() <= logging.INFO:
521                numModelsExtracted += 1
522                statusStr = '    Finished extracting %d of %d (%.2f%%) HMMs.' % (numModelsExtracted, numMarkers, float(numModelsExtracted) * 100 / numMarkers)
523                sys.stderr.write('%s\r' % statusStr)
524                sys.stderr.flush()
525
526        if bReportProgress and self.logger.getEffectiveLevel() <= logging.INFO:
527            sys.stderr.write('\n')
528