1 /** 2 * @file methods/rann/ra_model.hpp 3 * @author Ryan Curtin 4 * 5 * This is a model for rank-approximate nearest neighbor search. It provides an 6 * easy way to serialize a rank-approximate neighbor search model by abstracting 7 * the types of trees and reflecting the RASearch API. 8 * 9 * mlpack is free software; you may redistribute it and/or modify it under the 10 * terms of the 3-clause BSD license. You should have received a copy of the 11 * 3-clause BSD license along with mlpack. If not, see 12 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 13 */ 14 #ifndef MLPACK_METHODS_RANN_RA_MODEL_HPP 15 #define MLPACK_METHODS_RANN_RA_MODEL_HPP 16 17 #include <mlpack/core/tree/binary_space_tree.hpp> 18 #include <mlpack/core/tree/cover_tree.hpp> 19 #include <mlpack/core/tree/rectangle_tree.hpp> 20 #include <mlpack/core/tree/octree.hpp> 21 #include <boost/variant.hpp> 22 #include "ra_search.hpp" 23 24 namespace mlpack { 25 namespace neighbor { 26 27 /** 28 * Alias template for RASearch 29 */ 30 template<typename SortPolicy, 31 template<typename TreeMetricType, 32 typename TreeStatType, 33 typename TreeMatType> class TreeType> 34 using RAType = RASearch<SortPolicy, 35 metric::EuclideanDistance, 36 arma::mat, 37 TreeType>; 38 39 /** 40 * MonoSearchVisitor executes a monochromatic neighbor search on the given 41 * RAType. We don't make any difference for different instantiation of RAType. 42 */ 43 class MonoSearchVisitor : public boost::static_visitor<void> 44 { 45 private: 46 //! Number of neighbors to search for. 47 const size_t k; 48 //! Result matrix for neighbors. 49 arma::Mat<size_t>& neighbors; 50 //! Result matrix for distances. 51 arma::mat& distances; 52 53 public: 54 //! Perform monochromatic nearest neighbor search. 55 template<typename RAType> 56 void operator()(RAType* ra) const; 57 58 //! Construct the MonoSearchVisitor object with the given parameters. MonoSearchVisitor(const size_t k,arma::Mat<size_t> & neighbors,arma::mat & distances)59 MonoSearchVisitor(const size_t k, 60 arma::Mat<size_t>& neighbors, 61 arma::mat& distances) : 62 k(k), 63 neighbors(neighbors), 64 distances(distances) 65 {}; 66 }; 67 68 /** 69 * BiSearchVisitor executes a bichromatic neighbor search on the given RAType. 70 * We use template specialization to differentiate those tree types types that 71 * accept leafSize as a parameter. In these cases, before doing neighbor search 72 * a query tree with proper leafSize is built from the querySet. 73 */ 74 template<typename SortPolicy> 75 class BiSearchVisitor : public boost::static_visitor<void> 76 { 77 private: 78 //! The query set for the bichromatic search. 79 const arma::mat& querySet; 80 //! The number of neighbors to search for. 81 const size_t k; 82 //! The results matrix for neighbors. 83 arma::Mat<size_t>& neighbors; 84 //! The result matrix for distances. 85 arma::mat& distances; 86 //! The number of points in a leaf (for BinarySpaceTrees). 87 const size_t leafSize; 88 89 //! Bichromatic neighbor search on the given RAType considering leafSize. 90 template<typename RAType> 91 void SearchLeaf(RAType* ra) const; 92 93 public: 94 //! Alias template necessary for visual c++ compiler. 95 template<template<typename TreeMetricType, 96 typename TreeStatType, 97 typename TreeMatType> class TreeType> 98 using RATypeT = RAType<SortPolicy, TreeType>; 99 100 //! Default Bichromatic neighbor search on the given RAType instance. 101 template<template<typename TreeMetricType, 102 typename TreeStatType, 103 typename TreeMatType> class TreeType> 104 void operator()(RATypeT<TreeType>* ra) const; 105 106 //! Bichromatic search on the given RAType specialized for KDTrees. 107 void operator()(RATypeT<tree::KDTree>* ra) const; 108 109 //! Bichromatic search on the given RAType specialized for octrees. 110 void operator()(RATypeT<tree::Octree>* ra) const; 111 112 //! Construct the BiSearchVisitor. 113 BiSearchVisitor(const arma::mat& querySet, 114 const size_t k, 115 arma::Mat<size_t>& neighbors, 116 arma::mat& distances, 117 const size_t leafSize); 118 }; 119 120 /** 121 * TrainVisitor sets the reference set to a new reference set on the given 122 * RAType. We use template specialization to differentiate those trees that 123 * accept leafSize as a parameter. In these cases, a reference tree with proper 124 * leafSize is built from the referenceSet. 125 */ 126 template<typename SortPolicy> 127 class TrainVisitor : public boost::static_visitor<void> 128 { 129 private: 130 //! The reference set to use for training. 131 arma::mat&& referenceSet; 132 //! The leaf size, used only by BinarySpaceTree. 133 size_t leafSize; 134 135 //! Train on the given RAType considering the leafSize. 136 template<typename RAType> 137 void TrainLeaf(RAType* ra) const; 138 139 public: 140 //! Alias template necessary for visual c++ compiler. 141 template<template<typename TreeMetricType, 142 typename TreeStatType, 143 typename TreeMatType> class TreeType> 144 using RATypeT = RAType<SortPolicy, TreeType>; 145 146 //! Default Train on the given RAType instance. 147 template<template<typename TreeMetricType, 148 typename TreeStatType, 149 typename TreeMatType> class TreeType> 150 void operator()(RATypeT<TreeType>* ra) const; 151 152 //! Train on the given RAType specialized for KDTrees. 153 void operator()(RATypeT<tree::KDTree>* ra) const; 154 155 //! Train on the given RAType specialized for Octrees. 156 void operator()(RATypeT<tree::Octree>* ra) const; 157 158 //! Construct the TrainVisitor object with the given reference set, leafSize 159 //! for BinarySpaceTrees. 160 TrainVisitor(arma::mat&& referenceSet, 161 const size_t leafSize); 162 }; 163 164 /** 165 * Exposes the SingleSampleLimit() method of the given RAType. 166 */ 167 class SingleSampleLimitVisitor : public boost::static_visitor<size_t&> 168 { 169 public: 170 template<typename RAType> 171 size_t& operator()(RAType* ra) const; 172 }; 173 174 /** 175 * Exposes the FirstLeafExact() method of the given RAType. 176 */ 177 class FirstLeafExactVisitor : public boost::static_visitor<bool&> 178 { 179 public: 180 template<typename RAType> 181 bool& operator()(RAType* ra) const; 182 }; 183 184 /** 185 * Exposes the SampleAtLeaves() method of the given RAType. 186 */ 187 class SampleAtLeavesVisitor : public boost::static_visitor<bool&> 188 { 189 public: 190 //! Return SampleAtLeaves (whether or not sampling is done at leaves). 191 template<typename RAType> 192 bool& operator()(RAType *) const; 193 }; 194 195 /** 196 * Exposes the Alpha() method of the given RAType. 197 */ 198 class AlphaVisitor : public boost::static_visitor<double&> 199 { 200 public: 201 //! Return Alpha parameter. 202 template<typename RAType> 203 double& operator()(RAType* ra) const; 204 }; 205 206 /** 207 * Exposes the Tau() method of the given RAType. 208 */ 209 class TauVisitor : public boost::static_visitor<double&> 210 { 211 public: 212 //! Get a reference to the Tau parameter. 213 template<typename RAType> 214 double& operator()(RAType* ra) const; 215 }; 216 217 /** 218 * Exposes the SingleMode() method of the given RAType. 219 */ 220 class SingleModeVisitor : public boost::static_visitor<bool&> 221 { 222 public: 223 //! Get a reference to the SingleMode parameter of the given RASearch object. 224 template<typename RAType> 225 bool& operator()(RAType* ra) const; 226 }; 227 228 /** 229 * Exposes the referenceSet of the given RAType. 230 */ 231 class ReferenceSetVisitor : public boost::static_visitor<const arma::mat&> 232 { 233 public: 234 //! Return the reference set. 235 template<typename RAType> 236 const arma::mat& operator()(RAType* ra) const; 237 }; 238 239 /** 240 * DeleteVisitor deletes the give RAType Instance. 241 */ 242 class DeleteVisitor : public boost::static_visitor<void> 243 { 244 public: 245 //! Delete the RAType Object. 246 template<typename RAType> void operator()(RAType* ra) const; 247 }; 248 249 /** 250 * NaiveVisitor exposes the Naive() method of the given RAType. 251 */ 252 class NaiveVisitor : public boost::static_visitor<bool&> 253 { 254 public: 255 /** 256 * Get a reference to the naive parameter of the given RASearch object. 257 */ 258 template<typename RAType> 259 bool& operator()(RAType* ra) const; 260 }; 261 262 /** 263 * The RAModel class provides an abstraction for the RASearch class, abstracting 264 * away the TreeType parameter and allowing it to be specified at runtime in 265 * this class. This class is written for the sake of the 'allkrann' program, 266 * but is not necessarily restricted to that use. 267 * 268 * @param SortPolicy Sorting policy for neighbor searching (see RASearch). 269 */ 270 template<typename SortPolicy> 271 class RAModel 272 { 273 public: 274 /** 275 * The list of tree types we can use with RASearch. Does not include ball 276 * trees; see #338. 277 */ 278 enum TreeTypes 279 { 280 KD_TREE, 281 COVER_TREE, 282 R_TREE, 283 R_STAR_TREE, 284 X_TREE, 285 HILBERT_R_TREE, 286 R_PLUS_TREE, 287 R_PLUS_PLUS_TREE, 288 UB_TREE, 289 OCTREE 290 }; 291 292 private: 293 //! The type of tree being used. 294 TreeTypes treeType; 295 //! The leaf size of the tree being used (useful only for the kd-tree). 296 size_t leafSize; 297 298 //! If true, randomly project into a new basis. 299 bool randomBasis; 300 //! The basis to project into. 301 arma::mat q; 302 303 //! The rank-approximate model. 304 boost::variant<RAType<SortPolicy, tree::KDTree>*, 305 RAType<SortPolicy, tree::StandardCoverTree>*, 306 RAType<SortPolicy, tree::RTree>*, 307 RAType<SortPolicy, tree::RStarTree>*, 308 RAType<SortPolicy, tree::XTree>*, 309 RAType<SortPolicy, tree::HilbertRTree>*, 310 RAType<SortPolicy, tree::RPlusTree>*, 311 RAType<SortPolicy, tree::RPlusPlusTree>*, 312 RAType<SortPolicy, tree::UBTree>*, 313 RAType<SortPolicy, tree::Octree>*> raSearch; 314 315 public: 316 /** 317 * Initialize the RAModel with the given type and whether or not a random 318 * basis should be used. 319 */ 320 RAModel(TreeTypes treeType = TreeTypes::KD_TREE, bool randomBasis = false); 321 322 /** 323 * Copy the given RAModel. 324 * 325 * @param other RAModel to copy. 326 */ 327 RAModel(const RAModel& other); 328 329 /** 330 * Take ownership of the given RAModel. 331 * 332 * @param other RAModel to take ownership of. 333 */ 334 RAModel(RAModel&& other); 335 336 /** 337 * Copy the given RAModel. 338 * 339 * @param other RAModel to copy. 340 */ 341 RAModel& operator=(const RAModel& other); 342 343 /** 344 * Take ownership of the given RAModel. 345 * 346 * @param other RAModel to take ownership of. 347 */ 348 RAModel& operator=(RAModel&& other); 349 350 //! Clean memory, if necessary. 351 ~RAModel(); 352 353 //! Serialize the model. 354 template<typename Archive> 355 void serialize(Archive& ar, const unsigned int /* version */); 356 357 //! Expose the dataset. 358 const arma::mat& Dataset() const; 359 360 //! Get whether or not single-tree search is being used. 361 bool SingleMode() const; 362 //! Modify whether or not single-tree search is being used. 363 bool& SingleMode(); 364 365 //! Get whether or not naive search is being used. 366 bool Naive() const; 367 //! Modify whether or not naive search is being used. 368 bool& Naive(); 369 370 //! Get the rank-approximation in percentile of the data. 371 double Tau() const; 372 //! Modify the rank-approximation in percentile of the data. 373 double& Tau(); 374 375 //! Get the desired success probability. 376 double Alpha() const; 377 //! Modify the desired success probability. 378 double& Alpha(); 379 380 //! Get whether or not sampling is done at the leaves. 381 bool SampleAtLeaves() const; 382 //! Modify whether or not sampling is done at the leaves. 383 bool& SampleAtLeaves(); 384 385 //! Get whether or not we traverse to the first leaf without approximation. 386 bool FirstLeafExact() const; 387 //! Modify whether or not we traverse to the first leaf without approximation. 388 bool& FirstLeafExact(); 389 390 //! Get the limit on the size of a node that can be approximated. 391 size_t SingleSampleLimit() const; 392 //! Modify the limit on the size of a node that can be approximation. 393 size_t& SingleSampleLimit(); 394 395 //! Get the leaf size (only relevant when the kd-tree is used). 396 size_t LeafSize() const; 397 //! Modify the leaf size (only relevant when the kd-tree is used). 398 size_t& LeafSize(); 399 400 //! Get the type of tree being used. 401 TreeTypes TreeType() const; 402 //! Modify the type of tree being used. 403 TreeTypes& TreeType(); 404 405 //! Get whether or not a random basis is being used. 406 bool RandomBasis() const; 407 //! Modify whether or not a random basis is being used. Be sure to rebuild 408 //! the model using BuildModel(). 409 bool& RandomBasis(); 410 411 //! Build the reference tree. 412 void BuildModel(arma::mat&& referenceSet, 413 const size_t leafSize, 414 const bool naive, 415 const bool singleMode); 416 417 //! Perform rank-approximate neighbor search, taking ownership of the query 418 //! set. 419 void Search(arma::mat&& querySet, 420 const size_t k, 421 arma::Mat<size_t>& neighbors, 422 arma::mat& distances); 423 424 /** 425 * Perform rank-approximate neighbor search, using the reference set as the 426 * query set. 427 */ 428 void Search(const size_t k, 429 arma::Mat<size_t>& neighbors, 430 arma::mat& distances); 431 432 //! Get the name of the tree type. 433 std::string TreeName() const; 434 }; 435 436 } // namespace neighbor 437 } // namespace mlpack 438 439 #include "ra_model_impl.hpp" 440 441 #endif 442