1 /**
2  * @file methods/rann/ra_search_rules.hpp
3  * @author Parikshit Ram
4  *
5  * Defines the pruning rules and base case rules necessary to perform a
6  * tree-based rank-approximate search (with an arbitrary tree) for the RASearch
7  * class.
8  *
9  * mlpack is free software; you may redistribute it and/or modify it under the
10  * terms of the 3-clause BSD license.  You should have received a copy of the
11  * 3-clause BSD license along with mlpack.  If not, see
12  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
13  */
14 #ifndef MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
15 #define MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
16 
17 #include <mlpack/core/tree/traversal_info.hpp>
18 
19 #include <queue>
20 
21 namespace mlpack {
22 namespace neighbor {
23 
24 /**
25  * The RASearchRules class is a template helper class used by RASearch class
26  * when performing rank-approximate search via random-sampling.
27  *
28  * @tparam SortPolicy The sort policy for distances.
29  * @tparam MetricType The metric to use for computation.
30  * @tparam TreeType The tree type to use; must adhere to the TreeType API.
31  */
32 template<typename SortPolicy, typename MetricType, typename TreeType>
33 class RASearchRules
34 {
35  public:
36   /**
37    * Construct the RASearchRules object.  This is usually done from within
38    * the RASearch class at search time.
39    *
40    * @param referenceSet Set of reference data.
41    * @param querySet Set of query data.
42    * @param k Number of neighbors to search for.
43    * @param metric Instantiated metric.
44    * @param tau The rank-approximation in percentile of the data.
45    * @param alpha The desired success probability.
46    * @param naive If true, the rank-approximate search will be performed by
47    *      directly sampling the whole set instead of using the stratified
48    *      sampling on the tree.
49    * @param sampleAtLeaves Sample at leaves for faster but less accurate
50    *      computation.
51    * @param firstLeafExact Traverse to the first leaf without approximation.
52    * @param singleSampleLimit The limit on the largest node that can be
53    *     approximated by sampling.
54    * @param sameSet If true, the query and reference set are taken to be the
55    *      same, and a query point will not return itself in the results.
56    */
57   RASearchRules(const arma::mat& referenceSet,
58                 const arma::mat& querySet,
59                 const size_t k,
60                 MetricType& metric,
61                 const double tau = 5,
62                 const double alpha = 0.95,
63                 const bool naive = false,
64                 const bool sampleAtLeaves = false,
65                 const bool firstLeafExact = false,
66                 const size_t singleSampleLimit = 20,
67                 const bool sameSet = false);
68 
69   /**
70    * Store the list of candidates for each query point in the given matrices.
71    *
72    * @param neighbors Matrix storing lists of neighbors for each query point.
73    * @param distances Matrix storing distances of neighbors for each query
74    *     point.
75    */
76   void GetResults(arma::Mat<size_t>& neighbors, arma::mat& distances);
77 
78   /**
79    * Get the distance from the query point to the reference point.
80    * This will update the list of candidates with the new point if appropriate.
81    *
82    * @param queryIndex Index of query point.
83    * @param referenceIndex Index of reference point.
84    */
85   double BaseCase(const size_t queryIndex, const size_t referenceIndex);
86 
87   /**
88    * Get the score for recursion order.  A low score indicates priority for
89    * recursion, while DBL_MAX indicates that the node should not be recursed
90    * into at all (it should be pruned).
91    *
92    * For rank-approximation, the scoring function first checks if pruning
93    * by distance is possible.
94    *   If yes, then the node is given the score of
95    *   'DBL_MAX' and the expected number of samples from that node are
96    *   added to the number of samples made for the query.
97    *
98    *   If no, then the function tries to see if the node can be pruned by
99    *   approximation. If number of samples required from this node is small
100    *   enough, then that number of samples are acquired from this node
101    *   and the score is set to be 'DBL_MAX'.
102    *
103    *   If the pruning by approximation is not possible either, the algorithm
104    *   continues with the usual tree-traversal.
105    *
106    * @param queryIndex Index of query point.
107    * @param referenceNode Candidate node to be recursed into.
108    */
109   double Score(const size_t queryIndex, TreeType& referenceNode);
110 
111   /**
112    * Get the score for recursion order.  A low score indicates priority for
113    * recursion, while DBL_MAX indicates that the node should not be recursed
114    * into at all (it should be pruned).
115    *
116    * For rank-approximation, the scoring function first checks if pruning
117    * by distance is possible.
118    *   If yes, then the node is given the score of
119    *   'DBL_MAX' and the expected number of samples from that node are
120    *   added to the number of samples made for the query.
121    *
122    *   If no, then the function tries to see if the node can be pruned by
123    *   approximation. If number of samples required from this node is small
124    *   enough, then that number of samples are acquired from this node
125    *   and the score is set to be 'DBL_MAX'.
126    *
127    *   If the pruning by approximation is not possible either, the algorithm
128    *   continues with the usual tree-traversal.
129    *
130    * @param queryIndex Index of query point.
131    * @param referenceNode Candidate node to be recursed into.
132    * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
133    */
134   double Score(const size_t queryIndex,
135                TreeType& referenceNode,
136                const double baseCaseResult);
137 
138   /**
139    * Re-evaluate the score for recursion order.  A low score indicates priority
140    * for recursion, while DBL_MAX indicates that the node should not be
141    * recursed into at all (it should be pruned).  This is used when the score
142    * has already been calculated, but another recursion may have modified the
143    * bounds for pruning.  So the old score is checked against the new pruning
144    * bound.
145    *
146    * For rank-approximation, it also checks if the number of samples left
147    * for a query to satisfy the rank constraint is small enough at this
148    * point of the algorithm, then this node is approximated by sampling
149    * and given a new score of 'DBL_MAX'.
150    *
151    * @param queryIndex Index of query point.
152    * @param referenceNode Candidate node to be recursed into.
153    * @param oldScore Old score produced by Score() (or Rescore()).
154    */
155   double Rescore(const size_t queryIndex,
156                  TreeType& referenceNode,
157                  const double oldScore);
158 
159   /**
160    * Get the score for recursion order.  A low score indicates priority for
161    * recursionm while DBL_MAX indicates that the node should not be recursed
162    * into at all (it should be pruned).
163    *
164    * For the rank-approximation, we check if the referenceNode can be
165    * approximated by sampling. If it can be, enough samples are made for
166    * every query in the queryNode. No further query-tree traversal is
167    * performed.
168    *
169    * The 'NumSamplesMade' query stat is propagated up the tree. And then
170    * if pruning occurs (by distance or by sampling), the 'NumSamplesMade'
171    * stat is not propagated down the tree. If no pruning occurs, the
172    * stat is propagated down the tree.
173    *
174    * @param queryNode Candidate query node to recurse into.
175    * @param referenceNode Candidate reference node to recurse into.
176    */
177   double Score(TreeType& queryNode, TreeType& referenceNode);
178 
179   /**
180    * Get the score for recursion order, passing the base case result (in the
181    * situation where it may be needed to calculate the recursion order).  A low
182    * score indicates priority for recursion, while DBL_MAX indicates that the
183    * node should not be recursed into at all (it should be pruned).
184    *
185    * For the rank-approximation, we check if the referenceNode can be
186    * approximated by sampling. If it can be, enough samples are made for
187    * every query in the queryNode. No further query-tree traversal is
188    * performed.
189    *
190    * The 'NumSamplesMade' query stat is propagated up the tree. And then
191    * if pruning occurs (by distance or by sampling), the 'NumSamplesMade'
192    * stat is not propagated down the tree. If no pruning occurs, the
193    * stat is propagated down the tree.
194    *
195    * @param queryNode Candidate query node to recurse into.
196    * @param referenceNode Candidate reference node to recurse into.
197    * @param baseCaseResult Result of BaseCase(queryIndex, referenceNode).
198    */
199   double Score(TreeType& queryNode,
200                TreeType& referenceNode,
201                const double baseCaseResult);
202 
203   /**
204    * Re-evaluate the score for recursion order.  A low score indicates priority
205    * for recursion, while DBL_MAX indicates that the node should not be
206    * recursed into at all (it should be pruned).  This is used when the score
207    * has already been calculated, but another recursion may have modified the
208    * bounds for pruning.  So the old score is checked against the new pruning
209    * bound.
210    *
211    * For the rank-approximation, we check if the referenceNode can be
212    * approximated by sampling. If it can be, enough samples are made for
213    * every query in the queryNode. No further query-tree traversal is
214    * performed.
215    *
216    * The 'NumSamplesMade' query stat is propagated up the tree. And then
217    * if pruning occurs (by distance or by sampling), the 'NumSamplesMade'
218    * stat is not propagated down the tree. If no pruning occurs, the
219    * stat is propagated down the tree.
220    *
221    * @param queryNode Candidate query node to recurse into.
222    * @param referenceNode Candidate reference node to recurse into.
223    * @param oldScore Old score produced by Socre() (or Rescore()).
224    */
225   double Rescore(TreeType& queryNode,
226                  TreeType& referenceNode,
227                  const double oldScore);
228 
229 
NumDistComputations()230   size_t NumDistComputations() { return numDistComputations; }
NumEffectiveSamples()231   size_t NumEffectiveSamples()
232   {
233     if (numSamplesMade.n_elem == 0)
234       return 0;
235     else
236       return arma::sum(numSamplesMade);
237   }
238 
239   typedef typename tree::TraversalInfo<TreeType> TraversalInfoType;
240 
TraversalInfo() const241   const TraversalInfoType& TraversalInfo() const { return traversalInfo; }
TraversalInfo()242   TraversalInfoType& TraversalInfo() { return traversalInfo; }
243 
244   //! Get the minimum number of base cases that must be performed for each query
245   //! point for an acceptable result.  This is only needed in defeatist search
246   //! mode.
MinimumBaseCases() const247   size_t MinimumBaseCases() const { return k; }
248 
249  private:
250   //! The reference set.
251   const arma::mat& referenceSet;
252 
253   //! The query set.
254   const arma::mat& querySet;
255 
256   //! Candidate represents a possible candidate neighbor (distance, index).
257   typedef std::pair<double, size_t> Candidate;
258 
259   //! Compare two candidates based on the distance.
260   struct CandidateCmp {
operator ()mlpack::neighbor::RASearchRules::CandidateCmp261     bool operator()(const Candidate& c1, const Candidate& c2)
262     {
263       return !SortPolicy::IsBetter(c2.first, c1.first);
264     };
265   };
266 
267   //! Use a priority queue to represent the list of candidate neighbors.
268   typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
269       CandidateList;
270 
271   //! Set of candidate neighbors for each point.
272   std::vector<CandidateList> candidates;
273 
274   //! Number of neighbors to search for.
275   const size_t k;
276 
277   //! The instantiated metric.
278   MetricType& metric;
279 
280   //! Whether to sample at leaves or just use all of it.
281   bool sampleAtLeaves;
282 
283   //! Whether to do exact computation on the first leaf before any sampling.
284   bool firstLeafExact;
285 
286   //! The limit on the largest node that can be approximated by sampling.
287   size_t singleSampleLimit;
288 
289   //! The minimum number of samples required per query.
290   size_t numSamplesReqd;
291 
292   //! The number of samples made for every query.
293   arma::Col<size_t> numSamplesMade;
294 
295   //! The sampling ratio.
296   double samplingRatio;
297 
298   //! The number of distance calculations performed during search.
299   size_t numDistComputations;
300 
301   //! If the query and reference set are identical, this is true.
302   bool sameSet;
303 
304   TraversalInfoType traversalInfo;
305 
306   /**
307    * Helper function to insert a point into the list of candidate points.
308    *
309    * @param queryIndex Index of point whose neighbors we are inserting into.
310    * @param neighbor Index of reference point which is being inserted.
311    * @param distance Distance from query point to reference point.
312    */
313   void InsertNeighbor(const size_t queryIndex,
314                       const size_t neighbor,
315                       const double distance);
316 
317   /**
318    * Perform actual scoring for single-tree case.
319    */
320   double Score(const size_t queryIndex,
321                TreeType& referenceNode,
322                const double distance,
323                const double bestDistance);
324 
325   /**
326    * Perform actual scoring for dual-tree case.
327    */
328   double Score(TreeType& queryNode,
329                TreeType& referenceNode,
330                const double distance,
331                const double bestDistance);
332 
333   static_assert(tree::TreeTraits<TreeType>::UniqueNumDescendants, "TreeType "
334       "must provide a unique number of descendants points.");
335 }; // class RASearchRules
336 
337 } // namespace neighbor
338 } // namespace mlpack
339 
340 // Include implementation.
341 #include "ra_search_rules_impl.hpp"
342 
343 #endif // MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP
344