1<?php
2
3namespace Rubix\ML\Graph\Trees;
4
5use Rubix\ML\DataType;
6use Rubix\ML\Graph\Nodes\Box;
7use Rubix\ML\Datasets\Dataset;
8use Rubix\ML\Datasets\Labeled;
9use Rubix\ML\Graph\Nodes\Hypercube;
10use Rubix\ML\Graph\Nodes\Neighborhood;
11use Rubix\ML\Kernels\Distance\Distance;
12use Rubix\ML\Kernels\Distance\Euclidean;
13use Rubix\ML\Exceptions\InvalidArgumentException;
14use SplObjectStorage;
15
16use function count;
17use function array_slice;
18use function in_array;
19
20/**
21 * K-d Tree
22 *
23 * A multi-dimensional binary search tree for fast nearest neighbor queries.
24 * The K-d tree construction algorithm separates data points into bounded
25 * hypercubes or *bounding boxes* that are used to prune off nodes during
26 * nearest neighbor and range searches.
27 *
28 * [1] J. L. Bentley. (1975). Multidimensional Binary Search Trees Used for
29 * Associative Searching.
30 *
31 * @category    Machine Learning
32 * @package     Rubix/ML
33 * @author      Andrew DalPino
34 */
35class KDTree implements BinaryTree, Spatial
36{
37    /**
38     * The maximum number of samples that each neighborhood node can contain.
39     *
40     * @var int
41     */
42    protected $maxLeafSize;
43
44    /**
45     * The distance function to use when computing the distances.
46     *
47     * @var \Rubix\ML\Kernels\Distance\Distance
48     */
49    protected $kernel;
50
51    /**
52     * The root node of the tree.
53     *
54     * @var \Rubix\ML\Graph\Nodes\Box|null
55     */
56    protected $root;
57
58    /**
59     * @param int $maxLeafSize
60     * @param \Rubix\ML\Kernels\Distance\Distance|null $kernel
61     * @throws \Rubix\ML\Exceptions\InvalidArgumentException
62     */
63    public function __construct(int $maxLeafSize = 30, ?Distance $kernel = null)
64    {
65        if ($maxLeafSize < 1) {
66            throw new InvalidArgumentException('At least one sample is required'
67                . " to form a neighborhood, $maxLeafSize given.");
68        }
69
70        if ($kernel and !in_array(DataType::continuous(), $kernel->compatibility())) {
71            throw new InvalidArgumentException('Distance kernel must be'
72                . ' compatible with continuous features.');
73        }
74
75        $this->maxLeafSize = $maxLeafSize;
76        $this->kernel = $kernel ?? new Euclidean();
77    }
78
79    /**
80     * Return the height of the tree i.e. the number of levels.
81     *
82     * @internal
83     *
84     * @return int
85     */
86    public function height() : int
87    {
88        return $this->root ? $this->root->height() : 0;
89    }
90
91    /**
92     * Return the balance factor of the tree. A balanced tree will have
93     * a factor of 0 whereas an imbalanced tree will either be positive
94     * or negative indicating the direction and degree of the imbalance.
95     *
96     * @internal
97     *
98     * @return int
99     */
100    public function balance() : int
101    {
102        return $this->root ? $this->root->balance() : 0;
103    }
104
105    /**
106     * Is the tree bare?
107     *
108     * @internal
109     *
110     * @return bool
111     */
112    public function bare() : bool
113    {
114        return !$this->root;
115    }
116
117    /**
118     * Return the distance kernel used to compute distances.
119     *
120     * @internal
121     *
122     * @return \Rubix\ML\Kernels\Distance\Distance
123     */
124    public function kernel() : Distance
125    {
126        return $this->kernel;
127    }
128
129    /**
130     * Insert a root node and recursively split the dataset until a terminating condition is met.
131     *
132     * @internal
133     *
134     * @param \Rubix\ML\Datasets\Labeled $dataset
135     * @throws \Rubix\ML\Exceptions\InvalidArgumentException
136     */
137    public function grow(Labeled $dataset) : void
138    {
139        if ($dataset->columnType(0) != DataType::continuous() or !$dataset->homogeneous()) {
140            throw new InvalidArgumentException('KD Tree only works with continuous features.');
141        }
142
143        $this->root = Box::split($dataset);
144
145        $stack = [$this->root];
146
147        while ($current = array_pop($stack)) {
148            [$left, $right] = $current->groups();
149
150            $current->cleanup();
151
152            if ($left->numRows() > $this->maxLeafSize) {
153                $node = Box::split($left);
154
155                if ($node->isPoint()) {
156                    $current->attachLeft(Neighborhood::terminate($left));
157                } else {
158                    $current->attachLeft($node);
159
160                    $stack[] = $node;
161                }
162            } elseif (!$left->empty()) {
163                $current->attachLeft(Neighborhood::terminate($left));
164            }
165
166            if ($right->numRows() > $this->maxLeafSize) {
167                $node = Box::split($right);
168
169                $current->attachRight($node);
170
171                $stack[] = $node;
172            } elseif (!$right->empty()) {
173                $current->attachRight(Neighborhood::terminate($right));
174            }
175        }
176    }
177
178    /**
179     * Run a k nearest neighbors search and return the samples, labels, and distances in a 3-tuple.
180     *
181     * @internal
182     *
183     * @param list<int|float> $sample
184     * @param int $k
185     * @throws \Rubix\ML\Exceptions\InvalidArgumentException
186     * @return array{array[],mixed[],float[]}
187     */
188    public function nearest(array $sample, int $k = 1) : array
189    {
190        $visited = new SplObjectStorage();
191
192        $samples = $labels = $distances = [];
193
194        $stack = $this->path($sample);
195
196        while ($current = array_pop($stack)) {
197            if ($current instanceof Box) {
198                $radius = $distances[$k - 1] ?? INF;
199
200                foreach ($current->children() as $child) {
201                    if (!$visited->contains($child)) {
202                        if ($child instanceof Hypercube) {
203                            foreach ($child->sides() as $side) {
204                                $distance = $this->kernel->compute($sample, $side);
205
206                                if ($distance < $radius) {
207                                    $stack[] = $child;
208
209                                    continue 2;
210                                }
211                            }
212                        }
213
214                        $visited->attach($child);
215                    }
216                }
217
218                $visited->attach($current);
219
220                continue;
221            }
222
223            if ($current instanceof Neighborhood) {
224                $dataset = $current->dataset();
225
226                foreach ($dataset->samples() as $neighbor) {
227                    $distances[] = $this->kernel->compute($sample, $neighbor);
228                }
229
230                $samples = array_merge($samples, $dataset->samples());
231                $labels = array_merge($labels, $dataset->labels());
232
233                array_multisort($distances, $samples, $labels);
234
235                if (count($samples) > $k) {
236                    $samples = array_slice($samples, 0, $k);
237                    $labels = array_slice($labels, 0, $k);
238                    $distances = array_slice($distances, 0, $k);
239                }
240
241                $visited->attach($current);
242            }
243        }
244
245        return [$samples, $labels, $distances];
246    }
247
248    /**
249     * Run a range search over every cluster within radius and return the samples, labels and distances in a 3-tuple.
250     *
251     * @internal
252     *
253     * @param list<int|float> $sample
254     * @param float $radius
255     * @throws \Rubix\ML\Exceptions\InvalidArgumentException
256     * @return array{array[],mixed[],float[]}
257     */
258    public function range(array $sample, float $radius) : array
259    {
260        $samples = $labels = $distances = [];
261
262        /** @var list<Box|Neighborhood> */
263        $stack = [$this->root];
264
265        while ($current = array_pop($stack)) {
266            if ($current instanceof Box) {
267                foreach ($current->children() as $child) {
268                    if ($child instanceof Hypercube) {
269                        foreach ($child->sides() as $side) {
270                            $distance = $this->kernel->compute($sample, $side);
271
272                            if ($distance <= $radius) {
273                                $stack[] = $child;
274
275                                continue 2;
276                            }
277                        }
278                    }
279                }
280
281                continue;
282            }
283
284            if ($current instanceof Neighborhood) {
285                $dataset = $current->dataset();
286
287                foreach ($dataset->samples() as $i => $neighbor) {
288                    $distance = $this->kernel->compute($sample, $neighbor);
289
290                    if ($distance <= $radius) {
291                        $samples[] = $neighbor;
292                        $labels[] = $dataset->label($i);
293                        $distances[] = $distance;
294                    }
295                }
296            }
297        }
298
299        return [$samples, $labels, $distances];
300    }
301
302    /**
303     * Destroy the tree.
304     *
305     * @internal
306     */
307    public function destroy() : void
308    {
309        $this->root = null;
310    }
311
312    /**
313     * Return the path of a sample taken from the root node to a leaf node in an array.
314     *
315     * @param list<int|float> $sample
316     * @return list<\Rubix\ML\Graph\Nodes\BinaryNode|null>
317     */
318    protected function path(array $sample) : array
319    {
320        $current = $this->root;
321
322        $path = [$current];
323
324        while ($current instanceof Box) {
325            if ($sample[$current->column()] < $current->value()) {
326                $current = $current->left();
327            } else {
328                $current = $current->right();
329            }
330
331            if ($current) {
332                $path[] = $current;
333            }
334        }
335
336        return $path;
337    }
338
339    /**
340     * Return the string representation of the object.
341     *
342     * @return string
343     */
344    public function __toString() : string
345    {
346        return "K-d Tree (max_leaf_size: {$this->maxLeafSize}, kernel: {$this->kernel})";
347    }
348}
349