1 /**
2 * @file methods/range_search/rs_model_impl.hpp
3 * @author Ryan Curtin
4 *
5 * Implementation of serialize() and inline functions for RSModel.
6 *
7 * mlpack is free software; you may redistribute it and/or modify it under the
8 * terms of the 3-clause BSD license. You should have received a copy of the
9 * 3-clause BSD license along with mlpack. If not, see
10 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
11 */
12 #ifndef MLPACK_METHODS_RANGE_SEARCH_RS_MODEL_IMPL_HPP
13 #define MLPACK_METHODS_RANGE_SEARCH_RS_MODEL_IMPL_HPP
14
15 // In case it hasn't been included yet.
16 #include "rs_model.hpp"
17
18 #include <mlpack/core/math/random_basis.hpp>
19 #include <boost/serialization/variant.hpp>
20
21 namespace mlpack {
22 namespace range {
23
24 /**
25 * Initialize the RSModel with the given tree type and whether or not a random
26 * basis should be used.
27 */
RSModel(TreeTypes treeType,bool randomBasis)28 inline RSModel::RSModel(TreeTypes treeType, bool randomBasis) :
29 treeType(treeType),
30 leafSize(0),
31 randomBasis(randomBasis)
32 {
33 // Nothing to do.
34 }
35
36 // Copy constructor.
RSModel(const RSModel & other)37 inline RSModel::RSModel(const RSModel& other) :
38 treeType(other.treeType),
39 leafSize(other.leafSize),
40 randomBasis(other.randomBasis),
41 q(other.q),
42 rSearch(other.rSearch)
43 {
44 // Nothing to do.
45 }
46
47 // Move constructor.
RSModel(RSModel && other)48 inline RSModel::RSModel(RSModel&& other) :
49 treeType(other.treeType),
50 leafSize(other.leafSize),
51 randomBasis(other.randomBasis),
52 q(std::move(other.q)),
53 rSearch(std::move(other.rSearch))
54 {
55 // Reset other model.
56 other.treeType = TreeTypes::KD_TREE;
57 other.leafSize = 0;
58 other.randomBasis = false;
59 other.rSearch = decltype(other.rSearch)();
60 }
61
operator =(RSModel other)62 inline RSModel& RSModel::operator=(RSModel other)
63 {
64 boost::apply_visitor(DeleteVisitor(), rSearch);
65
66 treeType = other.treeType;
67 leafSize = other.leafSize;
68 randomBasis = other.randomBasis;
69 q = std::move(other.q);
70 rSearch = std::move(other.rSearch);
71
72 return *this;
73 }
74
75 // Clean memory, if necessary.
~RSModel()76 inline RSModel::~RSModel()
77 {
78 boost::apply_visitor(DeleteVisitor(), rSearch);
79 }
80
BuildModel(arma::mat && referenceSet,const size_t leafSize,const bool naive,const bool singleMode)81 inline void RSModel::BuildModel(arma::mat&& referenceSet,
82 const size_t leafSize,
83 const bool naive,
84 const bool singleMode)
85 {
86 // Initialize random basis if necessary.
87 if (randomBasis)
88 {
89 Log::Info << "Creating random basis..." << std::endl;
90 math::RandomBasis(q, referenceSet.n_rows);
91 }
92
93 this->leafSize = leafSize;
94
95 // Clean memory, if necessary.
96 boost::apply_visitor(DeleteVisitor(), rSearch);
97
98 // Do we need to modify the reference set?
99 if (randomBasis)
100 referenceSet = q * referenceSet;
101
102 if (!naive)
103 {
104 Timer::Start("tree_building");
105 Log::Info << "Building reference tree..." << std::endl;
106 }
107
108 switch (treeType)
109 {
110 case KD_TREE:
111 rSearch = new RSType<tree::KDTree> (naive, singleMode);
112 break;
113
114 case COVER_TREE:
115 rSearch = new RSType<tree::StandardCoverTree>(naive, singleMode);
116 break;
117
118 case R_TREE:
119 rSearch = new RSType<tree::RTree>(naive, singleMode);
120 break;
121
122 case R_STAR_TREE:
123 rSearch = new RSType<tree::RStarTree>(naive, singleMode);
124 break;
125
126 case BALL_TREE:
127 rSearch = new RSType<tree::BallTree>(naive, singleMode);
128 break;
129
130 case X_TREE:
131 rSearch = new RSType<tree::XTree>(naive, singleMode);
132 break;
133
134 case HILBERT_R_TREE:
135 rSearch = new RSType<tree::HilbertRTree>(naive, singleMode);
136 break;
137
138 case R_PLUS_TREE:
139 rSearch = new RSType<tree::RPlusTree>(naive, singleMode);
140 break;
141
142 case R_PLUS_PLUS_TREE:
143 rSearch = new RSType<tree::RPlusPlusTree>(naive, singleMode);
144 break;
145
146 case VP_TREE:
147 rSearch = new RSType<tree::VPTree>(naive, singleMode);
148 break;
149
150 case RP_TREE:
151 rSearch = new RSType<tree::RPTree>(naive, singleMode);
152 break;
153
154 case MAX_RP_TREE:
155 rSearch = new RSType<tree::MaxRPTree>(naive, singleMode);
156 break;
157
158 case UB_TREE:
159 rSearch = new RSType<tree::UBTree>(naive, singleMode);
160 break;
161
162 case OCTREE:
163 rSearch = new RSType<tree::Octree>(naive, singleMode);
164 break;
165 }
166
167 TrainVisitor tn(std::move(referenceSet), leafSize);
168 boost::apply_visitor(tn, rSearch);
169
170 if (!naive)
171 {
172 Timer::Stop("tree_building");
173 Log::Info << "Tree built." << std::endl;
174 }
175 }
176
177 // Perform range search.
Search(arma::mat && querySet,const math::Range & range,std::vector<std::vector<size_t>> & neighbors,std::vector<std::vector<double>> & distances)178 inline void RSModel::Search(arma::mat&& querySet,
179 const math::Range& range,
180 std::vector<std::vector<size_t>>& neighbors,
181 std::vector<std::vector<double>>& distances)
182 {
183 // We may need to map the query set randomly.
184 if (randomBasis)
185 querySet = q * querySet;
186
187 Log::Info << "Search for points in the range [" << range.Lo() << ", "
188 << range.Hi() << "] with ";
189 if (!Naive() && !SingleMode())
190 Log::Info << "dual-tree " << TreeName() << " search..." << std::endl;
191 else if (!Naive())
192 Log::Info << "single-tree " << TreeName() << " search..." << std::endl;
193 else
194 Log::Info << "brute-force (naive) search..." << std::endl;
195
196
197 BiSearchVisitor search(querySet, range, neighbors, distances,
198 leafSize);
199 boost::apply_visitor(search, rSearch);
200 }
201
202 // Perform range search (monochromatic case).
Search(const math::Range & range,std::vector<std::vector<size_t>> & neighbors,std::vector<std::vector<double>> & distances)203 inline void RSModel::Search(const math::Range& range,
204 std::vector<std::vector<size_t>>& neighbors,
205 std::vector<std::vector<double>>& distances)
206 {
207 Log::Info << "Search for points in the range [" << range.Lo() << ", "
208 << range.Hi() << "] with ";
209 if (!Naive() && !SingleMode())
210 Log::Info << "dual-tree " << TreeName() << " search..." << std::endl;
211 else if (!Naive())
212 Log::Info << "single-tree " << TreeName() << " search..." << std::endl;
213 else
214 Log::Info << "brute-force (naive) search..." << std::endl;
215
216 MonoSearchVisitor search(range, neighbors, distances);
217 boost::apply_visitor(search, rSearch);
218 }
219
220 // Get the name of the tree type.
TreeName() const221 inline std::string RSModel::TreeName() const
222 {
223 switch (treeType)
224 {
225 case KD_TREE:
226 return "kd-tree";
227 case COVER_TREE:
228 return "cover tree";
229 case R_TREE:
230 return "R tree";
231 case R_STAR_TREE:
232 return "R* tree";
233 case BALL_TREE:
234 return "ball tree";
235 case X_TREE:
236 return "X tree";
237 case HILBERT_R_TREE:
238 return "Hilbert R tree";
239 case R_PLUS_TREE:
240 return "R+ tree";
241 case R_PLUS_PLUS_TREE:
242 return "R++ tree";
243 case VP_TREE:
244 return "vantage point tree";
245 case RP_TREE:
246 return "random projection tree (mean split)";
247 case MAX_RP_TREE:
248 return "random projection tree (max split)";
249 case UB_TREE:
250 return "UB tree";
251 case OCTREE:
252 return "octree";
253 default:
254 return "unknown tree";
255 }
256 }
257
258 // Clean memory.
CleanMemory()259 inline void RSModel::CleanMemory()
260 {
261 boost::apply_visitor(DeleteVisitor(), rSearch);
262 }
263
264 //! Monochromatic range search on the given RSType instance.
265 template<typename RSType>
operator ()(RSType * rs) const266 void MonoSearchVisitor::operator()(RSType* rs) const
267 {
268 if (rs)
269 return rs->Search(range, neighbors, distances);
270 throw std::runtime_error("no range search model initialized");
271 }
272
273 //! Save parameters for bichromatic range search.
BiSearchVisitor(const arma::mat & querySet,const math::Range & range,std::vector<std::vector<size_t>> & neighbors,std::vector<std::vector<double>> & distances,const size_t leafSize)274 inline BiSearchVisitor::BiSearchVisitor(
275 const arma::mat& querySet,
276 const math::Range& range,
277 std::vector<std::vector<size_t>>& neighbors,
278 std::vector<std::vector<double>>& distances,
279 const size_t leafSize) :
280 querySet(querySet),
281 range(range),
282 neighbors(neighbors),
283 distances(distances),
284 leafSize(leafSize)
285 {}
286
287 //! Default Bichromatic range search on the given RSType instance.
288 template<template<typename TreeMetricType,
289 typename TreeStatType,
290 typename TreeMatType> class TreeType>
operator ()(RSTypeT<TreeType> * rs) const291 void BiSearchVisitor::operator()(RSTypeT<TreeType>* rs) const
292 {
293 if (rs)
294 return rs->Search(querySet, range, neighbors, distances);
295 throw std::runtime_error("no range search model initialized");
296 }
297
298 //! Bichromatic range search on the given RSType specialized for KDTrees.
operator ()(RSTypeT<tree::KDTree> * rs) const299 inline void BiSearchVisitor::operator()(RSTypeT<tree::KDTree>* rs) const
300 {
301 if (rs)
302 return SearchLeaf(rs);
303 throw std::runtime_error("no range search model initialized");
304 }
305
306 //! Bichromatic range search on the given RSType specialized for BallTrees.
operator ()(RSTypeT<tree::BallTree> * rs) const307 inline void BiSearchVisitor::operator()(RSTypeT<tree::BallTree>* rs) const
308 {
309 if (rs)
310 return SearchLeaf(rs);
311 throw std::runtime_error("no range search model initialized");
312 }
313
314 //! Bichromatic range search specialized for Ocrees.
operator ()(RSTypeT<tree::Octree> * rs) const315 inline void BiSearchVisitor::operator()(RSTypeT<tree::Octree>* rs) const
316 {
317 if (rs)
318 return SearchLeaf(rs);
319 throw std::runtime_error("no range search model initialized");
320 }
321
322 //! Bichromatic range search on the given RSType considering the leafSize.
323 template<typename RSType>
SearchLeaf(RSType * rs) const324 void BiSearchVisitor::SearchLeaf(RSType* rs) const
325 {
326 if (!rs->Naive() && !rs->SingleMode())
327 {
328 // Build a second tree and search.
329 Timer::Start("tree_building");
330 Log::Info << "Building query tree..." << std::endl;
331 std::vector<size_t> oldFromNewQueries;
332 typename RSType::Tree queryTree(std::move(querySet), oldFromNewQueries,
333 leafSize);
334 Log::Info << "Tree built." << std::endl;
335 Timer::Stop("tree_building");
336
337 std::vector<std::vector<size_t>> neighborsOut;
338 std::vector<std::vector<double>> distancesOut;
339 rs->Search(&queryTree, range, neighborsOut, distancesOut);
340
341 // Remap the query points.
342 neighbors.resize(queryTree.Dataset().n_cols);
343 distances.resize(queryTree.Dataset().n_cols);
344 for (size_t i = 0; i < queryTree.Dataset().n_cols; ++i)
345 {
346 neighbors[oldFromNewQueries[i]] = neighborsOut[i];
347 distances[oldFromNewQueries[i]] = distancesOut[i];
348 }
349 }
350 else
351 rs->Search(querySet, range, neighbors, distances);
352 }
353
354 //! Save parameters for Train.
TrainVisitor(arma::mat && referenceSet,const size_t leafSize)355 inline TrainVisitor::TrainVisitor(arma::mat&& referenceSet,
356 const size_t leafSize) :
357 referenceSet(std::move(referenceSet)),
358 leafSize(leafSize)
359 {}
360
361 //! Default Train on the given RSType instance.
362 template<template<typename TreeMetricType,
363 typename TreeStatType,
364 typename TreeMatType> class TreeType>
operator ()(RSTypeT<TreeType> * rs) const365 void TrainVisitor::operator()(RSTypeT<TreeType>* rs) const
366 {
367 if (rs)
368 return rs->Train(std::move(referenceSet));
369 throw std::runtime_error("no range search model initialized");
370 }
371
372 //! Train on the given RSType specialized for KDTrees.
operator ()(RSTypeT<tree::KDTree> * rs) const373 inline void TrainVisitor::operator()(RSTypeT<tree::KDTree>* rs) const
374 {
375 if (rs)
376 return TrainLeaf(rs);
377 throw std::runtime_error("no range search model initialized");
378 }
379
380 //! Train on the given RSType specialized for BallTrees.
operator ()(RSTypeT<tree::BallTree> * rs) const381 inline void TrainVisitor::operator()(RSTypeT<tree::BallTree>* rs) const
382 {
383 if (rs)
384 return TrainLeaf(rs);
385 throw std::runtime_error("no range search model initialized");
386 }
387
388 //! Train specialized for Octrees.
operator ()(RSTypeT<tree::Octree> * rs) const389 inline void TrainVisitor::operator()(RSTypeT<tree::Octree>* rs) const
390 {
391 if (rs)
392 return TrainLeaf(rs);
393 throw std::runtime_error("no range search model initialized");
394 }
395
396 //! Train on the given RSType considering the leafSize.
397 template<typename RSType>
TrainLeaf(RSType * rs) const398 void TrainVisitor::TrainLeaf(RSType* rs) const
399 {
400 if (rs->Naive())
401 rs->Train(std::move(referenceSet));
402 else
403 {
404 std::vector<size_t> oldFromNewReferences;
405 typename RSType::Tree* tree =
406 new typename RSType::Tree(std::move(referenceSet), oldFromNewReferences,
407 leafSize);
408 rs->Train(tree);
409
410 // Give the model ownership of the tree and the mappings.
411 rs->treeOwner = true;
412 rs->oldFromNewReferences = std::move(oldFromNewReferences);
413 }
414 }
415
416 //! Expose the referenceSet of the given RSType.
417 template<typename RSType>
operator ()(RSType * rs) const418 const arma::mat& ReferenceSetVisitor::operator()(RSType* rs) const
419 {
420 if (rs)
421 return rs->ReferenceSet();
422 throw std::runtime_error("no range search model initialized");
423 }
424
425 //! For cleaning memory
426 template<typename RSType>
operator ()(RSType * rs) const427 void DeleteVisitor::operator()(RSType* rs) const
428 {
429 if (rs)
430 delete rs;
431 }
432
433 //! Return whether single mode enabled
434 template<typename RSType>
operator ()(RSType * rs) const435 bool& SingleModeVisitor::operator()(RSType* rs) const
436 {
437 if (rs)
438 return rs->SingleMode();
439 throw std::runtime_error("no range search model initialized");
440 }
441
442 //! Exposes Naive() function of given RSType
443 template<typename RSType>
operator ()(RSType * rs) const444 bool& NaiveVisitor::operator()(RSType* rs) const
445 {
446 if (rs)
447 return rs->Naive();
448 throw std::runtime_error("no range search model initialized");
449 }
450
451 // Serialize the model.
452 template<typename Archive>
serialize(Archive & ar,const unsigned int)453 void RSModel::serialize(Archive& ar, const unsigned int /* version */)
454 {
455 ar & BOOST_SERIALIZATION_NVP(treeType);
456 ar & BOOST_SERIALIZATION_NVP(randomBasis);
457 ar & BOOST_SERIALIZATION_NVP(q);
458
459 // This should never happen, but just in case...
460 if (Archive::is_loading::value)
461 boost::apply_visitor(DeleteVisitor(), rSearch);
462
463 // We'll only need to serialize one of the model objects, based on the type.
464 ar & BOOST_SERIALIZATION_NVP(rSearch);
465 }
466
Dataset() const467 inline const arma::mat& RSModel::Dataset() const
468 {
469 return boost::apply_visitor(ReferenceSetVisitor(), rSearch);
470 }
471
SingleMode() const472 inline bool RSModel::SingleMode() const
473 {
474 return boost::apply_visitor(SingleModeVisitor(), rSearch);
475 }
476
SingleMode()477 inline bool& RSModel::SingleMode()
478 {
479 return boost::apply_visitor(SingleModeVisitor(), rSearch);
480 }
481
Naive() const482 inline bool RSModel::Naive() const
483 {
484 return boost::apply_visitor(NaiveVisitor(), rSearch);
485 }
486
Naive()487 inline bool& RSModel::Naive()
488 {
489 return boost::apply_visitor(NaiveVisitor(), rSearch);
490 }
491
492 } // namespace range
493 } // namespace mlpack
494
495 #endif
496