1 /** 2 * @file methods/range_search/range_search_rules.hpp 3 * @author Ryan Curtin 4 * 5 * Rules for range search, so that it can be done with arbitrary tree types. 6 * 7 * mlpack is free software; you may redistribute it and/or modify it under the 8 * terms of the 3-clause BSD license. You should have received a copy of the 9 * 3-clause BSD license along with mlpack. If not, see 10 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 11 */ 12 #ifndef MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_RULES_HPP 13 #define MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_RULES_HPP 14 15 #include <mlpack/core/tree/traversal_info.hpp> 16 17 namespace mlpack { 18 namespace range { 19 20 /** 21 * The RangeSearchRules class is a template helper class used by RangeSearch 22 * class when performing range searches. 23 * 24 * @tparam MetricType The metric to use for computation. 25 * @tparam TreeType The tree type to use; must adhere to the TreeType API. 26 */ 27 template<typename MetricType, typename TreeType> 28 class RangeSearchRules 29 { 30 public: 31 /** 32 * Construct the RangeSearchRules object. This is usually done from within 33 * the RangeSearch class at search time. 34 * 35 * @param referenceSet Set of reference data. 36 * @param querySet Set of query data. 37 * @param range Range to search for. 38 * @param neighbors Vector to store resulting neighbors in. 39 * @param distances Vector to store resulting distances in. 40 * @param metric Instantiated metric. 41 * @param sameSet If true, the query and reference set are taken to be the 42 * same, and a query point will not return itself in the results. 43 */ 44 RangeSearchRules(const arma::mat& referenceSet, 45 const arma::mat& querySet, 46 const math::Range& range, 47 std::vector<std::vector<size_t> >& neighbors, 48 std::vector<std::vector<double> >& distances, 49 MetricType& metric, 50 const bool sameSet = false); 51 52 /** 53 * Compute the base case between the given query point and reference point. 54 * 55 * @param queryIndex Index of query point. 56 * @param referenceIndex Index of reference point. 57 */ 58 double BaseCase(const size_t queryIndex, const size_t referenceIndex); 59 60 /** 61 * Get the score for recursion order. A low score indicates priority for 62 * recursion, while DBL_MAX indicates that the node should not be recursed 63 * into at all (it should be pruned). 64 * 65 * @param queryIndex Index of query point. 66 * @param referenceNode Candidate node to be recursed into. 67 */ 68 double Score(const size_t queryIndex, TreeType& referenceNode); 69 70 /** 71 * Re-evaluate the score for recursion order. A low score indicates priority 72 * for recursion, while DBL_MAX indicates that the node should not be recursed 73 * into at all (it should be pruned). This is used when the score has already 74 * been calculated, but another recursion may have modified the bounds for 75 * pruning. So the old score is checked against the new pruning bound. 76 * 77 * @param queryIndex Index of query point. 78 * @param referenceNode Candidate node to be recursed into. 79 * @param oldScore Old score produced by Score() (or Rescore()). 80 */ 81 double Rescore(const size_t queryIndex, 82 TreeType& referenceNode, 83 const double oldScore) const; 84 85 /** 86 * Get the score for recursion order. A low score indicates priority for 87 * recursion, while DBL_MAX indicates that the node should not be recursed 88 * into at all (it should be pruned). 89 * 90 * @param queryNode Candidate query node to recurse into. 91 * @param referenceNode Candidate reference node to recurse into. 92 */ 93 double Score(TreeType& queryNode, TreeType& referenceNode); 94 95 /** 96 * Re-evaluate the score for recursion order. A low score indicates priority 97 * for recursion, while DBL_MAX indicates that the node should not be recursed 98 * into at all (it should be pruned). This is used when the score has already 99 * been calculated, but another recursion may have modified the bounds for 100 * pruning. So the old score is checked against the new pruning bound. 101 * 102 * @param queryNode Candidate query node to recurse into. 103 * @param referenceNode Candidate reference node to recurse into. 104 * @param oldScore Old score produced by Score() (or Rescore()). 105 */ 106 double Rescore(TreeType& queryNode, 107 TreeType& referenceNode, 108 const double oldScore) const; 109 110 typedef typename tree::TraversalInfo<TreeType> TraversalInfoType; 111 TraversalInfo() const112 const TraversalInfoType& TraversalInfo() const { return traversalInfo; } TraversalInfo()113 TraversalInfoType& TraversalInfo() { return traversalInfo; } 114 115 //! Get the number of base cases. BaseCases() const116 size_t BaseCases() const { return baseCases; } 117 //! Get the number of scores (that is, calls to RangeDistance()). Scores() const118 size_t Scores() const { return scores; } 119 120 //! Get the minimum number of base cases we need to perform to have acceptable 121 //! results. MinimumBaseCases() const122 size_t MinimumBaseCases() const { return 0; } 123 124 private: 125 //! The reference set. 126 const arma::mat& referenceSet; 127 128 //! The query set. 129 const arma::mat& querySet; 130 131 //! The range of distances for which we are searching. 132 const math::Range& range; 133 134 //! The vector the resultant neighbor indices should be stored in. 135 std::vector<std::vector<size_t> >& neighbors; 136 137 //! The vector the resultant neighbor distances should be stored in. 138 std::vector<std::vector<double> >& distances; 139 140 //! The instantiated metric. 141 MetricType& metric; 142 143 //! If true, the query and reference set are taken to be the same. 144 bool sameSet; 145 146 //! The last query index. 147 size_t lastQueryIndex; 148 //! The last reference index. 149 size_t lastReferenceIndex; 150 151 //! Add all the points in the given node to the results for the given query 152 //! point. If the base case has already been calculated, we make sure to not 153 //! add that to the results twice. 154 void AddResult(const size_t queryIndex, 155 TreeType& referenceNode); 156 157 TraversalInfoType traversalInfo; 158 159 //! The number of base cases. 160 size_t baseCases; 161 //! THe number of scores. 162 size_t scores; 163 }; 164 165 } // namespace range 166 } // namespace mlpack 167 168 // Include implementation. 169 #include "range_search_rules_impl.hpp" 170 171 #endif 172