1 // -*- C++ -*-
2 /**
3 * @brief KDTree structure to speed-up queries on large samples
4 *
5 * Copyright 2005-2021 Airbus-EDF-IMACS-ONERA-Phimeca
6 *
7 * This library is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU Lesser General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * This library is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 * GNU Lesser General Public License for more details.
16 *
17 * You should have received a copy of the GNU Lesser General Public License
18 * along with this library. If not, see <http://www.gnu.org/licenses/>.
19 *
20 */
21
22 #include "openturns/KDTree.hxx"
23 #include "openturns/Exception.hxx"
24 #include "openturns/SpecFunc.hxx"
25 #include "openturns/Indices.hxx"
26 #include "openturns/SobolSequence.hxx"
27 #include "openturns/PersistentObjectFactory.hxx"
28
29 BEGIN_NAMESPACE_OPENTURNS
30
31 CLASSNAMEINIT(KDTree)
32
33 static const Factory<KDTree> Factory_KDTree;
34
35 /**
36 * @class KDNearestNeighboursFinder
37 *
38 * A fixed-size heap to find k-nearest neighbours in a KDTree
39 */
40 class KDNearestNeighboursFinder
41 {
42 public:
43
44 /** Constructor */
KDNearestNeighboursFinder(const Indices & tree,const Sample & sample,const Interval & boundingBox,const UnsignedInteger size)45 KDNearestNeighboursFinder(const Indices & tree, const Sample & sample, const Interval & boundingBox, const UnsignedInteger size)
46 : tree_(tree)
47 , sample_(sample)
48 , boundingBox_(boundingBox)
49 , capacity_(size)
50 , size_(0)
51 , values_(size)
52 , indices_(size)
53 {
54 // Initialize values_[0] to a valid value
55 values_[0] = SpecFunc::MaxScalar;
56 }
57
58 /** Get the indices of the k nearest neighbours of the given point */
getNearestNeighboursIndices(const UnsignedInteger inode,const Point & x,const Bool sorted)59 Indices getNearestNeighboursIndices(const UnsignedInteger inode, const Point & x, const Bool sorted)
60 {
61 if (size_ != 0)
62 {
63 // Clear heap
64 indices_.clear();
65 values_.clear();
66 size_ = 0;
67 values_[0] = SpecFunc::MaxScalar;
68 }
69 Point lowerBoundingBox(boundingBox_.getLowerBound());
70 Point upperBoundingBox(boundingBox_.getUpperBound());
71 collectNearestNeighbours(inode, x, lowerBoundingBox, upperBoundingBox, 0);
72 if (sorted)
73 {
74 /* Sort heap in-place in ascending order.
75 This breaks heap structure, but it does not matter, heap is
76 rebuilt when calling collectNearestNeighbours.
77 */
78 const UnsignedInteger realSize = size_;
79 while (size_ > 1)
80 {
81 // Move largest value at the end
82 std::swap(values_[size_ - 1], values_[0]);
83 std::swap(indices_[size_ - 1], indices_[0]);
84 // Make heap believe that largest value has been removed
85 --size_;
86 // Move new root to a valid location
87 moveNodeDown(0);
88 }
89 // Restore heap size
90 size_ = realSize;
91 }
92
93 return indices_;
94 }
95
96 private:
97 /* Recursive method to find the indices of the k nearest neighbours
98 Strategy:
99 + for a new candidate, if there is still room just add it to the list of neighbours
100 + else replace the worst candidate from the current list by the new candidate
101 Complexity: O(k) at each insertion, O(log(n)) expected insertions
102 */
collectNearestNeighbours(const UnsignedInteger inode,const Point & x,Point & lowerBoundingBox,Point & upperBoundingBox,const UnsignedInteger activeDimension)103 void collectNearestNeighbours(const UnsignedInteger inode,
104 const Point & x,
105 Point & lowerBoundingBox, Point & upperBoundingBox,
106 const UnsignedInteger activeDimension)
107 {
108 const Scalar splitValue = sample_(tree_[3 * inode], activeDimension);
109 const Scalar delta = x[activeDimension] - splitValue;
110 const UnsignedInteger sameSide(delta < 0.0 ? tree_[3 * inode + 1] : tree_[3 * inode + 2]);
111 const UnsignedInteger oppositeSide(delta < 0.0 ? tree_[3 * inode + 2] : tree_[3 * inode + 1]);
112 const UnsignedInteger dimension = sample_.getDimension();
113 const UnsignedInteger nextActiveDimension = (activeDimension + 1) % dimension;
114 const Scalar saveLower = lowerBoundingBox[activeDimension];
115 const Scalar saveUpper = upperBoundingBox[activeDimension];
116 Scalar currentGreatestValidSquaredDistance = values_[0];
117 if (sameSide != 0)
118 {
119 // Update bounding box to match sameSide bounding box
120 if (delta < 0.0)
121 upperBoundingBox[activeDimension] = splitValue;
122 else
123 lowerBoundingBox[activeDimension] = splitValue;
124 // Compute distance between x and sameSide
125 Scalar squaredDistanceBoundingBox = 0.0;
126 for(UnsignedInteger i = 0; i < dimension; ++i)
127 {
128 Scalar difference = std::max(0.0, std::max(x[i] - upperBoundingBox[i], lowerBoundingBox[i] - x[i]));
129 squaredDistanceBoundingBox += difference * difference;
130 }
131 if (squaredDistanceBoundingBox < values_[0])
132 {
133 collectNearestNeighbours(sameSide, x, lowerBoundingBox, upperBoundingBox, nextActiveDimension);
134 currentGreatestValidSquaredDistance = values_[0];
135 }
136 // Restore bounding box
137 if (delta < 0.0)
138 upperBoundingBox[activeDimension] = saveUpper;
139 else
140 lowerBoundingBox[activeDimension] = saveLower;
141 }
142 if (size_ == capacity_ && currentGreatestValidSquaredDistance < delta * delta) return;
143 const UnsignedInteger localIndex = tree_[3 * inode];
144 // Similar to (x - sample_[localIndex]).normSquare() but it is better to avoid Point creation
145 Scalar localSquaredDistance = 0.0;
146 for(UnsignedInteger i = 0; i < dimension; ++i)
147 localSquaredDistance += (x[i] - sample_(localIndex, i)) * (x[i] - sample_(localIndex, i));
148 if (size_ != capacity_)
149 {
150 // Put index/value at the first free node and move it up to a valid location
151 indices_[size_] = localIndex;
152 values_[size_] = localSquaredDistance;
153 moveNodeUp(size_);
154 ++size_;
155 }
156 else if (localSquaredDistance < currentGreatestValidSquaredDistance)
157 {
158 // Heap is full, and current value is smaller than heap largest value.
159 // Replace the largest value by current value and move it down to a
160 // valid location.
161 if (localSquaredDistance < values_[0])
162 {
163 indices_[0] = localIndex;
164 values_[0] = localSquaredDistance;
165 moveNodeDown(0);
166 }
167 }
168 if (oppositeSide != 0)
169 {
170 // Update bounding box to match oppositeSide bounding box
171 if (delta < 0.0)
172 lowerBoundingBox[activeDimension] = splitValue;
173 else
174 upperBoundingBox[activeDimension] = splitValue;
175 // Compute distance between x and oppositeSide
176 Scalar squaredDistanceBoundingBox = 0.0;
177 for(UnsignedInteger i = 0; i < dimension; ++i)
178 {
179 Scalar difference = std::max(0.0, std::max(x[i] - upperBoundingBox[i], lowerBoundingBox[i] - x[i]));
180 squaredDistanceBoundingBox += difference * difference;
181 }
182 if (squaredDistanceBoundingBox < values_[0])
183 collectNearestNeighbours(oppositeSide, x, lowerBoundingBox, upperBoundingBox, nextActiveDimension);
184 // Restore bounding box
185 if (delta < 0.0)
186 lowerBoundingBox[activeDimension] = saveLower;
187 else
188 upperBoundingBox[activeDimension] = saveUpper;
189 }
190 }
191
192 /** Move node down to its final location */
moveNodeDown(const UnsignedInteger index)193 void moveNodeDown(const UnsignedInteger index)
194 {
195 const UnsignedInteger left = (index << 1) + 1;
196 const UnsignedInteger right = left + 1;
197 UnsignedInteger maxValueIndex = index;
198 if (left < size_ && values_[left] > values_[maxValueIndex])
199 {
200 maxValueIndex = left;
201 }
202 if (right < size_ && values_[right] > values_[maxValueIndex])
203 {
204 maxValueIndex = right;
205 }
206 if (index != maxValueIndex)
207 {
208 std::swap(values_[index], values_[maxValueIndex]);
209 std::swap(indices_[index], indices_[maxValueIndex]);
210 moveNodeDown(maxValueIndex);
211 }
212 }
213
214 /** Move node up to its final location */
moveNodeUp(const UnsignedInteger index)215 void moveNodeUp(const UnsignedInteger index)
216 {
217 if (index == 0) return;
218 const UnsignedInteger parent = (index - 1) >> 1;
219 if (values_[index] > values_[parent])
220 {
221 std::swap(values_[index], values_[parent]);
222 std::swap(indices_[index], indices_[parent]);
223 moveNodeUp(parent);
224 }
225 }
226
227 // Reference to tree
228 const Indices & tree_;
229 // Points contained in the tree
230 const Sample sample_;
231 // Points bounding box
232 const Interval boundingBox_;
233 // Maximum heap size
234 const UnsignedInteger capacity_;
235 // Number of used buckets
236 UnsignedInteger size_;
237 // Array containing values
238 Collection<Scalar> values_;
239 // Array containing point indices
240 Indices indices_;
241
242 }; /* class KDNearestNeighboursFinder */
243
244 /**
245 * @class KDTree
246 *
247 * Organize d-dimensional points into a hierarchical tree-like structure
248 */
249
250 /* Default constructor */
KDTree()251 KDTree::KDTree()
252 : NearestNeighbourAlgorithmImplementation()
253 , points_(0, 0)
254 , boundingBox_()
255 , tree_()
256 {
257 // Nothing to do
258 }
259
260 /* Parameters constructor */
KDTree(const Sample & points)261 KDTree::KDTree(const Sample & points)
262 : NearestNeighbourAlgorithmImplementation()
263 , points_(0, 0)
264 , boundingBox_()
265 , tree_()
266 {
267 // Build the tree
268 setSample(points);
269 }
270
271 /* Sample accessor */
getSample() const272 Sample KDTree::getSample() const
273 {
274 return points_;
275 }
276
setSample(const Sample & points)277 void KDTree::setSample(const Sample & points)
278 {
279 if (points == points_) return;
280
281 points_ = points;
282 boundingBox_ = Interval(points_.getDimension());
283 tree_ = Indices(3 * (points_.getSize() + 1));
284
285 // Scramble the order in which the points are inserted in the tree in order to improve its balancing
286 const UnsignedInteger size = points_.getSize();
287 Indices buffer(size);
288 buffer.fill();
289 SobolSequence sequence(1);
290 UnsignedInteger root = 0;
291 UnsignedInteger currentSize = 0;
292 for (UnsignedInteger i = 0; i < points_.getSize(); ++ i)
293 {
294 const UnsignedInteger index = i + static_cast< UnsignedInteger >((size - i) * sequence.generate()[0]);
295 insert(root, currentSize, buffer[index], 0);
296 buffer[index] = buffer[i];
297 }
298 boundingBox_.setLowerBound(points_.getMin());
299 boundingBox_.setUpperBound(points_.getMax());
300 }
301
302 /* Virtual constructor */
clone() const303 KDTree * KDTree::clone() const
304 {
305 return new KDTree( *this );
306 }
307
308 /* Virtual default constructor */
emptyClone() const309 KDTree * KDTree::emptyClone() const
310 {
311 return new KDTree();
312 }
313
314 /* String converter */
__repr__() const315 String KDTree::__repr__() const
316 {
317 return OSS(true) << "class=" << GetClassName()
318 << " root=" << (tree_.getSize() > 0 ? printNode(1) : "NULL");
319 }
320
__str__(const String &) const321 String KDTree::__str__(const String & ) const
322 {
323 return OSS(false) << "class=" << GetClassName()
324 << " root=" << (tree_.getSize() > 0 ? printNode(1) : "NULL");
325 }
326
printNode(const UnsignedInteger inode) const327 String KDTree::printNode(const UnsignedInteger inode) const
328 {
329 return OSS() << "class=KDNode"
330 << " index=" << tree_[3 * inode]
331 << " left=" << (tree_[3 * inode + 1] ? printNode(tree_[3 * inode + 1]) : "NULL")
332 << " right=" << (tree_[3 * inode + 2] ? printNode(tree_[3 * inode + 2]) : "NULL");
333 }
334
335 /* Insert the point at given index into the tree */
insert(UnsignedInteger & inode,UnsignedInteger & currentSize,const UnsignedInteger index,const UnsignedInteger activeDimension)336 void KDTree::insert(UnsignedInteger & inode,
337 UnsignedInteger & currentSize,
338 const UnsignedInteger index,
339 const UnsignedInteger activeDimension)
340 {
341 if (!(index < points_.getSize())) throw InvalidArgumentException(HERE) << "Error: expected an index less than " << points_.getSize() << ", got " << index;
342 // We are on a leaf
343 if (inode == 0)
344 {
345 ++currentSize;
346 inode = currentSize;
347 tree_[3 * inode] = index;
348 }
349 else if (points_(index, activeDimension) < points_(tree_[3 * inode], activeDimension))
350 insert(tree_[3 * inode + 1], currentSize, index, (activeDimension + 1) % points_.getDimension());
351 else
352 insert(tree_[3 * inode + 2], currentSize, index, (activeDimension + 1) % points_.getDimension());
353 }
354
355 /* Get the index of the nearest neighbour of the given point */
query(const Point & x) const356 UnsignedInteger KDTree::query(const Point & x) const
357 {
358 if (points_.getSize() == 1) return 0;
359 Scalar smallestDistance = SpecFunc::MaxScalar;
360 Point lowerBoundingBox(boundingBox_.getLowerBound());
361 Point upperBoundingBox(boundingBox_.getUpperBound());
362 return getNearestNeighbourIndex(1, x, smallestDistance, lowerBoundingBox, upperBoundingBox, 0);
363 }
364
getNearestNeighbourIndex(const UnsignedInteger inode,const Point & x,Scalar & bestSquaredDistance,Point & lowerBoundingBox,Point & upperBoundingBox,const UnsignedInteger activeDimension) const365 UnsignedInteger KDTree::getNearestNeighbourIndex(const UnsignedInteger inode,
366 const Point & x,
367 Scalar & bestSquaredDistance,
368 Point & lowerBoundingBox,
369 Point & upperBoundingBox,
370 const UnsignedInteger activeDimension) const
371 {
372 if (!(inode > 0)) throw NotDefinedException(HERE) << "Error: cannot find a nearest neighbour in an empty tree";
373 // Set delta = x[activeDimension] - points_(tree_[3*inode], activeDimension)
374 // sameSide = tree_(inode, 0) if delta < 0, tree_[3*inode+2] else
375 // oppositeSide = tree_[3*inode+2] if delta < 0, tree_(inode, 0) else
376 // Possibilities:
377 // 1) sameSide != 0
378 // 1.1) Go on the same side. On return, the index is the best candidate index. If the associated distance is less than the current best index, update the current best index and the associated squared distance.
379 // 2) Check if the current best squared distance is less than delta^2
380 // 2.1*) If yes, the points associated to inode or to its opposite side can't be better than the current best candidate so return it and the associated squared distance to the upper level
381 // 2.2) If no, check the point associated to the current node and update the current best index and the associated squared distance
382 // 2.3) oppositeSide != 0
383 // 2.4) Go on the opposite side. On return, check if the returned squared distance is less than the current best squared distance and update the current best index and the associated squared distance.
384 // 3*) Return the current best index and the associated squared distance to the upper level
385
386 const Scalar splitValue = points_(tree_[3 * inode], activeDimension);
387 const Scalar delta = x[activeDimension] - splitValue;
388 const UnsignedInteger sameSide(delta < 0.0 ? tree_[3 * inode + 1] : tree_[3 * inode + 2]);
389 const UnsignedInteger oppositeSide(delta < 0.0 ? tree_[3 * inode + 2] : tree_[3 * inode + 1]);
390 UnsignedInteger currentBestIndex = points_.getSize();
391 Scalar currentBestSquaredDistance = bestSquaredDistance;
392 const UnsignedInteger dimension = points_.getDimension();
393 const UnsignedInteger nextActiveDimension = (activeDimension + 1) % dimension;
394 const Scalar saveLower = lowerBoundingBox[activeDimension];
395 const Scalar saveUpper = upperBoundingBox[activeDimension];
396 // 1)
397 if (sameSide != 0)
398 {
399 // 1.1)
400 // Update bounding box to match sameSide bounding box
401 if (delta < 0.0)
402 upperBoundingBox[activeDimension] = splitValue;
403 else
404 lowerBoundingBox[activeDimension] = splitValue;
405 // Compute distance between x and sameSide
406 Scalar squaredDistanceBoundingBox = 0.0;
407 for(UnsignedInteger i = 0; i < dimension; ++i)
408 {
409 Scalar difference = std::max(0.0, std::max(x[i] - upperBoundingBox[i], lowerBoundingBox[i] - x[i]));
410 squaredDistanceBoundingBox += difference * difference;
411 }
412 if (squaredDistanceBoundingBox < currentBestSquaredDistance)
413 {
414 UnsignedInteger candidateBestIndex = getNearestNeighbourIndex(sameSide, x, bestSquaredDistance, lowerBoundingBox, upperBoundingBox, nextActiveDimension);
415 if (bestSquaredDistance < currentBestSquaredDistance)
416 {
417 currentBestSquaredDistance = bestSquaredDistance;
418 currentBestIndex = candidateBestIndex;
419 }
420 }
421 // Restore bounding box
422 if (delta < 0.0)
423 upperBoundingBox[activeDimension] = saveUpper;
424 else
425 lowerBoundingBox[activeDimension] = saveLower;
426 } // sameSide != 0
427 // 2)
428 if (currentBestSquaredDistance < delta * delta)
429 {
430 // 2.1)
431 bestSquaredDistance = currentBestSquaredDistance;
432 return currentBestIndex;
433 }
434 // 2.2)
435 const UnsignedInteger localIndex = tree_[3 * inode];
436 // Similar to (x - points_[localIndex]).normSquare() but it is better to avoid Point creation
437 Scalar localSquaredDistance = 0.0;
438 for(UnsignedInteger i = 0; i < dimension; ++i)
439 localSquaredDistance += (x[i] - points_(localIndex, i)) * (x[i] - points_(localIndex, i));
440 if (localSquaredDistance < currentBestSquaredDistance)
441 {
442 currentBestSquaredDistance = localSquaredDistance;
443 // To send the current best squared distance to the lower levels
444 bestSquaredDistance = localSquaredDistance;
445 currentBestIndex = localIndex;
446 }
447 // 2.3)
448 if (oppositeSide != 0)
449 {
450 // Update bounding box to match oppositeSide bounding box
451 if (delta < 0.0)
452 lowerBoundingBox[activeDimension] = splitValue;
453 else
454 upperBoundingBox[activeDimension] = splitValue;
455 // Compute distance between x and oppositeSide
456 Scalar squaredDistanceBoundingBox = 0.0;
457 for(UnsignedInteger i = 0; i < dimension; ++i)
458 {
459 Scalar difference = std::max(0.0, std::max(x[i] - upperBoundingBox[i], lowerBoundingBox[i] - x[i]));
460 squaredDistanceBoundingBox += difference * difference;
461 }
462 // 2.4)
463 if (squaredDistanceBoundingBox < currentBestSquaredDistance)
464 {
465 UnsignedInteger candidateBestIndex = getNearestNeighbourIndex(oppositeSide, x, bestSquaredDistance, lowerBoundingBox, upperBoundingBox, nextActiveDimension);
466 if (bestSquaredDistance < currentBestSquaredDistance)
467 {
468 currentBestSquaredDistance = bestSquaredDistance;
469 currentBestIndex = candidateBestIndex;
470 }
471 }
472 // Restore bounding box
473 if (delta < 0.0)
474 lowerBoundingBox[activeDimension] = saveLower;
475 else
476 upperBoundingBox[activeDimension] = saveUpper;
477 } // oppositeSide != 0
478 // 3)
479 bestSquaredDistance = currentBestSquaredDistance;
480 return currentBestIndex;
481 }
482
483 /* Get the indices of the k nearest neighbours of the given point */
queryK(const Point & x,const UnsignedInteger k,const Bool sorted) const484 Indices KDTree::queryK(const Point & x, const UnsignedInteger k, const Bool sorted) const
485 {
486 if (k > points_.getSize()) throw InvalidArgumentException(HERE) << "Error: cannot return more neighbours than points in the database!";
487 Indices result(k);
488 // If we need as many neighbours as points in the sample, just return all the possible indices
489 if (k == points_.getSize() && !sorted)
490 {
491 result.fill();
492 }
493 else
494 {
495 KDNearestNeighboursFinder heap(tree_, points_, boundingBox_, k);
496 result = heap.getNearestNeighboursIndices(1, x, sorted);
497 }
498 return result;
499 }
500
501 /** Method save() stores the object through the StorageManager */
save(Advocate & adv) const502 void KDTree::save(Advocate & adv) const
503 {
504 NearestNeighbourAlgorithmImplementation::save(adv);
505 adv.saveAttribute("points_", points_);
506 }
507
508 /** Method load() reloads the object from the StorageManager */
load(Advocate & adv)509 void KDTree::load(Advocate & adv)
510 {
511 NearestNeighbourAlgorithmImplementation::load(adv);
512 Sample points;
513 adv.loadAttribute("points_", points);
514 if (points.getSize() > 0) setSample(points);
515
516 }
517
518 END_NAMESPACE_OPENTURNS
519