1 /**
2  * @file methods/neighbor_search/ns_model.hpp
3  * @author Ryan Curtin
4  *
5  * This is a model for nearest or furthest neighbor search.  It is useful in
6  * that it provides an easy way to serialize a model, abstracts away the
7  * different types of trees, and also reflects the NeighborSearch API and
8  * automatically directs to the right tree type.
9  *
10  * mlpack is free software; you may redistribute it and/or modify it under the
11  * terms of the 3-clause BSD license.  You should have received a copy of the
12  * 3-clause BSD license along with mlpack.  If not, see
13  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
14  */
15 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP
16 #define MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP
17 
18 #include <mlpack/core/tree/binary_space_tree.hpp>
19 #include <mlpack/core/tree/cover_tree.hpp>
20 #include <mlpack/core/tree/rectangle_tree.hpp>
21 #include <mlpack/core/tree/spill_tree.hpp>
22 #include <mlpack/core/tree/octree.hpp>
23 #include <boost/variant.hpp>
24 #include "neighbor_search.hpp"
25 
26 namespace mlpack {
27 namespace neighbor {
28 
29 /**
30  * Alias template for euclidean neighbor search.
31  */
32 template<typename SortPolicy,
33          template<typename TreeMetricType,
34                   typename TreeStatType,
35                   typename TreeMatType> class TreeType>
36 using NSType = NeighborSearch<SortPolicy,
37                               metric::EuclideanDistance,
38                               arma::mat,
39                               TreeType,
40                               TreeType<metric::EuclideanDistance,
41                                   NeighborSearchStat<SortPolicy>,
42                                   arma::mat>::template DualTreeTraverser>;
43 
44 /**
45  * MonoSearchVisitor executes a monochromatic neighbor search on the given
46  * NSType. We don't make any difference for different instantiations of NSType.
47  */
48 class MonoSearchVisitor : public boost::static_visitor<void>
49 {
50  private:
51   //! Number of neighbors to search for.
52   const size_t k;
53   //! Result matrix for neighbors.
54   arma::Mat<size_t>& neighbors;
55   //! Result matrix for distances.
56   arma::mat& distances;
57 
58  public:
59   //! Perform monochromatic nearest neighbor search.
60   template<typename NSType>
61   void operator()(NSType* ns) const;
62 
63   //! Construct the MonoSearchVisitor object with the given parameters.
MonoSearchVisitor(const size_t k,arma::Mat<size_t> & neighbors,arma::mat & distances)64   MonoSearchVisitor(const size_t k,
65                     arma::Mat<size_t>& neighbors,
66                     arma::mat& distances) :
67       k(k),
68       neighbors(neighbors),
69       distances(distances)
70   {};
71 };
72 
73 /**
74  * BiSearchVisitor executes a bichromatic neighbor search on the given NSType.
75  * We use template specialization to differentiate those tree types that
76  * accept leafSize as a parameter. In these cases, before doing neighbor search,
77  * a query tree with proper leafSize is built from the querySet.
78  */
79 template<typename SortPolicy>
80 class BiSearchVisitor : public boost::static_visitor<void>
81 {
82  private:
83   //! The query set for the bichromatic search.
84   const arma::mat& querySet;
85   //! The number of neighbors to search for.
86   const size_t k;
87   //! The result matrix for neighbors.
88   arma::Mat<size_t>& neighbors;
89   //! The result matrix for distances.
90   arma::mat& distances;
91   //! The number of points in a leaf (for BinarySpaceTrees).
92   const size_t leafSize;
93   //! Overlapping size (for spill trees).
94   const double tau;
95   //! Balance threshold (for spill trees).
96   const double rho;
97 
98   //! Bichromatic neighbor search on the given NSType considering the leafSize.
99   template<typename NSType>
100   void SearchLeaf(NSType* ns) const;
101 
102  public:
103   //! Alias template necessary for visual c++ compiler.
104   template<template<typename TreeMetricType,
105                     typename TreeStatType,
106                     typename TreeMatType> class TreeType>
107   using NSTypeT = NSType<SortPolicy, TreeType>;
108 
109   //! Default Bichromatic neighbor search on the given NSType instance.
110   template<template<typename TreeMetricType,
111                     typename TreeStatType,
112                     typename TreeMatType> class TreeType>
113   void operator()(NSTypeT<TreeType>* ns) const;
114 
115   //! Bichromatic neighbor search on the given NSType specialized for KDTrees.
116   void operator()(NSTypeT<tree::KDTree>* ns) const;
117 
118   //! Bichromatic neighbor search on the given NSType specialized for BallTrees.
119   void operator()(NSTypeT<tree::BallTree>* ns) const;
120 
121   //! Bichromatic neighbor search specialized for SPTrees.
122   void operator()(SpillKNN* ns) const;
123 
124   //! Bichromatic neighbor search specialized for octrees.
125   void operator()(NSTypeT<tree::Octree>* ns) const;
126 
127   //! Construct the BiSearchVisitor.
128   BiSearchVisitor(const arma::mat& querySet,
129                   const size_t k,
130                   arma::Mat<size_t>& neighbors,
131                   arma::mat& distances,
132                   const size_t leafSize,
133                   const double tau,
134                   const double rho);
135 };
136 
137 /**
138  * TrainVisitor sets the reference set to a new reference set on the given
139  * NSType. We use template specialization to differentiate those tree types that
140  * accept leafSize as a parameter. In these cases, a reference tree with proper
141  * leafSize is built from the referenceSet.
142  */
143 template<typename SortPolicy>
144 class TrainVisitor : public boost::static_visitor<void>
145 {
146  private:
147   //! The reference set to use for training.
148   arma::mat&& referenceSet;
149   //! The leaf size, used only by BinarySpaceTree.
150   size_t leafSize;
151   //! Overlapping size (for spill trees).
152   const double tau;
153   //! Balance threshold (for spill trees).
154   const double rho;
155 
156   //! Train on the given NSType considering the leafSize.
157   template<typename NSType>
158   void TrainLeaf(NSType* ns) const;
159 
160  public:
161   //! Alias template necessary for visual c++ compiler.
162   template<template<typename TreeMetricType,
163                     typename TreeStatType,
164                     typename TreeMatType> class TreeType>
165   using NSTypeT = NSType<SortPolicy, TreeType>;
166 
167   //! Default Train on the given NSType instance.
168   template<template<typename TreeMetricType,
169                     typename TreeStatType,
170                     typename TreeMatType> class TreeType>
171   void operator()(NSTypeT<TreeType>* ns) const;
172 
173   //! Train on the given NSType specialized for KDTrees.
174   void operator()(NSTypeT<tree::KDTree>* ns) const;
175 
176   //! Train on the given NSType specialized for BallTrees.
177   void operator()(NSTypeT<tree::BallTree>* ns) const;
178 
179   //! Train specialized for SPTrees.
180   void operator()(SpillKNN* ns) const;
181 
182   //! Train specialized for octrees.
183   void operator()(NSTypeT<tree::Octree>* ns) const;
184 
185   //! Construct the TrainVisitor object with the given reference set, leafSize
186   //! for BinarySpaceTrees, and tau and rho for spill trees.
187   TrainVisitor(arma::mat&& referenceSet,
188                const size_t leafSize,
189                const double tau,
190                const double rho);
191 };
192 
193 /**
194  * SearchModeVisitor exposes the SearchMode() method of the given NSType.
195  */
196 class SearchModeVisitor : public boost::static_visitor<NeighborSearchMode&>
197 {
198  public:
199   //! Return the search mode.
200   template<typename NSType>
201   NeighborSearchMode& operator()(NSType* ns) const;
202 };
203 
204 /**
205  * EpsilonVisitor exposes the Epsilon method of the given NSType.
206  */
207 class EpsilonVisitor : public boost::static_visitor<double&>
208 {
209  public:
210   //! Return epsilon, the approximation parameter.
211   template<typename NSType>
212   double& operator()(NSType *ns) const;
213 };
214 
215 /**
216  * ReferenceSetVisitor exposes the referenceSet of the given NSType.
217  */
218 class ReferenceSetVisitor : public boost::static_visitor<const arma::mat&>
219 {
220  public:
221   //! Return the reference set.
222   template<typename NSType>
223   const arma::mat& operator()(NSType *ns) const;
224 };
225 
226 /**
227  * DeleteVisitor deletes the given NSType instance.
228  */
229 class DeleteVisitor : public boost::static_visitor<void>
230 {
231  public:
232   //! Delete the NSType object.
233   template<typename NSType>
234   void operator()(NSType *ns) const;
235 };
236 
237 /**
238  * The NSModel class provides an easy way to serialize a model, abstracts away
239  * the different types of trees, and also reflects the NeighborSearch API.  This
240  * class is meant to be used by the command-line mlpack_knn and mlpack_kfn
241  * programs, and thus does not have the same complete functionality and
242  * flexibility as the NeighborSearch class.  So if you are using it outside of
243  * mlpack_knn and mlpack_kfn, be aware that it is limited!
244  *
245  * @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
246  */
247 template<typename SortPolicy>
248 class NSModel
249 {
250  public:
251   //! Enum type to identify each accepted tree type.
252   enum TreeTypes
253   {
254     KD_TREE,
255     COVER_TREE,
256     R_TREE,
257     R_STAR_TREE,
258     BALL_TREE,
259     X_TREE,
260     HILBERT_R_TREE,
261     R_PLUS_TREE,
262     R_PLUS_PLUS_TREE,
263     VP_TREE,
264     RP_TREE,
265     MAX_RP_TREE,
266     SPILL_TREE,
267     UB_TREE,
268     OCTREE
269   };
270 
271  private:
272   //! Tree type considered for neighbor search.
273   TreeTypes treeType;
274 
275   //! For tree types that accept the maxLeafSize parameter.
276   size_t leafSize;
277 
278   //! Overlapping size (for spill trees).
279   double tau;
280   //! Balance threshold (for spill trees).
281   double rho;
282 
283   //! If true, random projections are used.
284   bool randomBasis;
285   //! This is the random projection matrix; only used if randomBasis is true.
286   arma::mat q;
287 
288   /**
289    * nSearch holds an instance of the NeigborSearch class for the current
290    * treeType. It is initialized every time BuildModel is executed.
291    * We access to the contained value through the visitor classes defined above.
292    */
293   boost::variant<NSType<SortPolicy, tree::KDTree>*,
294                  NSType<SortPolicy, tree::StandardCoverTree>*,
295                  NSType<SortPolicy, tree::RTree>*,
296                  NSType<SortPolicy, tree::RStarTree>*,
297                  NSType<SortPolicy, tree::BallTree>*,
298                  NSType<SortPolicy, tree::XTree>*,
299                  NSType<SortPolicy, tree::HilbertRTree>*,
300                  NSType<SortPolicy, tree::RPlusTree>*,
301                  NSType<SortPolicy, tree::RPlusPlusTree>*,
302                  NSType<SortPolicy, tree::VPTree>*,
303                  NSType<SortPolicy, tree::RPTree>*,
304                  NSType<SortPolicy, tree::MaxRPTree>*,
305                  SpillKNN*,
306                  NSType<SortPolicy, tree::UBTree>*,
307                  NSType<SortPolicy, tree::Octree>*> nSearch;
308 
309  public:
310   /**
311    * Initialize the NSModel with the given type and whether or not a random
312    * basis should be used.
313    *
314    * @param treeType Type of tree to use.
315    * @param randomBasis Whether or not to project the points onto a random basis
316    *      before searching.
317    */
318   NSModel(TreeTypes treeType = TreeTypes::KD_TREE, bool randomBasis = false);
319 
320   /**
321    * Copy the given NSModel.
322    *
323    * @param other Model to copy.
324    */
325   NSModel(const NSModel& other);
326 
327   /**
328    * Take ownership of the given NSModel.
329    *
330    * @param other Model to take ownership of.
331    */
332   NSModel(NSModel&& other);
333 
334   /**
335    * Copy the given NSModel.
336    *
337    * @param other Model to copy.
338    */
339   NSModel& operator=(const NSModel& other);
340 
341   /**
342    * Take ownership of the given NSModel.
343    *
344    * @param other Model to take ownership of.
345    */
346   NSModel& operator=(NSModel&& other);
347 
348   //! Clean memory, if necessary.
349   ~NSModel();
350 
351   //! Serialize the neighbor search model.
352   template<typename Archive>
353   void serialize(Archive& ar, const unsigned int /* version */);
354 
355   //! Expose the dataset.
356   const arma::mat& Dataset() const;
357 
358   //! Expose SearchMode.
359   NeighborSearchMode SearchMode() const;
360   NeighborSearchMode& SearchMode();
361 
362   //! Expose Epsilon.
363   double Epsilon() const;
364   double& Epsilon();
365 
366   //! Expose leafSize.
LeafSize() const367   size_t LeafSize() const { return leafSize; }
LeafSize()368   size_t& LeafSize() { return leafSize; }
369 
370   //! Expose tau.
Tau() const371   double Tau() const { return tau; }
Tau()372   double& Tau() { return tau; }
373 
374   //! Expose rho.
Rho() const375   double Rho() const { return rho; }
Rho()376   double& Rho() { return rho; }
377 
378   //! Expose treeType.
TreeType() const379   TreeTypes TreeType() const { return treeType; }
TreeType()380   TreeTypes& TreeType() { return treeType; }
381 
382   //! Expose randomBasis.
RandomBasis() const383   bool RandomBasis() const { return randomBasis; }
RandomBasis()384   bool& RandomBasis() { return randomBasis; }
385 
386   //! Build the reference tree.
387   void BuildModel(arma::mat&& referenceSet,
388                   const size_t leafSize,
389                   const NeighborSearchMode searchMode,
390                   const double epsilon = 0);
391 
392   //! Perform neighbor search.  The query set will be reordered.
393   void Search(arma::mat&& querySet,
394               const size_t k,
395               arma::Mat<size_t>& neighbors,
396               arma::mat& distances);
397 
398   //! Perform monochromatic neighbor search.
399   void Search(const size_t k,
400               arma::Mat<size_t>& neighbors,
401               arma::mat& distances);
402 
403   //! Return a string representation of the current tree type.
404   std::string TreeName() const;
405 };
406 
407 } // namespace neighbor
408 } // namespace mlpack
409 
410 //! Set the serialization version of the NSModel class.
411 BOOST_TEMPLATE_CLASS_VERSION(template<typename SortPolicy>,
412     mlpack::neighbor::NSModel<SortPolicy>, 1);
413 
414 // Include implementation.
415 #include "ns_model_impl.hpp"
416 
417 #endif
418