1 //                                               -*- C++ -*-
2 /**
3  *  @brief KDTree structure to speed-up queries on large samples
4  *
5  *  Copyright 2005-2021 Airbus-EDF-IMACS-ONERA-Phimeca
6  *
7  *  This library is free software: you can redistribute it and/or modify
8  *  it under the terms of the GNU Lesser General Public License as published by
9  *  the Free Software Foundation, either version 3 of the License, or
10  *  (at your option) any later version.
11  *
12  *  This library is distributed in the hope that it will be useful,
13  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
14  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15  *  GNU Lesser General Public License for more details.
16  *
17  *  You should have received a copy of the GNU Lesser General Public License
18  *  along with this library.  If not, see <http://www.gnu.org/licenses/>.
19  *
20  */
21 
22 #include "openturns/KDTree.hxx"
23 #include "openturns/Exception.hxx"
24 #include "openturns/SpecFunc.hxx"
25 #include "openturns/Indices.hxx"
26 #include "openturns/SobolSequence.hxx"
27 #include "openturns/PersistentObjectFactory.hxx"
28 
29 BEGIN_NAMESPACE_OPENTURNS
30 
31 CLASSNAMEINIT(KDTree)
32 
33 static const Factory<KDTree> Factory_KDTree;
34 
35 /**
36  * @class KDNearestNeighboursFinder
37  *
38  * A fixed-size heap to find k-nearest neighbours in a KDTree
39  */
40 class KDNearestNeighboursFinder
41 {
42 public:
43 
44   /** Constructor */
KDNearestNeighboursFinder(const Indices & tree,const Sample & sample,const Interval & boundingBox,const UnsignedInteger size)45   KDNearestNeighboursFinder(const Indices & tree, const Sample & sample, const Interval & boundingBox, const UnsignedInteger size)
46     : tree_(tree)
47     , sample_(sample)
48     , boundingBox_(boundingBox)
49     , capacity_(size)
50     , size_(0)
51     , values_(size)
52     , indices_(size)
53   {
54     // Initialize values_[0] to a valid value
55     values_[0] = SpecFunc::MaxScalar;
56   }
57 
58   /** Get the indices of the k nearest neighbours of the given point */
getNearestNeighboursIndices(const UnsignedInteger inode,const Point & x,const Bool sorted)59   Indices getNearestNeighboursIndices(const UnsignedInteger inode, const Point & x, const Bool sorted)
60   {
61     if (size_ != 0)
62     {
63       // Clear heap
64       indices_.clear();
65       values_.clear();
66       size_ = 0;
67       values_[0] = SpecFunc::MaxScalar;
68     }
69     Point lowerBoundingBox(boundingBox_.getLowerBound());
70     Point upperBoundingBox(boundingBox_.getUpperBound());
71     collectNearestNeighbours(inode, x, lowerBoundingBox, upperBoundingBox, 0);
72     if (sorted)
73     {
74       /* Sort heap in-place in ascending order.
75          This breaks heap structure, but it does not matter, heap is
76          rebuilt when calling collectNearestNeighbours.
77        */
78       const UnsignedInteger realSize = size_;
79       while (size_ > 1)
80       {
81         // Move largest value at the end
82         std::swap(values_[size_ - 1], values_[0]);
83         std::swap(indices_[size_ - 1], indices_[0]);
84         // Make heap believe that largest value has been removed
85         --size_;
86         // Move new root to a valid location
87         moveNodeDown(0);
88       }
89       // Restore heap size
90       size_ = realSize;
91     }
92 
93     return indices_;
94   }
95 
96 private:
97   /* Recursive method to find the indices of the k nearest neighbours
98      Strategy:
99      + for a new candidate, if there is still room just add it to the list of neighbours
100      + else replace the worst candidate from the current list by the new candidate
101      Complexity: O(k) at each insertion, O(log(n)) expected insertions
102   */
collectNearestNeighbours(const UnsignedInteger inode,const Point & x,Point & lowerBoundingBox,Point & upperBoundingBox,const UnsignedInteger activeDimension)103   void collectNearestNeighbours(const UnsignedInteger inode,
104                                 const Point & x,
105                                 Point & lowerBoundingBox, Point & upperBoundingBox,
106                                 const UnsignedInteger activeDimension)
107   {
108     const Scalar splitValue = sample_(tree_[3 * inode], activeDimension);
109     const Scalar delta = x[activeDimension] - splitValue;
110     const UnsignedInteger sameSide(delta < 0.0 ? tree_[3 * inode + 1] : tree_[3 * inode + 2]);
111     const UnsignedInteger oppositeSide(delta < 0.0 ? tree_[3 * inode + 2] : tree_[3 * inode + 1]);
112     const UnsignedInteger dimension = sample_.getDimension();
113     const UnsignedInteger nextActiveDimension = (activeDimension + 1) % dimension;
114     const Scalar saveLower = lowerBoundingBox[activeDimension];
115     const Scalar saveUpper = upperBoundingBox[activeDimension];
116     Scalar currentGreatestValidSquaredDistance = values_[0];
117     if (sameSide != 0)
118     {
119       // Update bounding box to match sameSide bounding box
120       if (delta < 0.0)
121         upperBoundingBox[activeDimension] = splitValue;
122       else
123         lowerBoundingBox[activeDimension] = splitValue;
124       // Compute distance between x and sameSide
125       Scalar squaredDistanceBoundingBox = 0.0;
126       for(UnsignedInteger i = 0; i < dimension; ++i)
127       {
128         Scalar difference = std::max(0.0, std::max(x[i] - upperBoundingBox[i], lowerBoundingBox[i] - x[i]));
129         squaredDistanceBoundingBox += difference * difference;
130       }
131       if (squaredDistanceBoundingBox < values_[0])
132       {
133         collectNearestNeighbours(sameSide, x, lowerBoundingBox, upperBoundingBox, nextActiveDimension);
134         currentGreatestValidSquaredDistance = values_[0];
135       }
136       // Restore bounding box
137       if (delta < 0.0)
138         upperBoundingBox[activeDimension] = saveUpper;
139       else
140         lowerBoundingBox[activeDimension] = saveLower;
141     }
142     if (size_ == capacity_ && currentGreatestValidSquaredDistance < delta * delta) return;
143     const UnsignedInteger localIndex = tree_[3 * inode];
144     // Similar to (x - sample_[localIndex]).normSquare() but it is better to avoid Point creation
145     Scalar localSquaredDistance = 0.0;
146     for(UnsignedInteger i = 0; i < dimension; ++i)
147       localSquaredDistance += (x[i] - sample_(localIndex, i)) * (x[i] - sample_(localIndex, i));
148     if (size_ != capacity_)
149     {
150       // Put index/value at the first free node and move it up to a valid location
151       indices_[size_] = localIndex;
152       values_[size_] = localSquaredDistance;
153       moveNodeUp(size_);
154       ++size_;
155     }
156     else if (localSquaredDistance < currentGreatestValidSquaredDistance)
157     {
158       // Heap is full, and current value is smaller than heap largest value.
159       // Replace the largest value by current value and move it down to a
160       // valid location.
161       if (localSquaredDistance < values_[0])
162       {
163         indices_[0] = localIndex;
164         values_[0] = localSquaredDistance;
165         moveNodeDown(0);
166       }
167     }
168     if (oppositeSide != 0)
169     {
170       // Update bounding box to match oppositeSide bounding box
171       if (delta < 0.0)
172         lowerBoundingBox[activeDimension] = splitValue;
173       else
174         upperBoundingBox[activeDimension] = splitValue;
175       // Compute distance between x and oppositeSide
176       Scalar squaredDistanceBoundingBox = 0.0;
177       for(UnsignedInteger i = 0; i < dimension; ++i)
178       {
179         Scalar difference = std::max(0.0, std::max(x[i] - upperBoundingBox[i], lowerBoundingBox[i] - x[i]));
180         squaredDistanceBoundingBox += difference * difference;
181       }
182       if (squaredDistanceBoundingBox < values_[0])
183         collectNearestNeighbours(oppositeSide, x, lowerBoundingBox, upperBoundingBox, nextActiveDimension);
184       // Restore bounding box
185       if (delta < 0.0)
186         lowerBoundingBox[activeDimension] = saveLower;
187       else
188         upperBoundingBox[activeDimension] = saveUpper;
189     }
190   }
191 
192   /** Move node down to its final location */
moveNodeDown(const UnsignedInteger index)193   void moveNodeDown(const UnsignedInteger index)
194   {
195     const UnsignedInteger left = (index << 1) + 1;
196     const UnsignedInteger right = left + 1;
197     UnsignedInteger maxValueIndex = index;
198     if (left < size_ && values_[left] > values_[maxValueIndex])
199     {
200       maxValueIndex = left;
201     }
202     if (right < size_ && values_[right] > values_[maxValueIndex])
203     {
204       maxValueIndex = right;
205     }
206     if (index != maxValueIndex)
207     {
208       std::swap(values_[index], values_[maxValueIndex]);
209       std::swap(indices_[index], indices_[maxValueIndex]);
210       moveNodeDown(maxValueIndex);
211     }
212   }
213 
214   /** Move node up to its final location */
moveNodeUp(const UnsignedInteger index)215   void moveNodeUp(const UnsignedInteger index)
216   {
217     if (index == 0) return;
218     const UnsignedInteger parent = (index - 1) >> 1;
219     if (values_[index] > values_[parent])
220     {
221       std::swap(values_[index], values_[parent]);
222       std::swap(indices_[index], indices_[parent]);
223       moveNodeUp(parent);
224     }
225   }
226 
227   // Reference to tree
228   const Indices & tree_;
229   // Points contained in the tree
230   const Sample sample_;
231   // Points bounding box
232   const Interval boundingBox_;
233   // Maximum heap size
234   const UnsignedInteger capacity_;
235   // Number of used buckets
236   UnsignedInteger size_;
237   // Array containing values
238   Collection<Scalar> values_;
239   // Array containing point indices
240   Indices indices_;
241 
242 }; /* class KDNearestNeighboursFinder */
243 
244 /**
245  * @class KDTree
246  *
247  * Organize d-dimensional points into a hierarchical tree-like structure
248  */
249 
250 /* Default constructor */
KDTree()251 KDTree::KDTree()
252   : NearestNeighbourAlgorithmImplementation()
253   , points_(0, 0)
254   , boundingBox_()
255   , tree_()
256 {
257   // Nothing to do
258 }
259 
260 /* Parameters constructor */
KDTree(const Sample & points)261 KDTree::KDTree(const Sample & points)
262   : NearestNeighbourAlgorithmImplementation()
263   , points_(0, 0)
264   , boundingBox_()
265   , tree_()
266 {
267   // Build the tree
268   setSample(points);
269 }
270 
271 /* Sample accessor */
getSample() const272 Sample KDTree::getSample() const
273 {
274   return points_;
275 }
276 
setSample(const Sample & points)277 void KDTree::setSample(const Sample & points)
278 {
279   if (points == points_) return;
280 
281   points_ = points;
282   boundingBox_ = Interval(points_.getDimension());
283   tree_ = Indices(3 * (points_.getSize() + 1));
284 
285   // Scramble the order in which the points are inserted in the tree in order to improve its balancing
286   const UnsignedInteger size = points_.getSize();
287   Indices buffer(size);
288   buffer.fill();
289   SobolSequence sequence(1);
290   UnsignedInteger root = 0;
291   UnsignedInteger currentSize = 0;
292   for (UnsignedInteger i = 0; i < points_.getSize(); ++ i)
293   {
294     const UnsignedInteger index = i + static_cast< UnsignedInteger >((size - i) * sequence.generate()[0]);
295     insert(root, currentSize, buffer[index], 0);
296     buffer[index] = buffer[i];
297   }
298   boundingBox_.setLowerBound(points_.getMin());
299   boundingBox_.setUpperBound(points_.getMax());
300 }
301 
302 /* Virtual constructor */
clone() const303 KDTree * KDTree::clone() const
304 {
305   return new KDTree( *this );
306 }
307 
308 /* Virtual default constructor */
emptyClone() const309 KDTree * KDTree::emptyClone() const
310 {
311   return new KDTree();
312 }
313 
314 /* String converter */
__repr__() const315 String KDTree::__repr__() const
316 {
317   return OSS(true) << "class=" << GetClassName()
318          << " root=" << (tree_.getSize() > 0 ? printNode(1) : "NULL");
319 }
320 
__str__(const String &) const321 String KDTree::__str__(const String & ) const
322 {
323   return OSS(false) << "class=" << GetClassName()
324          << " root=" << (tree_.getSize() > 0 ? printNode(1) : "NULL");
325 }
326 
printNode(const UnsignedInteger inode) const327 String KDTree::printNode(const UnsignedInteger inode) const
328 {
329   return OSS() << "class=KDNode"
330          << " index=" << tree_[3 * inode]
331          << " left=" << (tree_[3 * inode + 1] ? printNode(tree_[3 * inode + 1]) : "NULL")
332          << " right=" << (tree_[3 * inode + 2] ? printNode(tree_[3 * inode + 2]) : "NULL");
333 }
334 
335 /* Insert the point at given index into the tree */
insert(UnsignedInteger & inode,UnsignedInteger & currentSize,const UnsignedInteger index,const UnsignedInteger activeDimension)336 void KDTree::insert(UnsignedInteger & inode,
337                     UnsignedInteger & currentSize,
338                     const UnsignedInteger index,
339                     const UnsignedInteger activeDimension)
340 {
341   if (!(index < points_.getSize())) throw InvalidArgumentException(HERE) << "Error: expected an index less than " << points_.getSize() << ", got " << index;
342   // We are on a leaf
343   if (inode == 0)
344   {
345     ++currentSize;
346     inode = currentSize;
347     tree_[3 * inode] = index;
348   }
349   else if (points_(index, activeDimension) < points_(tree_[3 * inode], activeDimension))
350     insert(tree_[3 * inode + 1], currentSize, index, (activeDimension + 1) % points_.getDimension());
351   else
352     insert(tree_[3 * inode + 2], currentSize, index, (activeDimension + 1) % points_.getDimension());
353 }
354 
355 /* Get the index of the nearest neighbour of the given point */
query(const Point & x) const356 UnsignedInteger KDTree::query(const Point & x) const
357 {
358   if (points_.getSize() == 1) return 0;
359   Scalar smallestDistance = SpecFunc::MaxScalar;
360   Point lowerBoundingBox(boundingBox_.getLowerBound());
361   Point upperBoundingBox(boundingBox_.getUpperBound());
362   return getNearestNeighbourIndex(1, x, smallestDistance, lowerBoundingBox, upperBoundingBox, 0);
363 }
364 
getNearestNeighbourIndex(const UnsignedInteger inode,const Point & x,Scalar & bestSquaredDistance,Point & lowerBoundingBox,Point & upperBoundingBox,const UnsignedInteger activeDimension) const365 UnsignedInteger KDTree::getNearestNeighbourIndex(const UnsignedInteger inode,
366     const Point & x,
367     Scalar & bestSquaredDistance,
368     Point & lowerBoundingBox,
369     Point & upperBoundingBox,
370     const UnsignedInteger activeDimension) const
371 {
372   if (!(inode > 0)) throw NotDefinedException(HERE) << "Error: cannot find a nearest neighbour in an empty tree";
373   // Set delta = x[activeDimension] - points_(tree_[3*inode], activeDimension)
374   // sameSide = tree_(inode,  0) if delta < 0, tree_[3*inode+2] else
375   // oppositeSide = tree_[3*inode+2] if delta < 0, tree_(inode,  0) else
376   // Possibilities:
377   // 1) sameSide != 0
378   // 1.1) Go on the same side. On return, the index is the best candidate index. If the associated distance is less than the current best index, update the current best index and the associated squared distance.
379   // 2) Check if the current best squared distance is less than delta^2
380   // 2.1*) If yes, the points associated to inode or to its opposite side can't be better than the current best candidate so return it and the associated squared distance to the upper level
381   // 2.2) If no, check the point associated to the current node and update the current best index and the associated squared distance
382   // 2.3) oppositeSide != 0
383   // 2.4) Go on the opposite side. On return, check if the returned squared distance is less than the current best squared distance and update the current best index and the associated squared distance.
384   // 3*) Return the current best index and the associated squared distance to the upper level
385 
386   const Scalar splitValue = points_(tree_[3 * inode], activeDimension);
387   const Scalar delta = x[activeDimension] - splitValue;
388   const UnsignedInteger sameSide(delta < 0.0 ? tree_[3 * inode + 1] : tree_[3 * inode + 2]);
389   const UnsignedInteger oppositeSide(delta < 0.0 ? tree_[3 * inode + 2] : tree_[3 * inode + 1]);
390   UnsignedInteger currentBestIndex = points_.getSize();
391   Scalar currentBestSquaredDistance = bestSquaredDistance;
392   const UnsignedInteger dimension = points_.getDimension();
393   const UnsignedInteger nextActiveDimension = (activeDimension + 1) % dimension;
394   const Scalar saveLower = lowerBoundingBox[activeDimension];
395   const Scalar saveUpper = upperBoundingBox[activeDimension];
396   // 1)
397   if (sameSide != 0)
398   {
399     // 1.1)
400     // Update bounding box to match sameSide bounding box
401     if (delta < 0.0)
402       upperBoundingBox[activeDimension] = splitValue;
403     else
404       lowerBoundingBox[activeDimension] = splitValue;
405     // Compute distance between x and sameSide
406     Scalar squaredDistanceBoundingBox = 0.0;
407     for(UnsignedInteger i = 0; i < dimension; ++i)
408     {
409       Scalar difference = std::max(0.0, std::max(x[i] - upperBoundingBox[i], lowerBoundingBox[i] - x[i]));
410       squaredDistanceBoundingBox += difference * difference;
411     }
412     if (squaredDistanceBoundingBox < currentBestSquaredDistance)
413     {
414       UnsignedInteger candidateBestIndex = getNearestNeighbourIndex(sameSide, x, bestSquaredDistance, lowerBoundingBox, upperBoundingBox, nextActiveDimension);
415       if (bestSquaredDistance < currentBestSquaredDistance)
416       {
417         currentBestSquaredDistance = bestSquaredDistance;
418         currentBestIndex = candidateBestIndex;
419       }
420     }
421     // Restore bounding box
422     if (delta < 0.0)
423       upperBoundingBox[activeDimension] = saveUpper;
424     else
425       lowerBoundingBox[activeDimension] = saveLower;
426   } // sameSide != 0
427   // 2)
428   if (currentBestSquaredDistance < delta * delta)
429   {
430     // 2.1)
431     bestSquaredDistance = currentBestSquaredDistance;
432     return currentBestIndex;
433   }
434   // 2.2)
435   const UnsignedInteger localIndex = tree_[3 * inode];
436   // Similar to (x - points_[localIndex]).normSquare() but it is better to avoid Point creation
437   Scalar localSquaredDistance = 0.0;
438   for(UnsignedInteger i = 0; i < dimension; ++i)
439     localSquaredDistance += (x[i] - points_(localIndex, i)) * (x[i] - points_(localIndex, i));
440   if (localSquaredDistance < currentBestSquaredDistance)
441   {
442     currentBestSquaredDistance = localSquaredDistance;
443     // To send the current best squared distance to the lower levels
444     bestSquaredDistance = localSquaredDistance;
445     currentBestIndex = localIndex;
446   }
447   // 2.3)
448   if (oppositeSide != 0)
449   {
450     // Update bounding box to match oppositeSide bounding box
451     if (delta < 0.0)
452       lowerBoundingBox[activeDimension] = splitValue;
453     else
454       upperBoundingBox[activeDimension] = splitValue;
455     // Compute distance between x and oppositeSide
456     Scalar squaredDistanceBoundingBox = 0.0;
457     for(UnsignedInteger i = 0; i < dimension; ++i)
458     {
459       Scalar difference = std::max(0.0, std::max(x[i] - upperBoundingBox[i], lowerBoundingBox[i] - x[i]));
460       squaredDistanceBoundingBox += difference * difference;
461     }
462     // 2.4)
463     if (squaredDistanceBoundingBox < currentBestSquaredDistance)
464     {
465       UnsignedInteger candidateBestIndex = getNearestNeighbourIndex(oppositeSide, x, bestSquaredDistance, lowerBoundingBox, upperBoundingBox, nextActiveDimension);
466       if (bestSquaredDistance < currentBestSquaredDistance)
467       {
468         currentBestSquaredDistance = bestSquaredDistance;
469         currentBestIndex = candidateBestIndex;
470       }
471     }
472     // Restore bounding box
473     if (delta < 0.0)
474       lowerBoundingBox[activeDimension] = saveLower;
475     else
476       upperBoundingBox[activeDimension] = saveUpper;
477   } // oppositeSide != 0
478   // 3)
479   bestSquaredDistance = currentBestSquaredDistance;
480   return currentBestIndex;
481 }
482 
483 /* Get the indices of the k nearest neighbours of the given point */
queryK(const Point & x,const UnsignedInteger k,const Bool sorted) const484 Indices KDTree::queryK(const Point & x, const UnsignedInteger k, const Bool sorted) const
485 {
486   if (k > points_.getSize()) throw InvalidArgumentException(HERE) << "Error: cannot return more neighbours than points in the database!";
487   Indices result(k);
488   // If we need as many neighbours as points in the sample, just return all the possible indices
489   if (k == points_.getSize() && !sorted)
490   {
491     result.fill();
492   }
493   else
494   {
495     KDNearestNeighboursFinder heap(tree_, points_, boundingBox_, k);
496     result = heap.getNearestNeighboursIndices(1, x, sorted);
497   }
498   return result;
499 }
500 
501 /** Method save() stores the object through the StorageManager */
save(Advocate & adv) const502 void KDTree::save(Advocate & adv) const
503 {
504   NearestNeighbourAlgorithmImplementation::save(adv);
505   adv.saveAttribute("points_", points_);
506 }
507 
508 /** Method load() reloads the object from the StorageManager */
load(Advocate & adv)509 void KDTree::load(Advocate & adv)
510 {
511   NearestNeighbourAlgorithmImplementation::load(adv);
512   Sample points;
513   adv.loadAttribute("points_", points);
514   if (points.getSize() > 0) setSample(points);
515 
516 }
517 
518 END_NAMESPACE_OPENTURNS
519