1# ----------------------------------------------------------------------------
2# Copyright (c) 2013--, scikit-bio development team.
3#
4# Distributed under the terms of the Modified BSD License.
5#
6# The full license is in the file COPYING.txt, distributed with this software.
7# ----------------------------------------------------------------------------
8
9import collections
10
11import numpy as np
12import pandas as pd
13
14from skbio.tree import DuplicateNodeError, MissingNodeError
15from skbio.diversity._phylogenetic import _nodes_by_counts
16
17
18def _validate_counts_vector(counts, suppress_cast=False):
19    """Validate and convert input to an acceptable counts vector type.
20
21    Note: may not always return a copy of `counts`!
22
23    """
24    counts = np.asarray(counts)
25    if not np.all(np.isreal(counts)):
26        raise ValueError("Counts vector must contain real-valued entries.")
27    if counts.ndim != 1:
28        raise ValueError("Only 1-D vectors are supported.")
29    elif (counts < 0).any():
30        raise ValueError("Counts vector cannot contain negative values.")
31
32    return counts
33
34
35def _validate_counts_matrix(counts, ids=None, suppress_cast=False):
36    results = []
37
38    # handle case of where counts is a single vector by making it a matrix.
39    # this has to be done before forcing counts into an ndarray because we
40    # don't yet know that all of the entries are of equal length
41    if isinstance(counts, pd.core.frame.DataFrame):
42        if ids is not None and len(counts.index) != len(ids):
43            raise ValueError(
44                "Number of rows in ``counts``"
45                " must be equal to number of provided ``ids``."
46            )
47        return np.asarray(counts)
48    else:
49
50        if len(counts) == 0 or not isinstance(counts[0], collections.Iterable):
51            counts = [counts]
52        counts = np.asarray(counts)
53        if counts.ndim > 2:
54            raise ValueError(
55                "Only 1-D and 2-D array-like objects can be provided "
56                "as input. Provided object has %d dimensions." %
57                counts.ndim)
58
59        if ids is not None and len(counts) != len(ids):
60            raise ValueError(
61                "Number of rows in ``counts`` must be equal "
62                "to number of provided ``ids``."
63            )
64
65        lens = []
66        for v in counts:
67            results.append(_validate_counts_vector(v, suppress_cast))
68            lens.append(len(v))
69        if len(set(lens)) > 1:
70            raise ValueError(
71                "All rows in ``counts`` must be of equal length."
72            )
73        return np.asarray(results)
74
75
76def _validate_otu_ids_and_tree(counts, otu_ids, tree):
77    len_otu_ids = len(otu_ids)
78    set_otu_ids = set(otu_ids)
79    if len_otu_ids != len(set_otu_ids):
80        raise ValueError("``otu_ids`` cannot contain duplicated ids.")
81
82    if len(counts) != len_otu_ids:
83        raise ValueError("``otu_ids`` must be the same length as ``counts`` "
84                         "vector(s).")
85
86    if len(tree.root().children) == 0:
87        raise ValueError("``tree`` must contain more than just a root node.")
88
89    if len(tree.root().children) > 2:
90        # this is an imperfect check for whether the tree is rooted or not.
91        # can this be improved?
92        raise ValueError("``tree`` must be rooted.")
93
94    # all nodes (except the root node) have corresponding branch lengths
95    # all tip names in tree are unique
96    # all otu_ids correspond to tip names in tree
97    branch_lengths = []
98    tip_names = []
99    for e in tree.traverse():
100        if not e.is_root():
101            branch_lengths.append(e.length)
102        if e.is_tip():
103            tip_names.append(e.name)
104    set_tip_names = set(tip_names)
105    if len(tip_names) != len(set_tip_names):
106        raise DuplicateNodeError("All tip names must be unique.")
107    if np.array([l is None for l in branch_lengths]).any():
108        raise ValueError("All non-root nodes in ``tree`` must have a branch "
109                         "length.")
110    missing_tip_names = set_otu_ids - set_tip_names
111    if missing_tip_names != set():
112        n_missing_tip_names = len(missing_tip_names)
113        raise MissingNodeError("All ``otu_ids`` must be present as tip names "
114                               "in ``tree``. ``otu_ids`` not corresponding to "
115                               "tip names (n=%d): %s" %
116                               (n_missing_tip_names,
117                                " ".join(missing_tip_names)))
118
119
120def _vectorize_counts_and_tree(counts, otu_ids, tree):
121    """ Index tree and convert counts to np.array in corresponding order
122    """
123    tree_index = tree.to_array(nan_length_value=0.0)
124    otu_ids = np.asarray(otu_ids)
125    counts = np.atleast_2d(counts)
126    counts_by_node = _nodes_by_counts(counts, otu_ids, tree_index)
127    branch_lengths = tree_index['length']
128
129    # branch_lengths is just a reference to the array inside of tree_index,
130    # but it's used so much that it's convenient to just pull it out here.
131    return counts_by_node.T, tree_index, branch_lengths
132
133
134def _get_phylogenetic_kwargs(counts, **kwargs):
135    try:
136        otu_ids = kwargs.pop('otu_ids')
137    except KeyError:
138        raise ValueError("``otu_ids`` is required for phylogenetic diversity "
139                         "metrics.")
140    try:
141        tree = kwargs.pop('tree')
142    except KeyError:
143        raise ValueError("``tree`` is required for phylogenetic diversity "
144                         "metrics.")
145
146    return otu_ids, tree, kwargs
147