1###############################################################################
2#
3# binTools.py - functions for exploring and modifying bins
4#
5###############################################################################
6#                                                                             #
7#    This program is free software: you can redistribute it and/or modify     #
8#    it under the terms of the GNU General Public License as published by     #
9#    the Free Software Foundation, either version 3 of the License, or        #
10#    (at your option) any later version.                                      #
11#                                                                             #
12#    This program is distributed in the hope that it will be useful,          #
13#    but WITHOUT ANY WARRANTY; without even the implied warranty of           #
14#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the            #
15#    GNU General Public License for more details.                             #
16#                                                                             #
17#    You should have received a copy of the GNU General Public License        #
18#    along with this program. If not, see <http://www.gnu.org/licenses/>.     #
19#                                                                             #
20###############################################################################
21
22import os
23import sys
24import logging
25import gzip
26
27import numpy as np
28
29from common import binIdFromFilename, checkFileExists, readDistribution, findNearest
30from checkm.util.seqUtils import readFasta, writeFasta, baseCount
31from checkm.genomicSignatures import GenomicSignatures
32from checkm.prodigal import ProdigalGeneFeatureParser
33from checkm.defaultValues import DefaultValues
34
35
36class BinTools():
37    """Functions for exploring and modifying bins."""
38    def __init__(self, threads=1):
39        self.logger = logging.getLogger('timestamp')
40
41    def __removeSeqs(self, seqs, seqsToRemove):
42        """Remove sequences. """
43        missingSeqIds = set(seqsToRemove).difference(set(seqs.keys()))
44        if len(missingSeqIds) > 0:
45            self.logger.error('  [Error] Missing sequence(s) specified for removal: ' + ', '.join(missingSeqIds) + '\n')
46            sys.exit(1)
47
48        for seqId in seqsToRemove:
49            seqs.pop(seqId)
50
51    def __addSeqs(self, seqs, refSeqs, seqsToAdd):
52        """Add sequences. """
53        missingSeqIds = set(seqsToAdd).difference(set(refSeqs.keys()))
54        if len(missingSeqIds) > 0:
55            self.logger.error('  [Error] Missing sequence(s) specified for addition: ' + ', '.join(missingSeqIds) + '\n')
56            sys.exit(1)
57
58        for seqId in seqsToAdd:
59            seqs[seqId] = refSeqs[seqId]
60
61    def modify(self, binFile, seqFile, seqsToAdd, seqsToRemove, outputFile):
62        """Add and remove sequences from a file."""
63        binSeqs = readFasta(binFile)
64
65        # add sequences to bin
66        if seqsToAdd != None:
67            refSeqs = readFasta(seqFile)
68            self.__addSeqs(binSeqs, refSeqs, seqsToAdd)
69
70        # remove sequences from bin
71        if seqsToRemove != None:
72            self.__removeSeqs(binSeqs, seqsToRemove)
73
74        # save modified bin
75        writeFasta(binSeqs, outputFile)
76
77    def removeOutliers(self, binFile, outlierFile, outputFile):
78        """Remove sequences specified as outliers in the provided file."""
79
80        binSeqs = readFasta(binFile)
81        binIdToModify = binIdFromFilename(binFile)
82
83        # get files to remove
84        checkFileExists(outlierFile)
85        seqsToRemove = []
86        bHeader = True
87        for line in open(outlierFile):
88            if bHeader:
89                bHeader = False
90                continue
91
92            lineSplit = line.split('\t')
93            binId = lineSplit[0]
94
95            if binId == binIdToModify:
96                seqId = lineSplit[1]
97                seqsToRemove.append(seqId)
98
99        # remove sequences from bin
100        if len(seqsToRemove) > 0:
101            self.__removeSeqs(binSeqs, seqsToRemove)
102
103        # save modified bin
104        writeFasta(binSeqs, outputFile)
105
106    def unique(self, binFiles):
107        """Check if sequences are assigned to multiple bins."""
108
109        # read sequence IDs from all bins,
110        # while checking for duplicate sequences within a bin
111        binSeqs = {}
112        for f in binFiles:
113            binId = binIdFromFilename(f)
114
115            if f.endswith('.gz'):
116                openFile = gzip.open
117            else:
118                openFile = open
119
120            seqIds = set()
121            for line in openFile(f):
122                if line[0] == '>':
123                    seqId = line[1:].split(None, 1)[0]
124
125                    if seqId in seqIds:
126                        print '  [Warning] Sequence %s found multiple times in bin %s.' % (seqId, binId)
127                    seqIds.add(seqId)
128
129            binSeqs[binId] = seqIds
130
131        # check for sequences assigned to multiple bins
132        bDuplicates = False
133        binIds = binSeqs.keys()
134        for i in xrange(0, len(binIds)):
135            for j in xrange(i + 1, len(binIds)):
136                seqInter = set(binSeqs[binIds[i]]).intersection(set(binSeqs[binIds[j]]))
137
138                if len(seqInter) > 0:
139                    bDuplicates = True
140                    print '  Sequences shared between %s and %s: ' % (binIds[i], binIds[j])
141                    for seqId in seqInter:
142                        print '    ' + seqId
143                    print ''
144
145        if not bDuplicates:
146            print '  No sequences assigned to multiple bins.'
147
148    def gcDist(self, seqs):
149        """GC statistics for bin."""
150        GCs = []
151        gcTotal = 0
152        basesTotal = 0
153        for _, seq in seqs.iteritems():
154            a, c, g, t = baseCount(seq)
155            gc = g + c
156            bases = a + c + g + t
157
158            GCs.append(float(gc) / (bases))
159
160            gcTotal += gc
161            basesTotal += bases
162
163        meanGC = float(gcTotal) / basesTotal
164        deltaGCs = np.array(GCs) - meanGC
165
166        return meanGC, deltaGCs, GCs
167
168    def codingDensityDist(self, seqs, prodigalParser):
169        """Coding density statistics for bin."""
170        CDs = []
171
172        codingBasesTotal = 0
173        basesTotal = 0
174        for seqId, seq in seqs.iteritems():
175            codingBases = prodigalParser.codingBases(seqId)
176
177            CDs.append(float(codingBases) / len(seq))
178            codingBasesTotal += codingBases
179            basesTotal += len(seq)
180
181        meanCD = float(codingBasesTotal) / basesTotal
182        deltaCDs = np.array(CDs) - meanCD
183
184        return meanCD, deltaCDs, CDs
185
186    def binTetraSig(self, seqs, tetraSigs):
187        """Tetranucleotide signature for bin. """
188        binSize = 0
189        for _, seq in seqs.iteritems():
190            binSize += len(seq)
191
192        bInit = True
193        for seqId, seq in seqs.iteritems():
194            weightedTetraSig = tetraSigs[seqId] * (float(len(seq)) / binSize)
195            if bInit:
196                binSig = weightedTetraSig
197                bInit = False
198            else:
199                binSig += weightedTetraSig
200
201        return binSig
202
203    def tetraDiffDist(self, seqs, genomicSig, tetraSigs, binSig):
204        """TD statistics for bin."""
205        deltaTDs = np.zeros(len(seqs))
206        for i, seqId in enumerate(seqs.keys()):
207            deltaTDs[i] = genomicSig.distance(tetraSigs[seqId], binSig)
208
209        return np.mean(deltaTDs), deltaTDs
210
211    def identifyOutliers(self, outDir, binFiles, tetraProfileFile, distribution, reportType, outputFile):
212        """Identify sequences that are outliers."""
213
214        self.logger.info('Reading reference distributions.')
215        gcBounds = readDistribution('gc_dist')
216        cdBounds = readDistribution('cd_dist')
217        tdBounds = readDistribution('td_dist')
218
219        fout = open(outputFile, 'w')
220        fout.write('Bin Id\tSequence Id\tSequence length\tOutlying distributions')
221        fout.write('\tSequence GC\tMean bin GC\tLower GC bound (%s%%)\tUpper GC bound (%s%%)' % (distribution, distribution))
222        fout.write('\tSequence CD\tMean bin CD\tLower CD bound (%s%%)' % distribution)
223        fout.write('\tSequence TD\tMean bin TD\tUpper TD bound (%s%%)\n' % distribution)
224
225        processedBins = 0
226        for binFile in binFiles:
227            binId = binIdFromFilename(binFile)
228
229            processedBins += 1
230            self.logger.info('Finding outliers in %s (%d of %d).' % (binId, processedBins, len(binFiles)))
231
232            seqs = readFasta(binFile)
233
234            meanGC, deltaGCs, seqGC = self.gcDist(seqs)
235
236            genomicSig = GenomicSignatures(K=4, threads=1)
237            tetraSigs = genomicSig.read(tetraProfileFile)
238            binSig = self.binTetraSig(seqs, tetraSigs)
239            meanTD, deltaTDs = self.tetraDiffDist(seqs, genomicSig, tetraSigs, binSig)
240
241            gffFile = os.path.join(outDir, 'bins', binId, DefaultValues.PRODIGAL_GFF)
242            if not os.path.exists(gffFile):
243                self.logger.error('  [Error] Missing gene feature file (%s). This plot if not compatible with the --genes option.\n' % DefaultValues.PRODIGAL_GFF)
244                sys.exit(1)
245
246            prodigalParser = ProdigalGeneFeatureParser(gffFile)
247            meanCD, deltaCDs, CDs = self.codingDensityDist(seqs, prodigalParser)
248
249            # find keys into GC and CD distributions
250            closestGC = findNearest(np.array(gcBounds.keys()), meanGC)
251            sampleSeqLen = gcBounds[closestGC].keys()[0]
252            d = gcBounds[closestGC][sampleSeqLen]
253            gcLowerBoundKey = findNearest(d.keys(), (100 - distribution) / 2.0)
254            gcUpperBoundKey = findNearest(d.keys(), (100 + distribution) / 2.0)
255
256            closestCD = findNearest(np.array(cdBounds.keys()), meanCD)
257            sampleSeqLen = cdBounds[closestCD].keys()[0]
258            d = cdBounds[closestCD][sampleSeqLen]
259            cdLowerBoundKey = findNearest(d.keys(), (100 - distribution) / 2.0)
260
261            tdBoundKey = findNearest(tdBounds[tdBounds.keys()[0]].keys(), distribution)
262
263            index = 0
264            for seqId, seq in seqs.iteritems():
265                seqLen = len(seq)
266
267                # find GC, CD, and TD bounds
268                closestSeqLen = findNearest(gcBounds[closestGC].keys(), seqLen)
269                gcLowerBound = gcBounds[closestGC][closestSeqLen][gcLowerBoundKey]
270                gcUpperBound = gcBounds[closestGC][closestSeqLen][gcUpperBoundKey]
271
272                closestSeqLen = findNearest(cdBounds[closestCD].keys(), seqLen)
273                cdLowerBound = cdBounds[closestCD][closestSeqLen][cdLowerBoundKey]
274
275                closestSeqLen = findNearest(tdBounds.keys(), seqLen)
276                tdBound = tdBounds[closestSeqLen][tdBoundKey]
277
278                outlyingDists = []
279                if deltaGCs[index] < gcLowerBound or deltaGCs[index] > gcUpperBound:
280                    outlyingDists.append('GC')
281
282                if deltaCDs[index] < cdLowerBound:
283                    outlyingDists.append('CD')
284
285                if deltaTDs[index] > tdBound:
286                    outlyingDists.append('TD')
287
288                if (reportType == 'any' and len(outlyingDists) >= 1) or (reportType == 'all' and len(outlyingDists) == 3):
289                    fout.write(binId + '\t' + seqId + '\t%d' % len(seq) + '\t' + ','.join(outlyingDists))
290                    fout.write('\t%.1f\t%.1f\t%.1f\t%.1f' % (seqGC[index] * 100, meanGC * 100, (meanGC + gcLowerBound) * 100, (meanGC + gcUpperBound) * 100))
291                    fout.write('\t%.1f\t%.1f\t%.1f' % (CDs[index] * 100, meanCD * 100, (meanCD + cdLowerBound) * 100))
292                    fout.write('\t%.3f\t%.3f\t%.3f' % (deltaTDs[index], meanTD, tdBound) + '\n')
293
294                index += 1
295
296        fout.close()
297