1 /**
2  * @file methods/neighbor_search/ns_model_impl.hpp
3  * @author Ryan Curtin
4  *
5  * This is a model for nearest or furthest neighbor search.  It is useful in
6  * that it provides an easy way to serialize a model, abstracts away the
7  * different types of trees, and also reflects the NeighborSearch API and
8  * automatically directs to the right tree type.
9  *
10  * mlpack is free software; you may redistribute it and/or modify it under the
11  * terms of the 3-clause BSD license.  You should have received a copy of the
12  * 3-clause BSD license along with mlpack.  If not, see
13  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
14  */
15 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_IMPL_HPP
16 #define MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_IMPL_HPP
17 
18 // In case it hasn't been included yet.
19 #include "ns_model.hpp"
20 
21 #include <boost/serialization/variant.hpp>
22 
23 namespace mlpack {
24 namespace neighbor {
25 
26 //! Monochromatic neighbor search on the given NSType instance.
27 template<typename NSType>
operator ()(NSType * ns) const28 void MonoSearchVisitor::operator()(NSType *ns) const
29 {
30   if (ns)
31     return ns->Search(k, neighbors, distances);
32   throw std::runtime_error("no neighbor search model initialized");
33 }
34 
35 //! Save parameters for bichromatic neighbor search.
36 template<typename SortPolicy>
BiSearchVisitor(const arma::mat & querySet,const size_t k,arma::Mat<size_t> & neighbors,arma::mat & distances,const size_t leafSize,const double tau,const double rho)37 BiSearchVisitor<SortPolicy>::BiSearchVisitor(const arma::mat& querySet,
38                                              const size_t k,
39                                              arma::Mat<size_t>& neighbors,
40                                              arma::mat& distances,
41                                              const size_t leafSize,
42                                              const double tau,
43                                              const double rho) :
44     querySet(querySet),
45     k(k),
46     neighbors(neighbors),
47     distances(distances),
48     leafSize(leafSize),
49     tau(tau),
50     rho(rho)
51 {}
52 
53 //! Default Bichromatic neighbor search on the given NSType instance.
54 template<typename SortPolicy>
55 template<template<typename TreeMetricType,
56                   typename TreeStatType,
57                   typename TreeMatType> class TreeType>
operator ()(NSTypeT<TreeType> * ns) const58 void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<TreeType>* ns) const
59 {
60   if (ns)
61     return ns->Search(querySet, k, neighbors, distances);
62   throw std::runtime_error("no neighbor search model initialized");
63 }
64 
65 //! Bichromatic neighbor search on the given NSType specialized for KDTrees.
66 template<typename SortPolicy>
operator ()(NSTypeT<tree::KDTree> * ns) const67 void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::KDTree>* ns) const
68 {
69   if (ns)
70     return SearchLeaf(ns);
71   throw std::runtime_error("no neighbor search model initialized");
72 }
73 
74 //! Bichromatic neighbor search on the given NSType specialized for BallTrees.
75 template<typename SortPolicy>
operator ()(NSTypeT<tree::BallTree> * ns) const76 void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::BallTree>* ns) const
77 {
78   if (ns)
79     return SearchLeaf(ns);
80   throw std::runtime_error("no neighbor search model initialized");
81 }
82 
83 //! Bichromatic neighbor search specialized for SPTrees.
84 template<typename SortPolicy>
operator ()(SpillKNN * ns) const85 void BiSearchVisitor<SortPolicy>::operator()(SpillKNN* ns) const
86 {
87   if (ns)
88   {
89     if (ns->SearchMode() == DUAL_TREE_MODE)
90     {
91       // For Dual Tree Search on SpillTrees, the queryTree must be built with
92       // non overlapping (tau = 0).
93       typename SpillKNN::Tree queryTree(std::move(querySet), 0 /* tau*/,
94           leafSize, rho);
95       ns->Search(queryTree, k, neighbors, distances);
96     }
97     else
98       ns->Search(querySet, k, neighbors, distances);
99   }
100   else
101     throw std::runtime_error("no neighbor search model initialized");
102 }
103 
104 //! Bichromatic neighbor search specialized for octrees.
105 template<typename SortPolicy>
operator ()(NSTypeT<tree::Octree> * ns) const106 void BiSearchVisitor<SortPolicy>::operator()(NSTypeT<tree::Octree>* ns) const
107 {
108   if (ns)
109     return SearchLeaf(ns);
110   throw std::runtime_error("no neighbor search model initialized");
111 }
112 
113 //! Bichromatic neighbor search on the given NSType considering the leafSize.
114 template<typename SortPolicy>
115 template<typename NSType>
SearchLeaf(NSType * ns) const116 void BiSearchVisitor<SortPolicy>::SearchLeaf(NSType* ns) const
117 {
118   if (ns->SearchMode() == DUAL_TREE_MODE)
119   {
120     std::vector<size_t> oldFromNewQueries;
121     typename NSType::Tree queryTree(std::move(querySet), oldFromNewQueries,
122         leafSize);
123 
124     arma::Mat<size_t> neighborsOut;
125     arma::mat distancesOut;
126     ns->Search(queryTree, k, neighborsOut, distancesOut);
127 
128     // Unmap the query points.
129     distances.set_size(distancesOut.n_rows, distancesOut.n_cols);
130     neighbors.set_size(neighborsOut.n_rows, neighborsOut.n_cols);
131     for (size_t i = 0; i < neighborsOut.n_cols; ++i)
132     {
133       neighbors.col(oldFromNewQueries[i]) = neighborsOut.col(i);
134       distances.col(oldFromNewQueries[i]) = distancesOut.col(i);
135     }
136   }
137   else
138     ns->Search(querySet, k, neighbors, distances);
139 }
140 
141 //! Save parameters for Train.
142 template<typename SortPolicy>
TrainVisitor(arma::mat && referenceSet,const size_t leafSize,const double tau,const double rho)143 TrainVisitor<SortPolicy>::TrainVisitor(arma::mat&& referenceSet,
144                                        const size_t leafSize,
145                                        const double tau,
146                                        const double rho) :
147     referenceSet(std::move(referenceSet)),
148     leafSize(leafSize),
149     tau(tau),
150     rho(rho)
151 {}
152 
153 //! Default Train on the given NSType instance.
154 template<typename SortPolicy>
155 template<template<typename TreeMetricType,
156                   typename TreeStatType,
157                   typename TreeMatType> class TreeType>
operator ()(NSTypeT<TreeType> * ns) const158 void TrainVisitor<SortPolicy>::operator()(NSTypeT<TreeType>* ns) const
159 {
160   if (ns)
161     return ns->Train(std::move(referenceSet));
162   throw std::runtime_error("no neighbor search model initialized");
163 }
164 
165 //! Train on the given NSType specialized for KDTrees.
166 template<typename SortPolicy>
operator ()(NSTypeT<tree::KDTree> * ns) const167 void TrainVisitor<SortPolicy>::operator()(NSTypeT<tree::KDTree>* ns) const
168 {
169   if (ns)
170     return TrainLeaf(ns);
171   throw std::runtime_error("no neighbor search model initialized");
172 }
173 
174 //! Train on the given NSType specialized for BallTrees.
175 template<typename SortPolicy>
operator ()(NSTypeT<tree::BallTree> * ns) const176 void TrainVisitor<SortPolicy>::operator()(NSTypeT<tree::BallTree>* ns) const
177 {
178   if (ns)
179     return TrainLeaf(ns);
180   throw std::runtime_error("no neighbor search model initialized");
181 }
182 
183 //! Train specialized for SPTrees.
184 template<typename SortPolicy>
operator ()(SpillKNN * ns) const185 void TrainVisitor<SortPolicy>::operator()(SpillKNN* ns) const
186 {
187   if (ns)
188   {
189     if (ns->SearchMode() == NAIVE_MODE)
190       ns->Train(std::move(referenceSet));
191     else
192     {
193       typename SpillKNN::Tree tree(std::move(referenceSet), tau, leafSize, rho);
194       ns->Train(std::move(tree));
195     }
196   }
197   else
198     throw std::runtime_error("no neighbor search model initialized");
199 }
200 
201 //! Train specialized for Octrees.
202 template<typename SortPolicy>
operator ()(NSTypeT<tree::Octree> * ns) const203 void TrainVisitor<SortPolicy>::operator()(NSTypeT<tree::Octree>* ns) const
204 {
205   if (ns)
206     return TrainLeaf(ns);
207   throw std::runtime_error("no neighbor search model initialized");
208 }
209 
210 //! Train on the given NSType considering the leafSize.
211 template<typename SortPolicy>
212 template<typename NSType>
TrainLeaf(NSType * ns) const213 void TrainVisitor<SortPolicy>::TrainLeaf(NSType* ns) const
214 {
215   if (ns->SearchMode() == NAIVE_MODE)
216     ns->Train(std::move(referenceSet));
217   else
218   {
219     std::vector<size_t> oldFromNewReferences;
220     typename NSType::Tree referenceTree(std::move(referenceSet),
221         oldFromNewReferences, leafSize);
222     ns->Train(std::move(referenceTree));
223     // Set the mappings.
224     ns->oldFromNewReferences = std::move(oldFromNewReferences);
225   }
226 }
227 
228 //! Return the search mode.
229 template<typename NSType>
operator ()(NSType * ns) const230 NeighborSearchMode& SearchModeVisitor::operator()(NSType* ns) const
231 {
232   if (ns)
233     return ns->SearchMode();
234   throw std::runtime_error("no neighbor search model initialized");
235 }
236 
237 //! Expose the Epsilon method of the given NSType.
238 template<typename NSType>
operator ()(NSType * ns) const239 double& EpsilonVisitor::operator()(NSType* ns) const
240 {
241   if (ns)
242     return ns->Epsilon();
243   throw std::runtime_error("no neighbor search model initialized");
244 }
245 
246 //! Expose the referenceSet of the given NSType.
247 template<typename NSType>
operator ()(NSType * ns) const248 const arma::mat& ReferenceSetVisitor::operator()(NSType* ns) const
249 {
250   if (ns)
251     return ns->ReferenceSet();
252   throw std::runtime_error("no neighbor search model initialized");
253 }
254 
255 //! Clean memory, if necessary.
256 template<typename NSType>
operator ()(NSType * ns) const257 void DeleteVisitor::operator()(NSType* ns) const
258 {
259   if (ns)
260     delete ns;
261 }
262 
263 /**
264  * Initialize the NSModel with the given type and whether or not a random
265  * basis should be used.
266  */
267 template<typename SortPolicy>
NSModel(TreeTypes treeType,bool randomBasis)268 NSModel<SortPolicy>::NSModel(TreeTypes treeType, bool randomBasis) :
269     treeType(treeType),
270     leafSize(20),
271     tau(0),
272     rho(0.7),
273     randomBasis(randomBasis)
274 {
275   // Nothing to do.
276 }
277 
278 template<typename SortPolicy>
NSModel(const NSModel & other)279 NSModel<SortPolicy>::NSModel(const NSModel& other) :
280     treeType(other.treeType),
281     leafSize(other.leafSize),
282     tau(other.tau),
283     rho(other.rho),
284     randomBasis(other.randomBasis),
285     q(other.q),
286     nSearch(other.nSearch)
287 {
288   // Nothing to do.
289 }
290 
291 template<typename SortPolicy>
NSModel(NSModel && other)292 NSModel<SortPolicy>::NSModel(NSModel&& other) :
293     treeType(other.treeType),
294     leafSize(other.leafSize),
295     tau(other.tau),
296     rho(other.rho),
297     randomBasis(other.randomBasis),
298     q(std::move(other.q)),
299     nSearch(other.nSearch)
300 {
301   // Reset parameters of the other model.
302   other.treeType = TreeTypes::KD_TREE;
303   other.leafSize = 20;
304   other.tau = 0;
305   other.rho = 0.7;
306   other.randomBasis = false;
307   other.nSearch = decltype(other.nSearch)();
308 }
309 
310 template<typename SortPolicy>
operator =(const NSModel & other)311 NSModel<SortPolicy>& NSModel<SortPolicy>::operator=(const NSModel& other)
312 {
313   boost::apply_visitor(DeleteVisitor(), nSearch);
314 
315   treeType = other.treeType;
316   leafSize = other.leafSize;
317   tau = other.tau;
318   rho = other.rho;
319   randomBasis = other.randomBasis;
320   q = other.q;
321   nSearch = other.nSearch;
322 
323   return *this;
324 }
325 
326 template<typename SortPolicy>
operator =(NSModel && other)327 NSModel<SortPolicy>& NSModel<SortPolicy>::operator=(NSModel&& other)
328 {
329   boost::apply_visitor(DeleteVisitor(), nSearch);
330 
331   treeType = other.treeType;
332   leafSize = other.leafSize;
333   tau = other.tau;
334   rho = other.rho;
335   randomBasis = other.randomBasis;
336   q = std::move(other.q);
337   // Copy the pointer and type.
338   nSearch = other.nSearch;
339 
340   // Reset parameters of the other model.
341   other.treeType = TreeTypes::KD_TREE;
342   other.leafSize = 20;
343   other.tau = 0;
344   other.rho = 0.7;
345   other.randomBasis = false;
346   other.nSearch = decltype(other.nSearch)();
347 
348   return *this;
349 }
350 
351 //! Clean memory, if necessary.
352 template<typename SortPolicy>
~NSModel()353 NSModel<SortPolicy>::~NSModel()
354 {
355   boost::apply_visitor(DeleteVisitor(), nSearch);
356 }
357 
358 //! Serialize the kNN model.
359 template<typename SortPolicy>
360 template<typename Archive>
serialize(Archive & ar,const unsigned int version)361 void NSModel<SortPolicy>::serialize(Archive& ar, const unsigned int version)
362 {
363   ar & BOOST_SERIALIZATION_NVP(treeType);
364   // Backward compatibility: older versions of NSModel didn't include these
365   // parameters.
366   if (version > 0)
367   {
368     ar & BOOST_SERIALIZATION_NVP(leafSize);
369     ar & BOOST_SERIALIZATION_NVP(tau);
370     ar & BOOST_SERIALIZATION_NVP(rho);
371   }
372   ar & BOOST_SERIALIZATION_NVP(randomBasis);
373   ar & BOOST_SERIALIZATION_NVP(q);
374 
375   // This should never happen, but just in case, be clean with memory.
376   if (Archive::is_loading::value)
377     boost::apply_visitor(DeleteVisitor(), nSearch);
378 
379   ar & BOOST_SERIALIZATION_NVP(nSearch);
380 }
381 
382 //! Expose the dataset.
383 template<typename SortPolicy>
Dataset() const384 const arma::mat& NSModel<SortPolicy>::Dataset() const
385 {
386   return boost::apply_visitor(ReferenceSetVisitor(), nSearch);
387 }
388 
389 //! Access the search mode.
390 template<typename SortPolicy>
SearchMode() const391 NeighborSearchMode NSModel<SortPolicy>::SearchMode() const
392 {
393   return boost::apply_visitor(SearchModeVisitor(), nSearch);
394 }
395 
396 //! Modify the search mode.
397 template<typename SortPolicy>
SearchMode()398 NeighborSearchMode& NSModel<SortPolicy>::SearchMode()
399 {
400   return boost::apply_visitor(SearchModeVisitor(), nSearch);
401 }
402 
403 template<typename SortPolicy>
Epsilon() const404 double NSModel<SortPolicy>::Epsilon() const
405 {
406   return boost::apply_visitor(EpsilonVisitor(), nSearch);
407 }
408 
409 template<typename SortPolicy>
Epsilon()410 double& NSModel<SortPolicy>::Epsilon()
411 {
412   return boost::apply_visitor(EpsilonVisitor(), nSearch);
413 }
414 
415 //! Build the reference tree.
416 template<typename SortPolicy>
BuildModel(arma::mat && referenceSet,const size_t leafSize,const NeighborSearchMode searchMode,const double epsilon)417 void NSModel<SortPolicy>::BuildModel(arma::mat&& referenceSet,
418                                      const size_t leafSize,
419                                      const NeighborSearchMode searchMode,
420                                      const double epsilon)
421 {
422   this->leafSize = leafSize;
423   // Initialize random basis if necessary.
424   if (randomBasis)
425   {
426     Log::Info << "Creating random basis..." << std::endl;
427     while (true)
428     {
429       // [Q, R] = qr(randn(d, d));
430       // Q = Q * diag(sign(diag(R)));
431       arma::mat r;
432       if (arma::qr(q, r, arma::randn<arma::mat>(referenceSet.n_rows,
433               referenceSet.n_rows)))
434       {
435         arma::vec rDiag(r.n_rows);
436         for (size_t i = 0; i < rDiag.n_elem; ++i)
437         {
438           if (r(i, i) < 0)
439             rDiag(i) = -1;
440           else if (r(i, i) > 0)
441             rDiag(i) = 1;
442           else
443             rDiag(i) = 0;
444         }
445 
446         q *= arma::diagmat(rDiag);
447 
448         // Check if the determinant is positive.
449         if (arma::det(q) >= 0)
450           break;
451       }
452     }
453   }
454 
455   // Clean memory, if necessary.
456   boost::apply_visitor(DeleteVisitor(), nSearch);
457 
458   // Do we need to modify the reference set?
459   if (randomBasis)
460     referenceSet = q * referenceSet;
461 
462   if (searchMode != NAIVE_MODE)
463   {
464     Timer::Start("tree_building");
465     Log::Info << "Building reference tree..." << std::endl;
466   }
467 
468   switch (treeType)
469   {
470     case KD_TREE:
471       nSearch = new NSType<SortPolicy, tree::KDTree>(searchMode, epsilon);
472       break;
473     case COVER_TREE:
474       nSearch = new NSType<SortPolicy, tree::StandardCoverTree>(searchMode,
475           epsilon);
476       break;
477     case R_TREE:
478       nSearch = new NSType<SortPolicy, tree::RTree>(searchMode, epsilon);
479       break;
480     case R_STAR_TREE:
481       nSearch = new NSType<SortPolicy, tree::RStarTree>(searchMode, epsilon);
482       break;
483     case BALL_TREE:
484       nSearch = new NSType<SortPolicy, tree::BallTree>(searchMode, epsilon);
485       break;
486     case X_TREE:
487       nSearch = new NSType<SortPolicy, tree::XTree>(searchMode, epsilon);
488       break;
489     case HILBERT_R_TREE:
490       nSearch = new NSType<SortPolicy, tree::HilbertRTree>(searchMode, epsilon);
491       break;
492     case R_PLUS_TREE:
493       nSearch = new NSType<SortPolicy, tree::RPlusTree>(searchMode, epsilon);
494       break;
495     case R_PLUS_PLUS_TREE:
496       nSearch = new NSType<SortPolicy, tree::RPlusPlusTree>(searchMode,
497           epsilon);
498       break;
499     case VP_TREE:
500       nSearch = new NSType<SortPolicy, tree::VPTree>(searchMode, epsilon);
501       break;
502     case RP_TREE:
503       nSearch = new NSType<SortPolicy, tree::RPTree>(searchMode, epsilon);
504       break;
505     case MAX_RP_TREE:
506       nSearch = new NSType<SortPolicy, tree::MaxRPTree>(searchMode, epsilon);
507       break;
508     case SPILL_TREE:
509       nSearch = new SpillKNN(searchMode, epsilon);
510       break;
511     case UB_TREE:
512       nSearch = new NSType<SortPolicy, tree::UBTree>(searchMode, epsilon);
513       break;
514     case OCTREE:
515       nSearch = new NSType<SortPolicy, tree::Octree>(searchMode, epsilon);
516       break;
517   }
518 
519   TrainVisitor<SortPolicy> tn(std::move(referenceSet), leafSize, tau, rho);
520   boost::apply_visitor(tn, nSearch);
521 
522   if (searchMode != NAIVE_MODE)
523   {
524     Timer::Stop("tree_building");
525     Log::Info << "Tree built." << std::endl;
526   }
527 }
528 
529 //! Perform neighbor search.  The query set will be reordered.
530 template<typename SortPolicy>
Search(arma::mat && querySet,const size_t k,arma::Mat<size_t> & neighbors,arma::mat & distances)531 void NSModel<SortPolicy>::Search(arma::mat&& querySet,
532                                  const size_t k,
533                                  arma::Mat<size_t>& neighbors,
534                                  arma::mat& distances)
535 {
536   // We may need to map the query set randomly.
537   if (randomBasis)
538     querySet = q * querySet;
539 
540   Log::Info << "Searching for " << k << " neighbors with ";
541 
542   switch (SearchMode())
543   {
544     case NAIVE_MODE:
545       Log::Info << "brute-force (naive) search..." << std::endl;
546       break;
547     case SINGLE_TREE_MODE:
548       Log::Info << "single-tree " << TreeName() << " search..." << std::endl;
549       break;
550     case DUAL_TREE_MODE:
551       Log::Info << "dual-tree " << TreeName() << " search..." << std::endl;
552       break;
553     case GREEDY_SINGLE_TREE_MODE:
554       Log::Info << "greedy single-tree " << TreeName() << " search..."
555           << std::endl;
556       break;
557   }
558 
559   BiSearchVisitor<SortPolicy> search(querySet, k, neighbors, distances,
560       leafSize, tau, rho);
561   boost::apply_visitor(search, nSearch);
562 }
563 
564 //! Perform neighbor search.
565 template<typename SortPolicy>
Search(const size_t k,arma::Mat<size_t> & neighbors,arma::mat & distances)566 void NSModel<SortPolicy>::Search(const size_t k,
567                                  arma::Mat<size_t>& neighbors,
568                                  arma::mat& distances)
569 {
570   Log::Info << "Searching for " << k << " neighbors with ";
571 
572   switch (SearchMode())
573   {
574     case NAIVE_MODE:
575       Log::Info << "brute-force (naive) search..." << std::endl;
576       break;
577     case SINGLE_TREE_MODE:
578       Log::Info << "single-tree " << TreeName() << " search..." << std::endl;
579       break;
580     case DUAL_TREE_MODE:
581       Log::Info << "dual-tree " << TreeName() << " search..." << std::endl;
582       break;
583     case GREEDY_SINGLE_TREE_MODE:
584       Log::Info << "greedy single-tree " << TreeName() << " search..."
585           << std::endl;
586       break;
587   }
588 
589   if (Epsilon() != 0 && SearchMode() != NAIVE_MODE)
590     Log::Info << "Maximum of " << Epsilon() * 100 << "% relative error."
591         << std::endl;
592 
593   MonoSearchVisitor search(k, neighbors, distances);
594   boost::apply_visitor(search, nSearch);
595 }
596 
597 //! Get the name of the tree type.
598 template<typename SortPolicy>
TreeName() const599 std::string NSModel<SortPolicy>::TreeName() const
600 {
601   switch (treeType)
602   {
603     case KD_TREE:
604       return "kd-tree";
605     case COVER_TREE:
606       return "cover tree";
607     case R_TREE:
608       return "R tree";
609     case R_STAR_TREE:
610       return "R* tree";
611     case BALL_TREE:
612       return "ball tree";
613     case X_TREE:
614       return "X tree";
615     case HILBERT_R_TREE:
616       return "Hilbert R tree";
617     case R_PLUS_TREE:
618       return "R+ tree";
619     case R_PLUS_PLUS_TREE:
620       return "R++ tree";
621     case SPILL_TREE:
622       return "Spill tree";
623     case VP_TREE:
624       return "vantage point tree";
625     case RP_TREE:
626       return "random projection tree (mean split)";
627     case MAX_RP_TREE:
628       return "random projection tree (max split)";
629     case UB_TREE:
630       return "UB tree";
631     case OCTREE:
632       return "octree";
633     default:
634       return "unknown tree";
635   }
636 }
637 
638 } // namespace neighbor
639 } // namespace mlpack
640 
641 #endif
642