1 /*
2  * Medical Image Registration ToolKit (MIRTK)
3  *
4  * Copyright 2013-2015 Imperial College London
5  * Copyright 2013-2015 Andreas Schuh
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  *     http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19 
20 #include "mirtk/RobustPointMatch.h"
21 
22 #include "mirtk/Math.h"
23 #include "mirtk/Memory.h"
24 #include "mirtk/Pair.h"
25 #include "mirtk/Array.h"
26 #include "mirtk/Parallel.h"
27 #include "mirtk/Profiling.h"
28 #include "mirtk/PointLocator.h"
29 #include "mirtk/SparseMatrix.h"
30 #include "mirtk/VtkMath.h"
31 
32 #include "vtkNew.h"
33 #include "vtkSmartPointer.h"
34 #include "vtkIdList.h"
35 #include "vtkPointSet.h"
36 #include "vtkOctreePointLocator.h"
37 
38 
39 namespace mirtk {
40 
41 
42 // =============================================================================
43 // Utilities
44 // =============================================================================
45 
46 namespace RobustPointMatchUtils {
47 
48 
49 // -----------------------------------------------------------------------------
50 class SquaredDistance
51 {
52   vtkSmartPointer<vtkPointSet>           _Target;
53   const Array<int>                      *_Sample;
54   vtkSmartPointer<vtkOctreePointLocator> _Source;
55   int                                    _Num;
56   double                                 _Max;
57   double                                 _Sum, _Sum2;
58   int                                    _Cnt;
59 
60 public:
61 
SquaredDistance()62   SquaredDistance()
63   :
64     _Max(.0), _Sum(.0), _Sum2(.0), _Cnt(0)
65   {}
66 
SquaredDistance(const SquaredDistance & lhs,split)67   SquaredDistance(const SquaredDistance &lhs, split)
68   :
69     _Target(lhs._Target), _Sample(lhs._Sample), _Source(lhs._Source),
70     _Num(lhs._Num), _Max(.0), _Sum(.0), _Sum2(.0), _Cnt(0)
71   {}
72 
join(const SquaredDistance & rhs)73   void join(const SquaredDistance &rhs)
74   {
75     _Max   = max(_Max, rhs._Max);
76     _Sum  += rhs._Sum;
77     _Sum2 += rhs._Sum2;
78     _Cnt  += rhs._Cnt;
79   }
80 
operator ()(const blocked_range<int> & re)81   void operator()(const blocked_range<int> &re)
82   {
83     double p1[3], p2[3], dist2 = .0;
84     vtkSmartPointer<vtkIdList> ids = vtkSmartPointer<vtkIdList>::New();
85     for (int r = re.begin(); r != re.end(); ++r) {
86       _Target->GetPoint(PointCorrespondence::GetPointIndex(_Target, _Sample, r), p1);
87       _Source->FindClosestNPoints(_Num, p1, ids);
88       for (vtkIdType i = 0; i < ids->GetNumberOfIds(); ++i) {
89         _Source->GetDataSet()->GetPoint(ids->GetId(i), p2);
90         dist2 = vtkMath::Distance2BetweenPoints(p1, p2);
91         _Sum  += dist2;
92         _Sum2 += dist2 * dist2;
93       }
94       _Cnt += static_cast<int>(ids->GetNumberOfIds());
95       // Note: ids are sorted from closest to farthest
96       if (dist2 > _Max) _Max = dist2;
97     }
98   }
99 
Add(vtkPointSet * target,const Array<int> * sample,vtkPointSet * source,int num=0)100   void Add(vtkPointSet      *target,
101            const Array<int> *sample,
102            vtkPointSet      *source,
103            int               num = 0)
104   {
105     const int m = PointLocator::GetNumberOfPoints(target, sample);
106     const int n = source->GetNumberOfPoints();
107     if (m == 0 || n == 0) return;
108     if (num <= 0) num = n;
109     _Num    = num;
110     _Target = target;
111     _Sample = sample;
112     _Source = vtkSmartPointer<vtkOctreePointLocator>::New();
113     _Source->SetDataSet(source);
114     _Source->BuildLocator();
115     blocked_range<int> range(0, m);
116     parallel_reduce(range, *this);
117     _Target = NULL;
118     _Source = NULL;
119   }
120 
Add(vtkPointSet * target,const Array<int> * sample,vtkPointSet * source,double source_frac,int min_num=1)121   void Add(vtkPointSet *target, const Array<int> *sample, vtkPointSet *source, double source_frac, int min_num  = 1)
122   {
123     int num = max(min_num, iround(source_frac * source->GetNumberOfPoints()));
124     Add(target, sample, source, num);
125   }
126 
Mean()127   double Mean()  { return _Cnt > 0 ? _Sum / _Cnt : .0; }
Sigma()128   double Sigma() { return _Cnt > 0 ? sqrt(_Sum2 / _Cnt - pow(_Sum / _Cnt, 2)) : .0; }
Max()129   double Max()   { return _Max; }
130 };
131 
132 // -----------------------------------------------------------------------------
133 /// Calculate fuzzy correspondence weights
134 class CalculateCorrespondenceWeights
135 {
136   typedef RobustPointMatch::WeightMatrix WeightMatrix;
137   typedef RobustPointMatch::FeatureList  FeatureList;
138 
139   const RegisteredPointSet *_Target;
140   const Array<int>         *_TargetSample;
141   const FeatureList        *_TargetFeatures;
142   const RegisteredPointSet *_Source;
143   const Array<int>         *_SourceSample;
144   const FeatureList        *_SourceFeatures;
145   int                       _NumberOfPoints;
146   int                       _NumberOfFeatures;
147   double                    _MaxDist;
148   double                    _Temperature;
149   double                    _VarianceOfFeatures;
150   WeightMatrix::Entries    *_CorrWeights;
151   vtkAbstractPointLocator  *_Locator;
152 
CalculateCorrespondenceWeights()153   CalculateCorrespondenceWeights() {}
154 
155 public:
156 
CalculateCorrespondenceWeights(const CalculateCorrespondenceWeights & lhs)157   CalculateCorrespondenceWeights(const CalculateCorrespondenceWeights &lhs)
158   :
159     _Target            (lhs._Target),
160     _TargetSample      (lhs._TargetSample),
161     _TargetFeatures    (lhs._TargetFeatures),
162     _Source            (lhs._Source),
163     _SourceSample      (lhs._SourceSample),
164     _SourceFeatures    (lhs._SourceFeatures),
165     _NumberOfPoints    (lhs._NumberOfPoints),
166     _NumberOfFeatures  (lhs._NumberOfFeatures),
167     _MaxDist           (lhs._MaxDist),
168     _Temperature       (lhs._Temperature),
169     _VarianceOfFeatures(lhs._VarianceOfFeatures),
170     _CorrWeights       (lhs._CorrWeights),
171     _Locator           (lhs._Locator)
172   {}
173 
operator ()(const blocked_range<int> & re) const174   void operator()(const blocked_range<int> &re) const
175   {
176     WeightMatrix::Entries::iterator weight;
177     vtkNew<vtkIdList> ids;
178     double *p1 = Allocate<double>(_NumberOfFeatures); // spatial + optional extra features
179     double *p2 = Allocate<double>(_NumberOfFeatures);
180     for (int i = re.begin(); i != re.end(); ++i) {
181       PointCorrespondence::GetPoint(p1, _Target, _TargetSample, i, _TargetFeatures);
182       _Locator->FindPointsWithinRadius(_MaxDist, p1, ids.GetPointer());
183       if (ids->GetNumberOfIds() > 0) {
184         _CorrWeights[i].resize(_CorrWeights[i].size() + ids->GetNumberOfIds());
185         weight = _CorrWeights[i].end() - ids->GetNumberOfIds();
186         for (vtkIdType j = 0; j < ids->GetNumberOfIds(); ++j, ++weight) {
187           PointCorrespondence::GetPoint(p2, _Source, _SourceSample, ids->GetId(j), _SourceFeatures);
188           weight->first  = PointCorrespondence::GetPointIndex(_Source, _SourceSample, ids->GetId(j));
189           weight->second = exp(- vtkMath::Distance2BetweenPoints(p1, p2) / _Temperature);
190           if (_NumberOfFeatures > 3) {
191             weight->second *= exp(- PointCorrespondence::Distance2BetweenPoints(p1+3, p2+3, _NumberOfFeatures-3) / _VarianceOfFeatures);
192           }
193         }
194       }
195     }
196     Deallocate(p1);
197     Deallocate(p2);
198   }
199 
Run(const RegisteredPointSet * target,const Array<int> * target_sample,const FeatureList * target_features,const RegisteredPointSet * source,const Array<int> * source_sample,const FeatureList * source_features,WeightMatrix::Entries * corrw,WeightMatrix::StorageLayout layout,double temperature,double var_of_features,double min_weight)200   static void Run(const RegisteredPointSet    *target,
201                   const Array<int>            *target_sample,
202                   const FeatureList           *target_features,
203                   const RegisteredPointSet    *source,
204                   const Array<int>            *source_sample,
205                   const FeatureList           *source_features,
206                   WeightMatrix::Entries       *corrw,
207                   WeightMatrix::StorageLayout  layout,
208                   double                       temperature,
209                   double                       var_of_features,
210                   double                       min_weight)
211   {
212     if (min_weight >= 1.0) {
213       cerr << "RobustPointMatch::CalculateWeights: Minimum weight must be less than 1" << endl;
214       exit(1);
215     }
216     if (min_weight <= .0) {
217       cerr << "RobustPointMatch::CalculateWeights: Minimum weight must be positive" << endl;
218       exit(1);
219     }
220     if (layout == WeightMatrix::CCS) {
221       swap(target,          source);
222       swap(target_sample,   source_sample);
223       swap(target_features, source_features);
224     }
225     const int m = PointCorrespondence::GetNumberOfPoints(target, target_sample);
226     const int n = PointCorrespondence::GetNumberOfPoints(source, source_sample);
227     if (m == 0 || n == 0) return;
228     MIRTK_START_TIMING();
229     CalculateCorrespondenceWeights body;
230     body._Target             = target;
231     body._TargetSample       = target_sample;
232     body._TargetFeatures     = target_features;
233     body._Source             = source;
234     body._SourceSample       = source_sample;
235     body._SourceFeatures     = source_features;
236     body._NumberOfPoints     = n;
237     body._NumberOfFeatures   = PointCorrespondence::GetPointDimension(target, target_features);
238     body._MaxDist            = sqrt(temperature * -log(min_weight));
239     body._Temperature        = temperature;
240     body._VarianceOfFeatures = var_of_features;
241     body._CorrWeights        = corrw;
242     vtkSmartPointer<vtkOctreePointLocator> locator;
243     vtkSmartPointer<vtkPointSet>           dataset;
244     dataset = PointCorrespondence::GetPointSet(source, source_sample);
245     locator = vtkSmartPointer<vtkOctreePointLocator>::New();
246     locator->SetDataSet(dataset);
247     locator->BuildLocator();
248     body._Locator = locator;
249     parallel_for(blocked_range<int>(0, m), body);
250     for (int i = 0; i < m; ++i) sort(corrw[i].begin(), corrw[i].end());
251     MIRTK_DEBUG_TIMING(7, "calculating weight for each pair of points");
252   }
253 };
254 
255 // -----------------------------------------------------------------------------
256 class GetCentroidOfPoints
257 {
258   vtkPointSet      *_DataSet;
259   const Array<int> *_Sample;
260   Point             _Sum;
261 
GetCentroidOfPoints()262   GetCentroidOfPoints() : _Sum(.0,.0,.0) {}
263 
264 public:
265 
GetCentroidOfPoints(const GetCentroidOfPoints & lhs,split)266   GetCentroidOfPoints(const GetCentroidOfPoints &lhs, split)
267   :
268     _DataSet(lhs._DataSet),
269     _Sample (lhs._Sample),
270     _Sum    (lhs._Sum)
271   {}
272 
join(const GetCentroidOfPoints & rhs)273   void join(const GetCentroidOfPoints &rhs)
274   {
275     _Sum += rhs._Sum;
276   }
277 
~GetCentroidOfPoints()278   ~GetCentroidOfPoints()
279   {
280   }
281 
operator ()(const blocked_range<int> & re)282   void operator ()(const blocked_range<int> &re)
283   {
284     double p[3];
285     for (int i = re.begin(); i != re.end(); ++i) {
286       _DataSet->GetPoint(PointCorrespondence::GetPointIndex(_DataSet, _Sample, i), p);
287       _Sum += Point(p);
288     }
289   }
290 
Run(vtkPointSet * dataset,const Array<int> * sample,Point & centroid)291   static void Run(vtkPointSet      *dataset,
292                   const Array<int> *sample,
293                   Point            &centroid)
294   {
295     const int n = PointCorrespondence::GetNumberOfPoints(dataset, sample);
296     if (n == 0) return;
297     MIRTK_START_TIMING();
298     GetCentroidOfPoints body;
299     body._DataSet = dataset;
300     body._Sample  = sample;
301     blocked_range<int> i(0, n);
302     parallel_reduce(i, body);
303     centroid = body._Sum / n;
304     MIRTK_DEBUG_TIMING(7, "getting centroid of points");
305   }
306 };
307 
308 // -----------------------------------------------------------------------------
309 /// Calculate outlier weights
310 class CalculateOutlierWeights1
311 {
312   typedef RobustPointMatch::WeightMatrix WeightMatrix;
313 
314   const RegisteredPointSet *_DataSet;
315   const Array<int>         *_Sample;
316   Point                     _Cluster;
317   double                    _Temperature;
318   WeightMatrix::Entries    *_CorrWeights;
319   int                       _N;
320 
CalculateOutlierWeights1()321   CalculateOutlierWeights1() {}
322 
323 public:
324 
CalculateOutlierWeights1(const CalculateOutlierWeights1 & lhs)325   CalculateOutlierWeights1(const CalculateOutlierWeights1 &lhs)
326   :
327     _DataSet    (lhs._DataSet),
328     _Sample     (lhs._Sample),
329     _Cluster    (lhs._Cluster),
330     _Temperature(lhs._Temperature),
331     _CorrWeights(lhs._CorrWeights),
332     _N          (lhs._N)
333   {}
334 
operator ()(const blocked_range<int> & re) const335   void operator()(const blocked_range<int> &re) const
336   {
337     double weight;
338     Point p;
339     for (int i = re.begin(); i != re.end(); ++i) {
340       _DataSet->GetPoint(PointCorrespondence::GetPointIndex(_DataSet, _Sample, i), p);
341       weight = exp(- p.SquaredDistance(_Cluster) / _Temperature);
342       _CorrWeights[i].push_back(MakePair(_N, static_cast<WeightMatrix::EntryType>(weight)));
343     }
344   }
345 
Run(const RegisteredPointSet * dataset,const Array<int> * sample,const Point & cluster,WeightMatrix::Entries * corrw,int n,double temperature)346   static void Run(const RegisteredPointSet *dataset,
347                   const Array<int>         *sample,
348                   const Point              &cluster,
349                   WeightMatrix::Entries    *corrw,
350                   int                       n,
351                   double                    temperature)
352   {
353     const int m = PointCorrespondence::GetNumberOfPoints(dataset, sample);
354     if (m == 0) return;
355     MIRTK_START_TIMING();
356     CalculateOutlierWeights1 body;
357     body._DataSet     = dataset;
358     body._Sample      = sample;
359     body._Cluster     = cluster;
360     body._Temperature = temperature;
361     body._CorrWeights = corrw;
362     body._N           = n;
363     blocked_range<int> i(0, m);
364     parallel_for(i, body);
365     MIRTK_DEBUG_TIMING(7, "calculating outlier weights 1");
366   }
367 };
368 
369 // -----------------------------------------------------------------------------
370 /// Calculate outlier weights
371 class CalculateOutlierWeights2
372 {
373   typedef RobustPointMatch::WeightMatrix WeightMatrix;
374 
375   const RegisteredPointSet *_DataSet;
376   const Array<int>         *_Sample;
377   Point                     _Cluster;
378   double                    _Temperature;
379   WeightMatrix::Entries    *_CorrWeights;
380 
CalculateOutlierWeights2()381   CalculateOutlierWeights2() {}
382 
383 public:
384 
CalculateOutlierWeights2(const CalculateOutlierWeights2 & lhs)385   CalculateOutlierWeights2(const CalculateOutlierWeights2 &lhs)
386   :
387     _DataSet    (lhs._DataSet),
388     _Sample     (lhs._Sample),
389     _Cluster    (lhs._Cluster),
390     _Temperature(lhs._Temperature),
391     _CorrWeights(lhs._CorrWeights)
392   {}
393 
operator ()(const blocked_range<int> & re) const394   void operator()(const blocked_range<int> &re) const
395   {
396     double weight;
397     Point p;
398     for (int i = re.begin(); i != re.end(); ++i) {
399       _DataSet->GetPoint(PointCorrespondence::GetPointIndex(_DataSet, _Sample, i), p);
400       weight = exp(- p.SquaredDistance(_Cluster) / _Temperature);
401       (*_CorrWeights)[i] = MakePair(i, static_cast<WeightMatrix::EntryType>(weight));
402     }
403   }
404 
Run(const RegisteredPointSet * dataset,const Array<int> * sample,const Point & cluster,WeightMatrix::Entries & corrw,double temperature)405   static void Run(const RegisteredPointSet *dataset,
406                   const Array<int>         *sample,
407                   const Point              &cluster,
408                   WeightMatrix::Entries    &corrw,
409                   double                    temperature)
410   {
411     const int n = PointCorrespondence::GetNumberOfPoints(dataset, sample);
412     if (n == 0) return;
413     MIRTK_START_TIMING();
414     corrw.resize(n);
415     CalculateOutlierWeights2 body;
416     body._DataSet     = dataset;
417     body._Sample      = sample;
418     body._Cluster     = cluster;
419     body._Temperature = temperature;
420     body._CorrWeights = &corrw;
421     blocked_range<int> i(0, n);
422     parallel_for(i, body);
423     MIRTK_DEBUG_TIMING(7, "calculating outlier weights 2");
424   }
425 };
426 
427 
428 } // namespace RobustPointMatchUtils
429 using namespace RobustPointMatchUtils;
430 
431 // =============================================================================
432 // Construction/Destruction
433 // =============================================================================
434 
435 // -----------------------------------------------------------------------------
RobustPointMatch()436 RobustPointMatch::RobustPointMatch()
437 :
438   _InitialTemperature(numeric_limits<double>::quiet_NaN()),
439   _AnnealingRate     (numeric_limits<double>::quiet_NaN()),
440   _FinalTemperature  (numeric_limits<double>::quiet_NaN()),
441   _Temperature       (numeric_limits<double>::quiet_NaN()),
442   _VarianceOfFeatures(numeric_limits<double>::quiet_NaN())
443 {
444 }
445 
446 // -----------------------------------------------------------------------------
RobustPointMatch(const RegisteredPointSet * target,const RegisteredPointSet * source)447 RobustPointMatch::RobustPointMatch(const RegisteredPointSet *target,
448                                            const RegisteredPointSet *source)
449 :
450   _InitialTemperature(numeric_limits<double>::quiet_NaN()),
451   _AnnealingRate     (numeric_limits<double>::quiet_NaN()),
452   _FinalTemperature  (numeric_limits<double>::quiet_NaN()),
453   _Temperature       (numeric_limits<double>::quiet_NaN()),
454   _VarianceOfFeatures(numeric_limits<double>::quiet_NaN())
455 {
456   Target(target);
457   Source(source);
458   Initialize();
459 }
460 
461 // -----------------------------------------------------------------------------
RobustPointMatch(const RobustPointMatch & other)462 RobustPointMatch::RobustPointMatch(const RobustPointMatch &other)
463 :
464   FuzzyCorrespondence(other),
465   _InitialTemperature  (other._InitialTemperature),
466   _AnnealingRate       (other._AnnealingRate),
467   _FinalTemperature    (other._FinalTemperature),
468   _Temperature         (other._Temperature),
469   _VarianceOfFeatures  (other._VarianceOfFeatures),
470   _TargetOutlierCluster(other._TargetOutlierCluster),
471   _SourceOutlierCluster(other._SourceOutlierCluster)
472 {
473 }
474 
475 // -----------------------------------------------------------------------------
NewInstance() const476 PointCorrespondence *RobustPointMatch::NewInstance() const
477 {
478   return new RobustPointMatch(*this);
479 }
480 
481 // -----------------------------------------------------------------------------
~RobustPointMatch()482 RobustPointMatch::~RobustPointMatch()
483 {
484 }
485 
486 // -----------------------------------------------------------------------------
Type() const487 RobustPointMatch::TypeId RobustPointMatch::Type() const
488 {
489   return TypeId::RobustPointMatch;
490 }
491 
492 // =============================================================================
493 // Parameters
494 // =============================================================================
495 
496 // -----------------------------------------------------------------------------
Set(const char * name,const char * value)497 bool RobustPointMatch::Set(const char *name, const char *value)
498 {
499   // Initial temperature
500   if (strcmp(name, "Initial temperature") == 0 || strcmp(name, "Temperature") == 0) {
501     return FromString(value, _InitialTemperature);
502   }
503   // Annealing rate, negative value indicates kNN search
504   if (strcmp(name, "Annealing rate") == 0) {
505     return FromString(value, _AnnealingRate) && _AnnealingRate < 1.0;
506   }
507   // Final temperature
508   if (strcmp(name, "Final temperature") == 0) {
509     return FromString(value, _FinalTemperature);
510   }
511   // Variance of extra features (resp. of their differences)
512   if (strcmp(name, "Variance of features") == 0) {
513     return FromString(value, _VarianceOfFeatures);
514   }
515   return FuzzyCorrespondence::Set(name, value);
516 }
517 
518 // -----------------------------------------------------------------------------
Parameter() const519 ParameterList RobustPointMatch::Parameter() const
520 {
521   ParameterList params = FuzzyCorrespondence::Parameter();
522   Insert(params, "Initial temperature",  _InitialTemperature);
523   Insert(params, "Annealing rate",       _AnnealingRate);
524   Insert(params, "Final temperature",    _FinalTemperature);
525   Insert(params, "Variance of features", _VarianceOfFeatures);
526   return params;
527 }
528 
529 // =============================================================================
530 // Correspondences
531 // =============================================================================
532 
533 // -----------------------------------------------------------------------------
Initialize()534 void RobustPointMatch::Initialize()
535 {
536   // Initialize base class
537   FuzzyCorrespondence::Initialize();
538 
539   // Ensure that spatial coordinates are first components of feature vector
540   size_t i;
541   for (i = 0; i < _TargetFeatures.size(); ++i) {
542     if (_TargetFeatures[i]._Index == -1) break;
543   }
544   if (i == _TargetFeatures.size()) {
545     FeatureList::value_type info;
546     info._Name      = "spatial coordinates";
547     info._Index     = -1;
548     info._Slope     = 1.0;
549     info._Intercept = .0;
550     _TargetFeatures.insert(_TargetFeatures.begin(), info);
551   } else if (i != 0) {
552     swap(_TargetFeatures[0], _TargetFeatures[i]);
553     if (_TargetFeatures[0]._Slope == .0) _TargetFeatures[0]._Slope = 1.0;
554   }
555 
556   for (i = 0; i < _SourceFeatures.size(); ++i) {
557     if (_SourceFeatures[i]._Index == -1) break;
558   }
559   if (i == _SourceFeatures.size()) {
560     FeatureList::value_type info;
561     info._Name      = "spatial coordinates";
562     info._Index     = -1;
563     info._Slope     = 1.0;
564     info._Intercept = .0;
565     _SourceFeatures.insert(_SourceFeatures.begin(), info);
566   } else if (i != 0) {
567     swap(_SourceFeatures[0], _SourceFeatures[i]);
568     if (_SourceFeatures[0]._Slope == .0) _SourceFeatures[0]._Slope = 1.0;
569   }
570 
571   // Initialize annealing process
572   this->InitializeAnnealing();
573 
574   // TODO: Determine variance
575   if (IsNaN(_VarianceOfFeatures)) _VarianceOfFeatures = 1.0;
576 }
577 
578 // -----------------------------------------------------------------------------
InitializeAnnealing()579 void RobustPointMatch::InitializeAnnealing()
580 {
581   MIRTK_START_TIMING();
582 
583   // Annealing rate
584   if (IsNaN(_AnnealingRate)) _AnnealingRate = 0.93;
585   if (_AnnealingRate >= 1.0) {
586     cerr << this->NameOfClass() << "::Initialize: ";
587     cerr << "Annealing rate must be less than 1, normally it is in the range [0.9, 0.99]" << endl;
588     cerr << "    Alternatively, set it to a negative integer to specify the number" << endl;
589     cerr << "    of nearest neighbors to consider when choosing a new temperature." << endl;
590     exit(1);
591   }
592 
593   // Annealing rate used below to adjust temperature range
594   double annealing_rate = (_AnnealingRate <= .0 ? .93 : _AnnealingRate);
595 
596   // Temperature range
597   if (IsNaN(_InitialTemperature) || _InitialTemperature <= .0) {
598     MIRTK_LOG_EVENT("Initializing annealing process...\n");
599 
600     // Number of nearest neighbors
601     int n = 10;
602     if (_InitialTemperature < .0) n = ceil(-_InitialTemperature);
603     else if (_AnnealingRate < .0) n = ceil(-_AnnealingRate);
604     const int k = min(n, min(_M, _N));
605     MIRTK_LOG_EVENT("  Considering " << k << " nearest neighors\n");
606 
607     // Determine mean and standard deviation of distances to k nearest neighbors
608     SquaredDistance dist2;
609     vtkSmartPointer<vtkPointSet> source = GetPointSet(_Source, _SourceSample);
610     dist2.Add(_Target->PointSet(), _TargetSample, source, k);
611     MIRTK_LOG_EVENT("  Mean squared distance = " << dist2.Mean() << " (sigma = " << dist2.Sigma() << ")\n");
612 
613     // Set initial temperature of annealing process
614     if (IsNaN(_InitialTemperature) || _InitialTemperature <= .0) {
615       _InitialTemperature = dist2.Mean() + 1.5 * dist2.Sigma();
616     }
617 
618     MIRTK_LOG_EVENT("Initializing annealing proces... done\n");
619   }
620 
621   // Set final temperature of annealing process
622   if (IsNaN(_FinalTemperature)) {
623     _FinalTemperature = pow(annealing_rate, 50) * _InitialTemperature;
624   }
625 
626   MIRTK_LOG_EVENT("Initial temperature = " << _InitialTemperature << "\n");
627   MIRTK_LOG_EVENT("Final   temperature = " << _FinalTemperature   << "\n");
628 
629   // Set initial temperature
630   _Temperature = _InitialTemperature;
631   MIRTK_LOG_EVENT("Temperature = " << _InitialTemperature << "\n");
632 
633   MIRTK_DEBUG_TIMING(6, "initialization of annealing process");
634 }
635 
636 // -----------------------------------------------------------------------------
Upgrade()637 bool RobustPointMatch::Upgrade()
638 {
639   // Have base class check ratio of outliers
640   if (!FuzzyCorrespondence::Upgrade()) return false;
641 
642   // Reduce temperature
643   if (_AnnealingRate < .0) {
644 
645     // Number of nearest neighbors
646     const int k = min(int(ceil(-_AnnealingRate)), min(_M, _N));
647 
648     // Compute statistics of (squared) distances
649     SquaredDistance dist2;
650     vtkSmartPointer<vtkPointSet> source = GetPointSet(_Source, _SourceSample);
651     dist2.Add(_Target->PointSet(), _TargetSample, source, k);
652 
653     // Adjust temperature further to speed up annealing process
654     _Temperature = min(0.96 * _Temperature, dist2.Mean());
655   } else {
656     _Temperature *= _AnnealingRate;
657   }
658 
659   // Stop if final temperature reached
660   if (_Temperature < _FinalTemperature) {
661     MIRTK_LOG_EVENT("Final temperature reached\n");
662     return false;
663   }
664 
665   // Log current temperature
666   MIRTK_LOG_EVENT("Temperature = " << _Temperature << "\n");
667   return true;
668 }
669 
670 // -----------------------------------------------------------------------------
CalculateWeights()671 void RobustPointMatch::CalculateWeights()
672 {
673   MIRTK_START_TIMING();
674 
675   // Size of weight matrix
676   const int m = _M + 1;
677   const int n = _N + 1;
678 
679   // Allocate lists for non-zero weight entries
680   const int nentries = (_Weight.Layout() == WeightMatrix::CRS ? m : n);
681   WeightMatrix::Entries *entries = new WeightMatrix::Entries[nentries];
682 
683   // Calculate correspondence weights
684   CalculateCorrespondenceWeights::Run(_Target, _TargetSample, &_TargetFeatures,
685                                       _Source, _SourceSample, &_SourceFeatures,
686                                       entries, _Weight.Layout(),
687                                       _Temperature, _VarianceOfFeatures, _MinWeight);
688 
689   // Compute cluster to which source outliers are matched
690   // (i.e., centroid of target points!)
691   GetCentroidOfPoints::Run(_Target->PointSet(), _TargetSample, _SourceOutlierCluster);
692 
693   // Compute cluster to which target outliers are matched
694   // (i.e., centroid of source points!)
695   GetCentroidOfPoints::Run(_Source->PointSet(), _SourceSample, _TargetOutlierCluster);
696 
697   // Calculate weights of point assignment to outlier clusters
698   if (_Weight.Layout() == WeightMatrix::CRS) {
699 
700     CalculateOutlierWeights1::Run(_Target, _TargetSample, _TargetOutlierCluster,
701                                   entries, _N, _InitialTemperature);
702 
703     CalculateOutlierWeights2::Run(_Source, _SourceSample, _SourceOutlierCluster,
704                                   entries[_M], _InitialTemperature);
705 
706   } else {
707 
708     CalculateOutlierWeights1::Run(_Source, _SourceSample, _SourceOutlierCluster,
709                                   entries, _M, _InitialTemperature);
710 
711     CalculateOutlierWeights2::Run(_Target, _TargetSample, _TargetOutlierCluster,
712                                   entries[_N], _InitialTemperature);
713 
714   }
715 
716   MIRTK_DEBUG_TIMING(6, "calculating correspondence weights (" << m << "x" << n << ")");
717 
718   // Initialize correspondence matrix
719   MIRTK_RESET_TIMING();
720   _Weight.Initialize(m, n, entries, true);
721   delete[] entries;
722   MIRTK_DEBUG_TIMING(7, "copying sparse matrix entries (NNZ=" << _Weight.NNZ() << ")");
723 }
724 
725 
726 } // namespace mirtk
727