1 /**
2  * @file methods/range_search/range_search_rules.hpp
3  * @author Ryan Curtin
4  *
5  * Rules for range search, so that it can be done with arbitrary tree types.
6  *
7  * mlpack is free software; you may redistribute it and/or modify it under the
8  * terms of the 3-clause BSD license.  You should have received a copy of the
9  * 3-clause BSD license along with mlpack.  If not, see
10  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
11  */
12 #ifndef MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_RULES_HPP
13 #define MLPACK_METHODS_RANGE_SEARCH_RANGE_SEARCH_RULES_HPP
14 
15 #include <mlpack/core/tree/traversal_info.hpp>
16 
17 namespace mlpack {
18 namespace range {
19 
20 /**
21  * The RangeSearchRules class is a template helper class used by RangeSearch
22  * class when performing range searches.
23  *
24  * @tparam MetricType The metric to use for computation.
25  * @tparam TreeType The tree type to use; must adhere to the TreeType API.
26  */
27 template<typename MetricType, typename TreeType>
28 class RangeSearchRules
29 {
30  public:
31   /**
32    * Construct the RangeSearchRules object.  This is usually done from within
33    * the RangeSearch class at search time.
34    *
35    * @param referenceSet Set of reference data.
36    * @param querySet Set of query data.
37    * @param range Range to search for.
38    * @param neighbors Vector to store resulting neighbors in.
39    * @param distances Vector to store resulting distances in.
40    * @param metric Instantiated metric.
41    * @param sameSet If true, the query and reference set are taken to be the
42    *      same, and a query point will not return itself in the results.
43    */
44   RangeSearchRules(const arma::mat& referenceSet,
45                    const arma::mat& querySet,
46                    const math::Range& range,
47                    std::vector<std::vector<size_t> >& neighbors,
48                    std::vector<std::vector<double> >& distances,
49                    MetricType& metric,
50                    const bool sameSet = false);
51 
52   /**
53    * Compute the base case between the given query point and reference point.
54    *
55    * @param queryIndex Index of query point.
56    * @param referenceIndex Index of reference point.
57    */
58   double BaseCase(const size_t queryIndex, const size_t referenceIndex);
59 
60   /**
61    * Get the score for recursion order.  A low score indicates priority for
62    * recursion, while DBL_MAX indicates that the node should not be recursed
63    * into at all (it should be pruned).
64    *
65    * @param queryIndex Index of query point.
66    * @param referenceNode Candidate node to be recursed into.
67    */
68   double Score(const size_t queryIndex, TreeType& referenceNode);
69 
70   /**
71    * Re-evaluate the score for recursion order.  A low score indicates priority
72    * for recursion, while DBL_MAX indicates that the node should not be recursed
73    * into at all (it should be pruned).  This is used when the score has already
74    * been calculated, but another recursion may have modified the bounds for
75    * pruning.  So the old score is checked against the new pruning bound.
76    *
77    * @param queryIndex Index of query point.
78    * @param referenceNode Candidate node to be recursed into.
79    * @param oldScore Old score produced by Score() (or Rescore()).
80    */
81   double Rescore(const size_t queryIndex,
82                  TreeType& referenceNode,
83                  const double oldScore) const;
84 
85   /**
86    * Get the score for recursion order.  A low score indicates priority for
87    * recursion, while DBL_MAX indicates that the node should not be recursed
88    * into at all (it should be pruned).
89    *
90    * @param queryNode Candidate query node to recurse into.
91    * @param referenceNode Candidate reference node to recurse into.
92    */
93   double Score(TreeType& queryNode, TreeType& referenceNode);
94 
95   /**
96    * Re-evaluate the score for recursion order.  A low score indicates priority
97    * for recursion, while DBL_MAX indicates that the node should not be recursed
98    * into at all (it should be pruned).  This is used when the score has already
99    * been calculated, but another recursion may have modified the bounds for
100    * pruning.  So the old score is checked against the new pruning bound.
101    *
102    * @param queryNode Candidate query node to recurse into.
103    * @param referenceNode Candidate reference node to recurse into.
104    * @param oldScore Old score produced by Score() (or Rescore()).
105    */
106   double Rescore(TreeType& queryNode,
107                  TreeType& referenceNode,
108                  const double oldScore) const;
109 
110   typedef typename tree::TraversalInfo<TreeType> TraversalInfoType;
111 
TraversalInfo() const112   const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
TraversalInfo()113   TraversalInfoType& TraversalInfo() { return traversalInfo; }
114 
115   //! Get the number of base cases.
BaseCases() const116   size_t BaseCases() const { return baseCases; }
117   //! Get the number of scores (that is, calls to RangeDistance()).
Scores() const118   size_t Scores() const { return scores; }
119 
120   //! Get the minimum number of base cases we need to perform to have acceptable
121   //! results.
MinimumBaseCases() const122   size_t MinimumBaseCases() const { return 0; }
123 
124  private:
125   //! The reference set.
126   const arma::mat& referenceSet;
127 
128   //! The query set.
129   const arma::mat& querySet;
130 
131   //! The range of distances for which we are searching.
132   const math::Range& range;
133 
134   //! The vector the resultant neighbor indices should be stored in.
135   std::vector<std::vector<size_t> >& neighbors;
136 
137   //! The vector the resultant neighbor distances should be stored in.
138   std::vector<std::vector<double> >& distances;
139 
140   //! The instantiated metric.
141   MetricType& metric;
142 
143   //! If true, the query and reference set are taken to be the same.
144   bool sameSet;
145 
146   //! The last query index.
147   size_t lastQueryIndex;
148   //! The last reference index.
149   size_t lastReferenceIndex;
150 
151   //! Add all the points in the given node to the results for the given query
152   //! point.  If the base case has already been calculated, we make sure to not
153   //! add that to the results twice.
154   void AddResult(const size_t queryIndex,
155                  TreeType& referenceNode);
156 
157   TraversalInfoType traversalInfo;
158 
159   //! The number of base cases.
160   size_t baseCases;
161   //! THe number of scores.
162   size_t scores;
163 };
164 
165 } // namespace range
166 } // namespace mlpack
167 
168 // Include implementation.
169 #include "range_search_rules_impl.hpp"
170 
171 #endif
172