1 /*
2 * Medical Image Registration ToolKit (MIRTK)
3 *
4 * Copyright 2013-2015 Imperial College London
5 * Copyright 2013-2015 Andreas Schuh
6 *
7 * Licensed under the Apache License, Version 2.0 (the "License");
8 * you may not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 *
11 * http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
18 */
19
20 #include "mirtk/RobustPointMatch.h"
21
22 #include "mirtk/Math.h"
23 #include "mirtk/Memory.h"
24 #include "mirtk/Pair.h"
25 #include "mirtk/Array.h"
26 #include "mirtk/Parallel.h"
27 #include "mirtk/Profiling.h"
28 #include "mirtk/PointLocator.h"
29 #include "mirtk/SparseMatrix.h"
30 #include "mirtk/VtkMath.h"
31
32 #include "vtkNew.h"
33 #include "vtkSmartPointer.h"
34 #include "vtkIdList.h"
35 #include "vtkPointSet.h"
36 #include "vtkOctreePointLocator.h"
37
38
39 namespace mirtk {
40
41
42 // =============================================================================
43 // Utilities
44 // =============================================================================
45
46 namespace RobustPointMatchUtils {
47
48
49 // -----------------------------------------------------------------------------
50 class SquaredDistance
51 {
52 vtkSmartPointer<vtkPointSet> _Target;
53 const Array<int> *_Sample;
54 vtkSmartPointer<vtkOctreePointLocator> _Source;
55 int _Num;
56 double _Max;
57 double _Sum, _Sum2;
58 int _Cnt;
59
60 public:
61
SquaredDistance()62 SquaredDistance()
63 :
64 _Max(.0), _Sum(.0), _Sum2(.0), _Cnt(0)
65 {}
66
SquaredDistance(const SquaredDistance & lhs,split)67 SquaredDistance(const SquaredDistance &lhs, split)
68 :
69 _Target(lhs._Target), _Sample(lhs._Sample), _Source(lhs._Source),
70 _Num(lhs._Num), _Max(.0), _Sum(.0), _Sum2(.0), _Cnt(0)
71 {}
72
join(const SquaredDistance & rhs)73 void join(const SquaredDistance &rhs)
74 {
75 _Max = max(_Max, rhs._Max);
76 _Sum += rhs._Sum;
77 _Sum2 += rhs._Sum2;
78 _Cnt += rhs._Cnt;
79 }
80
operator ()(const blocked_range<int> & re)81 void operator()(const blocked_range<int> &re)
82 {
83 double p1[3], p2[3], dist2 = .0;
84 vtkSmartPointer<vtkIdList> ids = vtkSmartPointer<vtkIdList>::New();
85 for (int r = re.begin(); r != re.end(); ++r) {
86 _Target->GetPoint(PointCorrespondence::GetPointIndex(_Target, _Sample, r), p1);
87 _Source->FindClosestNPoints(_Num, p1, ids);
88 for (vtkIdType i = 0; i < ids->GetNumberOfIds(); ++i) {
89 _Source->GetDataSet()->GetPoint(ids->GetId(i), p2);
90 dist2 = vtkMath::Distance2BetweenPoints(p1, p2);
91 _Sum += dist2;
92 _Sum2 += dist2 * dist2;
93 }
94 _Cnt += static_cast<int>(ids->GetNumberOfIds());
95 // Note: ids are sorted from closest to farthest
96 if (dist2 > _Max) _Max = dist2;
97 }
98 }
99
Add(vtkPointSet * target,const Array<int> * sample,vtkPointSet * source,int num=0)100 void Add(vtkPointSet *target,
101 const Array<int> *sample,
102 vtkPointSet *source,
103 int num = 0)
104 {
105 const int m = PointLocator::GetNumberOfPoints(target, sample);
106 const int n = source->GetNumberOfPoints();
107 if (m == 0 || n == 0) return;
108 if (num <= 0) num = n;
109 _Num = num;
110 _Target = target;
111 _Sample = sample;
112 _Source = vtkSmartPointer<vtkOctreePointLocator>::New();
113 _Source->SetDataSet(source);
114 _Source->BuildLocator();
115 blocked_range<int> range(0, m);
116 parallel_reduce(range, *this);
117 _Target = NULL;
118 _Source = NULL;
119 }
120
Add(vtkPointSet * target,const Array<int> * sample,vtkPointSet * source,double source_frac,int min_num=1)121 void Add(vtkPointSet *target, const Array<int> *sample, vtkPointSet *source, double source_frac, int min_num = 1)
122 {
123 int num = max(min_num, iround(source_frac * source->GetNumberOfPoints()));
124 Add(target, sample, source, num);
125 }
126
Mean()127 double Mean() { return _Cnt > 0 ? _Sum / _Cnt : .0; }
Sigma()128 double Sigma() { return _Cnt > 0 ? sqrt(_Sum2 / _Cnt - pow(_Sum / _Cnt, 2)) : .0; }
Max()129 double Max() { return _Max; }
130 };
131
132 // -----------------------------------------------------------------------------
133 /// Calculate fuzzy correspondence weights
134 class CalculateCorrespondenceWeights
135 {
136 typedef RobustPointMatch::WeightMatrix WeightMatrix;
137 typedef RobustPointMatch::FeatureList FeatureList;
138
139 const RegisteredPointSet *_Target;
140 const Array<int> *_TargetSample;
141 const FeatureList *_TargetFeatures;
142 const RegisteredPointSet *_Source;
143 const Array<int> *_SourceSample;
144 const FeatureList *_SourceFeatures;
145 int _NumberOfPoints;
146 int _NumberOfFeatures;
147 double _MaxDist;
148 double _Temperature;
149 double _VarianceOfFeatures;
150 WeightMatrix::Entries *_CorrWeights;
151 vtkAbstractPointLocator *_Locator;
152
CalculateCorrespondenceWeights()153 CalculateCorrespondenceWeights() {}
154
155 public:
156
CalculateCorrespondenceWeights(const CalculateCorrespondenceWeights & lhs)157 CalculateCorrespondenceWeights(const CalculateCorrespondenceWeights &lhs)
158 :
159 _Target (lhs._Target),
160 _TargetSample (lhs._TargetSample),
161 _TargetFeatures (lhs._TargetFeatures),
162 _Source (lhs._Source),
163 _SourceSample (lhs._SourceSample),
164 _SourceFeatures (lhs._SourceFeatures),
165 _NumberOfPoints (lhs._NumberOfPoints),
166 _NumberOfFeatures (lhs._NumberOfFeatures),
167 _MaxDist (lhs._MaxDist),
168 _Temperature (lhs._Temperature),
169 _VarianceOfFeatures(lhs._VarianceOfFeatures),
170 _CorrWeights (lhs._CorrWeights),
171 _Locator (lhs._Locator)
172 {}
173
operator ()(const blocked_range<int> & re) const174 void operator()(const blocked_range<int> &re) const
175 {
176 WeightMatrix::Entries::iterator weight;
177 vtkNew<vtkIdList> ids;
178 double *p1 = Allocate<double>(_NumberOfFeatures); // spatial + optional extra features
179 double *p2 = Allocate<double>(_NumberOfFeatures);
180 for (int i = re.begin(); i != re.end(); ++i) {
181 PointCorrespondence::GetPoint(p1, _Target, _TargetSample, i, _TargetFeatures);
182 _Locator->FindPointsWithinRadius(_MaxDist, p1, ids.GetPointer());
183 if (ids->GetNumberOfIds() > 0) {
184 _CorrWeights[i].resize(_CorrWeights[i].size() + ids->GetNumberOfIds());
185 weight = _CorrWeights[i].end() - ids->GetNumberOfIds();
186 for (vtkIdType j = 0; j < ids->GetNumberOfIds(); ++j, ++weight) {
187 PointCorrespondence::GetPoint(p2, _Source, _SourceSample, ids->GetId(j), _SourceFeatures);
188 weight->first = PointCorrespondence::GetPointIndex(_Source, _SourceSample, ids->GetId(j));
189 weight->second = exp(- vtkMath::Distance2BetweenPoints(p1, p2) / _Temperature);
190 if (_NumberOfFeatures > 3) {
191 weight->second *= exp(- PointCorrespondence::Distance2BetweenPoints(p1+3, p2+3, _NumberOfFeatures-3) / _VarianceOfFeatures);
192 }
193 }
194 }
195 }
196 Deallocate(p1);
197 Deallocate(p2);
198 }
199
Run(const RegisteredPointSet * target,const Array<int> * target_sample,const FeatureList * target_features,const RegisteredPointSet * source,const Array<int> * source_sample,const FeatureList * source_features,WeightMatrix::Entries * corrw,WeightMatrix::StorageLayout layout,double temperature,double var_of_features,double min_weight)200 static void Run(const RegisteredPointSet *target,
201 const Array<int> *target_sample,
202 const FeatureList *target_features,
203 const RegisteredPointSet *source,
204 const Array<int> *source_sample,
205 const FeatureList *source_features,
206 WeightMatrix::Entries *corrw,
207 WeightMatrix::StorageLayout layout,
208 double temperature,
209 double var_of_features,
210 double min_weight)
211 {
212 if (min_weight >= 1.0) {
213 cerr << "RobustPointMatch::CalculateWeights: Minimum weight must be less than 1" << endl;
214 exit(1);
215 }
216 if (min_weight <= .0) {
217 cerr << "RobustPointMatch::CalculateWeights: Minimum weight must be positive" << endl;
218 exit(1);
219 }
220 if (layout == WeightMatrix::CCS) {
221 swap(target, source);
222 swap(target_sample, source_sample);
223 swap(target_features, source_features);
224 }
225 const int m = PointCorrespondence::GetNumberOfPoints(target, target_sample);
226 const int n = PointCorrespondence::GetNumberOfPoints(source, source_sample);
227 if (m == 0 || n == 0) return;
228 MIRTK_START_TIMING();
229 CalculateCorrespondenceWeights body;
230 body._Target = target;
231 body._TargetSample = target_sample;
232 body._TargetFeatures = target_features;
233 body._Source = source;
234 body._SourceSample = source_sample;
235 body._SourceFeatures = source_features;
236 body._NumberOfPoints = n;
237 body._NumberOfFeatures = PointCorrespondence::GetPointDimension(target, target_features);
238 body._MaxDist = sqrt(temperature * -log(min_weight));
239 body._Temperature = temperature;
240 body._VarianceOfFeatures = var_of_features;
241 body._CorrWeights = corrw;
242 vtkSmartPointer<vtkOctreePointLocator> locator;
243 vtkSmartPointer<vtkPointSet> dataset;
244 dataset = PointCorrespondence::GetPointSet(source, source_sample);
245 locator = vtkSmartPointer<vtkOctreePointLocator>::New();
246 locator->SetDataSet(dataset);
247 locator->BuildLocator();
248 body._Locator = locator;
249 parallel_for(blocked_range<int>(0, m), body);
250 for (int i = 0; i < m; ++i) sort(corrw[i].begin(), corrw[i].end());
251 MIRTK_DEBUG_TIMING(7, "calculating weight for each pair of points");
252 }
253 };
254
255 // -----------------------------------------------------------------------------
256 class GetCentroidOfPoints
257 {
258 vtkPointSet *_DataSet;
259 const Array<int> *_Sample;
260 Point _Sum;
261
GetCentroidOfPoints()262 GetCentroidOfPoints() : _Sum(.0,.0,.0) {}
263
264 public:
265
GetCentroidOfPoints(const GetCentroidOfPoints & lhs,split)266 GetCentroidOfPoints(const GetCentroidOfPoints &lhs, split)
267 :
268 _DataSet(lhs._DataSet),
269 _Sample (lhs._Sample),
270 _Sum (lhs._Sum)
271 {}
272
join(const GetCentroidOfPoints & rhs)273 void join(const GetCentroidOfPoints &rhs)
274 {
275 _Sum += rhs._Sum;
276 }
277
~GetCentroidOfPoints()278 ~GetCentroidOfPoints()
279 {
280 }
281
operator ()(const blocked_range<int> & re)282 void operator ()(const blocked_range<int> &re)
283 {
284 double p[3];
285 for (int i = re.begin(); i != re.end(); ++i) {
286 _DataSet->GetPoint(PointCorrespondence::GetPointIndex(_DataSet, _Sample, i), p);
287 _Sum += Point(p);
288 }
289 }
290
Run(vtkPointSet * dataset,const Array<int> * sample,Point & centroid)291 static void Run(vtkPointSet *dataset,
292 const Array<int> *sample,
293 Point ¢roid)
294 {
295 const int n = PointCorrespondence::GetNumberOfPoints(dataset, sample);
296 if (n == 0) return;
297 MIRTK_START_TIMING();
298 GetCentroidOfPoints body;
299 body._DataSet = dataset;
300 body._Sample = sample;
301 blocked_range<int> i(0, n);
302 parallel_reduce(i, body);
303 centroid = body._Sum / n;
304 MIRTK_DEBUG_TIMING(7, "getting centroid of points");
305 }
306 };
307
308 // -----------------------------------------------------------------------------
309 /// Calculate outlier weights
310 class CalculateOutlierWeights1
311 {
312 typedef RobustPointMatch::WeightMatrix WeightMatrix;
313
314 const RegisteredPointSet *_DataSet;
315 const Array<int> *_Sample;
316 Point _Cluster;
317 double _Temperature;
318 WeightMatrix::Entries *_CorrWeights;
319 int _N;
320
CalculateOutlierWeights1()321 CalculateOutlierWeights1() {}
322
323 public:
324
CalculateOutlierWeights1(const CalculateOutlierWeights1 & lhs)325 CalculateOutlierWeights1(const CalculateOutlierWeights1 &lhs)
326 :
327 _DataSet (lhs._DataSet),
328 _Sample (lhs._Sample),
329 _Cluster (lhs._Cluster),
330 _Temperature(lhs._Temperature),
331 _CorrWeights(lhs._CorrWeights),
332 _N (lhs._N)
333 {}
334
operator ()(const blocked_range<int> & re) const335 void operator()(const blocked_range<int> &re) const
336 {
337 double weight;
338 Point p;
339 for (int i = re.begin(); i != re.end(); ++i) {
340 _DataSet->GetPoint(PointCorrespondence::GetPointIndex(_DataSet, _Sample, i), p);
341 weight = exp(- p.SquaredDistance(_Cluster) / _Temperature);
342 _CorrWeights[i].push_back(MakePair(_N, static_cast<WeightMatrix::EntryType>(weight)));
343 }
344 }
345
Run(const RegisteredPointSet * dataset,const Array<int> * sample,const Point & cluster,WeightMatrix::Entries * corrw,int n,double temperature)346 static void Run(const RegisteredPointSet *dataset,
347 const Array<int> *sample,
348 const Point &cluster,
349 WeightMatrix::Entries *corrw,
350 int n,
351 double temperature)
352 {
353 const int m = PointCorrespondence::GetNumberOfPoints(dataset, sample);
354 if (m == 0) return;
355 MIRTK_START_TIMING();
356 CalculateOutlierWeights1 body;
357 body._DataSet = dataset;
358 body._Sample = sample;
359 body._Cluster = cluster;
360 body._Temperature = temperature;
361 body._CorrWeights = corrw;
362 body._N = n;
363 blocked_range<int> i(0, m);
364 parallel_for(i, body);
365 MIRTK_DEBUG_TIMING(7, "calculating outlier weights 1");
366 }
367 };
368
369 // -----------------------------------------------------------------------------
370 /// Calculate outlier weights
371 class CalculateOutlierWeights2
372 {
373 typedef RobustPointMatch::WeightMatrix WeightMatrix;
374
375 const RegisteredPointSet *_DataSet;
376 const Array<int> *_Sample;
377 Point _Cluster;
378 double _Temperature;
379 WeightMatrix::Entries *_CorrWeights;
380
CalculateOutlierWeights2()381 CalculateOutlierWeights2() {}
382
383 public:
384
CalculateOutlierWeights2(const CalculateOutlierWeights2 & lhs)385 CalculateOutlierWeights2(const CalculateOutlierWeights2 &lhs)
386 :
387 _DataSet (lhs._DataSet),
388 _Sample (lhs._Sample),
389 _Cluster (lhs._Cluster),
390 _Temperature(lhs._Temperature),
391 _CorrWeights(lhs._CorrWeights)
392 {}
393
operator ()(const blocked_range<int> & re) const394 void operator()(const blocked_range<int> &re) const
395 {
396 double weight;
397 Point p;
398 for (int i = re.begin(); i != re.end(); ++i) {
399 _DataSet->GetPoint(PointCorrespondence::GetPointIndex(_DataSet, _Sample, i), p);
400 weight = exp(- p.SquaredDistance(_Cluster) / _Temperature);
401 (*_CorrWeights)[i] = MakePair(i, static_cast<WeightMatrix::EntryType>(weight));
402 }
403 }
404
Run(const RegisteredPointSet * dataset,const Array<int> * sample,const Point & cluster,WeightMatrix::Entries & corrw,double temperature)405 static void Run(const RegisteredPointSet *dataset,
406 const Array<int> *sample,
407 const Point &cluster,
408 WeightMatrix::Entries &corrw,
409 double temperature)
410 {
411 const int n = PointCorrespondence::GetNumberOfPoints(dataset, sample);
412 if (n == 0) return;
413 MIRTK_START_TIMING();
414 corrw.resize(n);
415 CalculateOutlierWeights2 body;
416 body._DataSet = dataset;
417 body._Sample = sample;
418 body._Cluster = cluster;
419 body._Temperature = temperature;
420 body._CorrWeights = &corrw;
421 blocked_range<int> i(0, n);
422 parallel_for(i, body);
423 MIRTK_DEBUG_TIMING(7, "calculating outlier weights 2");
424 }
425 };
426
427
428 } // namespace RobustPointMatchUtils
429 using namespace RobustPointMatchUtils;
430
431 // =============================================================================
432 // Construction/Destruction
433 // =============================================================================
434
435 // -----------------------------------------------------------------------------
RobustPointMatch()436 RobustPointMatch::RobustPointMatch()
437 :
438 _InitialTemperature(numeric_limits<double>::quiet_NaN()),
439 _AnnealingRate (numeric_limits<double>::quiet_NaN()),
440 _FinalTemperature (numeric_limits<double>::quiet_NaN()),
441 _Temperature (numeric_limits<double>::quiet_NaN()),
442 _VarianceOfFeatures(numeric_limits<double>::quiet_NaN())
443 {
444 }
445
446 // -----------------------------------------------------------------------------
RobustPointMatch(const RegisteredPointSet * target,const RegisteredPointSet * source)447 RobustPointMatch::RobustPointMatch(const RegisteredPointSet *target,
448 const RegisteredPointSet *source)
449 :
450 _InitialTemperature(numeric_limits<double>::quiet_NaN()),
451 _AnnealingRate (numeric_limits<double>::quiet_NaN()),
452 _FinalTemperature (numeric_limits<double>::quiet_NaN()),
453 _Temperature (numeric_limits<double>::quiet_NaN()),
454 _VarianceOfFeatures(numeric_limits<double>::quiet_NaN())
455 {
456 Target(target);
457 Source(source);
458 Initialize();
459 }
460
461 // -----------------------------------------------------------------------------
RobustPointMatch(const RobustPointMatch & other)462 RobustPointMatch::RobustPointMatch(const RobustPointMatch &other)
463 :
464 FuzzyCorrespondence(other),
465 _InitialTemperature (other._InitialTemperature),
466 _AnnealingRate (other._AnnealingRate),
467 _FinalTemperature (other._FinalTemperature),
468 _Temperature (other._Temperature),
469 _VarianceOfFeatures (other._VarianceOfFeatures),
470 _TargetOutlierCluster(other._TargetOutlierCluster),
471 _SourceOutlierCluster(other._SourceOutlierCluster)
472 {
473 }
474
475 // -----------------------------------------------------------------------------
NewInstance() const476 PointCorrespondence *RobustPointMatch::NewInstance() const
477 {
478 return new RobustPointMatch(*this);
479 }
480
481 // -----------------------------------------------------------------------------
~RobustPointMatch()482 RobustPointMatch::~RobustPointMatch()
483 {
484 }
485
486 // -----------------------------------------------------------------------------
Type() const487 RobustPointMatch::TypeId RobustPointMatch::Type() const
488 {
489 return TypeId::RobustPointMatch;
490 }
491
492 // =============================================================================
493 // Parameters
494 // =============================================================================
495
496 // -----------------------------------------------------------------------------
Set(const char * name,const char * value)497 bool RobustPointMatch::Set(const char *name, const char *value)
498 {
499 // Initial temperature
500 if (strcmp(name, "Initial temperature") == 0 || strcmp(name, "Temperature") == 0) {
501 return FromString(value, _InitialTemperature);
502 }
503 // Annealing rate, negative value indicates kNN search
504 if (strcmp(name, "Annealing rate") == 0) {
505 return FromString(value, _AnnealingRate) && _AnnealingRate < 1.0;
506 }
507 // Final temperature
508 if (strcmp(name, "Final temperature") == 0) {
509 return FromString(value, _FinalTemperature);
510 }
511 // Variance of extra features (resp. of their differences)
512 if (strcmp(name, "Variance of features") == 0) {
513 return FromString(value, _VarianceOfFeatures);
514 }
515 return FuzzyCorrespondence::Set(name, value);
516 }
517
518 // -----------------------------------------------------------------------------
Parameter() const519 ParameterList RobustPointMatch::Parameter() const
520 {
521 ParameterList params = FuzzyCorrespondence::Parameter();
522 Insert(params, "Initial temperature", _InitialTemperature);
523 Insert(params, "Annealing rate", _AnnealingRate);
524 Insert(params, "Final temperature", _FinalTemperature);
525 Insert(params, "Variance of features", _VarianceOfFeatures);
526 return params;
527 }
528
529 // =============================================================================
530 // Correspondences
531 // =============================================================================
532
533 // -----------------------------------------------------------------------------
Initialize()534 void RobustPointMatch::Initialize()
535 {
536 // Initialize base class
537 FuzzyCorrespondence::Initialize();
538
539 // Ensure that spatial coordinates are first components of feature vector
540 size_t i;
541 for (i = 0; i < _TargetFeatures.size(); ++i) {
542 if (_TargetFeatures[i]._Index == -1) break;
543 }
544 if (i == _TargetFeatures.size()) {
545 FeatureList::value_type info;
546 info._Name = "spatial coordinates";
547 info._Index = -1;
548 info._Slope = 1.0;
549 info._Intercept = .0;
550 _TargetFeatures.insert(_TargetFeatures.begin(), info);
551 } else if (i != 0) {
552 swap(_TargetFeatures[0], _TargetFeatures[i]);
553 if (_TargetFeatures[0]._Slope == .0) _TargetFeatures[0]._Slope = 1.0;
554 }
555
556 for (i = 0; i < _SourceFeatures.size(); ++i) {
557 if (_SourceFeatures[i]._Index == -1) break;
558 }
559 if (i == _SourceFeatures.size()) {
560 FeatureList::value_type info;
561 info._Name = "spatial coordinates";
562 info._Index = -1;
563 info._Slope = 1.0;
564 info._Intercept = .0;
565 _SourceFeatures.insert(_SourceFeatures.begin(), info);
566 } else if (i != 0) {
567 swap(_SourceFeatures[0], _SourceFeatures[i]);
568 if (_SourceFeatures[0]._Slope == .0) _SourceFeatures[0]._Slope = 1.0;
569 }
570
571 // Initialize annealing process
572 this->InitializeAnnealing();
573
574 // TODO: Determine variance
575 if (IsNaN(_VarianceOfFeatures)) _VarianceOfFeatures = 1.0;
576 }
577
578 // -----------------------------------------------------------------------------
InitializeAnnealing()579 void RobustPointMatch::InitializeAnnealing()
580 {
581 MIRTK_START_TIMING();
582
583 // Annealing rate
584 if (IsNaN(_AnnealingRate)) _AnnealingRate = 0.93;
585 if (_AnnealingRate >= 1.0) {
586 cerr << this->NameOfClass() << "::Initialize: ";
587 cerr << "Annealing rate must be less than 1, normally it is in the range [0.9, 0.99]" << endl;
588 cerr << " Alternatively, set it to a negative integer to specify the number" << endl;
589 cerr << " of nearest neighbors to consider when choosing a new temperature." << endl;
590 exit(1);
591 }
592
593 // Annealing rate used below to adjust temperature range
594 double annealing_rate = (_AnnealingRate <= .0 ? .93 : _AnnealingRate);
595
596 // Temperature range
597 if (IsNaN(_InitialTemperature) || _InitialTemperature <= .0) {
598 MIRTK_LOG_EVENT("Initializing annealing process...\n");
599
600 // Number of nearest neighbors
601 int n = 10;
602 if (_InitialTemperature < .0) n = ceil(-_InitialTemperature);
603 else if (_AnnealingRate < .0) n = ceil(-_AnnealingRate);
604 const int k = min(n, min(_M, _N));
605 MIRTK_LOG_EVENT(" Considering " << k << " nearest neighors\n");
606
607 // Determine mean and standard deviation of distances to k nearest neighbors
608 SquaredDistance dist2;
609 vtkSmartPointer<vtkPointSet> source = GetPointSet(_Source, _SourceSample);
610 dist2.Add(_Target->PointSet(), _TargetSample, source, k);
611 MIRTK_LOG_EVENT(" Mean squared distance = " << dist2.Mean() << " (sigma = " << dist2.Sigma() << ")\n");
612
613 // Set initial temperature of annealing process
614 if (IsNaN(_InitialTemperature) || _InitialTemperature <= .0) {
615 _InitialTemperature = dist2.Mean() + 1.5 * dist2.Sigma();
616 }
617
618 MIRTK_LOG_EVENT("Initializing annealing proces... done\n");
619 }
620
621 // Set final temperature of annealing process
622 if (IsNaN(_FinalTemperature)) {
623 _FinalTemperature = pow(annealing_rate, 50) * _InitialTemperature;
624 }
625
626 MIRTK_LOG_EVENT("Initial temperature = " << _InitialTemperature << "\n");
627 MIRTK_LOG_EVENT("Final temperature = " << _FinalTemperature << "\n");
628
629 // Set initial temperature
630 _Temperature = _InitialTemperature;
631 MIRTK_LOG_EVENT("Temperature = " << _InitialTemperature << "\n");
632
633 MIRTK_DEBUG_TIMING(6, "initialization of annealing process");
634 }
635
636 // -----------------------------------------------------------------------------
Upgrade()637 bool RobustPointMatch::Upgrade()
638 {
639 // Have base class check ratio of outliers
640 if (!FuzzyCorrespondence::Upgrade()) return false;
641
642 // Reduce temperature
643 if (_AnnealingRate < .0) {
644
645 // Number of nearest neighbors
646 const int k = min(int(ceil(-_AnnealingRate)), min(_M, _N));
647
648 // Compute statistics of (squared) distances
649 SquaredDistance dist2;
650 vtkSmartPointer<vtkPointSet> source = GetPointSet(_Source, _SourceSample);
651 dist2.Add(_Target->PointSet(), _TargetSample, source, k);
652
653 // Adjust temperature further to speed up annealing process
654 _Temperature = min(0.96 * _Temperature, dist2.Mean());
655 } else {
656 _Temperature *= _AnnealingRate;
657 }
658
659 // Stop if final temperature reached
660 if (_Temperature < _FinalTemperature) {
661 MIRTK_LOG_EVENT("Final temperature reached\n");
662 return false;
663 }
664
665 // Log current temperature
666 MIRTK_LOG_EVENT("Temperature = " << _Temperature << "\n");
667 return true;
668 }
669
670 // -----------------------------------------------------------------------------
CalculateWeights()671 void RobustPointMatch::CalculateWeights()
672 {
673 MIRTK_START_TIMING();
674
675 // Size of weight matrix
676 const int m = _M + 1;
677 const int n = _N + 1;
678
679 // Allocate lists for non-zero weight entries
680 const int nentries = (_Weight.Layout() == WeightMatrix::CRS ? m : n);
681 WeightMatrix::Entries *entries = new WeightMatrix::Entries[nentries];
682
683 // Calculate correspondence weights
684 CalculateCorrespondenceWeights::Run(_Target, _TargetSample, &_TargetFeatures,
685 _Source, _SourceSample, &_SourceFeatures,
686 entries, _Weight.Layout(),
687 _Temperature, _VarianceOfFeatures, _MinWeight);
688
689 // Compute cluster to which source outliers are matched
690 // (i.e., centroid of target points!)
691 GetCentroidOfPoints::Run(_Target->PointSet(), _TargetSample, _SourceOutlierCluster);
692
693 // Compute cluster to which target outliers are matched
694 // (i.e., centroid of source points!)
695 GetCentroidOfPoints::Run(_Source->PointSet(), _SourceSample, _TargetOutlierCluster);
696
697 // Calculate weights of point assignment to outlier clusters
698 if (_Weight.Layout() == WeightMatrix::CRS) {
699
700 CalculateOutlierWeights1::Run(_Target, _TargetSample, _TargetOutlierCluster,
701 entries, _N, _InitialTemperature);
702
703 CalculateOutlierWeights2::Run(_Source, _SourceSample, _SourceOutlierCluster,
704 entries[_M], _InitialTemperature);
705
706 } else {
707
708 CalculateOutlierWeights1::Run(_Source, _SourceSample, _SourceOutlierCluster,
709 entries, _M, _InitialTemperature);
710
711 CalculateOutlierWeights2::Run(_Target, _TargetSample, _TargetOutlierCluster,
712 entries[_N], _InitialTemperature);
713
714 }
715
716 MIRTK_DEBUG_TIMING(6, "calculating correspondence weights (" << m << "x" << n << ")");
717
718 // Initialize correspondence matrix
719 MIRTK_RESET_TIMING();
720 _Weight.Initialize(m, n, entries, true);
721 delete[] entries;
722 MIRTK_DEBUG_TIMING(7, "copying sparse matrix entries (NNZ=" << _Weight.NNZ() << ")");
723 }
724
725
726 } // namespace mirtk
727