1 /** 2 * @file methods/rann/ra_search_rules.hpp 3 * @author Parikshit Ram 4 * 5 * Defines the pruning rules and base case rules necessary to perform a 6 * tree-based rank-approximate search (with an arbitrary tree) for the RASearch 7 * class. 8 * 9 * mlpack is free software; you may redistribute it and/or modify it under the 10 * terms of the 3-clause BSD license. You should have received a copy of the 11 * 3-clause BSD license along with mlpack. If not, see 12 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 13 */ 14 #ifndef MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP 15 #define MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP 16 17 #include <mlpack/core/tree/traversal_info.hpp> 18 19 #include <queue> 20 21 namespace mlpack { 22 namespace neighbor { 23 24 /** 25 * The RASearchRules class is a template helper class used by RASearch class 26 * when performing rank-approximate search via random-sampling. 27 * 28 * @tparam SortPolicy The sort policy for distances. 29 * @tparam MetricType The metric to use for computation. 30 * @tparam TreeType The tree type to use; must adhere to the TreeType API. 31 */ 32 template<typename SortPolicy, typename MetricType, typename TreeType> 33 class RASearchRules 34 { 35 public: 36 /** 37 * Construct the RASearchRules object. This is usually done from within 38 * the RASearch class at search time. 39 * 40 * @param referenceSet Set of reference data. 41 * @param querySet Set of query data. 42 * @param k Number of neighbors to search for. 43 * @param metric Instantiated metric. 44 * @param tau The rank-approximation in percentile of the data. 45 * @param alpha The desired success probability. 46 * @param naive If true, the rank-approximate search will be performed by 47 * directly sampling the whole set instead of using the stratified 48 * sampling on the tree. 49 * @param sampleAtLeaves Sample at leaves for faster but less accurate 50 * computation. 51 * @param firstLeafExact Traverse to the first leaf without approximation. 52 * @param singleSampleLimit The limit on the largest node that can be 53 * approximated by sampling. 54 * @param sameSet If true, the query and reference set are taken to be the 55 * same, and a query point will not return itself in the results. 56 */ 57 RASearchRules(const arma::mat& referenceSet, 58 const arma::mat& querySet, 59 const size_t k, 60 MetricType& metric, 61 const double tau = 5, 62 const double alpha = 0.95, 63 const bool naive = false, 64 const bool sampleAtLeaves = false, 65 const bool firstLeafExact = false, 66 const size_t singleSampleLimit = 20, 67 const bool sameSet = false); 68 69 /** 70 * Store the list of candidates for each query point in the given matrices. 71 * 72 * @param neighbors Matrix storing lists of neighbors for each query point. 73 * @param distances Matrix storing distances of neighbors for each query 74 * point. 75 */ 76 void GetResults(arma::Mat<size_t>& neighbors, arma::mat& distances); 77 78 /** 79 * Get the distance from the query point to the reference point. 80 * This will update the list of candidates with the new point if appropriate. 81 * 82 * @param queryIndex Index of query point. 83 * @param referenceIndex Index of reference point. 84 */ 85 double BaseCase(const size_t queryIndex, const size_t referenceIndex); 86 87 /** 88 * Get the score for recursion order. A low score indicates priority for 89 * recursion, while DBL_MAX indicates that the node should not be recursed 90 * into at all (it should be pruned). 91 * 92 * For rank-approximation, the scoring function first checks if pruning 93 * by distance is possible. 94 * If yes, then the node is given the score of 95 * 'DBL_MAX' and the expected number of samples from that node are 96 * added to the number of samples made for the query. 97 * 98 * If no, then the function tries to see if the node can be pruned by 99 * approximation. If number of samples required from this node is small 100 * enough, then that number of samples are acquired from this node 101 * and the score is set to be 'DBL_MAX'. 102 * 103 * If the pruning by approximation is not possible either, the algorithm 104 * continues with the usual tree-traversal. 105 * 106 * @param queryIndex Index of query point. 107 * @param referenceNode Candidate node to be recursed into. 108 */ 109 double Score(const size_t queryIndex, TreeType& referenceNode); 110 111 /** 112 * Get the score for recursion order. A low score indicates priority for 113 * recursion, while DBL_MAX indicates that the node should not be recursed 114 * into at all (it should be pruned). 115 * 116 * For rank-approximation, the scoring function first checks if pruning 117 * by distance is possible. 118 * If yes, then the node is given the score of 119 * 'DBL_MAX' and the expected number of samples from that node are 120 * added to the number of samples made for the query. 121 * 122 * If no, then the function tries to see if the node can be pruned by 123 * approximation. If number of samples required from this node is small 124 * enough, then that number of samples are acquired from this node 125 * and the score is set to be 'DBL_MAX'. 126 * 127 * If the pruning by approximation is not possible either, the algorithm 128 * continues with the usual tree-traversal. 129 * 130 * @param queryIndex Index of query point. 131 * @param referenceNode Candidate node to be recursed into. 132 * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode). 133 */ 134 double Score(const size_t queryIndex, 135 TreeType& referenceNode, 136 const double baseCaseResult); 137 138 /** 139 * Re-evaluate the score for recursion order. A low score indicates priority 140 * for recursion, while DBL_MAX indicates that the node should not be 141 * recursed into at all (it should be pruned). This is used when the score 142 * has already been calculated, but another recursion may have modified the 143 * bounds for pruning. So the old score is checked against the new pruning 144 * bound. 145 * 146 * For rank-approximation, it also checks if the number of samples left 147 * for a query to satisfy the rank constraint is small enough at this 148 * point of the algorithm, then this node is approximated by sampling 149 * and given a new score of 'DBL_MAX'. 150 * 151 * @param queryIndex Index of query point. 152 * @param referenceNode Candidate node to be recursed into. 153 * @param oldScore Old score produced by Score() (or Rescore()). 154 */ 155 double Rescore(const size_t queryIndex, 156 TreeType& referenceNode, 157 const double oldScore); 158 159 /** 160 * Get the score for recursion order. A low score indicates priority for 161 * recursionm while DBL_MAX indicates that the node should not be recursed 162 * into at all (it should be pruned). 163 * 164 * For the rank-approximation, we check if the referenceNode can be 165 * approximated by sampling. If it can be, enough samples are made for 166 * every query in the queryNode. No further query-tree traversal is 167 * performed. 168 * 169 * The 'NumSamplesMade' query stat is propagated up the tree. And then 170 * if pruning occurs (by distance or by sampling), the 'NumSamplesMade' 171 * stat is not propagated down the tree. If no pruning occurs, the 172 * stat is propagated down the tree. 173 * 174 * @param queryNode Candidate query node to recurse into. 175 * @param referenceNode Candidate reference node to recurse into. 176 */ 177 double Score(TreeType& queryNode, TreeType& referenceNode); 178 179 /** 180 * Get the score for recursion order, passing the base case result (in the 181 * situation where it may be needed to calculate the recursion order). A low 182 * score indicates priority for recursion, while DBL_MAX indicates that the 183 * node should not be recursed into at all (it should be pruned). 184 * 185 * For the rank-approximation, we check if the referenceNode can be 186 * approximated by sampling. If it can be, enough samples are made for 187 * every query in the queryNode. No further query-tree traversal is 188 * performed. 189 * 190 * The 'NumSamplesMade' query stat is propagated up the tree. And then 191 * if pruning occurs (by distance or by sampling), the 'NumSamplesMade' 192 * stat is not propagated down the tree. If no pruning occurs, the 193 * stat is propagated down the tree. 194 * 195 * @param queryNode Candidate query node to recurse into. 196 * @param referenceNode Candidate reference node to recurse into. 197 * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode). 198 */ 199 double Score(TreeType& queryNode, 200 TreeType& referenceNode, 201 const double baseCaseResult); 202 203 /** 204 * Re-evaluate the score for recursion order. A low score indicates priority 205 * for recursion, while DBL_MAX indicates that the node should not be 206 * recursed into at all (it should be pruned). This is used when the score 207 * has already been calculated, but another recursion may have modified the 208 * bounds for pruning. So the old score is checked against the new pruning 209 * bound. 210 * 211 * For the rank-approximation, we check if the referenceNode can be 212 * approximated by sampling. If it can be, enough samples are made for 213 * every query in the queryNode. No further query-tree traversal is 214 * performed. 215 * 216 * The 'NumSamplesMade' query stat is propagated up the tree. And then 217 * if pruning occurs (by distance or by sampling), the 'NumSamplesMade' 218 * stat is not propagated down the tree. If no pruning occurs, the 219 * stat is propagated down the tree. 220 * 221 * @param queryNode Candidate query node to recurse into. 222 * @param referenceNode Candidate reference node to recurse into. 223 * @param oldScore Old score produced by Socre() (or Rescore()). 224 */ 225 double Rescore(TreeType& queryNode, 226 TreeType& referenceNode, 227 const double oldScore); 228 229 NumDistComputations()230 size_t NumDistComputations() { return numDistComputations; } NumEffectiveSamples()231 size_t NumEffectiveSamples() 232 { 233 if (numSamplesMade.n_elem == 0) 234 return 0; 235 else 236 return arma::sum(numSamplesMade); 237 } 238 239 typedef typename tree::TraversalInfo<TreeType> TraversalInfoType; 240 TraversalInfo() const241 const TraversalInfoType& TraversalInfo() const { return traversalInfo; } TraversalInfo()242 TraversalInfoType& TraversalInfo() { return traversalInfo; } 243 244 //! Get the minimum number of base cases that must be performed for each query 245 //! point for an acceptable result. This is only needed in defeatist search 246 //! mode. MinimumBaseCases() const247 size_t MinimumBaseCases() const { return k; } 248 249 private: 250 //! The reference set. 251 const arma::mat& referenceSet; 252 253 //! The query set. 254 const arma::mat& querySet; 255 256 //! Candidate represents a possible candidate neighbor (distance, index). 257 typedef std::pair<double, size_t> Candidate; 258 259 //! Compare two candidates based on the distance. 260 struct CandidateCmp { operator ()mlpack::neighbor::RASearchRules::CandidateCmp261 bool operator()(const Candidate& c1, const Candidate& c2) 262 { 263 return !SortPolicy::IsBetter(c2.first, c1.first); 264 }; 265 }; 266 267 //! Use a priority queue to represent the list of candidate neighbors. 268 typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp> 269 CandidateList; 270 271 //! Set of candidate neighbors for each point. 272 std::vector<CandidateList> candidates; 273 274 //! Number of neighbors to search for. 275 const size_t k; 276 277 //! The instantiated metric. 278 MetricType& metric; 279 280 //! Whether to sample at leaves or just use all of it. 281 bool sampleAtLeaves; 282 283 //! Whether to do exact computation on the first leaf before any sampling. 284 bool firstLeafExact; 285 286 //! The limit on the largest node that can be approximated by sampling. 287 size_t singleSampleLimit; 288 289 //! The minimum number of samples required per query. 290 size_t numSamplesReqd; 291 292 //! The number of samples made for every query. 293 arma::Col<size_t> numSamplesMade; 294 295 //! The sampling ratio. 296 double samplingRatio; 297 298 //! The number of distance calculations performed during search. 299 size_t numDistComputations; 300 301 //! If the query and reference set are identical, this is true. 302 bool sameSet; 303 304 TraversalInfoType traversalInfo; 305 306 /** 307 * Helper function to insert a point into the list of candidate points. 308 * 309 * @param queryIndex Index of point whose neighbors we are inserting into. 310 * @param neighbor Index of reference point which is being inserted. 311 * @param distance Distance from query point to reference point. 312 */ 313 void InsertNeighbor(const size_t queryIndex, 314 const size_t neighbor, 315 const double distance); 316 317 /** 318 * Perform actual scoring for single-tree case. 319 */ 320 double Score(const size_t queryIndex, 321 TreeType& referenceNode, 322 const double distance, 323 const double bestDistance); 324 325 /** 326 * Perform actual scoring for dual-tree case. 327 */ 328 double Score(TreeType& queryNode, 329 TreeType& referenceNode, 330 const double distance, 331 const double bestDistance); 332 333 static_assert(tree::TreeTraits<TreeType>::UniqueNumDescendants, "TreeType " 334 "must provide a unique number of descendants points."); 335 }; // class RASearchRules 336 337 } // namespace neighbor 338 } // namespace mlpack 339 340 // Include implementation. 341 #include "ra_search_rules_impl.hpp" 342 343 #endif // MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP 344