1# Author: Jake Vanderplas <vanderplas@astro.washington.edu> 2# License: BSD 3 clause 3 4__all__ = ['BallTree'] 5 6DOC_DICT = {'BinaryTree': 'BallTree', 'binary_tree': 'ball_tree'} 7 8VALID_METRICS = ['EuclideanDistance', 'SEuclideanDistance', 9 'ManhattanDistance', 'ChebyshevDistance', 10 'MinkowskiDistance', 'WMinkowskiDistance', 11 'MahalanobisDistance', 'HammingDistance', 12 'CanberraDistance', 'BrayCurtisDistance', 13 'JaccardDistance', 'MatchingDistance', 14 'DiceDistance', 'KulsinskiDistance', 15 'RogersTanimotoDistance', 'RussellRaoDistance', 16 'SokalMichenerDistance', 'SokalSneathDistance', 17 'PyFuncDistance', 'HaversineDistance'] 18 19 20include "_binary_tree.pxi" 21 22# Inherit BallTree from BinaryTree 23cdef class BallTree(BinaryTree): 24 __doc__ = CLASS_DOC.format(**DOC_DICT) 25 pass 26 27 28#---------------------------------------------------------------------- 29# The functions below specialized the Binary Tree as a Ball Tree 30# 31# Note that these functions use the concept of "reduced distance". 32# The reduced distance, defined for some metrics, is a quantity which 33# is more efficient to compute than the distance, but preserves the 34# relative rankings of the true distance. For example, the reduced 35# distance for the Euclidean metric is the squared-euclidean distance. 36# For some metrics, the reduced distance is simply the distance. 37 38cdef int allocate_data(BinaryTree tree, ITYPE_t n_nodes, 39 ITYPE_t n_features) except -1: 40 """Allocate arrays needed for the KD Tree""" 41 tree.node_bounds_arr = np.zeros((1, n_nodes, n_features), dtype=DTYPE) 42 tree.node_bounds = tree.node_bounds_arr 43 return 0 44 45 46cdef int init_node(BinaryTree tree, NodeData_t[::1] node_data, ITYPE_t i_node, 47 ITYPE_t idx_start, ITYPE_t idx_end) except -1: 48 """Initialize the node for the dataset stored in tree.data""" 49 cdef ITYPE_t n_features = tree.data.shape[1] 50 cdef ITYPE_t n_points = idx_end - idx_start 51 52 cdef ITYPE_t i, j 53 cdef DTYPE_t radius 54 cdef DTYPE_t *this_pt 55 56 cdef ITYPE_t* idx_array = &tree.idx_array[0] 57 cdef DTYPE_t* data = &tree.data[0, 0] 58 cdef DTYPE_t* centroid = &tree.node_bounds[0, i_node, 0] 59 60 cdef bint with_sample_weight = tree.sample_weight is not None 61 cdef DTYPE_t* sample_weight 62 cdef DTYPE_t sum_weight_node 63 if with_sample_weight: 64 sample_weight = &tree.sample_weight[0] 65 66 # determine Node centroid 67 for j in range(n_features): 68 centroid[j] = 0 69 70 if with_sample_weight: 71 sum_weight_node = 0 72 for i in range(idx_start, idx_end): 73 sum_weight_node += sample_weight[idx_array[i]] 74 this_pt = data + n_features * idx_array[i] 75 for j from 0 <= j < n_features: 76 centroid[j] += this_pt[j] * sample_weight[idx_array[i]] 77 78 for j in range(n_features): 79 centroid[j] /= sum_weight_node 80 else: 81 for i in range(idx_start, idx_end): 82 this_pt = data + n_features * idx_array[i] 83 for j from 0 <= j < n_features: 84 centroid[j] += this_pt[j] 85 86 for j in range(n_features): 87 centroid[j] /= n_points 88 89 # determine Node radius 90 radius = 0 91 for i in range(idx_start, idx_end): 92 radius = fmax(radius, 93 tree.rdist(centroid, 94 data + n_features * idx_array[i], 95 n_features)) 96 97 node_data[i_node].radius = tree.dist_metric._rdist_to_dist(radius) 98 node_data[i_node].idx_start = idx_start 99 node_data[i_node].idx_end = idx_end 100 return 0 101 102 103cdef inline DTYPE_t min_dist(BinaryTree tree, ITYPE_t i_node, 104 DTYPE_t* pt) nogil except -1: 105 """Compute the minimum distance between a point and a node""" 106 cdef DTYPE_t dist_pt = tree.dist(pt, &tree.node_bounds[0, i_node, 0], 107 tree.data.shape[1]) 108 return fmax(0, dist_pt - tree.node_data[i_node].radius) 109 110 111cdef inline DTYPE_t max_dist(BinaryTree tree, ITYPE_t i_node, 112 DTYPE_t* pt) except -1: 113 """Compute the maximum distance between a point and a node""" 114 cdef DTYPE_t dist_pt = tree.dist(pt, &tree.node_bounds[0, i_node, 0], 115 tree.data.shape[1]) 116 return dist_pt + tree.node_data[i_node].radius 117 118 119cdef inline int min_max_dist(BinaryTree tree, ITYPE_t i_node, DTYPE_t* pt, 120 DTYPE_t* min_dist, DTYPE_t* max_dist) nogil except -1: 121 """Compute the minimum and maximum distance between a point and a node""" 122 cdef DTYPE_t dist_pt = tree.dist(pt, &tree.node_bounds[0, i_node, 0], 123 tree.data.shape[1]) 124 cdef DTYPE_t rad = tree.node_data[i_node].radius 125 min_dist[0] = fmax(0, dist_pt - rad) 126 max_dist[0] = dist_pt + rad 127 return 0 128 129 130cdef inline DTYPE_t min_rdist(BinaryTree tree, ITYPE_t i_node, 131 DTYPE_t* pt) nogil except -1: 132 """Compute the minimum reduced-distance between a point and a node""" 133 if tree.euclidean: 134 return euclidean_dist_to_rdist(min_dist(tree, i_node, pt)) 135 else: 136 return tree.dist_metric._dist_to_rdist(min_dist(tree, i_node, pt)) 137 138 139cdef inline DTYPE_t max_rdist(BinaryTree tree, ITYPE_t i_node, 140 DTYPE_t* pt) except -1: 141 """Compute the maximum reduced-distance between a point and a node""" 142 if tree.euclidean: 143 return euclidean_dist_to_rdist(max_dist(tree, i_node, pt)) 144 else: 145 return tree.dist_metric._dist_to_rdist(max_dist(tree, i_node, pt)) 146 147 148cdef inline DTYPE_t min_dist_dual(BinaryTree tree1, ITYPE_t i_node1, 149 BinaryTree tree2, ITYPE_t i_node2) except -1: 150 """compute the minimum distance between two nodes""" 151 cdef DTYPE_t dist_pt = tree1.dist(&tree2.node_bounds[0, i_node2, 0], 152 &tree1.node_bounds[0, i_node1, 0], 153 tree1.data.shape[1]) 154 return fmax(0, (dist_pt - tree1.node_data[i_node1].radius 155 - tree2.node_data[i_node2].radius)) 156 157 158cdef inline DTYPE_t max_dist_dual(BinaryTree tree1, ITYPE_t i_node1, 159 BinaryTree tree2, ITYPE_t i_node2) except -1: 160 """compute the maximum distance between two nodes""" 161 cdef DTYPE_t dist_pt = tree1.dist(&tree2.node_bounds[0, i_node2, 0], 162 &tree1.node_bounds[0, i_node1, 0], 163 tree1.data.shape[1]) 164 return (dist_pt + tree1.node_data[i_node1].radius 165 + tree2.node_data[i_node2].radius) 166 167 168cdef inline DTYPE_t min_rdist_dual(BinaryTree tree1, ITYPE_t i_node1, 169 BinaryTree tree2, ITYPE_t i_node2) except -1: 170 """compute the minimum reduced distance between two nodes""" 171 if tree1.euclidean: 172 return euclidean_dist_to_rdist(min_dist_dual(tree1, i_node1, 173 tree2, i_node2)) 174 else: 175 return tree1.dist_metric._dist_to_rdist(min_dist_dual(tree1, i_node1, 176 tree2, i_node2)) 177 178 179cdef inline DTYPE_t max_rdist_dual(BinaryTree tree1, ITYPE_t i_node1, 180 BinaryTree tree2, ITYPE_t i_node2) except -1: 181 """compute the maximum reduced distance between two nodes""" 182 if tree1.euclidean: 183 return euclidean_dist_to_rdist(max_dist_dual(tree1, i_node1, 184 tree2, i_node2)) 185 else: 186 return tree1.dist_metric._dist_to_rdist(max_dist_dual(tree1, i_node1, 187 tree2, i_node2)) 188