1 /**
2  * @file core/tree/binary_space_tree/dual_tree_traverser_impl.hpp
3  * @author Ryan Curtin
4  *
5  * Implementation of the DualTreeTraverser for BinarySpaceTree.  This is a way
6  * to perform a dual-tree traversal of two trees.  The trees must be the same
7  * type.
8  *
9  * mlpack is free software; you may redistribute it and/or modify it under the
10  * terms of the 3-clause BSD license.  You should have received a copy of the
11  * 3-clause BSD license along with mlpack.  If not, see
12  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
13  */
14 #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
15 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
16 
17 // In case it hasn't been included yet.
18 #include "dual_tree_traverser.hpp"
19 
20 namespace mlpack {
21 namespace tree {
22 
23 template<typename MetricType,
24          typename StatisticType,
25          typename MatType,
26          template<typename BoundMetricType, typename...> class BoundType,
27          template<typename SplitBoundType, typename SplitMatType>
28              class SplitType>
29 template<typename RuleType>
30 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
DualTreeTraverser(RuleType & rule)31 DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
32     rule(rule),
33     numPrunes(0),
34     numVisited(0),
35     numScores(0),
36     numBaseCases(0)
37 { /* Nothing to do. */ }
38 
39 template<typename MetricType,
40          typename StatisticType,
41          typename MatType,
42          template<typename BoundMetricType, typename...> class BoundType,
43          template<typename SplitBoundType, typename SplitMatType>
44              class SplitType>
45 template<typename RuleType>
46 void BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
Traverse(BinarySpaceTree<MetricType,StatisticType,MatType,BoundType,SplitType> & queryNode,BinarySpaceTree<MetricType,StatisticType,MatType,BoundType,SplitType> & referenceNode)47 DualTreeTraverser<RuleType>::Traverse(
48     BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>&
49         queryNode,
50     BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>&
51         referenceNode)
52 {
53   // Increment the visit counter.
54   ++numVisited;
55 
56   // Store the current traversal info.
57   traversalInfo = rule.TraversalInfo();
58 
59   // If both nodes are root nodes, just score them.
60   if (queryNode.Parent() == NULL && referenceNode.Parent() == NULL)
61   {
62     const double rootScore = rule.Score(queryNode, referenceNode);
63     // If root score is DBL_MAX, don't recurse.
64     if (rootScore == DBL_MAX)
65     {
66       ++numPrunes;
67       return;
68     }
69   }
70 
71   // If both are leaves, we must evaluate the base case.
72   if (queryNode.IsLeaf() && referenceNode.IsLeaf())
73   {
74     // Loop through each of the points in each node.
75     const size_t queryEnd = queryNode.Begin() + queryNode.Count();
76     const size_t refEnd = referenceNode.Begin() + referenceNode.Count();
77     for (size_t query = queryNode.Begin(); query < queryEnd; ++query)
78     {
79       // See if we need to investigate this point (this function should be
80       // implemented for the single-tree recursion too).  Restore the traversal
81       // information first.
82       rule.TraversalInfo() = traversalInfo;
83       const double childScore = rule.Score(query, referenceNode);
84 
85       if (childScore == DBL_MAX)
86         continue; // We can't improve this particular point.
87 
88       for (size_t ref = referenceNode.Begin(); ref < refEnd; ++ref)
89         rule.BaseCase(query, ref);
90 
91       numBaseCases += referenceNode.Count();
92     }
93   }
94   else if (((!queryNode.IsLeaf()) && referenceNode.IsLeaf()) ||
95            (queryNode.NumDescendants() > 3 * referenceNode.NumDescendants() &&
96             !queryNode.IsLeaf() && !referenceNode.IsLeaf()))
97   {
98     // We have to recurse down the query node.  In this case the recursion order
99     // does not matter.
100     const double leftScore = rule.Score(*queryNode.Left(), referenceNode);
101     ++numScores;
102 
103     if (leftScore != DBL_MAX)
104       Traverse(*queryNode.Left(), referenceNode);
105     else
106       ++numPrunes;
107 
108     // Before recursing, we have to set the traversal information correctly.
109     rule.TraversalInfo() = traversalInfo;
110     const double rightScore = rule.Score(*queryNode.Right(), referenceNode);
111     ++numScores;
112 
113     if (rightScore != DBL_MAX)
114       Traverse(*queryNode.Right(), referenceNode);
115     else
116       ++numPrunes;
117   }
118   else if (queryNode.IsLeaf() && (!referenceNode.IsLeaf()))
119   {
120     // We have to recurse down the reference node.  In this case the recursion
121     // order does matter.  Before recursing, though, we have to set the
122     // traversal information correctly.
123     double leftScore = rule.Score(queryNode, *referenceNode.Left());
124     typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
125     rule.TraversalInfo() = traversalInfo;
126     double rightScore = rule.Score(queryNode, *referenceNode.Right());
127     numScores += 2;
128 
129     if (leftScore < rightScore)
130     {
131       // Recurse to the left.  Restore the left traversal info.  Store the right
132       // traversal info.
133       traversalInfo = rule.TraversalInfo();
134       rule.TraversalInfo() = leftInfo;
135       Traverse(queryNode, *referenceNode.Left());
136 
137       // Is it still valid to recurse to the right?
138       rightScore = rule.Rescore(queryNode, *referenceNode.Right(), rightScore);
139 
140       if (rightScore != DBL_MAX)
141       {
142         // Restore the right traversal info.
143         rule.TraversalInfo() = traversalInfo;
144         Traverse(queryNode, *referenceNode.Right());
145       }
146       else
147         ++numPrunes;
148     }
149     else if (rightScore < leftScore)
150     {
151       // Recurse to the right.
152       Traverse(queryNode, *referenceNode.Right());
153 
154       // Is it still valid to recurse to the left?
155       leftScore = rule.Rescore(queryNode, *referenceNode.Left(), leftScore);
156 
157       if (leftScore != DBL_MAX)
158       {
159         // Restore the left traversal info.
160         rule.TraversalInfo() = leftInfo;
161         Traverse(queryNode, *referenceNode.Left());
162       }
163       else
164         ++numPrunes;
165     }
166     else // leftScore is equal to rightScore.
167     {
168       if (leftScore == DBL_MAX)
169       {
170         numPrunes += 2;
171       }
172       else
173       {
174         // Choose the left first.  Restore the left traversal info.  Store the
175         // right traversal info.
176         traversalInfo = rule.TraversalInfo();
177         rule.TraversalInfo() = leftInfo;
178         Traverse(queryNode, *referenceNode.Left());
179 
180         rightScore = rule.Rescore(queryNode, *referenceNode.Right(),
181             rightScore);
182 
183         if (rightScore != DBL_MAX)
184         {
185           // Restore the right traversal info.
186           rule.TraversalInfo() = traversalInfo;
187           Traverse(queryNode, *referenceNode.Right());
188         }
189         else
190           ++numPrunes;
191       }
192     }
193   }
194   else
195   {
196     // We have to recurse down both query and reference nodes.  Because the
197     // query descent order does not matter, we will go to the left query child
198     // first.  Before recursing, we have to set the traversal information
199     // correctly.
200     double leftScore = rule.Score(*queryNode.Left(), *referenceNode.Left());
201     typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
202     rule.TraversalInfo() = traversalInfo;
203     double rightScore = rule.Score(*queryNode.Left(), *referenceNode.Right());
204     typename RuleType::TraversalInfoType rightInfo;
205     numScores += 2;
206 
207     if (leftScore < rightScore)
208     {
209       // Recurse to the left.  Restore the left traversal info.  Store the right
210       // traversal info.
211       rightInfo = rule.TraversalInfo();
212       rule.TraversalInfo() = leftInfo;
213       Traverse(*queryNode.Left(), *referenceNode.Left());
214 
215       // Is it still valid to recurse to the right?
216       rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
217           rightScore);
218 
219       if (rightScore != DBL_MAX)
220       {
221         // Restore the right traversal info.
222         rule.TraversalInfo() = rightInfo;
223         Traverse(*queryNode.Left(), *referenceNode.Right());
224       }
225       else
226         ++numPrunes;
227     }
228     else if (rightScore < leftScore)
229     {
230       // Recurse to the right.
231       Traverse(*queryNode.Left(), *referenceNode.Right());
232 
233       // Is it still valid to recurse to the left?
234       leftScore = rule.Rescore(*queryNode.Left(), *referenceNode.Left(),
235           leftScore);
236 
237       if (leftScore != DBL_MAX)
238       {
239         // Restore the left traversal info.
240         rule.TraversalInfo() = leftInfo;
241         Traverse(*queryNode.Left(), *referenceNode.Left());
242       }
243       else
244         ++numPrunes;
245     }
246     else
247     {
248       if (leftScore == DBL_MAX)
249       {
250         numPrunes += 2;
251       }
252       else
253       {
254         // Choose the left first.  Restore the left traversal info and store the
255         // right traversal info.
256         rightInfo = rule.TraversalInfo();
257         rule.TraversalInfo() = leftInfo;
258         Traverse(*queryNode.Left(), *referenceNode.Left());
259 
260         // Is it still valid to recurse to the right?
261         rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
262             rightScore);
263 
264         if (rightScore != DBL_MAX)
265         {
266           // Restore the right traversal information.
267           rule.TraversalInfo() = rightInfo;
268           Traverse(*queryNode.Left(), *referenceNode.Right());
269         }
270         else
271           ++numPrunes;
272       }
273     }
274 
275     // Restore the main traversal information.
276     rule.TraversalInfo() = traversalInfo;
277 
278     // Now recurse down the right query node.
279     leftScore = rule.Score(*queryNode.Right(), *referenceNode.Left());
280     leftInfo = rule.TraversalInfo();
281     rule.TraversalInfo() = traversalInfo;
282     rightScore = rule.Score(*queryNode.Right(), *referenceNode.Right());
283     numScores += 2;
284 
285     if (leftScore < rightScore)
286     {
287       // Recurse to the left.  Restore the left traversal info.  Store the right
288       // traversal info.
289       rightInfo = rule.TraversalInfo();
290       rule.TraversalInfo() = leftInfo;
291       Traverse(*queryNode.Right(), *referenceNode.Left());
292 
293       // Is it still valid to recurse to the right?
294       rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
295           rightScore);
296 
297       if (rightScore != DBL_MAX)
298       {
299         // Restore the right traversal info.
300         rule.TraversalInfo() = rightInfo;
301         Traverse(*queryNode.Right(), *referenceNode.Right());
302       }
303       else
304         ++numPrunes;
305     }
306     else if (rightScore < leftScore)
307     {
308       // Recurse to the right.
309       Traverse(*queryNode.Right(), *referenceNode.Right());
310 
311       // Is it still valid to recurse to the left?
312       leftScore = rule.Rescore(*queryNode.Right(), *referenceNode.Left(),
313           leftScore);
314 
315       if (leftScore != DBL_MAX)
316       {
317         // Restore the left traversal info.
318         rule.TraversalInfo() = leftInfo;
319         Traverse(*queryNode.Right(), *referenceNode.Left());
320       }
321       else
322         ++numPrunes;
323     }
324     else
325     {
326       if (leftScore == DBL_MAX)
327       {
328         numPrunes += 2;
329       }
330       else
331       {
332         // Choose the left first.  Restore the left traversal info.  Store the
333         // right traversal info.
334         rightInfo = rule.TraversalInfo();
335         rule.TraversalInfo() = leftInfo;
336         Traverse(*queryNode.Right(), *referenceNode.Left());
337 
338         // Is it still valid to recurse to the right?
339         rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
340             rightScore);
341 
342         if (rightScore != DBL_MAX)
343         {
344           // Restore the right traversal info.
345           rule.TraversalInfo() = rightInfo;
346           Traverse(*queryNode.Right(), *referenceNode.Right());
347         }
348         else
349           ++numPrunes;
350       }
351     }
352   }
353 }
354 
355 } // namespace tree
356 } // namespace mlpack
357 
358 #endif // MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
359