1#!/usr/bin/env python
2
3###############################################################################
4#
5# calculateBoundsDeltaGC.py - find confidence intervals for GC distribution
6#
7###############################################################################
8#                                                                             #
9#    This program is free software: you can redistribute it and/or modify     #
10#    it under the terms of the GNU General Public License as published by     #
11#    the Free Software Foundation, either version 3 of the License, or        #
12#    (at your option) any later version.                                      #
13#                                                                             #
14#    This program is distributed in the hope that it will be useful,          #
15#    but WITHOUT ANY WARRANTY; without even the implied warranty of           #
16#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the            #
17#    GNU General Public License for more details.                             #
18#                                                                             #
19#    You should have received a copy of the GNU General Public License        #
20#    along with this program. If not, see <http://www.gnu.org/licenses/>.     #
21#                                                                             #
22###############################################################################
23
24import sys
25import multiprocessing as mp
26from string import maketrans
27import logging
28
29import numpy as np
30
31from checkm.util.seqUtils import readFasta
32
33
34class GenomicSignatures(object):
35    def __init__(self, K, threads):
36        self.logger = logging.getLogger('timestamp')
37
38        self.K = K
39        self.compl = maketrans('ACGT', 'TGCA')
40        self.kmerCols, self.kmerToCanonicalIndex = self.__makeKmerColNames()
41
42        self.totalThreads = threads
43
44    def __makeKmerColNames(self):
45        """Work out unique kmers."""
46
47        # determine all mers of a given length
48        baseWords = ("A", "C", "G", "T")
49        mers = ["A", "C", "G", "T"]
50        for _ in range(1, self.K):
51            workingList = []
52            for mer in mers:
53                for char in baseWords:
54                    workingList.append(mer + char)
55            mers = workingList
56
57        # pare down kmers based on lexicographical ordering
58        retList = []
59        for mer in mers:
60            kmer = self.__lexicographicallyLowest(mer)
61            if kmer not in retList:
62                retList.append(kmer)
63
64        sorted(retList)
65
66        # create mapping from kmers to their canonical order position
67        kmerToCanonicalIndex = {}
68        for index, kmer in enumerate(retList):
69            kmerToCanonicalIndex[kmer] = index
70            kmerToCanonicalIndex[self.__revComp(kmer)] = index
71
72        return retList, kmerToCanonicalIndex
73
74    def __lexicographicallyLowest(self, seq):
75        """Return the lexicographically lowest form of this sequence."""
76        rseq = self.__revComp(seq)
77        if(seq < rseq):
78            return seq
79        return rseq
80
81    def __revComp(self, seq):
82        """Return the reverse complement of a sequence."""
83        # build a dictionary to know what letter to switch to
84        return seq.translate(self.compl)[::-1]
85
86    def __calculateResults(self, queueIn, queueOut):
87        """Calculate genomic signature of sequences in parallel."""
88        while True:
89            seqId, seq = queueIn.get(block=True, timeout=None)
90            if seqId == None:
91                break
92
93            sig = self.seqSignature(seq)
94
95            queueOut.put((seqId, sig))
96
97    def __storeResults(self, seqFile, outputFile, totalSeqs, writerQueue):
98        """Store genomic signatures to file."""
99
100        # write header
101        fout = open(outputFile, 'w')
102        fout.write('Sequence Id')
103        for kmer in self.canonicalKmerOrder():
104            fout.write('\t' + kmer)
105        fout.write('\n')
106
107        numProcessedSeq = 0
108        while True:
109            seqId, sig = writerQueue.get(block=True, timeout=None)
110            if seqId == None:
111                break
112
113            if self.logger.getEffectiveLevel() <= logging.INFO:
114                numProcessedSeq += 1
115                statusStr = '    Finished processing %d of %d (%.2f%%) sequences.' % (numProcessedSeq, totalSeqs, float(numProcessedSeq) * 100 / totalSeqs)
116                sys.stderr.write('%s\r' % statusStr)
117                sys.stderr.flush()
118
119            fout.write(seqId)
120            fout.write('\t' + '\t'.join(map(str, sig)))
121            fout.write('\n')
122
123        if self.logger.getEffectiveLevel() <= logging.INFO:
124            sys.stderr.write('\n')
125
126        fout.close()
127
128    def canonicalKmerOrder(self):
129        return self.kmerCols
130
131    def seqSignature(self, seq):
132        sig = [0] * len(self.kmerCols)
133
134        tmp_seq = seq.upper()
135
136        numMers = len(tmp_seq) - self.K + 1
137        for i in range(0, numMers):
138            try:
139                kmerIndex = self.kmerToCanonicalIndex[tmp_seq[i:i + self.K]]
140                sig[kmerIndex] += 1  # Note: a numpy array would be slow here due to this single element increment
141            except KeyError:
142                # unknown kmer (e.g., contains a N)
143                pass
144
145        # normalize
146        sig = np.array(sig, dtype=float)
147        sig /= np.sum(sig)
148
149        return sig
150
151    def calculate(self, seqFile, outputFile):
152        """Calculate genomic signature of each sequence."""
153
154        self.logger.info('Determining tetranucleotide signature of each sequence.')
155
156        # process each sequence in parallel
157        workerQueue = mp.Queue()
158        writerQueue = mp.Queue()
159
160        seqs = readFasta(seqFile)
161
162        for seqId, seq in seqs.iteritems():
163            workerQueue.put((seqId, seq))
164
165        for _ in range(self.totalThreads):
166            workerQueue.put((None, None))
167
168        try:
169            calcProc = [mp.Process(target=self.__calculateResults, args=(workerQueue, writerQueue)) for _ in range(self.totalThreads)]
170            writeProc = mp.Process(target=self.__storeResults, args=(seqFile, outputFile, len(seqs), writerQueue))
171
172            writeProc.start()
173
174            for p in calcProc:
175                p.start()
176
177            for p in calcProc:
178                p.join()
179
180            writerQueue.put((None, None))
181            writeProc.join()
182        except:
183            # make sure all processes are terminated
184            for p in calcProc:
185                p.terminate()
186
187            writeProc.terminate()
188
189    def distance(self, sig1, sig2):
190        return np.sum(np.abs(sig1 - sig2))
191
192    def read(self, tetraProfileFile):
193        sig = {}
194        with open(tetraProfileFile) as f:
195            next(f)
196            for line in f:
197                lineSplit = line.split('\t')
198                sig[lineSplit[0]] = np.array([float(x) for x in lineSplit[1:]])
199
200        return sig
201