1 /**
2  * @file methods/rann/ra_model.hpp
3  * @author Ryan Curtin
4  *
5  * This is a model for rank-approximate nearest neighbor search.  It provides an
6  * easy way to serialize a rank-approximate neighbor search model by abstracting
7  * the types of trees and reflecting the RASearch API.
8  *
9  * mlpack is free software; you may redistribute it and/or modify it under the
10  * terms of the 3-clause BSD license.  You should have received a copy of the
11  * 3-clause BSD license along with mlpack.  If not, see
12  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
13  */
14 #ifndef MLPACK_METHODS_RANN_RA_MODEL_HPP
15 #define MLPACK_METHODS_RANN_RA_MODEL_HPP
16 
17 #include <mlpack/core/tree/binary_space_tree.hpp>
18 #include <mlpack/core/tree/cover_tree.hpp>
19 #include <mlpack/core/tree/rectangle_tree.hpp>
20 #include <mlpack/core/tree/octree.hpp>
21 #include <boost/variant.hpp>
22 #include "ra_search.hpp"
23 
24 namespace mlpack {
25 namespace neighbor {
26 
27 /**
28  * Alias template for RASearch
29  */
30 template<typename SortPolicy,
31          template<typename TreeMetricType,
32                   typename TreeStatType,
33                   typename TreeMatType> class TreeType>
34 using RAType = RASearch<SortPolicy,
35                         metric::EuclideanDistance,
36                         arma::mat,
37                         TreeType>;
38 
39 /**
40  * MonoSearchVisitor executes a monochromatic neighbor search on the given
41  * RAType. We don't make any difference for different instantiation of RAType.
42  */
43 class MonoSearchVisitor : public boost::static_visitor<void>
44 {
45  private:
46   //! Number of neighbors to search for.
47   const size_t k;
48   //! Result matrix for neighbors.
49   arma::Mat<size_t>& neighbors;
50   //! Result matrix for distances.
51   arma::mat& distances;
52 
53  public:
54   //! Perform monochromatic nearest neighbor search.
55   template<typename RAType>
56   void operator()(RAType* ra) const;
57 
58   //! Construct the MonoSearchVisitor object with the given parameters.
MonoSearchVisitor(const size_t k,arma::Mat<size_t> & neighbors,arma::mat & distances)59   MonoSearchVisitor(const size_t k,
60                     arma::Mat<size_t>& neighbors,
61                     arma::mat& distances) :
62       k(k),
63       neighbors(neighbors),
64       distances(distances)
65   {};
66 };
67 
68 /**
69  * BiSearchVisitor executes a bichromatic neighbor search on the given RAType.
70  * We use template specialization to differentiate those tree types types that
71  * accept leafSize as a parameter. In these cases, before doing neighbor search
72  * a query tree with proper leafSize is built from the querySet.
73  */
74 template<typename SortPolicy>
75 class BiSearchVisitor : public boost::static_visitor<void>
76 {
77  private:
78   //! The query set for the bichromatic search.
79   const arma::mat& querySet;
80   //! The number of neighbors to search for.
81   const size_t k;
82   //! The results matrix for neighbors.
83   arma::Mat<size_t>& neighbors;
84   //! The result matrix for distances.
85   arma::mat& distances;
86   //! The number of points in a leaf (for BinarySpaceTrees).
87   const size_t leafSize;
88 
89   //! Bichromatic neighbor search on the given RAType considering leafSize.
90   template<typename RAType>
91   void SearchLeaf(RAType* ra) const;
92 
93  public:
94   //! Alias template necessary for visual c++ compiler.
95   template<template<typename TreeMetricType,
96                     typename TreeStatType,
97                     typename TreeMatType> class TreeType>
98   using RATypeT = RAType<SortPolicy, TreeType>;
99 
100   //! Default Bichromatic neighbor search on the given RAType instance.
101   template<template<typename TreeMetricType,
102                     typename TreeStatType,
103                     typename TreeMatType> class TreeType>
104   void operator()(RATypeT<TreeType>* ra) const;
105 
106   //! Bichromatic search on the given RAType specialized for KDTrees.
107   void operator()(RATypeT<tree::KDTree>* ra) const;
108 
109   //! Bichromatic search on the given RAType specialized for octrees.
110   void operator()(RATypeT<tree::Octree>* ra) const;
111 
112   //! Construct the BiSearchVisitor.
113   BiSearchVisitor(const arma::mat& querySet,
114                   const size_t k,
115                   arma::Mat<size_t>& neighbors,
116                   arma::mat& distances,
117                   const size_t leafSize);
118 };
119 
120 /**
121  * TrainVisitor sets the reference set to a new reference set on the given
122  * RAType. We use template specialization to differentiate those trees that
123  * accept leafSize as a parameter. In these cases, a reference tree with proper
124  * leafSize is built from the referenceSet.
125  */
126 template<typename SortPolicy>
127 class TrainVisitor : public boost::static_visitor<void>
128 {
129  private:
130   //! The reference set to use for training.
131   arma::mat&& referenceSet;
132   //! The leaf size, used only by BinarySpaceTree.
133   size_t leafSize;
134 
135   //! Train on the given RAType considering the leafSize.
136   template<typename RAType>
137   void TrainLeaf(RAType* ra) const;
138 
139  public:
140   //! Alias template necessary for visual c++ compiler.
141   template<template<typename TreeMetricType,
142                     typename TreeStatType,
143                     typename TreeMatType> class TreeType>
144   using RATypeT = RAType<SortPolicy, TreeType>;
145 
146   //! Default Train on the given RAType instance.
147   template<template<typename TreeMetricType,
148                     typename TreeStatType,
149                     typename TreeMatType> class TreeType>
150   void operator()(RATypeT<TreeType>* ra) const;
151 
152   //! Train on the given RAType specialized for KDTrees.
153   void operator()(RATypeT<tree::KDTree>* ra) const;
154 
155   //! Train on the given RAType specialized for Octrees.
156   void operator()(RATypeT<tree::Octree>* ra) const;
157 
158   //! Construct the TrainVisitor object with the given reference set, leafSize
159   //! for BinarySpaceTrees.
160   TrainVisitor(arma::mat&& referenceSet,
161                const size_t leafSize);
162 };
163 
164 /**
165  * Exposes the SingleSampleLimit() method of the given RAType.
166  */
167 class SingleSampleLimitVisitor : public boost::static_visitor<size_t&>
168 {
169  public:
170   template<typename RAType>
171   size_t& operator()(RAType* ra) const;
172 };
173 
174 /**
175  * Exposes the FirstLeafExact() method of the given RAType.
176  */
177 class FirstLeafExactVisitor : public boost::static_visitor<bool&>
178 {
179  public:
180   template<typename RAType>
181   bool& operator()(RAType* ra) const;
182 };
183 
184 /**
185  * Exposes the SampleAtLeaves() method of the given RAType.
186  */
187 class SampleAtLeavesVisitor : public boost::static_visitor<bool&>
188 {
189  public:
190   //! Return SampleAtLeaves (whether or not sampling is done at leaves).
191   template<typename RAType>
192   bool& operator()(RAType *) const;
193 };
194 
195 /**
196  * Exposes the Alpha() method of the given RAType.
197  */
198 class AlphaVisitor : public boost::static_visitor<double&>
199 {
200  public:
201   //! Return Alpha parameter.
202   template<typename RAType>
203   double& operator()(RAType* ra) const;
204 };
205 
206 /**
207  * Exposes the Tau() method of the given RAType.
208  */
209 class TauVisitor : public boost::static_visitor<double&>
210 {
211  public:
212   //! Get a reference to the Tau parameter.
213   template<typename RAType>
214   double& operator()(RAType* ra) const;
215 };
216 
217 /**
218  * Exposes the SingleMode() method of the given RAType.
219  */
220 class SingleModeVisitor : public boost::static_visitor<bool&>
221 {
222  public:
223   //! Get a reference to the SingleMode parameter of the given RASearch object.
224   template<typename RAType>
225   bool& operator()(RAType* ra) const;
226 };
227 
228 /**
229  * Exposes the referenceSet of the given RAType.
230  */
231 class ReferenceSetVisitor : public boost::static_visitor<const arma::mat&>
232 {
233  public:
234   //! Return the reference set.
235   template<typename RAType>
236   const arma::mat& operator()(RAType* ra) const;
237 };
238 
239 /**
240  * DeleteVisitor deletes the give RAType Instance.
241  */
242 class DeleteVisitor : public boost::static_visitor<void>
243 {
244  public:
245   //! Delete the RAType Object.
246   template<typename RAType> void operator()(RAType* ra) const;
247 };
248 
249 /**
250  * NaiveVisitor exposes the Naive() method of the given RAType.
251  */
252 class NaiveVisitor : public boost::static_visitor<bool&>
253 {
254  public:
255   /**
256    * Get a reference to the naive parameter of the given RASearch object.
257    */
258   template<typename RAType>
259   bool& operator()(RAType* ra) const;
260 };
261 
262 /**
263  * The RAModel class provides an abstraction for the RASearch class, abstracting
264  * away the TreeType parameter and allowing it to be specified at runtime in
265  * this class.  This class is written for the sake of the 'allkrann' program,
266  * but is not necessarily restricted to that use.
267  *
268  * @param SortPolicy Sorting policy for neighbor searching (see RASearch).
269  */
270 template<typename SortPolicy>
271 class RAModel
272 {
273  public:
274   /**
275    * The list of tree types we can use with RASearch.  Does not include ball
276    * trees; see #338.
277    */
278   enum TreeTypes
279   {
280     KD_TREE,
281     COVER_TREE,
282     R_TREE,
283     R_STAR_TREE,
284     X_TREE,
285     HILBERT_R_TREE,
286     R_PLUS_TREE,
287     R_PLUS_PLUS_TREE,
288     UB_TREE,
289     OCTREE
290   };
291 
292  private:
293   //! The type of tree being used.
294   TreeTypes treeType;
295   //! The leaf size of the tree being used (useful only for the kd-tree).
296   size_t leafSize;
297 
298   //! If true, randomly project into a new basis.
299   bool randomBasis;
300   //! The basis to project into.
301   arma::mat q;
302 
303   //! The rank-approximate model.
304   boost::variant<RAType<SortPolicy, tree::KDTree>*,
305                  RAType<SortPolicy, tree::StandardCoverTree>*,
306                  RAType<SortPolicy, tree::RTree>*,
307                  RAType<SortPolicy, tree::RStarTree>*,
308                  RAType<SortPolicy, tree::XTree>*,
309                  RAType<SortPolicy, tree::HilbertRTree>*,
310                  RAType<SortPolicy, tree::RPlusTree>*,
311                  RAType<SortPolicy, tree::RPlusPlusTree>*,
312                  RAType<SortPolicy, tree::UBTree>*,
313                  RAType<SortPolicy, tree::Octree>*> raSearch;
314 
315  public:
316   /**
317    * Initialize the RAModel with the given type and whether or not a random
318    * basis should be used.
319    */
320   RAModel(TreeTypes treeType = TreeTypes::KD_TREE, bool randomBasis = false);
321 
322   /**
323    * Copy the given RAModel.
324    *
325    * @param other RAModel to copy.
326    */
327   RAModel(const RAModel& other);
328 
329   /**
330    * Take ownership of the given RAModel.
331    *
332    * @param other RAModel to take ownership of.
333    */
334   RAModel(RAModel&& other);
335 
336   /**
337    * Copy the given RAModel.
338    *
339    * @param other RAModel to copy.
340    */
341   RAModel& operator=(const RAModel& other);
342 
343   /**
344    * Take ownership of the given RAModel.
345    *
346    * @param other RAModel to take ownership of.
347    */
348   RAModel& operator=(RAModel&& other);
349 
350   //! Clean memory, if necessary.
351   ~RAModel();
352 
353   //! Serialize the model.
354   template<typename Archive>
355   void serialize(Archive& ar, const unsigned int /* version */);
356 
357   //! Expose the dataset.
358   const arma::mat& Dataset() const;
359 
360   //! Get whether or not single-tree search is being used.
361   bool SingleMode() const;
362   //! Modify whether or not single-tree search is being used.
363   bool& SingleMode();
364 
365   //! Get whether or not naive search is being used.
366   bool Naive() const;
367   //! Modify whether or not naive search is being used.
368   bool& Naive();
369 
370   //! Get the rank-approximation in percentile of the data.
371   double Tau() const;
372   //! Modify the rank-approximation in percentile of the data.
373   double& Tau();
374 
375   //! Get the desired success probability.
376   double Alpha() const;
377   //! Modify the desired success probability.
378   double& Alpha();
379 
380   //! Get whether or not sampling is done at the leaves.
381   bool SampleAtLeaves() const;
382   //! Modify whether or not sampling is done at the leaves.
383   bool& SampleAtLeaves();
384 
385   //! Get whether or not we traverse to the first leaf without approximation.
386   bool FirstLeafExact() const;
387   //! Modify whether or not we traverse to the first leaf without approximation.
388   bool& FirstLeafExact();
389 
390   //! Get the limit on the size of a node that can be approximated.
391   size_t SingleSampleLimit() const;
392   //! Modify the limit on the size of a node that can be approximation.
393   size_t& SingleSampleLimit();
394 
395   //! Get the leaf size (only relevant when the kd-tree is used).
396   size_t LeafSize() const;
397   //! Modify the leaf size (only relevant when the kd-tree is used).
398   size_t& LeafSize();
399 
400   //! Get the type of tree being used.
401   TreeTypes TreeType() const;
402   //! Modify the type of tree being used.
403   TreeTypes& TreeType();
404 
405   //! Get whether or not a random basis is being used.
406   bool RandomBasis() const;
407   //! Modify whether or not a random basis is being used.  Be sure to rebuild
408   //! the model using BuildModel().
409   bool& RandomBasis();
410 
411   //! Build the reference tree.
412   void BuildModel(arma::mat&& referenceSet,
413                   const size_t leafSize,
414                   const bool naive,
415                   const bool singleMode);
416 
417   //! Perform rank-approximate neighbor search, taking ownership of the query
418   //! set.
419   void Search(arma::mat&& querySet,
420               const size_t k,
421               arma::Mat<size_t>& neighbors,
422               arma::mat& distances);
423 
424   /**
425    * Perform rank-approximate neighbor search, using the reference set as the
426    * query set.
427    */
428   void Search(const size_t k,
429               arma::Mat<size_t>& neighbors,
430               arma::mat& distances);
431 
432   //! Get the name of the tree type.
433   std::string TreeName() const;
434 };
435 
436 } // namespace neighbor
437 } // namespace mlpack
438 
439 #include "ra_model_impl.hpp"
440 
441 #endif
442