1 /**
2 * @file core/tree/binary_space_tree/binary_space_tree_impl.hpp
3 *
4 * Implementation of generalized space partitioning tree.
5 *
6 * mlpack is free software; you may redistribute it and/or modify it under the
7 * terms of the 3-clause BSD license. You should have received a copy of the
8 * 3-clause BSD license along with mlpack. If not, see
9 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
10 */
11 #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_BINARY_SPACE_TREE_IMPL_HPP
12 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_BINARY_SPACE_TREE_IMPL_HPP
13
14 // In case it wasn't included already for some reason.
15 #include "binary_space_tree.hpp"
16
17 #include <mlpack/core/util/log.hpp>
18 #include <queue>
19
20 namespace mlpack {
21 namespace tree {
22
23 // Each of these overloads is kept as a separate function to keep the overhead
24 // from the two std::vectors out, if possible.
25 template<typename MetricType,
26 typename StatisticType,
27 typename MatType,
28 template<typename BoundMetricType, typename...> class BoundType,
29 template<typename SplitBoundType, typename SplitMatType>
30 class SplitType>
31 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree(const MatType & data,const size_t maxLeafSize)32 BinarySpaceTree(
33 const MatType& data,
34 const size_t maxLeafSize) :
35 left(NULL),
36 right(NULL),
37 parent(NULL),
38 begin(0), /* This root node starts at index 0, */
39 count(data.n_cols), /* and spans all of the dataset. */
40 bound(data.n_rows),
41 parentDistance(0), // Parent distance for the root is 0: it has no parent.
42 dataset(new MatType(data)) // Copies the dataset.
43 {
44 // Do the actual splitting of this node.
45 SplitType<BoundType<MetricType>, MatType> splitter;
46 SplitNode(maxLeafSize, splitter);
47
48 // Create the statistic depending on if we are a leaf or not.
49 stat = StatisticType(*this);
50 }
51
52 template<typename MetricType,
53 typename StatisticType,
54 typename MatType,
55 template<typename BoundMetricType, typename...> class BoundType,
56 template<typename SplitBoundType, typename SplitMatType>
57 class SplitType>
58 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree(const MatType & data,std::vector<size_t> & oldFromNew,const size_t maxLeafSize)59 BinarySpaceTree(
60 const MatType& data,
61 std::vector<size_t>& oldFromNew,
62 const size_t maxLeafSize) :
63 left(NULL),
64 right(NULL),
65 parent(NULL),
66 begin(0),
67 count(data.n_cols),
68 bound(data.n_rows),
69 parentDistance(0), // Parent distance for the root is 0: it has no parent.
70 dataset(new MatType(data)) // Copies the dataset.
71 {
72 // Initialize oldFromNew correctly.
73 oldFromNew.resize(data.n_cols);
74 for (size_t i = 0; i < data.n_cols; ++i)
75 oldFromNew[i] = i; // Fill with unharmed indices.
76
77 // Now do the actual splitting.
78 SplitType<BoundType<MetricType>, MatType> splitter;
79 SplitNode(oldFromNew, maxLeafSize, splitter);
80
81 // Create the statistic depending on if we are a leaf or not.
82 stat = StatisticType(*this);
83 }
84
85 template<typename MetricType,
86 typename StatisticType,
87 typename MatType,
88 template<typename BoundMetricType, typename...> class BoundType,
89 template<typename SplitBoundType, typename SplitMatType>
90 class SplitType>
91 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree(const MatType & data,std::vector<size_t> & oldFromNew,std::vector<size_t> & newFromOld,const size_t maxLeafSize)92 BinarySpaceTree(
93 const MatType& data,
94 std::vector<size_t>& oldFromNew,
95 std::vector<size_t>& newFromOld,
96 const size_t maxLeafSize) :
97 left(NULL),
98 right(NULL),
99 parent(NULL),
100 begin(0),
101 count(data.n_cols),
102 bound(data.n_rows),
103 parentDistance(0), // Parent distance for the root is 0: it has no parent.
104 dataset(new MatType(data)) // Copies the dataset.
105 {
106 // Initialize the oldFromNew vector correctly.
107 oldFromNew.resize(data.n_cols);
108 for (size_t i = 0; i < data.n_cols; ++i)
109 oldFromNew[i] = i; // Fill with unharmed indices.
110
111 // Now do the actual splitting.
112 SplitType<BoundType<MetricType>, MatType> splitter;
113 SplitNode(oldFromNew, maxLeafSize, splitter);
114
115 // Create the statistic depending on if we are a leaf or not.
116 stat = StatisticType(*this);
117
118 // Map the newFromOld indices correctly.
119 newFromOld.resize(data.n_cols);
120 for (size_t i = 0; i < data.n_cols; ++i)
121 newFromOld[oldFromNew[i]] = i;
122 }
123
124 template<typename MetricType,
125 typename StatisticType,
126 typename MatType,
127 template<typename BoundMetricType, typename...> class BoundType,
128 template<typename SplitBoundType, typename SplitMatType>
129 class SplitType>
130 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree(MatType && data,const size_t maxLeafSize)131 BinarySpaceTree(MatType&& data, const size_t maxLeafSize) :
132 left(NULL),
133 right(NULL),
134 parent(NULL),
135 begin(0),
136 count(data.n_cols),
137 bound(data.n_rows),
138 parentDistance(0), // Parent distance for the root is 0: it has no parent.
139 dataset(new MatType(std::move(data)))
140 {
141 // Do the actual splitting of this node.
142 SplitType<BoundType<MetricType>, MatType> splitter;
143 SplitNode(maxLeafSize, splitter);
144
145 // Create the statistic depending on if we are a leaf or not.
146 stat = StatisticType(*this);
147 }
148
149 template<typename MetricType,
150 typename StatisticType,
151 typename MatType,
152 template<typename BoundMetricType, typename...> class BoundType,
153 template<typename SplitBoundType, typename SplitMatType>
154 class SplitType>
155 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree(MatType && data,std::vector<size_t> & oldFromNew,const size_t maxLeafSize)156 BinarySpaceTree(
157 MatType&& data,
158 std::vector<size_t>& oldFromNew,
159 const size_t maxLeafSize) :
160 left(NULL),
161 right(NULL),
162 parent(NULL),
163 begin(0),
164 count(data.n_cols),
165 bound(data.n_rows),
166 parentDistance(0), // Parent distance for the root is 0: it has no parent.
167 dataset(new MatType(std::move(data)))
168 {
169 // Initialize oldFromNew correctly.
170 oldFromNew.resize(dataset->n_cols);
171 for (size_t i = 0; i < dataset->n_cols; ++i)
172 oldFromNew[i] = i; // Fill with unharmed indices.
173
174 // Now do the actual splitting.
175 SplitType<BoundType<MetricType>, MatType> splitter;
176 SplitNode(oldFromNew, maxLeafSize, splitter);
177
178 // Create the statistic depending on if we are a leaf or not.
179 stat = StatisticType(*this);
180 }
181
182 template<typename MetricType,
183 typename StatisticType,
184 typename MatType,
185 template<typename BoundMetricType, typename...> class BoundType,
186 template<typename SplitBoundType, typename SplitMatType>
187 class SplitType>
188 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree(MatType && data,std::vector<size_t> & oldFromNew,std::vector<size_t> & newFromOld,const size_t maxLeafSize)189 BinarySpaceTree(
190 MatType&& data,
191 std::vector<size_t>& oldFromNew,
192 std::vector<size_t>& newFromOld,
193 const size_t maxLeafSize) :
194 left(NULL),
195 right(NULL),
196 parent(NULL),
197 begin(0),
198 count(data.n_cols),
199 bound(data.n_rows),
200 parentDistance(0), // Parent distance for the root is 0: it has no parent.
201 dataset(new MatType(std::move(data)))
202 {
203 // Initialize the oldFromNew vector correctly.
204 oldFromNew.resize(dataset->n_cols);
205 for (size_t i = 0; i < dataset->n_cols; ++i)
206 oldFromNew[i] = i; // Fill with unharmed indices.
207
208 // Now do the actual splitting.
209 SplitType<BoundType<MetricType>, MatType> splitter;
210 SplitNode(oldFromNew, maxLeafSize, splitter);
211
212 // Create the statistic depending on if we are a leaf or not.
213 stat = StatisticType(*this);
214
215 // Map the newFromOld indices correctly.
216 newFromOld.resize(dataset->n_cols);
217 for (size_t i = 0; i < dataset->n_cols; ++i)
218 newFromOld[oldFromNew[i]] = i;
219 }
220
221 template<typename MetricType,
222 typename StatisticType,
223 typename MatType,
224 template<typename BoundMetricType, typename...> class BoundType,
225 template<typename SplitBoundType, typename SplitMatType>
226 class SplitType>
227 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree(BinarySpaceTree * parent,const size_t begin,const size_t count,SplitType<BoundType<MetricType>,MatType> & splitter,const size_t maxLeafSize)228 BinarySpaceTree(
229 BinarySpaceTree* parent,
230 const size_t begin,
231 const size_t count,
232 SplitType<BoundType<MetricType>, MatType>& splitter,
233 const size_t maxLeafSize) :
234 left(NULL),
235 right(NULL),
236 parent(parent),
237 begin(begin),
238 count(count),
239 bound(parent->Dataset().n_rows),
240 dataset(&parent->Dataset()) // Point to the parent's dataset.
241 {
242 // Perform the actual splitting.
243 SplitNode(maxLeafSize, splitter);
244
245 // Create the statistic depending on if we are a leaf or not.
246 stat = StatisticType(*this);
247 }
248
249 template<typename MetricType,
250 typename StatisticType,
251 typename MatType,
252 template<typename BoundMetricType, typename...> class BoundType,
253 template<typename SplitBoundType, typename SplitMatType>
254 class SplitType>
255 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree(BinarySpaceTree * parent,const size_t begin,const size_t count,std::vector<size_t> & oldFromNew,SplitType<BoundType<MetricType>,MatType> & splitter,const size_t maxLeafSize)256 BinarySpaceTree(
257 BinarySpaceTree* parent,
258 const size_t begin,
259 const size_t count,
260 std::vector<size_t>& oldFromNew,
261 SplitType<BoundType<MetricType>, MatType>& splitter,
262 const size_t maxLeafSize) :
263 left(NULL),
264 right(NULL),
265 parent(parent),
266 begin(begin),
267 count(count),
268 bound(parent->Dataset().n_rows),
269 dataset(&parent->Dataset())
270 {
271 // Hopefully the vector is initialized correctly! We can't check that
272 // entirely but we can do a minor sanity check.
273 assert(oldFromNew.size() == dataset->n_cols);
274
275 // Perform the actual splitting.
276 SplitNode(oldFromNew, maxLeafSize, splitter);
277
278 // Create the statistic depending on if we are a leaf or not.
279 stat = StatisticType(*this);
280 }
281
282 template<typename MetricType,
283 typename StatisticType,
284 typename MatType,
285 template<typename BoundMetricType, typename...> class BoundType,
286 template<typename SplitBoundType, typename SplitMatType>
287 class SplitType>
288 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree(BinarySpaceTree * parent,const size_t begin,const size_t count,std::vector<size_t> & oldFromNew,std::vector<size_t> & newFromOld,SplitType<BoundType<MetricType>,MatType> & splitter,const size_t maxLeafSize)289 BinarySpaceTree(
290 BinarySpaceTree* parent,
291 const size_t begin,
292 const size_t count,
293 std::vector<size_t>& oldFromNew,
294 std::vector<size_t>& newFromOld,
295 SplitType<BoundType<MetricType>, MatType>& splitter,
296 const size_t maxLeafSize) :
297 left(NULL),
298 right(NULL),
299 parent(parent),
300 begin(begin),
301 count(count),
302 bound(parent->Dataset()->n_rows),
303 dataset(&parent->Dataset())
304 {
305 // Hopefully the vector is initialized correctly! We can't check that
306 // entirely but we can do a minor sanity check.
307 Log::Assert(oldFromNew.size() == dataset->n_cols);
308
309 // Perform the actual splitting.
310 SplitNode(oldFromNew, maxLeafSize, splitter);
311
312 // Create the statistic depending on if we are a leaf or not.
313 stat = StatisticType(*this);
314
315 // Map the newFromOld indices correctly.
316 newFromOld.resize(dataset->n_cols);
317 for (size_t i = 0; i < dataset->n_cols; ++i)
318 newFromOld[oldFromNew[i]] = i;
319 }
320
321 /**
322 * Create a binary space tree by copying the other tree. Be careful! This can
323 * take a long time and use a lot of memory.
324 */
325 template<typename MetricType,
326 typename StatisticType,
327 typename MatType,
328 template<typename BoundMetricType, typename...> class BoundType,
329 template<typename SplitBoundType, typename SplitMatType>
330 class SplitType>
331 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree(const BinarySpaceTree & other)332 BinarySpaceTree(
333 const BinarySpaceTree& other) :
334 left(NULL),
335 right(NULL),
336 parent(other.parent),
337 begin(other.begin),
338 count(other.count),
339 bound(other.bound),
340 stat(other.stat),
341 parentDistance(other.parentDistance),
342 furthestDescendantDistance(other.furthestDescendantDistance),
343 minimumBoundDistance(other.minimumBoundDistance),
344 // Copy matrix, but only if we are the root.
345 dataset((other.parent == NULL) ? new MatType(*other.dataset) : NULL)
346 {
347 // Create left and right children (if any).
348 if (other.Left())
349 {
350 left = new BinarySpaceTree(*other.Left());
351 left->Parent() = this; // Set parent to this, not other tree.
352 }
353
354 if (other.Right())
355 {
356 right = new BinarySpaceTree(*other.Right());
357 right->Parent() = this; // Set parent to this, not other tree.
358 }
359
360 // Propagate matrix, but only if we are the root.
361 if (parent == NULL)
362 {
363 std::queue<BinarySpaceTree*> queue;
364 if (left)
365 queue.push(left);
366 if (right)
367 queue.push(right);
368 while (!queue.empty())
369 {
370 BinarySpaceTree* node = queue.front();
371 queue.pop();
372
373 node->dataset = dataset;
374 if (node->left)
375 queue.push(node->left);
376 if (node->right)
377 queue.push(node->right);
378 }
379 }
380 }
381
382 /**
383 * Copy assignment operator: copy the given other tree.
384 */
385 template<typename MetricType,
386 typename StatisticType,
387 typename MatType,
388 template<typename BoundMetricType, typename...> class BoundType,
389 template<typename SplitBoundType, typename SplitMatType>
390 class SplitType>
391 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>&
392 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
operator =(const BinarySpaceTree & other)393 operator=(const BinarySpaceTree& other)
394 {
395 // Return if it's the same tree.
396 if (this == &other)
397 return *this;
398
399 // Freeing memory that will not be used anymore.
400 delete dataset;
401 delete left;
402 delete right;
403
404 left = NULL;
405 right = NULL;
406 parent = other.Parent();
407 begin = other.Begin();
408 count = other.Count();
409 bound = other.bound;
410 stat = other.stat;
411 parentDistance = other.ParentDistance();
412 furthestDescendantDistance = other.FurthestDescendantDistance();
413 minimumBoundDistance = other.MinimumBoundDistance();
414 // Copy matrix, but only if we are the root.
415 dataset = ((other.parent == NULL) ? new MatType(*other.dataset) : NULL);
416
417 // Create left and right children (if any).
418 if (other.Left())
419 {
420 left = new BinarySpaceTree(*other.Left());
421 left->Parent() = this; // Set parent to this, not other tree.
422 }
423
424 if (other.Right())
425 {
426 right = new BinarySpaceTree(*other.Right());
427 right->Parent() = this; // Set parent to this, not other tree.
428 }
429
430 // Propagate matrix, but only if we are the root.
431 if (parent == NULL)
432 {
433 std::queue<BinarySpaceTree*> queue;
434 if (left)
435 queue.push(left);
436 if (right)
437 queue.push(right);
438 while (!queue.empty())
439 {
440 BinarySpaceTree* node = queue.front();
441 queue.pop();
442
443 node->dataset = dataset;
444 if (node->left)
445 queue.push(node->left);
446 if (node->right)
447 queue.push(node->right);
448 }
449 }
450
451 return *this;
452 }
453
454 /**
455 * Move assignment operator: take ownership of the given tree.
456 */
457 template<typename MetricType,
458 typename StatisticType,
459 typename MatType,
460 template<typename BoundMetricType, typename...> class BoundType,
461 template<typename SplitBoundType, typename SplitMatType>
462 class SplitType>
463 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>&
464 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
operator =(BinarySpaceTree && other)465 operator=(BinarySpaceTree&& other)
466 {
467 // Return if it's the same tree.
468 if (this == &other)
469 return *this;
470
471 // Freeing memory that will not be used anymore.
472 delete dataset;
473 delete left;
474 delete right;
475
476 parent = other.Parent();
477 left = other.Left();
478 right = other.Right();
479 begin = other.Begin();
480 count = other.Count();
481 bound = std::move(other.bound);
482 stat = std::move(other.stat);
483 parentDistance = other.ParentDistance();
484 furthestDescendantDistance = other.FurthestDescendantDistance();
485 minimumBoundDistance = other.MinimumBoundDistance();
486 dataset = other.dataset;
487
488 other.left = NULL;
489 other.right = NULL;
490 other.parent = NULL;
491 other.begin = 0;
492 other.count = 0;
493 other.parentDistance = 0.0;
494 other.furthestDescendantDistance = 0.0;
495 other.minimumBoundDistance = 0.0;
496 other.dataset = NULL;
497
498 return *this;
499 }
500
501
502 /**
503 * Move constructor.
504 */
505 template<typename MetricType,
506 typename StatisticType,
507 typename MatType,
508 template<typename BoundMetricType, typename...> class BoundType,
509 template<typename SplitBoundType, typename SplitMatType>
510 class SplitType>
511 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree(BinarySpaceTree && other)512 BinarySpaceTree(BinarySpaceTree&& other) :
513 left(other.left),
514 right(other.right),
515 parent(other.parent),
516 begin(other.begin),
517 count(other.count),
518 bound(std::move(other.bound)),
519 stat(std::move(other.stat)),
520 parentDistance(other.parentDistance),
521 furthestDescendantDistance(other.furthestDescendantDistance),
522 minimumBoundDistance(other.minimumBoundDistance),
523 dataset(other.dataset)
524 {
525 // Now we are a clone of the other tree. But we must also clear the other
526 // tree's contents, so it doesn't delete anything when it is destructed.
527 other.left = NULL;
528 other.right = NULL;
529 other.parent = NULL;
530 other.begin = 0;
531 other.count = 0;
532 other.parentDistance = 0.0;
533 other.furthestDescendantDistance = 0.0;
534 other.minimumBoundDistance = 0.0;
535 other.dataset = NULL;
536
537 // Set new parent.
538 if (left)
539 left->parent = this;
540 if (right)
541 right->parent = this;
542 }
543
544 /**
545 * Initialize the tree from an archive.
546 */
547 template<typename MetricType,
548 typename StatisticType,
549 typename MatType,
550 template<typename BoundMetricType, typename...> class BoundType,
551 template<typename SplitBoundType, typename SplitMatType>
552 class SplitType>
553 template<typename Archive>
554 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree(Archive & ar,const typename std::enable_if_t<Archive::is_loading::value> *)555 BinarySpaceTree(
556 Archive& ar,
557 const typename std::enable_if_t<Archive::is_loading::value>*) :
558 BinarySpaceTree() // Create an empty BinarySpaceTree.
559 {
560 // We've delegated to the constructor which gives us an empty tree, and now we
561 // can serialize from it.
562 ar >> BOOST_SERIALIZATION_NVP(*this);
563 }
564
565 /**
566 * Deletes this node, deallocating the memory for the children and calling their
567 * destructors in turn. This will invalidate any pointers or references to any
568 * nodes which are children of this one.
569 */
570 template<typename MetricType,
571 typename StatisticType,
572 typename MatType,
573 template<typename BoundMetricType, typename...> class BoundType,
574 template<typename SplitBoundType, typename SplitMatType>
575 class SplitType>
576 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
~BinarySpaceTree()577 ~BinarySpaceTree()
578 {
579 delete left;
580 delete right;
581
582 // If we're the root, delete the matrix.
583 if (!parent)
584 delete dataset;
585 }
586
587 template<typename MetricType,
588 typename StatisticType,
589 typename MatType,
590 template<typename BoundMetricType, typename...> class BoundType,
591 template<typename SplitBoundType, typename SplitMatType>
592 class SplitType>
593 inline bool BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
IsLeaf() const594 SplitType>::IsLeaf() const
595 {
596 return !left;
597 }
598
599 /**
600 * Returns the number of children in this node.
601 */
602 template<typename MetricType,
603 typename StatisticType,
604 typename MatType,
605 template<typename BoundMetricType, typename...> class BoundType,
606 template<typename SplitBoundType, typename SplitMatType>
607 class SplitType>
608 inline size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
NumChildren() const609 SplitType>::NumChildren() const
610 {
611 if (left && right)
612 return 2;
613 if (left)
614 return 1;
615
616 return 0;
617 }
618
619 /**
620 * Return the index of the nearest child node to the given query point. If
621 * this is a leaf node, it will return NumChildren() (invalid index).
622 */
623 template<typename MetricType,
624 typename StatisticType,
625 typename MatType,
626 template<typename BoundMetricType, typename...> class BoundType,
627 template<typename SplitBoundType, typename SplitMatType>
628 class SplitType>
629 template<typename VecType>
630 size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
GetNearestChild(const VecType & point,typename std::enable_if_t<IsVector<VecType>::value> *)631 SplitType>::GetNearestChild(
632 const VecType& point,
633 typename std::enable_if_t<IsVector<VecType>::value>*)
634 {
635 if (IsLeaf() || !left || !right)
636 return 0;
637
638 if (left->MinDistance(point) <= right->MinDistance(point))
639 return 0;
640 return 1;
641 }
642
643 /**
644 * Return the index of the furthest child node to the given query point. If
645 * this is a leaf node, it will return NumChildren() (invalid index).
646 */
647 template<typename MetricType,
648 typename StatisticType,
649 typename MatType,
650 template<typename BoundMetricType, typename...> class BoundType,
651 template<typename SplitBoundType, typename SplitMatType>
652 class SplitType>
653 template<typename VecType>
654 size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
GetFurthestChild(const VecType & point,typename std::enable_if_t<IsVector<VecType>::value> *)655 SplitType>::GetFurthestChild(
656 const VecType& point,
657 typename std::enable_if_t<IsVector<VecType>::value>*)
658 {
659 if (IsLeaf() || !left || !right)
660 return 0;
661
662 if (left->MaxDistance(point) > right->MaxDistance(point))
663 return 0;
664 return 1;
665 }
666
667 /**
668 * Return the index of the nearest child node to the given query node. If it
669 * can't decide, it will return NumChildren() (invalid index).
670 */
671 template<typename MetricType,
672 typename StatisticType,
673 typename MatType,
674 template<typename BoundMetricType, typename...> class BoundType,
675 template<typename SplitBoundType, typename SplitMatType>
676 class SplitType>
677 size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
GetNearestChild(const BinarySpaceTree & queryNode)678 SplitType>::GetNearestChild(const BinarySpaceTree& queryNode)
679 {
680 if (IsLeaf() || !left || !right)
681 return 0;
682
683 ElemType leftDist = left->MinDistance(queryNode);
684 ElemType rightDist = right->MinDistance(queryNode);
685 if (leftDist < rightDist)
686 return 0;
687 if (rightDist < leftDist)
688 return 1;
689 return NumChildren();
690 }
691
692 /**
693 * Return the index of the furthest child node to the given query node. If it
694 * can't decide, it will return NumChildren() (invalid index).
695 */
696 template<typename MetricType,
697 typename StatisticType,
698 typename MatType,
699 template<typename BoundMetricType, typename...> class BoundType,
700 template<typename SplitBoundType, typename SplitMatType>
701 class SplitType>
702 size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
GetFurthestChild(const BinarySpaceTree & queryNode)703 SplitType>::GetFurthestChild(const BinarySpaceTree& queryNode)
704 {
705 if (IsLeaf() || !left || !right)
706 return 0;
707
708 ElemType leftDist = left->MaxDistance(queryNode);
709 ElemType rightDist = right->MaxDistance(queryNode);
710 if (leftDist > rightDist)
711 return 0;
712 if (rightDist > leftDist)
713 return 1;
714 return NumChildren();
715 }
716
717 /**
718 * Return a bound on the furthest point in the node from the center. This
719 * returns 0 unless the node is a leaf.
720 */
721 template<typename MetricType,
722 typename StatisticType,
723 typename MatType,
724 template<typename BoundMetricType, typename...> class BoundType,
725 template<typename SplitBoundType, typename SplitMatType>
726 class SplitType>
727 inline
728 typename BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
729 SplitType>::ElemType
730 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
FurthestPointDistance() const731 SplitType>::FurthestPointDistance() const
732 {
733 if (!IsLeaf())
734 return 0.0;
735
736 // Otherwise return the distance from the center to a corner of the bound.
737 return 0.5 * bound.Diameter();
738 }
739
740 /**
741 * Return the furthest possible descendant distance. This returns the maximum
742 * distance from the center to the edge of the bound and not the empirical
743 * quantity which is the actual furthest descendant distance. So the actual
744 * furthest descendant distance may be less than what this method returns (but
745 * it will never be greater than this).
746 */
747 template<typename MetricType,
748 typename StatisticType,
749 typename MatType,
750 template<typename BoundMetricType, typename...> class BoundType,
751 template<typename SplitBoundType, typename SplitMatType>
752 class SplitType>
753 inline
754 typename BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
755 SplitType>::ElemType
756 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
FurthestDescendantDistance() const757 SplitType>::FurthestDescendantDistance() const
758 {
759 return furthestDescendantDistance;
760 }
761
762 //! Return the minimum distance from the center to any bound edge.
763 template<typename MetricType,
764 typename StatisticType,
765 typename MatType,
766 template<typename BoundMetricType, typename...> class BoundType,
767 template<typename SplitBoundType, typename SplitMatType>
768 class SplitType>
769 inline
770 typename BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
771 SplitType>::ElemType
772 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
MinimumBoundDistance() const773 SplitType>::MinimumBoundDistance() const
774 {
775 return bound.MinWidth() / 2.0;
776 }
777
778 /**
779 * Return the specified child.
780 */
781 template<typename MetricType,
782 typename StatisticType,
783 typename MatType,
784 template<typename BoundMetricType, typename...> class BoundType,
785 template<typename SplitBoundType, typename SplitMatType>
786 class SplitType>
787 inline BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
788 SplitType>&
789 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
Child(const size_t child) const790 SplitType>::Child(const size_t child) const
791 {
792 if (child == 0)
793 return *left;
794 else
795 return *right;
796 }
797
798 /**
799 * Return the number of points contained in this node.
800 */
801 template<typename MetricType,
802 typename StatisticType,
803 typename MatType,
804 template<typename BoundMetricType, typename...> class BoundType,
805 template<typename SplitBoundType, typename SplitMatType>
806 class SplitType>
807 inline size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
NumPoints() const808 SplitType>::NumPoints() const
809 {
810 if (left)
811 return 0;
812
813 return count;
814 }
815
816 /**
817 * Return the number of descendants contained in the node.
818 */
819 template<typename MetricType,
820 typename StatisticType,
821 typename MatType,
822 template<typename BoundMetricType, typename...> class BoundType,
823 template<typename SplitBoundType, typename SplitMatType>
824 class SplitType>
825 inline size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
NumDescendants() const826 SplitType>::NumDescendants() const
827 {
828 return count;
829 }
830
831 /**
832 * Return the index of a particular descendant contained in this node.
833 */
834 template<typename MetricType,
835 typename StatisticType,
836 typename MatType,
837 template<typename BoundMetricType, typename...> class BoundType,
838 template<typename SplitBoundType, typename SplitMatType>
839 class SplitType>
840 inline size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
Descendant(const size_t index) const841 SplitType>::Descendant(const size_t index) const
842 {
843 return (begin + index);
844 }
845
846 /**
847 * Return the index of a particular point contained in this node.
848 */
849 template<typename MetricType,
850 typename StatisticType,
851 typename MatType,
852 template<typename BoundMetricType, typename...> class BoundType,
853 template<typename SplitBoundType, typename SplitMatType>
854 class SplitType>
855 inline size_t BinarySpaceTree<MetricType, StatisticType, MatType, BoundType,
Point(const size_t index) const856 SplitType>::Point(const size_t index) const
857 {
858 return (begin + index);
859 }
860
861 template<typename MetricType,
862 typename StatisticType,
863 typename MatType,
864 template<typename BoundMetricType, typename...> class BoundType,
865 template<typename SplitBoundType, typename SplitMatType>
866 class SplitType>
867 void BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
SplitNode(const size_t maxLeafSize,SplitType<BoundType<MetricType>,MatType> & splitter)868 SplitNode(const size_t maxLeafSize,
869 SplitType<BoundType<MetricType>, MatType>& splitter)
870 {
871 // We need to expand the bounds of this node properly.
872 UpdateBound(bound);
873
874 // Calculate the furthest descendant distance.
875 furthestDescendantDistance = 0.5 * bound.Diameter();
876
877 // Now, check if we need to split at all.
878 if (count <= maxLeafSize)
879 return; // We can't split this.
880
881 // splitCol denotes the two partitions of the dataset after the split. The
882 // points on its left go to the left child and the others go to the right
883 // child.
884 size_t splitCol;
885
886 // Find the partition of the node. This method does not perform the split.
887 typename Split::SplitInfo splitInfo;
888
889 const bool split = splitter.SplitNode(bound, *dataset, begin, count,
890 splitInfo);
891
892 // The node may not be always split. For instance, if all the points are the
893 // same, we can't split them.
894 if (!split)
895 return;
896
897 // Perform the actual splitting. This will order the dataset such that
898 // points that belong to the left subtree are on the left of splitCol, and
899 // points from the right subtree are on the right side of splitCol.
900 splitCol = splitter.PerformSplit(*dataset, begin, count, splitInfo);
901
902 assert(splitCol > begin);
903 assert(splitCol < begin + count);
904
905 // Now that we know the split column, we will recursively split the children
906 // by calling their constructors (which perform this splitting process).
907 left = new BinarySpaceTree(this, begin, splitCol - begin, splitter,
908 maxLeafSize);
909 right = new BinarySpaceTree(this, splitCol, begin + count - splitCol,
910 splitter, maxLeafSize);
911
912 // Calculate parent distances for those two nodes.
913 arma::vec center, leftCenter, rightCenter;
914 Center(center);
915 left->Center(leftCenter);
916 right->Center(rightCenter);
917
918 const ElemType leftParentDistance = bound.Metric().Evaluate(center,
919 leftCenter);
920 const ElemType rightParentDistance = bound.Metric().Evaluate(center,
921 rightCenter);
922
923 left->ParentDistance() = leftParentDistance;
924 right->ParentDistance() = rightParentDistance;
925 }
926
927 template<typename MetricType,
928 typename StatisticType,
929 typename MatType,
930 template<typename BoundMetricType, typename...> class BoundType,
931 template<typename SplitBoundType, typename SplitMatType>
932 class SplitType>
933 void BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
SplitNode(std::vector<size_t> & oldFromNew,const size_t maxLeafSize,SplitType<BoundType<MetricType>,MatType> & splitter)934 SplitNode(std::vector<size_t>& oldFromNew,
935 const size_t maxLeafSize,
936 SplitType<BoundType<MetricType>, MatType>& splitter)
937 {
938 // We need to expand the bounds of this node properly.
939 UpdateBound(bound);
940
941 // Calculate the furthest descendant distance.
942 furthestDescendantDistance = 0.5 * bound.Diameter();
943
944 // First, check if we need to split at all.
945 if (count <= maxLeafSize)
946 return; // We can't split this.
947
948 // splitCol denotes the two partitions of the dataset after the split. The
949 // points on its left go to the left child and the others go to the right
950 // child.
951 size_t splitCol;
952
953 // Find the partition of the node. This method does not perform the split.
954 typename Split::SplitInfo splitInfo;
955
956 const bool split = splitter.SplitNode(bound, *dataset, begin, count,
957 splitInfo);
958
959 // The node may not be always split. For instance, if all the points are the
960 // same, we can't split them.
961 if (!split)
962 return;
963
964 // Perform the actual splitting. This will order the dataset such that
965 // points that belong to the left subtree are on the left of splitCol, and
966 // points from the right subtree are on the right side of splitCol.
967 splitCol = splitter.PerformSplit(*dataset, begin, count, splitInfo,
968 oldFromNew);
969
970 assert(splitCol > begin);
971 assert(splitCol < begin + count);
972
973 // Now that we know the split column, we will recursively split the children
974 // by calling their constructors (which perform this splitting process).
975 left = new BinarySpaceTree(this, begin, splitCol - begin, oldFromNew,
976 splitter, maxLeafSize);
977 right = new BinarySpaceTree(this, splitCol, begin + count - splitCol,
978 oldFromNew, splitter, maxLeafSize);
979
980 // Calculate parent distances for those two nodes.
981 arma::vec center, leftCenter, rightCenter;
982 Center(center);
983 left->Center(leftCenter);
984 right->Center(rightCenter);
985
986 const ElemType leftParentDistance = bound.Metric().Evaluate(center,
987 leftCenter);
988 const ElemType rightParentDistance = bound.Metric().Evaluate(center,
989 rightCenter);
990
991 left->ParentDistance() = leftParentDistance;
992 right->ParentDistance() = rightParentDistance;
993 }
994
995 template<typename MetricType,
996 typename StatisticType,
997 typename MatType,
998 template<typename BoundMetricType, typename...> class BoundType,
999 template<typename SplitBoundType, typename SplitMatType>
1000 class SplitType>
1001 template<typename BoundType2>
1002 void BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
UpdateBound(BoundType2 & boundToUpdate)1003 UpdateBound(BoundType2& boundToUpdate)
1004 {
1005 if (count > 0)
1006 boundToUpdate |= dataset->cols(begin, begin + count - 1);
1007 }
1008
1009 template<typename MetricType,
1010 typename StatisticType,
1011 typename MatType,
1012 template<typename BoundMetricType, typename...> class BoundType,
1013 template<typename SplitBoundType, typename SplitMatType>
1014 class SplitType>
1015 void BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
UpdateBound(bound::HollowBallBound<MetricType> & boundToUpdate)1016 UpdateBound(bound::HollowBallBound<MetricType>& boundToUpdate)
1017 {
1018 if (!parent)
1019 {
1020 if (count > 0)
1021 boundToUpdate |= dataset->cols(begin, begin + count - 1);
1022 return;
1023 }
1024
1025 if (parent->left != NULL && parent->left != this)
1026 {
1027 boundToUpdate.HollowCenter() = parent->left->bound.Center();
1028 boundToUpdate.InnerRadius() = std::numeric_limits<ElemType>::max();
1029 }
1030
1031 if (count > 0)
1032 boundToUpdate |= dataset->cols(begin, begin + count - 1);
1033 }
1034
1035 // Default constructor (private), for boost::serialization.
1036 template<typename MetricType,
1037 typename StatisticType,
1038 typename MatType,
1039 template<typename BoundMetricType, typename...> class BoundType,
1040 template<typename SplitBoundType, typename SplitMatType>
1041 class SplitType>
1042 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
BinarySpaceTree()1043 BinarySpaceTree() :
1044 left(NULL),
1045 right(NULL),
1046 parent(NULL),
1047 begin(0),
1048 count(0),
1049 stat(*this),
1050 parentDistance(0),
1051 furthestDescendantDistance(0),
1052 dataset(NULL)
1053 {
1054 // Nothing to do.
1055 }
1056
1057 /**
1058 * Serialize the tree.
1059 */
1060 template<typename MetricType,
1061 typename StatisticType,
1062 typename MatType,
1063 template<typename BoundMetricType, typename...> class BoundType,
1064 template<typename SplitBoundType, typename SplitMatType>
1065 class SplitType>
1066 template<typename Archive>
1067 void BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
serialize(Archive & ar,const unsigned int)1068 serialize(Archive& ar, const unsigned int /* version */)
1069 {
1070 // If we're loading, and we have children, they need to be deleted.
1071 if (Archive::is_loading::value)
1072 {
1073 if (left)
1074 delete left;
1075 if (right)
1076 delete right;
1077 if (!parent)
1078 delete dataset;
1079
1080 parent = NULL;
1081 left = NULL;
1082 right = NULL;
1083 }
1084
1085 ar & BOOST_SERIALIZATION_NVP(begin);
1086 ar & BOOST_SERIALIZATION_NVP(count);
1087 ar & BOOST_SERIALIZATION_NVP(bound);
1088 ar & BOOST_SERIALIZATION_NVP(stat);
1089
1090 ar & BOOST_SERIALIZATION_NVP(parentDistance);
1091 ar & BOOST_SERIALIZATION_NVP(furthestDescendantDistance);
1092 ar & BOOST_SERIALIZATION_NVP(dataset);
1093
1094 // Save children last; otherwise boost::serialization gets confused.
1095 bool hasLeft = (left != NULL);
1096 bool hasRight = (right != NULL);
1097
1098 ar & BOOST_SERIALIZATION_NVP(hasLeft);
1099 ar & BOOST_SERIALIZATION_NVP(hasRight);
1100 if (hasLeft)
1101 ar & BOOST_SERIALIZATION_NVP(left);
1102 if (hasRight)
1103 ar & BOOST_SERIALIZATION_NVP(right);
1104
1105 if (Archive::is_loading::value)
1106 {
1107 if (left)
1108 left->parent = this;
1109 if (right)
1110 right->parent = this;
1111 }
1112 }
1113
1114 } // namespace tree
1115 } // namespace mlpack
1116
1117 #endif
1118