1 /** 2 * @file methods/neighbor_search/ns_model.hpp 3 * @author Ryan Curtin 4 * 5 * This is a model for nearest or furthest neighbor search. It is useful in 6 * that it provides an easy way to serialize a model, abstracts away the 7 * different types of trees, and also reflects the NeighborSearch API and 8 * automatically directs to the right tree type. 9 * 10 * mlpack is free software; you may redistribute it and/or modify it under the 11 * terms of the 3-clause BSD license. You should have received a copy of the 12 * 3-clause BSD license along with mlpack. If not, see 13 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 14 */ 15 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP 16 #define MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP 17 18 #include <mlpack/core/tree/binary_space_tree.hpp> 19 #include <mlpack/core/tree/cover_tree.hpp> 20 #include <mlpack/core/tree/rectangle_tree.hpp> 21 #include <mlpack/core/tree/spill_tree.hpp> 22 #include <mlpack/core/tree/octree.hpp> 23 #include <boost/variant.hpp> 24 #include "neighbor_search.hpp" 25 26 namespace mlpack { 27 namespace neighbor { 28 29 /** 30 * Alias template for euclidean neighbor search. 31 */ 32 template<typename SortPolicy, 33 template<typename TreeMetricType, 34 typename TreeStatType, 35 typename TreeMatType> class TreeType> 36 using NSType = NeighborSearch<SortPolicy, 37 metric::EuclideanDistance, 38 arma::mat, 39 TreeType, 40 TreeType<metric::EuclideanDistance, 41 NeighborSearchStat<SortPolicy>, 42 arma::mat>::template DualTreeTraverser>; 43 44 /** 45 * MonoSearchVisitor executes a monochromatic neighbor search on the given 46 * NSType. We don't make any difference for different instantiations of NSType. 47 */ 48 class MonoSearchVisitor : public boost::static_visitor<void> 49 { 50 private: 51 //! Number of neighbors to search for. 52 const size_t k; 53 //! Result matrix for neighbors. 54 arma::Mat<size_t>& neighbors; 55 //! Result matrix for distances. 56 arma::mat& distances; 57 58 public: 59 //! Perform monochromatic nearest neighbor search. 60 template<typename NSType> 61 void operator()(NSType* ns) const; 62 63 //! Construct the MonoSearchVisitor object with the given parameters. MonoSearchVisitor(const size_t k,arma::Mat<size_t> & neighbors,arma::mat & distances)64 MonoSearchVisitor(const size_t k, 65 arma::Mat<size_t>& neighbors, 66 arma::mat& distances) : 67 k(k), 68 neighbors(neighbors), 69 distances(distances) 70 {}; 71 }; 72 73 /** 74 * BiSearchVisitor executes a bichromatic neighbor search on the given NSType. 75 * We use template specialization to differentiate those tree types that 76 * accept leafSize as a parameter. In these cases, before doing neighbor search, 77 * a query tree with proper leafSize is built from the querySet. 78 */ 79 template<typename SortPolicy> 80 class BiSearchVisitor : public boost::static_visitor<void> 81 { 82 private: 83 //! The query set for the bichromatic search. 84 const arma::mat& querySet; 85 //! The number of neighbors to search for. 86 const size_t k; 87 //! The result matrix for neighbors. 88 arma::Mat<size_t>& neighbors; 89 //! The result matrix for distances. 90 arma::mat& distances; 91 //! The number of points in a leaf (for BinarySpaceTrees). 92 const size_t leafSize; 93 //! Overlapping size (for spill trees). 94 const double tau; 95 //! Balance threshold (for spill trees). 96 const double rho; 97 98 //! Bichromatic neighbor search on the given NSType considering the leafSize. 99 template<typename NSType> 100 void SearchLeaf(NSType* ns) const; 101 102 public: 103 //! Alias template necessary for visual c++ compiler. 104 template<template<typename TreeMetricType, 105 typename TreeStatType, 106 typename TreeMatType> class TreeType> 107 using NSTypeT = NSType<SortPolicy, TreeType>; 108 109 //! Default Bichromatic neighbor search on the given NSType instance. 110 template<template<typename TreeMetricType, 111 typename TreeStatType, 112 typename TreeMatType> class TreeType> 113 void operator()(NSTypeT<TreeType>* ns) const; 114 115 //! Bichromatic neighbor search on the given NSType specialized for KDTrees. 116 void operator()(NSTypeT<tree::KDTree>* ns) const; 117 118 //! Bichromatic neighbor search on the given NSType specialized for BallTrees. 119 void operator()(NSTypeT<tree::BallTree>* ns) const; 120 121 //! Bichromatic neighbor search specialized for SPTrees. 122 void operator()(SpillKNN* ns) const; 123 124 //! Bichromatic neighbor search specialized for octrees. 125 void operator()(NSTypeT<tree::Octree>* ns) const; 126 127 //! Construct the BiSearchVisitor. 128 BiSearchVisitor(const arma::mat& querySet, 129 const size_t k, 130 arma::Mat<size_t>& neighbors, 131 arma::mat& distances, 132 const size_t leafSize, 133 const double tau, 134 const double rho); 135 }; 136 137 /** 138 * TrainVisitor sets the reference set to a new reference set on the given 139 * NSType. We use template specialization to differentiate those tree types that 140 * accept leafSize as a parameter. In these cases, a reference tree with proper 141 * leafSize is built from the referenceSet. 142 */ 143 template<typename SortPolicy> 144 class TrainVisitor : public boost::static_visitor<void> 145 { 146 private: 147 //! The reference set to use for training. 148 arma::mat&& referenceSet; 149 //! The leaf size, used only by BinarySpaceTree. 150 size_t leafSize; 151 //! Overlapping size (for spill trees). 152 const double tau; 153 //! Balance threshold (for spill trees). 154 const double rho; 155 156 //! Train on the given NSType considering the leafSize. 157 template<typename NSType> 158 void TrainLeaf(NSType* ns) const; 159 160 public: 161 //! Alias template necessary for visual c++ compiler. 162 template<template<typename TreeMetricType, 163 typename TreeStatType, 164 typename TreeMatType> class TreeType> 165 using NSTypeT = NSType<SortPolicy, TreeType>; 166 167 //! Default Train on the given NSType instance. 168 template<template<typename TreeMetricType, 169 typename TreeStatType, 170 typename TreeMatType> class TreeType> 171 void operator()(NSTypeT<TreeType>* ns) const; 172 173 //! Train on the given NSType specialized for KDTrees. 174 void operator()(NSTypeT<tree::KDTree>* ns) const; 175 176 //! Train on the given NSType specialized for BallTrees. 177 void operator()(NSTypeT<tree::BallTree>* ns) const; 178 179 //! Train specialized for SPTrees. 180 void operator()(SpillKNN* ns) const; 181 182 //! Train specialized for octrees. 183 void operator()(NSTypeT<tree::Octree>* ns) const; 184 185 //! Construct the TrainVisitor object with the given reference set, leafSize 186 //! for BinarySpaceTrees, and tau and rho for spill trees. 187 TrainVisitor(arma::mat&& referenceSet, 188 const size_t leafSize, 189 const double tau, 190 const double rho); 191 }; 192 193 /** 194 * SearchModeVisitor exposes the SearchMode() method of the given NSType. 195 */ 196 class SearchModeVisitor : public boost::static_visitor<NeighborSearchMode&> 197 { 198 public: 199 //! Return the search mode. 200 template<typename NSType> 201 NeighborSearchMode& operator()(NSType* ns) const; 202 }; 203 204 /** 205 * EpsilonVisitor exposes the Epsilon method of the given NSType. 206 */ 207 class EpsilonVisitor : public boost::static_visitor<double&> 208 { 209 public: 210 //! Return epsilon, the approximation parameter. 211 template<typename NSType> 212 double& operator()(NSType *ns) const; 213 }; 214 215 /** 216 * ReferenceSetVisitor exposes the referenceSet of the given NSType. 217 */ 218 class ReferenceSetVisitor : public boost::static_visitor<const arma::mat&> 219 { 220 public: 221 //! Return the reference set. 222 template<typename NSType> 223 const arma::mat& operator()(NSType *ns) const; 224 }; 225 226 /** 227 * DeleteVisitor deletes the given NSType instance. 228 */ 229 class DeleteVisitor : public boost::static_visitor<void> 230 { 231 public: 232 //! Delete the NSType object. 233 template<typename NSType> 234 void operator()(NSType *ns) const; 235 }; 236 237 /** 238 * The NSModel class provides an easy way to serialize a model, abstracts away 239 * the different types of trees, and also reflects the NeighborSearch API. This 240 * class is meant to be used by the command-line mlpack_knn and mlpack_kfn 241 * programs, and thus does not have the same complete functionality and 242 * flexibility as the NeighborSearch class. So if you are using it outside of 243 * mlpack_knn and mlpack_kfn, be aware that it is limited! 244 * 245 * @tparam SortPolicy The sort policy for distances; see NearestNeighborSort. 246 */ 247 template<typename SortPolicy> 248 class NSModel 249 { 250 public: 251 //! Enum type to identify each accepted tree type. 252 enum TreeTypes 253 { 254 KD_TREE, 255 COVER_TREE, 256 R_TREE, 257 R_STAR_TREE, 258 BALL_TREE, 259 X_TREE, 260 HILBERT_R_TREE, 261 R_PLUS_TREE, 262 R_PLUS_PLUS_TREE, 263 VP_TREE, 264 RP_TREE, 265 MAX_RP_TREE, 266 SPILL_TREE, 267 UB_TREE, 268 OCTREE 269 }; 270 271 private: 272 //! Tree type considered for neighbor search. 273 TreeTypes treeType; 274 275 //! For tree types that accept the maxLeafSize parameter. 276 size_t leafSize; 277 278 //! Overlapping size (for spill trees). 279 double tau; 280 //! Balance threshold (for spill trees). 281 double rho; 282 283 //! If true, random projections are used. 284 bool randomBasis; 285 //! This is the random projection matrix; only used if randomBasis is true. 286 arma::mat q; 287 288 /** 289 * nSearch holds an instance of the NeigborSearch class for the current 290 * treeType. It is initialized every time BuildModel is executed. 291 * We access to the contained value through the visitor classes defined above. 292 */ 293 boost::variant<NSType<SortPolicy, tree::KDTree>*, 294 NSType<SortPolicy, tree::StandardCoverTree>*, 295 NSType<SortPolicy, tree::RTree>*, 296 NSType<SortPolicy, tree::RStarTree>*, 297 NSType<SortPolicy, tree::BallTree>*, 298 NSType<SortPolicy, tree::XTree>*, 299 NSType<SortPolicy, tree::HilbertRTree>*, 300 NSType<SortPolicy, tree::RPlusTree>*, 301 NSType<SortPolicy, tree::RPlusPlusTree>*, 302 NSType<SortPolicy, tree::VPTree>*, 303 NSType<SortPolicy, tree::RPTree>*, 304 NSType<SortPolicy, tree::MaxRPTree>*, 305 SpillKNN*, 306 NSType<SortPolicy, tree::UBTree>*, 307 NSType<SortPolicy, tree::Octree>*> nSearch; 308 309 public: 310 /** 311 * Initialize the NSModel with the given type and whether or not a random 312 * basis should be used. 313 * 314 * @param treeType Type of tree to use. 315 * @param randomBasis Whether or not to project the points onto a random basis 316 * before searching. 317 */ 318 NSModel(TreeTypes treeType = TreeTypes::KD_TREE, bool randomBasis = false); 319 320 /** 321 * Copy the given NSModel. 322 * 323 * @param other Model to copy. 324 */ 325 NSModel(const NSModel& other); 326 327 /** 328 * Take ownership of the given NSModel. 329 * 330 * @param other Model to take ownership of. 331 */ 332 NSModel(NSModel&& other); 333 334 /** 335 * Copy the given NSModel. 336 * 337 * @param other Model to copy. 338 */ 339 NSModel& operator=(const NSModel& other); 340 341 /** 342 * Take ownership of the given NSModel. 343 * 344 * @param other Model to take ownership of. 345 */ 346 NSModel& operator=(NSModel&& other); 347 348 //! Clean memory, if necessary. 349 ~NSModel(); 350 351 //! Serialize the neighbor search model. 352 template<typename Archive> 353 void serialize(Archive& ar, const unsigned int /* version */); 354 355 //! Expose the dataset. 356 const arma::mat& Dataset() const; 357 358 //! Expose SearchMode. 359 NeighborSearchMode SearchMode() const; 360 NeighborSearchMode& SearchMode(); 361 362 //! Expose Epsilon. 363 double Epsilon() const; 364 double& Epsilon(); 365 366 //! Expose leafSize. LeafSize() const367 size_t LeafSize() const { return leafSize; } LeafSize()368 size_t& LeafSize() { return leafSize; } 369 370 //! Expose tau. Tau() const371 double Tau() const { return tau; } Tau()372 double& Tau() { return tau; } 373 374 //! Expose rho. Rho() const375 double Rho() const { return rho; } Rho()376 double& Rho() { return rho; } 377 378 //! Expose treeType. TreeType() const379 TreeTypes TreeType() const { return treeType; } TreeType()380 TreeTypes& TreeType() { return treeType; } 381 382 //! Expose randomBasis. RandomBasis() const383 bool RandomBasis() const { return randomBasis; } RandomBasis()384 bool& RandomBasis() { return randomBasis; } 385 386 //! Build the reference tree. 387 void BuildModel(arma::mat&& referenceSet, 388 const size_t leafSize, 389 const NeighborSearchMode searchMode, 390 const double epsilon = 0); 391 392 //! Perform neighbor search. The query set will be reordered. 393 void Search(arma::mat&& querySet, 394 const size_t k, 395 arma::Mat<size_t>& neighbors, 396 arma::mat& distances); 397 398 //! Perform monochromatic neighbor search. 399 void Search(const size_t k, 400 arma::Mat<size_t>& neighbors, 401 arma::mat& distances); 402 403 //! Return a string representation of the current tree type. 404 std::string TreeName() const; 405 }; 406 407 } // namespace neighbor 408 } // namespace mlpack 409 410 //! Set the serialization version of the NSModel class. 411 BOOST_TEMPLATE_CLASS_VERSION(template<typename SortPolicy>, 412 mlpack::neighbor::NSModel<SortPolicy>, 1); 413 414 // Include implementation. 415 #include "ns_model_impl.hpp" 416 417 #endif 418