1#!/usr/bin/env python
2
3###############################################################################
4#                                                                             #
5#    This program is free software: you can redistribute it and/or modify     #
6#    it under the terms of the GNU General Public License as published by     #
7#    the Free Software Foundation, either version 3 of the License, or        #
8#    (at your option) any later version.                                      #
9#                                                                             #
10#    This program is distributed in the hope that it will be useful,          #
11#    but WITHOUT ANY WARRANTY; without even the implied warranty of           #
12#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the            #
13#    GNU General Public License for more details.                             #
14#                                                                             #
15#    You should have received a copy of the GNU General Public License        #
16#    along with this program. If not, see <http://www.gnu.org/licenses/>.     #
17#                                                                             #
18###############################################################################
19
20"""
21Assess performance of marker set selection criteria.
22"""
23
24__author__ = 'Donovan Parks'
25__copyright__ = 'Copyright 2013'
26__credits__ = ['Donovan Parks']
27__license__ = 'GPL3'
28__version__ = '1.0.0'
29__maintainer__ = 'Donovan Parks'
30__email__ = 'donovan.parks@gmail.com'
31__status__ = 'Development'
32
33import os
34import sys
35import argparse
36import multiprocessing as mp
37from collections import defaultdict
38
39import dendropy
40from  dendropy.dataobject.taxon import Taxon
41
42from numpy import mean, std, abs, array, percentile
43
44from checkm.lib.img import IMG
45from lib.markerSetBuilder import MarkerSetBuilder
46
47class MarkerSetSelection(object):
48    def __init__(self):
49        self.simFile = './experiments/simulation.tuning.genus.summary.tsv'
50        self.looRank = 5
51
52        self.markerSetBuilder = MarkerSetBuilder()
53        self.img = IMG()
54
55    def __stabilityTest(self, genomeIds, ubiquityThreshold = 0.97, singleCopyThreshold = 0.97, stabilityThreshold = 0.05):
56        """Test stability of marker set for a group of genomes using LOO-testing."""
57
58        # quick escape for lineage that are clearly stable
59        if len(genomeIds) > 200:
60            return True
61
62        # calculate marker sets using a LOO-testing
63        looMarkerGenes = []
64        for genomeId in genomeIds:
65            looGenomeIds = genomeIds.difference([genomeId])
66
67            # calculate marker genes
68            geneCountTable = self.img.geneCountTable(looGenomeIds)
69            markerGenes = self.markerSetBuilder.markerGenes(looGenomeIds, geneCountTable, ubiquityThreshold*len(looGenomeIds), singleCopyThreshold*len(looGenomeIds))
70            tigrToRemove = self.img.identifyRedundantTIGRFAMs(markerGenes)
71            markerGenes = markerGenes - tigrToRemove
72
73            looMarkerGenes.append(markerGenes)
74
75        # calculate change in marker set for all pairs
76        markerSetSize = []
77        diffMarkerSet = []
78        for i in xrange(0, len(looMarkerGenes)):
79            markerSetSize.append(len(looMarkerGenes[i]))
80            for j in xrange(i+1, len(looMarkerGenes)):
81                symmDiff = looMarkerGenes[i].symmetric_difference(looMarkerGenes[j])
82                diffMarkerSet.append(len(symmDiff))
83
84        print len(genomeIds), mean(diffMarkerSet), mean(markerSetSize)
85        return (float(mean(diffMarkerSet)) / mean(markerSetSize)) <= stabilityThreshold
86
87    def __patristicDist(self, tree, taxa1, taxa2):
88        mrca = tree.mrca(taxon_labels=[taxa1.taxon.label, taxa2.taxon.label])
89
90        if mrca.parent_node == None:
91            # MRCA is the root of the tree
92            return taxa1.distance_from_root() + taxa2.distance_from_root()
93        else:
94
95            dist = taxa1.edge_length
96            parentNode = taxa1.parent_node
97            while parentNode != mrca:
98                dist += parentNode.edge_length
99                parentNode = parentNode.parent_node
100
101
102            dist += taxa2.edge_length
103            parentNode = taxa2.parent_node
104            while parentNode != mrca:
105                dist += parentNode.edge_length
106                parentNode = parentNode.parent_node
107
108            return dist
109
110    def __distToNodePercentileTest(self, genomeNode, markerSetNode, leaves, percentileTest):
111
112        distToBin = self.__distanceToAncestor(genomeNode, markerSetNode)
113
114        distToLeaves = []
115        for leaf in leaves:
116            distToLeaves.append(self.__distanceToAncestor(leaf, markerSetNode))
117
118        return distToBin < percentile(distToLeaves, percentileTest)
119
120    def __selectMarkerSetNode(self, tree, genomeId, metadata, taxonToGenomeIds):
121        """Determine lineage-specific marker set to use for assessing the giving genome."""
122
123        # read genomes removed from tree as a result of duplicate sequences
124        duplicateSeqs = self.markerSetBuilder.readDuplicateSeqs()
125
126        # determine location of genome in tree
127        node = tree.find_node_with_taxon_label('IMG_' + genomeId)
128
129        # ascend tree to root looking for suitable marker set
130        curNode = node.parent_node
131        while curNode != None:
132            uniqueId = curNode.label.split('|')[0]
133
134            genomeIds = set()
135            for leaf in curNode.leaf_nodes():
136                genomeIds.add(leaf.taxon.label.replace('IMG_', ''))
137
138                duplicateGenomes = duplicateSeqs.get(leaf.taxon.label, [])
139                for dup in duplicateGenomes:
140                    genomeIds.add(dup.replace('IMG_', ''))
141
142            # remove genome (LOO-style analysis)
143            print 'Full:', len(genomeIds)
144            genomeIds.difference_update([genomeId])
145            print 'LOO:', len(genomeIds)
146
147            # remove all genomes from the same taxonomic group as the genome of interest
148            taxon = metadata[genomeId]['taxonomy'][self.looRank]
149            genomeIds.difference_update(taxonToGenomeIds[taxon])
150            print 'Rank reduced:', len(genomeIds)
151
152            print uniqueId
153            if len(genomeIds) > 10 and self.__stabilityTest(genomeIds):
154                uidSelected = uniqueId
155                break
156
157            curNode = curNode.parent_node
158            if curNode == None:
159                # reach root so use universal marker set
160                uidSelected = uniqueId
161
162        return uidSelected
163
164    def __bestMarkerSet(self, genomeId, simResults):
165        """Get stats for best marker set."""
166        curBest = 1000
167        bestUID = None
168        for uid, results in simResults[genomeId].iteritems():
169            numDescendants, dComp, dCont = results
170            if (dComp + dCont) < curBest:
171                numDescendantsBest = numDescendants
172                dCompBest = dComp
173                dContBest = dCont
174                bestUID = uid
175                curBest = dComp + dCont
176
177        return bestUID, numDescendantsBest, dCompBest, dContBest
178
179
180    def __workerThread(self, tree, simResults, metadata, taxonToGenomeIds, queueIn, queueOut):
181        """Process each data item in parallel."""
182
183        while True:
184            testGenomeId = queueIn.get(block=True, timeout=None)
185            if testGenomeId == None:
186                break
187
188            uidSelected = self.__selectMarkerSetNode(tree, testGenomeId, metadata, taxonToGenomeIds)
189            numDescendantsSelected, dCompSelected, dContSelected = simResults[testGenomeId][uidSelected]
190
191            # find best marker set
192            bestUID, numDescendantsBest, dCompBest, dContBest = self.__bestMarkerSet(testGenomeId, simResults)
193
194            queueOut.put((testGenomeId, uidSelected, numDescendantsSelected, dCompSelected, dContSelected, bestUID, numDescendantsBest, dCompBest, dContBest))
195
196    def __writerThread(self, numTestGenomes, writerQueue):
197        """Store or write results of worker threads in a single thread."""
198
199        fout = open('./experiments/markerSetSelection.tsv', 'w')
200
201        fout.write('Genome Id\tSelected UID\t# descendants\tSelected dComp\tSelected dCont\tBest UID\t# descendants\tBest dComp\tBest dCont\tdDescendants\tdComp\tdCont\n')
202
203        itemsToProcess = 0
204
205        dComps = []
206        dConts = []
207
208        dCompsPer = []
209        dContsPer = []
210
211        bestComp = []
212        bestCont = []
213
214        selectedComp = []
215        selectedCont = []
216
217        while True:
218            testGenomeId, uidSelected, numDescendantsSelected, dCompSelected, dContSelected, bestUID, numDescendantsBest, dCompBest, dContBest = writerQueue.get(block=True, timeout=None)
219            if testGenomeId == None:
220                break
221
222            itemsToProcess += 1
223            statusStr = '    Finished processing %d of %d (%.2f%%) test genomes.' % (itemsToProcess, numTestGenomes, float(itemsToProcess)*100/(numTestGenomes))
224            sys.stdout.write('%s\r' % statusStr)
225            sys.stdout.flush()
226
227            dComp = abs(dCompSelected - dCompBest)
228            dCont = abs(dContSelected - dContBest)
229            dDescendants = abs(numDescendantsSelected - numDescendantsBest)
230            fout.write('%s\t%s\t%d\t%.4f\t%.4f\t%s\t%d\t%.4f\t%.4f\t%d\t%.4f\t%.4f\n' % (testGenomeId, uidSelected, numDescendantsSelected, dCompSelected, dContSelected, bestUID, numDescendantsBest, dCompBest, dContBest, dDescendants, dComp, dCont))
231
232            dComps.append(dComp)
233            dConts.append(dCont)
234
235            dCompsPer.append(dComp*100.0 / dCompBest)
236            dContsPer.append(dCont*100.0 / max(dContBest, 0.01))
237
238            bestComp.append(dCompBest)
239            bestCont.append(dContBest)
240
241            selectedComp.append(dCompSelected)
242            selectedCont.append(dContSelected)
243
244        sys.stdout.write('\n')
245        fout.close()
246
247        print ''
248        print '  General results:'
249        print '   Best comp: %.2f +/- %.2f' % (mean(bestComp), std(bestComp))
250        print '   Best cont: %.2f +/- %.2f' % (mean(bestCont), std(bestCont))
251        print '   Selected comp: %.2f +/- %.2f' % (mean(selectedComp), std(selectedComp))
252        print '   Selected cont: %.2f +/- %.2f' % (mean(selectedCont), std(selectedCont))
253        print ''
254        print '   Delta comp: %.2f +/- %.2f' % (mean(dComps), std(dComps))
255        print '   Delta cont: %.2f +/- %.2f' % (mean(dConts), std(dConts))
256        print '   Delta comp per error: %.1f +/- %.1f' % (mean(dCompsPer), std(dCompsPer))
257        print '   Delta cont per error: %.1f +/- %.1f' % (mean(dContsPer), std(dContsPer))
258
259    def __distanceToAncestor(self, leaf, ancestor):
260        dist = 0
261
262        curNode = leaf
263        while curNode != ancestor:
264            dist += curNode.edge_length
265
266            curNode = curNode.parent_node
267
268        return dist
269
270    def __bestNodeProperties(self, genomeId, tree, bestUID):
271        # determine location of genome in tree
272        node = tree.find_node_with_taxon_label('IMG_' + genomeId)
273
274        # find node of best marker set
275        curNode = node.parent_node
276        nodesToBin = 0
277        distanceToBin = node.edge_length
278        distanceToLeaves = []
279        while curNode != None:
280            uniqueId = curNode.label.split('|')[0]
281
282            nodesToBin += 1
283
284            if uniqueId == bestUID:
285                for leaf in curNode.leaf_nodes():
286                    if leaf != node:
287                        dist = self.__distanceToAncestor(leaf, curNode)
288                        distanceToLeaves.append(dist)
289                break
290
291            distanceToBin += curNode.edge_length
292
293            curNode = curNode.parent_node
294
295        return nodesToBin, distanceToBin, mean(distanceToLeaves)
296
297    def __propertiesOfBestMarkerSets(self, tree, simResults):
298
299        numDescendants = []
300        nodesToBin = []
301        distanceToBin = []
302        avgDistanceToLeaf = []
303        percDiffs = []
304        for genomeId in simResults:
305            bestUID, numDescendantsBest, _, _ = self.__bestMarkerSet(genomeId, simResults)
306            nodesToBinBest, distanceToBinBest, avgDistanceToLeafBest = self.__bestNodeProperties(genomeId, tree, bestUID)
307
308            numDescendants.append(numDescendantsBest)
309            nodesToBin.append(nodesToBinBest)
310            distanceToBin.append(distanceToBinBest)
311            avgDistanceToLeaf.append(avgDistanceToLeafBest)
312
313            percDiff = abs(distanceToBinBest - avgDistanceToLeafBest) * 100 / distanceToBinBest
314            percDiffs.append(percDiff)
315
316        print '    # descendants: %.2f +/- %.2f' % (mean(numDescendants), std(numDescendants))
317        print '    # nodes to bin: %.2f +/- %.2f' % (mean(nodesToBin), std(nodesToBin))
318        print '    Distance to bin: %.2f +/- %.2f' % (mean(distanceToBin), std(distanceToBin))
319
320        distanceToBin = array(distanceToBin)
321        avgDistanceToLeaf = array(avgDistanceToLeaf)
322        print '    Distance to bin - average distance to leaf: %.2f +/- %.2f' % (mean(abs(distanceToBin - avgDistanceToLeaf)), std(abs(distanceToBin - avgDistanceToLeaf)))
323        print '    Percent difference to average leaf distance: %.2f +/- %.2f' % (mean(percDiffs), std(percDiffs))
324        print ''
325
326    def run(self, numThreads):
327        # read reference tree
328        print '\n  Reading reference genome tree.'
329        treeFile = os.path.join(os.path.dirname(sys.argv[0]), '..', 'data', 'genome_tree', 'genome_tree_prok.refpkg', 'genome_tree.final.tre')
330        tree = dendropy.Tree.get_from_path(treeFile, schema='newick', as_rooted=True, preserve_underscores=True)
331
332        # get all genomes with a given taxon label
333        metadata = self.img.genomeMetadata()
334        taxonToGenomeIds = defaultdict(set)
335        for genomeId in metadata:
336            for t in metadata[genomeId]['taxonomy']:
337                taxonToGenomeIds[t].add(genomeId)
338
339        # read simulation results
340        print '  Reading simulation results.'
341
342        simResults = defaultdict(dict)
343        with open(self.simFile) as f:
344            f.readline()
345            for line in f:
346                lineSplit = line.split('\t')
347
348                simId = lineSplit[0] + '-' + lineSplit[1] + '-' + lineSplit[2] + '-' + lineSplit[3]
349                uid = lineSplit[5].split('|')[0].strip()
350                numDescendants = int(lineSplit[6])
351                comp = float(lineSplit[21])
352                cont = float(lineSplit[23])
353
354                simResults[simId][uid] = [numDescendants, comp, cont]
355
356        #print ''
357        #print '  Properties of best marker sets:'
358        #self.__propertiesOfBestMarkerSets(tree, simResults)
359
360        print '  Evaluating %d test genomes.' % len(simResults)
361        workerQueue = mp.Queue()
362        writerQueue = mp.Queue()
363
364        for testGenomeId in simResults:
365            workerQueue.put(testGenomeId)
366
367        for _ in range(numThreads):
368            workerQueue.put(None)
369
370        workerProc = [mp.Process(target = self.__workerThread, args = (tree, simResults, metadata, taxonToGenomeIds, workerQueue, writerQueue)) for _ in range(numThreads)]
371        writeProc = mp.Process(target = self.__writerThread, args = (len(simResults), writerQueue))
372
373        writeProc.start()
374
375        for p in workerProc:
376            p.start()
377
378        for p in workerProc:
379            p.join()
380
381        writerQueue.put((None, None, None, None, None, None, None, None, None))
382        writeProc.join()
383
384if __name__ == '__main__':
385    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
386    parser.add_argument('-t', '--threads', help='Threads to use', type=int, default = 40)
387
388    args = parser.parse_args()
389
390    markerSetSelection = MarkerSetSelection()
391    markerSetSelection.run(args.threads)
392