1 /*
2  * Medical Image Registration ToolKit (MIRTK)
3  *
4  * Copyright 2013-2015 Imperial College London
5  * Copyright 2013-2015 Andreas Schuh
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  *     http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19 
20 #include "mirtk/PointCorrespondence.h"
21 
22 #include "mirtk/Vtk.h"
23 #include "mirtk/Math.h"
24 #include "mirtk/Pair.h"
25 #include "mirtk/Array.h"
26 #include "mirtk/Algorithm.h"
27 #include "mirtk/Vector.h"
28 #include "mirtk/Matrix.h"
29 #include "mirtk/SparseMatrix.h"
30 #include "mirtk/Parallel.h"
31 #include "mirtk/Profiling.h"
32 #include "mirtk/Transformation.h"
33 #include "mirtk/SpectralDecomposition.h"
34 #include "mirtk/PointSetIO.h"
35 
36 #include "mirtk/CommonExport.h"
37 
38 #include "vtkSmartPointer.h"
39 #include "vtkDataArray.h"
40 #include "vtkPointData.h"
41 #include "vtkOctreePointLocator.h"
42 #include "vtkPMaskPoints.h"
43 #include "vtkFloatArray.h"
44 #include "vtkCharArray.h"
45 #include "vtkPolyData.h"
46 #include "vtkPolyDataNormals.h"
47 #include "vtkWindowedSincPolyDataFilter.h"
48 
49 #include "mirtk/FiducialMatch.h"
50 #include "mirtk/ClosestPoint.h"
51 #include "mirtk/ClosestPointLabel.h"
52 #include "mirtk/ClosestCell.h"
53 #include "mirtk/SpectralMatch.h"
54 #include "mirtk/RobustClosestPoint.h"
55 #include "mirtk/RobustPointMatch.h"
56 
57 
58 namespace mirtk {
59 
60 
61 // Global debug level (cf. mirtk/Options.h)
62 MIRTK_Common_EXPORT extern int debug;
63 
64 
65 // =============================================================================
66 // Factory method
67 // =============================================================================
68 
69 // -----------------------------------------------------------------------------
New(TypeId type)70 PointCorrespondence *PointCorrespondence::New(TypeId type)
71 {
72   switch (type) {
73     case FiducialMatch:       return new class FiducialMatch();
74     case ClosestPoint:        return new class ClosestPoint();
75     case ClosestPointLabel:   return new class ClosestPointLabel();
76     case ClosestCell:         return new class ClosestCell();
77     case SpectralMatch:       return new class SpectralMatch();
78     case RobustClosestPoint:  return new class RobustClosestPoint();
79     case RobustPointMatch:    return new class RobustPointMatch();
80     default:
81       cerr << "PointCorrespondence::New: Unknown type = " << type << endl;
82       exit(1);
83   }
84 }
85 
86 // -----------------------------------------------------------------------------
New(const char * type_name)87 PointCorrespondence *PointCorrespondence::New(const char *type_name)
88 {
89   TypeId type = Unknown;
90   if (!FromString(type_name, type)) {
91     cerr << "PointCorrespondence::New: Unknown type = " << type_name << endl;
92     exit(1);
93   }
94   return New(type);
95 }
96 
97 // -----------------------------------------------------------------------------
FromString(const char * str,PointCorrespondence::TypeId & type)98 template <> bool FromString(const char *str, PointCorrespondence::TypeId &type)
99 {
100   if (strcmp(str, "Index") == 0 || strcmp(str, "Fiducial") == 0) {
101     type = PointCorrespondence::FiducialMatch;
102   } else if (strcmp(str, "CP")            == 0 ||
103              strcmp(str, "ClosestPoint")  == 0 ||
104              strcmp(str, "Closest Point") == 0 ||
105              strcmp(str, "Closest point") == 0) {
106     type = PointCorrespondence::ClosestPoint;
107   } else if (strcmp(str, "ClosestPointLabel")   == 0 ||
108              strcmp(str, "Closest Point Label") == 0 ||
109              strcmp(str, "Closest point label") == 0) {
110     type = PointCorrespondence::ClosestPointLabel;
111   } else if (strcmp(str, "CSP")                   == 0 ||
112              strcmp(str, "ClosestCell")           == 0 ||
113              strcmp(str, "Closest Cell")          == 0 ||
114              strcmp(str, "Closest cell")          == 0 ||
115              strcmp(str, "Closest Surface Point") == 0 ||
116              strcmp(str, "Closest surface point") == 0 ||
117              strcmp(str, "Closest")               == 0) {
118     type = PointCorrespondence::ClosestCell;
119   } else if (strcmp(str, "SM")                   == 0 ||
120              strcmp(str, "SpectralMatch")   == 0 ||
121              strcmp(str, "Spectral Match") == 0 ||
122              strcmp(str, "Spectral match") == 0) {
123     type = PointCorrespondence::SpectralMatch;
124   } else if (strcmp(str, "RCP")                  == 0 ||
125              strcmp(str, "RobustClosestPoint")   == 0 ||
126              strcmp(str, "Robust Closest Point") == 0 ||
127              strcmp(str, "Robust closest point") == 0) {
128     type = PointCorrespondence::RobustClosestPoint;
129   } else if (strcmp(str, "RPM")                == 0 ||
130              strcmp(str, "RobustPointMatch")   == 0 ||
131              strcmp(str, "Robust Point Match") == 0 ||
132              strcmp(str, "Robust point match") == 0) {
133     type = PointCorrespondence::RobustPointMatch;
134   } else {
135     type = PointCorrespondence::Unknown;
136   }
137   return (type != PointCorrespondence::Unknown);
138 }
139 
140 // -----------------------------------------------------------------------------
ToString(const PointCorrespondence::TypeId & type,int w,char c,bool left)141 template <> string ToString(const PointCorrespondence::TypeId &type, int w, char c, bool left)
142 {
143   string str;
144   switch (type) {
145     case PointCorrespondence::FiducialMatch:      str = "Index";              break;
146     case PointCorrespondence::ClosestPoint:       str = "ClosestPoint";       break;
147     case PointCorrespondence::ClosestPointLabel:  str = "ClosestPointLabel";  break;
148     case PointCorrespondence::ClosestCell:        str = "ClosestCell";        break;
149     case PointCorrespondence::SpectralMatch:      str = "SpectralMatch";      break;
150     case PointCorrespondence::RobustClosestPoint: str = "RobustClosestPoint"; break;
151     case PointCorrespondence::RobustPointMatch:   str = "RobustPointMatch";   break;
152     default:                                      str = "Unknown";            break;
153   }
154   return ToString(str, w, c, left);
155 }
156 
157 // =============================================================================
158 // Auxiliary functions
159 // =============================================================================
160 
161 namespace PointCorrespondenceUtils {
162 
163 
164 // -----------------------------------------------------------------------------
165 /// Auxiliary functor used by SamplePoints
166 class FindIndicesOfPointsClosestToSamples
167 {
168   vtkPointSet             *_Samples;
169   vtkAbstractPointLocator *_Locator;
170   Array<int>              *_Index;
171 
172 public:
173 
operator ()(const blocked_range<vtkIdType> & re) const174   void operator()(const blocked_range<vtkIdType> &re) const
175   {
176     double pt[3];
177     Array<int> &index = (*_Index);
178     for (vtkIdType i = re.begin(); i != re.end(); ++i) {
179       _Samples->GetPoint(i, pt);
180       index[i] = _Locator->FindClosestPoint(pt);
181     }
182   }
183 
Run(vtkAbstractPointLocator * locator,vtkPointSet * samples,Array<int> & index)184   static void Run(vtkAbstractPointLocator *locator, vtkPointSet *samples, Array<int> &index)
185   {
186     index.resize(samples->GetNumberOfPoints());
187     FindIndicesOfPointsClosestToSamples body;
188     body._Samples = samples;
189     body._Locator = locator;
190     body._Index   = &index;
191     blocked_range<vtkIdType> pts(0, samples->GetNumberOfPoints());
192     parallel_for(pts, body);
193     sort(index.begin(), index.end());
194     index.erase(unique(index.begin(), index.end()), index.end());
195   }
196 };
197 
198 // -----------------------------------------------------------------------------
SamplePoints(vtkPointSet * pointset,Array<int> & indices,int maxnum,double maxdist,bool stratified)199 void SamplePoints(vtkPointSet *pointset, Array<int> &indices,
200                   int maxnum, double maxdist, bool stratified)
201 {
202   MIRTK_START_TIMING();
203 
204   // Reset set of drawn samples (empty --> use all points)
205   indices.clear();
206 
207   // Skip if input data set contains no points (failsafe)
208   if (pointset->GetNumberOfPoints() == 0) return;
209 
210   // Build point locator
211   //
212   // Due to a bug in vtkKdTreePointLocator, calling BuildLocator
213   // is not sufficient to make FindClosestPoint thread-safe as it does
214   // not call vtkBSPIntersections::BuildRegionsList
215   // (cf. http://www.vtk.org/Bug/view.php?id=15206 ).
216   vtkSmartPointer<vtkOctreePointLocator> locator;
217   if (maxnum > 0 || maxdist > .0) {
218     locator = vtkSmartPointer<vtkOctreePointLocator>::New();
219     locator->SetDataSet(pointset);
220     locator->BuildLocator();
221   }
222 
223   // Determine maximum number of samples
224   int nsamples;
225   if (maxnum == 0 && maxdist > .0) {
226     const double r = maxdist / 2.0;
227     // Count number of points within radius r of each input point
228     vtkNew<vtkIdList> ids;
229     double p[3];
230     double sum = .0;
231     for (vtkIdType i = 0; i < pointset->GetNumberOfPoints(); ++i) {
232       pointset->GetPoint(i, p);
233       locator->FindPointsWithinRadius(r, p, ids.GetPointer());
234       sum += ids->GetNumberOfIds();
235     }
236     // Divide number of points by average number of points within radius r
237     nsamples = iround(pointset->GetNumberOfPoints() / (sum / pointset->GetNumberOfPoints()));
238   } else {
239     nsamples = maxnum;
240   }
241 
242   // Extract specified maximum number of samples
243   if (nsamples > 0) {
244     vtkSmartPointer<vtkPMaskPoints> sampler;
245     sampler = vtkSmartPointer<vtkPMaskPoints>::New();
246     sampler->GenerateVerticesOff();
247     sampler->SetMaximumNumberOfPoints(nsamples);
248     sampler->SetRandomModeType(stratified ? 2 : 1);
249     sampler->RandomModeOn();
250     sampler->ProportionalMaximumNumberOfPointsOn();
251     SetVTKInput(sampler, pointset);
252     MIRTK_START_TIMING();
253     sampler->Update();
254     MIRTK_DEBUG_TIMING(6, "uniformly subsampling point set");
255     MIRTK_RESET_TIMING();
256     FindIndicesOfPointsClosestToSamples::Run(locator, sampler->GetOutput(), indices);
257     MIRTK_DEBUG_TIMING(6, "finding closest sample points");
258   }
259 
260   MIRTK_DEBUG_TIMING(5, "subsampling of point set (#samples = " << nsamples << ")");
261 }
262 
263 // -----------------------------------------------------------------------------
264 /// Maximum range of spatial coordinates
MaxSpatialRange(vtkPointSet * target,vtkPointSet * source)265 double MaxSpatialRange(vtkPointSet *target, vtkPointSet *source)
266 {
267   double target_range[6], source_range[6], range[3];
268   target->GetBounds(target_range);
269   source->GetBounds(source_range);
270   range[0] = max(target_range[1] - target_range[0], source_range[1] - source_range[0]);
271   range[1] = max(target_range[3] - target_range[2], source_range[3] - source_range[2]);
272   range[2] = max(target_range[5] - target_range[4], source_range[5] - source_range[4]);
273   return max(max(range[0], range[1]), range[2]);
274 }
275 
276 // -----------------------------------------------------------------------------
SpectralPoints(vtkPointSet * input)277 vtkSmartPointer<vtkPointSet> SpectralPoints(vtkPointSet *input)
278 {
279   vtkDataArray *eigenmodes = input->GetPointData()->GetArray("eigenmodes");
280   if (!eigenmodes) return NULL;
281 
282   vtkSmartPointer<vtkPoints> points = vtkSmartPointer<vtkPoints>::New();
283   points->SetNumberOfPoints(input->GetNumberOfPoints());
284 
285   double *p = new double[max(3, eigenmodes->GetNumberOfComponents())];
286   for (vtkIdType i = 0; i < points->GetNumberOfPoints(); ++i) {
287     eigenmodes->GetTuple(i, p);
288     points->SetPoint(i, p);
289   }
290   delete[] p;
291 
292   vtkSmartPointer<vtkPointSet> output;
293   output.TakeReference(input->NewInstance());
294   output->DeepCopy (input);
295   output->SetPoints(points);
296   return output;
297 }
298 
299 // -----------------------------------------------------------------------------
ComputeNormals(vtkPolyData * dataset)300 int ComputeNormals(vtkPolyData *dataset)
301 {
302   // Smooth polydata
303   vtkSmartPointer<vtkWindowedSincPolyDataFilter> smoother;
304   smoother = vtkSmartPointer<vtkWindowedSincPolyDataFilter>::New();
305   smoother->FeatureEdgeSmoothingOff();
306   smoother->SetNumberOfIterations(25);
307   smoother->SetPassBand(.1);
308   smoother->NormalizeCoordinatesOn();
309   SetVTKInput(smoother, dataset);
310   // Calculate normals
311   vtkSmartPointer<vtkPolyDataNormals> calculator;
312   calculator = vtkSmartPointer<vtkPolyDataNormals>::New();
313   calculator->SplittingOff();
314   calculator->ConsistencyOn();
315   calculator->AutoOrientNormalsOn();
316   SetVTKConnection(calculator, smoother);
317   calculator->Update();
318   vtkDataArray *normals = calculator->GetOutput()->GetPointData()->GetNormals();
319   normals->SetName("Normals");
320   dataset->GetPointData()->SetNormals(normals);
321   int i = -1;
322   dataset->GetPointData()->GetArray("Normals", i);
323   return i;
324 }
325 
326 // -----------------------------------------------------------------------------
ComputeSpectralNormals(vtkPolyData * dataset)327 int ComputeSpectralNormals(vtkPolyData *dataset)
328 {
329   vtkSmartPointer<vtkPointSet> spectral_pointset = SpectralPoints(dataset);
330   vtkPolyData *spectral_polydata = vtkPolyData::SafeDownCast(spectral_pointset);
331   int i = -1;
332   if (spectral_polydata) {
333     ComputeNormals(spectral_polydata);
334     vtkDataArray *normals = dataset->GetPointData()->GetArray("spectral normals", i);
335     if (normals) normals->DeepCopy(spectral_polydata->GetPointData()->GetNormals());
336     else {
337       i = dataset->GetPointData()->AddArray(spectral_polydata->GetPointData()->GetNormals());
338     }
339   }
340   return i;
341 }
342 
343 // -----------------------------------------------------------------------------
GetEigenmodes(const RegisteredPointSet * dataset,int k,Matrix & m,Vector & v)344 int GetEigenmodes(const RegisteredPointSet *dataset, int k, Matrix &m, Vector &v)
345 {
346   int           i = -2;
347   vtkDataArray *s = dataset->PointSet()->GetPointData()->GetArray("eigenmodes", i);
348   if (dataset->Transformation() || s == NULL || s->GetNumberOfComponents() < k || v.Rows() < k) {
349     vtkPolyData *polydata = vtkPolyData::SafeDownCast(dataset->PointSet());
350     if (!polydata) {
351       cerr << "PointCorrespondence: Can only compute eigenmodes of surface mesh" << endl;
352       exit(1);
353     }
354     if (SpectralDecomposition::ComputeEigenmodes(polydata, k, m, v) != k) {
355       cerr << "PointCorrespondence: Failed to compute " << k << " eigenmodes" << endl;
356       exit(1);
357     }
358     return -2;
359   } else {
360     m.Initialize(static_cast<int>(s->GetNumberOfTuples()), k);
361     double *row = new double[k];
362     for (vtkIdType r = 0; r < s->GetNumberOfTuples(); ++r) {
363       s->GetTuple(r, row);
364       for (int c = 0; c < k; ++c) m(r, c) = row[c];
365     }
366     delete[] row;
367     return i;
368   }
369 }
370 
371 // -----------------------------------------------------------------------------
GetNormalizationParametersOfEigenmodes(double & slope,double & intercept,vtkDataArray * m1,vtkDataArray * m2,Vector w=Vector ())372 void GetNormalizationParametersOfEigenmodes(double &slope, double &intercept,
373                                             vtkDataArray *m1, vtkDataArray *m2,
374                                             Vector w = Vector())
375 {
376   const int k = min(m1->GetNumberOfComponents(), m2->GetNumberOfComponents());
377   if (w.Rows() < k) w.Resize(k, 1.0);
378   double range[2], minval = numeric_limits<double>::infinity(), maxrange = .0;
379   for (int c = 0; c < k; ++c) {
380     m1->GetRange(range, c);
381     minval   = min(minval,   w(c) * range[0]);
382     maxrange = max(maxrange, w(c) * (range[1] - range[0]));
383     m2->GetRange(range, c);
384     minval   = min(minval,   w(c) * range[0]);
385     maxrange = max(maxrange, w(c) * (range[1] - range[0]));
386   }
387   slope     = 2.0 / maxrange;
388   intercept = - slope * minval - 1.0;
389 }
390 
391 // -----------------------------------------------------------------------------
392 /// (Re-)compute eigenmodes of datasets independent of each other and
393 /// correct for sign ambiguity and order afterwards using bipartite matching
UpdateEigenmodes(const RegisteredPointSet * target,Vector & target_eigenvalues,PointCorrespondence::FeatureList & target_features,const RegisteredPointSet * source,Vector & source_eigenvalues,PointCorrespondence::FeatureList & source_features,int k)394 void UpdateEigenmodes(const RegisteredPointSet         *target,
395                       Vector                           &target_eigenvalues,
396                       PointCorrespondence::FeatureList &target_features,
397                       const RegisteredPointSet         *source,
398                       Vector                           &source_eigenvalues,
399                       PointCorrespondence::FeatureList &source_features,
400                       int                               k)
401 {
402   using namespace SpectralDecomposition;
403   // Get feature entries
404   PointCorrespondence::FeatureList::iterator target_info;
405   for (target_info = target_features.begin(); target_info != target_features.end(); ++target_info) {
406     if (target_info->_Name == "eigenmodes") break;
407   }
408   if (target_info == target_features.end()) return;
409   PointCorrespondence::FeatureList::iterator source_info;
410   for (source_info = source_features.begin(); source_info != source_features.end(); ++source_info) {
411     if (source_info->_Name == "eigenmodes") break;
412   }
413   if (source_info == source_features.end()) return;
414   // Cast to vtkPolyData
415   vtkPolyData *target_polydata = vtkPolyData::SafeDownCast(target->PointSet());
416   vtkPolyData *source_polydata = vtkPolyData::SafeDownCast(source->PointSet());
417   if (!target_polydata || !source_polydata) {
418     cerr << "FiducialRegistrationError::UpdateEigenmodes: Point sets must be surface meshes" << endl;
419     exit(1);
420   }
421   // Get/compute eigenmodes of each dataset
422   Matrix m1, m2;
423   target_info->_Index = GetEigenmodes(target, k, m1, target_eigenvalues);
424   source_info->_Index = GetEigenmodes(source, k, m2, source_eigenvalues);
425   // Match sign and order of eigenmodes
426   if (source->Transformation() != NULL && target->Transformation() == NULL) {
427     MatchEigenmodes(target->Points(), m1, target_eigenvalues,
428                     source->Points(), m2, source_eigenvalues);
429     source_info->_Index = -2; // indicate that eigenmodes need to be updated
430   } else {
431     MatchEigenmodes(source->Points(), m2, source_eigenvalues,
432                     target->Points(), m1, target_eigenvalues);
433     target_info->_Index = -2; // indicate that eigenmodes need to be updated
434   }
435   // Set eigenmodes as point data of datasets to be registered
436   if (target_info->_Index < 0) {
437     target_info->_Index = SetEigenmodes(target_polydata, m1, 0, k, "eigenmodes");
438   }
439   if (source_info->_Index < 0) {
440     source_info->_Index = SetEigenmodes(source_polydata, m2, 0, k, "eigenmodes");
441   }
442   // Set rescaling parameters s.t. eigenmodes are normalized to range [-1 +1]
443   double slope, intercept;
444   vtkDataArray *target_eigenmodes = target->PointSet()->GetPointData()->GetArray(target_info->_Index);
445   vtkDataArray *source_eigenmodes = source->PointSet()->GetPointData()->GetArray(source_info->_Index);
446   GetNormalizationParametersOfEigenmodes(slope, intercept, target_eigenmodes, source_eigenmodes);
447   target_info->_Slope     = source_info->_Slope     = slope;
448   target_info->_Intercept = source_info->_Intercept = intercept;
449 }
450 
451 // -----------------------------------------------------------------------------
UpdateEigenmodes(const RegisteredPointSet * target,PointCorrespondence::FeatureList & target_features,const RegisteredPointSet * source,PointCorrespondence::FeatureList & source_features,Vector & eigenvalues,int k)452 void UpdateEigenmodes(const RegisteredPointSet         *target,
453                       PointCorrespondence::FeatureList &target_features,
454                       const RegisteredPointSet         *source,
455                       PointCorrespondence::FeatureList &source_features,
456                       Vector                           &eigenvalues,
457                       int                               k)
458 {
459   using namespace SpectralDecomposition;
460   typedef GenericSparseMatrix<double> SparseMatrix;
461   vtkDataArray *target_eigenmodes, *source_eigenmodes;
462   // Get feature entries
463   PointCorrespondence::FeatureList::iterator target_info;
464   for (target_info = target_features.begin(); target_info != target_features.end(); ++target_info) {
465     if (target_info->_Name == "eigenmodes") break;
466   }
467   if (target_info == target_features.end()) return;
468   PointCorrespondence::FeatureList::iterator source_info;
469   for (source_info = source_features.begin(); source_info != source_features.end(); ++source_info) {
470     if (source_info->_Name == "eigenmodes") break;
471   }
472   if (source_info == source_features.end()) return;
473   // Cast to vtkPolyData
474   vtkPolyData *target_polydata = vtkPolyData::SafeDownCast(target->PointSet());
475   vtkPolyData *source_polydata = vtkPolyData::SafeDownCast(source->PointSet());
476   if (!target_polydata || !source_polydata) {
477     cerr << "FiducialRegistrationError::UpdateEigenmodes: Point sets must be surface meshes" << endl;
478     exit(1);
479   }
480   // Compute initial spectral coordinates
481   if (ComputeEigenmodes(target_polydata, source_polydata, k) < k) {
482     cerr << "FiducialRegistrationError::UpdateEigenmodes: Failed to find " << k << " initial eigenmodes" << endl;
483     exit(1);
484   }
485   // Sample points for which inter-dataset links will be added
486   Array<int> target_sample, source_sample;
487   SamplePoints(target, target_sample, max(10, target->NumberOfPoints() / 10));
488   SamplePoints(source, source_sample, max(10, source->NumberOfPoints() / 10));
489   // Find initial point correspondences
490   PointLocator::FeatureList target_match_features;
491   PointLocator::FeatureList source_match_features;
492   PointLocator::FeatureInfo feature;
493   vtkPointData * const targetPD = target_polydata->GetPointData();
494   vtkPointData * const sourcePD = source_polydata->GetPointData();
495   // Use computed spectral coordinates for initial match
496   int target_index, source_index;
497   target_eigenmodes = targetPD->GetArray(feature._Name.c_str(), target_index);
498   source_eigenmodes = sourcePD->GetArray(feature._Name.c_str(), source_index);
499   GetNormalizationParametersOfEigenmodes(feature._Slope, feature._Intercept,
500                                          target_eigenmodes, source_eigenmodes);
501   feature._Name   = "eigenmodes";
502   feature._Weight = target_info->_Weight;
503   feature._Index  = target_index;
504   target_match_features.push_back(feature);
505   feature._Weight = source_info->_Weight;
506   feature._Index  = source_index;
507   source_match_features.push_back(feature);
508   // Optionally also use normals of 3D spectral points
509   feature._Name = "spectral normals";
510   for (size_t i = 0; i < target_features.size(); ++i) {
511     if (target_features[i]._Name == feature._Name) {
512       ComputeSpectralNormals(target_polydata);
513       feature._Weight    = target_features[i]._Weight;
514       feature._Slope     = target_features[i]._Slope;
515       feature._Intercept = target_features[i]._Intercept;
516       targetPD->GetArray(feature._Name.c_str(), feature._Index);
517       target_match_features.push_back(feature);
518       break;
519     }
520   }
521   for (size_t i = 0; i < source_features.size(); ++i) {
522     if (source_features[i]._Name == feature._Name) {
523       ComputeSpectralNormals(source_polydata);
524       sourcePD->GetArray(feature._Name.c_str(), feature._Index);
525       feature._Weight    = source_features[i]._Weight;
526       feature._Slope     = source_features[i]._Slope;
527       feature._Intercept = source_features[i]._Intercept;
528       source_match_features.push_back(feature);
529       break;
530     }
531   }
532   // Find nearest neighbors of point samples
533   Array<int>    corr12, corr21;
534   Array<double> dist12, dist21;
535   corr12 = PointLocator::FindClosestPoint(target->PointSet(), &target_sample, &target_match_features,
536                                           source->PointSet(), &source_sample, &source_match_features, &dist12);
537   corr21 = PointLocator::FindClosestPoint(source->PointSet(), &source_sample, &source_match_features,
538                                           target->PointSet(), &target_sample, &target_match_features, &dist21);
539   // Set intra-mesh affinity weights
540   const int m = target->NumberOfPoints();
541   const int n = source->NumberOfPoints();
542   SparseMatrix::Entries *cols = new SparseMatrix::Entries[m + n];
543   AdjacencyMatrix(cols, SparseMatrix::CCS, 0, 0, vtkPolyData::SafeDownCast(target->InputPointSet()));
544   AdjacencyMatrix(cols, SparseMatrix::CCS, m, m, vtkPolyData::SafeDownCast(source->InputPointSet()));
545   // Set inter-mesh affinity weights
546   const int ncorr12 = static_cast<int>(corr12.size());
547   for (int i = 0; i < ncorr12; ++i) {
548     dist12[i] = 1.0 / (sqrt(dist12[i]) + EPSILON);
549     const int r = corr12[i] + m;
550     const int c = PointCorrespondence::GetPointIndex(target, &target_sample, i);
551     cols[c].push_back(MakePair(r, dist12[i]));
552     cols[r].push_back(MakePair(c, dist12[i]));
553   }
554   const int ncorr21 = static_cast<int>(corr21.size());
555   for (int i = 0; i < ncorr21; ++i) {
556     dist21[i] = 1.0 / (sqrt(dist21[i]) + EPSILON);
557     const int r = corr21[i];
558     const int c = PointCorrespondence::GetPointIndex(source, &source_sample, i) + m;
559     cols[c].push_back(MakePair(r, dist21[i]));
560     cols[r].push_back(MakePair(c, dist21[i]));
561   }
562   // Compute graph Laplacian of joint connectivity graph
563   SparseMatrix L(SparseMatrix::CCS);
564   L.Initialize(m + n, m + n, cols);
565   NormalizedLaplacian(L, L);
566   delete[] cols;
567   // Compute eigenmodes of joint graph Laplacian
568   Matrix eigenmodes;
569   if (ComputeEigenmodes(L, k+2, eigenmodes, eigenvalues) < k) {
570     cerr << "FiducialRegistrationError::UpdateEigenmodes: Failed to compute " << k << " eigenmodes" << endl;
571     exit(1);
572   }
573   eigenvalues.Resize(k);
574   // Set eigenmodes as point data
575   target_info->_Index = SetEigenmodes(target_polydata, eigenmodes, 0, k, "eigenmodes");
576   source_info->_Index = SetEigenmodes(source_polydata, eigenmodes, m, k, "eigenmodes");
577   target_eigenmodes   = targetPD->GetArray(target_info->_Index);
578   source_eigenmodes   = sourcePD->GetArray(source_info->_Index);
579   // Calculate weight of eigenmodes, more importance to lower frequencies
580   double wsum = .0;
581   for (int c = 0; c < k; ++c) {
582     wsum += 1.0 / sqrt(eigenvalues(c));
583   }
584   Vector weight(k);
585   for (int c = k; c >= 0; --c) {
586     weight(c) = (1.0 / sqrt(eigenvalues(c))) / wsum;
587   }
588   // Set rescaling parameters s.t. eigenmodes are normalized to range [-1 +1]
589   double slope, intercept;
590   GetNormalizationParametersOfEigenmodes(slope, intercept, target_eigenmodes, source_eigenmodes, weight);
591   target_info->_Slope     = source_info->_Slope     = slope;
592   target_info->_Intercept = source_info->_Intercept = intercept;
593 }
594 
595 // -----------------------------------------------------------------------------
UpdateEigenmodesIfUsed(const RegisteredPointSet * target,Vector & target_eigenvalues,PointCorrespondence::FeatureList & target_features,const RegisteredPointSet * source,Vector & source_eigenvalues,PointCorrespondence::FeatureList & source_features,int k,bool diffeo)596 void UpdateEigenmodesIfUsed(const RegisteredPointSet         *target,
597                             Vector                           &target_eigenvalues,
598                             PointCorrespondence::FeatureList &target_features,
599                             const RegisteredPointSet         *source,
600                             Vector                           &source_eigenvalues,
601                             PointCorrespondence::FeatureList &source_features,
602                             int k, bool diffeo)
603 {
604   // Skip if no spectral coordinates used as features
605   if (k == 0) {
606     target_eigenvalues.Clear();
607     source_eigenvalues.Clear();
608     return;
609   }
610   size_t i;
611   for (i = 0; i < target_features.size(); ++i) {
612     if ((target_features[i]._Name == "eigenmodes" ||
613          target_features[i]._Name == "spectral normals")
614         && target_features[i]._Weight != .0) break;
615   }
616   if (i == target_features.size()) {
617     target_eigenvalues.Clear();
618     source_eigenvalues.Clear();
619     return;
620   }
621   // Compute eigenmodes
622   if (diffeo) {
623     UpdateEigenmodes(target, target_features,
624                      source, source_features, source_eigenvalues, k);
625     target_eigenvalues = source_eigenvalues;
626   } else {
627     UpdateEigenmodes(target, target_eigenvalues, target_features,
628                      source, source_eigenvalues, source_features, k);
629   }
630 }
631 
632 
633 } // namespace PointCorrespondenceUtils
634 using namespace PointCorrespondenceUtils;
635 
636 // =============================================================================
637 // Construction/Destruction
638 // =============================================================================
639 
640 // -----------------------------------------------------------------------------
PointCorrespondence(const RegisteredPointSet * target,const RegisteredPointSet * source)641 PointCorrespondence::PointCorrespondence(const RegisteredPointSet *target,
642                                          const RegisteredPointSet *source)
643 :
644   _M(0), _N(0), _NumberOfFeatures(0),
645   _Target(target), _TargetSample(NULL), _TargetFeatures(),
646   _Source(source), _SourceSample(NULL), _SourceFeatures(),
647   _DimensionOfSpectralPoints(5),
648   _DiffeomorphicSpectralDecomposition(true),
649   _UpdateSpectralPoints(false),
650   _FromTargetToSource(true),
651   _FromSourceToTarget(true),
652   _DefaultDirection(TargetToSource)
653 {
654 }
655 
656 // -----------------------------------------------------------------------------
PointCorrespondence(const PointCorrespondence & other)657 PointCorrespondence::PointCorrespondence(const PointCorrespondence &other)
658 :
659   _M(other._M), _N(other._N), _NumberOfFeatures(other._NumberOfFeatures),
660   _Target(other._Target), _TargetSample(other._TargetSample), _TargetFeatures(other._TargetFeatures),
661   _Source(other._Source), _SourceSample(other._SourceSample), _SourceFeatures(other._SourceFeatures),
662   _DimensionOfSpectralPoints(other._DimensionOfSpectralPoints),
663   _DiffeomorphicSpectralDecomposition(other._DiffeomorphicSpectralDecomposition),
664   _UpdateSpectralPoints (other._UpdateSpectralPoints),
665   _FromTargetToSource(other._FromTargetToSource),
666   _FromSourceToTarget(other._FromSourceToTarget),
667   _DefaultDirection(other._DefaultDirection)
668 {
669 }
670 
671 // -----------------------------------------------------------------------------
~PointCorrespondence()672 PointCorrespondence::~PointCorrespondence()
673 {
674 }
675 
676 // =============================================================================
677 // Parameters
678 // =============================================================================
679 
680 // -----------------------------------------------------------------------------
GetPointDataIndexByCaseInsensitiveName(vtkPointData * pd,const string & name)681 int PointCorrespondence::GetPointDataIndexByCaseInsensitiveName(vtkPointData *pd, const string &name)
682 {
683   string lname = ToLower(name);
684   for (int i = 0; i < pd->GetNumberOfArrays(); ++i) {
685     const char *array_name = pd->GetArrayName(i);
686     if (array_name == NULL) continue;
687     string lower_name = ToLower(array_name);
688     if (lower_name == lname) return i;
689   }
690   return -1;
691 }
692 
693 // -----------------------------------------------------------------------------
Set(const char * param,const char * value)694 bool PointCorrespondence::Set(const char *param, const char *value)
695 {
696   if (strcmp(param, "No. of spectral coordinates")    == 0 ||
697       strcmp(param, "Number of spectral coordinates") == 0 ||
698       strcmp(param, "Dimension of spectral points")   == 0 ||
699       strcmp(param, "Spectral coordinates number")    == 0) {
700     return FromString(value, _DimensionOfSpectralPoints);
701   }
702   if (strcmp(param, "Diffeomorphic spectral matching")      == 0 ||
703       strcmp(param, "Diffeomorphic spectral coordinates")   == 0 ||
704       strcmp(param, "Diffeomorphic spectral decomposition") == 0) {
705     return FromString(value, _DiffeomorphicSpectralDecomposition);
706   }
707   if (strcmp(param, "Spectral coordinates update") == 0) {
708     return FromString(value, _UpdateSpectralPoints);
709   }
710 
711   size_t len = strlen(param);
712   if (len > 7 && strcmp(param + len - 7, " weight") == 0) {
713     double weight;
714     if (!FromString(value, weight) || weight < .0) return false;
715     string feature_name(param, param + len - 7);
716     AddFeature(feature_name.c_str(), weight);
717     return true;
718   }
719 
720   return false;
721 }
722 
723 // -----------------------------------------------------------------------------
Parameter() const724 ParameterList PointCorrespondence::Parameter() const
725 {
726   ParameterList params = Observable::Parameter();
727   Insert(params, "Diffeomorphic spectral decompositon", ToString(_DiffeomorphicSpectralDecomposition));
728   Insert(params, "No. of spectral coordinates", ToString(_DimensionOfSpectralPoints));
729   Insert(params, "Spectral coordinates update", ToString(_UpdateSpectralPoints));
730   string name;
731   for (size_t i = 0; i < _TargetFeatures.size(); ++i) {
732     name = _TargetFeatures[i]._Name;
733     if (name.empty()) continue;
734     if (name == "eigenmodes") name = "spectral coordinates";
735     name += " weight";
736     Insert(params, name, ToString(_TargetFeatures[i]._Slope));
737   }
738   return params;
739 }
740 
741 // -----------------------------------------------------------------------------
TransformFeatureName(const char * name)742 static string TransformFeatureName(const char *name)
743 {
744   string feature_name = ToLower(name);
745   if (feature_name == "spatial normals") {
746     feature_name = "normals";
747   } else if (feature_name == "spectral coordinates" ||
748              feature_name == "spectral points") {
749     feature_name = "eigenmodes";
750   } else if (feature_name == "spatial coordinates" ||
751              feature_name == "spatial point"       ||
752              feature_name == "spatial points"      ||
753              feature_name == "points") {
754     feature_name = "spatial coordinates";
755   }
756   return feature_name;
757 }
758 
759 // -----------------------------------------------------------------------------
AddFeature(const char * name,double weight,double slope,double intercept)760 bool PointCorrespondence::AddFeature(const char *name, double weight, double slope, double intercept)
761 {
762   const string feature_name = TransformFeatureName(name);
763   int index = -2; // < -1: determined after data is set by CompleteFeatureInfo
764   if (feature_name == "spatial coordinates") index = -1;
765   _TargetFeatures.push_back(FeatureInfo(feature_name, index, weight, slope, intercept));
766   _SourceFeatures.push_back(FeatureInfo(feature_name, index, weight, slope, intercept));
767   return true;
768 }
769 
770 // -----------------------------------------------------------------------------
RemoveFeature(const char * name)771 void PointCorrespondence::RemoveFeature(const char *name)
772 {
773   const string feature_name = TransformFeatureName(name);
774   for (FeatureList::iterator i = _TargetFeatures.begin(); i != _TargetFeatures.end(); ++i) {
775     if (i->_Name == feature_name) {
776       FeatureList::iterator pos = i; --i;
777       _TargetFeatures.erase(pos);
778     }
779   }
780   for (FeatureList::iterator i = _SourceFeatures.begin(); i != _SourceFeatures.end(); ++i) {
781     if (i->_Name == feature_name) {
782       FeatureList::iterator pos = i; --i;
783       _SourceFeatures.erase(pos);
784     }
785   }
786 }
787 
788 // -----------------------------------------------------------------------------
CompleteFeatureInfo(const RegisteredPointSet * input,FeatureList & feature)789 void PointCorrespondence::CompleteFeatureInfo(const RegisteredPointSet *input, FeatureList &feature)
790 {
791   vtkPointData * const inputPD = input->PointSet()->GetPointData();
792   for (FeatureList::iterator i = feature.begin(); i != feature.end(); ++i) {
793     if (i->_Index < -1) {
794       if (i->_Name.empty()) {
795         cerr << "PointCorrespondence::CompleteFeatureInfo: Encountered feature without name" << endl;
796         exit(1);
797       }
798       i->_Index = GetPointDataIndexByCaseInsensitiveName(inputPD, i->_Name.c_str());
799       // Spectral features are computed and added by UpdateEigenmodesIfUsed when missing
800       if (i->_Index < 0 && i->_Name != "eigenmodes" && i->_Name != "spectral normals") {
801         cerr << "PointCorrespondence::CompleteFeatureInfo: Missing point data " << i->_Name << endl;
802         exit(1);
803       }
804     } else if (i->_Index >= inputPD->GetNumberOfArrays()) {
805       cerr << "PointCorrespondence::CompleteFeatureInfo: Feature index is out of bounds" << endl;
806       exit(1);
807     }
808   }
809 }
810 
811 // =============================================================================
812 // Correspondences
813 // =============================================================================
814 
815 // -----------------------------------------------------------------------------
Initialize()816 void PointCorrespondence::Initialize()
817 {
818   // Check that all inputs are set
819   if (_Target == NULL) {
820     cerr << "PointCorrespondence::Initialize: Target data set not set" << endl;
821     exit(1);
822   }
823   if (_Source == NULL) {
824     cerr << "PointCorrespondence::Initialize: Source data set not set" << endl;
825     exit(1);
826   }
827 
828   // Ensure that correspondences in at least one direction are requested
829   if (!_FromTargetToSource && !_FromSourceToTarget) {
830     cerr << "PointCorrespondence::Initialize: At least one direction,"
831             " either FromTargetToSource or FromSourceToTarget must be enabled" << endl;
832     exit(1);
833   }
834 
835   // Fill missing feature info which requires inspection of the input dataset
836   CompleteFeatureInfo(_Target, _TargetFeatures);
837   CompleteFeatureInfo(_Source, _SourceFeatures);
838 
839   // Determine size of feature vectors
840   _NumberOfFeatures = GetPointDimension(_Target, &_TargetFeatures);
841   if (GetPointDimension(_Source, &_SourceFeatures) != _NumberOfFeatures) {
842     cerr << "PointCorrespondence::Initialize: Mismatching feature vector size" << endl;
843     exit(1);
844   }
845 
846   // Initialize this class
847   PointCorrespondence::Init();
848 }
849 
850 // -----------------------------------------------------------------------------
Reinitialize()851 void PointCorrespondence::Reinitialize()
852 {
853   // Reinitialize this class
854   PointCorrespondence::Init();
855 }
856 
857 // -----------------------------------------------------------------------------
Init()858 void PointCorrespondence::Init()
859 {
860   // Determine number of points (samples)
861   _M = GetNumberOfPoints(_Target, _TargetSample);
862   _N = GetNumberOfPoints(_Source, _SourceSample);
863   if (_M == 0) {
864     cerr << "PointCorrespondence::Initialize: Target data set has no points!" << endl;
865     exit(1);
866   }
867   if (_N == 0) {
868     cerr << "PointCorrespondence::Initialize: Source data set has no points!" << endl;
869     exit(1);
870   }
871 
872   // Compute spectral coordinates if needed
873   UpdateEigenmodesIfUsed(_Target, _TargetEigenvalues, _TargetFeatures,
874                          _Source, _SourceEigenvalues, _SourceFeatures,
875                          _DimensionOfSpectralPoints, _DiffeomorphicSpectralDecomposition);
876 }
877 
878 // -----------------------------------------------------------------------------
Update()879 void PointCorrespondence::Update()
880 {
881   // Recompute spectral coordinates (optional/experimental)
882   if (_UpdateSpectralPoints) {
883     UpdateEigenmodesIfUsed(_Target, _TargetEigenvalues, _TargetFeatures,
884                            _Source, _SourceEigenvalues, _SourceFeatures,
885                            _DimensionOfSpectralPoints, _DiffeomorphicSpectralDecomposition);
886   }
887 }
888 
889 // -----------------------------------------------------------------------------
Upgrade()890 bool PointCorrespondence::Upgrade()
891 {
892   return false;
893 }
894 
895 // -----------------------------------------------------------------------------
GetTargetIndex(int) const896 int PointCorrespondence::GetTargetIndex(int) const
897 {
898   return -1;
899 }
900 
901 // -----------------------------------------------------------------------------
GetSourceIndex(int) const902 int PointCorrespondence::GetSourceIndex(int) const
903 {
904   return -1;
905 }
906 
907 // =============================================================================
908 // Debugging
909 // =============================================================================
910 
911 // -----------------------------------------------------------------------------
WriteDataSets(const char * prefix,const char * suffix,bool all) const912 void PointCorrespondence::WriteDataSets(const char *prefix, const char *suffix, bool all) const
913 {
914   if (all || debug >= 4) {
915     const int sz = 1024;
916     char      fname[sz];
917     snprintf(fname, sz, "%sspectral_target_points%s.vtp", prefix, suffix);
918     this->WriteSpectralPoints(fname, _Target->PointSet());
919     snprintf(fname, sz, "%sspectral_source_points%s.vtp", prefix, suffix);
920     this->WriteSpectralPoints(fname, _Source->PointSet());
921   }
922 }
923 
924 // -----------------------------------------------------------------------------
WriteSpectralPoints(const char * fname,vtkPointSet * d) const925 void PointCorrespondence::WriteSpectralPoints(const char *fname, vtkPointSet *d) const
926 {
927   vtkDataArray *modes = d->GetPointData()->GetArray("eigenmodes");
928   if (!modes) return;
929 
930   vtkSmartPointer<vtkPoints> points = vtkSmartPointer<vtkPoints>::New();
931   points->SetNumberOfPoints(d->GetNumberOfPoints());
932 
933   double *p = new double[max(3, int(modes->GetNumberOfComponents()))];
934   for (int c = 0; c < 3; ++c) p[c] = .0;
935   for (vtkIdType i = 0; i < points->GetNumberOfPoints(); ++i) {
936     modes ->GetTuple(i, p);
937     points->SetPoint(i, p);
938   }
939   delete[] p;
940 
941   vtkSmartPointer<vtkPointSet> output;
942   output.TakeReference(d->NewInstance());
943   output->ShallowCopy(d);
944   output->SetPoints(points);
945   WritePointSet(fname, output);
946 }
947 
948 
949 } // namespace mirtk
950