1 /*=========================================================================
2 *
3 * Copyright Insight Software Consortium
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0.txt
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 *=========================================================================*/
18 #ifndef itkKdTree_hxx
19 #define itkKdTree_hxx
20
21 #include "itkKdTree.h"
22
23 namespace itk
24 {
25 namespace Statistics
26 {
27 template<typename TSample>
28 KdTreeNonterminalNode<TSample>
KdTreeNonterminalNode(unsigned int partitionDimension,MeasurementType partitionValue,Superclass * left,Superclass * right)29 ::KdTreeNonterminalNode( unsigned int partitionDimension,
30 MeasurementType partitionValue, Superclass *left, Superclass *right ) :
31 m_PartitionDimension(partitionDimension),
32 m_PartitionValue(partitionValue),
33 m_InstanceIdentifier(0),
34 m_Left(left),
35 m_Right(right)
36 {
37 }
38
39 template<typename TSample>
40 void
41 KdTreeNonterminalNode<TSample>
GetParameters(unsigned int & partitionDimension,MeasurementType & partitionValue) const42 ::GetParameters( unsigned int &partitionDimension,
43 MeasurementType &partitionValue ) const
44 {
45 partitionDimension = this->m_PartitionDimension;
46 partitionValue = this->m_PartitionValue;
47 }
48
49 template<typename TSample>
50 KdTreeWeightedCentroidNonterminalNode<TSample>
KdTreeWeightedCentroidNonterminalNode(unsigned int partitionDimension,MeasurementType partitionValue,Superclass * left,Superclass * right,CentroidType & centroid,unsigned int size)51 ::KdTreeWeightedCentroidNonterminalNode( unsigned int partitionDimension,
52 MeasurementType partitionValue, Superclass *left, Superclass *right,
53 CentroidType & centroid, unsigned int size )
54 {
55 this->m_PartitionDimension = partitionDimension;
56 this->m_PartitionValue = partitionValue;
57 this->m_Left = left;
58 this->m_Right = right;
59 this->m_WeightedCentroid = centroid;
60 this->m_MeasurementVectorSize =
61 NumericTraits<CentroidType>::GetLength( centroid );
62
63 this->m_Centroid = this->m_WeightedCentroid / static_cast<double>( size );
64
65 this->m_Size = size;
66 }
67
68 template<typename TSample>
69 void
70 KdTreeWeightedCentroidNonterminalNode<TSample>
GetParameters(unsigned int & partitionDimension,MeasurementType & partitionValue) const71 ::GetParameters( unsigned int &partitionDimension,
72 MeasurementType &partitionValue) const
73 {
74 partitionDimension = this->m_PartitionDimension;
75 partitionValue = this->m_PartitionValue;
76 }
77
78 template<typename TSample>
79 KdTree<TSample>
KdTree()80 ::KdTree()
81 {
82 this->m_EmptyTerminalNode = new KdTreeTerminalNode<TSample>();
83
84 this->m_DistanceMetric = DistanceMetricType::New();
85 this->m_Sample = nullptr;
86 this->m_Root = nullptr;
87 this->m_BucketSize = 16;
88 this->m_MeasurementVectorSize = 0;
89 }
90
91 template<typename TSample>
92 KdTree<TSample>
~KdTree()93 ::~KdTree()
94 {
95 if( this->m_Root != nullptr )
96 {
97 this->DeleteNode( this->m_Root );
98 }
99 delete this->m_EmptyTerminalNode;
100 }
101
102 template<typename TSample>
103 void
104 KdTree<TSample>
PrintSelf(std::ostream & os,Indent indent) const105 ::PrintSelf( std::ostream &os, Indent indent ) const
106 {
107 Superclass::PrintSelf( os, indent );
108
109 os << indent << "Input Sample: ";
110 if( this->m_Sample != nullptr )
111 {
112 os << this->m_Sample << std::endl;
113 }
114 else
115 {
116 os << "not set." << std::endl;
117 }
118 os << indent << "Bucket Size: " << this->m_BucketSize << std::endl;
119 os << indent << "Root Node: ";
120 if( this->m_Root != nullptr )
121 {
122 os << this->m_Root << std::endl;
123 }
124 else
125 {
126 os << "not set." << std::endl;
127 }
128 os << indent << "MeasurementVectorSize: "
129 << this->m_MeasurementVectorSize << std::endl;
130 }
131
132 template<typename TSample>
133 void
134 KdTree<TSample>
DeleteNode(KdTreeNodeType * node)135 ::DeleteNode( KdTreeNodeType *node )
136 {
137 if( node->IsTerminal() )
138 {
139 // terminal node
140 if( node == this->m_EmptyTerminalNode )
141 {
142 // empty node
143 return;
144 }
145 delete node;
146 return;
147 }
148
149 // non-terminal node
150 if( node->Left() != nullptr )
151 {
152 this->DeleteNode( node->Left() );
153 }
154
155 if( node->Right() != nullptr )
156 {
157 this->DeleteNode( node->Right() );
158 }
159
160 delete node;
161 }
162
163 template<typename TSample>
164 void
165 KdTree<TSample>
SetSample(const TSample * sample)166 ::SetSample( const TSample *sample )
167 {
168 this->m_Sample = sample;
169 this->m_MeasurementVectorSize = this->m_Sample->GetMeasurementVectorSize();
170 this->m_DistanceMetric->SetMeasurementVectorSize(
171 this->m_MeasurementVectorSize );
172 this->Modified();
173 }
174
175 template<typename TSample>
176 void
177 KdTree<TSample>
SetBucketSize(unsigned int size)178 ::SetBucketSize(unsigned int size)
179 {
180 this->m_BucketSize = size;
181 }
182
183 template<typename TSample>
184 void
185 KdTree<TSample>
Search(const MeasurementVectorType & query,unsigned int numberOfNeighborsRequested,InstanceIdentifierVectorType & result) const186 ::Search( const MeasurementVectorType & query,
187 unsigned int numberOfNeighborsRequested, InstanceIdentifierVectorType &result ) const
188 {
189 // This function has two different signatures. The other signature, that returns the distances vector too,
190 // is called here; however, its distances vector is discarded.
191 std::vector<double> not_used_distances;
192 this->Search(query, numberOfNeighborsRequested, result, not_used_distances);
193 }
194
195 template<typename TSample>
196 void
197 KdTree<TSample>
Search(const MeasurementVectorType & query,unsigned int numberOfNeighborsRequested,InstanceIdentifierVectorType & result,std::vector<double> & distances) const198 ::Search( const MeasurementVectorType & query,
199 unsigned int numberOfNeighborsRequested, InstanceIdentifierVectorType &result,
200 std::vector<double> &distances ) const
201 {
202 if( numberOfNeighborsRequested > this->Size() )
203 {
204 itkExceptionMacro( "The numberOfNeighborsRequested for the nearest "
205 << "neighbor search should be less than or equal to the number of "
206 << "the measurement vectors." );
207 }
208
209 /* 'distances' is the storage container used internally for the
210 * NearestNeighbors class. The 'distances' vector is modified
211 * by the NearestNeighbors class. By passing in
212 * the 'distances' vector here, we can avoid unnecessary memory
213 * duplications and copy operations.*/
214 NearestNeighbors nearestNeighbors(distances);
215 nearestNeighbors.resize( numberOfNeighborsRequested );
216
217 MeasurementVectorType lowerBound;
218 NumericTraits<MeasurementVectorType>::SetLength( lowerBound,
219 this->m_MeasurementVectorSize );
220 MeasurementVectorType upperBound;
221 NumericTraits<MeasurementVectorType>::SetLength( upperBound,
222 this->m_MeasurementVectorSize );
223
224 for( unsigned int d = 0; d < this->m_MeasurementVectorSize; ++d )
225 {
226 lowerBound[d] = static_cast< MeasurementType >( -std::sqrt(
227 -static_cast< double >( NumericTraits< MeasurementType >::
228 NonpositiveMin() ) ) / 2.0 );
229 upperBound[d] = static_cast< MeasurementType >( std::sqrt(
230 static_cast<double >( NumericTraits< MeasurementType >::max() ) / 2.0 ) );
231 }
232 this->NearestNeighborSearchLoop( this->m_Root, query, lowerBound, upperBound,
233 nearestNeighbors );
234
235 result = nearestNeighbors.GetNeighbors();
236 }
237
238 template<typename TSample>
239 inline int
240 KdTree<TSample>
NearestNeighborSearchLoop(const KdTreeNodeType * node,const MeasurementVectorType & query,MeasurementVectorType & lowerBound,MeasurementVectorType & upperBound,NearestNeighbors & nearestNeighbors) const241 ::NearestNeighborSearchLoop( const KdTreeNodeType *node,
242 const MeasurementVectorType &query, MeasurementVectorType &lowerBound,
243 MeasurementVectorType &upperBound, NearestNeighbors &nearestNeighbors ) const
244 {
245 unsigned int i;
246 InstanceIdentifier tempId;
247 double tempDistance;
248
249 if( node->IsTerminal() )
250 {
251 // terminal node
252 if( node == this->m_EmptyTerminalNode )
253 {
254 // empty node
255 return 0;
256 }
257
258 for( i = 0; i < node->Size(); ++i )
259 {
260 tempId = node->GetInstanceIdentifier(i);
261 tempDistance = this->m_DistanceMetric->Evaluate( query,
262 this->m_Sample->GetMeasurementVector( tempId ) );
263 if( tempDistance < nearestNeighbors.GetLargestDistance() )
264 {
265 nearestNeighbors.ReplaceFarthestNeighbor( tempId, tempDistance );
266 }
267 }
268
269 if( this->BallWithinBounds( query, lowerBound, upperBound,
270 nearestNeighbors.GetLargestDistance() ) )
271 {
272 return 1;
273 }
274
275 return 0;
276 }
277
278 unsigned int partitionDimension;
279 MeasurementType partitionValue;
280 MeasurementType tempValue;
281 node->GetParameters( partitionDimension, partitionValue );
282
283 //
284 // Check the point associated with the nonterminal node
285 // and potentially add it to the list of nearest neighbors
286 //
287 tempId = node->GetInstanceIdentifier(0);
288 tempDistance = this->m_DistanceMetric->Evaluate( query,
289 this->m_Sample->GetMeasurementVector( tempId ) );
290 if( tempDistance < nearestNeighbors.GetLargestDistance() )
291 {
292 nearestNeighbors.ReplaceFarthestNeighbor( tempId, tempDistance );
293 }
294
295 //
296 // Now check both child sub-trees
297 //
298 if( query[partitionDimension] <= partitionValue )
299 {
300 // search the closer child node
301 tempValue = upperBound[partitionDimension];
302 upperBound[partitionDimension] = partitionValue;
303 if( this->NearestNeighborSearchLoop( node->Left(), query, lowerBound,
304 upperBound, nearestNeighbors ) )
305 {
306 return 1;
307 }
308 upperBound[partitionDimension] = tempValue;
309
310 // search the other node, if necessary
311 tempValue = lowerBound[partitionDimension];
312 lowerBound[partitionDimension] = partitionValue;
313 if( this->BoundsOverlapBall( query, lowerBound, upperBound,
314 nearestNeighbors.GetLargestDistance() ) )
315 {
316 this->NearestNeighborSearchLoop( node->Right(), query, lowerBound,
317 upperBound, nearestNeighbors );
318 }
319 lowerBound[partitionDimension] = tempValue;
320 }
321 else
322 {
323 // search the closer child node
324 tempValue = lowerBound[partitionDimension];
325 lowerBound[partitionDimension] = partitionValue;
326 if( this->NearestNeighborSearchLoop( node->Right(), query, lowerBound,
327 upperBound, nearestNeighbors ) )
328 {
329 return 1;
330 }
331 lowerBound[partitionDimension] = tempValue;
332
333 // search the other node, if necessary
334 tempValue = upperBound[partitionDimension];
335 upperBound[partitionDimension] = partitionValue;
336 if( this->BoundsOverlapBall( query, lowerBound, upperBound,
337 nearestNeighbors.GetLargestDistance() ) )
338 {
339 this->NearestNeighborSearchLoop( node->Left(), query, lowerBound,
340 upperBound, nearestNeighbors );
341 }
342 upperBound[partitionDimension] = tempValue;
343 }
344
345 // stop or continue search
346 if( this->BallWithinBounds( query, lowerBound, upperBound,
347 nearestNeighbors.GetLargestDistance() ) )
348 {
349 return 1;
350 }
351
352 return 0;
353 }
354
355 template<typename TSample>
356 void
357 KdTree<TSample>
Search(const MeasurementVectorType & query,double radius,InstanceIdentifierVectorType & result) const358 ::Search( const MeasurementVectorType & query, double radius,
359 InstanceIdentifierVectorType & result ) const
360 {
361 MeasurementVectorType lowerBound;
362 MeasurementVectorType upperBound;
363
364 NumericTraits<MeasurementVectorType>::SetLength( lowerBound,
365 this->m_MeasurementVectorSize );
366 NumericTraits<MeasurementVectorType>::SetLength( upperBound,
367 this->m_MeasurementVectorSize );
368
369 for( unsigned int d = 0; d < this->m_MeasurementVectorSize; ++d )
370 {
371 lowerBound[d] = static_cast<MeasurementType>( -std::sqrt(
372 -static_cast<double>( NumericTraits<MeasurementType>::
373 NonpositiveMin() ) ) / 2.0 );
374 upperBound[d] = static_cast< MeasurementType >( std::sqrt(
375 static_cast<double>( NumericTraits< MeasurementType >::max() ) / 2.0 ) );
376 }
377
378 result.clear();
379 this->SearchLoop( this->m_Root, query, radius, lowerBound, upperBound, result );
380 }
381
382 template<typename TSample>
383 inline int
384 KdTree<TSample>
SearchLoop(const KdTreeNodeType * node,const MeasurementVectorType & query,double radius,MeasurementVectorType & lowerBound,MeasurementVectorType & upperBound,InstanceIdentifierVectorType & neighbors) const385 ::SearchLoop( const KdTreeNodeType *node, const MeasurementVectorType &query,
386 double radius, MeasurementVectorType &lowerBound, MeasurementVectorType
387 &upperBound, InstanceIdentifierVectorType &neighbors ) const
388 {
389 InstanceIdentifier tempId;
390 double tempDistance;
391
392 if( node->IsTerminal() )
393 {
394 // terminal node
395 if( node == this->m_EmptyTerminalNode )
396 {
397 // empty node
398 return 0;
399 }
400
401 for( unsigned int i = 0; i < node->Size(); ++i )
402 {
403 tempId = node->GetInstanceIdentifier( i );
404 tempDistance = this->m_DistanceMetric->Evaluate( query,
405 this->m_Sample->GetMeasurementVector( tempId ) );
406 if( tempDistance <= radius )
407 {
408 neighbors.push_back( tempId );
409 }
410 }
411
412 if( this->BallWithinBounds( query, lowerBound, upperBound, radius ) )
413 {
414 return 1;
415 }
416
417 return 0;
418 }
419 if( node->IsTerminal() == false )
420 {
421 tempId = node->GetInstanceIdentifier( 0 );
422 tempDistance = this->m_DistanceMetric->Evaluate( query,
423 this->m_Sample->GetMeasurementVector( tempId ) );
424 if( tempDistance <= radius )
425 {
426 neighbors.push_back( tempId );
427 }
428 }
429
430 unsigned int partitionDimension;
431 MeasurementType partitionValue;
432 MeasurementType tempValue;
433 node->GetParameters( partitionDimension, partitionValue );
434
435 if( query[partitionDimension] <= partitionValue )
436 {
437 // search the closer child node
438 tempValue = upperBound[partitionDimension];
439 upperBound[partitionDimension] = partitionValue;
440 if( this->SearchLoop( node->Left(), query, radius, lowerBound, upperBound,
441 neighbors ) )
442 {
443 return 1;
444 }
445 upperBound[partitionDimension] = tempValue;
446
447 // search the other node, if necessary
448 tempValue = lowerBound[partitionDimension];
449 lowerBound[partitionDimension] = partitionValue;
450 if( this->BoundsOverlapBall( query, lowerBound, upperBound, radius ) )
451 {
452 this->SearchLoop( node->Right(), query, radius, lowerBound, upperBound,
453 neighbors );
454 }
455 lowerBound[partitionDimension] = tempValue;
456 }
457 else
458 {
459 // search the closer child node
460 tempValue = lowerBound[partitionDimension];
461 lowerBound[partitionDimension] = partitionValue;
462 if( this->SearchLoop( node->Right(), query, radius, lowerBound, upperBound,
463 neighbors ) )
464 {
465 return 1;
466 }
467 lowerBound[partitionDimension] = tempValue;
468
469 // search the other node, if necessary
470 tempValue = upperBound[partitionDimension];
471 upperBound[partitionDimension] = partitionValue;
472 if( this->BoundsOverlapBall( query, lowerBound, upperBound, radius ) )
473 {
474 this->SearchLoop( node->Left(), query, radius, lowerBound, upperBound,
475 neighbors );
476 }
477 upperBound[partitionDimension] = tempValue;
478 }
479
480 // stop or continue search
481 if( this->BallWithinBounds( query, lowerBound, upperBound, radius ) )
482 {
483 return 1;
484 }
485
486 return 0;
487 }
488
489 template<typename TSample>
490 inline bool
491 KdTree<TSample>
BallWithinBounds(const MeasurementVectorType & query,MeasurementVectorType & lowerBound,MeasurementVectorType & upperBound,double radius) const492 ::BallWithinBounds( const MeasurementVectorType & query, MeasurementVectorType
493 &lowerBound, MeasurementVectorType & upperBound, double radius ) const
494 {
495 for( unsigned int d = 0; d < this->m_MeasurementVectorSize; ++d )
496 {
497 if( ( this->m_DistanceMetric->Evaluate( query[d], lowerBound[d] ) <=
498 radius ) || ( this->m_DistanceMetric->Evaluate( query[d],
499 upperBound[d] ) <= radius ) )
500 {
501 return false;
502 }
503 }
504 return true;
505 }
506
507 template<typename TSample>
508 inline bool
509 KdTree<TSample>
BoundsOverlapBall(const MeasurementVectorType & query,MeasurementVectorType & lowerBound,MeasurementVectorType & upperBound,double radius) const510 ::BoundsOverlapBall( const MeasurementVectorType &query, MeasurementVectorType
511 &lowerBound, MeasurementVectorType &upperBound, double radius ) const
512 {
513 double squaredSearchRadius = itk::Math::sqr( radius );
514
515 double sum = 0.0;
516 for( unsigned int d = 0; d < this->m_MeasurementVectorSize; ++d )
517 {
518 if( query[d] <= lowerBound[d] )
519 {
520 sum += itk::Math::sqr( this->m_DistanceMetric->Evaluate( query[d],
521 lowerBound[d] ) );
522 if( sum < squaredSearchRadius )
523 {
524 return true;
525 }
526 }
527 else if( query[d] >= upperBound[d] )
528 {
529 sum += itk::Math::sqr( this->m_DistanceMetric->Evaluate( query[d],
530 upperBound[d] ) );
531 if( sum < squaredSearchRadius )
532 {
533 return true;
534 }
535 }
536 }
537 return false;
538 }
539
540 template<typename TSample>
541 void
542 KdTree<TSample>
PrintTree(std::ostream & os) const543 ::PrintTree( std::ostream & os ) const
544 {
545 constexpr unsigned int topLevel = 0;
546 constexpr unsigned int activeDimension = 0;
547
548 this->PrintTree( this->m_Root, topLevel, activeDimension, os );
549 }
550
551 template<typename TSample>
552 void
553 KdTree<TSample>
PrintTree(KdTreeNodeType * node,unsigned int level,unsigned int activeDimension,std::ostream & os) const554 ::PrintTree( KdTreeNodeType *node, unsigned int level,
555 unsigned int activeDimension, std::ostream &os ) const
556 {
557 level++;
558 if( node->IsTerminal() )
559 {
560 // terminal node
561 if( node == this->m_EmptyTerminalNode )
562 {
563 // empty node
564 os << "Empty node: level = " << level << std::endl;
565 return;
566 }
567 os << "Terminal: level = " << level << " dim = " << activeDimension
568 << std::endl;
569 os << " ";
570 for( unsigned int i = 0; i < node->Size(); ++i )
571 {
572 os << "[" << node->GetInstanceIdentifier( i ) << "] "
573 << this->m_Sample->GetMeasurementVector(
574 node->GetInstanceIdentifier( i ) ) << ", ";
575 }
576 os << std::endl;
577 return;
578 }
579
580 unsigned int partitionDimension;
581 MeasurementType partitionValue;
582
583 node->GetParameters( partitionDimension, partitionValue );
584 typename KdTreeNodeType::CentroidType centroid;
585 node->GetWeightedCentroid(centroid);
586 os << "Nonterminal: level = " << level << std::endl;
587 os << " dim = " << partitionDimension << std::endl;
588 os << " value = " << partitionValue << std::endl;
589 os << " weighted centroid = " << centroid;
590 os << " size = " << node->Size() << std::endl;
591 os << " identifier = " << node->GetInstanceIdentifier( 0 );
592 os << this->m_Sample->GetMeasurementVector( node->GetInstanceIdentifier( 0 ) )
593 << std::endl;
594
595 this->PrintTree( node->Left(), level, partitionDimension, os );
596 this->PrintTree( node->Right(), level, partitionDimension, os );
597 }
598
599 template<typename TSample>
600 void
601 KdTree<TSample>
PlotTree(std::ostream & os) const602 ::PlotTree( std::ostream & os ) const
603 {
604 //
605 // Graph header
606 //
607 os << "digraph G {" << std::endl;
608
609 //
610 // Recursively visit the tree and add entries for the nodes
611 //
612 this->PlotTree( this->m_Root, os );
613
614 //
615 // Graph footer
616 //
617 os << "}" << std::endl;
618 }
619
620 template<typename TSample>
621 void
622 KdTree<TSample>
PlotTree(KdTreeNodeType * node,std::ostream & os) const623 ::PlotTree( KdTreeNodeType *node, std::ostream & os ) const
624 {
625 unsigned int partitionDimension;
626 MeasurementType partitionValue;
627
628 node->GetParameters( partitionDimension, partitionValue );
629
630 KdTreeNodeType *left = node->Left();
631 KdTreeNodeType *right = node->Right();
632
633 char partitionDimensionCharSymbol = ( 'X' + partitionDimension );
634
635 if( node->IsTerminal() )
636 {
637 // terminal node
638 if( node != this->m_EmptyTerminalNode )
639 {
640 os << "\"" << node << "\" [label=\"";
641 for( unsigned int i = 0; i < node->Size(); ++i )
642 {
643 os << this->GetMeasurementVector( node->GetInstanceIdentifier( i ) );
644 os << " ";
645 }
646 os << "\" ];" << std::endl;
647 }
648 }
649 else
650 {
651 os << "\"" << node << "\" [label=\"";
652 os << this->GetMeasurementVector( node->GetInstanceIdentifier( 0 ) );
653 os << " " << partitionDimensionCharSymbol << "=" << partitionValue;
654 os << "\" ];" << std::endl;
655 }
656
657 if( left && ( left != this->m_EmptyTerminalNode ) )
658 {
659 os << "\"" << node << "\" -> \"" << left << "\";" << std::endl;
660 this->PlotTree( left, os );
661 }
662
663 if( right && ( right != this->m_EmptyTerminalNode ) )
664 {
665 os << "\"" << node << "\" -> \"" << right << "\";" << std::endl;
666 this->PlotTree( right, os );
667 }
668 }
669 } // end of namespace Statistics
670 } // end of namespace itk
671
672 #endif
673