1 /**
2 * @file methods/neighbor_search/ns_model_impl.hpp
3 * @author Ryan Curtin
4 *
5 * This is a model for nearest or furthest neighbor search. It is useful in
6 * that it provides an easy way to serialize a model, abstracts away the
7 * different types of trees, and also reflects the NeighborSearch API and
8 * automatically directs to the right tree type.
9 *
10 * mlpack is free software; you may redistribute it and/or modify it under the
11 * terms of the 3-clause BSD license. You should have received a copy of the
12 * 3-clause BSD license along with mlpack. If not, see
13 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
14 */
15 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_IMPL_HPP
16 #define MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_IMPL_HPP
17
18 // In case it hasn't been included yet.
19 #include "ns_model.hpp"
20
21 #include <boost/serialization/variant.hpp>
22
23 namespace mlpack {
24 namespace neighbor {
25
26 //! Monochromatic neighbor search on the given NSType instance.
27 template<typename NSType>
operator ()(NSType * ns) const28 void MonoSearchVisitor::operator()(NSType *ns) const
29 {
30 if (ns)
31 return ns->Search(k, neighbors, distances);
32 throw std::runtime_error("no neighbor search model initialized");
33 }
34
35 //! Save parameters for bichromatic neighbor search.
36 template<typename SortPolicy>
BiSearchVisitor(const arma::mat & querySet,const size_t k,arma::Mat<size_t> & neighbors,arma::mat & distances,const size_t leafSize,const double tau,const double rho)37 BiSearchVisitor<SortPolicy>::BiSearchVisitor(const arma::mat& querySet,
38 const size_t k,
39 arma::Mat<size_t>& neighbors,
40 arma::mat& distances,
41 const size_t leafSize,
42 const double tau,
43 const double rho) :
44 querySet(querySet),
45 k(k),
46 neighbors(neighbors),
47 distances(distances),
48 leafSize(leafSize),
49 tau(tau),
50 rho(rho)
51 {}
52
53 //! Default Bichromatic neighbor search on the given NSType instance.
54 template<typename SortPolicy>
55 template<template<typename TreeMetricType,
56 typename TreeStatType,
57 typename TreeMatType> class TreeType>
operator ()(NSTypeT<TreeType> * ns) const58 void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<TreeType>* ns) const
59 {
60 if (ns)
61 return ns->Search(querySet, k, neighbors, distances);
62 throw std::runtime_error("no neighbor search model initialized");
63 }
64
65 //! Bichromatic neighbor search on the given NSType specialized for KDTrees.
66 template<typename SortPolicy>
operator ()(NSTypeT<tree::KDTree> * ns) const67 void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::KDTree>* ns) const
68 {
69 if (ns)
70 return SearchLeaf(ns);
71 throw std::runtime_error("no neighbor search model initialized");
72 }
73
74 //! Bichromatic neighbor search on the given NSType specialized for BallTrees.
75 template<typename SortPolicy>
operator ()(NSTypeT<tree::BallTree> * ns) const76 void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::BallTree>* ns) const
77 {
78 if (ns)
79 return SearchLeaf(ns);
80 throw std::runtime_error("no neighbor search model initialized");
81 }
82
83 //! Bichromatic neighbor search specialized for SPTrees.
84 template<typename SortPolicy>
operator ()(SpillKNN * ns) const85 void BiSearchVisitor<SortPolicy>::operator()(SpillKNN* ns) const
86 {
87 if (ns)
88 {
89 if (ns->SearchMode() == DUAL_TREE_MODE)
90 {
91 // For Dual Tree Search on SpillTrees, the queryTree must be built with
92 // non overlapping (tau = 0).
93 typename SpillKNN::Tree queryTree(std::move(querySet), 0 /* tau*/,
94 leafSize, rho);
95 ns->Search(queryTree, k, neighbors, distances);
96 }
97 else
98 ns->Search(querySet, k, neighbors, distances);
99 }
100 else
101 throw std::runtime_error("no neighbor search model initialized");
102 }
103
104 //! Bichromatic neighbor search specialized for octrees.
105 template<typename SortPolicy>
operator ()(NSTypeT<tree::Octree> * ns) const106 void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::Octree>* ns) const
107 {
108 if (ns)
109 return SearchLeaf(ns);
110 throw std::runtime_error("no neighbor search model initialized");
111 }
112
113 //! Bichromatic neighbor search on the given NSType considering the leafSize.
114 template<typename SortPolicy>
115 template<typename NSType>
SearchLeaf(NSType * ns) const116 void BiSearchVisitor<SortPolicy>::SearchLeaf(NSType* ns) const
117 {
118 if (ns->SearchMode() == DUAL_TREE_MODE)
119 {
120 std::vector<size_t> oldFromNewQueries;
121 typename NSType::Tree queryTree(std::move(querySet), oldFromNewQueries,
122 leafSize);
123
124 arma::Mat<size_t> neighborsOut;
125 arma::mat distancesOut;
126 ns->Search(queryTree, k, neighborsOut, distancesOut);
127
128 // Unmap the query points.
129 distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
130 neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
131 for (size_t i = 0; i < neighborsOut.n_cols; ++i)
132 {
133 neighbors.col(oldFromNewQueries[i]) = neighborsOut.col(i);
134 distances.col(oldFromNewQueries[i]) = distancesOut.col(i);
135 }
136 }
137 else
138 ns->Search(querySet, k, neighbors, distances);
139 }
140
141 //! Save parameters for Train.
142 template<typename SortPolicy>
TrainVisitor(arma::mat && referenceSet,const size_t leafSize,const double tau,const double rho)143 TrainVisitor<SortPolicy>::TrainVisitor(arma::mat&& referenceSet,
144 const size_t leafSize,
145 const double tau,
146 const double rho) :
147 referenceSet(std::move(referenceSet)),
148 leafSize(leafSize),
149 tau(tau),
150 rho(rho)
151 {}
152
153 //! Default Train on the given NSType instance.
154 template<typename SortPolicy>
155 template<template<typename TreeMetricType,
156 typename TreeStatType,
157 typename TreeMatType> class TreeType>
operator ()(NSTypeT<TreeType> * ns) const158 void TrainVisitor<SortPolicy>::operator()(NSTypeT<TreeType>* ns) const
159 {
160 if (ns)
161 return ns->Train(std::move(referenceSet));
162 throw std::runtime_error("no neighbor search model initialized");
163 }
164
165 //! Train on the given NSType specialized for KDTrees.
166 template<typename SortPolicy>
operator ()(NSTypeT<tree::KDTree> * ns) const167 void TrainVisitor<SortPolicy>::operator()(NSTypeT<tree::KDTree>* ns) const
168 {
169 if (ns)
170 return TrainLeaf(ns);
171 throw std::runtime_error("no neighbor search model initialized");
172 }
173
174 //! Train on the given NSType specialized for BallTrees.
175 template<typename SortPolicy>
operator ()(NSTypeT<tree::BallTree> * ns) const176 void TrainVisitor<SortPolicy>::operator()(NSTypeT<tree::BallTree>* ns) const
177 {
178 if (ns)
179 return TrainLeaf(ns);
180 throw std::runtime_error("no neighbor search model initialized");
181 }
182
183 //! Train specialized for SPTrees.
184 template<typename SortPolicy>
operator ()(SpillKNN * ns) const185 void TrainVisitor<SortPolicy>::operator()(SpillKNN* ns) const
186 {
187 if (ns)
188 {
189 if (ns->SearchMode() == NAIVE_MODE)
190 ns->Train(std::move(referenceSet));
191 else
192 {
193 typename SpillKNN::Tree tree(std::move(referenceSet), tau, leafSize, rho);
194 ns->Train(std::move(tree));
195 }
196 }
197 else
198 throw std::runtime_error("no neighbor search model initialized");
199 }
200
201 //! Train specialized for Octrees.
202 template<typename SortPolicy>
operator ()(NSTypeT<tree::Octree> * ns) const203 void TrainVisitor<SortPolicy>::operator()(NSTypeT<tree::Octree>* ns) const
204 {
205 if (ns)
206 return TrainLeaf(ns);
207 throw std::runtime_error("no neighbor search model initialized");
208 }
209
210 //! Train on the given NSType considering the leafSize.
211 template<typename SortPolicy>
212 template<typename NSType>
TrainLeaf(NSType * ns) const213 void TrainVisitor<SortPolicy>::TrainLeaf(NSType* ns) const
214 {
215 if (ns->SearchMode() == NAIVE_MODE)
216 ns->Train(std::move(referenceSet));
217 else
218 {
219 std::vector<size_t> oldFromNewReferences;
220 typename NSType::Tree referenceTree(std::move(referenceSet),
221 oldFromNewReferences, leafSize);
222 ns->Train(std::move(referenceTree));
223 // Set the mappings.
224 ns->oldFromNewReferences = std::move(oldFromNewReferences);
225 }
226 }
227
228 //! Return the search mode.
229 template<typename NSType>
operator ()(NSType * ns) const230 NeighborSearchMode& SearchModeVisitor::operator()(NSType* ns) const
231 {
232 if (ns)
233 return ns->SearchMode();
234 throw std::runtime_error("no neighbor search model initialized");
235 }
236
237 //! Expose the Epsilon method of the given NSType.
238 template<typename NSType>
operator ()(NSType * ns) const239 double& EpsilonVisitor::operator()(NSType* ns) const
240 {
241 if (ns)
242 return ns->Epsilon();
243 throw std::runtime_error("no neighbor search model initialized");
244 }
245
246 //! Expose the referenceSet of the given NSType.
247 template<typename NSType>
operator ()(NSType * ns) const248 const arma::mat& ReferenceSetVisitor::operator()(NSType* ns) const
249 {
250 if (ns)
251 return ns->ReferenceSet();
252 throw std::runtime_error("no neighbor search model initialized");
253 }
254
255 //! Clean memory, if necessary.
256 template<typename NSType>
operator ()(NSType * ns) const257 void DeleteVisitor::operator()(NSType* ns) const
258 {
259 if (ns)
260 delete ns;
261 }
262
263 /**
264 * Initialize the NSModel with the given type and whether or not a random
265 * basis should be used.
266 */
267 template<typename SortPolicy>
NSModel(TreeTypes treeType,bool randomBasis)268 NSModel<SortPolicy>::NSModel(TreeTypes treeType, bool randomBasis) :
269 treeType(treeType),
270 leafSize(20),
271 tau(0),
272 rho(0.7),
273 randomBasis(randomBasis)
274 {
275 // Nothing to do.
276 }
277
278 template<typename SortPolicy>
NSModel(const NSModel & other)279 NSModel<SortPolicy>::NSModel(const NSModel& other) :
280 treeType(other.treeType),
281 leafSize(other.leafSize),
282 tau(other.tau),
283 rho(other.rho),
284 randomBasis(other.randomBasis),
285 q(other.q),
286 nSearch(other.nSearch)
287 {
288 // Nothing to do.
289 }
290
291 template<typename SortPolicy>
NSModel(NSModel && other)292 NSModel<SortPolicy>::NSModel(NSModel&& other) :
293 treeType(other.treeType),
294 leafSize(other.leafSize),
295 tau(other.tau),
296 rho(other.rho),
297 randomBasis(other.randomBasis),
298 q(std::move(other.q)),
299 nSearch(other.nSearch)
300 {
301 // Reset parameters of the other model.
302 other.treeType = TreeTypes::KD_TREE;
303 other.leafSize = 20;
304 other.tau = 0;
305 other.rho = 0.7;
306 other.randomBasis = false;
307 other.nSearch = decltype(other.nSearch)();
308 }
309
310 template<typename SortPolicy>
operator =(const NSModel & other)311 NSModel<SortPolicy>& NSModel<SortPolicy>::operator=(const NSModel& other)
312 {
313 boost::apply_visitor(DeleteVisitor(), nSearch);
314
315 treeType = other.treeType;
316 leafSize = other.leafSize;
317 tau = other.tau;
318 rho = other.rho;
319 randomBasis = other.randomBasis;
320 q = other.q;
321 nSearch = other.nSearch;
322
323 return *this;
324 }
325
326 template<typename SortPolicy>
operator =(NSModel && other)327 NSModel<SortPolicy>& NSModel<SortPolicy>::operator=(NSModel&& other)
328 {
329 boost::apply_visitor(DeleteVisitor(), nSearch);
330
331 treeType = other.treeType;
332 leafSize = other.leafSize;
333 tau = other.tau;
334 rho = other.rho;
335 randomBasis = other.randomBasis;
336 q = std::move(other.q);
337 // Copy the pointer and type.
338 nSearch = other.nSearch;
339
340 // Reset parameters of the other model.
341 other.treeType = TreeTypes::KD_TREE;
342 other.leafSize = 20;
343 other.tau = 0;
344 other.rho = 0.7;
345 other.randomBasis = false;
346 other.nSearch = decltype(other.nSearch)();
347
348 return *this;
349 }
350
351 //! Clean memory, if necessary.
352 template<typename SortPolicy>
~NSModel()353 NSModel<SortPolicy>::~NSModel()
354 {
355 boost::apply_visitor(DeleteVisitor(), nSearch);
356 }
357
358 //! Serialize the kNN model.
359 template<typename SortPolicy>
360 template<typename Archive>
serialize(Archive & ar,const unsigned int version)361 void NSModel<SortPolicy>::serialize(Archive& ar, const unsigned int version)
362 {
363 ar & BOOST_SERIALIZATION_NVP(treeType);
364 // Backward compatibility: older versions of NSModel didn't include these
365 // parameters.
366 if (version > 0)
367 {
368 ar & BOOST_SERIALIZATION_NVP(leafSize);
369 ar & BOOST_SERIALIZATION_NVP(tau);
370 ar & BOOST_SERIALIZATION_NVP(rho);
371 }
372 ar & BOOST_SERIALIZATION_NVP(randomBasis);
373 ar & BOOST_SERIALIZATION_NVP(q);
374
375 // This should never happen, but just in case, be clean with memory.
376 if (Archive::is_loading::value)
377 boost::apply_visitor(DeleteVisitor(), nSearch);
378
379 ar & BOOST_SERIALIZATION_NVP(nSearch);
380 }
381
382 //! Expose the dataset.
383 template<typename SortPolicy>
Dataset() const384 const arma::mat& NSModel<SortPolicy>::Dataset() const
385 {
386 return boost::apply_visitor(ReferenceSetVisitor(), nSearch);
387 }
388
389 //! Access the search mode.
390 template<typename SortPolicy>
SearchMode() const391 NeighborSearchMode NSModel<SortPolicy>::SearchMode() const
392 {
393 return boost::apply_visitor(SearchModeVisitor(), nSearch);
394 }
395
396 //! Modify the search mode.
397 template<typename SortPolicy>
SearchMode()398 NeighborSearchMode& NSModel<SortPolicy>::SearchMode()
399 {
400 return boost::apply_visitor(SearchModeVisitor(), nSearch);
401 }
402
403 template<typename SortPolicy>
Epsilon() const404 double NSModel<SortPolicy>::Epsilon() const
405 {
406 return boost::apply_visitor(EpsilonVisitor(), nSearch);
407 }
408
409 template<typename SortPolicy>
Epsilon()410 double& NSModel<SortPolicy>::Epsilon()
411 {
412 return boost::apply_visitor(EpsilonVisitor(), nSearch);
413 }
414
415 //! Build the reference tree.
416 template<typename SortPolicy>
BuildModel(arma::mat && referenceSet,const size_t leafSize,const NeighborSearchMode searchMode,const double epsilon)417 void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
418 const size_t leafSize,
419 const NeighborSearchMode searchMode,
420 const double epsilon)
421 {
422 this->leafSize = leafSize;
423 // Initialize random basis if necessary.
424 if (randomBasis)
425 {
426 Log::Info << "Creating random basis..." << std::endl;
427 while (true)
428 {
429 // [Q, R] = qr(randn(d, d));
430 // Q = Q * diag(sign(diag(R)));
431 arma::mat r;
432 if (arma::qr(q, r, arma::randn<arma::mat>(referenceSet.n_rows,
433 referenceSet.n_rows)))
434 {
435 arma::vec rDiag(r.n_rows);
436 for (size_t i = 0; i < rDiag.n_elem; ++i)
437 {
438 if (r(i, i) < 0)
439 rDiag(i) = -1;
440 else if (r(i, i) > 0)
441 rDiag(i) = 1;
442 else
443 rDiag(i) = 0;
444 }
445
446 q *= arma::diagmat(rDiag);
447
448 // Check if the determinant is positive.
449 if (arma::det(q) >= 0)
450 break;
451 }
452 }
453 }
454
455 // Clean memory, if necessary.
456 boost::apply_visitor(DeleteVisitor(), nSearch);
457
458 // Do we need to modify the reference set?
459 if (randomBasis)
460 referenceSet = q * referenceSet;
461
462 if (searchMode != NAIVE_MODE)
463 {
464 Timer::Start("tree_building");
465 Log::Info << "Building reference tree..." << std::endl;
466 }
467
468 switch (treeType)
469 {
470 case KD_TREE:
471 nSearch = new NSType<SortPolicy, tree::KDTree>(searchMode, epsilon);
472 break;
473 case COVER_TREE:
474 nSearch = new NSType<SortPolicy, tree::StandardCoverTree>(searchMode,
475 epsilon);
476 break;
477 case R_TREE:
478 nSearch = new NSType<SortPolicy, tree::RTree>(searchMode, epsilon);
479 break;
480 case R_STAR_TREE:
481 nSearch = new NSType<SortPolicy, tree::RStarTree>(searchMode, epsilon);
482 break;
483 case BALL_TREE:
484 nSearch = new NSType<SortPolicy, tree::BallTree>(searchMode, epsilon);
485 break;
486 case X_TREE:
487 nSearch = new NSType<SortPolicy, tree::XTree>(searchMode, epsilon);
488 break;
489 case HILBERT_R_TREE:
490 nSearch = new NSType<SortPolicy, tree::HilbertRTree>(searchMode, epsilon);
491 break;
492 case R_PLUS_TREE:
493 nSearch = new NSType<SortPolicy, tree::RPlusTree>(searchMode, epsilon);
494 break;
495 case R_PLUS_PLUS_TREE:
496 nSearch = new NSType<SortPolicy, tree::RPlusPlusTree>(searchMode,
497 epsilon);
498 break;
499 case VP_TREE:
500 nSearch = new NSType<SortPolicy, tree::VPTree>(searchMode, epsilon);
501 break;
502 case RP_TREE:
503 nSearch = new NSType<SortPolicy, tree::RPTree>(searchMode, epsilon);
504 break;
505 case MAX_RP_TREE:
506 nSearch = new NSType<SortPolicy, tree::MaxRPTree>(searchMode, epsilon);
507 break;
508 case SPILL_TREE:
509 nSearch = new SpillKNN(searchMode, epsilon);
510 break;
511 case UB_TREE:
512 nSearch = new NSType<SortPolicy, tree::UBTree>(searchMode, epsilon);
513 break;
514 case OCTREE:
515 nSearch = new NSType<SortPolicy, tree::Octree>(searchMode, epsilon);
516 break;
517 }
518
519 TrainVisitor<SortPolicy> tn(std::move(referenceSet), leafSize, tau, rho);
520 boost::apply_visitor(tn, nSearch);
521
522 if (searchMode != NAIVE_MODE)
523 {
524 Timer::Stop("tree_building");
525 Log::Info << "Tree built." << std::endl;
526 }
527 }
528
529 //! Perform neighbor search. The query set will be reordered.
530 template<typename SortPolicy>
Search(arma::mat && querySet,const size_t k,arma::Mat<size_t> & neighbors,arma::mat & distances)531 void NSModel<SortPolicy>::Search(arma::mat&& querySet,
532 const size_t k,
533 arma::Mat<size_t>& neighbors,
534 arma::mat& distances)
535 {
536 // We may need to map the query set randomly.
537 if (randomBasis)
538 querySet = q * querySet;
539
540 Log::Info << "Searching for " << k << " neighbors with ";
541
542 switch (SearchMode())
543 {
544 case NAIVE_MODE:
545 Log::Info << "brute-force (naive) search..." << std::endl;
546 break;
547 case SINGLE_TREE_MODE:
548 Log::Info << "single-tree " << TreeName() << " search..." << std::endl;
549 break;
550 case DUAL_TREE_MODE:
551 Log::Info << "dual-tree " << TreeName() << " search..." << std::endl;
552 break;
553 case GREEDY_SINGLE_TREE_MODE:
554 Log::Info << "greedy single-tree " << TreeName() << " search..."
555 << std::endl;
556 break;
557 }
558
559 BiSearchVisitor<SortPolicy> search(querySet, k, neighbors, distances,
560 leafSize, tau, rho);
561 boost::apply_visitor(search, nSearch);
562 }
563
564 //! Perform neighbor search.
565 template<typename SortPolicy>
Search(const size_t k,arma::Mat<size_t> & neighbors,arma::mat & distances)566 void NSModel<SortPolicy>::Search(const size_t k,
567 arma::Mat<size_t>& neighbors,
568 arma::mat& distances)
569 {
570 Log::Info << "Searching for " << k << " neighbors with ";
571
572 switch (SearchMode())
573 {
574 case NAIVE_MODE:
575 Log::Info << "brute-force (naive) search..." << std::endl;
576 break;
577 case SINGLE_TREE_MODE:
578 Log::Info << "single-tree " << TreeName() << " search..." << std::endl;
579 break;
580 case DUAL_TREE_MODE:
581 Log::Info << "dual-tree " << TreeName() << " search..." << std::endl;
582 break;
583 case GREEDY_SINGLE_TREE_MODE:
584 Log::Info << "greedy single-tree " << TreeName() << " search..."
585 << std::endl;
586 break;
587 }
588
589 if (Epsilon() != 0 && SearchMode() != NAIVE_MODE)
590 Log::Info << "Maximum of " << Epsilon() * 100 << "% relative error."
591 << std::endl;
592
593 MonoSearchVisitor search(k, neighbors, distances);
594 boost::apply_visitor(search, nSearch);
595 }
596
597 //! Get the name of the tree type.
598 template<typename SortPolicy>
TreeName() const599 std::string NSModel<SortPolicy>::TreeName() const
600 {
601 switch (treeType)
602 {
603 case KD_TREE:
604 return "kd-tree";
605 case COVER_TREE:
606 return "cover tree";
607 case R_TREE:
608 return "R tree";
609 case R_STAR_TREE:
610 return "R* tree";
611 case BALL_TREE:
612 return "ball tree";
613 case X_TREE:
614 return "X tree";
615 case HILBERT_R_TREE:
616 return "Hilbert R tree";
617 case R_PLUS_TREE:
618 return "R+ tree";
619 case R_PLUS_PLUS_TREE:
620 return "R++ tree";
621 case SPILL_TREE:
622 return "Spill tree";
623 case VP_TREE:
624 return "vantage point tree";
625 case RP_TREE:
626 return "random projection tree (mean split)";
627 case MAX_RP_TREE:
628 return "random projection tree (max split)";
629 case UB_TREE:
630 return "UB tree";
631 case OCTREE:
632 return "octree";
633 default:
634 return "unknown tree";
635 }
636 }
637
638 } // namespace neighbor
639 } // namespace mlpack
640
641 #endif
642