1 /**
2  * @file core/tree/binary_space_tree/binary_space_tree_impl.hpp
3  *
4  * Implementation of generalized space partitioning tree.
5  *
6  * mlpack is free software; you may redistribute it and/or modify it under the
7  * terms of the 3-clause BSD license.  You should have received a copy of the
8  * 3-clause BSD license along with mlpack.  If not, see
9  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
10  */
11 #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_BINARY_SPACE_TREE_IMPL_HPP
12 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_BINARY_SPACE_TREE_IMPL_HPP
13 
14 // In case it wasn't included already for some reason.
15 #include "binary_space_tree.hpp"
16 
17 #include <mlpack/core/util/log.hpp>
18 #include <queue>
19 
20 namespace mlpack {
21 namespace tree {
22 
23 // Each of these overloads is kept as a separate function to keep the overhead
24 // from the two std::vectors out, if possible.
25 template<typename MetricType,
26          typename StatisticType,
27          typename MatType,
28          template<typename BoundMetricType, typename...> class BoundType,
29          template<typename SplitBoundType, typename SplitMatType>
30              class SplitType>
31 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree(const MatType & data,const size_t maxLeafSize)32 BinarySpaceTree(
33     const MatType& data,
34     const size_t maxLeafSize) :
35     left(NULL),
36     right(NULL),
37     parent(NULL),
38     begin(0), /* This root node starts at index 0, */
39     count(data.n_cols), /* and spans all of the dataset. */
40     bound(data.n_rows),
41     parentDistance(0), // Parent distance for the root is 0: it has no parent.
42     dataset(new MatType(data)) // Copies the dataset.
43 {
44   // Do the actual splitting of this node.
45   SplitType<BoundType<MetricType>, MatType> splitter;
46   SplitNode(maxLeafSize, splitter);
47 
48   // Create the statistic depending on if we are a leaf or not.
49   stat = StatisticType(*this);
50 }
51 
52 template<typename MetricType,
53          typename StatisticType,
54          typename MatType,
55          template<typename BoundMetricType, typename...> class BoundType,
56          template<typename SplitBoundType, typename SplitMatType>
57              class SplitType>
58 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree(const MatType & data,std::vector<size_t> & oldFromNew,const size_t maxLeafSize)59 BinarySpaceTree(
60     const MatType& data,
61     std::vector<size_t>& oldFromNew,
62     const size_t maxLeafSize) :
63     left(NULL),
64     right(NULL),
65     parent(NULL),
66     begin(0),
67     count(data.n_cols),
68     bound(data.n_rows),
69     parentDistance(0), // Parent distance for the root is 0: it has no parent.
70     dataset(new MatType(data)) // Copies the dataset.
71 {
72   // Initialize oldFromNew correctly.
73   oldFromNew.resize(data.n_cols);
74   for (size_t i = 0; i < data.n_cols; ++i)
75     oldFromNew[i] = i; // Fill with unharmed indices.
76 
77   // Now do the actual splitting.
78   SplitType<BoundType<MetricType>, MatType> splitter;
79   SplitNode(oldFromNew, maxLeafSize, splitter);
80 
81   // Create the statistic depending on if we are a leaf or not.
82   stat = StatisticType(*this);
83 }
84 
85 template<typename MetricType,
86          typename StatisticType,
87          typename MatType,
88          template<typename BoundMetricType, typename...> class BoundType,
89          template<typename SplitBoundType, typename SplitMatType>
90              class SplitType>
91 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree(const MatType & data,std::vector<size_t> & oldFromNew,std::vector<size_t> & newFromOld,const size_t maxLeafSize)92 BinarySpaceTree(
93     const MatType& data,
94     std::vector<size_t>& oldFromNew,
95     std::vector<size_t>& newFromOld,
96     const size_t maxLeafSize) :
97     left(NULL),
98     right(NULL),
99     parent(NULL),
100     begin(0),
101     count(data.n_cols),
102     bound(data.n_rows),
103     parentDistance(0), // Parent distance for the root is 0: it has no parent.
104     dataset(new MatType(data)) // Copies the dataset.
105 {
106   // Initialize the oldFromNew vector correctly.
107   oldFromNew.resize(data.n_cols);
108   for (size_t i = 0; i < data.n_cols; ++i)
109     oldFromNew[i] = i; // Fill with unharmed indices.
110 
111   // Now do the actual splitting.
112   SplitType<BoundType<MetricType>, MatType> splitter;
113   SplitNode(oldFromNew, maxLeafSize, splitter);
114 
115   // Create the statistic depending on if we are a leaf or not.
116   stat = StatisticType(*this);
117 
118   // Map the newFromOld indices correctly.
119   newFromOld.resize(data.n_cols);
120   for (size_t i = 0; i < data.n_cols; ++i)
121     newFromOld[oldFromNew[i]] = i;
122 }
123 
124 template<typename MetricType,
125          typename StatisticType,
126          typename MatType,
127          template<typename BoundMetricType, typename...> class BoundType,
128          template<typename SplitBoundType, typename SplitMatType>
129              class SplitType>
130 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree(MatType && data,const size_t maxLeafSize)131 BinarySpaceTree(MatType&& data, const size_t maxLeafSize) :
132     left(NULL),
133     right(NULL),
134     parent(NULL),
135     begin(0),
136     count(data.n_cols),
137     bound(data.n_rows),
138     parentDistance(0), // Parent distance for the root is 0: it has no parent.
139     dataset(new MatType(std::move(data)))
140 {
141   // Do the actual splitting of this node.
142   SplitType<BoundType<MetricType>, MatType> splitter;
143   SplitNode(maxLeafSize, splitter);
144 
145   // Create the statistic depending on if we are a leaf or not.
146   stat = StatisticType(*this);
147 }
148 
149 template<typename MetricType,
150          typename StatisticType,
151          typename MatType,
152          template<typename BoundMetricType, typename...> class BoundType,
153          template<typename SplitBoundType, typename SplitMatType>
154              class SplitType>
155 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree(MatType && data,std::vector<size_t> & oldFromNew,const size_t maxLeafSize)156 BinarySpaceTree(
157     MatType&& data,
158     std::vector<size_t>& oldFromNew,
159     const size_t maxLeafSize) :
160     left(NULL),
161     right(NULL),
162     parent(NULL),
163     begin(0),
164     count(data.n_cols),
165     bound(data.n_rows),
166     parentDistance(0), // Parent distance for the root is 0: it has no parent.
167     dataset(new MatType(std::move(data)))
168 {
169   // Initialize oldFromNew correctly.
170   oldFromNew.resize(dataset->n_cols);
171   for (size_t i = 0; i < dataset->n_cols; ++i)
172     oldFromNew[i] = i; // Fill with unharmed indices.
173 
174   // Now do the actual splitting.
175   SplitType<BoundType<MetricType>, MatType> splitter;
176   SplitNode(oldFromNew, maxLeafSize, splitter);
177 
178   // Create the statistic depending on if we are a leaf or not.
179   stat = StatisticType(*this);
180 }
181 
182 template<typename MetricType,
183          typename StatisticType,
184          typename MatType,
185          template<typename BoundMetricType, typename...> class BoundType,
186          template<typename SplitBoundType, typename SplitMatType>
187              class SplitType>
188 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree(MatType && data,std::vector<size_t> & oldFromNew,std::vector<size_t> & newFromOld,const size_t maxLeafSize)189 BinarySpaceTree(
190     MatType&& data,
191     std::vector<size_t>& oldFromNew,
192     std::vector<size_t>& newFromOld,
193     const size_t maxLeafSize) :
194     left(NULL),
195     right(NULL),
196     parent(NULL),
197     begin(0),
198     count(data.n_cols),
199     bound(data.n_rows),
200     parentDistance(0), // Parent distance for the root is 0: it has no parent.
201     dataset(new MatType(std::move(data)))
202 {
203   // Initialize the oldFromNew vector correctly.
204   oldFromNew.resize(dataset->n_cols);
205   for (size_t i = 0; i < dataset->n_cols; ++i)
206     oldFromNew[i] = i; // Fill with unharmed indices.
207 
208   // Now do the actual splitting.
209   SplitType<BoundType<MetricType>, MatType> splitter;
210   SplitNode(oldFromNew, maxLeafSize, splitter);
211 
212   // Create the statistic depending on if we are a leaf or not.
213   stat = StatisticType(*this);
214 
215   // Map the newFromOld indices correctly.
216   newFromOld.resize(dataset->n_cols);
217   for (size_t i = 0; i < dataset->n_cols; ++i)
218     newFromOld[oldFromNew[i]] = i;
219 }
220 
221 template<typename MetricType,
222          typename StatisticType,
223          typename MatType,
224          template<typename BoundMetricType, typename...> class BoundType,
225          template<typename SplitBoundType, typename SplitMatType>
226              class SplitType>
227 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree(BinarySpaceTree * parent,const size_t begin,const size_t count,SplitType<BoundType<MetricType>,MatType> & splitter,const size_t maxLeafSize)228 BinarySpaceTree(
229     BinarySpaceTree* parent,
230     const size_t begin,
231     const size_t count,
232     SplitType<BoundType<MetricType>, MatType>& splitter,
233     const size_t maxLeafSize) :
234     left(NULL),
235     right(NULL),
236     parent(parent),
237     begin(begin),
238     count(count),
239     bound(parent->Dataset().n_rows),
240     dataset(&parent->Dataset()) // Point to the parent's dataset.
241 {
242   // Perform the actual splitting.
243   SplitNode(maxLeafSize, splitter);
244 
245   // Create the statistic depending on if we are a leaf or not.
246   stat = StatisticType(*this);
247 }
248 
249 template<typename MetricType,
250          typename StatisticType,
251          typename MatType,
252          template<typename BoundMetricType, typename...> class BoundType,
253          template<typename SplitBoundType, typename SplitMatType>
254              class SplitType>
255 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree(BinarySpaceTree * parent,const size_t begin,const size_t count,std::vector<size_t> & oldFromNew,SplitType<BoundType<MetricType>,MatType> & splitter,const size_t maxLeafSize)256 BinarySpaceTree(
257     BinarySpaceTree* parent,
258     const size_t begin,
259     const size_t count,
260     std::vector<size_t>& oldFromNew,
261     SplitType<BoundType<MetricType>, MatType>& splitter,
262     const size_t maxLeafSize) :
263     left(NULL),
264     right(NULL),
265     parent(parent),
266     begin(begin),
267     count(count),
268     bound(parent->Dataset().n_rows),
269     dataset(&parent->Dataset())
270 {
271   // Hopefully the vector is initialized correctly!  We can't check that
272   // entirely but we can do a minor sanity check.
273   assert(oldFromNew.size() == dataset->n_cols);
274 
275   // Perform the actual splitting.
276   SplitNode(oldFromNew, maxLeafSize, splitter);
277 
278   // Create the statistic depending on if we are a leaf or not.
279   stat = StatisticType(*this);
280 }
281 
282 template<typename MetricType,
283          typename StatisticType,
284          typename MatType,
285          template<typename BoundMetricType, typename...> class BoundType,
286          template<typename SplitBoundType, typename SplitMatType>
287              class SplitType>
288 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree(BinarySpaceTree * parent,const size_t begin,const size_t count,std::vector<size_t> & oldFromNew,std::vector<size_t> & newFromOld,SplitType<BoundType<MetricType>,MatType> & splitter,const size_t maxLeafSize)289 BinarySpaceTree(
290     BinarySpaceTree* parent,
291     const size_t begin,
292     const size_t count,
293     std::vector<size_t>& oldFromNew,
294     std::vector<size_t>& newFromOld,
295     SplitType<BoundType<MetricType>, MatType>& splitter,
296     const size_t maxLeafSize) :
297     left(NULL),
298     right(NULL),
299     parent(parent),
300     begin(begin),
301     count(count),
302     bound(parent->Dataset()->n_rows),
303     dataset(&parent->Dataset())
304 {
305   // Hopefully the vector is initialized correctly!  We can't check that
306   // entirely but we can do a minor sanity check.
307   Log::Assert(oldFromNew.size() == dataset->n_cols);
308 
309   // Perform the actual splitting.
310   SplitNode(oldFromNew, maxLeafSize, splitter);
311 
312   // Create the statistic depending on if we are a leaf or not.
313   stat = StatisticType(*this);
314 
315   // Map the newFromOld indices correctly.
316   newFromOld.resize(dataset->n_cols);
317   for (size_t i = 0; i < dataset->n_cols; ++i)
318     newFromOld[oldFromNew[i]] = i;
319 }
320 
321 /**
322  * Create a binary space tree by copying the other tree.  Be careful!  This can
323  * take a long time and use a lot of memory.
324  */
325 template<typename MetricType,
326          typename StatisticType,
327          typename MatType,
328          template<typename BoundMetricType, typename...> class BoundType,
329          template<typename SplitBoundType, typename SplitMatType>
330              class SplitType>
331 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree(const BinarySpaceTree & other)332 BinarySpaceTree(
333     const BinarySpaceTree& other) :
334     left(NULL),
335     right(NULL),
336     parent(other.parent),
337     begin(other.begin),
338     count(other.count),
339     bound(other.bound),
340     stat(other.stat),
341     parentDistance(other.parentDistance),
342     furthestDescendantDistance(other.furthestDescendantDistance),
343     minimumBoundDistance(other.minimumBoundDistance),
344     // Copy matrix, but only if we are the root.
345     dataset((other.parent == NULL) ? new MatType(*other.dataset) : NULL)
346 {
347   // Create left and right children (if any).
348   if (other.Left())
349   {
350     left = new BinarySpaceTree(*other.Left());
351     left->Parent() = this; // Set parent to this, not other tree.
352   }
353 
354   if (other.Right())
355   {
356     right = new BinarySpaceTree(*other.Right());
357     right->Parent() = this; // Set parent to this, not other tree.
358   }
359 
360   // Propagate matrix, but only if we are the root.
361   if (parent == NULL)
362   {
363     std::queue<BinarySpaceTree*> queue;
364     if (left)
365       queue.push(left);
366     if (right)
367       queue.push(right);
368     while (!queue.empty())
369     {
370       BinarySpaceTree* node = queue.front();
371       queue.pop();
372 
373       node->dataset = dataset;
374       if (node->left)
375         queue.push(node->left);
376       if (node->right)
377         queue.push(node->right);
378     }
379   }
380 }
381 
382 /**
383  * Copy assignment operator: copy the given other tree.
384  */
385 template<typename MetricType,
386          typename StatisticType,
387          typename MatType,
388          template<typename BoundMetricType, typename...> class BoundType,
389          template<typename SplitBoundType, typename SplitMatType>
390              class SplitType>
391 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>&
392 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
operator =(const BinarySpaceTree & other)393 operator=(const BinarySpaceTree& other)
394 {
395   // Return if it's the same tree.
396   if (this == &other)
397     return *this;
398 
399   // Freeing memory that will not be used anymore.
400   delete dataset;
401   delete left;
402   delete right;
403 
404   left = NULL;
405   right = NULL;
406   parent = other.Parent();
407   begin = other.Begin();
408   count = other.Count();
409   bound = other.bound;
410   stat = other.stat;
411   parentDistance = other.ParentDistance();
412   furthestDescendantDistance = other.FurthestDescendantDistance();
413   minimumBoundDistance = other.MinimumBoundDistance();
414   // Copy matrix, but only if we are the root.
415   dataset = ((other.parent == NULL) ? new MatType(*other.dataset) : NULL);
416 
417   // Create left and right children (if any).
418   if (other.Left())
419   {
420     left = new BinarySpaceTree(*other.Left());
421     left->Parent() = this; // Set parent to this, not other tree.
422   }
423 
424   if (other.Right())
425   {
426     right = new BinarySpaceTree(*other.Right());
427     right->Parent() = this; // Set parent to this, not other tree.
428   }
429 
430   // Propagate matrix, but only if we are the root.
431   if (parent == NULL)
432   {
433     std::queue<BinarySpaceTree*> queue;
434     if (left)
435       queue.push(left);
436     if (right)
437       queue.push(right);
438     while (!queue.empty())
439     {
440       BinarySpaceTree* node = queue.front();
441       queue.pop();
442 
443       node->dataset = dataset;
444       if (node->left)
445         queue.push(node->left);
446       if (node->right)
447         queue.push(node->right);
448     }
449   }
450 
451   return *this;
452 }
453 
454 /**
455  * Move assignment operator: take ownership of the given tree.
456  */
457 template<typename MetricType,
458          typename StatisticType,
459          typename MatType,
460          template<typename BoundMetricType, typename...> class BoundType,
461          template<typename SplitBoundType, typename SplitMatType>
462              class SplitType>
463 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>&
464 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
operator =(BinarySpaceTree && other)465 operator=(BinarySpaceTree&& other)
466 {
467   // Return if it's the same tree.
468   if (this == &other)
469     return *this;
470 
471   // Freeing memory that will not be used anymore.
472   delete dataset;
473   delete left;
474   delete right;
475 
476   parent = other.Parent();
477   left = other.Left();
478   right = other.Right();
479   begin = other.Begin();
480   count = other.Count();
481   bound = std::move(other.bound);
482   stat = std::move(other.stat);
483   parentDistance = other.ParentDistance();
484   furthestDescendantDistance = other.FurthestDescendantDistance();
485   minimumBoundDistance = other.MinimumBoundDistance();
486   dataset = other.dataset;
487 
488   other.left = NULL;
489   other.right = NULL;
490   other.parent = NULL;
491   other.begin = 0;
492   other.count = 0;
493   other.parentDistance = 0.0;
494   other.furthestDescendantDistance = 0.0;
495   other.minimumBoundDistance = 0.0;
496   other.dataset = NULL;
497 
498   return *this;
499 }
500 
501 
502 /**
503  * Move constructor.
504  */
505 template<typename MetricType,
506          typename StatisticType,
507          typename MatType,
508          template<typename BoundMetricType, typename...> class BoundType,
509          template<typename SplitBoundType, typename SplitMatType>
510              class SplitType>
511 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree(BinarySpaceTree && other)512 BinarySpaceTree(BinarySpaceTree&& other) :
513     left(other.left),
514     right(other.right),
515     parent(other.parent),
516     begin(other.begin),
517     count(other.count),
518     bound(std::move(other.bound)),
519     stat(std::move(other.stat)),
520     parentDistance(other.parentDistance),
521     furthestDescendantDistance(other.furthestDescendantDistance),
522     minimumBoundDistance(other.minimumBoundDistance),
523     dataset(other.dataset)
524 {
525   // Now we are a clone of the other tree.  But we must also clear the other
526   // tree's contents, so it doesn't delete anything when it is destructed.
527   other.left = NULL;
528   other.right = NULL;
529   other.parent = NULL;
530   other.begin = 0;
531   other.count = 0;
532   other.parentDistance = 0.0;
533   other.furthestDescendantDistance = 0.0;
534   other.minimumBoundDistance = 0.0;
535   other.dataset = NULL;
536 
537   // Set new parent.
538   if (left)
539     left->parent = this;
540   if (right)
541     right->parent = this;
542 }
543 
544 /**
545  * Initialize the tree from an archive.
546  */
547 template<typename MetricType,
548          typename StatisticType,
549          typename MatType,
550          template<typename BoundMetricType, typename...> class BoundType,
551          template<typename SplitBoundType, typename SplitMatType>
552              class SplitType>
553 template<typename Archive>
554 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree(Archive & ar,const typename std::enable_if_t<Archive::is_loading::value> *)555 BinarySpaceTree(
556     Archive& ar,
557     const typename std::enable_if_t<Archive::is_loading::value>*) :
558     BinarySpaceTree() // Create an empty BinarySpaceTree.
559 {
560   // We've delegated to the constructor which gives us an empty tree, and now we
561   // can serialize from it.
562   ar >> BOOST_SERIALIZATION_NVP(*this);
563 }
564 
565 /**
566  * Deletes this node, deallocating the memory for the children and calling their
567  * destructors in turn.  This will invalidate any pointers or references to any
568  * nodes which are children of this one.
569  */
570 template<typename MetricType,
571          typename StatisticType,
572          typename MatType,
573          template<typename BoundMetricType, typename...> class BoundType,
574          template<typename SplitBoundType, typename SplitMatType>
575              class SplitType>
576 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
~BinarySpaceTree()577     ~BinarySpaceTree()
578 {
579   delete left;
580   delete right;
581 
582   // If we're the root, delete the matrix.
583   if (!parent)
584     delete dataset;
585 }
586 
587 template<typename MetricType,
588          typename StatisticType,
589          typename MatType,
590          template<typename BoundMetricType, typename...> class BoundType,
591          template<typename SplitBoundType, typename SplitMatType>
592              class SplitType>
593 inline bool BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
IsLeaf() const594                             SplitType>::IsLeaf() const
595 {
596   return !left;
597 }
598 
599 /**
600  * Returns the number of children in this node.
601  */
602 template<typename MetricType,
603          typename StatisticType,
604          typename MatType,
605          template<typename BoundMetricType, typename...> class BoundType,
606          template<typename SplitBoundType, typename SplitMatType>
607              class SplitType>
608 inline size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
NumChildren() const609                               SplitType>::NumChildren() const
610 {
611   if (left && right)
612     return 2;
613   if (left)
614     return 1;
615 
616   return 0;
617 }
618 
619 /**
620  * Return the index of the nearest child node to the given query point.  If
621  * this is a leaf node, it will return NumChildren() (invalid index).
622  */
623 template<typename MetricType,
624          typename StatisticType,
625          typename MatType,
626          template<typename BoundMetricType, typename...> class BoundType,
627          template<typename SplitBoundType, typename SplitMatType>
628              class SplitType>
629 template<typename VecType>
630 size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
GetNearestChild(const VecType & point,typename std::enable_if_t<IsVector<VecType>::value> *)631     SplitType>::GetNearestChild(
632     const VecType& point,
633     typename std::enable_if_t<IsVector<VecType>::value>*)
634 {
635   if (IsLeaf() || !left || !right)
636     return 0;
637 
638   if (left->MinDistance(point) <= right->MinDistance(point))
639     return 0;
640   return 1;
641 }
642 
643 /**
644  * Return the index of the furthest child node to the given query point.  If
645  * this is a leaf node, it will return NumChildren() (invalid index).
646  */
647 template<typename MetricType,
648          typename StatisticType,
649          typename MatType,
650          template<typename BoundMetricType, typename...> class BoundType,
651          template<typename SplitBoundType, typename SplitMatType>
652              class SplitType>
653 template<typename VecType>
654 size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
GetFurthestChild(const VecType & point,typename std::enable_if_t<IsVector<VecType>::value> *)655     SplitType>::GetFurthestChild(
656     const VecType& point,
657     typename std::enable_if_t<IsVector<VecType>::value>*)
658 {
659   if (IsLeaf() || !left || !right)
660     return 0;
661 
662   if (left->MaxDistance(point) > right->MaxDistance(point))
663     return 0;
664   return 1;
665 }
666 
667 /**
668  * Return the index of the nearest child node to the given query node.  If it
669  * can't decide, it will return NumChildren() (invalid index).
670  */
671 template<typename MetricType,
672          typename StatisticType,
673          typename MatType,
674          template<typename BoundMetricType, typename...> class BoundType,
675          template<typename SplitBoundType, typename SplitMatType>
676              class SplitType>
677 size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
GetNearestChild(const BinarySpaceTree & queryNode)678     SplitType>::GetNearestChild(const BinarySpaceTree& queryNode)
679 {
680   if (IsLeaf() || !left || !right)
681     return 0;
682 
683   ElemType leftDist = left->MinDistance(queryNode);
684   ElemType rightDist = right->MinDistance(queryNode);
685   if (leftDist < rightDist)
686     return 0;
687   if (rightDist < leftDist)
688     return 1;
689   return NumChildren();
690 }
691 
692 /**
693  * Return the index of the furthest child node to the given query node.  If it
694  * can't decide, it will return NumChildren() (invalid index).
695  */
696 template<typename MetricType,
697          typename StatisticType,
698          typename MatType,
699          template<typename BoundMetricType, typename...> class BoundType,
700          template<typename SplitBoundType, typename SplitMatType>
701              class SplitType>
702 size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
GetFurthestChild(const BinarySpaceTree & queryNode)703     SplitType>::GetFurthestChild(const BinarySpaceTree& queryNode)
704 {
705   if (IsLeaf() || !left || !right)
706     return 0;
707 
708   ElemType leftDist = left->MaxDistance(queryNode);
709   ElemType rightDist = right->MaxDistance(queryNode);
710   if (leftDist > rightDist)
711     return 0;
712   if (rightDist > leftDist)
713     return 1;
714   return NumChildren();
715 }
716 
717 /**
718  * Return a bound on the furthest point in the node from the center.  This
719  * returns 0 unless the node is a leaf.
720  */
721 template<typename MetricType,
722          typename StatisticType,
723          typename MatType,
724          template<typename BoundMetricType, typename...> class BoundType,
725          template<typename SplitBoundType, typename SplitMatType>
726              class SplitType>
727 inline
728 typename BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
729     SplitType>::ElemType
730 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
FurthestPointDistance() const731     SplitType>::FurthestPointDistance() const
732 {
733   if (!IsLeaf())
734     return 0.0;
735 
736   // Otherwise return the distance from the center to a corner of the bound.
737   return 0.5 * bound.Diameter();
738 }
739 
740 /**
741  * Return the furthest possible descendant distance.  This returns the maximum
742  * distance from the center to the edge of the bound and not the empirical
743  * quantity which is the actual furthest descendant distance.  So the actual
744  * furthest descendant distance may be less than what this method returns (but
745  * it will never be greater than this).
746  */
747 template<typename MetricType,
748          typename StatisticType,
749          typename MatType,
750          template<typename BoundMetricType, typename...> class BoundType,
751          template<typename SplitBoundType, typename SplitMatType>
752              class SplitType>
753 inline
754 typename BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
755     SplitType>::ElemType
756 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
FurthestDescendantDistance() const757     SplitType>::FurthestDescendantDistance() const
758 {
759   return furthestDescendantDistance;
760 }
761 
762 //! Return the minimum distance from the center to any bound edge.
763 template<typename MetricType,
764          typename StatisticType,
765          typename MatType,
766          template<typename BoundMetricType, typename...> class BoundType,
767          template<typename SplitBoundType, typename SplitMatType>
768              class SplitType>
769 inline
770 typename BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
771     SplitType>::ElemType
772 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
MinimumBoundDistance() const773     SplitType>::MinimumBoundDistance() const
774 {
775   return bound.MinWidth() / 2.0;
776 }
777 
778 /**
779  * Return the specified child.
780  */
781 template<typename MetricType,
782          typename StatisticType,
783          typename MatType,
784          template<typename BoundMetricType, typename...> class BoundType,
785          template<typename SplitBoundType, typename SplitMatType>
786              class SplitType>
787 inline BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
788                        SplitType>&
789     BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
Child(const size_t child) const790                     SplitType>::Child(const size_t child) const
791 {
792   if (child == 0)
793     return *left;
794   else
795     return *right;
796 }
797 
798 /**
799  * Return the number of points contained in this node.
800  */
801 template<typename MetricType,
802          typename StatisticType,
803          typename MatType,
804          template<typename BoundMetricType, typename...> class BoundType,
805          template<typename SplitBoundType, typename SplitMatType>
806              class SplitType>
807 inline size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
NumPoints() const808                               SplitType>::NumPoints() const
809 {
810   if (left)
811     return 0;
812 
813   return count;
814 }
815 
816 /**
817  * Return the number of descendants contained in the node.
818  */
819 template<typename MetricType,
820          typename StatisticType,
821          typename MatType,
822          template<typename BoundMetricType, typename...> class BoundType,
823          template<typename SplitBoundType, typename SplitMatType>
824              class SplitType>
825 inline size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
NumDescendants() const826                               SplitType>::NumDescendants() const
827 {
828   return count;
829 }
830 
831 /**
832  * Return the index of a particular descendant contained in this node.
833  */
834 template<typename MetricType,
835          typename StatisticType,
836          typename MatType,
837          template<typename BoundMetricType, typename...> class BoundType,
838          template<typename SplitBoundType, typename SplitMatType>
839              class SplitType>
840 inline size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
Descendant(const size_t index) const841                               SplitType>::Descendant(const size_t index) const
842 {
843   return (begin + index);
844 }
845 
846 /**
847  * Return the index of a particular point contained in this node.
848  */
849 template<typename MetricType,
850          typename StatisticType,
851          typename MatType,
852          template<typename BoundMetricType, typename...> class BoundType,
853          template<typename SplitBoundType, typename SplitMatType>
854              class SplitType>
855 inline size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
Point(const size_t index) const856                               SplitType>::Point(const size_t index) const
857 {
858   return (begin + index);
859 }
860 
861 template<typename MetricType,
862          typename StatisticType,
863          typename MatType,
864          template<typename BoundMetricType, typename...> class BoundType,
865          template<typename SplitBoundType, typename SplitMatType>
866              class SplitType>
867 void BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
SplitNode(const size_t maxLeafSize,SplitType<BoundType<MetricType>,MatType> & splitter)868     SplitNode(const size_t maxLeafSize,
869               SplitType<BoundType<MetricType>, MatType>& splitter)
870 {
871   // We need to expand the bounds of this node properly.
872   UpdateBound(bound);
873 
874   // Calculate the furthest descendant distance.
875   furthestDescendantDistance = 0.5 * bound.Diameter();
876 
877   // Now, check if we need to split at all.
878   if (count <= maxLeafSize)
879     return; // We can't split this.
880 
881   // splitCol denotes the two partitions of the dataset after the split. The
882   // points on its left go to the left child and the others go to the right
883   // child.
884   size_t splitCol;
885 
886   // Find the partition of the node. This method does not perform the split.
887   typename Split::SplitInfo splitInfo;
888 
889   const bool split = splitter.SplitNode(bound, *dataset, begin, count,
890       splitInfo);
891 
892   // The node may not be always split. For instance, if all the points are the
893   // same, we can't split them.
894   if (!split)
895     return;
896 
897   // Perform the actual splitting.  This will order the dataset such that
898   // points that belong to the left subtree are on the left of splitCol, and
899   // points from the right subtree are on the right side of splitCol.
900   splitCol = splitter.PerformSplit(*dataset, begin, count, splitInfo);
901 
902   assert(splitCol > begin);
903   assert(splitCol < begin + count);
904 
905   // Now that we know the split column, we will recursively split the children
906   // by calling their constructors (which perform this splitting process).
907   left = new BinarySpaceTree(this, begin, splitCol - begin, splitter,
908       maxLeafSize);
909   right = new BinarySpaceTree(this, splitCol, begin + count - splitCol,
910       splitter, maxLeafSize);
911 
912   // Calculate parent distances for those two nodes.
913   arma::vec center, leftCenter, rightCenter;
914   Center(center);
915   left->Center(leftCenter);
916   right->Center(rightCenter);
917 
918   const ElemType leftParentDistance = bound.Metric().Evaluate(center,
919       leftCenter);
920   const ElemType rightParentDistance = bound.Metric().Evaluate(center,
921       rightCenter);
922 
923   left->ParentDistance() = leftParentDistance;
924   right->ParentDistance() = rightParentDistance;
925 }
926 
927 template<typename MetricType,
928          typename StatisticType,
929          typename MatType,
930          template<typename BoundMetricType, typename...> class BoundType,
931          template<typename SplitBoundType, typename SplitMatType>
932              class SplitType>
933 void BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
SplitNode(std::vector<size_t> & oldFromNew,const size_t maxLeafSize,SplitType<BoundType<MetricType>,MatType> & splitter)934 SplitNode(std::vector<size_t>& oldFromNew,
935           const size_t maxLeafSize,
936           SplitType<BoundType<MetricType>, MatType>& splitter)
937 {
938   // We need to expand the bounds of this node properly.
939   UpdateBound(bound);
940 
941   // Calculate the furthest descendant distance.
942   furthestDescendantDistance = 0.5 * bound.Diameter();
943 
944   // First, check if we need to split at all.
945   if (count <= maxLeafSize)
946     return; // We can't split this.
947 
948   // splitCol denotes the two partitions of the dataset after the split. The
949   // points on its left go to the left child and the others go to the right
950   // child.
951   size_t splitCol;
952 
953   // Find the partition of the node. This method does not perform the split.
954   typename Split::SplitInfo splitInfo;
955 
956   const bool split = splitter.SplitNode(bound, *dataset, begin, count,
957       splitInfo);
958 
959   // The node may not be always split. For instance, if all the points are the
960   // same, we can't split them.
961   if (!split)
962     return;
963 
964   // Perform the actual splitting.  This will order the dataset such that
965   // points that belong to the left subtree are on the left of splitCol, and
966   // points from the right subtree are on the right side of splitCol.
967   splitCol = splitter.PerformSplit(*dataset, begin, count, splitInfo,
968       oldFromNew);
969 
970   assert(splitCol > begin);
971   assert(splitCol < begin + count);
972 
973   // Now that we know the split column, we will recursively split the children
974   // by calling their constructors (which perform this splitting process).
975   left = new BinarySpaceTree(this, begin, splitCol - begin, oldFromNew,
976       splitter, maxLeafSize);
977   right = new BinarySpaceTree(this, splitCol, begin + count - splitCol,
978       oldFromNew, splitter, maxLeafSize);
979 
980   // Calculate parent distances for those two nodes.
981   arma::vec center, leftCenter, rightCenter;
982   Center(center);
983   left->Center(leftCenter);
984   right->Center(rightCenter);
985 
986   const ElemType leftParentDistance = bound.Metric().Evaluate(center,
987       leftCenter);
988   const ElemType rightParentDistance = bound.Metric().Evaluate(center,
989       rightCenter);
990 
991   left->ParentDistance() = leftParentDistance;
992   right->ParentDistance() = rightParentDistance;
993 }
994 
995 template<typename MetricType,
996          typename StatisticType,
997          typename MatType,
998          template<typename BoundMetricType, typename...> class BoundType,
999          template<typename SplitBoundType, typename SplitMatType>
1000              class SplitType>
1001 template<typename BoundType2>
1002 void BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
UpdateBound(BoundType2 & boundToUpdate)1003 UpdateBound(BoundType2& boundToUpdate)
1004 {
1005   if (count > 0)
1006     boundToUpdate |= dataset->cols(begin, begin + count - 1);
1007 }
1008 
1009 template<typename MetricType,
1010          typename StatisticType,
1011          typename MatType,
1012          template<typename BoundMetricType, typename...> class BoundType,
1013          template<typename SplitBoundType, typename SplitMatType>
1014              class SplitType>
1015 void BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
UpdateBound(bound::HollowBallBound<MetricType> & boundToUpdate)1016 UpdateBound(bound::HollowBallBound<MetricType>& boundToUpdate)
1017 {
1018   if (!parent)
1019   {
1020     if (count > 0)
1021       boundToUpdate |= dataset->cols(begin, begin + count - 1);
1022     return;
1023   }
1024 
1025   if (parent->left != NULL && parent->left != this)
1026   {
1027     boundToUpdate.HollowCenter() = parent->left->bound.Center();
1028     boundToUpdate.InnerRadius() = std::numeric_limits<ElemType>::max();
1029   }
1030 
1031   if (count > 0)
1032     boundToUpdate |= dataset->cols(begin, begin + count - 1);
1033 }
1034 
1035 // Default constructor (private), for boost::serialization.
1036 template<typename MetricType,
1037          typename StatisticType,
1038          typename MatType,
1039          template<typename BoundMetricType, typename...> class BoundType,
1040          template<typename SplitBoundType, typename SplitMatType>
1041              class SplitType>
1042 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree()1043     BinarySpaceTree() :
1044     left(NULL),
1045     right(NULL),
1046     parent(NULL),
1047     begin(0),
1048     count(0),
1049     stat(*this),
1050     parentDistance(0),
1051     furthestDescendantDistance(0),
1052     dataset(NULL)
1053 {
1054   // Nothing to do.
1055 }
1056 
1057 /**
1058  * Serialize the tree.
1059  */
1060 template<typename MetricType,
1061          typename StatisticType,
1062          typename MatType,
1063          template<typename BoundMetricType, typename...> class BoundType,
1064          template<typename SplitBoundType, typename SplitMatType>
1065              class SplitType>
1066 template<typename Archive>
1067 void BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
serialize(Archive & ar,const unsigned int)1068     serialize(Archive& ar, const unsigned int /* version */)
1069 {
1070   // If we're loading, and we have children, they need to be deleted.
1071   if (Archive::is_loading::value)
1072   {
1073     if (left)
1074       delete left;
1075     if (right)
1076       delete right;
1077     if (!parent)
1078       delete dataset;
1079 
1080     parent = NULL;
1081     left = NULL;
1082     right = NULL;
1083   }
1084 
1085   ar & BOOST_SERIALIZATION_NVP(begin);
1086   ar & BOOST_SERIALIZATION_NVP(count);
1087   ar & BOOST_SERIALIZATION_NVP(bound);
1088   ar & BOOST_SERIALIZATION_NVP(stat);
1089 
1090   ar & BOOST_SERIALIZATION_NVP(parentDistance);
1091   ar & BOOST_SERIALIZATION_NVP(furthestDescendantDistance);
1092   ar & BOOST_SERIALIZATION_NVP(dataset);
1093 
1094   // Save children last; otherwise boost::serialization gets confused.
1095   bool hasLeft = (left != NULL);
1096   bool hasRight = (right != NULL);
1097 
1098   ar & BOOST_SERIALIZATION_NVP(hasLeft);
1099   ar & BOOST_SERIALIZATION_NVP(hasRight);
1100   if (hasLeft)
1101     ar & BOOST_SERIALIZATION_NVP(left);
1102   if (hasRight)
1103     ar & BOOST_SERIALIZATION_NVP(right);
1104 
1105   if (Archive::is_loading::value)
1106   {
1107     if (left)
1108       left->parent = this;
1109     if (right)
1110       right->parent = this;
1111   }
1112 }
1113 
1114 } // namespace tree
1115 } // namespace mlpack
1116 
1117 #endif
1118