1# Author: Jake Vanderplas <vanderplas@astro.washington.edu>
2# License: BSD 3 clause
3
4__all__ = ['BallTree']
5
6DOC_DICT = {'BinaryTree': 'BallTree', 'binary_tree': 'ball_tree'}
7
8VALID_METRICS = ['EuclideanDistance', 'SEuclideanDistance',
9                 'ManhattanDistance', 'ChebyshevDistance',
10                 'MinkowskiDistance', 'WMinkowskiDistance',
11                 'MahalanobisDistance', 'HammingDistance',
12                 'CanberraDistance', 'BrayCurtisDistance',
13                 'JaccardDistance', 'MatchingDistance',
14                 'DiceDistance', 'KulsinskiDistance',
15                 'RogersTanimotoDistance', 'RussellRaoDistance',
16                 'SokalMichenerDistance', 'SokalSneathDistance',
17                 'PyFuncDistance', 'HaversineDistance']
18
19
20include "_binary_tree.pxi"
21
22# Inherit BallTree from BinaryTree
23cdef class BallTree(BinaryTree):
24    __doc__ = CLASS_DOC.format(**DOC_DICT)
25    pass
26
27
28#----------------------------------------------------------------------
29# The functions below specialized the Binary Tree as a Ball Tree
30#
31#   Note that these functions use the concept of "reduced distance".
32#   The reduced distance, defined for some metrics, is a quantity which
33#   is more efficient to compute than the distance, but preserves the
34#   relative rankings of the true distance.  For example, the reduced
35#   distance for the Euclidean metric is the squared-euclidean distance.
36#   For some metrics, the reduced distance is simply the distance.
37
38cdef int allocate_data(BinaryTree tree, ITYPE_t n_nodes,
39                       ITYPE_t n_features) except -1:
40    """Allocate arrays needed for the KD Tree"""
41    tree.node_bounds_arr = np.zeros((1, n_nodes, n_features), dtype=DTYPE)
42    tree.node_bounds = tree.node_bounds_arr
43    return 0
44
45
46cdef int init_node(BinaryTree tree, NodeData_t[::1] node_data, ITYPE_t i_node,
47                   ITYPE_t idx_start, ITYPE_t idx_end) except -1:
48    """Initialize the node for the dataset stored in tree.data"""
49    cdef ITYPE_t n_features = tree.data.shape[1]
50    cdef ITYPE_t n_points = idx_end - idx_start
51
52    cdef ITYPE_t i, j
53    cdef DTYPE_t radius
54    cdef DTYPE_t *this_pt
55
56    cdef ITYPE_t* idx_array = &tree.idx_array[0]
57    cdef DTYPE_t* data = &tree.data[0, 0]
58    cdef DTYPE_t* centroid = &tree.node_bounds[0, i_node, 0]
59
60    cdef bint with_sample_weight = tree.sample_weight is not None
61    cdef DTYPE_t* sample_weight
62    cdef DTYPE_t sum_weight_node
63    if with_sample_weight:
64        sample_weight = &tree.sample_weight[0]
65
66    # determine Node centroid
67    for j in range(n_features):
68        centroid[j] = 0
69
70    if with_sample_weight:
71        sum_weight_node = 0
72        for i in range(idx_start, idx_end):
73            sum_weight_node += sample_weight[idx_array[i]]
74            this_pt = data + n_features * idx_array[i]
75            for j from 0 <= j < n_features:
76                centroid[j] += this_pt[j] * sample_weight[idx_array[i]]
77
78        for j in range(n_features):
79            centroid[j] /= sum_weight_node
80    else:
81        for i in range(idx_start, idx_end):
82            this_pt = data + n_features * idx_array[i]
83            for j from 0 <= j < n_features:
84                centroid[j] += this_pt[j]
85
86        for j in range(n_features):
87            centroid[j] /= n_points
88
89    # determine Node radius
90    radius = 0
91    for i in range(idx_start, idx_end):
92        radius = fmax(radius,
93                      tree.rdist(centroid,
94                                 data + n_features * idx_array[i],
95                                 n_features))
96
97    node_data[i_node].radius = tree.dist_metric._rdist_to_dist(radius)
98    node_data[i_node].idx_start = idx_start
99    node_data[i_node].idx_end = idx_end
100    return 0
101
102
103cdef inline DTYPE_t min_dist(BinaryTree tree, ITYPE_t i_node,
104                             DTYPE_t* pt) nogil except -1:
105    """Compute the minimum distance between a point and a node"""
106    cdef DTYPE_t dist_pt = tree.dist(pt, &tree.node_bounds[0, i_node, 0],
107                                     tree.data.shape[1])
108    return fmax(0, dist_pt - tree.node_data[i_node].radius)
109
110
111cdef inline DTYPE_t max_dist(BinaryTree tree, ITYPE_t i_node,
112                             DTYPE_t* pt) except -1:
113    """Compute the maximum distance between a point and a node"""
114    cdef DTYPE_t dist_pt = tree.dist(pt, &tree.node_bounds[0, i_node, 0],
115                                     tree.data.shape[1])
116    return dist_pt + tree.node_data[i_node].radius
117
118
119cdef inline int min_max_dist(BinaryTree tree, ITYPE_t i_node, DTYPE_t* pt,
120                             DTYPE_t* min_dist, DTYPE_t* max_dist) nogil except -1:
121    """Compute the minimum and maximum distance between a point and a node"""
122    cdef DTYPE_t dist_pt = tree.dist(pt, &tree.node_bounds[0, i_node, 0],
123                                     tree.data.shape[1])
124    cdef DTYPE_t rad = tree.node_data[i_node].radius
125    min_dist[0] = fmax(0, dist_pt - rad)
126    max_dist[0] = dist_pt + rad
127    return 0
128
129
130cdef inline DTYPE_t min_rdist(BinaryTree tree, ITYPE_t i_node,
131                              DTYPE_t* pt) nogil except -1:
132    """Compute the minimum reduced-distance between a point and a node"""
133    if tree.euclidean:
134        return euclidean_dist_to_rdist(min_dist(tree, i_node, pt))
135    else:
136        return tree.dist_metric._dist_to_rdist(min_dist(tree, i_node, pt))
137
138
139cdef inline DTYPE_t max_rdist(BinaryTree tree, ITYPE_t i_node,
140                              DTYPE_t* pt) except -1:
141    """Compute the maximum reduced-distance between a point and a node"""
142    if tree.euclidean:
143        return euclidean_dist_to_rdist(max_dist(tree, i_node, pt))
144    else:
145        return tree.dist_metric._dist_to_rdist(max_dist(tree, i_node, pt))
146
147
148cdef inline DTYPE_t min_dist_dual(BinaryTree tree1, ITYPE_t i_node1,
149                                  BinaryTree tree2, ITYPE_t i_node2) except -1:
150    """compute the minimum distance between two nodes"""
151    cdef DTYPE_t dist_pt = tree1.dist(&tree2.node_bounds[0, i_node2, 0],
152                                      &tree1.node_bounds[0, i_node1, 0],
153                                      tree1.data.shape[1])
154    return fmax(0, (dist_pt - tree1.node_data[i_node1].radius
155                    - tree2.node_data[i_node2].radius))
156
157
158cdef inline DTYPE_t max_dist_dual(BinaryTree tree1, ITYPE_t i_node1,
159                                  BinaryTree tree2, ITYPE_t i_node2) except -1:
160    """compute the maximum distance between two nodes"""
161    cdef DTYPE_t dist_pt = tree1.dist(&tree2.node_bounds[0, i_node2, 0],
162                                      &tree1.node_bounds[0, i_node1, 0],
163                                      tree1.data.shape[1])
164    return (dist_pt + tree1.node_data[i_node1].radius
165            + tree2.node_data[i_node2].radius)
166
167
168cdef inline DTYPE_t min_rdist_dual(BinaryTree tree1, ITYPE_t i_node1,
169                                   BinaryTree tree2, ITYPE_t i_node2) except -1:
170    """compute the minimum reduced distance between two nodes"""
171    if tree1.euclidean:
172        return euclidean_dist_to_rdist(min_dist_dual(tree1, i_node1,
173                                                     tree2, i_node2))
174    else:
175        return tree1.dist_metric._dist_to_rdist(min_dist_dual(tree1, i_node1,
176                                                              tree2, i_node2))
177
178
179cdef inline DTYPE_t max_rdist_dual(BinaryTree tree1, ITYPE_t i_node1,
180                                   BinaryTree tree2, ITYPE_t i_node2) except -1:
181    """compute the maximum reduced distance between two nodes"""
182    if tree1.euclidean:
183        return euclidean_dist_to_rdist(max_dist_dual(tree1, i_node1,
184                                                     tree2, i_node2))
185    else:
186        return tree1.dist_metric._dist_to_rdist(max_dist_dual(tree1, i_node1,
187                                                              tree2, i_node2))
188