1# Copyright (C) 2013 by Yanbo Ye (yeyanbo289@gmail.com)
2#
3# This file is part of the Biopython distribution and governed by your
4# choice of the "Biopython License Agreement" or the "BSD 3-Clause License".
5# Please see the LICENSE file that should have been included as part of this
6# package.
7
8"""Classes and methods for tree construction."""
9
10import itertools
11import copy
12import numbers
13from Bio.Phylo import BaseTree
14from Bio.Align import MultipleSeqAlignment
15from Bio.Align import substitution_matrices
16
17
18class _Matrix:
19    """Base class for distance matrix or scoring matrix.
20
21    Accepts a list of names and a lower triangular matrix.::
22
23        matrix = [[0],
24                  [1, 0],
25                  [2, 3, 0],
26                  [4, 5, 6, 0]]
27        represents the symmetric matrix of
28        [0,1,2,4]
29        [1,0,3,5]
30        [2,3,0,6]
31        [4,5,6,0]
32
33    :Parameters:
34        names : list
35            names of elements, used for indexing
36        matrix : list
37            nested list of numerical lists in lower triangular format
38
39    Examples
40    --------
41    >>> from Bio.Phylo.TreeConstruction import _Matrix
42    >>> names = ['Alpha', 'Beta', 'Gamma', 'Delta']
43    >>> matrix = [[0], [1, 0], [2, 3, 0], [4, 5, 6, 0]]
44    >>> m = _Matrix(names, matrix)
45    >>> m
46    _Matrix(names=['Alpha', 'Beta', 'Gamma', 'Delta'], matrix=[[0], [1, 0], [2, 3, 0], [4, 5, 6, 0]])
47
48    You can use two indices to get or assign an element in the matrix.
49
50    >>> m[1,2]
51    3
52    >>> m['Beta','Gamma']
53    3
54    >>> m['Beta','Gamma'] = 4
55    >>> m['Beta','Gamma']
56    4
57
58    Further more, you can use one index to get or assign a list of elements related to that index.
59
60    >>> m[0]
61    [0, 1, 2, 4]
62    >>> m['Alpha']
63    [0, 1, 2, 4]
64    >>> m['Alpha'] = [0, 7, 8, 9]
65    >>> m[0]
66    [0, 7, 8, 9]
67    >>> m[0,1]
68    7
69
70    Also you can delete or insert a column&row of elemets by index.
71
72    >>> m
73    _Matrix(names=['Alpha', 'Beta', 'Gamma', 'Delta'], matrix=[[0], [7, 0], [8, 4, 0], [9, 5, 6, 0]])
74    >>> del m['Alpha']
75    >>> m
76    _Matrix(names=['Beta', 'Gamma', 'Delta'], matrix=[[0], [4, 0], [5, 6, 0]])
77    >>> m.insert('Alpha', [0, 7, 8, 9] , 0)
78    >>> m
79    _Matrix(names=['Alpha', 'Beta', 'Gamma', 'Delta'], matrix=[[0], [7, 0], [8, 4, 0], [9, 5, 6, 0]])
80
81    """
82
83    def __init__(self, names, matrix=None):
84        """Initialize matrix.
85
86        Arguments are a list of names, and optionally a list of lower
87        triangular matrix data (zero matrix used by default).
88        """
89        # check names
90        if isinstance(names, list) and all(isinstance(s, str) for s in names):
91            if len(set(names)) == len(names):
92                self.names = names
93            else:
94                raise ValueError("Duplicate names found")
95        else:
96            raise TypeError("'names' should be a list of strings")
97
98        # check matrix
99        if matrix is None:
100            # create a new one with 0 if matrix is not assigned
101            matrix = [[0] * i for i in range(1, len(self) + 1)]
102            self.matrix = matrix
103        else:
104            # check if all elements are numbers
105            if (
106                isinstance(matrix, list)
107                and all(isinstance(l, list) for l in matrix)
108                and all(
109                    isinstance(n, numbers.Number)
110                    for n in [item for sublist in matrix for item in sublist]
111                )
112            ):
113                # check if the same length with names
114                if len(matrix) == len(names):
115                    # check if is lower triangle format
116                    if [len(m) for m in matrix] == list(range(1, len(self) + 1)):
117                        self.matrix = matrix
118                    else:
119                        raise ValueError("'matrix' should be in lower triangle format")
120                else:
121                    raise ValueError("'names' and 'matrix' should be the same size")
122            else:
123                raise TypeError("'matrix' should be a list of numerical lists")
124
125    def __getitem__(self, item):
126        """Access value(s) by the index(s) or name(s).
127
128        For a _Matrix object 'dm'::
129
130            dm[i]                   get a value list from the given 'i' to others;
131            dm[i, j]                get the value between 'i' and 'j';
132            dm['name']              map name to index first
133            dm['name1', 'name2']    map name to index first
134
135        """
136        # Handle single indexing
137        if isinstance(item, (int, str)):
138            index = None
139            if isinstance(item, int):
140                index = item
141            elif isinstance(item, str):
142                if item in self.names:
143                    index = self.names.index(item)
144                else:
145                    raise ValueError("Item not found.")
146            else:
147                raise TypeError("Invalid index type.")
148            # check index
149            if index > len(self) - 1:
150                raise IndexError("Index out of range.")
151            return [self.matrix[index][i] for i in range(0, index)] + [
152                self.matrix[i][index] for i in range(index, len(self))
153            ]
154        # Handle double indexing
155        elif len(item) == 2:
156            row_index = None
157            col_index = None
158            if all(isinstance(i, int) for i in item):
159                row_index, col_index = item
160            elif all(isinstance(i, str) for i in item):
161                row_name, col_name = item
162                if row_name in self.names and col_name in self.names:
163                    row_index = self.names.index(row_name)
164                    col_index = self.names.index(col_name)
165                else:
166                    raise ValueError("Item not found.")
167            else:
168                raise TypeError("Invalid index type.")
169            # check index
170            if row_index > len(self) - 1 or col_index > len(self) - 1:
171                raise IndexError("Index out of range.")
172            if row_index > col_index:
173                return self.matrix[row_index][col_index]
174            else:
175                return self.matrix[col_index][row_index]
176        else:
177            raise TypeError("Invalid index type.")
178
179    def __setitem__(self, item, value):
180        """Set value by the index(s) or name(s).
181
182        Similar to __getitem__::
183
184            dm[1] = [1, 0, 3, 4]    set values from '1' to others;
185            dm[i, j] = 2            set the value from 'i' to 'j'
186
187        """
188        # Handle single indexing
189        if isinstance(item, (int, str)):
190            index = None
191            if isinstance(item, int):
192                index = item
193            elif isinstance(item, str):
194                if item in self.names:
195                    index = self.names.index(item)
196                else:
197                    raise ValueError("Item not found.")
198            else:
199                raise TypeError("Invalid index type.")
200            # check index
201            if index > len(self) - 1:
202                raise IndexError("Index out of range.")
203            # check and assign value
204            if isinstance(value, list) and all(
205                isinstance(n, numbers.Number) for n in value
206            ):
207                if len(value) == len(self):
208                    for i in range(0, index):
209                        self.matrix[index][i] = value[i]
210                    for i in range(index, len(self)):
211                        self.matrix[i][index] = value[i]
212                else:
213                    raise ValueError("Value not the same size.")
214            else:
215                raise TypeError("Invalid value type.")
216        # Handle double indexing
217        elif len(item) == 2:
218            row_index = None
219            col_index = None
220            if all(isinstance(i, int) for i in item):
221                row_index, col_index = item
222            elif all(isinstance(i, str) for i in item):
223                row_name, col_name = item
224                if row_name in self.names and col_name in self.names:
225                    row_index = self.names.index(row_name)
226                    col_index = self.names.index(col_name)
227                else:
228                    raise ValueError("Item not found.")
229            else:
230                raise TypeError("Invalid index type.")
231            # check index
232            if row_index > len(self) - 1 or col_index > len(self) - 1:
233                raise IndexError("Index out of range.")
234            # check and assign value
235            if isinstance(value, numbers.Number):
236                if row_index > col_index:
237                    self.matrix[row_index][col_index] = value
238                else:
239                    self.matrix[col_index][row_index] = value
240            else:
241                raise TypeError("Invalid value type.")
242        else:
243            raise TypeError("Invalid index type.")
244
245    def __delitem__(self, item):
246        """Delete related distances by the index or name."""
247        index = None
248        if isinstance(item, int):
249            index = item
250        elif isinstance(item, str):
251            index = self.names.index(item)
252        else:
253            raise TypeError("Invalid index type.")
254        # remove distances related to index
255        for i in range(index + 1, len(self)):
256            del self.matrix[i][index]
257        del self.matrix[index]
258        # remove name
259        del self.names[index]
260
261    def insert(self, name, value, index=None):
262        """Insert distances given the name and value.
263
264        :Parameters:
265            name : str
266                name of a row/col to be inserted
267            value : list
268                a row/col of values to be inserted
269
270        """
271        if isinstance(name, str):
272            # insert at the given index or at the end
273            if index is None:
274                index = len(self)
275            if not isinstance(index, int):
276                raise TypeError("Invalid index type.")
277            # insert name
278            self.names.insert(index, name)
279            # insert elements of 0, to be assigned
280            self.matrix.insert(index, [0] * index)
281            for i in range(index, len(self)):
282                self.matrix[i].insert(index, 0)
283            # assign value
284            self[index] = value
285        else:
286            raise TypeError("Invalid name type.")
287
288    def __len__(self):
289        """Matrix length."""
290        return len(self.names)
291
292    def __repr__(self):
293        """Return Matrix as a string."""
294        return self.__class__.__name__ + "(names=%s, matrix=%s)" % tuple(
295            map(repr, (self.names, self.matrix))
296        )
297
298    def __str__(self):
299        """Get a lower triangular matrix string."""
300        matrix_string = "\n".join(
301            [
302                self.names[i] + "\t" + "\t".join([str(n) for n in self.matrix[i]])
303                for i in range(0, len(self))
304            ]
305        )
306        matrix_string = matrix_string + "\n\t" + "\t".join(self.names)
307        return matrix_string
308
309
310class DistanceMatrix(_Matrix):
311    """Distance matrix class that can be used for distance based tree algorithms.
312
313    All diagonal elements will be zero no matter what the users provide.
314    """
315
316    def __init__(self, names, matrix=None):
317        """Initialize the class."""
318        _Matrix.__init__(self, names, matrix)
319        self._set_zero_diagonal()
320
321    def __setitem__(self, item, value):
322        """Set Matrix's items to values."""
323        _Matrix.__setitem__(self, item, value)
324        self._set_zero_diagonal()
325
326    def _set_zero_diagonal(self):
327        """Set all diagonal elements to zero (PRIVATE)."""
328        for i in range(0, len(self)):
329            self.matrix[i][i] = 0
330
331    def format_phylip(self, handle):
332        """Write data in Phylip format to a given file-like object or handle.
333
334        The output stream is the input distance matrix format used with Phylip
335        programs (e.g. 'neighbor'). See:
336        http://evolution.genetics.washington.edu/phylip/doc/neighbor.html
337
338        :Parameters:
339            handle : file or file-like object
340                A writeable text mode file handle or other object supporting
341                the 'write' method, such as StringIO or sys.stdout.
342
343        """
344        handle.write(f"    {len(self.names)}\n")
345        # Phylip needs space-separated, vertically aligned columns
346        name_width = max(12, max(map(len, self.names)) + 1)
347        value_fmts = ("{" + str(x) + ":.4f}" for x in range(1, len(self.matrix) + 1))
348        row_fmt = "{0:" + str(name_width) + "s}" + "  ".join(value_fmts) + "\n"
349        for i, (name, values) in enumerate(zip(self.names, self.matrix)):
350            # Mirror the matrix values across the diagonal
351            mirror_values = (self.matrix[j][i] for j in range(i + 1, len(self.matrix)))
352            fields = itertools.chain([name], values, mirror_values)
353            handle.write(row_fmt.format(*fields))
354
355
356# Shim for compatibility with Biopython<1.70 (#1304)
357_DistanceMatrix = DistanceMatrix
358
359
360class DistanceCalculator:
361    """Class to calculate the distance matrix from a DNA or Protein.
362
363    Multiple Sequence Alignment(MSA) and the given name of the
364    substitution model.
365
366    Currently only scoring matrices are used.
367
368    :Parameters:
369        model : str
370            Name of the model matrix to be used to calculate distance.
371            The attribute ``dna_models`` contains the available model
372            names for DNA sequences and ``protein_models`` for protein
373            sequences.
374
375    Examples
376    --------
377    Loading a small PHYLIP alignment from which to compute distances::
378
379        from Bio.Phylo.TreeConstruction import DistanceCalculator
380        from Bio import AlignIO
381        aln = AlignIO.read(open('TreeConstruction/msa.phy'), 'phylip')
382        print(aln)
383
384    Output::
385
386        Alignment with 5 rows and 13 columns
387        AACGTGGCCACAT Alpha
388        AAGGTCGCCACAC Beta
389        CAGTTCGCCACAA Gamma
390        GAGATTTCCGCCT Delta
391        GAGATCTCCGCCC Epsilon
392
393    DNA calculator with 'identity' model::
394
395        calculator = DistanceCalculator('identity')
396        dm = calculator.get_distance(aln)
397        print(dm)
398
399    Output::
400
401        Alpha	0
402        Beta	0.23076923076923073	0
403        Gamma	0.3846153846153846	0.23076923076923073	0
404        Delta	0.5384615384615384	0.5384615384615384	0.5384615384615384	0
405        Epsilon	0.6153846153846154	0.3846153846153846	0.46153846153846156	0.15384615384615385	0
406            Alpha	Beta	Gamma	Delta	Epsilon
407
408    Protein calculator with 'blosum62' model::
409
410        calculator = DistanceCalculator('blosum62')
411        dm = calculator.get_distance(aln)
412        print(dm)
413
414    Output::
415
416        Alpha	0
417        Beta	0.36904761904761907	0
418        Gamma	0.49397590361445787	0.25	0
419        Delta	0.5853658536585367	0.5476190476190477	0.5662650602409638	0
420        Epsilon	0.7	0.3555555555555555	0.48888888888888893	0.2222222222222222	0
421            Alpha	Beta	Gamma	Delta	Epsilon
422
423    """
424
425    protein_alphabet = set("ABCDEFGHIKLMNPQRSTVWXYZ")
426
427    dna_models = []
428    protein_models = []
429
430    # matrices available
431    names = substitution_matrices.load()
432    for name in names:
433        matrix = substitution_matrices.load(name)
434        if name == "NUC.4.4":
435            # BLAST nucleic acid scoring matrix
436            name = "blastn"
437        else:
438            name = name.lower()
439        if protein_alphabet.issubset(set(matrix.alphabet)):
440            protein_models.append(name)
441        else:
442            dna_models.append(name)
443
444    del protein_alphabet
445    del name
446    del names
447    del matrix
448
449    models = ["identity"] + dna_models + protein_models
450
451    def __init__(self, model="identity", skip_letters=None):
452        """Initialize with a distance model."""
453        # Shim for backward compatibility (#491)
454        if skip_letters:
455            self.skip_letters = skip_letters
456        elif model == "identity":
457            self.skip_letters = ()
458        else:
459            self.skip_letters = ("-", "*")
460
461        if model == "identity":
462            self.scoring_matrix = None
463        elif model in self.models:
464            if model == "blastn":
465                name = "NUC.4.4"
466            else:
467                name = model.upper()
468            self.scoring_matrix = substitution_matrices.load(name)
469        else:
470            raise ValueError(
471                "Model not supported. Available models: " + ", ".join(self.models)
472            )
473
474    def _pairwise(self, seq1, seq2):
475        """Calculate pairwise distance from two sequences (PRIVATE).
476
477        Returns a value between 0 (identical sequences) and 1 (completely
478        different, or seq1 is an empty string.)
479        """
480        score = 0
481        max_score = 0
482        if self.scoring_matrix is None:
483            # Score by character identity, not skipping any special letters
484            score = sum(
485                l1 == l2
486                for l1, l2 in zip(seq1, seq2)
487                if l1 not in self.skip_letters and l2 not in self.skip_letters
488            )
489            max_score = len(seq1)
490        else:
491            max_score1 = 0
492            max_score2 = 0
493            for i in range(0, len(seq1)):
494                l1 = seq1[i]
495                l2 = seq2[i]
496                if l1 in self.skip_letters or l2 in self.skip_letters:
497                    continue
498                try:
499                    max_score1 += self.scoring_matrix[l1, l1]
500                except IndexError:
501                    raise ValueError(
502                        "Bad letter '%s' in sequence '%s' at position '%s'"
503                        % (l1, seq1.id, i)
504                    ) from None
505                try:
506                    max_score2 += self.scoring_matrix[l2, l2]
507                except IndexError:
508                    raise ValueError(
509                        "Bad letter '%s' in sequence '%s' at position '%s'"
510                        % (l2, seq2.id, i)
511                    ) from None
512                score += self.scoring_matrix[l1, l2]
513            # Take the higher score if the matrix is asymmetrical
514            max_score = max(max_score1, max_score2)
515        if max_score == 0:
516            return 1  # max possible scaled distance
517        return 1 - (score * 1.0 / max_score)
518
519    def get_distance(self, msa):
520        """Return a DistanceMatrix for MSA object.
521
522        :Parameters:
523            msa : MultipleSeqAlignment
524                DNA or Protein multiple sequence alignment.
525
526        """
527        if not isinstance(msa, MultipleSeqAlignment):
528            raise TypeError("Must provide a MultipleSeqAlignment object.")
529
530        names = [s.id for s in msa]
531        dm = DistanceMatrix(names)
532        for seq1, seq2 in itertools.combinations(msa, 2):
533            dm[seq1.id, seq2.id] = self._pairwise(seq1, seq2)
534        return dm
535
536
537class TreeConstructor:
538    """Base class for all tree constructor."""
539
540    def build_tree(self, msa):
541        """Caller to built the tree from a MultipleSeqAlignment object.
542
543        This should be implemented in subclass.
544        """
545        raise NotImplementedError("Method not implemented!")
546
547
548class DistanceTreeConstructor(TreeConstructor):
549    """Distance based tree constructor.
550
551    :Parameters:
552        method : str
553            Distance tree construction method, 'nj'(default) or 'upgma'.
554        distance_calculator : DistanceCalculator
555            The distance matrix calculator for multiple sequence alignment.
556            It must be provided if ``build_tree`` will be called.
557
558    Examples
559    --------
560    Loading a small PHYLIP alignment from which to compute distances, and then
561    build a upgma Tree::
562
563        from Bio.Phylo.TreeConstruction import DistanceTreeConstructor
564        from Bio.Phylo.TreeConstruction import DistanceCalculator
565        from Bio import AlignIO
566        aln = AlignIO.read(open('TreeConstruction/msa.phy'), 'phylip')
567        constructor = DistanceTreeConstructor()
568        calculator = DistanceCalculator('identity')
569        dm = calculator.get_distance(aln)
570        upgmatree = constructor.upgma(dm)
571        print(upgmatree)
572
573    Output::
574
575        Tree(rooted=True)
576            Clade(branch_length=0, name='Inner4')
577                Clade(branch_length=0.18749999999999994, name='Inner1')
578                    Clade(branch_length=0.07692307692307693, name='Epsilon')
579                    Clade(branch_length=0.07692307692307693, name='Delta')
580                Clade(branch_length=0.11057692307692304, name='Inner3')
581                    Clade(branch_length=0.038461538461538464, name='Inner2')
582                        Clade(branch_length=0.11538461538461536, name='Gamma')
583                        Clade(branch_length=0.11538461538461536, name='Beta')
584                    Clade(branch_length=0.15384615384615383, name='Alpha')
585
586    Build a NJ Tree::
587
588        njtree = constructor.nj(dm)
589        print(njtree)
590
591    Output::
592
593        Tree(rooted=False)
594            Clade(branch_length=0, name='Inner3')
595                Clade(branch_length=0.18269230769230765, name='Alpha')
596                Clade(branch_length=0.04807692307692307, name='Beta')
597                Clade(branch_length=0.04807692307692307, name='Inner2')
598                    Clade(branch_length=0.27884615384615385, name='Inner1')
599                        Clade(branch_length=0.051282051282051266, name='Epsilon')
600                        Clade(branch_length=0.10256410256410259, name='Delta')
601                    Clade(branch_length=0.14423076923076922, name='Gamma')
602
603    """
604
605    methods = ["nj", "upgma"]
606
607    def __init__(self, distance_calculator=None, method="nj"):
608        """Initialize the class."""
609        if distance_calculator is None or isinstance(
610            distance_calculator, DistanceCalculator
611        ):
612            self.distance_calculator = distance_calculator
613        else:
614            raise TypeError("Must provide a DistanceCalculator object.")
615        if isinstance(method, str) and method in self.methods:
616            self.method = method
617        else:
618            raise TypeError(
619                "Bad method: "
620                + method
621                + ". Available methods: "
622                + ", ".join(self.methods)
623            )
624
625    def build_tree(self, msa):
626        """Construct and return a Tree, Neighbor Joining or UPGMA."""
627        if self.distance_calculator:
628            dm = self.distance_calculator.get_distance(msa)
629            tree = None
630            if self.method == "upgma":
631                tree = self.upgma(dm)
632            else:
633                tree = self.nj(dm)
634            return tree
635        else:
636            raise TypeError("Must provide a DistanceCalculator object.")
637
638    def upgma(self, distance_matrix):
639        """Construct and return an UPGMA tree.
640
641        Constructs and returns an Unweighted Pair Group Method
642        with Arithmetic mean (UPGMA) tree.
643
644        :Parameters:
645            distance_matrix : DistanceMatrix
646                The distance matrix for tree construction.
647
648        """
649        if not isinstance(distance_matrix, DistanceMatrix):
650            raise TypeError("Must provide a DistanceMatrix object.")
651
652        # make a copy of the distance matrix to be used
653        dm = copy.deepcopy(distance_matrix)
654        # init terminal clades
655        clades = [BaseTree.Clade(None, name) for name in dm.names]
656        # init minimum index
657        min_i = 0
658        min_j = 0
659        inner_count = 0
660        while len(dm) > 1:
661            min_dist = dm[1, 0]
662            # find minimum index
663            for i in range(1, len(dm)):
664                for j in range(0, i):
665                    if min_dist >= dm[i, j]:
666                        min_dist = dm[i, j]
667                        min_i = i
668                        min_j = j
669
670            # create clade
671            clade1 = clades[min_i]
672            clade2 = clades[min_j]
673            inner_count += 1
674            inner_clade = BaseTree.Clade(None, "Inner" + str(inner_count))
675            inner_clade.clades.append(clade1)
676            inner_clade.clades.append(clade2)
677            # assign branch length
678            if clade1.is_terminal():
679                clade1.branch_length = min_dist * 1.0 / 2
680            else:
681                clade1.branch_length = min_dist * 1.0 / 2 - self._height_of(clade1)
682
683            if clade2.is_terminal():
684                clade2.branch_length = min_dist * 1.0 / 2
685            else:
686                clade2.branch_length = min_dist * 1.0 / 2 - self._height_of(clade2)
687
688            # update node list
689            clades[min_j] = inner_clade
690            del clades[min_i]
691
692            # rebuild distance matrix,
693            # set the distances of new node at the index of min_j
694            for k in range(0, len(dm)):
695                if k != min_i and k != min_j:
696                    dm[min_j, k] = (dm[min_i, k] + dm[min_j, k]) * 1.0 / 2
697
698            dm.names[min_j] = "Inner" + str(inner_count)
699
700            del dm[min_i]
701        inner_clade.branch_length = 0
702        return BaseTree.Tree(inner_clade)
703
704    def nj(self, distance_matrix):
705        """Construct and return a Neighbor Joining tree.
706
707        :Parameters:
708            distance_matrix : DistanceMatrix
709                The distance matrix for tree construction.
710
711        """
712        if not isinstance(distance_matrix, DistanceMatrix):
713            raise TypeError("Must provide a DistanceMatrix object.")
714
715        # make a copy of the distance matrix to be used
716        dm = copy.deepcopy(distance_matrix)
717        # init terminal clades
718        clades = [BaseTree.Clade(None, name) for name in dm.names]
719        # init node distance
720        node_dist = [0] * len(dm)
721        # init minimum index
722        min_i = 0
723        min_j = 0
724        inner_count = 0
725        # special cases for Minimum Alignment Matrices
726        if len(dm) == 1:
727            root = clades[0]
728
729            return BaseTree.Tree(root, rooted=False)
730        elif len(dm) == 2:
731            # minimum distance will always be [1,0]
732            min_i = 1
733            min_j = 0
734            clade1 = clades[min_i]
735            clade2 = clades[min_j]
736            clade1.branch_length = dm[min_i, min_j] / 2.0
737            clade2.branch_length = dm[min_i, min_j] - clade1.branch_length
738            inner_clade = BaseTree.Clade(None, "Inner")
739            inner_clade.clades.append(clade1)
740            inner_clade.clades.append(clade2)
741            clades[0] = inner_clade
742            root = clades[0]
743
744            return BaseTree.Tree(root, rooted=False)
745        while len(dm) > 2:
746            # calculate nodeDist
747            for i in range(0, len(dm)):
748                node_dist[i] = 0
749                for j in range(0, len(dm)):
750                    node_dist[i] += dm[i, j]
751                node_dist[i] = node_dist[i] / (len(dm) - 2)
752
753            # find minimum distance pair
754            min_dist = dm[1, 0] - node_dist[1] - node_dist[0]
755            min_i = 0
756            min_j = 1
757            for i in range(1, len(dm)):
758                for j in range(0, i):
759                    temp = dm[i, j] - node_dist[i] - node_dist[j]
760                    if min_dist > temp:
761                        min_dist = temp
762                        min_i = i
763                        min_j = j
764            # create clade
765            clade1 = clades[min_i]
766            clade2 = clades[min_j]
767            inner_count += 1
768            inner_clade = BaseTree.Clade(None, "Inner" + str(inner_count))
769            inner_clade.clades.append(clade1)
770            inner_clade.clades.append(clade2)
771            # assign branch length
772            clade1.branch_length = (
773                dm[min_i, min_j] + node_dist[min_i] - node_dist[min_j]
774            ) / 2.0
775            clade2.branch_length = dm[min_i, min_j] - clade1.branch_length
776
777            # update node list
778            clades[min_j] = inner_clade
779            del clades[min_i]
780
781            # rebuild distance matrix,
782            # set the distances of new node at the index of min_j
783            for k in range(0, len(dm)):
784                if k != min_i and k != min_j:
785                    dm[min_j, k] = (
786                        dm[min_i, k] + dm[min_j, k] - dm[min_i, min_j]
787                    ) / 2.0
788
789            dm.names[min_j] = "Inner" + str(inner_count)
790            del dm[min_i]
791
792        # set the last clade as one of the child of the inner_clade
793        root = None
794        if clades[0] == inner_clade:
795            clades[0].branch_length = 0
796            clades[1].branch_length = dm[1, 0]
797            clades[0].clades.append(clades[1])
798            root = clades[0]
799        else:
800            clades[0].branch_length = dm[1, 0]
801            clades[1].branch_length = 0
802            clades[1].clades.append(clades[0])
803            root = clades[1]
804
805        return BaseTree.Tree(root, rooted=False)
806
807    def _height_of(self, clade):
808        """Calculate clade height -- the longest path to any terminal (PRIVATE)."""
809        height = 0
810        if clade.is_terminal():
811            height = clade.branch_length
812        else:
813            height = height + max(self._height_of(c) for c in clade.clades)
814        return height
815
816
817# #################### Tree Scoring and Searching Classes #####################
818
819
820class Scorer:
821    """Base class for all tree scoring methods."""
822
823    def get_score(self, tree, alignment):
824        """Caller to get the score of a tree for the given alignment.
825
826        This should be implemented in subclass.
827        """
828        raise NotImplementedError("Method not implemented!")
829
830
831class TreeSearcher:
832    """Base class for all tree searching methods."""
833
834    def search(self, starting_tree, alignment):
835        """Caller to search the best tree with a starting tree.
836
837        This should be implemented in subclass.
838        """
839        raise NotImplementedError("Method not implemented!")
840
841
842class NNITreeSearcher(TreeSearcher):
843    """Tree searching with Nearest Neighbor Interchanges (NNI) algorithm.
844
845    :Parameters:
846        scorer : ParsimonyScorer
847            parsimony scorer to calculate the parsimony score of
848            different trees during NNI algorithm.
849
850    """
851
852    def __init__(self, scorer):
853        """Initialize the class."""
854        if isinstance(scorer, Scorer):
855            self.scorer = scorer
856        else:
857            raise TypeError("Must provide a Scorer object.")
858
859    def search(self, starting_tree, alignment):
860        """Implement the TreeSearcher.search method.
861
862        :Parameters:
863           starting_tree : Tree
864               starting tree of NNI method.
865           alignment : MultipleSeqAlignment
866               multiple sequence alignment used to calculate parsimony
867               score of different NNI trees.
868
869        """
870        return self._nni(starting_tree, alignment)
871
872    def _nni(self, starting_tree, alignment):
873        """Search for the best parsimony tree using the NNI algorithm (PRIVATE)."""
874        best_tree = starting_tree
875        while True:
876            best_score = self.scorer.get_score(best_tree, alignment)
877            temp = best_score
878            for t in self._get_neighbors(best_tree):
879                score = self.scorer.get_score(t, alignment)
880                if score < best_score:
881                    best_score = score
882                    best_tree = t
883            # stop if no smaller score exist
884            if best_score >= temp:
885                break
886        return best_tree
887
888    def _get_neighbors(self, tree):
889        """Get all neighbor trees of the given tree (PRIVATE).
890
891        Currently only for binary rooted trees.
892        """
893        # make child to parent dict
894        parents = {}
895        for clade in tree.find_clades():
896            if clade != tree.root:
897                node_path = tree.get_path(clade)
898                # cannot get the parent if the parent is root. Bug?
899                if len(node_path) == 1:
900                    parents[clade] = tree.root
901                else:
902                    parents[clade] = node_path[-2]
903        neighbors = []
904        root_childs = []
905        for clade in tree.get_nonterminals(order="level"):
906            if clade == tree.root:
907                left = clade.clades[0]
908                right = clade.clades[1]
909                root_childs.append(left)
910                root_childs.append(right)
911                if not left.is_terminal() and not right.is_terminal():
912                    # make changes around the left_left clade
913                    # left_left = left.clades[0]
914                    left_right = left.clades[1]
915                    right_left = right.clades[0]
916                    right_right = right.clades[1]
917                    # neightbor 1 (left_left + right_right)
918                    del left.clades[1]
919                    del right.clades[1]
920                    left.clades.append(right_right)
921                    right.clades.append(left_right)
922                    temp_tree = copy.deepcopy(tree)
923                    neighbors.append(temp_tree)
924                    # neighbor 2 (left_left + right_left)
925                    del left.clades[1]
926                    del right.clades[0]
927                    left.clades.append(right_left)
928                    right.clades.append(right_right)
929                    temp_tree = copy.deepcopy(tree)
930                    neighbors.append(temp_tree)
931                    # change back (left_left + left_right)
932                    del left.clades[1]
933                    del right.clades[0]
934                    left.clades.append(left_right)
935                    right.clades.insert(0, right_left)
936            elif clade in root_childs:
937                # skip root child
938                continue
939            else:
940                # method for other clades
941                # make changes around the parent clade
942                left = clade.clades[0]
943                right = clade.clades[1]
944                parent = parents[clade]
945                if clade == parent.clades[0]:
946                    sister = parent.clades[1]
947                    # neighbor 1 (parent + right)
948                    del parent.clades[1]
949                    del clade.clades[1]
950                    parent.clades.append(right)
951                    clade.clades.append(sister)
952                    temp_tree = copy.deepcopy(tree)
953                    neighbors.append(temp_tree)
954                    # neighbor 2 (parent + left)
955                    del parent.clades[1]
956                    del clade.clades[0]
957                    parent.clades.append(left)
958                    clade.clades.append(right)
959                    temp_tree = copy.deepcopy(tree)
960                    neighbors.append(temp_tree)
961                    # change back (parent + sister)
962                    del parent.clades[1]
963                    del clade.clades[0]
964                    parent.clades.append(sister)
965                    clade.clades.insert(0, left)
966                else:
967                    sister = parent.clades[0]
968                    # neighbor 1 (parent + right)
969                    del parent.clades[0]
970                    del clade.clades[1]
971                    parent.clades.insert(0, right)
972                    clade.clades.append(sister)
973                    temp_tree = copy.deepcopy(tree)
974                    neighbors.append(temp_tree)
975                    # neighbor 2 (parent + left)
976                    del parent.clades[0]
977                    del clade.clades[0]
978                    parent.clades.insert(0, left)
979                    clade.clades.append(right)
980                    temp_tree = copy.deepcopy(tree)
981                    neighbors.append(temp_tree)
982                    # change back (parent + sister)
983                    del parent.clades[0]
984                    del clade.clades[0]
985                    parent.clades.insert(0, sister)
986                    clade.clades.insert(0, left)
987        return neighbors
988
989
990# ######################## Parsimony Classes ##########################
991
992
993class ParsimonyScorer(Scorer):
994    """Parsimony scorer with a scoring matrix.
995
996    This is a combination of Fitch algorithm and Sankoff algorithm.
997    See ParsimonyTreeConstructor for usage.
998
999    :Parameters:
1000        matrix : _Matrix
1001            scoring matrix used in parsimony score calculation.
1002
1003    """
1004
1005    def __init__(self, matrix=None):
1006        """Initialize the class."""
1007        if not matrix or isinstance(matrix, _Matrix):
1008            self.matrix = matrix
1009        else:
1010            raise TypeError("Must provide a _Matrix object.")
1011
1012    def get_score(self, tree, alignment):
1013        """Calculate parsimony score using the Fitch algorithm.
1014
1015        Calculate and return the parsimony score given a tree and the
1016        MSA using either the Fitch algorithm (without a penalty matrix)
1017        or the Sankoff algorithm (with a matrix).
1018        """
1019        # make sure the tree is rooted and bifurcating
1020        if not tree.is_bifurcating():
1021            raise ValueError("The tree provided should be bifurcating.")
1022        if not tree.rooted:
1023            tree.root_at_midpoint()
1024        # sort tree terminals and alignment
1025        terms = tree.get_terminals()
1026        terms.sort(key=lambda term: term.name)
1027        alignment.sort()
1028        if not all(t.name == a.id for t, a in zip(terms, alignment)):
1029            raise ValueError(
1030                "Taxon names of the input tree should be the same with the alignment."
1031            )
1032        # term_align = dict(zip(terms, alignment))
1033        score = 0
1034        for i in range(len(alignment[0])):
1035            # parsimony score for column_i
1036            score_i = 0
1037            # get column
1038            column_i = alignment[:, i]
1039            # skip non-informative column
1040            if column_i == len(column_i) * column_i[0]:
1041                continue
1042
1043            # start calculating score_i using the tree and column_i
1044
1045            # Fitch algorithm without the penalty matrix
1046            if not self.matrix:
1047                # init by mapping terminal clades and states in column_i
1048                clade_states = dict(zip(terms, [{c} for c in column_i]))
1049                for clade in tree.get_nonterminals(order="postorder"):
1050                    clade_childs = clade.clades
1051                    left_state = clade_states[clade_childs[0]]
1052                    right_state = clade_states[clade_childs[1]]
1053                    state = left_state & right_state
1054                    if not state:
1055                        state = left_state | right_state
1056                        score_i = score_i + 1
1057                    clade_states[clade] = state
1058            # Sankoff algorithm with the penalty matrix
1059            else:
1060                inf = float("inf")
1061                # init score arrays for terminal clades
1062                alphabet = self.matrix.names
1063                length = len(alphabet)
1064                clade_scores = {}
1065                for j in range(len(column_i)):
1066                    array = [inf] * length
1067                    index = alphabet.index(column_i[j])
1068                    array[index] = 0
1069                    clade_scores[terms[j]] = array
1070                # bottom up calculation
1071                for clade in tree.get_nonterminals(order="postorder"):
1072                    clade_childs = clade.clades
1073                    left_score = clade_scores[clade_childs[0]]
1074                    right_score = clade_scores[clade_childs[1]]
1075                    array = []
1076                    for m in range(length):
1077                        min_l = inf
1078                        min_r = inf
1079                        for n in range(length):
1080                            sl = self.matrix[alphabet[m], alphabet[n]] + left_score[n]
1081                            sr = self.matrix[alphabet[m], alphabet[n]] + right_score[n]
1082                            if min_l > sl:
1083                                min_l = sl
1084                            if min_r > sr:
1085                                min_r = sr
1086                        array.append(min_l + min_r)
1087                    clade_scores[clade] = array
1088                # minimum from root score
1089                score_i = min(array)
1090                # TODO: resolve internal states
1091            score = score + score_i
1092        return score
1093
1094
1095class ParsimonyTreeConstructor(TreeConstructor):
1096    """Parsimony tree constructor.
1097
1098    :Parameters:
1099        searcher : TreeSearcher
1100            tree searcher to search the best parsimony tree.
1101        starting_tree : Tree
1102            starting tree provided to the searcher.
1103
1104    Examples
1105    --------
1106    We will load an alignment, and then load various trees which have already been computed from it::
1107
1108        from Bio import AlignIO, Phylo
1109        aln = AlignIO.read(open('TreeConstruction/msa.phy'), 'phylip')
1110        print(aln)
1111
1112    Output::
1113
1114        Alignment with 5 rows and 13 columns
1115        AACGTGGCCACAT Alpha
1116        AAGGTCGCCACAC Beta
1117        CAGTTCGCCACAA Gamma
1118        GAGATTTCCGCCT Delta
1119        GAGATCTCCGCCC Epsilon
1120
1121    Load a starting tree::
1122
1123        starting_tree = Phylo.read('TreeConstruction/nj.tre', 'newick')
1124        print(starting_tree)
1125
1126    Output::
1127
1128        Tree(rooted=False, weight=1.0)
1129            Clade(branch_length=0.0, name='Inner3')
1130                Clade(branch_length=0.01421, name='Inner2')
1131                    Clade(branch_length=0.23927, name='Inner1')
1132                        Clade(branch_length=0.08531, name='Epsilon')
1133                        Clade(branch_length=0.13691, name='Delta')
1134                    Clade(branch_length=0.2923, name='Alpha')
1135                Clade(branch_length=0.07477, name='Beta')
1136                Clade(branch_length=0.17523, name='Gamma')
1137
1138    Build the Parsimony tree from the starting tree::
1139
1140        scorer = Phylo.TreeConstruction.ParsimonyScorer()
1141        searcher = Phylo.TreeConstruction.NNITreeSearcher(scorer)
1142        constructor = Phylo.TreeConstruction.ParsimonyTreeConstructor(searcher, starting_tree)
1143        pars_tree = constructor.build_tree(aln)
1144        print(pars_tree)
1145
1146    Output::
1147
1148        Tree(rooted=True, weight=1.0)
1149            Clade(branch_length=0.0)
1150                Clade(branch_length=0.19732999999999998, name='Inner1')
1151                    Clade(branch_length=0.13691, name='Delta')
1152                    Clade(branch_length=0.08531, name='Epsilon')
1153                Clade(branch_length=0.04194000000000003, name='Inner2')
1154                    Clade(branch_length=0.01421, name='Inner3')
1155                        Clade(branch_length=0.17523, name='Gamma')
1156                        Clade(branch_length=0.07477, name='Beta')
1157                    Clade(branch_length=0.2923, name='Alpha')
1158
1159    """
1160
1161    def __init__(self, searcher, starting_tree=None):
1162        """Initialize the class."""
1163        self.searcher = searcher
1164        self.starting_tree = starting_tree
1165
1166    def build_tree(self, alignment):
1167        """Build the tree.
1168
1169        :Parameters:
1170            alignment : MultipleSeqAlignment
1171                multiple sequence alignment to calculate parsimony tree.
1172
1173        """
1174        # if starting_tree is none,
1175        # create a upgma tree with 'identity' scoring matrix
1176        if self.starting_tree is None:
1177            dtc = DistanceTreeConstructor(DistanceCalculator("identity"), "upgma")
1178            self.starting_tree = dtc.build_tree(alignment)
1179        return self.searcher.search(self.starting_tree, alignment)
1180