1 /*
2 * Medical Image Registration ToolKit (MIRTK)
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "mirtk/Common.h"
18 #include "mirtk/Options.h"
19
20 #include "mirtk/PointCorrespondence.h"
21 #include "mirtk/PointSetIO.h"
22 #include "mirtk/RegisteredPointSet.h"
23
24 #include "mirtk/Transformation.h"
25 #include "mirtk/TransformationModel.h"
26 #include "mirtk/HomogeneousTransformation.h"
27 #include "mirtk/MultiLevelTransformation.h"
28
29 using namespace mirtk;
30
31
32 // TODO: Support also non-rigid transformations. In case of a FFD, support especially also a MFFD
33 // which implements the multi-level scatter data approximation with cubic B-splines.
34
35
36 // =============================================================================
37 // Help
38 // =============================================================================
39
40 // -----------------------------------------------------------------------------
PrintHelp(const char * name)41 void PrintHelp(const char *name)
42 {
43 cout << "\n";
44 cout << "Usage: " << name << " [options]\n";
45 cout << "\n";
46 cout << "Description:\n";
47 cout << " Register point sets or surfaces by iteratively approximating the residual target registration error\n";
48 cout << " given current point correspondences and transformation estimate. By default, correspondences are defined\n";
49 cout << " by closest points. The default settings implement the iterative closest points (ICP) algorithm. When the\n";
50 cout << " input point sets are surface meshes, :option:`-closest-surface-points` can be used as correspondences\n";
51 cout << " instead of discrete points in order to increase the accuracy of the surface registration at the expense\n";
52 cout << " of a slightly more costly point correspondence update at each iteration.\n";
53 cout << "\n";
54 cout << "Required options:\n";
55 cout << " -t, -target <path>...\n";
56 cout << " One or more target point sets or surfaces. If multiple targets are specified,\n";
57 cout << " the number of target point sets must match the number of :option:`-source` point sets.\n";
58 cout << " In this case, the respective pairs of target and source point sets are registered\n";
59 cout << " to one another using the same target to source transformation. This option can\n";
60 cout << " be given multiple times to append target point sets.\n";
61 cout << " -s, -source <path>...\n";
62 cout << " One or more source point sets or surfaces. If a single :option:`-target` point set is given,\n";
63 cout << " all source point sets are registered to this common target point set. Otherwise,\n";
64 cout << " the number of target and source point sets must be the same. This option can be\n";
65 cout << " can be given multiple times to append source point sets.\n";
66 cout << " -o, -dofout <path>\n";
67 cout << " Output transformation file (.dof, .dof.gz).\n";
68 cout << "\n";
69 cout << "Optional arguments:\n";
70 cout << " -i, -dofin <path>|Id|identity\n";
71 cout << " File path of input transformation from which initial -model parameters are derived.\n";
72 cout << " -m, -model Rigid|Similarity|Affine\n";
73 cout << " Transformation model to use. (default: Rigid)\n";
74 cout << " -n, -iterations <n>\n";
75 cout << " Number of iterations. (default: 100)\n";
76 cout << " -c, -cor, -corr, -correspondence <name>\n";
77 cout << " Name of correspondence type, e.g., 'closest point', 'closest cell'. (default: closest point)\n";
78 cout << " -cp, -closest-point\n";
79 cout << " Alias for :option:`-correspondence 'closest point'`.\n";
80 cout << " -csp, -closest-surface-point, -closest-cell\n";
81 cout << " Alias for :option:`-correspondence 'closest surface point'`.\n";
82 cout << " -p, -par, -corpar, -corrpar <name> <value>\n";
83 cout << " Set parameter of chosen correspondence type.\n";
84 cout << " -f, -feature <name> [<weight>]\n";
85 cout << " Point or cell features based on which correspondences are determined.\n";
86 cout << " By default, the spatial 3D coordinates of the input points are used.\n";
87 cout << " -[no]symmetric [on|off]\n";
88 cout << " Approximate symmetric registration error. When this flag is set, the distance\n";
89 cout << " from every source point to its corresponding target point is considered in addition\n";
90 cout << " to the distance of each target point to its corresponding source point. (default: off)\n";
91 cout << " -[no]inverse [on|off]\n";
92 cout << " By default, the output transformation is applied to the target points.\n";
93 cout << " When this flag is set, the output transformation is applied to the source\n";
94 cout << " points instead. (default: off)\n";
95 PrintStandardOptions(cout);
96 cout << "\n";
97 cout.flush();
98 }
99
100 // =============================================================================
101 // Auxiliary functions
102 // =============================================================================
103
104 // -----------------------------------------------------------------------------
TargetIndex(int m,int n,int i)105 inline int TargetIndex(int m, int n, int i)
106 {
107 return n == 1 ? 0 : i;
108 }
109
110 // -----------------------------------------------------------------------------
Update(Array<UniquePtr<PointCorrespondence>> & cmaps)111 void Update(Array<UniquePtr<PointCorrespondence>> &cmaps)
112 {
113 for (auto &cmap : cmaps) {
114 cmap->Update();
115 }
116 }
117
118 // -----------------------------------------------------------------------------
Update(Array<RegisteredPointSet> & psets)119 void Update(Array<RegisteredPointSet> &psets)
120 {
121 for (auto &pset : psets) {
122 pset.Update();
123 }
124 }
125
126 // -----------------------------------------------------------------------------
Update(Array<RegisteredPointSet> & targets,Array<RegisteredPointSet> & sources,Array<UniquePtr<PointCorrespondence>> & cmaps)127 void Update(Array<RegisteredPointSet> &targets,
128 Array<RegisteredPointSet> &sources,
129 Array<UniquePtr<PointCorrespondence>> &cmaps)
130 {
131 Update(targets);
132 Update(sources);
133 Update(cmaps);
134 }
135
136 // -----------------------------------------------------------------------------
EvaluateRMSError(const Array<RegisteredPointSet> & targets,const Array<RegisteredPointSet> & sources,const Array<UniquePtr<PointCorrespondence>> & cmaps,bool symmetric=false)137 double EvaluateRMSError(const Array<RegisteredPointSet> &targets,
138 const Array<RegisteredPointSet> &sources,
139 const Array<UniquePtr<PointCorrespondence>> &cmaps,
140 bool symmetric = false)
141 {
142 Point p, q;
143
144 const int m = static_cast<int>(sources.size());
145 const int n = static_cast<int>(targets.size());
146
147 double error = 0;
148 int count = 0;
149
150 for (int i = 0; i < m; ++i) {
151 const int j = TargetIndex(m, n, i);
152 auto &cmap = cmaps[i];
153 auto &source = sources[i];
154 auto &target = targets[j];
155 for (int t = 0; t < target.NumberOfPoints(); ++t) {
156 target.GetPoint(t, p);
157 if (cmap->GetSourcePoint(t, q)) {
158 error += pow(q._x - p._x, 2) + pow(q._y - p._y, 2) + pow(q._z - p._z, 2);
159 ++count;
160 }
161 }
162 if (symmetric) {
163 for (int s = 0; s < source.NumberOfPoints(); ++s) {
164 source.GetInputPoint(s, q);
165 if (cmap->GetInputTargetPoint(s, p)) {
166 error += pow(q._x - p._x, 2) + pow(q._y - p._y, 2) + pow(q._z - p._z, 2);
167 ++count;
168 }
169 }
170 }
171 }
172
173 if (count == 0) FatalError("No corresponding points found");
174 return sqrt(error / count);
175 }
176
177 // -----------------------------------------------------------------------------
Fit(Transformation * dof,const Array<RegisteredPointSet> & targets,const Array<RegisteredPointSet> & sources,const Array<UniquePtr<PointCorrespondence>> & cmaps,bool symmetric=false,bool inverse=false)178 void Fit(Transformation *dof,
179 const Array<RegisteredPointSet> &targets,
180 const Array<RegisteredPointSet> &sources,
181 const Array<UniquePtr<PointCorrespondence>> &cmaps,
182 bool symmetric = false, bool inverse = false)
183 {
184 const int m = static_cast<int>(sources.size());
185 const int n = static_cast<int>(targets.size());
186
187 int no = 0;
188 for (int i = 0; i < m; ++i) {
189 const int j = TargetIndex(m, n, i);
190 no += targets[j].NumberOfPoints();
191 if (symmetric) no += sources[i].NumberOfPoints();
192 }
193
194 Array<double> x(no), y(no), z(no), dx(no), dy(no), dz(no);
195 Point p, q;
196
197 int k = 0;
198 for (int i = 0; i < m; ++i) {
199 const int j = TargetIndex(m, n, i);
200 auto &cmap = cmaps[i];
201 auto &source = sources[i];
202 auto &target = targets[j];
203 for (int t = 0; t < target.NumberOfPoints(); ++t) {
204 target.GetInputPoint(t, p);
205 if (cmap->GetInputSourcePoint(t, q)) {
206 if (inverse) swap(p, q);
207 x[k] = p._x;
208 y[k] = p._y;
209 z[k] = p._z;
210 dx[k] = q._x - p._x;
211 dy[k] = q._y - p._y;
212 dz[k] = q._z - p._z;
213 ++k;
214 }
215 }
216 if (symmetric) {
217 for (int s = 0; s < source.NumberOfPoints(); ++s) {
218 source.GetInputPoint(s, q);
219 if (cmap->GetInputTargetPoint(s, p)) {
220 if (inverse) swap(p, q);
221 x[k] = p._x;
222 y[k] = p._y;
223 z[k] = p._z;
224 dx[k] = q._x - p._x;
225 dy[k] = q._y - p._y;
226 dz[k] = q._z - p._z;
227 ++k;
228 }
229 }
230 }
231 }
232 dof->ApproximateAsNew(x.data(), y.data(), z.data(), dx.data(), dy.data(), dz.data(), k);
233 }
234
235 // -----------------------------------------------------------------------------
PrintProgress(ostream & os,int i,double error,bool flush=true)236 void PrintProgress(ostream &os, int i, double error, bool flush = true)
237 {
238 const streamsize w = os.width(0);
239 const streamsize p = os.precision(5);
240 const ios::fmtflags f = os.flags();
241
242 os << setw(3) << i << ". RMS error = " << setprecision(5) << error << "\n";
243 if (flush) os.flush();
244
245 os.width(w);
246 os.precision(p);
247 os.flags(f);
248 }
249
250 // =============================================================================
251 // Main
252 // =============================================================================
253
254 // -----------------------------------------------------------------------------
main(int argc,char * argv[])255 int main(int argc, char *argv[])
256 {
257 EXPECTS_POSARGS(0);
258
259 Array<string> target_names;
260 Array<string> source_names;
261
262 TransformationModel model = TM_Rigid;
263 string dofin_name;
264 string dofout_name;
265
266 PointCorrespondence::TypeId ctype = PointCorrespondence::ClosestPoint;
267 Array<string> feature_name;
268 Array<double> feature_weight;
269 ParameterList param;
270
271 int iterations = 100;
272 double epsilon = 0.01;
273 bool inverse = false;
274 bool symmetric = false;
275
276 verbose = 1;
277
278 for (ALL_OPTIONS) {
279 if (OPTION("-t") || OPTION("-target")) {
280 do {
281 target_names.push_back(ARGUMENT);
282 } while (HAS_ARGUMENT);
283 }
284 else if (OPTION("-s") || OPTION("-source")) {
285 do {
286 source_names.push_back(ARGUMENT);
287 } while (HAS_ARGUMENT);
288 }
289 else if (OPTION("-i") || OPTION("-dofin")) {
290 dofin_name = ARGUMENT;
291 }
292 else if (OPTION("-o") || OPTION("-dofout")) {
293 dofout_name = ARGUMENT;
294 }
295 else if (OPTION("-m") || OPTION("-model")) {
296 const char *arg = ARGUMENT;
297 if (!FromString(arg, model)) FatalError("Invalid -model argument: " << arg);
298 }
299 else if (OPTION("-n") || OPTION("-iterations")) {
300 PARSE_ARGUMENT(iterations);
301 }
302 else if (OPTION("-c") || OPTION("-cor") || OPTION("-corr") || OPTION("-correspondence")) {
303 const char *arg = ARGUMENT;
304 if (!FromString(arg, ctype)) FatalError("Invalid -correspondence argument: " << arg);
305 }
306 else if (OPTION("-p") || OPTION("-par") || OPTION("-corpar") || OPTION("-corrpar")) {
307 Insert(param, ARGUMENT, ARGUMENT);
308 }
309 else if (OPTION("-f") || OPTION("-feature")) {
310 feature_name.push_back(ARGUMENT);
311 if (HAS_ARGUMENT) feature_weight.push_back(atof(ARGUMENT));
312 else feature_weight.push_back(1.0);
313 }
314 else if (OPTION("-cp") || OPTION("-closest-point")) {
315 ctype = PointCorrespondence::ClosestPoint;
316 }
317 else if (OPTION("-csp") || OPTION("-closest-surface-point") || OPTION("-closest-cell")) {
318 ctype = PointCorrespondence::ClosestCell;
319 }
320 else if (OPTION("-epsilon")) {
321 PARSE_ARGUMENT(epsilon);
322 }
323 else HANDLE_BOOL_OPTION(inverse);
324 else HANDLE_BOOL_OPTION(symmetric);
325 else HANDLE_COMMON_OR_UNKNOWN_OPTION();
326 }
327
328 const int m = static_cast<int>(source_names.size());
329 const int n = static_cast<int>(target_names.size());
330
331 // Check required arguments
332 if (dofout_name.empty()) {
333 FatalError("Option -dofout is required!");
334 }
335 if (m == 0 || n == 0) {
336 FatalError("Options -target and -source are required!");
337 }
338 if (n > 1 && n != m) {
339 FatalError("Either specify a single -target or one target for each -source point set!");
340 }
341 if (!IsLinear(model)) {
342 FatalError("Currently only Rigid, Similarity, and Affine transformation -model supported!");
343 }
344
345 // By default, use point coordinates as featurs for point matching
346 if (feature_name.empty()) {
347 feature_name.push_back("spatial coordinates");
348 feature_weight.push_back(1.0);
349 }
350
351 // Initialize transformation
352 UniquePtr<Transformation> dof(Transformation::New(ToTransformationType(model)));
353 if (!dofin_name.empty() && dofin_name != "Id" && dofin_name != "Identity" && dofin_name != "identity") {
354 if (verbose) cout << "Reading transformation...", cout.flush();
355 UniquePtr<Transformation> dofin(Transformation::New(dofin_name.c_str()));
356 const MultiLevelTransformation *mffd = dynamic_cast<const MultiLevelTransformation *>(dofin.get());
357 const HomogeneousTransformation *ilin = nullptr;
358 if (mffd) {
359 ilin = mffd->GetGlobalTransformation();
360 } else {
361 ilin = dynamic_cast<const HomogeneousTransformation *>(dofin.get());
362 if (!ilin) {
363 FatalError("Input -dofin must be a linear or multi-level transformation");
364 }
365 }
366 HomogeneousTransformation *lin = dynamic_cast<HomogeneousTransformation *>(dof.get());
367 mirtkAssert(lin != nullptr, "expected transformation for model=" << model << " to be of type HomogeneousTransformation");
368 lin->CopyFrom(ilin);
369 if (verbose) cout << " done" << endl;
370 }
371
372 // Initialize point sets
373 if (verbose) cout << "Initialize point sets...", cout.flush();
374 Array<RegisteredPointSet> sources(m);
375 Array<RegisteredPointSet> targets(n);
376 for (int i = 0; i < m; ++i) {
377 auto pset = ReadPointSet(source_names[i].c_str());
378 if (pset->GetNumberOfPoints() == 0) {
379 FatalError("Failed to open source point set or point set contains no points: " << source_names[i]);
380 }
381 sources[i].InputPointSet(pset);
382 if (inverse) {
383 sources[i].Transformation(dof.get());
384 }
385 sources[i].Initialize();
386 }
387 for (int j = 0; j < n; ++j) {
388 auto pset = ReadPointSet(target_names[j].c_str());
389 if (pset->GetNumberOfPoints() == 0) {
390 FatalError("Failed to open target point set or point set contains no points: " << target_names[j]);
391 }
392 targets[j].InputPointSet(pset);
393 if (!inverse) {
394 targets[j].Transformation(dof.get());
395 }
396 targets[j].Initialize();
397 }
398 if (verbose) cout << " done" << endl;
399
400 // Initialize correspondence maps
401 if (verbose) cout << "Initialize correspondence maps...", cout.flush();
402 Array<UniquePtr<PointCorrespondence>> cmaps(sources.size());
403 for (int i = 0; i < m; ++i) {
404 const int j = TargetIndex(m, n, i);
405 auto &cmap = cmaps[i];
406 cmap.reset(PointCorrespondence::New(ctype));
407 cmap->FromTargetToSource(true);
408 cmap->FromSourceToTarget(symmetric);
409 cmap->Parameter(param);
410 cmap->Source(&sources[i]);
411 cmap->Target(&targets[j]);
412 for (size_t f = 0; f < feature_name.size(); ++f) {
413 cmap->AddFeature(feature_name[f].c_str(), feature_weight[f]);
414 }
415 cmap->Initialize();
416 }
417 if (verbose) cout << " done" << endl;
418
419 // Iterate least squares fitting
420 double error, last_error = numeric_limits<double>::infinity();
421 for (int iter = 0; true; ++iter) {
422 // Update correspondences
423 Update(targets, sources, cmaps);
424 // Check for convergence
425 error = EvaluateRMSError(targets, sources, cmaps);
426 if (verbose) PrintProgress(cout, iter, error);
427 if (last_error - error < epsilon) {
428 if (verbose) {
429 cout << "Converged after " << iter << " iterations." << endl;
430 }
431 break;
432 }
433 last_error = error;
434 if (iter >= iterations) {
435 if (verbose) {
436 cout << "Terminated after " << iter << " iterations." << endl;
437 }
438 break;
439 }
440 // Update transformation
441 Fit(dof.get(), targets, sources, cmaps, symmetric, inverse);
442 }
443
444 // Write resulting transformation
445 dof->Write(dofout_name.c_str());
446
447 return 0;
448 }
449