1 /*=========================================================================
2  *
3  *  Copyright Insight Software Consortium
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at
8  *
9  *         http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  *=========================================================================*/
18 #ifndef itkKdTree_hxx
19 #define itkKdTree_hxx
20 
21 #include "itkKdTree.h"
22 
23 namespace itk
24 {
25 namespace Statistics
26 {
27 template<typename TSample>
28 KdTreeNonterminalNode<TSample>
KdTreeNonterminalNode(unsigned int partitionDimension,MeasurementType partitionValue,Superclass * left,Superclass * right)29 ::KdTreeNonterminalNode( unsigned int partitionDimension,
30                          MeasurementType partitionValue, Superclass *left, Superclass *right ) :
31   m_PartitionDimension(partitionDimension),
32   m_PartitionValue(partitionValue),
33   m_InstanceIdentifier(0),
34   m_Left(left),
35   m_Right(right)
36 {
37 }
38 
39 template<typename TSample>
40 void
41 KdTreeNonterminalNode<TSample>
GetParameters(unsigned int & partitionDimension,MeasurementType & partitionValue) const42 ::GetParameters( unsigned int &partitionDimension,
43   MeasurementType &partitionValue ) const
44 {
45   partitionDimension = this->m_PartitionDimension;
46   partitionValue = this->m_PartitionValue;
47 }
48 
49 template<typename TSample>
50 KdTreeWeightedCentroidNonterminalNode<TSample>
KdTreeWeightedCentroidNonterminalNode(unsigned int partitionDimension,MeasurementType partitionValue,Superclass * left,Superclass * right,CentroidType & centroid,unsigned int size)51 ::KdTreeWeightedCentroidNonterminalNode( unsigned int partitionDimension,
52   MeasurementType partitionValue, Superclass *left, Superclass *right,
53   CentroidType & centroid, unsigned int size )
54 {
55   this->m_PartitionDimension = partitionDimension;
56   this->m_PartitionValue = partitionValue;
57   this->m_Left = left;
58   this->m_Right = right;
59   this->m_WeightedCentroid = centroid;
60   this->m_MeasurementVectorSize =
61     NumericTraits<CentroidType>::GetLength( centroid );
62 
63   this->m_Centroid = this->m_WeightedCentroid / static_cast<double>( size );
64 
65   this->m_Size = size;
66 }
67 
68 template<typename TSample>
69 void
70 KdTreeWeightedCentroidNonterminalNode<TSample>
GetParameters(unsigned int & partitionDimension,MeasurementType & partitionValue) const71 ::GetParameters( unsigned int &partitionDimension,
72   MeasurementType &partitionValue) const
73 {
74   partitionDimension = this->m_PartitionDimension;
75   partitionValue = this->m_PartitionValue;
76 }
77 
78 template<typename TSample>
79 KdTree<TSample>
KdTree()80 ::KdTree()
81 {
82   this->m_EmptyTerminalNode = new KdTreeTerminalNode<TSample>();
83 
84   this->m_DistanceMetric = DistanceMetricType::New();
85   this->m_Sample = nullptr;
86   this->m_Root = nullptr;
87   this->m_BucketSize = 16;
88   this->m_MeasurementVectorSize = 0;
89 }
90 
91 template<typename TSample>
92 KdTree<TSample>
~KdTree()93 ::~KdTree()
94 {
95   if( this->m_Root != nullptr )
96     {
97     this->DeleteNode( this->m_Root );
98     }
99   delete this->m_EmptyTerminalNode;
100 }
101 
102 template<typename TSample>
103 void
104 KdTree<TSample>
PrintSelf(std::ostream & os,Indent indent) const105 ::PrintSelf( std::ostream &os, Indent indent ) const
106 {
107   Superclass::PrintSelf( os, indent );
108 
109   os << indent << "Input Sample: ";
110   if( this->m_Sample != nullptr )
111     {
112     os << this->m_Sample << std::endl;
113     }
114   else
115     {
116     os << "not set." << std::endl;
117     }
118   os << indent << "Bucket Size: " << this->m_BucketSize << std::endl;
119   os << indent << "Root Node: ";
120   if( this->m_Root != nullptr )
121     {
122     os << this->m_Root << std::endl;
123     }
124   else
125     {
126     os << "not set." << std::endl;
127     }
128   os << indent << "MeasurementVectorSize: "
129      << this->m_MeasurementVectorSize << std::endl;
130 }
131 
132 template<typename TSample>
133 void
134 KdTree<TSample>
DeleteNode(KdTreeNodeType * node)135 ::DeleteNode( KdTreeNodeType *node )
136 {
137   if( node->IsTerminal() )
138     {
139     // terminal node
140     if( node == this->m_EmptyTerminalNode )
141       {
142       // empty node
143       return;
144       }
145     delete node;
146     return;
147     }
148 
149   // non-terminal node
150   if( node->Left() != nullptr )
151     {
152     this->DeleteNode( node->Left() );
153     }
154 
155   if( node->Right() != nullptr )
156     {
157     this->DeleteNode( node->Right() );
158     }
159 
160   delete node;
161 }
162 
163 template<typename TSample>
164 void
165 KdTree<TSample>
SetSample(const TSample * sample)166 ::SetSample( const TSample *sample )
167 {
168   this->m_Sample = sample;
169   this->m_MeasurementVectorSize = this->m_Sample->GetMeasurementVectorSize();
170   this->m_DistanceMetric->SetMeasurementVectorSize(
171     this->m_MeasurementVectorSize );
172   this->Modified();
173 }
174 
175 template<typename TSample>
176 void
177 KdTree<TSample>
SetBucketSize(unsigned int size)178 ::SetBucketSize(unsigned int size)
179 {
180   this->m_BucketSize = size;
181 }
182 
183 template<typename TSample>
184 void
185 KdTree<TSample>
Search(const MeasurementVectorType & query,unsigned int numberOfNeighborsRequested,InstanceIdentifierVectorType & result) const186 ::Search( const MeasurementVectorType & query,
187          unsigned int numberOfNeighborsRequested, InstanceIdentifierVectorType &result ) const
188 {
189   // This function has two different signatures. The other signature, that returns the distances vector too,
190   // is called here; however, its distances vector is discarded.
191   std::vector<double> not_used_distances;
192   this->Search(query, numberOfNeighborsRequested, result, not_used_distances);
193 }
194 
195 template<typename TSample>
196 void
197 KdTree<TSample>
Search(const MeasurementVectorType & query,unsigned int numberOfNeighborsRequested,InstanceIdentifierVectorType & result,std::vector<double> & distances) const198 ::Search( const MeasurementVectorType & query,
199   unsigned int numberOfNeighborsRequested, InstanceIdentifierVectorType &result,
200   std::vector<double> &distances ) const
201 {
202   if( numberOfNeighborsRequested > this->Size() )
203     {
204     itkExceptionMacro( "The numberOfNeighborsRequested for the nearest "
205       << "neighbor search should be less than or equal to the number of "
206       << "the measurement vectors." );
207     }
208 
209   /* 'distances' is the storage container used internally for the
210    * NearestNeighbors class.  The 'distances' vector is modified
211    * by the NearestNeighbors class.  By passing in
212    * the 'distances' vector here, we can avoid unnecessary memory
213    * duplications and copy operations.*/
214   NearestNeighbors nearestNeighbors(distances);
215   nearestNeighbors.resize( numberOfNeighborsRequested );
216 
217   MeasurementVectorType lowerBound;
218   NumericTraits<MeasurementVectorType>::SetLength( lowerBound,
219     this->m_MeasurementVectorSize );
220   MeasurementVectorType upperBound;
221   NumericTraits<MeasurementVectorType>::SetLength( upperBound,
222     this->m_MeasurementVectorSize );
223 
224   for(  unsigned int d = 0; d < this->m_MeasurementVectorSize; ++d )
225     {
226     lowerBound[d] = static_cast< MeasurementType >( -std::sqrt(
227       -static_cast< double >( NumericTraits< MeasurementType >::
228       NonpositiveMin() ) ) / 2.0 );
229     upperBound[d] = static_cast< MeasurementType >( std::sqrt(
230       static_cast<double >( NumericTraits< MeasurementType >::max() ) / 2.0 ) );
231     }
232   this->NearestNeighborSearchLoop( this->m_Root, query, lowerBound, upperBound,
233     nearestNeighbors );
234 
235   result = nearestNeighbors.GetNeighbors();
236 }
237 
238 template<typename TSample>
239 inline int
240 KdTree<TSample>
NearestNeighborSearchLoop(const KdTreeNodeType * node,const MeasurementVectorType & query,MeasurementVectorType & lowerBound,MeasurementVectorType & upperBound,NearestNeighbors & nearestNeighbors) const241 ::NearestNeighborSearchLoop( const KdTreeNodeType *node,
242   const MeasurementVectorType &query, MeasurementVectorType &lowerBound,
243   MeasurementVectorType &upperBound, NearestNeighbors &nearestNeighbors ) const
244 {
245   unsigned int       i;
246   InstanceIdentifier tempId;
247   double             tempDistance;
248 
249   if( node->IsTerminal() )
250     {
251     // terminal node
252     if( node == this->m_EmptyTerminalNode )
253       {
254       // empty node
255       return 0;
256       }
257 
258     for(  i = 0; i < node->Size(); ++i )
259       {
260       tempId = node->GetInstanceIdentifier(i);
261       tempDistance = this->m_DistanceMetric->Evaluate( query,
262         this->m_Sample->GetMeasurementVector( tempId ) );
263       if( tempDistance < nearestNeighbors.GetLargestDistance() )
264         {
265         nearestNeighbors.ReplaceFarthestNeighbor( tempId, tempDistance );
266         }
267       }
268 
269     if( this->BallWithinBounds( query, lowerBound, upperBound,
270       nearestNeighbors.GetLargestDistance() ) )
271       {
272       return 1;
273       }
274 
275     return 0;
276     }
277 
278   unsigned int    partitionDimension;
279   MeasurementType partitionValue;
280   MeasurementType tempValue;
281   node->GetParameters( partitionDimension, partitionValue );
282 
283   //
284   // Check the point associated with the nonterminal node
285   // and potentially add it to the list of nearest neighbors
286   //
287   tempId = node->GetInstanceIdentifier(0);
288   tempDistance = this->m_DistanceMetric->Evaluate( query,
289     this->m_Sample->GetMeasurementVector( tempId ) );
290   if( tempDistance < nearestNeighbors.GetLargestDistance() )
291     {
292     nearestNeighbors.ReplaceFarthestNeighbor( tempId, tempDistance );
293     }
294 
295   //
296   // Now check both child sub-trees
297   //
298   if( query[partitionDimension] <= partitionValue )
299     {
300     // search the closer child node
301     tempValue = upperBound[partitionDimension];
302     upperBound[partitionDimension] = partitionValue;
303     if( this->NearestNeighborSearchLoop( node->Left(), query, lowerBound,
304       upperBound, nearestNeighbors ) )
305       {
306       return 1;
307       }
308     upperBound[partitionDimension] = tempValue;
309 
310     // search the other node, if necessary
311     tempValue = lowerBound[partitionDimension];
312     lowerBound[partitionDimension] = partitionValue;
313     if( this->BoundsOverlapBall( query, lowerBound, upperBound,
314       nearestNeighbors.GetLargestDistance() ) )
315       {
316       this->NearestNeighborSearchLoop( node->Right(), query, lowerBound,
317         upperBound, nearestNeighbors );
318       }
319     lowerBound[partitionDimension] = tempValue;
320     }
321   else
322     {
323     // search the closer child node
324     tempValue = lowerBound[partitionDimension];
325     lowerBound[partitionDimension] = partitionValue;
326     if( this->NearestNeighborSearchLoop( node->Right(), query, lowerBound,
327       upperBound, nearestNeighbors ) )
328       {
329       return 1;
330       }
331     lowerBound[partitionDimension] = tempValue;
332 
333     // search the other node, if necessary
334     tempValue = upperBound[partitionDimension];
335     upperBound[partitionDimension] = partitionValue;
336     if( this->BoundsOverlapBall( query, lowerBound, upperBound,
337       nearestNeighbors.GetLargestDistance() ) )
338       {
339       this->NearestNeighborSearchLoop( node->Left(), query, lowerBound,
340         upperBound, nearestNeighbors );
341       }
342     upperBound[partitionDimension] = tempValue;
343     }
344 
345   // stop or continue search
346   if( this->BallWithinBounds( query, lowerBound, upperBound,
347     nearestNeighbors.GetLargestDistance() ) )
348     {
349     return 1;
350     }
351 
352   return 0;
353 }
354 
355 template<typename TSample>
356 void
357 KdTree<TSample>
Search(const MeasurementVectorType & query,double radius,InstanceIdentifierVectorType & result) const358 ::Search( const MeasurementVectorType & query, double radius,
359   InstanceIdentifierVectorType & result ) const
360 {
361   MeasurementVectorType lowerBound;
362   MeasurementVectorType upperBound;
363 
364   NumericTraits<MeasurementVectorType>::SetLength( lowerBound,
365     this->m_MeasurementVectorSize );
366   NumericTraits<MeasurementVectorType>::SetLength( upperBound,
367     this->m_MeasurementVectorSize );
368 
369   for(  unsigned int d = 0; d < this->m_MeasurementVectorSize; ++d )
370     {
371     lowerBound[d] = static_cast<MeasurementType>( -std::sqrt(
372       -static_cast<double>( NumericTraits<MeasurementType>::
373       NonpositiveMin() ) ) / 2.0 );
374     upperBound[d] = static_cast< MeasurementType >( std::sqrt(
375       static_cast<double>( NumericTraits< MeasurementType >::max() ) / 2.0 ) );
376     }
377 
378   result.clear();
379   this->SearchLoop( this->m_Root, query, radius, lowerBound, upperBound, result );
380 }
381 
382 template<typename TSample>
383 inline int
384 KdTree<TSample>
SearchLoop(const KdTreeNodeType * node,const MeasurementVectorType & query,double radius,MeasurementVectorType & lowerBound,MeasurementVectorType & upperBound,InstanceIdentifierVectorType & neighbors) const385 ::SearchLoop( const KdTreeNodeType *node, const MeasurementVectorType &query,
386   double radius, MeasurementVectorType &lowerBound, MeasurementVectorType
387   &upperBound, InstanceIdentifierVectorType &neighbors ) const
388 {
389   InstanceIdentifier tempId;
390   double             tempDistance;
391 
392   if( node->IsTerminal() )
393     {
394     // terminal node
395     if( node == this->m_EmptyTerminalNode )
396       {
397       // empty node
398       return 0;
399       }
400 
401     for( unsigned int i = 0; i < node->Size(); ++i )
402       {
403       tempId = node->GetInstanceIdentifier( i );
404       tempDistance = this->m_DistanceMetric->Evaluate( query,
405         this->m_Sample->GetMeasurementVector( tempId ) );
406       if( tempDistance <= radius )
407         {
408         neighbors.push_back( tempId );
409         }
410       }
411 
412     if( this->BallWithinBounds( query, lowerBound, upperBound, radius ) )
413       {
414       return 1;
415       }
416 
417     return 0;
418     }
419   if( node->IsTerminal() == false )
420     {
421     tempId = node->GetInstanceIdentifier( 0 );
422     tempDistance = this->m_DistanceMetric->Evaluate( query,
423       this->m_Sample->GetMeasurementVector( tempId ) );
424     if( tempDistance <= radius )
425       {
426       neighbors.push_back( tempId );
427       }
428     }
429 
430   unsigned int    partitionDimension;
431   MeasurementType partitionValue;
432   MeasurementType tempValue;
433   node->GetParameters( partitionDimension, partitionValue );
434 
435   if( query[partitionDimension] <= partitionValue )
436     {
437     // search the closer child node
438     tempValue = upperBound[partitionDimension];
439     upperBound[partitionDimension] = partitionValue;
440     if( this->SearchLoop( node->Left(), query, radius, lowerBound, upperBound,
441       neighbors ) )
442       {
443       return 1;
444       }
445     upperBound[partitionDimension] = tempValue;
446 
447     // search the other node, if necessary
448     tempValue = lowerBound[partitionDimension];
449     lowerBound[partitionDimension] = partitionValue;
450     if( this->BoundsOverlapBall( query, lowerBound, upperBound, radius ) )
451       {
452       this->SearchLoop( node->Right(), query, radius, lowerBound, upperBound,
453         neighbors );
454       }
455     lowerBound[partitionDimension] = tempValue;
456     }
457   else
458     {
459     // search the closer child node
460     tempValue = lowerBound[partitionDimension];
461     lowerBound[partitionDimension] = partitionValue;
462     if( this->SearchLoop( node->Right(), query, radius, lowerBound, upperBound,
463       neighbors ) )
464       {
465       return 1;
466       }
467     lowerBound[partitionDimension] = tempValue;
468 
469     // search the other node, if necessary
470     tempValue = upperBound[partitionDimension];
471     upperBound[partitionDimension] = partitionValue;
472     if( this->BoundsOverlapBall( query, lowerBound, upperBound, radius ) )
473       {
474       this->SearchLoop( node->Left(), query, radius, lowerBound, upperBound,
475         neighbors );
476       }
477     upperBound[partitionDimension] = tempValue;
478     }
479 
480   // stop or continue search
481   if( this->BallWithinBounds( query, lowerBound, upperBound, radius ) )
482     {
483     return 1;
484     }
485 
486   return 0;
487 }
488 
489 template<typename TSample>
490 inline bool
491 KdTree<TSample>
BallWithinBounds(const MeasurementVectorType & query,MeasurementVectorType & lowerBound,MeasurementVectorType & upperBound,double radius) const492 ::BallWithinBounds( const MeasurementVectorType & query, MeasurementVectorType
493   &lowerBound, MeasurementVectorType & upperBound, double radius ) const
494 {
495   for( unsigned int d = 0; d < this->m_MeasurementVectorSize; ++d )
496     {
497     if( ( this->m_DistanceMetric->Evaluate( query[d], lowerBound[d] ) <=
498       radius ) || ( this->m_DistanceMetric->Evaluate( query[d],
499       upperBound[d] ) <= radius ) )
500       {
501       return false;
502       }
503     }
504   return true;
505 }
506 
507 template<typename TSample>
508 inline bool
509 KdTree<TSample>
BoundsOverlapBall(const MeasurementVectorType & query,MeasurementVectorType & lowerBound,MeasurementVectorType & upperBound,double radius) const510 ::BoundsOverlapBall( const MeasurementVectorType &query, MeasurementVectorType
511   &lowerBound, MeasurementVectorType &upperBound, double radius ) const
512 {
513   double squaredSearchRadius = itk::Math::sqr( radius );
514 
515   double sum = 0.0;
516   for( unsigned int d = 0; d < this->m_MeasurementVectorSize; ++d )
517     {
518     if( query[d] <= lowerBound[d] )
519       {
520       sum += itk::Math::sqr( this->m_DistanceMetric->Evaluate( query[d],
521         lowerBound[d] ) );
522       if( sum < squaredSearchRadius )
523         {
524         return true;
525         }
526       }
527     else if( query[d] >= upperBound[d] )
528       {
529       sum += itk::Math::sqr( this->m_DistanceMetric->Evaluate( query[d],
530         upperBound[d] ) );
531       if( sum < squaredSearchRadius )
532         {
533         return true;
534         }
535       }
536     }
537   return false;
538 }
539 
540 template<typename TSample>
541 void
542 KdTree<TSample>
PrintTree(std::ostream & os) const543 ::PrintTree( std::ostream & os ) const
544 {
545   constexpr unsigned int topLevel = 0;
546   constexpr unsigned int activeDimension = 0;
547 
548   this->PrintTree( this->m_Root, topLevel, activeDimension, os );
549 }
550 
551 template<typename TSample>
552 void
553 KdTree<TSample>
PrintTree(KdTreeNodeType * node,unsigned int level,unsigned int activeDimension,std::ostream & os) const554 ::PrintTree( KdTreeNodeType *node, unsigned int level,
555   unsigned int activeDimension, std::ostream &os ) const
556 {
557   level++;
558   if( node->IsTerminal() )
559     {
560     // terminal node
561     if( node == this->m_EmptyTerminalNode )
562       {
563       // empty node
564       os << "Empty node: level = " << level << std::endl;
565       return;
566       }
567     os << "Terminal: level = " << level << " dim = " << activeDimension
568        << std::endl;
569     os << "          ";
570     for(  unsigned int i = 0; i < node->Size(); ++i )
571       {
572       os << "[" << node->GetInstanceIdentifier( i ) << "] "
573          << this->m_Sample->GetMeasurementVector(
574          node->GetInstanceIdentifier( i ) ) << ", ";
575       }
576     os << std::endl;
577     return;
578     }
579 
580   unsigned int    partitionDimension;
581   MeasurementType partitionValue;
582 
583   node->GetParameters( partitionDimension, partitionValue );
584   typename KdTreeNodeType::CentroidType centroid;
585   node->GetWeightedCentroid(centroid);
586   os << "Nonterminal: level = " << level << std::endl;
587   os << "             dim = " << partitionDimension << std::endl;
588   os << "             value = " << partitionValue << std::endl;
589   os << "             weighted centroid = " << centroid;
590   os << "             size = " << node->Size() << std::endl;
591   os << "             identifier = " << node->GetInstanceIdentifier( 0 );
592   os << this->m_Sample->GetMeasurementVector( node->GetInstanceIdentifier( 0 ) )
593      << std::endl;
594 
595   this->PrintTree( node->Left(),  level, partitionDimension, os );
596   this->PrintTree( node->Right(), level, partitionDimension, os );
597 }
598 
599 template<typename TSample>
600 void
601 KdTree<TSample>
PlotTree(std::ostream & os) const602 ::PlotTree( std::ostream & os ) const
603 {
604   //
605   // Graph header
606   //
607   os << "digraph G {" << std::endl;
608 
609   //
610   // Recursively visit the tree and add entries for the nodes
611   //
612   this->PlotTree( this->m_Root, os );
613 
614   //
615   // Graph footer
616   //
617   os << "}" << std::endl;
618 }
619 
620 template<typename TSample>
621 void
622 KdTree<TSample>
PlotTree(KdTreeNodeType * node,std::ostream & os) const623 ::PlotTree( KdTreeNodeType *node, std::ostream & os ) const
624 {
625   unsigned int    partitionDimension;
626   MeasurementType partitionValue;
627 
628   node->GetParameters( partitionDimension, partitionValue );
629 
630   KdTreeNodeType *left  = node->Left();
631   KdTreeNodeType *right = node->Right();
632 
633   char partitionDimensionCharSymbol = ( 'X' + partitionDimension );
634 
635   if( node->IsTerminal() )
636     {
637     // terminal node
638     if( node != this->m_EmptyTerminalNode )
639       {
640       os << "\"" << node << "\" [label=\"";
641       for(  unsigned int i = 0; i < node->Size(); ++i )
642         {
643         os << this->GetMeasurementVector( node->GetInstanceIdentifier( i ) );
644         os << " ";
645         }
646       os << "\" ];" << std::endl;
647       }
648     }
649   else
650     {
651     os << "\"" << node << "\" [label=\"";
652     os << this->GetMeasurementVector( node->GetInstanceIdentifier( 0 ) );
653     os << " " << partitionDimensionCharSymbol << "=" << partitionValue;
654     os << "\" ];" << std::endl;
655     }
656 
657   if( left &&  ( left != this->m_EmptyTerminalNode ) )
658     {
659     os << "\"" << node << "\" -> \"" << left << "\";" << std::endl;
660     this->PlotTree( left, os );
661     }
662 
663   if( right && ( right != this->m_EmptyTerminalNode ) )
664     {
665     os << "\"" << node << "\" -> \"" << right << "\";" << std::endl;
666     this->PlotTree( right, os );
667     }
668 }
669 } // end of namespace Statistics
670 } // end of namespace itk
671 
672 #endif
673