1from Bio import Phylo
2from Bio.Phylo import PhyloXML
3from Bio.Phylo import PhyloXMLIO
4from collections import defaultdict as ddict
5from Bio.Phylo.PhyloXML import Property as Prop
6from Bio.Phylo.PhyloXML import Clade as PClade
7from Bio.Phylo.BaseTree import Tree as BTree
8from Bio.Phylo.BaseTree import Clade as BClade
9import string
10from numpy import pi as rpi
11rpi2 = 2.0*rpi
12import numpy as np
13import array as arr
14import collections as colls
15import sys
16#core_test = lambda ok,tot,pr: 1.0-st.binom.sf(ok,tot,pr)
17
18lev_sep = "."
19
20# Here are three functions that I'd love to see in Biopython but they
21# are not there (yet).
22def partial_branch_length(clade, selective_targets):
23    def _partial_branch_length_( clade, selective_targets ):
24        if clade.is_terminal() and clade.name in selective_targets:
25            return [clade.branch_length]
26        if not any([c.name in selective_targets for c in clade.get_terminals()]):
27            return [0.0]
28        ret = [0.0]
29        for c in clade.clades:
30            ret += [partial_branch_length( c, selective_targets)]
31        ret += [clade.branch_length]
32        return ret
33    return sum( _partial_branch_length_( clade,selective_targets  )  )
34
35def reroot(tree, new_root):
36    outgroup = new_root
37    outgroup_path = tree.get_path(outgroup)
38    if len(outgroup_path) == 0:
39        # Outgroup is the current root -- no change
40        return
41    prev_blen = outgroup.branch_length
42    if outgroup.is_terminal():
43        # Create a new root with a 0-length branch to the outgroup
44        outgroup.branch_length = 0.0
45        new_root = tree.root.__class__(
46                branch_length=tree.root.branch_length, clades=[outgroup])
47        # The first branch reversal (see the upcoming loop) is modified
48        if len(outgroup_path) == 1:
49            # Trivial tree like '(A,B);
50            new_parent = new_root
51        else:
52            parent = outgroup_path.pop(-2)
53            parent.clades.pop(parent.clades.index(outgroup))
54            prev_blen, parent.branch_length = parent.branch_length, prev_blen
55            new_root.clades.insert(0, parent)
56            new_parent = parent
57    else:
58        # Use the given outgroup node as the new (trifurcating) root
59        new_root = outgroup
60        new_root.branch_length = tree.root.branch_length
61        new_parent = new_root
62
63    # Tracing the outgroup lineage backwards, reattach the subclades under a
64    # new root clade. Reverse the branches directly above the outgroup in
65    # the tree, but keep the descendants of those clades as they are.
66    for parent in outgroup_path[-2::-1]:
67        parent.clades.pop(parent.clades.index(new_parent))
68        prev_blen, parent.branch_length = parent.branch_length, prev_blen
69        new_parent.clades.insert(0, parent)
70        new_parent = parent
71
72    # Finally, handle the original root according to number of descendents
73    old_root = tree.root
74    if outgroup in old_root.clades:
75        assert len(outgroup_path) == 1
76        old_root.clades.pop(old_root.clades.index(outgroup))
77    else:
78        old_root.clades.pop(old_root.clades.index(new_parent))
79    if len(old_root) == 1:
80        # Delete the old bifurcating root & add branch lengths
81        ingroup = old_root.clades[0]
82        if ingroup.branch_length:
83            ingroup.branch_length += prev_blen
84        else:
85            ingroup.branch_length = prev_blen
86        new_parent.clades.insert(0, ingroup)
87        # ENH: If annotations are attached to old_root, do... something.
88    else:
89        # Keep the old trifurcating/polytomous root as an internal node
90        old_root.branch_length = prev_blen
91        new_parent.clades.insert(0, old_root)
92
93    tree.root = new_root
94    tree.rooted = True
95    return
96
97def get_parent(tree, child_clade):
98    node_path = tree.get_path(child_clade)
99    return node_path[-2] if len(node_path) > 1 else None
100
101def reroot_mid_fat_edge( tree, node ):
102    if tree.root == node:
103        return
104    fat = get_parent( tree, node )
105    bl = node.branch_length
106    node.branch_length = bl*0.5
107    new_clade = PClade(branch_length=bl*0.5, clades = [node])
108    if fat:
109        fat.clades = [c for c in fat.clades if c != node] + [new_clade]
110        reroot( tree, new_clade )
111    else:
112        tree.root.clades = [new_clade] + [c for c in tree.root.clades if c != node]
113        reroot( tree, new_clade)
114
115def clades2terms( tree, startswith = None ):
116    c2t = {}
117    def clades2terms_rec( c ):
118        if startswith:
119            if c.name and c.name.startswith( startswith ):
120                c2t[c] = c.get_terminals()
121        else:
122            c2t[c] = c.get_terminals()
123        for cc in c.clades:
124            clades2terms_rec(cc)
125    clades2terms_rec( tree.root )
126    return c2t
127
128def dist_matrix( tree ):
129    terminals = list(tree.get_terminals())
130    term_names = [t.name for t in terminals]
131    # can be made faster with recursion
132    for n in tree.get_nonterminals():
133        n.ids = set( [nn.name for nn in n.get_terminals()]  )
134
135    dists = dict([(n,dict([(nn,0.0) for nn in term_names])) for n in term_names])
136
137    def dist_matrix_rec( clade ):
138        bl = clade.branch_length
139        if clade.is_terminal():
140            for t in term_names:
141                if t!=clade.name:
142                    dists[clade.name][t] += bl
143                    dists[t][clade.name] += bl
144            return
145        for t1 in clade.ids:
146            for t2 in terminals:
147                if t2.name not in clade.ids:
148                    dists[t1][t2.name] += bl
149                    dists[t2.name][t1] += bl
150        for c in clade.clades:
151            dist_matrix_rec( c )
152
153    dist_matrix_rec( tree.root )
154    return dists
155
156
157
158class PpaTree:
159
160    def __load_tree_txt__( self, fn ):
161        tree = Phylo.BaseTree.Tree()
162        try:
163            rows = [l.decode('utf-8').rstrip().split("\t")[0] for l in
164                        open(fn, 'rb')]
165        except IOError:
166            raise IOError()
167
168        clades = [r.split(lev_sep) for r in rows]
169
170        tree = BTree()
171        tree.root = BClade()
172
173        def add_clade_rec( father, txt_tree ):
174            fl = set([t[0] for t in txt_tree])
175            father.clades = []
176            for c in fl:
177                nclade = BClade( branch_length = 1.0,
178                                 name = c )
179                father.clades.append( nclade )
180                children = [t[1:] for t in txt_tree if len(t)>1 and t[0] == c]
181                if children:
182                    add_clade_rec( nclade, children )
183
184        add_clade_rec( tree.root, clades )
185        self.ignore_branch_len = 1
186        return tree.as_phyloxml()
187
188
189    def __read_tree__( self, fn ):
190        for ff in ['phyloxml','newick','nexus',"txt"]:
191            try:
192                if ff in ['txt']:
193                    tree = self.__load_tree_txt__( fn )
194                else:
195                    tree = Phylo.read(fn, ff)
196                    if len(tree.root.get_terminals()) == 1:
197                        raise ValueError
198            except ValueError:
199                continue
200            except IOError:
201                sys.stderr.write("Error: No tree file found: "+fn+"\n")
202                raise IOError
203            except Exception:
204                continue
205            else:
206                return tree.as_phyloxml()
207        sys.stderr.write("Error: unrecognized input format "+fn+"\n")
208        raise ValueError
209
210
211    def __init__( self, filename, warnings = False ):
212        self.warnings = warnings
213        if filename is None:
214            self.tree = None
215            return
216        try:
217            self.tree = self.__read_tree__(filename)
218            self.add_full_paths()
219        except:
220            sys.exit(0)
221
222
223    def core_test( self, ok, tot, pr ):
224        # scipy included here for non-compatibility with scons
225        import scipy.stats as st
226        if pr in self.ctc and tot in self.ctc[pr] and ok in self.ctc[pr][tot]:
227            return self.ctc[pr][tot][ok]
228        ret = 1.0-st.binom.sf(ok,tot,pr)
229        if not pr in self.ctc: self.ctc[pr] = {}
230        if not tot in self.ctc[pr]: self.ctc[pr][tot] = {}
231        if not ok in self.ctc[pr][tot]: self.ctc[pr][tot][ok] = ret
232        return ret
233
234    def is_core( self, clade, targs, er = 0.95 ):
235        intersection = clade.imgids & targs
236
237        len_intersection = len(intersection)
238
239        if len(clade.imgids) >= 2 and len_intersection < 2:
240           return False, 0.0, None
241
242        add = 0
243        for subclade in clade.clades:
244            if "?" in subclade.name:
245                out = subclade.imgids - intersection # targs
246                add += len(out)
247        if add and len_intersection >= add:
248            len_intersection += int(round(add/1.99))
249
250        core = self.core_test( len_intersection, clade.nterminals, er )
251        if core < 0.05 or len_intersection == 0:
252            return False, core, None
253        nsubclades, nsubclades_absent = 0, 0
254        for subclade in set(clade.get_nonterminals()) - set([clade]):
255            if "?" in subclade.full_name: # full??/
256                continue
257            if subclade.nterminals == 1:
258                nsubclades += 1 # !!!
259                if len(subclade.imgids & targs) == 0:
260                    nsubclades_absent += 1
261                continue
262
263            sc_intersection = subclade.imgids & targs
264            sc_len_intersection = len(sc_intersection)
265
266            sc_add = 0
267            for sc_subclade in subclade.clades:
268                if "?" in sc_subclade.name:
269                    sc_out = sc_subclade.imgids - sc_intersection
270                    sc_add += len(sc_out)
271            if add and sc_len_intersection >= sc_add:
272                sc_len_intersection += int(round(sc_add/1.99))
273
274            subcore = self.core_test( sc_len_intersection, subclade.nterminals, er )
275            if subcore < 0.05:
276                return False, core, None
277        if nsubclades > 0 and nsubclades == nsubclades_absent:
278            return False, core, None
279        return True, core, intersection
280
281    def _find_core( self, terminals, er = 0.95, root_name = None, skip_qm = True ):
282        #terminals_s = set(terminals)
283        def _find_core_rec( clade ):
284            if root_name:
285                #clname = lev_sep.join( [root_name]+clade.full_name.split(lev_sep)[1:] )
286                #clname = lev_sep.join( clade.full_name[1:] )
287                clname = clade.full_name
288            else:
289                clname = clade.full_name
290            if clade.is_terminal():
291                if clade.imgid in terminals:
292                    #n = terminals[clade.imgid]
293                    return [(clname,1,1,
294                                #n,n,n,
295                                1.0)]
296                return []
297            if skip_qm and clade.name and "?" in clade.name:
298                return []
299            if len(clade.imgids) == 1:
300                cimg = list(clade.imgids)[0]
301                if cimg in terminals:
302                    #n = terminals[cimg]
303                    return [(clname,1,1,
304                                #n,n,n,
305                                1.0)]
306                return []
307            core,pv,intersection = self.is_core( clade, terminals, er = er )
308            if core:
309                #ns = [terminals[ii] for ii in terminals_s if ii in clade.imgids]
310                return [( clname,
311                          len(intersection),len(clade.imgids),
312                          #len(clade.imgids&terminals_s),len(clade.imgids),
313                          #min(ns),max(ns),np.mean(ns),
314                          pv)]
315            rets = []
316            for c in clade.clades:
317                rets += _find_core_rec(c)
318            return rets
319        return  _find_core_rec( self.tree.root )
320
321
322    def add_full_paths( self ):
323
324        def _add_full_paths_( clade, path ):
325            lpath = path + ([clade.name] if clade.name else [])
326            clade.full_name = ".".join( lpath )
327            for c in clade.clades:
328                _add_full_paths_( c, lpath )
329        _add_full_paths_( self.tree.root, [] )
330
331    def find_cores( self, cl_taxa_file, min_core_size = 1, error_rate = 0.95, subtree = None, skip_qm = True ):
332        if subtree:
333            self.subtree( 'name', subtree )
334        self.ctc = {}
335        imgids2terminals = {}
336        for t in self.tree.get_terminals():
337            t.imgid = int(t.name[3:] if "t__"in t.name else t.name)
338            t.nterminals = 1
339            imgids2terminals[t.imgid] = t
340
341        # can be made faster with recursion
342        for n in self.tree.get_nonterminals():
343            n.imgids = set( [nn.imgid for nn in n.get_terminals()]  )
344            n.nterminals = len( n.imgids )
345
346        self.add_full_paths() # unnecessary
347
348        ret = {}
349        for vec in (l.strip().split('\t') for l in open(cl_taxa_file)):
350            sid = int(vec[0])
351            #tgts_l = [int(s) for s in vec[1:]]
352            #tgts = dict([(s,tgts_l.count(s)) for s in set(tgts_l)])
353            tgts = set([int(s) for s in vec[1:]])
354
355            if len(tgts) >= min_core_size:
356                subtree_name = lev_sep.join(subtree.split(lev_sep)[:-1] ) if subtree else None
357                ret[sid] = self._find_core( tgts, er = error_rate, root_name = subtree, skip_qm = skip_qm )
358                #print sid #, ret[sid]
359        return ret
360
361    def markerness( self, coreness, uniqueness, cn_min, cn_max, cn_avg ):
362        return coreness * uniqueness * (1.0 / float(cn_max-cn_min+1)) * 1.0 / cn_avg
363
364    def find_markers( self, cu_file, hitmap_file, core_file ):
365        self.ctc = {}
366        imgids2terminals = {}
367        ids2clades = {}
368        for t in self.tree.get_terminals():
369            t.imgid = int(t.name)
370            t.nterminals = 1
371            imgids2terminals[t.imgid] = t
372            ids2clades[t.name] = t
373
374        # can be made faster with recursion (but it is not a bottleneck)
375        for n in self.tree.get_nonterminals():
376            n.imgids = set( [nn.imgid for nn in n.get_terminals()]  )
377            n.nterminals = len( n.imgids )
378
379        self.add_full_paths() # unnecessary
380
381        cus = dict([(int(l[0]),[int(ll) for ll in l[1:]]) for l in
382                        (line.strip().split('\t') for line in open(cu_file))])
383        cinfo = dict([(int(v[0]),[v[1]] + [int(vv) for vv in v[2:6]] + [float(vv) for vv in v[6:]])
384                        for v in (line.strip().split('\t') for line in open(core_file))])
385
386        ret = {}
387        for vec in (l.strip().split('\t') for l in open(hitmap_file)):
388            sid = int(vec[0])
389            tgts_l = set([int(s) for s in vec[1:]])
390            lca = self.lca( cus[sid], ids2clades )
391            if lca.is_terminal():
392                tin = set([lca.imgid])
393                tout = tgts_l - tin
394            else:
395                tout = tgts_l - lca.imgids
396                tin = lca.imgids & tgts_l
397            ci = cinfo[sid]
398            ltin = len(tin)
399            ltout = len(tout)
400            uniqueness = float(ltin)/float(ltin+ltout)
401            coreness = float( ci[-1] )
402            cn_min, cp_max, cn_avg = [float(f) for f in ci[-4:-1]]
403            gtax = ci[0]
404            cobs, ctot = int(ci[1]), int(ci[2])
405            markerness = self.markerness( coreness, uniqueness, cn_min, cp_max, cn_avg )
406
407            res_lin = [ gtax, markerness, coreness, uniqueness, cobs, ctot, cn_min, cp_max, cn_avg,
408                        ltin, ltout, "|".join([str(s) for s in tin]), "|".join([str(s) for s in tout]) ]
409            ret[sid] = res_lin
410        return ret
411
412
413    def select_markers( self, marker_file, markerness_th = 0.0, max_markers = 200 ):
414        cl2markers = colls.defaultdict( list )
415        for line in (l.strip().split('\t') for l in open( marker_file )):
416            gid = line[1]
417            markerness = float(line[2])
418            if markerness < markerness_th:
419                continue
420            cl2markers[gid].append( line )
421        for k,v in cl2markers.items():
422            cl2markers[k] = sorted(v,key=lambda x:float(x[2]),reverse=True)[:max_markers]
423        return cl2markers.values()
424
425    def get_c2t( self ):
426        tc2t = {}
427
428        def _get_c2t_( clade ):
429            lterms = clade.get_terminals()
430            tc2t[clade] = set([l.name for l in lterms])
431            if clade.is_terminal():
432                return
433            for c in clade.clades:
434                _get_c2t_( c )
435        _get_c2t_( self.tree.root )
436        return tc2t
437
438    def ltcs( self, terminals, tc2t = None, terminals2clades = None, lca_precomputed = None ):
439        set_terminals = set( terminals )
440        lca = lca_precomputed if lca_precomputed else self.lca( terminals, terminals2clades )
441        def _ltcs_rec_( clade, cur_max ):
442            if clade.is_terminal() and clade.name in set_terminals:
443                return clade,1
444            terms = tc2t[clade] if tc2t else set([cc.name for cc in clade.get_terminals()])
445            if len(terms) < cur_max:
446                return None,0
447            if terms <= set_terminals:
448                return clade,len(terms)
449            rets = []
450            for c in clade.clades:
451                r,tmax = _ltcs_rec_( c, cur_max )
452                if tmax >= cur_max:
453                    cur_max = tmax
454                    if r:
455                        rets.append((r,tmax))
456            if rets:
457                return sorted(rets,key=lambda x:x[1])[-1][0],cur_max
458            else:
459                return None,None
460        return _ltcs_rec_( lca, cur_max = 0 )[0]
461
462    def lca( self, terminals, terminals2clades = None ):
463        clade_targets = []
464        if terminals2clades:
465            clade_targets = [terminals2clades[str(t)] for t in terminals]
466        else:
467            clade_targets = [t for t in self.tree.get_terminals() if t.name in terminals]
468            """
469            for t in terminals:
470                ct = list(self.tree.find_clades( {"name": str(t)} ))
471                if len( ct ) > 1:
472                    sys.stderr.write( "Error: non-unique target specified." )
473                    sys.exit(-1)
474                clade_targets.append( ct[0] )
475            """
476        lca = self.tree.common_ancestor( clade_targets )
477        return lca
478
479
480    def lcca( self, t, t2c ):
481        node_path = list(self.tree.get_path(t))
482        if not node_path or len(node_path) < 2:
483            return None,None,None
484        tlevs = t2c[t].split(lev_sep)[2:-1]
485        for p in node_path[-15:]:
486            terms = list(p.get_terminals())
487            descn = [t2c[l.name].split(lev_sep)[2:-1] for l in  terms if l.name!=t]
488            if not descn or len(descn) < 2:
489                continue
490
491            l = tlevs[-1]
492            descr_l = [d[-1] for d in descn]
493            if len(set(descr_l)) == 1 and descr_l[0] != l and \
494                l != "s__sp_" and not l.endswith("unclassified") and \
495                descr_l[0] != "s__sp_" and not descr_l[0].endswith("unclassified"):
496                return p,terms,lev_sep.join(tlevs)
497        return None,None,None
498
499
500    def tax_precision( self, c2t_f, strategy = 'lca' ):
501        c2t = self.read_tax_clades( c2t_f )
502        res = []
503        for c,terms in c2t.items():
504            lca = self.lca( terms )
505            num = partial_branch_length(lca,terms)
506            den = lca.total_branch_length()
507            prec = num / den
508            res.append([c,str(prec)])
509        return res
510
511    def tax_recall( self, c2t_f ):
512        c2t = self.read_tax_clades( c2t_f )
513        res = []
514        for c,terms in c2t.items():
515            lca = self.lca( terms )
516            ltcs = self.ltcs( terms )
517            lca_terms = set(lca.get_terminals())
518            ltcs_terms = set(ltcs.get_terminals())
519            out_terms = lca_terms - ltcs_terms
520            outs = [c]
521            if len(out_terms):
522                diam = sum(sorted(ltcs.depths().values())[-2:])
523                outs += [":".join([t.name,str( self.tree.distance(ltcs,t)/diam )])
524                             for t in out_terms]
525            res.append( outs )
526        return res
527
528    def tax_resolution( self, terminals ):
529        pass
530
531
532    def prune( self, strategy = 'lca', n = None, fn = None, name = None, newname = None ):
533        prune = None
534        if strategy == 'root_name':
535            ct = list(self.tree.find_clades( {"name": name} ))
536            if len( ct ) > 1:
537                sys.stderr.write( "Error: non-unique target specified." )
538                sys.exit(-1)
539            prune = ct[0]
540        elif strategy == 'lca':
541            terms = self.read_targets( fn ) if isinstance(fn,str) else fn
542            prune = self.lca( terms )
543        elif strategy == 'ltcs':
544            terms = self.read_targets( fn ) if isinstance(fn,str) else fn
545            prune = self.ltcs( terms )
546        elif strategy == 'n_anc':
547            if n is None:
548                n = 1
549            ct = list(self.tree.find_clades( {"name": name} ))
550            if len( ct ) > 1:
551                sys.stderr.write( "Error: non-unique target specified.\n" )
552                sys.exit(-1)
553            node_path = list(self.tree.get_path(name))
554            if not node_path or len(node_path) < n:
555                sys.stderr.write( "Error: no anchestors or number of anchestors < n." )
556                sys.exit(-1)
557            toprune = node_path[-n]
558            fat = node_path[-n-1]
559            fat.clades = [cc for cc in fat.clades if cc != toprune]
560            prune = None
561        else:
562            sys.stderr.write( strategy + " not supported yet." )
563            sys.exit(-1)
564        if prune:
565            prune.clades = []
566            if newname:
567                prune.name = newname
568
569    def subtree( self, strategy, n = None, fn = None ):
570        newroot = None
571        if strategy == 'name':
572            ct = list(self.tree.find_clades( {"name": fn} ))
573            if len( ct ) != 1:
574                int_clades = self.tree.get_nonterminals()
575                for cl in int_clades:
576                    if n == cl.full_name:
577                        ct = [cl]
578                        break
579                if not ct:
580                    sys.stderr.write( "Error: target not found." )
581                    sys.exit(-1)
582            newroot = ct[0]
583        elif strategy == 'lca':
584            terms = self.read_targets( fn ) if isinstance(fn,str) else fn
585            newroot = self.lca( terms )
586        elif strategy == 'ltcs':
587            terms = self.read_targets( fn ) if isinstance(fn,str) else fn
588            newroot = self.ltcs( terms )
589        if newroot:
590            self.tree.root = newroot
591
592    def rename( self, strategy, n = None, terms = None ):
593        newroot = None
594        if strategy == 'root_name':
595            ct = list(self.tree.find_clades( {"name": n} ))
596            if len( ct ) > 1:
597                sys.stderr.write( "Error: non-unique target specified.\n" )
598                sys.exit(-1)
599            newroot = ct[0]
600        elif strategy == 'lca':
601            newroot = self.lca( terms )
602        elif strategy == 'ltcs':
603            newroot = self.ltcs( terms )
604        if newroot:
605            newroot.name = n
606
607    def export( self, out_file ):
608        self.tree = self.tree.as_phyloxml()
609        Phylo.write( self.tree, out_file, "phyloxml")
610
611    def read_tax_clades( self, tf ):
612        with open( tf ) as inpf:
613            return dict([(ll[0],ll[1:]) for ll in [l.strip().split('\t') for l in inpf]])
614
615    def read_targets( self, tf ):
616        if tf.count(":"):
617            return tf.split(":")
618        with open( tf ) as inpf:
619            return [l.strip() for l in inpf]
620
621    def reroot( self, strategy = 'lca', tf = None, n = None ):
622        if strategy in [ 'lca', 'ltcs' ]:
623            targets = self.read_targets( tf )
624
625            lca = self.lca( targets ) if strategy == 'lca' else self.ltcs( targets )
626            reroot_mid_fat_edge( self.tree, lca)
627
628            #lca_f = get_parent( self.tree, lca )
629            #
630            #bl = lca.branch_length
631            #new_clade = PClade(branch_length=bl*0.5, clades = [lca])
632            #lca.branch_length = bl*0.5
633            #if lca_f:
634            #    lca_f.clades = [c for c in lca_f.clades if c != lca] + [new_clade]
635            #    reroot( self.tree, new_clade )
636            #else:
637            #    self.tree.root = new_clade
638
639        elif strategy == 'midpoint':
640            pass
641            #self.tree.reroot_at_midpoint(update_splits=True)
642        elif strategy == 'longest_edge':
643            nodes = list(self.tree.get_nonterminals()) + list(self.tree.get_terminals())
644            longest = max( nodes, key=lambda x:x.branch_length )
645            reroot_mid_fat_edge( self.tree, longest )
646            #longest_edge = max( self.ntree.get_edge_set(),
647            #                    key=lambda x:x.length)
648            #self.tree.reroot_at_edge(longest_edge, update_splits=True)
649        elif strategy == 'longest_internal_edge':
650            nodes = list(self.tree.get_nonterminals())
651            longest = max( nodes, key=lambda x:x.branch_length )
652            if self.tree.root != longest:
653                reroot_mid_fat_edge( self.tree, longest )
654            #longest = get_lensorted_int_edges( self.tree )
655            #self.tree.reroot_at_edge( list(longest)[-1], update_splits=True )
656        elif strategy == 'longest_internal_edge_n':
657            nodes = list(self.tree.get_nonterminals())
658            longest = max( nodes, key=lambda x:x.branch_length
659                    if len(x.get_terminals()) >= n else -1.0)
660            reroot_mid_fat_edge( self.tree, longest )
661            #longest = get_lensorted_int_edges( self.tree, n )
662            #self.tree.reroot_at_edge( list(longest)[-1], update_splits=True )
663
664
665    def reorder_tree( self ):
666        self._ord_terms = []
667        def reorder_tree_rec( clade ):
668            if clade.is_terminal():
669                self._ord_terms.append( clade )
670                return clade,clade
671            clade.clades.sort( key=lambda x:len(x.get_terminals()), reverse = True)
672            for c in clade.clades:
673                c.fc,c.lc = reorder_tree_rec( c )
674            return clade.clades[0].fc,clade.clades[-1].lc
675            #clade.fc, clade.lc = clade.clades[0], clade.clades[-1]
676
677        reorder_tree_rec( self.tree.root )
678        last = None # self._ord_terms[-1]
679        for c in self._ord_terms:
680            c.pc = last
681            if last:
682                last.nc = c
683            last = c
684        c.nc = None
685        #self._ord_terms[-1].nc = None # self._ord_terms[0]
686
687
688    def get_subtree_leaves( self, full_names = False ):
689
690        subtrees = []
691        def rec_subtree_leaves( clade ):
692            if not len(clade.clades):
693                return [clade.name]
694            leaves = []
695            for c in clade.clades:
696                leaves += rec_subtree_leaves( c )
697            leaves = [l for l in leaves if l]
698            subtrees.append( (clade.name if clade.name else "",leaves)  )
699            return leaves
700
701        rec_subtree_leaves( self.tree.root )
702        return subtrees
703
704    def get_clade_names( self, full_names = False, leaves = True, internals = True ):
705        clades = []
706        if leaves:
707            clades += self.tree.get_terminals()
708        if internals:
709            clades += self.tree.get_nonterminals()
710        if full_names:
711            def rec_name( clade, nam = "" ):
712                ret = []
713                if not nam and not clade.name:
714                    lnam = ""
715                elif not nam:
716                    lnam = clade.name
717                elif not clade.name:
718                    lnam = nam
719                else:
720                    lnam = lev_sep.join( [nam, clade.name if clade.name else ""] )
721                ret += [lnam] if lnam else []
722                for c in clade.clades:
723                    ret += rec_name(c,lnam)
724                return ret
725            names = set(rec_name(self.tree.root))
726        else:
727            names = set([c.name for c in clades])
728        return sorted(names)
729
730