1 /**
2  * @file methods/range_search/rs_model_impl.hpp
3  * @author Ryan Curtin
4  *
5  * Implementation of serialize() and inline functions for RSModel.
6  *
7  * mlpack is free software; you may redistribute it and/or modify it under the
8  * terms of the 3-clause BSD license.  You should have received a copy of the
9  * 3-clause BSD license along with mlpack.  If not, see
10  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
11  */
12 #ifndef MLPACK_METHODS_RANGE_SEARCH_RS_MODEL_IMPL_HPP
13 #define MLPACK_METHODS_RANGE_SEARCH_RS_MODEL_IMPL_HPP
14 
15 // In case it hasn't been included yet.
16 #include "rs_model.hpp"
17 
18 #include <mlpack/core/math/random_basis.hpp>
19 #include <boost/serialization/variant.hpp>
20 
21 namespace mlpack {
22 namespace range {
23 
24 /**
25  * Initialize the RSModel with the given tree type and whether or not a random
26  * basis should be used.
27  */
RSModel(TreeTypes treeType,bool randomBasis)28 inline RSModel::RSModel(TreeTypes treeType, bool randomBasis) :
29     treeType(treeType),
30     leafSize(0),
31     randomBasis(randomBasis)
32 {
33   // Nothing to do.
34 }
35 
36 // Copy constructor.
RSModel(const RSModel & other)37 inline RSModel::RSModel(const RSModel& other) :
38     treeType(other.treeType),
39     leafSize(other.leafSize),
40     randomBasis(other.randomBasis),
41     q(other.q),
42     rSearch(other.rSearch)
43 {
44   // Nothing to do.
45 }
46 
47 // Move constructor.
RSModel(RSModel && other)48 inline RSModel::RSModel(RSModel&& other) :
49     treeType(other.treeType),
50     leafSize(other.leafSize),
51     randomBasis(other.randomBasis),
52     q(std::move(other.q)),
53     rSearch(std::move(other.rSearch))
54 {
55   // Reset other model.
56   other.treeType = TreeTypes::KD_TREE;
57   other.leafSize = 0;
58   other.randomBasis = false;
59   other.rSearch = decltype(other.rSearch)();
60 }
61 
operator =(RSModel other)62 inline RSModel& RSModel::operator=(RSModel other)
63 {
64   boost::apply_visitor(DeleteVisitor(), rSearch);
65 
66   treeType = other.treeType;
67   leafSize = other.leafSize;
68   randomBasis = other.randomBasis;
69   q = std::move(other.q);
70   rSearch = std::move(other.rSearch);
71 
72   return *this;
73 }
74 
75 // Clean memory, if necessary.
~RSModel()76 inline RSModel::~RSModel()
77 {
78   boost::apply_visitor(DeleteVisitor(), rSearch);
79 }
80 
BuildModel(arma::mat && referenceSet,const size_t leafSize,const bool naive,const bool singleMode)81 inline void RSModel::BuildModel(arma::mat&& referenceSet,
82                                 const size_t leafSize,
83                                 const bool naive,
84                                 const bool singleMode)
85 {
86   // Initialize random basis if necessary.
87   if (randomBasis)
88   {
89     Log::Info << "Creating random basis..." << std::endl;
90     math::RandomBasis(q, referenceSet.n_rows);
91   }
92 
93   this->leafSize = leafSize;
94 
95   // Clean memory, if necessary.
96   boost::apply_visitor(DeleteVisitor(), rSearch);
97 
98   // Do we need to modify the reference set?
99   if (randomBasis)
100     referenceSet = q * referenceSet;
101 
102   if (!naive)
103   {
104     Timer::Start("tree_building");
105     Log::Info << "Building reference tree..." << std::endl;
106   }
107 
108   switch (treeType)
109   {
110     case KD_TREE:
111       rSearch = new RSType<tree::KDTree> (naive, singleMode);
112       break;
113 
114     case COVER_TREE:
115       rSearch = new RSType<tree::StandardCoverTree>(naive, singleMode);
116       break;
117 
118     case R_TREE:
119       rSearch = new RSType<tree::RTree>(naive, singleMode);
120       break;
121 
122     case R_STAR_TREE:
123       rSearch = new RSType<tree::RStarTree>(naive, singleMode);
124       break;
125 
126     case BALL_TREE:
127       rSearch = new RSType<tree::BallTree>(naive, singleMode);
128       break;
129 
130     case X_TREE:
131       rSearch = new RSType<tree::XTree>(naive, singleMode);
132       break;
133 
134     case HILBERT_R_TREE:
135       rSearch = new RSType<tree::HilbertRTree>(naive, singleMode);
136       break;
137 
138     case R_PLUS_TREE:
139       rSearch = new RSType<tree::RPlusTree>(naive, singleMode);
140       break;
141 
142     case R_PLUS_PLUS_TREE:
143       rSearch = new RSType<tree::RPlusPlusTree>(naive, singleMode);
144       break;
145 
146     case VP_TREE:
147       rSearch = new RSType<tree::VPTree>(naive, singleMode);
148       break;
149 
150     case RP_TREE:
151       rSearch = new RSType<tree::RPTree>(naive, singleMode);
152       break;
153 
154     case MAX_RP_TREE:
155       rSearch = new RSType<tree::MaxRPTree>(naive, singleMode);
156       break;
157 
158     case UB_TREE:
159       rSearch = new RSType<tree::UBTree>(naive, singleMode);
160       break;
161 
162     case OCTREE:
163       rSearch = new RSType<tree::Octree>(naive, singleMode);
164       break;
165   }
166 
167   TrainVisitor tn(std::move(referenceSet), leafSize);
168   boost::apply_visitor(tn, rSearch);
169 
170   if (!naive)
171   {
172     Timer::Stop("tree_building");
173     Log::Info << "Tree built." << std::endl;
174   }
175 }
176 
177 // Perform range search.
Search(arma::mat && querySet,const math::Range & range,std::vector<std::vector<size_t>> & neighbors,std::vector<std::vector<double>> & distances)178 inline void RSModel::Search(arma::mat&& querySet,
179                             const math::Range& range,
180                             std::vector<std::vector<size_t>>& neighbors,
181                             std::vector<std::vector<double>>& distances)
182 {
183   // We may need to map the query set randomly.
184   if (randomBasis)
185     querySet = q * querySet;
186 
187   Log::Info << "Search for points in the range [" << range.Lo() << ", "
188       << range.Hi() << "] with ";
189   if (!Naive() && !SingleMode())
190     Log::Info << "dual-tree " << TreeName() << " search..." << std::endl;
191   else if (!Naive())
192     Log::Info << "single-tree " << TreeName() << " search..." << std::endl;
193   else
194     Log::Info << "brute-force (naive) search..." << std::endl;
195 
196 
197   BiSearchVisitor search(querySet, range, neighbors, distances,
198       leafSize);
199   boost::apply_visitor(search, rSearch);
200 }
201 
202 // Perform range search (monochromatic case).
Search(const math::Range & range,std::vector<std::vector<size_t>> & neighbors,std::vector<std::vector<double>> & distances)203 inline void RSModel::Search(const math::Range& range,
204                             std::vector<std::vector<size_t>>& neighbors,
205                             std::vector<std::vector<double>>& distances)
206 {
207   Log::Info << "Search for points in the range [" << range.Lo() << ", "
208       << range.Hi() << "] with ";
209   if (!Naive() && !SingleMode())
210     Log::Info << "dual-tree " << TreeName() << " search..." << std::endl;
211   else if (!Naive())
212     Log::Info << "single-tree " << TreeName() << " search..." << std::endl;
213   else
214     Log::Info << "brute-force (naive) search..." << std::endl;
215 
216   MonoSearchVisitor search(range, neighbors, distances);
217   boost::apply_visitor(search, rSearch);
218 }
219 
220 // Get the name of the tree type.
TreeName() const221 inline std::string RSModel::TreeName() const
222 {
223   switch (treeType)
224   {
225     case KD_TREE:
226       return "kd-tree";
227     case COVER_TREE:
228       return "cover tree";
229     case R_TREE:
230       return "R tree";
231     case R_STAR_TREE:
232       return "R* tree";
233     case BALL_TREE:
234       return "ball tree";
235     case X_TREE:
236       return "X tree";
237     case HILBERT_R_TREE:
238       return "Hilbert R tree";
239     case R_PLUS_TREE:
240       return "R+ tree";
241     case R_PLUS_PLUS_TREE:
242       return "R++ tree";
243     case VP_TREE:
244       return "vantage point tree";
245     case RP_TREE:
246       return "random projection tree (mean split)";
247     case MAX_RP_TREE:
248       return "random projection tree (max split)";
249     case UB_TREE:
250       return "UB tree";
251     case OCTREE:
252       return "octree";
253     default:
254       return "unknown tree";
255   }
256 }
257 
258 // Clean memory.
CleanMemory()259 inline void RSModel::CleanMemory()
260 {
261   boost::apply_visitor(DeleteVisitor(), rSearch);
262 }
263 
264 //! Monochromatic range search on the given RSType instance.
265 template<typename RSType>
operator ()(RSType * rs) const266 void MonoSearchVisitor::operator()(RSType* rs) const
267 {
268   if (rs)
269     return rs->Search(range, neighbors, distances);
270   throw std::runtime_error("no range search model initialized");
271 }
272 
273 //! Save parameters for bichromatic range search.
BiSearchVisitor(const arma::mat & querySet,const math::Range & range,std::vector<std::vector<size_t>> & neighbors,std::vector<std::vector<double>> & distances,const size_t leafSize)274 inline BiSearchVisitor::BiSearchVisitor(
275     const arma::mat& querySet,
276     const math::Range& range,
277     std::vector<std::vector<size_t>>& neighbors,
278     std::vector<std::vector<double>>& distances,
279     const size_t leafSize) :
280     querySet(querySet),
281     range(range),
282     neighbors(neighbors),
283     distances(distances),
284     leafSize(leafSize)
285 {}
286 
287 //! Default Bichromatic range search on the given RSType instance.
288 template<template<typename TreeMetricType,
289                   typename TreeStatType,
290                   typename TreeMatType> class TreeType>
operator ()(RSTypeT<TreeType> * rs) const291 void BiSearchVisitor::operator()(RSTypeT<TreeType>* rs) const
292 {
293   if (rs)
294     return rs->Search(querySet, range, neighbors, distances);
295   throw std::runtime_error("no range search model initialized");
296 }
297 
298 //! Bichromatic range search on the given RSType specialized for KDTrees.
operator ()(RSTypeT<tree::KDTree> * rs) const299 inline void BiSearchVisitor::operator()(RSTypeT<tree::KDTree>* rs) const
300 {
301   if (rs)
302     return SearchLeaf(rs);
303   throw std::runtime_error("no range search model initialized");
304 }
305 
306 //! Bichromatic range search on the given RSType specialized for BallTrees.
operator ()(RSTypeT<tree::BallTree> * rs) const307 inline void BiSearchVisitor::operator()(RSTypeT<tree::BallTree>* rs) const
308 {
309   if (rs)
310     return SearchLeaf(rs);
311   throw std::runtime_error("no range search model initialized");
312 }
313 
314 //! Bichromatic range search specialized for Ocrees.
operator ()(RSTypeT<tree::Octree> * rs) const315 inline void BiSearchVisitor::operator()(RSTypeT<tree::Octree>* rs) const
316 {
317   if (rs)
318     return SearchLeaf(rs);
319   throw std::runtime_error("no range search model initialized");
320 }
321 
322 //! Bichromatic range search on the given RSType considering the leafSize.
323 template<typename RSType>
SearchLeaf(RSType * rs) const324 void BiSearchVisitor::SearchLeaf(RSType* rs) const
325 {
326   if (!rs->Naive() && !rs->SingleMode())
327   {
328     // Build a second tree and search.
329     Timer::Start("tree_building");
330     Log::Info << "Building query tree..." << std::endl;
331     std::vector<size_t> oldFromNewQueries;
332     typename RSType::Tree queryTree(std::move(querySet), oldFromNewQueries,
333         leafSize);
334     Log::Info << "Tree built." << std::endl;
335     Timer::Stop("tree_building");
336 
337     std::vector<std::vector<size_t>> neighborsOut;
338     std::vector<std::vector<double>> distancesOut;
339     rs->Search(&queryTree, range, neighborsOut, distancesOut);
340 
341     // Remap the query points.
342     neighbors.resize(queryTree.Dataset().n_cols);
343     distances.resize(queryTree.Dataset().n_cols);
344     for (size_t i = 0; i < queryTree.Dataset().n_cols; ++i)
345     {
346       neighbors[oldFromNewQueries[i]] = neighborsOut[i];
347       distances[oldFromNewQueries[i]] = distancesOut[i];
348     }
349   }
350   else
351     rs->Search(querySet, range, neighbors, distances);
352 }
353 
354 //! Save parameters for Train.
TrainVisitor(arma::mat && referenceSet,const size_t leafSize)355 inline TrainVisitor::TrainVisitor(arma::mat&& referenceSet,
356                                   const size_t leafSize) :
357     referenceSet(std::move(referenceSet)),
358     leafSize(leafSize)
359 {}
360 
361 //! Default Train on the given RSType instance.
362 template<template<typename TreeMetricType,
363                   typename TreeStatType,
364                   typename TreeMatType> class TreeType>
operator ()(RSTypeT<TreeType> * rs) const365 void TrainVisitor::operator()(RSTypeT<TreeType>* rs) const
366 {
367   if (rs)
368     return rs->Train(std::move(referenceSet));
369   throw std::runtime_error("no range search model initialized");
370 }
371 
372 //! Train on the given RSType specialized for KDTrees.
operator ()(RSTypeT<tree::KDTree> * rs) const373 inline void TrainVisitor::operator()(RSTypeT<tree::KDTree>* rs) const
374 {
375   if (rs)
376     return TrainLeaf(rs);
377   throw std::runtime_error("no range search model initialized");
378 }
379 
380 //! Train on the given RSType specialized for BallTrees.
operator ()(RSTypeT<tree::BallTree> * rs) const381 inline void TrainVisitor::operator()(RSTypeT<tree::BallTree>* rs) const
382 {
383   if (rs)
384     return TrainLeaf(rs);
385   throw std::runtime_error("no range search model initialized");
386 }
387 
388 //! Train specialized for Octrees.
operator ()(RSTypeT<tree::Octree> * rs) const389 inline void TrainVisitor::operator()(RSTypeT<tree::Octree>* rs) const
390 {
391   if (rs)
392     return TrainLeaf(rs);
393   throw std::runtime_error("no range search model initialized");
394 }
395 
396 //! Train on the given RSType considering the leafSize.
397 template<typename RSType>
TrainLeaf(RSType * rs) const398 void TrainVisitor::TrainLeaf(RSType* rs) const
399 {
400   if (rs->Naive())
401     rs->Train(std::move(referenceSet));
402   else
403   {
404     std::vector<size_t> oldFromNewReferences;
405     typename RSType::Tree* tree =
406         new typename RSType::Tree(std::move(referenceSet), oldFromNewReferences,
407         leafSize);
408     rs->Train(tree);
409 
410     // Give the model ownership of the tree and the mappings.
411     rs->treeOwner = true;
412     rs->oldFromNewReferences = std::move(oldFromNewReferences);
413   }
414 }
415 
416 //! Expose the referenceSet of the given RSType.
417 template<typename RSType>
operator ()(RSType * rs) const418 const arma::mat& ReferenceSetVisitor::operator()(RSType* rs) const
419 {
420   if (rs)
421     return rs->ReferenceSet();
422   throw std::runtime_error("no range search model initialized");
423 }
424 
425 //! For cleaning memory
426 template<typename RSType>
operator ()(RSType * rs) const427 void DeleteVisitor::operator()(RSType* rs) const
428 {
429   if (rs)
430     delete rs;
431 }
432 
433 //! Return whether single mode enabled
434 template<typename RSType>
operator ()(RSType * rs) const435 bool& SingleModeVisitor::operator()(RSType* rs) const
436 {
437   if (rs)
438     return rs->SingleMode();
439   throw std::runtime_error("no range search model initialized");
440 }
441 
442 //! Exposes Naive() function of given RSType
443 template<typename RSType>
operator ()(RSType * rs) const444 bool& NaiveVisitor::operator()(RSType* rs) const
445 {
446   if (rs)
447     return rs->Naive();
448   throw std::runtime_error("no range search model initialized");
449 }
450 
451 // Serialize the model.
452 template<typename Archive>
serialize(Archive & ar,const unsigned int)453 void RSModel::serialize(Archive& ar, const unsigned int /* version */)
454 {
455   ar & BOOST_SERIALIZATION_NVP(treeType);
456   ar & BOOST_SERIALIZATION_NVP(randomBasis);
457   ar & BOOST_SERIALIZATION_NVP(q);
458 
459   // This should never happen, but just in case...
460   if (Archive::is_loading::value)
461     boost::apply_visitor(DeleteVisitor(), rSearch);
462 
463   // We'll only need to serialize one of the model objects, based on the type.
464   ar & BOOST_SERIALIZATION_NVP(rSearch);
465 }
466 
Dataset() const467 inline const arma::mat& RSModel::Dataset() const
468 {
469   return boost::apply_visitor(ReferenceSetVisitor(), rSearch);
470 }
471 
SingleMode() const472 inline bool RSModel::SingleMode() const
473 {
474   return boost::apply_visitor(SingleModeVisitor(), rSearch);
475 }
476 
SingleMode()477 inline bool& RSModel::SingleMode()
478 {
479   return boost::apply_visitor(SingleModeVisitor(), rSearch);
480 }
481 
Naive() const482 inline bool RSModel::Naive() const
483 {
484   return boost::apply_visitor(NaiveVisitor(), rSearch);
485 }
486 
Naive()487 inline bool& RSModel::Naive()
488 {
489   return boost::apply_visitor(NaiveVisitor(), rSearch);
490 }
491 
492 } // namespace range
493 } // namespace mlpack
494 
495 #endif
496