1<?php 2 3namespace Rubix\ML\Graph\Trees; 4 5use Rubix\ML\DataType; 6use Rubix\ML\Graph\Nodes\Box; 7use Rubix\ML\Datasets\Dataset; 8use Rubix\ML\Datasets\Labeled; 9use Rubix\ML\Graph\Nodes\Hypercube; 10use Rubix\ML\Graph\Nodes\Neighborhood; 11use Rubix\ML\Kernels\Distance\Distance; 12use Rubix\ML\Kernels\Distance\Euclidean; 13use Rubix\ML\Exceptions\InvalidArgumentException; 14use SplObjectStorage; 15 16use function count; 17use function array_slice; 18use function in_array; 19 20/** 21 * K-d Tree 22 * 23 * A multi-dimensional binary search tree for fast nearest neighbor queries. 24 * The K-d tree construction algorithm separates data points into bounded 25 * hypercubes or *bounding boxes* that are used to prune off nodes during 26 * nearest neighbor and range searches. 27 * 28 * [1] J. L. Bentley. (1975). Multidimensional Binary Search Trees Used for 29 * Associative Searching. 30 * 31 * @category Machine Learning 32 * @package Rubix/ML 33 * @author Andrew DalPino 34 */ 35class KDTree implements BinaryTree, Spatial 36{ 37 /** 38 * The maximum number of samples that each neighborhood node can contain. 39 * 40 * @var int 41 */ 42 protected $maxLeafSize; 43 44 /** 45 * The distance function to use when computing the distances. 46 * 47 * @var \Rubix\ML\Kernels\Distance\Distance 48 */ 49 protected $kernel; 50 51 /** 52 * The root node of the tree. 53 * 54 * @var \Rubix\ML\Graph\Nodes\Box|null 55 */ 56 protected $root; 57 58 /** 59 * @param int $maxLeafSize 60 * @param \Rubix\ML\Kernels\Distance\Distance|null $kernel 61 * @throws \Rubix\ML\Exceptions\InvalidArgumentException 62 */ 63 public function __construct(int $maxLeafSize = 30, ?Distance $kernel = null) 64 { 65 if ($maxLeafSize < 1) { 66 throw new InvalidArgumentException('At least one sample is required' 67 . " to form a neighborhood, $maxLeafSize given."); 68 } 69 70 if ($kernel and !in_array(DataType::continuous(), $kernel->compatibility())) { 71 throw new InvalidArgumentException('Distance kernel must be' 72 . ' compatible with continuous features.'); 73 } 74 75 $this->maxLeafSize = $maxLeafSize; 76 $this->kernel = $kernel ?? new Euclidean(); 77 } 78 79 /** 80 * Return the height of the tree i.e. the number of levels. 81 * 82 * @internal 83 * 84 * @return int 85 */ 86 public function height() : int 87 { 88 return $this->root ? $this->root->height() : 0; 89 } 90 91 /** 92 * Return the balance factor of the tree. A balanced tree will have 93 * a factor of 0 whereas an imbalanced tree will either be positive 94 * or negative indicating the direction and degree of the imbalance. 95 * 96 * @internal 97 * 98 * @return int 99 */ 100 public function balance() : int 101 { 102 return $this->root ? $this->root->balance() : 0; 103 } 104 105 /** 106 * Is the tree bare? 107 * 108 * @internal 109 * 110 * @return bool 111 */ 112 public function bare() : bool 113 { 114 return !$this->root; 115 } 116 117 /** 118 * Return the distance kernel used to compute distances. 119 * 120 * @internal 121 * 122 * @return \Rubix\ML\Kernels\Distance\Distance 123 */ 124 public function kernel() : Distance 125 { 126 return $this->kernel; 127 } 128 129 /** 130 * Insert a root node and recursively split the dataset until a terminating condition is met. 131 * 132 * @internal 133 * 134 * @param \Rubix\ML\Datasets\Labeled $dataset 135 * @throws \Rubix\ML\Exceptions\InvalidArgumentException 136 */ 137 public function grow(Labeled $dataset) : void 138 { 139 if ($dataset->columnType(0) != DataType::continuous() or !$dataset->homogeneous()) { 140 throw new InvalidArgumentException('KD Tree only works with continuous features.'); 141 } 142 143 $this->root = Box::split($dataset); 144 145 $stack = [$this->root]; 146 147 while ($current = array_pop($stack)) { 148 [$left, $right] = $current->groups(); 149 150 $current->cleanup(); 151 152 if ($left->numRows() > $this->maxLeafSize) { 153 $node = Box::split($left); 154 155 if ($node->isPoint()) { 156 $current->attachLeft(Neighborhood::terminate($left)); 157 } else { 158 $current->attachLeft($node); 159 160 $stack[] = $node; 161 } 162 } elseif (!$left->empty()) { 163 $current->attachLeft(Neighborhood::terminate($left)); 164 } 165 166 if ($right->numRows() > $this->maxLeafSize) { 167 $node = Box::split($right); 168 169 $current->attachRight($node); 170 171 $stack[] = $node; 172 } elseif (!$right->empty()) { 173 $current->attachRight(Neighborhood::terminate($right)); 174 } 175 } 176 } 177 178 /** 179 * Run a k nearest neighbors search and return the samples, labels, and distances in a 3-tuple. 180 * 181 * @internal 182 * 183 * @param list<int|float> $sample 184 * @param int $k 185 * @throws \Rubix\ML\Exceptions\InvalidArgumentException 186 * @return array{array[],mixed[],float[]} 187 */ 188 public function nearest(array $sample, int $k = 1) : array 189 { 190 $visited = new SplObjectStorage(); 191 192 $samples = $labels = $distances = []; 193 194 $stack = $this->path($sample); 195 196 while ($current = array_pop($stack)) { 197 if ($current instanceof Box) { 198 $radius = $distances[$k - 1] ?? INF; 199 200 foreach ($current->children() as $child) { 201 if (!$visited->contains($child)) { 202 if ($child instanceof Hypercube) { 203 foreach ($child->sides() as $side) { 204 $distance = $this->kernel->compute($sample, $side); 205 206 if ($distance < $radius) { 207 $stack[] = $child; 208 209 continue 2; 210 } 211 } 212 } 213 214 $visited->attach($child); 215 } 216 } 217 218 $visited->attach($current); 219 220 continue; 221 } 222 223 if ($current instanceof Neighborhood) { 224 $dataset = $current->dataset(); 225 226 foreach ($dataset->samples() as $neighbor) { 227 $distances[] = $this->kernel->compute($sample, $neighbor); 228 } 229 230 $samples = array_merge($samples, $dataset->samples()); 231 $labels = array_merge($labels, $dataset->labels()); 232 233 array_multisort($distances, $samples, $labels); 234 235 if (count($samples) > $k) { 236 $samples = array_slice($samples, 0, $k); 237 $labels = array_slice($labels, 0, $k); 238 $distances = array_slice($distances, 0, $k); 239 } 240 241 $visited->attach($current); 242 } 243 } 244 245 return [$samples, $labels, $distances]; 246 } 247 248 /** 249 * Run a range search over every cluster within radius and return the samples, labels and distances in a 3-tuple. 250 * 251 * @internal 252 * 253 * @param list<int|float> $sample 254 * @param float $radius 255 * @throws \Rubix\ML\Exceptions\InvalidArgumentException 256 * @return array{array[],mixed[],float[]} 257 */ 258 public function range(array $sample, float $radius) : array 259 { 260 $samples = $labels = $distances = []; 261 262 /** @var list<Box|Neighborhood> */ 263 $stack = [$this->root]; 264 265 while ($current = array_pop($stack)) { 266 if ($current instanceof Box) { 267 foreach ($current->children() as $child) { 268 if ($child instanceof Hypercube) { 269 foreach ($child->sides() as $side) { 270 $distance = $this->kernel->compute($sample, $side); 271 272 if ($distance <= $radius) { 273 $stack[] = $child; 274 275 continue 2; 276 } 277 } 278 } 279 } 280 281 continue; 282 } 283 284 if ($current instanceof Neighborhood) { 285 $dataset = $current->dataset(); 286 287 foreach ($dataset->samples() as $i => $neighbor) { 288 $distance = $this->kernel->compute($sample, $neighbor); 289 290 if ($distance <= $radius) { 291 $samples[] = $neighbor; 292 $labels[] = $dataset->label($i); 293 $distances[] = $distance; 294 } 295 } 296 } 297 } 298 299 return [$samples, $labels, $distances]; 300 } 301 302 /** 303 * Destroy the tree. 304 * 305 * @internal 306 */ 307 public function destroy() : void 308 { 309 $this->root = null; 310 } 311 312 /** 313 * Return the path of a sample taken from the root node to a leaf node in an array. 314 * 315 * @param list<int|float> $sample 316 * @return list<\Rubix\ML\Graph\Nodes\BinaryNode|null> 317 */ 318 protected function path(array $sample) : array 319 { 320 $current = $this->root; 321 322 $path = [$current]; 323 324 while ($current instanceof Box) { 325 if ($sample[$current->column()] < $current->value()) { 326 $current = $current->left(); 327 } else { 328 $current = $current->right(); 329 } 330 331 if ($current) { 332 $path[] = $current; 333 } 334 } 335 336 return $path; 337 } 338 339 /** 340 * Return the string representation of the object. 341 * 342 * @return string 343 */ 344 public function __toString() : string 345 { 346 return "K-d Tree (max_leaf_size: {$this->maxLeafSize}, kernel: {$this->kernel})"; 347 } 348} 349