1 /*
2  * Medical Image Registration ToolKit (MIRTK)
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "mirtk/Common.h"
18 #include "mirtk/Options.h"
19 
20 #include "mirtk/PointCorrespondence.h"
21 #include "mirtk/PointSetIO.h"
22 #include "mirtk/RegisteredPointSet.h"
23 
24 #include "mirtk/Transformation.h"
25 #include "mirtk/TransformationModel.h"
26 #include "mirtk/HomogeneousTransformation.h"
27 #include "mirtk/MultiLevelTransformation.h"
28 
29 using namespace mirtk;
30 
31 
32 // TODO: Support also non-rigid transformations. In case of a FFD, support especially also a MFFD
33 //       which implements the multi-level scatter data approximation with cubic B-splines.
34 
35 
36 // =============================================================================
37 // Help
38 // =============================================================================
39 
40 // -----------------------------------------------------------------------------
PrintHelp(const char * name)41 void PrintHelp(const char *name)
42 {
43   cout << "\n";
44   cout << "Usage: " << name << " [options]\n";
45   cout << "\n";
46   cout << "Description:\n";
47   cout << "  Register point sets or surfaces by iteratively approximating the residual target registration error\n";
48   cout << "  given current point correspondences and transformation estimate. By default, correspondences are defined\n";
49   cout << "  by closest points. The default settings implement the iterative closest points (ICP) algorithm. When the\n";
50   cout << "  input point sets are surface meshes, :option:`-closest-surface-points` can be used as correspondences\n";
51   cout << "  instead of discrete points in order to increase the accuracy of the surface registration at the expense\n";
52   cout << "  of a slightly more costly point correspondence update at each iteration.\n";
53   cout << "\n";
54   cout << "Required options:\n";
55   cout << "  -t, -target <path>...\n";
56   cout << "    One or more target point sets or surfaces. If multiple targets are specified,\n";
57   cout << "    the number of target point sets must match the number of :option:`-source` point sets.\n";
58   cout << "    In this case, the respective pairs of target and source point sets are registered\n";
59   cout << "    to one another using the same target to source transformation. This option can\n";
60   cout << "    be given multiple times to append target point sets.\n";
61   cout << "  -s, -source <path>...\n";
62   cout << "    One or more source point sets or surfaces. If a single :option:`-target` point set is given,\n";
63   cout << "    all source point sets are registered to this common target point set. Otherwise,\n";
64   cout << "    the number of target and source point sets must be the same. This option can be\n";
65   cout << "    can be given multiple times to append source point sets.\n";
66   cout << "  -o, -dofout <path>\n";
67   cout << "    Output transformation file (.dof, .dof.gz).\n";
68   cout << "\n";
69   cout << "Optional arguments:\n";
70   cout << "  -i, -dofin <path>|Id|identity\n";
71   cout << "    File path of input transformation from which initial -model parameters are derived.\n";
72   cout << "  -m, -model Rigid|Similarity|Affine\n";
73   cout << "    Transformation model to use. (default: Rigid)\n";
74   cout << "  -n, -iterations <n>\n";
75   cout << "    Number of iterations. (default: 100)\n";
76   cout << "  -c, -cor, -corr, -correspondence <name>\n";
77   cout << "    Name of correspondence type, e.g., 'closest point', 'closest cell'. (default: closest point)\n";
78   cout << "  -cp, -closest-point\n";
79   cout << "    Alias for :option:`-correspondence 'closest point'`.\n";
80   cout << "  -csp, -closest-surface-point, -closest-cell\n";
81   cout << "    Alias for :option:`-correspondence 'closest surface point'`.\n";
82   cout << "  -p, -par, -corpar, -corrpar <name> <value>\n";
83   cout << "    Set parameter of chosen correspondence type.\n";
84   cout << "  -f, -feature <name> [<weight>]\n";
85   cout << "    Point or cell features based on which correspondences are determined.\n";
86   cout << "    By default, the spatial 3D coordinates of the input points are used.\n";
87   cout << "  -[no]symmetric [on|off]\n";
88   cout << "    Approximate symmetric registration error. When this flag is set, the distance\n";
89   cout << "    from every source point to its corresponding target point is considered in addition\n";
90   cout << "    to the distance of each target point to its corresponding source point. (default: off)\n";
91   cout << "  -[no]inverse [on|off]\n";
92   cout << "    By default, the output transformation is applied to the target points.\n";
93   cout << "    When this flag is set, the output transformation is applied to the source\n";
94   cout << "    points instead. (default: off)\n";
95   PrintStandardOptions(cout);
96   cout << "\n";
97   cout.flush();
98 }
99 
100 // =============================================================================
101 // Auxiliary functions
102 // =============================================================================
103 
104 // -----------------------------------------------------------------------------
TargetIndex(int m,int n,int i)105 inline int TargetIndex(int m, int n, int i)
106 {
107   return n == 1 ? 0 : i;
108 }
109 
110 // -----------------------------------------------------------------------------
Update(Array<UniquePtr<PointCorrespondence>> & cmaps)111 void Update(Array<UniquePtr<PointCorrespondence>> &cmaps)
112 {
113   for (auto &cmap : cmaps) {
114     cmap->Update();
115   }
116 }
117 
118 // -----------------------------------------------------------------------------
Update(Array<RegisteredPointSet> & psets)119 void Update(Array<RegisteredPointSet> &psets)
120 {
121   for (auto &pset : psets) {
122     pset.Update();
123   }
124 }
125 
126 // -----------------------------------------------------------------------------
Update(Array<RegisteredPointSet> & targets,Array<RegisteredPointSet> & sources,Array<UniquePtr<PointCorrespondence>> & cmaps)127 void Update(Array<RegisteredPointSet> &targets,
128             Array<RegisteredPointSet> &sources,
129             Array<UniquePtr<PointCorrespondence>> &cmaps)
130 {
131   Update(targets);
132   Update(sources);
133   Update(cmaps);
134 }
135 
136 // -----------------------------------------------------------------------------
EvaluateRMSError(const Array<RegisteredPointSet> & targets,const Array<RegisteredPointSet> & sources,const Array<UniquePtr<PointCorrespondence>> & cmaps,bool symmetric=false)137 double EvaluateRMSError(const Array<RegisteredPointSet> &targets,
138                         const Array<RegisteredPointSet> &sources,
139                         const Array<UniquePtr<PointCorrespondence>> &cmaps,
140                         bool symmetric = false)
141 {
142   Point p, q;
143 
144   const int m = static_cast<int>(sources.size());
145   const int n = static_cast<int>(targets.size());
146 
147   double error = 0;
148   int count = 0;
149 
150   for (int i = 0; i < m; ++i) {
151     const int j = TargetIndex(m, n, i);
152     auto &cmap = cmaps[i];
153     auto &source = sources[i];
154     auto &target = targets[j];
155     for (int t = 0; t < target.NumberOfPoints(); ++t) {
156       target.GetPoint(t, p);
157       if (cmap->GetSourcePoint(t, q)) {
158         error += pow(q._x - p._x, 2) + pow(q._y - p._y, 2) + pow(q._z - p._z, 2);
159         ++count;
160       }
161     }
162     if (symmetric) {
163       for (int s = 0; s < source.NumberOfPoints(); ++s) {
164         source.GetInputPoint(s, q);
165         if (cmap->GetInputTargetPoint(s, p)) {
166           error += pow(q._x - p._x, 2) + pow(q._y - p._y, 2) + pow(q._z - p._z, 2);
167           ++count;
168         }
169       }
170     }
171   }
172 
173   if (count == 0) FatalError("No corresponding points found");
174   return sqrt(error / count);
175 }
176 
177 // -----------------------------------------------------------------------------
Fit(Transformation * dof,const Array<RegisteredPointSet> & targets,const Array<RegisteredPointSet> & sources,const Array<UniquePtr<PointCorrespondence>> & cmaps,bool symmetric=false,bool inverse=false)178 void Fit(Transformation *dof,
179          const Array<RegisteredPointSet> &targets,
180          const Array<RegisteredPointSet> &sources,
181          const Array<UniquePtr<PointCorrespondence>> &cmaps,
182          bool symmetric = false, bool inverse = false)
183 {
184   const int m = static_cast<int>(sources.size());
185   const int n = static_cast<int>(targets.size());
186 
187   int no = 0;
188   for (int i = 0; i < m; ++i) {
189     const int j = TargetIndex(m, n, i);
190     no += targets[j].NumberOfPoints();
191     if (symmetric) no += sources[i].NumberOfPoints();
192   }
193 
194   Array<double> x(no), y(no), z(no), dx(no), dy(no), dz(no);
195   Point p, q;
196 
197   int k = 0;
198   for (int i = 0; i < m; ++i) {
199     const int j = TargetIndex(m, n, i);
200     auto &cmap = cmaps[i];
201     auto &source = sources[i];
202     auto &target = targets[j];
203     for (int t = 0; t < target.NumberOfPoints(); ++t) {
204       target.GetInputPoint(t, p);
205       if (cmap->GetInputSourcePoint(t, q)) {
206         if (inverse) swap(p, q);
207         x[k] = p._x;
208         y[k] = p._y;
209         z[k] = p._z;
210         dx[k] = q._x - p._x;
211         dy[k] = q._y - p._y;
212         dz[k] = q._z - p._z;
213         ++k;
214       }
215     }
216     if (symmetric) {
217       for (int s = 0; s < source.NumberOfPoints(); ++s) {
218         source.GetInputPoint(s, q);
219         if (cmap->GetInputTargetPoint(s, p)) {
220           if (inverse) swap(p, q);
221           x[k] = p._x;
222           y[k] = p._y;
223           z[k] = p._z;
224           dx[k] = q._x - p._x;
225           dy[k] = q._y - p._y;
226           dz[k] = q._z - p._z;
227           ++k;
228         }
229       }
230     }
231   }
232   dof->ApproximateAsNew(x.data(), y.data(), z.data(), dx.data(), dy.data(), dz.data(), k);
233 }
234 
235 // -----------------------------------------------------------------------------
PrintProgress(ostream & os,int i,double error,bool flush=true)236 void PrintProgress(ostream &os, int i, double error, bool flush = true)
237 {
238   const streamsize w = os.width(0);
239   const streamsize p = os.precision(5);
240   const ios::fmtflags f = os.flags();
241 
242   os << setw(3) << i << ". RMS error = " << setprecision(5) << error << "\n";
243   if (flush) os.flush();
244 
245   os.width(w);
246   os.precision(p);
247   os.flags(f);
248 }
249 
250 // =============================================================================
251 // Main
252 // =============================================================================
253 
254 // -----------------------------------------------------------------------------
main(int argc,char * argv[])255 int main(int argc, char *argv[])
256 {
257   EXPECTS_POSARGS(0);
258 
259   Array<string> target_names;
260   Array<string> source_names;
261 
262   TransformationModel model = TM_Rigid;
263   string dofin_name;
264   string dofout_name;
265 
266   PointCorrespondence::TypeId ctype = PointCorrespondence::ClosestPoint;
267   Array<string> feature_name;
268   Array<double> feature_weight;
269   ParameterList param;
270 
271   int iterations = 100;
272   double epsilon = 0.01;
273   bool inverse = false;
274   bool symmetric = false;
275 
276   verbose = 1;
277 
278   for (ALL_OPTIONS) {
279     if (OPTION("-t") || OPTION("-target")) {
280       do {
281         target_names.push_back(ARGUMENT);
282       } while (HAS_ARGUMENT);
283     }
284     else if (OPTION("-s") || OPTION("-source")) {
285       do {
286         source_names.push_back(ARGUMENT);
287       } while (HAS_ARGUMENT);
288     }
289     else if (OPTION("-i") || OPTION("-dofin")) {
290       dofin_name = ARGUMENT;
291     }
292     else if (OPTION("-o") || OPTION("-dofout")) {
293       dofout_name = ARGUMENT;
294     }
295     else if (OPTION("-m") || OPTION("-model")) {
296       const char *arg = ARGUMENT;
297       if (!FromString(arg, model)) FatalError("Invalid -model argument: " << arg);
298     }
299     else if (OPTION("-n") || OPTION("-iterations")) {
300       PARSE_ARGUMENT(iterations);
301     }
302     else if (OPTION("-c") || OPTION("-cor") || OPTION("-corr") || OPTION("-correspondence")) {
303       const char *arg = ARGUMENT;
304       if (!FromString(arg, ctype)) FatalError("Invalid -correspondence argument: " << arg);
305     }
306     else if (OPTION("-p") || OPTION("-par") || OPTION("-corpar") || OPTION("-corrpar")) {
307       Insert(param, ARGUMENT, ARGUMENT);
308     }
309     else if (OPTION("-f") || OPTION("-feature")) {
310       feature_name.push_back(ARGUMENT);
311       if (HAS_ARGUMENT) feature_weight.push_back(atof(ARGUMENT));
312       else              feature_weight.push_back(1.0);
313     }
314     else if (OPTION("-cp") || OPTION("-closest-point")) {
315       ctype = PointCorrespondence::ClosestPoint;
316     }
317     else if (OPTION("-csp") || OPTION("-closest-surface-point") || OPTION("-closest-cell")) {
318       ctype = PointCorrespondence::ClosestCell;
319     }
320     else if (OPTION("-epsilon")) {
321       PARSE_ARGUMENT(epsilon);
322     }
323     else HANDLE_BOOL_OPTION(inverse);
324     else HANDLE_BOOL_OPTION(symmetric);
325     else HANDLE_COMMON_OR_UNKNOWN_OPTION();
326   }
327 
328   const int m = static_cast<int>(source_names.size());
329   const int n = static_cast<int>(target_names.size());
330 
331   // Check required arguments
332   if (dofout_name.empty()) {
333     FatalError("Option -dofout is required!");
334   }
335   if (m == 0 || n == 0) {
336     FatalError("Options -target and -source are required!");
337   }
338   if (n > 1 && n != m) {
339     FatalError("Either specify a single -target or one target for each -source point set!");
340   }
341   if (!IsLinear(model)) {
342     FatalError("Currently only Rigid, Similarity, and Affine transformation -model supported!");
343   }
344 
345   // By default, use point coordinates as featurs for point matching
346   if (feature_name.empty()) {
347     feature_name.push_back("spatial coordinates");
348     feature_weight.push_back(1.0);
349   }
350 
351   // Initialize transformation
352   UniquePtr<Transformation> dof(Transformation::New(ToTransformationType(model)));
353   if (!dofin_name.empty() && dofin_name != "Id" && dofin_name != "Identity" && dofin_name != "identity") {
354     if (verbose) cout << "Reading transformation...", cout.flush();
355     UniquePtr<Transformation> dofin(Transformation::New(dofin_name.c_str()));
356     const MultiLevelTransformation *mffd = dynamic_cast<const MultiLevelTransformation *>(dofin.get());
357     const HomogeneousTransformation *ilin = nullptr;
358     if (mffd) {
359       ilin = mffd->GetGlobalTransformation();
360     } else {
361       ilin = dynamic_cast<const HomogeneousTransformation *>(dofin.get());
362       if (!ilin) {
363         FatalError("Input -dofin must be a linear or multi-level transformation");
364       }
365     }
366     HomogeneousTransformation *lin = dynamic_cast<HomogeneousTransformation *>(dof.get());
367     mirtkAssert(lin != nullptr, "expected transformation for model=" << model << " to be of type HomogeneousTransformation");
368     lin->CopyFrom(ilin);
369     if (verbose) cout << " done" << endl;
370   }
371 
372   // Initialize point sets
373   if (verbose) cout << "Initialize point sets...", cout.flush();
374   Array<RegisteredPointSet> sources(m);
375   Array<RegisteredPointSet> targets(n);
376   for (int i = 0; i < m; ++i) {
377     auto pset = ReadPointSet(source_names[i].c_str());
378     if (pset->GetNumberOfPoints() == 0) {
379       FatalError("Failed to open source point set or point set contains no points: " << source_names[i]);
380     }
381     sources[i].InputPointSet(pset);
382     if (inverse) {
383       sources[i].Transformation(dof.get());
384     }
385     sources[i].Initialize();
386   }
387   for (int j = 0; j < n; ++j) {
388     auto pset = ReadPointSet(target_names[j].c_str());
389     if (pset->GetNumberOfPoints() == 0) {
390       FatalError("Failed to open target point set or point set contains no points: " << target_names[j]);
391     }
392     targets[j].InputPointSet(pset);
393     if (!inverse) {
394       targets[j].Transformation(dof.get());
395     }
396     targets[j].Initialize();
397   }
398   if (verbose) cout << " done" << endl;
399 
400   // Initialize correspondence maps
401   if (verbose) cout << "Initialize correspondence maps...", cout.flush();
402   Array<UniquePtr<PointCorrespondence>> cmaps(sources.size());
403   for (int i = 0; i < m; ++i) {
404     const int j = TargetIndex(m, n, i);
405     auto &cmap = cmaps[i];
406     cmap.reset(PointCorrespondence::New(ctype));
407     cmap->FromTargetToSource(true);
408     cmap->FromSourceToTarget(symmetric);
409     cmap->Parameter(param);
410     cmap->Source(&sources[i]);
411     cmap->Target(&targets[j]);
412     for (size_t f = 0; f < feature_name.size(); ++f) {
413       cmap->AddFeature(feature_name[f].c_str(), feature_weight[f]);
414     }
415     cmap->Initialize();
416   }
417   if (verbose) cout << " done" << endl;
418 
419   // Iterate least squares fitting
420   double error, last_error = numeric_limits<double>::infinity();
421   for (int iter = 0; true; ++iter) {
422     // Update correspondences
423     Update(targets, sources, cmaps);
424     // Check for convergence
425     error = EvaluateRMSError(targets, sources, cmaps);
426     if (verbose) PrintProgress(cout, iter, error);
427     if (last_error - error < epsilon) {
428       if (verbose) {
429         cout << "Converged after " << iter << " iterations." << endl;
430       }
431       break;
432     }
433     last_error = error;
434     if (iter >= iterations) {
435       if (verbose) {
436         cout << "Terminated after " << iter << " iterations." << endl;
437       }
438       break;
439     }
440     // Update transformation
441     Fit(dof.get(), targets, sources, cmaps, symmetric, inverse);
442   }
443 
444   // Write resulting transformation
445   dof->Write(dofout_name.c_str());
446 
447   return 0;
448 }
449