1###############################################################################
2#
3# markerGeneFinder.py - identify marker genes in genome bins
4#
5###############################################################################
6#                                                                             #
7#    This program is free software: you can redistribute it and/or modify     #
8#    it under the terms of the GNU General Public License as published by     #
9#    the Free Software Foundation, either version 3 of the License, or        #
10#    (at your option) any later version.                                      #
11#                                                                             #
12#    This program is distributed in the hope that it will be useful,          #
13#    but WITHOUT ANY WARRANTY; without even the implied warranty of           #
14#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the            #
15#    GNU General Public License for more details.                             #
16#                                                                             #
17#    You should have received a copy of the GNU General Public License        #
18#    along with this program. If not, see <http://www.gnu.org/licenses/>.     #
19#                                                                             #
20###############################################################################
21
22import os
23import sys
24import shutil
25import multiprocessing as mp
26import logging
27
28from checkm.common import binIdFromFilename, makeSurePathExists
29from checkm.defaultValues import DefaultValues
30
31from checkm.hmmer import HMMERRunner
32from checkm.prodigal import ProdigalRunner
33
34from checkm.markerSets import MarkerSetParser
35from checkm.hmmerModelParser import HmmModelParser
36
37
38class MarkerGeneFinder():
39    """Identify marker genes within binned sequences using Prodigal and HMMER."""
40    def __init__(self, threads):
41        self.logger = logging.getLogger('timestamp')
42        self.totalThreads = threads
43
44    def find(self, binFiles, outDir, tableOut, hmmerOut, markerFile, bKeepAlignment, bNucORFs, bCalledGenes):
45        """Identify marker genes in each bin using prodigal and HMMER."""
46
47        # make sure HMMER and prodigal are on system path
48        HMMERRunner()
49
50        if not bCalledGenes:
51            ProdigalRunner('')
52
53        # process each fasta file
54        self.threadsPerSearch = max(1, int(self.totalThreads / len(binFiles)))
55        self.logger.info("Identifying marker genes in %d bins with %d threads:" % (len(binFiles), self.totalThreads))
56
57        # process each bin in parallel
58        workerQueue = mp.Queue()
59        writerQueue = mp.Queue()
60
61        for binFile in binFiles:
62            workerQueue.put(binFile)
63
64        for _ in range(self.totalThreads):
65            workerQueue.put(None)
66
67        binIdToModels = mp.Manager().dict()
68
69        try:
70            calcProc = [mp.Process(target=self.__processBin, args=(outDir, tableOut, hmmerOut, markerFile, bKeepAlignment, bNucORFs, bCalledGenes, workerQueue, writerQueue)) for _ in range(self.totalThreads)]
71            writeProc = mp.Process(target=self.__reportProgress, args=(len(binFiles), binIdToModels, writerQueue))
72
73            writeProc.start()
74
75            for p in calcProc:
76                p.start()
77
78            for p in calcProc:
79                p.join()
80
81            writerQueue.put((None, None))
82            writeProc.join()
83        except:
84            # make sure all processes are terminated
85            for p in calcProc:
86                p.terminate()
87
88            writeProc.terminate()
89
90        # create a standard dictionary from the managed dictionary
91        d = {}
92        for binId in binIdToModels.keys():
93            d[binId] = binIdToModels[binId]
94
95        return d
96
97    def __processBin(self, outDir, tableOut, hmmerOut, markerFile, bKeepAlignment, bNucORFs, bCalledGenes, queueIn, queueOut):
98        """Thread safe bin processing."""
99
100        markerSetParser = MarkerSetParser(self.threadsPerSearch)
101
102        while True:
103            binFile = queueIn.get(block=True, timeout=None)
104            if binFile == None:
105                break
106
107            binId = binIdFromFilename(binFile)
108            binDir = os.path.join(outDir, 'bins', binId)
109            makeSurePathExists(binDir)
110
111            # run Prodigal
112            if not bCalledGenes:
113                prodigal = ProdigalRunner(binDir)
114                if not prodigal.areORFsCalled(bNucORFs):
115                    prodigal.run(binFile, bNucORFs)
116                aaGeneFile = prodigal.aaGeneFile
117            else:
118                aaGeneFile = binFile
119                shutil.copyfile(aaGeneFile, os.path.join(binDir, DefaultValues.PRODIGAL_AA))
120
121            # extract HMMs into temporary file
122            hmmModelFile = markerSetParser.createHmmModelFile(binId, markerFile)
123
124            # run HMMER
125            hmmer = HMMERRunner()
126            tableOutPath = os.path.join(binDir, tableOut)
127            hmmerOutPath = os.path.join(binDir, hmmerOut)
128
129            keepAlignStr = ''
130            if not bKeepAlignment:
131                keepAlignStr = '--noali'
132            hmmer.search(hmmModelFile, aaGeneFile, tableOutPath, hmmerOutPath,
133                         '--cpu ' + str(self.threadsPerSearch) + ' --notextw -E 0.1 --domE 0.1 ' + keepAlignStr,
134                         bKeepAlignment)
135
136            queueOut.put((binId, hmmModelFile))
137
138    def __reportProgress(self, numBins, binIdToModels, queueIn):
139        """Report number of processed bins."""
140
141        numProcessedBins = 0
142        if self.logger.getEffectiveLevel() <= logging.INFO:
143            statusStr = '    Finished processing %d of %d (%.2f%%) bins.' % (numProcessedBins, numBins, float(numProcessedBins) * 100 / numBins)
144            sys.stderr.write('%s\r' % statusStr)
145            sys.stderr.flush()
146
147        while True:
148            binId, hmmModelFile = queueIn.get(block=True, timeout=None)
149            if binId == None:
150                break
151
152            # parse HMM file
153            # (This is done here as pushing the models onto the shared queue is too memory intensive)
154            modelParser = HmmModelParser(hmmModelFile)
155            models = modelParser.models()
156
157            binIdToModels[binId] = models
158
159            if os.path.exists(hmmModelFile):
160                os.remove(hmmModelFile)
161
162            indexFile = hmmModelFile + '.ssi'
163            if os.path.exists(indexFile):
164                os.remove(indexFile)
165
166            if self.logger.getEffectiveLevel() <= logging.INFO:
167                numProcessedBins += 1
168                statusStr = '    Finished processing %d of %d (%.2f%%) bins.' % (numProcessedBins, numBins, float(numProcessedBins) * 100 / numBins)
169                sys.stderr.write('%s\r' % statusStr)
170                sys.stderr.flush()
171
172        if self.logger.getEffectiveLevel() <= logging.INFO:
173            sys.stderr.write('\n')
174