1 /*
2 * Medical Image Registration ToolKit (MIRTK)
3 *
4 * Copyright 2013-2015 Imperial College London
5 * Copyright 2013-2015 Andreas Schuh
6 *
7 * Licensed under the Apache License, Version 2.0 (the "License");
8 * you may not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 *
11 * http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
18 */
19
20 #include "mirtk/PointSetDistance.h"
21
22 #include "mirtk/Array.h"
23 #include "mirtk/Memory.h"
24 #include "mirtk/Vector3D.h"
25 #include "mirtk/PointSetIO.h"
26
27 #include "vtkSmartPointer.h"
28 #include "vtkFloatArray.h"
29 #include "vtkCellArray.h"
30 #include "vtkVertex.h"
31 #include "vtkPointData.h"
32
33 #include <cstdio>
34
35
36 namespace mirtk {
37
38
39 // =============================================================================
40 // Factory
41 // =============================================================================
42
43 // -----------------------------------------------------------------------------
New(PointSetDistanceMeasure pdm,const char * name,double w)44 PointSetDistance *PointSetDistance::New(PointSetDistanceMeasure pdm, const char *name, double w)
45 {
46 enum EnergyMeasure em = static_cast<enum EnergyMeasure>(pdm);
47 if (PDM_Begin < em && em < PDM_End) {
48 EnergyTerm *term = EnergyTerm::TryNew(em, name, w);
49 if (term) return dynamic_cast<PointSetDistance *>(term);
50 cerr << NameOfType() << "::New: Point set distance measure not available: ";
51 } else {
52 cerr << NameOfType() << "::New: Energy term is not a point set distance measure: ";
53 }
54 cerr << ToString(em) << " (" << em << ")" << endl;
55 exit(1);
56 return NULL;
57 }
58
59 // =============================================================================
60 // Construction/Destruction
61 // =============================================================================
62
63 // -----------------------------------------------------------------------------
AllocateGradientWrtTarget(int m)64 void PointSetDistance::AllocateGradientWrtTarget(int m)
65 {
66 Deallocate(_GradientWrtTarget);
67 if (m > 0) _GradientWrtTarget = Allocate<GradientType>(m);
68 }
69
70 // -----------------------------------------------------------------------------
AllocateGradientWrtSource(int n)71 void PointSetDistance::AllocateGradientWrtSource(int n)
72 {
73 Deallocate(_GradientWrtTarget);
74 if (n > 0) _GradientWrtTarget = Allocate<GradientType>(n);
75 }
76
77 // -----------------------------------------------------------------------------
PointSetDistance(const char * name,double weight)78 PointSetDistance::PointSetDistance(const char *name, double weight)
79 :
80 DataFidelity(name, weight),
81 _Target(NULL),
82 _Source(NULL),
83 _GradientWrtTarget(NULL),
84 _GradientWrtSource(NULL),
85 _InitialUpdate (false)
86 {
87 _ParameterPrefix.push_back("Point set distance ");
88 }
89
90 // -----------------------------------------------------------------------------
PointSetDistance(const PointSetDistance & other,int m,int n)91 PointSetDistance::PointSetDistance(const PointSetDistance &other, int m, int n)
92 :
93 DataFidelity(other),
94 _Target(other._Target),
95 _Source(other._Source),
96 _GradientWrtTarget(NULL),
97 _GradientWrtSource(NULL),
98 _InitialUpdate(other._InitialUpdate)
99 {
100 AllocateGradientWrtTarget(m < 0 && _Target && other._GradientWrtTarget ? other._Target->NumberOfPoints() : 0);
101 AllocateGradientWrtSource(n < 0 && _Source && other._GradientWrtSource ? other._Source->NumberOfPoints() : 0);
102 }
103
104 // -----------------------------------------------------------------------------
CopyAttributes(const PointSetDistance & other,int m,int n)105 void PointSetDistance::CopyAttributes(const PointSetDistance &other, int m, int n)
106 {
107 _Target = other._Target;
108 _Source = other._Source;
109 _InitialUpdate = other._InitialUpdate;
110 AllocateGradientWrtTarget(m < 0 && _Target && other._GradientWrtTarget ? _Target->NumberOfPoints() : m);
111 AllocateGradientWrtSource(n < 0 && _Source && other._GradientWrtSource ? _Source->NumberOfPoints() : n);
112 }
113
114 // -----------------------------------------------------------------------------
operator =(const PointSetDistance & other)115 PointSetDistance &PointSetDistance::operator =(const PointSetDistance &other)
116 {
117 if (this != &other) {
118 DataFidelity::operator =(other);
119 CopyAttributes(other);
120 }
121 return *this;
122 }
123
124 // -----------------------------------------------------------------------------
~PointSetDistance()125 PointSetDistance::~PointSetDistance()
126 {
127 Deallocate(_GradientWrtTarget);
128 Deallocate(_GradientWrtSource);
129 }
130
131 // =============================================================================
132 // Initialization
133 // =============================================================================
134
135 // -----------------------------------------------------------------------------
Initialize(int m,int n)136 void PointSetDistance::Initialize(int m, int n)
137 {
138 // Initialize base class
139 DataFidelity::Initialize();
140
141 // Check inputs
142 if (_Target == NULL) {
143 cerr << "PointSetDistance::Initialize: Target dataset is NULL" << endl;
144 exit(1);
145 }
146 if (_Source == NULL) {
147 cerr << "PointSetDistance::Initialize: Source dataset is NULL" << endl;
148 exit(1);
149 }
150
151 // Force next update if data sets have their SelfUpdate flag set
152 _InitialUpdate = true;
153
154 // Allocate memory for non-parametric gradient
155 Deallocate(_GradientWrtTarget);
156 Deallocate(_GradientWrtSource);
157 if (_Target->Transformation()) _GradientWrtTarget = Allocate<GradientType>(m);
158 if (_Source->Transformation()) _GradientWrtSource = Allocate<GradientType>(n);
159 }
160
161 // -----------------------------------------------------------------------------
Initialize()162 void PointSetDistance::Initialize()
163 {
164 // Check inputs
165 if (_Target == NULL) {
166 cerr << "PointSetDistance::Initialize: Target dataset is NULL" << endl;
167 exit(1);
168 }
169 if (_Source == NULL) {
170 cerr << "PointSetDistance::Initialize: Source dataset is NULL" << endl;
171 exit(1);
172 }
173
174 // Initialize with allocation of memory for all points
175 Initialize(_Target->NumberOfPoints(), _Source->NumberOfPoints());
176 }
177
178 // -----------------------------------------------------------------------------
Reinitialize(int m,int n)179 void PointSetDistance::Reinitialize(int m, int n)
180 {
181 // Allocate memory for non-parametric gradient
182 Deallocate(_GradientWrtTarget);
183 Deallocate(_GradientWrtSource);
184 if (_Target->Transformation()) _GradientWrtTarget = Allocate<GradientType>(m);
185 if (_Source->Transformation()) _GradientWrtSource = Allocate<GradientType>(n);
186 }
187
188 // -----------------------------------------------------------------------------
Reinitialize()189 void PointSetDistance::Reinitialize()
190 {
191 // Reinitialize with allocation of memory for all points
192 Reinitialize(_Target->NumberOfPoints(), _Source->NumberOfPoints());
193 }
194
195 // =============================================================================
196 // Evaluation
197 // =============================================================================
198
199 // -----------------------------------------------------------------------------
Update(bool)200 void PointSetDistance::Update(bool)
201 {
202 if (_InitialUpdate || _Target->Transformation()) {
203 _Target->Update(_InitialUpdate && _Target->SelfUpdate());
204 }
205 if (_InitialUpdate || _Source->Transformation()) {
206 _Source->Update(_InitialUpdate && _Source->SelfUpdate());
207 }
208 _InitialUpdate = false;
209 }
210
211 // -----------------------------------------------------------------------------
212 void PointSetDistance
ParametricGradient(const RegisteredPointSet * wrt_pset,const Vector3D<double> * np_gradient,double * gradient,double weight)213 ::ParametricGradient(const RegisteredPointSet *wrt_pset,
214 const Vector3D<double> *np_gradient,
215 double *gradient,
216 double weight)
217 {
218 const class Transformation * const T = wrt_pset->Transformation();
219 mirtkAssert(T != NULL, "point set is being transformed");
220 const double t0 = wrt_pset->InputTime();
221 const double t = wrt_pset->Time();
222 T->ParametricGradient(wrt_pset->InputPoints(), np_gradient, gradient, t, t0, weight);
223 }
224
225 // -----------------------------------------------------------------------------
EvaluateGradient(double * gradient,double,double weight)226 void PointSetDistance::EvaluateGradient(double *gradient, double, double weight)
227 {
228 // Get transformations of input data sets
229 const class Transformation * const T1 = _Target->Transformation();
230 const class Transformation * const T2 = _Source->Transformation();
231 // Compute parametric gradient w.r.t target transformation
232 if (T1) {
233 this->NonParametricGradient(_Target, _GradientWrtTarget);
234 this->ParametricGradient (_Target, _GradientWrtTarget, gradient, weight);
235 }
236 // If target and source are transformed by different transformations,
237 // the gradient vector contains first the derivative values w.r.t the
238 // parameters of the target transformation followed by those computed
239 // w.r.t the parameters of the source transformation. Otherwise, if
240 // both point sets are transformed by the same transformation, i.e., a
241 // velocity based transformation integrated half way in both directions,
242 // the derivative values are summed up instead.
243 if (T1 && T2 && !HaveSameDOFs(T1, T2)) gradient += T2->NumberOfDOFs();
244 // Compute parametric gradient w.r.t source transformation
245 if (T2) {
246 this->NonParametricGradient(_Source, _GradientWrtSource);
247 this->ParametricGradient (_Source, _GradientWrtSource, gradient, weight);
248 }
249 }
250
251 // =============================================================================
252 // Debugging
253 // =============================================================================
254
255 // -----------------------------------------------------------------------------
ToFloatArray(const PointSetDistance::GradientType * v,int n)256 inline vtkSmartPointer<vtkFloatArray> ToFloatArray(const PointSetDistance::GradientType *v, int n)
257 {
258 vtkSmartPointer<vtkFloatArray> array = vtkSmartPointer<vtkFloatArray>::New();
259 array->SetNumberOfComponents(3);
260 array->SetNumberOfTuples(n);
261 for (int i = 0; i < n; ++i) array->SetTuple3(i, v[i]._x, v[i]._y, v[i]._z);
262 return array;
263 }
264
265 // -----------------------------------------------------------------------------
WriteDataSets(const char * p,const char * suffix,bool all) const266 void PointSetDistance::WriteDataSets(const char *p, const char *suffix, bool all) const
267 {
268 const int sz = 1024;
269 char fname[sz];
270 string _prefix = Prefix(p);
271 const char *prefix = _prefix.c_str();
272
273 if (_Target->Transformation() || all) {
274 snprintf(fname, sz, "%starget%s%s", prefix, suffix, _Target->DefaultExtension());
275 _Target->Write(fname);
276 }
277 if (_Source->Transformation() || all) {
278 snprintf(fname, sz, "%ssource%s%s", prefix, suffix, _Source->DefaultExtension());
279 _Source->Write(fname);
280 }
281 }
282
283 // -----------------------------------------------------------------------------
WriteGradient(const char * p,const char * suffix) const284 void PointSetDistance::WriteGradient(const char *p, const char *suffix) const
285 {
286 const int sz = 1024;
287 char fname[sz];
288 string _prefix = Prefix(p);
289 const char *prefix = _prefix.c_str();
290
291 if (_GradientWrtTarget) {
292 snprintf(fname, sz, "%sgradient_wrt_target%s.vtp", prefix, suffix);
293 this->WriteGradient(fname, _Target, _GradientWrtTarget);
294 }
295 if (_GradientWrtSource) {
296 snprintf(fname, sz, "%sgradient_wrt_source%s.vtp", prefix, suffix);
297 this->WriteGradient(fname, _Source, _GradientWrtSource);
298 }
299 }
300
301 // -----------------------------------------------------------------------------
WriteGradient(const char * fname,const RegisteredPointSet * data,const GradientType * g,const Array<int> * sample) const302 void PointSetDistance::WriteGradient(const char *fname,
303 const RegisteredPointSet *data,
304 const GradientType *g,
305 const Array<int> *sample) const
306 {
307 bool samples_only = sample && !sample->empty();
308 const vtkIdType n = (samples_only ? sample->size() : data->NumberOfPoints());
309
310 vtkSmartPointer<vtkPoints> points;
311 vtkSmartPointer<vtkCellArray> vertices;
312 vtkSmartPointer<vtkPolyData> output;
313
314 points = vtkSmartPointer<vtkPoints>::New();
315 points->SetNumberOfPoints(n);
316
317 vertices = vtkSmartPointer<vtkCellArray>::New();
318 vertices->Allocate(n);
319
320 double p[3];
321 for (vtkIdType i = 0; i < n; ++i) {
322 data->GetInputPoint((samples_only ? (*sample)[i] : i), p);
323 points ->SetPoint(i, p);
324 vertices->InsertNextCell(1, &i);
325 }
326
327 vtkSmartPointer<vtkDataArray> gradient = ToFloatArray(g, n);
328 gradient->SetName("gradient");
329
330 output = vtkSmartPointer<vtkPolyData>::New();
331 output->SetPoints(points);
332 output->SetVerts(vertices);
333 output->GetPointData()->AddArray(gradient);
334
335 WritePolyData(fname, output);
336 }
337
338
339 } // namespace mirtk
340