1from __future__ import absolute_import
2from __future__ import print_function
3# #START_LICENSE###########################################################
4#
5#
6# This file is part of the Environment for Tree Exploration program
7# (ETE).  http://etetoolkit.org
8#
9# ETE is free software: you can redistribute it and/or modify it
10# under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# ETE is distributed in the hope that it will be useful, but WITHOUT
15# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
16# or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public
17# License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with ETE.  If not, see <http://www.gnu.org/licenses/>.
21#
22#
23#                     ABOUT THE ETE PACKAGE
24#                     =====================
25#
26# ETE is distributed under the GPL copyleft license (2008-2015).
27#
28# If you make use of ETE in published work, please cite:
29#
30# Jaime Huerta-Cepas, Joaquin Dopazo and Toni Gabaldon.
31# ETE: a python Environment for Tree Exploration. Jaime BMC
32# Bioinformatics 2010,:24doi:10.1186/1471-2105-11-24
33#
34# Note that extra references to the specific methods implemented in
35# the toolkit may be available in the documentation.
36#
37# More info at http://etetoolkit.org. Contact: huerta@embl.de
38#
39#
40# #END_LICENSE#############################################################
41from six import StringIO
42import six.moves.cPickle
43from collections import defaultdict
44import logging
45import os
46import time
47import six
48from six.moves import map
49from six.moves import range
50log = logging.getLogger("main")
51
52from ..master_task import CogSelectorTask
53from ..errors import DataError
54from ..utils import (GLOBALS, print_as_table, generate_node_ids, cmp,
55                     encode_seqname, md5, pjoin, _mean, _median, _max, _min, _std)
56from .. import db
57
58__all__ = ["BrhCogCreator"]
59
60quote = lambda _x: '"%s"' %_x
61
62class BrhCogCreator(CogSelectorTask):
63    def __init__(self, target_sp, out_sp, seqtype, conf, confname):
64
65        self.seed = conf[confname]["_seed"]
66        self.missing_factor = float(conf[confname]["_species_missing_factor"])
67        node_id, clade_id = generate_node_ids(target_sp, out_sp)
68        # Initialize task
69        CogSelectorTask.__init__(self, node_id, "cog_selector",
70                                 "Cog-Selector", None, conf[confname])
71
72        # taskid does not depend on jobs, so I set it manually
73        self.cladeid = clade_id
74        self.seqtype = seqtype
75        self.targets = target_sp
76        self.outgroups = out_sp
77        self.init()
78        self.size = len(target_sp | out_sp)
79        self.cog_analysis = None
80        self.cogs = None
81
82    def finish(self):
83        tm_start = time.ctime()
84        all_species = self.targets | self.outgroups
85        cogs, cog_analysis = brh_cogs2(db, all_species,
86                                      missing_factor=self.missing_factor,
87                                      seed_sp=self.seed)
88        self.raw_cogs = cogs
89        self.cog_analysis = cog_analysis
90        self.cogs = []
91        for co in cogs:
92            # self.cogs.append(map(encode_seqname, co))
93            encoded_names = db.translate_names(co)
94            if len(encoded_names) != len(co):
95                print(set(co) - set(encoded_names.keys()))
96                raise DataError("Some sequence ids could not be translated")
97            self.cogs.append(list(encoded_names.values()))
98
99        # Sort Cogs according to the md5 hash of its content. Random
100        # sorting but kept among runs
101        list(map(lambda x: x.sort(), self.cogs))
102        self.cogs.sort(lambda x,y: cmp(md5(','.join(x)), md5(','.join(y))))
103        log.log(28, "%s COGs detected" %len(self.cogs))
104        tm_end = time.ctime()
105        #open(pjoin(self.taskdir, "__time__"), "w").write(
106        #    '\n'.join([tm_start, tm_end]))
107        CogSelectorTask.store_data(self, self.cogs, self.cog_analysis)
108
109
110def brh_cogs(DB, species, missing_factor=0.0, seed_sp=None, min_score=0):
111    """It scans all precalculate BRH relationships among the species
112       passed as an argument, and detects Clusters of Orthologs
113       according to several criteria:
114
115       min_score: the min coverage/overalp value required for a
116       blast to be a reliable hit.
117
118       missing_factor: the min percentage of species in which a
119       given seq must have  orthologs.
120
121    """
122    log.log(26, "Searching BRH orthologs")
123    species = set(map(str, species))
124
125    min_species = len(species) - round(missing_factor * len(species))
126
127    if seed_sp == "auto":
128        # seed2size = get_sorted_seeds(seed_sp, species, species, min_species, DB)
129        # sort_seeds =  sorted([(len(size), sp) for sp, size in seed2size.iteritems()])
130        # sp_to_test = [sort_seeds[-1][1]]
131        sp_to_test = list(species)
132    elif seed_sp == "largest":
133        cmd = """SELECT taxid, size FROM species"""
134        db.seqcursor.execute(cmd)
135        sp2size = {}
136        for tax, counter in db.seqcursor.fetchall():
137            if tax in species:
138                sp2size[tax] = counter
139
140        sorted_sp = sorted(list(sp2size.items()), lambda x,y: cmp(x[1],y[1]))
141        log.log(24, sorted_sp[:6])
142        largest_sp = sorted_sp[-1][0]
143        sp_to_test = [largest_sp]
144        log.log(28, "Using %s as search seed. Proteome size=%s genes" %\
145            (largest_sp, sp2size[largest_sp]))
146    else:
147        sp_to_test = [str(seed_sp)]
148
149    # The following loop tests each possible seed if none is
150    # specified.
151    log.log(28, "Detecting Clusters of Orthologs groups (COGs)")
152    log.log(28, "Min number of species per COG: %d" %min_species)
153    cogs_selection = []
154
155    for j, seed in enumerate(sp_to_test):
156        log.log(26,"Testing new seed species:%s (%d/%d)", seed, j+1, len(sp_to_test))
157        species_side1 = ','.join(map(quote, [s for s in species if str(s)>str(seed)]))
158        species_side2 = ','.join(map(quote, [s for s in species if str(s)<str(seed)]))
159        pairs1 = []
160        pairs2 = []
161        # Select all ids with matches in the target species, and
162        # return the total number of species covered by each of
163        # such ids.
164        if species_side1 != "":
165            cmd = """SELECT seqid1, taxid1, seqid2, taxid2 from ortho_pair WHERE
166            taxid1="%s" AND taxid2 IN (%s) """ %\
167            (seed, species_side1)
168            DB.orthocursor.execute(cmd)
169            pairs1 = DB.orthocursor.fetchall()
170
171        if species_side2 != "":
172            cmd = """SELECT seqid2, taxid2, seqid1, taxid1 from ortho_pair WHERE
173            taxid1 IN (%s) AND taxid2 = "%s" """ %\
174            (species_side2, seed)
175
176            #taxid2="%s" AND taxid1 IN (%s) AND score >= %s""" %\
177            #(seed, species_side2, min_score)
178            DB.orthocursor.execute(cmd)
179            pairs2 = DB.orthocursor.fetchall()
180
181        cog_candidates = defaultdict(set)
182        for seq1, sp1, seq2, sp2 in pairs1 + pairs2:
183            s1 = (sp1, seq1)
184            s2 = (sp2, seq2)
185            cog_candidates[(sp1, seq1)].update([s1, s2])
186
187        all_cogs = [cand for cand in list(cog_candidates.values()) if
188                    len(cand) >= min_species]
189
190        cog_sizes = [len(cog) for cog in all_cogs]
191        cog_spsizes = [len(set([e[0] for e in cog])) for cog in all_cogs]
192
193        if [1 for i in range(len(cog_sizes)) if cog_sizes[i] != cog_spsizes[i]]:
194            # for i in xrange(len(cog_sizes)):
195            #     if cog_sizes[i] != cog_spsizes[i]:
196            #         print cog_sizes[i], cog_spsizes[i]
197            #         raw_input()
198            raise ValueError("Inconsistent COG found")
199
200        if cog_sizes:
201            cogs_selection.append([seed, all_cogs])
202        log.log(26, "Found %d COGs" % len(all_cogs))
203
204    def _sort_cogs(cogs1, cogs2):
205        cogs1 = cogs1[1] # discard seed info
206        cogs2 = cogs2[1] # discard seed info
207        cog_sizes1 = [len(cog) for cog in cogs1]
208        cog_sizes2 = [len(cog) for cog in cogs2]
209        mx1, mn1, avg1 = _max(cog_sizes1), _min(cog_sizes1), round(_mean(cog_sizes1))
210        mx2, mn2, avg2 = _max(cog_sizes2), _min(cog_sizes2), round(_mean(cog_sizes2))
211
212        # we want to maximize all these values in the following order:
213        for i, j in ((mx1, mx2), (avg1, avg2), (len(cogs1), len(cogs2))):
214            v = -1 * cmp(i, j)
215            if v != 0:
216                break
217        return v
218
219    log.log(26, "Finding best COG selection...")
220    cogs_selection.sort(_sort_cogs)
221    lines = []
222    for seed, all_cogs in cogs_selection:
223        cog_sizes = [len(cog) for cog in all_cogs]
224        mx, mn, avg = max(cog_sizes), min(cog_sizes), round(_mean(cog_sizes))
225        lines.append([seed, mx, mn, avg, len(all_cogs)])
226    analysis_txt = StringIO()
227    print_as_table(lines[:25], stdout=analysis_txt,
228                   header=["Seed","largest COG", "smallest COGs", "avg COG size", "total COGs"])
229    log.log(28, "Analysis details:\n"+analysis_txt.getvalue())
230    best_seed, best_cogs = cogs_selection[0]
231    cog_sizes = [len(cog) for cog in best_cogs]
232
233    # Not necessary since they will be sorted differently later on
234    #best_cogs.sort(lambda x,y: cmp(len(x), len(y)), reverse=True)
235
236    if max(cog_sizes) < len(species):
237        raise ValueError("Current COG selection parameters do not permit to cover all species")
238
239    recoded_cogs = []
240    for cog in best_cogs:
241        named_cog = ["%s%s%s" %(x[0], GLOBALS["spname_delimiter"],x[1]) for x in cog]
242        recoded_cogs.append(named_cog)
243
244    return recoded_cogs, analysis_txt.getvalue()
245
246def brh_cogs2(DB, species, missing_factor=0.0, seed_sp=None, min_score=0):
247    """It scans all precalculate BRH relationships among the species
248       passed as an argument, and detects Clusters of Orthologs
249       according to several criteria:
250
251       min_score: the min coverage/overalp value required for a
252       blast to be a reliable hit.
253
254       missing_factor: the min percentage of species in which a
255       given seq must have  orthologs.
256
257    """
258    def _sort_cogs(cogs1, cogs2):
259        seed1, mx1, avg1, ncogs1 = cogs1
260        seed2, mx2, avg2, ncogs2 = cogs2
261        for i, j in ((mx1, mx2), (avg1, avg2), (ncogs1, ncogs2)):
262            v = -1 * cmp(i, j)
263            if v != 0:
264                break
265        return v
266
267    log.log(26, "Searching BRH orthologs")
268    species = set(map(str, species))
269
270    min_species = len(species) - round(missing_factor * len(species))
271
272    if seed_sp == "auto":
273        sp_to_test = list(species)
274    elif seed_sp == "largest":
275        cmd = """SELECT taxid, size FROM species"""
276        db.seqcursor.execute(cmd)
277        sp2size = {}
278        for tax, counter in db.seqcursor.fetchall():
279            if tax in species:
280                sp2size[tax] = counter
281
282        sorted_sp = sorted(list(sp2size.items()), lambda x,y: cmp(x[1],y[1]))
283        log.log(24, sorted_sp[:6])
284        largest_sp = sorted_sp[-1][0]
285        sp_to_test = [largest_sp]
286        log.log(28, "Using %s as search seed. Proteome size=%s genes" %\
287            (largest_sp, sp2size[largest_sp]))
288    else:
289        sp_to_test = [str(seed_sp)]
290
291    analysis_txt = StringIO()
292    if sp_to_test:
293        log.log(26, "Finding best COG selection...")
294        seed2size = get_sorted_seeds(seed_sp, species, sp_to_test, min_species, DB)
295        size_analysis = []
296        for seedname, content in six.iteritems(seed2size):
297            cog_sizes = [size for seq, size in content]
298            mx, avg = _max(cog_sizes), round(_mean(cog_sizes))
299            size_analysis.append([seedname, mx, avg, len(content)])
300        size_analysis.sort(_sort_cogs)
301        #print '\n'.join(map(str, size_analysis))
302        seed = size_analysis[0][0]
303        print_as_table(size_analysis[:25], stdout=analysis_txt,
304                   header=["Seed","largest COG", "avg COG size", "total COGs"])
305        if size_analysis[0][1] < len(species)-1:
306            print(size_analysis[0][1])
307            raise ValueError("Current COG selection parameters do not permit to cover all species")
308
309    log.log(28, analysis_txt.getvalue())
310    # The following loop tests each possible seed if none is
311    # specified.
312    log.log(28, "Computing Clusters of Orthologs groups (COGs)")
313    log.log(28, "Min number of species per COG: %d" %min_species)
314    cogs_selection = []
315    log.log(26,"Using seed species:%s", seed)
316    species_side1 = ','.join(map(quote, [s for s in species if str(s)>str(seed)]))
317    species_side2 = ','.join(map(quote, [s for s in species if str(s)<str(seed)]))
318    pairs1 = []
319    pairs2 = []
320    # Select all ids with matches in the target species, and
321    # return the total number of species covered by each of
322    # such ids.
323    if species_side1 != "":
324        cmd = """SELECT seqid1, taxid1, seqid2, taxid2 from ortho_pair WHERE
325            taxid1="%s" AND taxid2 IN (%s) """ % (seed, species_side1)
326        DB.orthocursor.execute(cmd)
327        pairs1 = DB.orthocursor.fetchall()
328
329    if species_side2 != "":
330        cmd = """SELECT seqid2, taxid2, seqid1, taxid1 from ortho_pair WHERE
331            taxid1 IN (%s) AND taxid2 = "%s" """ % (species_side2, seed)
332        DB.orthocursor.execute(cmd)
333        pairs2 = DB.orthocursor.fetchall()
334
335    cog_candidates = defaultdict(set)
336    for seq1, sp1, seq2, sp2 in pairs1 + pairs2:
337        s1 = (sp1, seq1)
338        s2 = (sp2, seq2)
339        cog_candidates[(sp1, seq1)].update([s1, s2])
340
341    all_cogs = [cand for cand in list(cog_candidates.values()) if
342                len(cand) >= min_species]
343
344    # CHECK CONSISTENCY
345    seqs = set()
346    for cand in all_cogs:
347        seqs.update([b for a,b  in cand if a == seed])
348    pre_selected_seqs = set([v[0] for v in seed2size[seed]])
349    if len(seqs & pre_selected_seqs) != len(set(seed2size[seed])) or\
350            len(seqs & pre_selected_seqs) != len(seqs):
351        print("old method seqs", len(seqs), "new seqs", len(set(seed2size[seed])), "Common", len(seqs & pre_selected_seqs))
352        raise ValueError("ooops")
353
354    cog_sizes = [len(cog) for cog in all_cogs]
355    cog_spsizes = [len(set([e[0] for e in cog])) for cog in all_cogs]
356
357    if [1 for i in range(len(cog_sizes)) if cog_sizes[i] != cog_spsizes[i]]:
358        raise ValueError("Inconsistent COG found")
359
360    if cog_sizes:
361        cogs_selection.append([seed, all_cogs])
362    log.log(26, "Found %d COGs" % len(all_cogs))
363
364    recoded_cogs = []
365    for cog in all_cogs:
366        named_cog = ["%s%s%s" %(x[0], GLOBALS["spname_delimiter"],x[1]) for x in cog]
367        recoded_cogs.append(named_cog)
368
369    return recoded_cogs, analysis_txt.getvalue()
370
371
372def get_sorted_seeds(seed, species, sp_to_test, min_species, DB):
373    seed2count = {}
374    species = set(species)
375    for j, seed in enumerate(sp_to_test):
376        log.log(26,"Testing SIZE of new seed species:%s (%d/%d)", seed, j+1, len(sp_to_test))
377        pairs1 = []
378        pairs2 = []
379        cmd = """SELECT seqid1, GROUP_CONCAT(taxid2) FROM ortho_pair WHERE
380            taxid1="%s" GROUP BY (seqid1)""" % (seed)
381        DB.orthocursor.execute(cmd)
382        pairs1= DB.orthocursor.fetchall()
383
384        cmd = """SELECT seqid2, GROUP_CONCAT(taxid1) FROM ortho_pair WHERE
385            taxid2 = "%s" GROUP BY seqid2""" % (seed)
386        DB.orthocursor.execute(cmd)
387        pairs2 = DB.orthocursor.fetchall()
388
389
390        # Compute number of species for each seqid representing a cog
391        counter = defaultdict(set)
392        all_pairs = pairs1 + pairs2
393        for seqid, targets in all_pairs:
394            counter[seqid].update(set(targets.split(",")) & species)
395
396        # Filter out too small COGs
397        valid_seqs = [(k, len(v)) for k, v in six.iteritems(counter) if
398                      len(v)>= min_species-1]
399
400        seed2count[seed] = valid_seqs
401        log.log(28, "Seed species:%s COGs:%s" %(seed, len(seed2count[seed])))
402    return seed2count
403
404def get_best_selection(cogs_selections, species):
405    ALL_SPECIES = set(species)
406
407    def _compare_cog_selection(cs1, cs2):
408        seed_1, missing_sp_allowed_1, candidates_1, sp2hits_1 = cs1
409        seed_2, missing_sp_allowed_2, candidates_2, sp2hits_2 = cs2
410
411        score_1, min_cov_1, max_cov_1, median_cov_1, cov_std_1, cog_cov_1 = get_cog_score(candidates_1, sp2hits_1, median_cogs, ALL_SPECIES-set([seed_1]))
412        score_2, min_cov_2, max_cov_2, median_cov_2, cov_std_2, cog_cov_2 = get_cog_score(candidates_2, sp2hits_2, median_cogs, ALL_SPECIES-set([seed_2]))
413
414        sp_represented_1 = len(sp2hits_1)
415        sp_represented_2 = len(sp2hits_1)
416        cmp_rpr = cmp(sp_represented_1, sp_represented_2)
417        if cmp_rpr == 1:
418            return 1
419        elif cmp_rpr == -1:
420            return -1
421        else:
422            cmp_score = cmp(score_1, score_2)
423            if cmp_score == 1:
424                return 1
425            elif cmp_score == -1:
426                return -1
427            else:
428                cmp_mincov = cmp(min_cov_1, min_cov_2)
429                if cmp_mincov == 1:
430                    return 1
431                elif cmp_mincov == -1:
432                    return -1
433                else:
434                    cmp_maxcov = cmp(max_cov_1, max_cov_2)
435                    if cmp_maxcov == 1:
436                        return 1
437                    elif cmp_maxcov == -1:
438                        return -1
439                    else:
440                        cmp_cand = cmp(len(candidates_1), len(candidates_2))
441                        if cmp_cand == 1:
442                            return 1
443                        elif cmp_cand == -1:
444                            return -1
445                        else:
446                            return 0
447
448    min_score = 0.5
449    max_cogs = _max([len(data[2]) for data in cogs_selections])
450    median_cogs = _median([len(data[2]) for data in cogs_selections])
451
452    cogs_selections.sort(_compare_cog_selection)
453    cogs_selections.reverse()
454
455    header = ['seed',
456              'missing sp allowed',
457              'spcs covered',
458              '#COGs',
459              'mean sp coverage)',
460              '#COGs for worst sp.',
461              '#COGs for best sp.',
462              'sp. in COGS(avg)',
463              'SCORE' ]
464    print_header = True
465    best_cog_selection = None
466    cog_analysis = StringIO()
467    for i, cogs in enumerate(cogs_selections):
468        seed, missing_sp_allowed, candidates, sp2hits = cogs
469        sp_percent_coverages = [(100*sp2hits.get(sp,0))/float(len(candidates)) for sp in species]
470        sp_coverages = [sp2hits.get(sp, 0) for sp in species]
471        score, min_cov, max_cov, median_cov, cov_std, cog_cov = get_cog_score(candidates, sp2hits, median_cogs, ALL_SPECIES-set([seed]))
472
473        if best_cog_selection is None:
474            best_cog_selection = i
475            flag = "*"
476        else:
477            flag = " "
478        data = (candidates,
479                flag+"%10s" %seed, \
480                    missing_sp_allowed, \
481                    "%d (%0.1f%%)" %(len(set(sp2hits.keys()))+1, 100*float(len(ALL_SPECIES))/(len(sp2hits)+1)) , \
482                    len(candidates), \
483                    "%0.1f%% +- %0.1f" %(_mean(sp_percent_coverages), _std(sp_percent_coverages)), \
484                    "% 3d (%0.1f%%)" %(min(sp_coverages),100*min(sp_coverages)/float(len(candidates))), \
485                    "% 3d (%0.1f%%)" %(max(sp_coverages),100*max(sp_coverages)/float(len(candidates))), \
486                    cog_cov,
487                    score
488                )
489        if print_header:
490            print_as_table([data[1:]], header=header, print_header=True, stdout=cog_analysis)
491            print_header = False
492        else:
493            print_as_table([data[1:]], header=header, print_header=False, stdout=cog_analysis)
494
495    #raw_input("Press")
496    print(cog_analysis.getvalue())
497    #best_cog_selection = int(raw_input("choose:"))
498    return cogs_selections[best_cog_selection], cog_analysis
499
500def _analyze_cog_selection(all_cogs):
501    print("total cogs:", len(all_cogs))
502    sp2cogcount = {}
503    size2cogs = {}
504    for cog in all_cogs:
505        for seq in cog:
506            sp = seq.split(GLOBALS["spname_delimiter"])[0]
507            sp2cogcount[sp] = sp2cogcount.setdefault(sp, 0)+1
508        size2cogs.setdefault(len(cog), []).append(cog)
509
510    sorted_spcs = sorted(list(sp2cogcount.items()), lambda x,y: cmp(x[1], y[1]))
511    # Take only first 20 species
512    coverages = [s[1] for s in sorted_spcs][:20]
513    spnames  = [str(s[0])+ s[0] for s in sorted_spcs][:20]
514    pylab.subplot(1,2,1)
515    pylab.bar(list(range(len(coverages))), coverages)
516    labels = pylab.xticks(pylab.arange(len(spnames)), spnames)
517    pylab.subplots_adjust(bottom=0.35)
518    pylab.title(str(len(all_cogs))+" COGs")
519    pylab.setp(labels[1], 'rotation', 90,fontsize=10, horizontalalignment = 'center')
520    pylab.subplot(1,2,2)
521    pylab.title("Best COG contains "+str(max(size2cogs.values()))+" species" )
522    pylab.bar(list(range(1,216)), [len(size2cogs.get(s, [])) for s in range(1,216)])
523    pylab.show()
524
525
526def cog_info(candidates, sp2hits):
527    sp_coverages = [hits/float(len(candidates)) for hits in list(sp2hits.values())]
528    species_covered = len(set(sp2hits.keys()))+1
529    min_cov = _min(sp_coverages)
530    max_cov = _min(sp_coverages)
531    median_cov = _median(sp_coverages)
532    return min_cov, max_cov, median_cov
533
534
535def get_cog_score(candidates, sp2hits, max_cogs, all_species):
536
537    cog_cov = _mean([len(cogs) for cogs in candidates])/float(len(sp2hits)+1)
538    cog_mean_cov = _mean([len(cogs)/float(len(sp2hits)) for cogs in candidates]) # numero medio de especies en cada cog
539    cog_min_sp = _min([len(cogs) for cogs in candidates])
540
541    sp_coverages = [sp2hits.get(sp, 0)/float(len(candidates)) for sp in all_species]
542    species_covered = len(set(sp2hits.keys()))+1
543
544    nfactor = len(candidates)/float(max_cogs) # Numero de cogs
545    min_cov = _min(sp_coverages) # el coverage de la peor especie
546    max_cov = _min(sp_coverages)
547    median_cov = _median(sp_coverages)
548    cov_std = _std(sp_coverages)
549
550    score = _min([nfactor, cog_mean_cov, min_cov])
551    return score, min_cov, max_cov, median_cov, cov_std, cog_cov
552
553