1###############################################################################
2#
3# aminoAcidIdentity.py - calculate AAI between aligned marker genes
4#
5###############################################################################
6#                                                                             #
7#    This program is free software: you can redistribute it and/or modify     #
8#    it under the terms of the GNU General Public License as published by     #
9#    the Free Software Foundation, either version 3 of the License, or        #
10#    (at your option) any later version.                                      #
11#                                                                             #
12#    This program is distributed in the hope that it will be useful,          #
13#    but WITHOUT ANY WARRANTY; without even the implied warranty of           #
14#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the            #
15#    GNU General Public License for more details.                             #
16#                                                                             #
17#    You should have received a copy of the GNU General Public License        #
18#    along with this program. If not, see <http://www.gnu.org/licenses/>.     #
19#                                                                             #
20###############################################################################
21
22import os
23import sys
24import logging
25from collections import defaultdict
26
27from checkm.defaultValues import DefaultValues
28from checkm.common import getBinIdsFromOutDir
29from checkm.util.seqUtils import readFasta
30
31
32class AminoAcidIdentity():
33    """Calculate AAI between sequences aligned to an HMM."""
34    def __init__(self):
35        self.logger = logging.getLogger('timestamp')
36        self.aaiRawScores = defaultdict(dict)
37        self.aaiHetero = defaultdict(dict)
38        self.aaiMeanBinHetero = {}
39
40    def run(self, aaiStrainThreshold, outDir, alignmentOutputFile):
41        """Calculate AAI between input alignments."""
42
43        self.logger.info('Calculating AAI between multi-copy marker genes.')
44
45        if alignmentOutputFile:
46            fout = open(alignmentOutputFile, 'w')
47
48        # calculate AAI for duplicate marker genes
49        binIds = getBinIdsFromOutDir(outDir)
50        aaiOutputDir = os.path.join(outDir, 'storage', 'aai_qa')
51        for binId in binIds:
52            binPath = os.path.join(aaiOutputDir, binId)
53            if not os.path.exists(binPath):
54                continue
55
56            for f in os.listdir(binPath):
57                if not f.endswith('.masked.faa'):
58                    continue
59
60                markerId = f[0:f.find('.')]
61
62                seqs = readFasta(os.path.join(binPath, f))
63
64                # calculate AAI between all pairs of seqs
65                for i in xrange(0, len(seqs)):
66                    seqIdI = seqs.keys()[i]
67                    binIdI = seqIdI[0:seqIdI.find(DefaultValues.SEQ_CONCAT_CHAR)]
68
69                    seqI = seqs[seqIdI]
70
71                    for j in xrange(i + 1, len(seqs)):
72                        seqIdJ = seqs.keys()[j]
73                        binIdJ = seqIdJ[0:seqIdJ.find(DefaultValues.SEQ_CONCAT_CHAR)]
74
75                        seqJ = seqs[seqIdJ]
76
77                        if binIdI == binIdJ:
78                            aai = self.aai(seqI, seqJ)
79
80                            if alignmentOutputFile:
81                                fout.write(binId + ',' + markerId + '\n')
82                                fout.write(seqIdI + '\t' + seqI + '\n')
83                                fout.write(seqIdJ + '\t' + seqJ + '\n')
84                                fout.write('AAI: %.3f\n' % aai)
85                                fout.write('\n')
86
87                            if binIdI not in self.aaiRawScores:
88                                self.aaiRawScores[binIdI] = defaultdict(list)
89                            self.aaiRawScores[binIdI][markerId].append(aai)
90                        else:
91                            # something is wrong as the bin Ids should always be the same
92                            self.logger.error('  [Error] Bin ids do not match.')
93                            sys.exit(1)
94
95        if alignmentOutputFile:
96            fout.close()
97
98        # calculate strain heterogeneity for each marker gene in each bin
99        self.aaiHetero, self.aaiMeanBinHetero = self.strainHetero(self.aaiRawScores, aaiStrainThreshold)
100
101    def strainHetero(self, aaiScores, aaiStrainThreshold):
102        """Calculate strain heterogeneity."""
103        aaiHetero = defaultdict(dict)
104        aaiMeanBinHetero = {}
105
106        for binId, markerIds in aaiScores.iteritems():
107            strainCount = 0
108            multiCopyPairs = 0
109
110            aaiHetero[binId] = {}
111
112            for markerId, aaiScores in markerIds.iteritems():
113                localStrainCount = 0
114                for aaiScore in aaiScores:
115                    multiCopyPairs += 1
116                    if aaiScore > aaiStrainThreshold:
117                        strainCount += 1
118                        localStrainCount += 1
119
120                strainHetero = float(localStrainCount) / len(aaiScores)
121                aaiHetero[binId][markerId] = strainHetero
122
123            aaiMeanBinHetero[binId] = 100 * float(strainCount) / multiCopyPairs
124
125        return aaiHetero, aaiMeanBinHetero
126
127    def aai(self, seq1, seq2):
128        """Calculate amino acid identity between sequences."""
129        assert len(seq1) == len(seq2)
130
131        # calculation of AAI should ignore missing data at
132        # the start of end of each sequence
133        startIndex = 0
134        for i in xrange(0, len(seq1)):
135            if seq1[i] == '-' or seq2[i] == '-':
136                startIndex = i + 1
137            else:
138                break
139
140        endIndex = len(seq1)
141        for i in xrange(len(seq1) - 1, 0, -1):
142            if seq1[i] == '-' or seq2[i] == '-':
143                endIndex = i
144            else:
145                break
146
147        mismatches = 0
148        seqLen = 0
149        for i in xrange(startIndex, endIndex):
150            if seq1[i] != seq2[i]:
151                mismatches += 1
152                seqLen += 1
153            elif seq1[i] == '-' and seq2[i] == '-':
154                pass
155            else:
156                seqLen += 1
157
158        if seqLen == 0:
159            return 0.0
160
161        return 1.0 - (float(mismatches) / seqLen)
162