1 /**
2 * @file core/tree/binary_space_tree/dual_tree_traverser_impl.hpp
3 * @author Ryan Curtin
4 *
5 * Implementation of the DualTreeTraverser for BinarySpaceTree. This is a way
6 * to perform a dual-tree traversal of two trees. The trees must be the same
7 * type.
8 *
9 * mlpack is free software; you may redistribute it and/or modify it under the
10 * terms of the 3-clause BSD license. You should have received a copy of the
11 * 3-clause BSD license along with mlpack. If not, see
12 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
13 */
14 #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
15 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
16
17 // In case it hasn't been included yet.
18 #include "dual_tree_traverser.hpp"
19
20 namespace mlpack {
21 namespace tree {
22
23 template<typename MetricType,
24 typename StatisticType,
25 typename MatType,
26 template<typename BoundMetricType, typename...> class BoundType,
27 template<typename SplitBoundType, typename SplitMatType>
28 class SplitType>
29 template<typename RuleType>
30 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
DualTreeTraverser(RuleType & rule)31 DualTreeTraverser<RuleType>::DualTreeTraverser(RuleType& rule) :
32 rule(rule),
33 numPrunes(0),
34 numVisited(0),
35 numScores(0),
36 numBaseCases(0)
37 { /* Nothing to do. */ }
38
39 template<typename MetricType,
40 typename StatisticType,
41 typename MatType,
42 template<typename BoundMetricType, typename...> class BoundType,
43 template<typename SplitBoundType, typename SplitMatType>
44 class SplitType>
45 template<typename RuleType>
46 void BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>::
Traverse(BinarySpaceTree<MetricType,StatisticType,MatType,BoundType,SplitType> & queryNode,BinarySpaceTree<MetricType,StatisticType,MatType,BoundType,SplitType> & referenceNode)47 DualTreeTraverser<RuleType>::Traverse(
48 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>&
49 queryNode,
50 BinarySpaceTree<MetricType, StatisticType, MatType, BoundType, SplitType>&
51 referenceNode)
52 {
53 // Increment the visit counter.
54 ++numVisited;
55
56 // Store the current traversal info.
57 traversalInfo = rule.TraversalInfo();
58
59 // If both nodes are root nodes, just score them.
60 if (queryNode.Parent() == NULL && referenceNode.Parent() == NULL)
61 {
62 const double rootScore = rule.Score(queryNode, referenceNode);
63 // If root score is DBL_MAX, don't recurse.
64 if (rootScore == DBL_MAX)
65 {
66 ++numPrunes;
67 return;
68 }
69 }
70
71 // If both are leaves, we must evaluate the base case.
72 if (queryNode.IsLeaf() && referenceNode.IsLeaf())
73 {
74 // Loop through each of the points in each node.
75 const size_t queryEnd = queryNode.Begin() + queryNode.Count();
76 const size_t refEnd = referenceNode.Begin() + referenceNode.Count();
77 for (size_t query = queryNode.Begin(); query < queryEnd; ++query)
78 {
79 // See if we need to investigate this point (this function should be
80 // implemented for the single-tree recursion too). Restore the traversal
81 // information first.
82 rule.TraversalInfo() = traversalInfo;
83 const double childScore = rule.Score(query, referenceNode);
84
85 if (childScore == DBL_MAX)
86 continue; // We can't improve this particular point.
87
88 for (size_t ref = referenceNode.Begin(); ref < refEnd; ++ref)
89 rule.BaseCase(query, ref);
90
91 numBaseCases += referenceNode.Count();
92 }
93 }
94 else if (((!queryNode.IsLeaf()) && referenceNode.IsLeaf()) ||
95 (queryNode.NumDescendants() > 3 * referenceNode.NumDescendants() &&
96 !queryNode.IsLeaf() && !referenceNode.IsLeaf()))
97 {
98 // We have to recurse down the query node. In this case the recursion order
99 // does not matter.
100 const double leftScore = rule.Score(*queryNode.Left(), referenceNode);
101 ++numScores;
102
103 if (leftScore != DBL_MAX)
104 Traverse(*queryNode.Left(), referenceNode);
105 else
106 ++numPrunes;
107
108 // Before recursing, we have to set the traversal information correctly.
109 rule.TraversalInfo() = traversalInfo;
110 const double rightScore = rule.Score(*queryNode.Right(), referenceNode);
111 ++numScores;
112
113 if (rightScore != DBL_MAX)
114 Traverse(*queryNode.Right(), referenceNode);
115 else
116 ++numPrunes;
117 }
118 else if (queryNode.IsLeaf() && (!referenceNode.IsLeaf()))
119 {
120 // We have to recurse down the reference node. In this case the recursion
121 // order does matter. Before recursing, though, we have to set the
122 // traversal information correctly.
123 double leftScore = rule.Score(queryNode, *referenceNode.Left());
124 typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
125 rule.TraversalInfo() = traversalInfo;
126 double rightScore = rule.Score(queryNode, *referenceNode.Right());
127 numScores += 2;
128
129 if (leftScore < rightScore)
130 {
131 // Recurse to the left. Restore the left traversal info. Store the right
132 // traversal info.
133 traversalInfo = rule.TraversalInfo();
134 rule.TraversalInfo() = leftInfo;
135 Traverse(queryNode, *referenceNode.Left());
136
137 // Is it still valid to recurse to the right?
138 rightScore = rule.Rescore(queryNode, *referenceNode.Right(), rightScore);
139
140 if (rightScore != DBL_MAX)
141 {
142 // Restore the right traversal info.
143 rule.TraversalInfo() = traversalInfo;
144 Traverse(queryNode, *referenceNode.Right());
145 }
146 else
147 ++numPrunes;
148 }
149 else if (rightScore < leftScore)
150 {
151 // Recurse to the right.
152 Traverse(queryNode, *referenceNode.Right());
153
154 // Is it still valid to recurse to the left?
155 leftScore = rule.Rescore(queryNode, *referenceNode.Left(), leftScore);
156
157 if (leftScore != DBL_MAX)
158 {
159 // Restore the left traversal info.
160 rule.TraversalInfo() = leftInfo;
161 Traverse(queryNode, *referenceNode.Left());
162 }
163 else
164 ++numPrunes;
165 }
166 else // leftScore is equal to rightScore.
167 {
168 if (leftScore == DBL_MAX)
169 {
170 numPrunes += 2;
171 }
172 else
173 {
174 // Choose the left first. Restore the left traversal info. Store the
175 // right traversal info.
176 traversalInfo = rule.TraversalInfo();
177 rule.TraversalInfo() = leftInfo;
178 Traverse(queryNode, *referenceNode.Left());
179
180 rightScore = rule.Rescore(queryNode, *referenceNode.Right(),
181 rightScore);
182
183 if (rightScore != DBL_MAX)
184 {
185 // Restore the right traversal info.
186 rule.TraversalInfo() = traversalInfo;
187 Traverse(queryNode, *referenceNode.Right());
188 }
189 else
190 ++numPrunes;
191 }
192 }
193 }
194 else
195 {
196 // We have to recurse down both query and reference nodes. Because the
197 // query descent order does not matter, we will go to the left query child
198 // first. Before recursing, we have to set the traversal information
199 // correctly.
200 double leftScore = rule.Score(*queryNode.Left(), *referenceNode.Left());
201 typename RuleType::TraversalInfoType leftInfo = rule.TraversalInfo();
202 rule.TraversalInfo() = traversalInfo;
203 double rightScore = rule.Score(*queryNode.Left(), *referenceNode.Right());
204 typename RuleType::TraversalInfoType rightInfo;
205 numScores += 2;
206
207 if (leftScore < rightScore)
208 {
209 // Recurse to the left. Restore the left traversal info. Store the right
210 // traversal info.
211 rightInfo = rule.TraversalInfo();
212 rule.TraversalInfo() = leftInfo;
213 Traverse(*queryNode.Left(), *referenceNode.Left());
214
215 // Is it still valid to recurse to the right?
216 rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
217 rightScore);
218
219 if (rightScore != DBL_MAX)
220 {
221 // Restore the right traversal info.
222 rule.TraversalInfo() = rightInfo;
223 Traverse(*queryNode.Left(), *referenceNode.Right());
224 }
225 else
226 ++numPrunes;
227 }
228 else if (rightScore < leftScore)
229 {
230 // Recurse to the right.
231 Traverse(*queryNode.Left(), *referenceNode.Right());
232
233 // Is it still valid to recurse to the left?
234 leftScore = rule.Rescore(*queryNode.Left(), *referenceNode.Left(),
235 leftScore);
236
237 if (leftScore != DBL_MAX)
238 {
239 // Restore the left traversal info.
240 rule.TraversalInfo() = leftInfo;
241 Traverse(*queryNode.Left(), *referenceNode.Left());
242 }
243 else
244 ++numPrunes;
245 }
246 else
247 {
248 if (leftScore == DBL_MAX)
249 {
250 numPrunes += 2;
251 }
252 else
253 {
254 // Choose the left first. Restore the left traversal info and store the
255 // right traversal info.
256 rightInfo = rule.TraversalInfo();
257 rule.TraversalInfo() = leftInfo;
258 Traverse(*queryNode.Left(), *referenceNode.Left());
259
260 // Is it still valid to recurse to the right?
261 rightScore = rule.Rescore(*queryNode.Left(), *referenceNode.Right(),
262 rightScore);
263
264 if (rightScore != DBL_MAX)
265 {
266 // Restore the right traversal information.
267 rule.TraversalInfo() = rightInfo;
268 Traverse(*queryNode.Left(), *referenceNode.Right());
269 }
270 else
271 ++numPrunes;
272 }
273 }
274
275 // Restore the main traversal information.
276 rule.TraversalInfo() = traversalInfo;
277
278 // Now recurse down the right query node.
279 leftScore = rule.Score(*queryNode.Right(), *referenceNode.Left());
280 leftInfo = rule.TraversalInfo();
281 rule.TraversalInfo() = traversalInfo;
282 rightScore = rule.Score(*queryNode.Right(), *referenceNode.Right());
283 numScores += 2;
284
285 if (leftScore < rightScore)
286 {
287 // Recurse to the left. Restore the left traversal info. Store the right
288 // traversal info.
289 rightInfo = rule.TraversalInfo();
290 rule.TraversalInfo() = leftInfo;
291 Traverse(*queryNode.Right(), *referenceNode.Left());
292
293 // Is it still valid to recurse to the right?
294 rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
295 rightScore);
296
297 if (rightScore != DBL_MAX)
298 {
299 // Restore the right traversal info.
300 rule.TraversalInfo() = rightInfo;
301 Traverse(*queryNode.Right(), *referenceNode.Right());
302 }
303 else
304 ++numPrunes;
305 }
306 else if (rightScore < leftScore)
307 {
308 // Recurse to the right.
309 Traverse(*queryNode.Right(), *referenceNode.Right());
310
311 // Is it still valid to recurse to the left?
312 leftScore = rule.Rescore(*queryNode.Right(), *referenceNode.Left(),
313 leftScore);
314
315 if (leftScore != DBL_MAX)
316 {
317 // Restore the left traversal info.
318 rule.TraversalInfo() = leftInfo;
319 Traverse(*queryNode.Right(), *referenceNode.Left());
320 }
321 else
322 ++numPrunes;
323 }
324 else
325 {
326 if (leftScore == DBL_MAX)
327 {
328 numPrunes += 2;
329 }
330 else
331 {
332 // Choose the left first. Restore the left traversal info. Store the
333 // right traversal info.
334 rightInfo = rule.TraversalInfo();
335 rule.TraversalInfo() = leftInfo;
336 Traverse(*queryNode.Right(), *referenceNode.Left());
337
338 // Is it still valid to recurse to the right?
339 rightScore = rule.Rescore(*queryNode.Right(), *referenceNode.Right(),
340 rightScore);
341
342 if (rightScore != DBL_MAX)
343 {
344 // Restore the right traversal info.
345 rule.TraversalInfo() = rightInfo;
346 Traverse(*queryNode.Right(), *referenceNode.Right());
347 }
348 else
349 ++numPrunes;
350 }
351 }
352 }
353 }
354
355 } // namespace tree
356 } // namespace mlpack
357
358 #endif // MLPACK_CORE_TREE_BINARY_SPACE_TREE_DUAL_TREE_TRAVERSER_IMPL_HPP
359