1from Bio import Phylo
2from Bio.Phylo import PhyloXML
3from Bio.Phylo import PhyloXMLIO
4from collections import defaultdict as ddict
5from Bio.Phylo.PhyloXML import Property as Prop
6from Bio.Phylo.PhyloXML import Clade as PClade
7from Bio.Phylo.BaseTree import Tree as BTree
8from Bio.Phylo.BaseTree import Clade as BClade
9import string
10from numpy import pi as rpi
11rpi2 = 2.0*rpi
12import numpy as np
13import array as arr
14import collections as colls
15import sys
16#core_test = lambda ok,tot,pr: 1.0-st.binom.sf(ok,tot,pr)
17
18lev_sep = "."
19
20# Here are three functions that I'd love to see in Biopython but they
21# are not there (yet).
22def partial_branch_length(clade, selective_targets):
23    def _partial_branch_length_( clade, selective_targets ):
24        if clade.is_terminal() and clade.name in selective_targets:
25            return [clade.branch_length]
26        if not any([c.name in selective_targets for c in clade.get_terminals()]):
27            return [0.0]
28        ret = [0.0]
29        for c in clade.clades:
30            ret += [partial_branch_length( c, selective_targets)]
31        ret += [clade.branch_length]
32        return ret
33    return sum( _partial_branch_length_( clade,selective_targets  )  )
34
35def reroot(tree, new_root):
36    outgroup = new_root
37    outgroup_path = tree.get_path(outgroup)
38    if len(outgroup_path) == 0:
39        # Outgroup is the current root -- no change
40        return
41    prev_blen = outgroup.branch_length
42    if outgroup.is_terminal():
43        # Create a new root with a 0-length branch to the outgroup
44        outgroup.branch_length = 0.0
45        new_root = tree.root.__class__(
46                branch_length=tree.root.branch_length, clades=[outgroup])
47        # The first branch reversal (see the upcoming loop) is modified
48        if len(outgroup_path) == 1:
49            # Trivial tree like '(A,B);
50            new_parent = new_root
51        else:
52            parent = outgroup_path.pop(-2)
53            parent.clades.pop(parent.clades.index(outgroup))
54            prev_blen, parent.branch_length = parent.branch_length, prev_blen
55            new_root.clades.insert(0, parent)
56            new_parent = parent
57    else:
58        # Use the given outgroup node as the new (trifurcating) root
59        new_root = outgroup
60        new_root.branch_length = tree.root.branch_length
61        new_parent = new_root
62
63    # Tracing the outgroup lineage backwards, reattach the subclades under a
64    # new root clade. Reverse the branches directly above the outgroup in
65    # the tree, but keep the descendants of those clades as they are.
66    for parent in outgroup_path[-2::-1]:
67        parent.clades.pop(parent.clades.index(new_parent))
68        prev_blen, parent.branch_length = parent.branch_length, prev_blen
69        new_parent.clades.insert(0, parent)
70        new_parent = parent
71
72    # Finally, handle the original root according to number of descendents
73    old_root = tree.root
74    if outgroup in old_root.clades:
75        assert len(outgroup_path) == 1
76        old_root.clades.pop(old_root.clades.index(outgroup))
77    else:
78        old_root.clades.pop(old_root.clades.index(new_parent))
79    if len(old_root) == 1:
80        # Delete the old bifurcating root & add branch lengths
81        ingroup = old_root.clades[0]
82        if ingroup.branch_length:
83            ingroup.branch_length += prev_blen
84        else:
85            ingroup.branch_length = prev_blen
86        new_parent.clades.insert(0, ingroup)
87        # ENH: If annotations are attached to old_root, do... something.
88    else:
89        # Keep the old trifurcating/polytomous root as an internal node
90        old_root.branch_length = prev_blen
91        new_parent.clades.insert(0, old_root)
92
93    tree.root = new_root
94    tree.rooted = True
95    return
96
97def get_parent(tree, child_clade):
98    node_path = tree.get_path(child_clade)
99    return node_path[-2] if len(node_path) > 1 else None
100
101def reroot_mid_fat_edge( tree, node ):
102    if tree.root == node:
103        return
104    fat = get_parent( tree, node )
105    bl = node.branch_length
106    node.branch_length = bl*0.5
107    new_clade = PClade(branch_length=bl*0.5, clades = [node])
108    if fat:
109        fat.clades = [c for c in fat.clades if c != node] + [new_clade]
110        reroot( tree, new_clade )
111    else:
112        tree.root.clades = [new_clade] + [c for c in tree.root.clades if c != node]
113        reroot( tree, new_clade)
114
115def clades2terms( tree, startswith = None ):
116    c2t = {}
117    def clades2terms_rec( c ):
118        if startswith:
119            if c.name and c.name.startswith( startswith ):
120                c2t[c] = c.get_terminals()
121        else:
122            c2t[c] = c.get_terminals()
123        for cc in c.clades:
124            clades2terms_rec(cc)
125    clades2terms_rec( tree.root )
126    return c2t
127
128def dist_matrix( tree ):
129    terminals = list(tree.get_terminals())
130    term_names = [t.name for t in terminals]
131    # can be made faster with recursion
132    for n in tree.get_nonterminals():
133        n.ids = set( [nn.name for nn in n.get_terminals()]  )
134
135    dists = dict([(n,dict([(nn,0.0) for nn in term_names])) for n in term_names])
136
137    def dist_matrix_rec( clade ):
138        bl = clade.branch_length
139        if clade.is_terminal():
140            for t in term_names:
141                if t!=clade.name:
142                    dists[clade.name][t] += bl
143                    dists[t][clade.name] += bl
144            return
145        for t1 in clade.ids:
146            for t2 in terminals:
147                if t2.name not in clade.ids:
148                    dists[t1][t2.name] += bl
149                    dists[t2.name][t1] += bl
150        for c in clade.clades:
151            dist_matrix_rec( c )
152
153    dist_matrix_rec( tree.root )
154    return dists
155
156
157
158class PpaTree:
159
160    def __load_tree_txt__( self, fn ):
161        tree = Phylo.BaseTree.Tree()
162        try:
163            rows = [l.decode('utf-8').rstrip().split("\t")[0] for l in
164                        open(fn, 'rb')]
165        except IOError:
166            raise IOError()
167
168        clades = [r.split(lev_sep) for r in rows]
169
170        tree = BTree()
171        tree.root = BClade()
172
173        def add_clade_rec( father, txt_tree ):
174            fl = []
175            for c in [t[0] for t in txt_tree]:
176                if c in fl: continue
177                else: fl.append(c)
178            father.clades = []
179            for c in fl:
180                nclade = BClade( branch_length = 1.0,
181                                 name = c )
182                father.clades.append( nclade )
183                children = [t[1:] for t in txt_tree if len(t)>1 and t[0] == c]
184                if children:
185                    add_clade_rec( nclade, children )
186
187        add_clade_rec( tree.root, clades )
188        self.ignore_branch_len = 1
189        return tree.as_phyloxml()
190
191
192    def __read_tree__( self, fn ):
193        for ff in ['phyloxml','newick','nexus',"txt"]:
194            try:
195                if ff in ['txt']:
196                    tree = self.__load_tree_txt__( fn )
197                else:
198                    tree = Phylo.read(fn, ff)
199                    if len(tree.root.get_terminals()) == 1:
200                        raise ValueError
201            except ValueError:
202                continue
203            except IOError:
204                sys.stderr.write("Error: No tree file found: "+fn+"\n")
205                raise IOError
206            except Exception:
207                continue
208            else:
209                return tree.as_phyloxml()
210        sys.stderr.write("Error: unrecognized input format "+fn+"\n")
211        raise ValueError
212
213
214    def __init__( self, filename, warnings = False ):
215        self.warnings = warnings
216        if filename is None:
217            self.tree = None
218            return
219        try:
220            self.tree = self.__read_tree__(filename)
221            self.add_full_paths()
222        except:
223            sys.exit(0)
224
225
226    def core_test( self, ok, tot, pr ):
227        # scipy included here for non-compatibility with scons
228        import scipy.stats as st
229        if pr in self.ctc and tot in self.ctc[pr] and ok in self.ctc[pr][tot]:
230            return self.ctc[pr][tot][ok]
231        ret = 1.0-st.binom.sf(ok,tot,pr)
232        if not pr in self.ctc: self.ctc[pr] = {}
233        if not tot in self.ctc[pr]: self.ctc[pr][tot] = {}
234        if not ok in self.ctc[pr][tot]: self.ctc[pr][tot][ok] = ret
235        return ret
236
237    def is_core( self, clade, targs, er = 0.95 ):
238        intersection = clade.imgids & targs
239
240        len_intersection = len(intersection)
241
242        if len(clade.imgids) >= 2 and len_intersection < 2:
243           return False, 0.0, None
244
245        add = 0
246        for subclade in clade.clades:
247            if "?" in subclade.name:
248                out = subclade.imgids - intersection # targs
249                add += len(out)
250        if add and len_intersection >= add:
251            len_intersection += int(round(add/1.99))
252
253        core = self.core_test( len_intersection, clade.nterminals, er )
254        if core < 0.05 or len_intersection == 0:
255            return False, core, None
256        nsubclades, nsubclades_absent = 0, 0
257        for subclade in set(clade.get_nonterminals()) - set([clade]):
258            if "?" in subclade.full_name: # full??/
259                continue
260            if subclade.nterminals == 1:
261                nsubclades += 1 # !!!
262                if len(subclade.imgids & targs) == 0:
263                    nsubclades_absent += 1
264                continue
265
266            sc_intersection = subclade.imgids & targs
267            sc_len_intersection = len(sc_intersection)
268
269            sc_add = 0
270            for sc_subclade in subclade.clades:
271                if "?" in sc_subclade.name:
272                    sc_out = sc_subclade.imgids - sc_intersection
273                    sc_add += len(sc_out)
274            if add and sc_len_intersection >= sc_add:
275                sc_len_intersection += int(round(sc_add/1.99))
276
277            subcore = self.core_test( sc_len_intersection, subclade.nterminals, er )
278            if subcore < 0.05:
279                return False, core, None
280        if nsubclades > 0 and nsubclades == nsubclades_absent:
281            return False, core, None
282        return True, core, intersection
283
284    def _find_core( self, terminals, er = 0.95, root_name = None, skip_qm = True ):
285        #terminals_s = set(terminals)
286        def _find_core_rec( clade ):
287            if root_name:
288                #clname = lev_sep.join( [root_name]+clade.full_name.split(lev_sep)[1:] )
289                #clname = lev_sep.join( clade.full_name[1:] )
290                clname = clade.full_name
291            else:
292                clname = clade.full_name
293            if clade.is_terminal():
294                if clade.imgid in terminals:
295                    #n = terminals[clade.imgid]
296                    return [(clname,1,1,
297                                #n,n,n,
298                                1.0)]
299                return []
300            if skip_qm and clade.name and "?" in clade.name:
301                return []
302            if len(clade.imgids) == 1:
303                cimg = list(clade.imgids)[0]
304                if cimg in terminals:
305                    #n = terminals[cimg]
306                    return [(clname,1,1,
307                                #n,n,n,
308                                1.0)]
309                return []
310            core,pv,intersection = self.is_core( clade, terminals, er = er )
311            if core:
312                #ns = [terminals[ii] for ii in terminals_s if ii in clade.imgids]
313                return [( clname,
314                          len(intersection),len(clade.imgids),
315                          #len(clade.imgids&terminals_s),len(clade.imgids),
316                          #min(ns),max(ns),np.mean(ns),
317                          pv)]
318            rets = []
319            for c in clade.clades:
320                rets += _find_core_rec(c)
321            return rets
322        return  _find_core_rec( self.tree.root )
323
324
325    def add_full_paths( self ):
326
327        def _add_full_paths_( clade, path ):
328            lpath = path + ([clade.name] if clade.name else [])
329            clade.full_name = ".".join( lpath )
330            for c in clade.clades:
331                _add_full_paths_( c, lpath )
332        _add_full_paths_( self.tree.root, [] )
333
334    def find_cores( self, cl_taxa_file, min_core_size = 1, error_rate = 0.95, subtree = None, skip_qm = True ):
335        if subtree:
336            self.subtree( 'name', subtree )
337        self.ctc = {}
338        imgids2terminals = {}
339        for t in self.tree.get_terminals():
340            t.imgid = int(t.name[3:] if "t__"in t.name else t.name)
341            t.nterminals = 1
342            imgids2terminals[t.imgid] = t
343
344        # can be made faster with recursion
345        for n in self.tree.get_nonterminals():
346            n.imgids = set( [nn.imgid for nn in n.get_terminals()]  )
347            n.nterminals = len( n.imgids )
348
349        self.add_full_paths() # unnecessary
350
351        ret = {}
352        for vec in (l.strip().split('\t') for l in open(cl_taxa_file)):
353            sid = int(vec[0])
354            #tgts_l = [int(s) for s in vec[1:]]
355            #tgts = dict([(s,tgts_l.count(s)) for s in set(tgts_l)])
356            tgts = set([int(s) for s in vec[1:]])
357
358            if len(tgts) >= min_core_size:
359                subtree_name = lev_sep.join(subtree.split(lev_sep)[:-1] ) if subtree else None
360                ret[sid] = self._find_core( tgts, er = error_rate, root_name = subtree, skip_qm = skip_qm )
361                #print sid #, ret[sid]
362        return ret
363
364    def markerness( self, coreness, uniqueness, cn_min, cn_max, cn_avg ):
365        return coreness * uniqueness * (1.0 / float(cn_max-cn_min+1)) * 1.0 / cn_avg
366
367    def find_markers( self, cu_file, hitmap_file, core_file ):
368        self.ctc = {}
369        imgids2terminals = {}
370        ids2clades = {}
371        for t in self.tree.get_terminals():
372            t.imgid = int(t.name)
373            t.nterminals = 1
374            imgids2terminals[t.imgid] = t
375            ids2clades[t.name] = t
376
377        # can be made faster with recursion (but it is not a bottleneck)
378        for n in self.tree.get_nonterminals():
379            n.imgids = set( [nn.imgid for nn in n.get_terminals()]  )
380            n.nterminals = len( n.imgids )
381
382        self.add_full_paths() # unnecessary
383
384        cus = dict([(int(l[0]),[int(ll) for ll in l[1:]]) for l in
385                        (line.strip().split('\t') for line in open(cu_file))])
386        cinfo = dict([(int(v[0]),[v[1]] + [int(vv) for vv in v[2:6]] + [float(vv) for vv in v[6:]])
387                        for v in (line.strip().split('\t') for line in open(core_file))])
388
389        ret = {}
390        for vec in (l.strip().split('\t') for l in open(hitmap_file)):
391            sid = int(vec[0])
392            tgts_l = set([int(s) for s in vec[1:]])
393            lca = self.lca( cus[sid], ids2clades )
394            if lca.is_terminal():
395                tin = set([lca.imgid])
396                tout = tgts_l - tin
397            else:
398                tout = tgts_l - lca.imgids
399                tin = lca.imgids & tgts_l
400            ci = cinfo[sid]
401            ltin = len(tin)
402            ltout = len(tout)
403            uniqueness = float(ltin)/float(ltin+ltout)
404            coreness = float( ci[-1] )
405            cn_min, cp_max, cn_avg = [float(f) for f in ci[-4:-1]]
406            gtax = ci[0]
407            cobs, ctot = int(ci[1]), int(ci[2])
408            markerness = self.markerness( coreness, uniqueness, cn_min, cp_max, cn_avg )
409
410            res_lin = [ gtax, markerness, coreness, uniqueness, cobs, ctot, cn_min, cp_max, cn_avg,
411                        ltin, ltout, "|".join([str(s) for s in tin]), "|".join([str(s) for s in tout]) ]
412            ret[sid] = res_lin
413        return ret
414
415
416    def select_markers( self, marker_file, markerness_th = 0.0, max_markers = 200 ):
417        cl2markers = colls.defaultdict( list )
418        for line in (l.strip().split('\t') for l in open( marker_file )):
419            gid = line[1]
420            markerness = float(line[2])
421            if markerness < markerness_th:
422                continue
423            cl2markers[gid].append( line )
424        for k,v in cl2markers.items():
425            cl2markers[k] = sorted(v,key=lambda x:float(x[2]),reverse=True)[:max_markers]
426        return cl2markers.values()
427
428    def get_c2t( self ):
429        tc2t = {}
430
431        def _get_c2t_( clade ):
432            lterms = clade.get_terminals()
433            tc2t[clade] = set([l.name for l in lterms])
434            if clade.is_terminal():
435                return
436            for c in clade.clades:
437                _get_c2t_( c )
438        _get_c2t_( self.tree.root )
439        return tc2t
440
441    def ltcs( self, terminals, tc2t = None, terminals2clades = None, lca_precomputed = None ):
442        set_terminals = set( terminals )
443        lca = lca_precomputed if lca_precomputed else self.lca( terminals, terminals2clades )
444        def _ltcs_rec_( clade, cur_max ):
445            if clade.is_terminal() and clade.name in set_terminals:
446                return clade,1
447            terms = tc2t[clade] if tc2t else set([cc.name for cc in clade.get_terminals()])
448            if len(terms) < cur_max:
449                return None,0
450            if terms <= set_terminals:
451                return clade,len(terms)
452            rets = []
453            for c in clade.clades:
454                r,tmax = _ltcs_rec_( c, cur_max )
455                if tmax >= cur_max:
456                    cur_max = tmax
457                    if r:
458                        rets.append((r,tmax))
459            if rets:
460                return sorted(rets,key=lambda x:x[1])[-1][0],cur_max
461            else:
462                return None,None
463        return _ltcs_rec_( lca, cur_max = 0 )[0]
464
465    def lca( self, terminals, terminals2clades = None ):
466        clade_targets = []
467        if terminals2clades:
468            clade_targets = [terminals2clades[str(t)] for t in terminals]
469        else:
470            clade_targets = [t for t in self.tree.get_terminals() if t.name in terminals]
471            """
472            for t in terminals:
473                ct = list(self.tree.find_clades( {"name": str(t)} ))
474                if len( ct ) > 1:
475                    sys.stderr.write( "Error: non-unique target specified." )
476                    sys.exit(-1)
477                clade_targets.append( ct[0] )
478            """
479        lca = self.tree.common_ancestor( clade_targets )
480        return lca
481
482
483    def lcca( self, t, t2c ):
484        node_path = list(self.tree.get_path(t))
485        if not node_path or len(node_path) < 2:
486            return None,None,None
487        tlevs = t2c[t].split(lev_sep)[2:-1]
488        for p in node_path[-15:]:
489            terms = list(p.get_terminals())
490            descn = [t2c[l.name].split(lev_sep)[2:-1] for l in  terms if l.name!=t]
491            if not descn or len(descn) < 2:
492                continue
493
494            l = tlevs[-1]
495            descr_l = [d[-1] for d in descn]
496            if len(set(descr_l)) == 1 and descr_l[0] != l and \
497                l != "s__sp_" and not l.endswith("unclassified") and \
498                descr_l[0] != "s__sp_" and not descr_l[0].endswith("unclassified"):
499                return p,terms,lev_sep.join(tlevs)
500        return None,None,None
501
502
503    def tax_precision( self, c2t_f, strategy = 'lca' ):
504        c2t = self.read_tax_clades( c2t_f )
505        res = []
506        for c,terms in c2t.items():
507            lca = self.lca( terms )
508            num = partial_branch_length(lca,terms)
509            den = lca.total_branch_length()
510            prec = num / den
511            res.append([c,str(prec)])
512        return res
513
514    def tax_recall( self, c2t_f ):
515        c2t = self.read_tax_clades( c2t_f )
516        res = []
517        for c,terms in c2t.items():
518            lca = self.lca( terms )
519            ltcs = self.ltcs( terms )
520            lca_terms = set(lca.get_terminals())
521            ltcs_terms = set(ltcs.get_terminals())
522            out_terms = lca_terms - ltcs_terms
523            outs = [c]
524            if len(out_terms):
525                diam = sum(sorted(ltcs.depths().values())[-2:])
526                outs += [":".join([t.name,str( self.tree.distance(ltcs,t)/diam )])
527                             for t in out_terms]
528            res.append( outs )
529        return res
530
531    def tax_resolution( self, terminals ):
532        pass
533
534
535    def prune( self, strategy = 'lca', n = None, fn = None, name = None, newname = None ):
536        prune = None
537        if strategy == 'root_name':
538            ct = list(self.tree.find_clades( {"name": name} ))
539            if len( ct ) > 1:
540                sys.stderr.write( "Error: non-unique target specified." )
541                sys.exit(-1)
542            prune = ct[0]
543        elif strategy == 'lca':
544            terms = self.read_targets( fn ) if isinstance(fn,str) else fn
545            prune = self.lca( terms )
546        elif strategy == 'ltcs':
547            terms = self.read_targets( fn ) if isinstance(fn,str) else fn
548            prune = self.ltcs( terms )
549        elif strategy == 'n_anc':
550            if n is None:
551                n = 1
552            ct = list(self.tree.find_clades( {"name": name} ))
553            if len( ct ) > 1:
554                sys.stderr.write( "Error: non-unique target specified.\n" )
555                sys.exit(-1)
556            node_path = list(self.tree.get_path(name))
557            if not node_path or len(node_path) < n:
558                sys.stderr.write( "Error: no anchestors or number of anchestors < n." )
559                sys.exit(-1)
560            toprune = node_path[-n]
561            fat = node_path[-n-1]
562            fat.clades = [cc for cc in fat.clades if cc != toprune]
563            prune = None
564        else:
565            sys.stderr.write( strategy + " not supported yet." )
566            sys.exit(-1)
567        if prune:
568            prune.clades = []
569            if newname:
570                prune.name = newname
571
572    def subtree( self, strategy, n = None, fn = None ):
573        newroot = None
574        if strategy == 'name':
575            ct = list(self.tree.find_clades( {"name": fn} ))
576            if len( ct ) != 1:
577                int_clades = self.tree.get_nonterminals()
578                for cl in int_clades:
579                    if n == cl.full_name:
580                        ct = [cl]
581                        break
582                if not ct:
583                    sys.stderr.write( "Error: target not found." )
584                    sys.exit(-1)
585            newroot = ct[0]
586        elif strategy == 'lca':
587            terms = self.read_targets( fn ) if isinstance(fn,str) else fn
588            newroot = self.lca( terms )
589        elif strategy == 'ltcs':
590            terms = self.read_targets( fn ) if isinstance(fn,str) else fn
591            newroot = self.ltcs( terms )
592        if newroot:
593            self.tree.root = newroot
594
595    def rename( self, strategy, n = None, terms = None ):
596        newroot = None
597        if strategy == 'root_name':
598            ct = list(self.tree.find_clades( {"name": n} ))
599            if len( ct ) > 1:
600                sys.stderr.write( "Error: non-unique target specified.\n" )
601                sys.exit(-1)
602            newroot = ct[0]
603        elif strategy == 'lca':
604            newroot = self.lca( terms )
605        elif strategy == 'ltcs':
606            newroot = self.ltcs( terms )
607        if newroot:
608            newroot.name = n
609
610    def export( self, out_file ):
611        self.tree = self.tree.as_phyloxml()
612        Phylo.write( self.tree, out_file, "phyloxml")
613
614    def read_tax_clades( self, tf ):
615        with open( tf ) as inpf:
616            return dict([(ll[0],ll[1:]) for ll in [l.strip().split('\t') for l in inpf]])
617
618    def read_targets( self, tf ):
619        if tf.count(":"):
620            return tf.split(":")
621        with open( tf ) as inpf:
622            return [l.strip() for l in inpf]
623
624    def reroot( self, strategy = 'lca', tf = None, n = None ):
625        if strategy in [ 'lca', 'ltcs' ]:
626            targets = self.read_targets( tf )
627
628            lca = self.lca( targets ) if strategy == 'lca' else self.ltcs( targets )
629            reroot_mid_fat_edge( self.tree, lca)
630
631            #lca_f = get_parent( self.tree, lca )
632            #
633            #bl = lca.branch_length
634            #new_clade = PClade(branch_length=bl*0.5, clades = [lca])
635            #lca.branch_length = bl*0.5
636            #if lca_f:
637            #    lca_f.clades = [c for c in lca_f.clades if c != lca] + [new_clade]
638            #    reroot( self.tree, new_clade )
639            #else:
640            #    self.tree.root = new_clade
641
642        elif strategy == 'midpoint':
643            pass
644            #self.tree.reroot_at_midpoint(update_splits=True)
645        elif strategy == 'longest_edge':
646            nodes = list(self.tree.get_nonterminals()) + list(self.tree.get_terminals())
647            longest = max( nodes, key=lambda x:x.branch_length )
648            reroot_mid_fat_edge( self.tree, longest )
649            #longest_edge = max( self.ntree.get_edge_set(),
650            #                    key=lambda x:x.length)
651            #self.tree.reroot_at_edge(longest_edge, update_splits=True)
652        elif strategy == 'longest_internal_edge':
653            nodes = list(self.tree.get_nonterminals())
654            longest = max( nodes, key=lambda x:x.branch_length )
655            if self.tree.root != longest:
656                reroot_mid_fat_edge( self.tree, longest )
657            #longest = get_lensorted_int_edges( self.tree )
658            #self.tree.reroot_at_edge( list(longest)[-1], update_splits=True )
659        elif strategy == 'longest_internal_edge_n':
660            nodes = list(self.tree.get_nonterminals())
661            longest = max( nodes, key=lambda x:x.branch_length
662                    if len(x.get_terminals()) >= n else -1.0)
663            reroot_mid_fat_edge( self.tree, longest )
664            #longest = get_lensorted_int_edges( self.tree, n )
665            #self.tree.reroot_at_edge( list(longest)[-1], update_splits=True )
666
667
668    def reorder_tree( self, reorder_tree ):
669        self._ord_terms = []
670
671
672        def reorder_tree_rec( clade, reorder_tree ):
673            if clade.is_terminal():
674                self._ord_terms.append( clade )
675                return clade,clade
676
677            if reorder_tree:
678                clade.clades.sort(key=lambda x:len(x.get_terminals()), reverse = True)
679
680            for c in clade.clades:
681                c.fc,c.lc = reorder_tree_rec( c, reorder_tree )
682
683            return clade.clades[0].fc,clade.clades[-1].lc
684            #clade.fc, clade.lc = clade.clades[0], clade.clades[-1]
685
686
687        reorder_tree_rec( self.tree.root, reorder_tree )
688        last = None # self._ord_terms[-1]
689
690        for c in self._ord_terms:
691            c.pc = last
692            if last:
693                last.nc = c
694            last = c
695
696        c.nc = None
697        #self._ord_terms[-1].nc = None # self._ord_terms[0]
698
699
700    def get_subtree_leaves( self, full_names = False ):
701
702        subtrees = []
703        def rec_subtree_leaves( clade ):
704            if not len(clade.clades):
705                return [clade.name]
706            leaves = []
707            for c in clade.clades:
708                leaves += rec_subtree_leaves( c )
709            leaves = [l for l in leaves if l]
710            subtrees.append( (clade.name if clade.name else "",leaves)  )
711            return leaves
712
713        rec_subtree_leaves( self.tree.root )
714        return subtrees
715
716    def get_clade_names( self, full_names = False, leaves = True, internals = True ):
717        clades = []
718        if leaves:
719            clades += self.tree.get_terminals()
720        if internals:
721            clades += self.tree.get_nonterminals()
722        if full_names:
723            def rec_name( clade, nam = "" ):
724                ret = []
725                if not nam and not clade.name:
726                    lnam = ""
727                elif not nam:
728                    lnam = clade.name
729                elif not clade.name:
730                    lnam = nam
731                else:
732                    lnam = lev_sep.join( [nam, clade.name if clade.name else ""] )
733                ret += [lnam] if lnam else []
734                for c in clade.clades:
735                    ret += rec_name(c,lnam)
736                return ret
737            names = set(rec_name(self.tree.root))
738        else:
739            names = set([c.name for c in clades])
740        return sorted(names)
741
742