1 /*
2  * Medical Image Registration ToolKit (MIRTK)
3  *
4  * Copyright 2013-2015 Imperial College London
5  * Copyright 2013-2015 Andreas Schuh
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  *     http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19 
20 #include "mirtk/PointSetDistance.h"
21 
22 #include "mirtk/Array.h"
23 #include "mirtk/Memory.h"
24 #include "mirtk/Vector3D.h"
25 #include "mirtk/PointSetIO.h"
26 
27 #include "vtkSmartPointer.h"
28 #include "vtkFloatArray.h"
29 #include "vtkCellArray.h"
30 #include "vtkVertex.h"
31 #include "vtkPointData.h"
32 
33 #include <cstdio>
34 
35 
36 namespace mirtk {
37 
38 
39 // =============================================================================
40 // Factory
41 // =============================================================================
42 
43 // -----------------------------------------------------------------------------
New(PointSetDistanceMeasure pdm,const char * name,double w)44 PointSetDistance *PointSetDistance::New(PointSetDistanceMeasure pdm, const char *name, double w)
45 {
46   enum EnergyMeasure em = static_cast<enum EnergyMeasure>(pdm);
47   if (PDM_Begin < em && em < PDM_End) {
48     EnergyTerm *term = EnergyTerm::TryNew(em, name, w);
49     if (term) return dynamic_cast<PointSetDistance *>(term);
50     cerr << NameOfType() << "::New: Point set distance measure not available: ";
51   } else {
52     cerr << NameOfType() << "::New: Energy term is not a point set distance measure: ";
53   }
54   cerr << ToString(em) << " (" << em << ")" << endl;
55   exit(1);
56   return NULL;
57 }
58 
59 // =============================================================================
60 // Construction/Destruction
61 // =============================================================================
62 
63 // -----------------------------------------------------------------------------
AllocateGradientWrtTarget(int m)64 void PointSetDistance::AllocateGradientWrtTarget(int m)
65 {
66   Deallocate(_GradientWrtTarget);
67   if (m > 0) _GradientWrtTarget = Allocate<GradientType>(m);
68 }
69 
70 // -----------------------------------------------------------------------------
AllocateGradientWrtSource(int n)71 void PointSetDistance::AllocateGradientWrtSource(int n)
72 {
73   Deallocate(_GradientWrtTarget);
74   if (n > 0) _GradientWrtTarget = Allocate<GradientType>(n);
75 }
76 
77 // -----------------------------------------------------------------------------
PointSetDistance(const char * name,double weight)78 PointSetDistance::PointSetDistance(const char *name, double weight)
79 :
80   DataFidelity(name, weight),
81   _Target(NULL),
82   _Source(NULL),
83   _GradientWrtTarget(NULL),
84   _GradientWrtSource(NULL),
85   _InitialUpdate    (false)
86 {
87   _ParameterPrefix.push_back("Point set distance ");
88 }
89 
90 // -----------------------------------------------------------------------------
PointSetDistance(const PointSetDistance & other,int m,int n)91 PointSetDistance::PointSetDistance(const PointSetDistance &other, int m, int n)
92 :
93   DataFidelity(other),
94   _Target(other._Target),
95   _Source(other._Source),
96   _GradientWrtTarget(NULL),
97   _GradientWrtSource(NULL),
98   _InitialUpdate(other._InitialUpdate)
99 {
100   AllocateGradientWrtTarget(m < 0 && _Target && other._GradientWrtTarget ? other._Target->NumberOfPoints() : 0);
101   AllocateGradientWrtSource(n < 0 && _Source && other._GradientWrtSource ? other._Source->NumberOfPoints() : 0);
102 }
103 
104 // -----------------------------------------------------------------------------
CopyAttributes(const PointSetDistance & other,int m,int n)105 void PointSetDistance::CopyAttributes(const PointSetDistance &other, int m, int n)
106 {
107   _Target        = other._Target;
108   _Source        = other._Source;
109   _InitialUpdate = other._InitialUpdate;
110   AllocateGradientWrtTarget(m < 0 && _Target && other._GradientWrtTarget ? _Target->NumberOfPoints() : m);
111   AllocateGradientWrtSource(n < 0 && _Source && other._GradientWrtSource ? _Source->NumberOfPoints() : n);
112 }
113 
114 // -----------------------------------------------------------------------------
operator =(const PointSetDistance & other)115 PointSetDistance &PointSetDistance::operator =(const PointSetDistance &other)
116 {
117   if (this != &other) {
118     DataFidelity::operator =(other);
119     CopyAttributes(other);
120   }
121   return *this;
122 }
123 
124 // -----------------------------------------------------------------------------
~PointSetDistance()125 PointSetDistance::~PointSetDistance()
126 {
127   Deallocate(_GradientWrtTarget);
128   Deallocate(_GradientWrtSource);
129 }
130 
131 // =============================================================================
132 // Initialization
133 // =============================================================================
134 
135 // -----------------------------------------------------------------------------
Initialize(int m,int n)136 void PointSetDistance::Initialize(int m, int n)
137 {
138   // Initialize base class
139   DataFidelity::Initialize();
140 
141   // Check inputs
142   if (_Target == NULL) {
143     cerr << "PointSetDistance::Initialize: Target dataset is NULL" << endl;
144     exit(1);
145   }
146   if (_Source == NULL) {
147     cerr << "PointSetDistance::Initialize: Source dataset is NULL" << endl;
148     exit(1);
149   }
150 
151   // Force next update if data sets have their SelfUpdate flag set
152   _InitialUpdate = true;
153 
154   // Allocate memory for non-parametric gradient
155   Deallocate(_GradientWrtTarget);
156   Deallocate(_GradientWrtSource);
157   if (_Target->Transformation()) _GradientWrtTarget = Allocate<GradientType>(m);
158   if (_Source->Transformation()) _GradientWrtSource = Allocate<GradientType>(n);
159 }
160 
161 // -----------------------------------------------------------------------------
Initialize()162 void PointSetDistance::Initialize()
163 {
164   // Check inputs
165   if (_Target == NULL) {
166     cerr << "PointSetDistance::Initialize: Target dataset is NULL" << endl;
167     exit(1);
168   }
169   if (_Source == NULL) {
170     cerr << "PointSetDistance::Initialize: Source dataset is NULL" << endl;
171     exit(1);
172   }
173 
174   // Initialize with allocation of memory for all points
175   Initialize(_Target->NumberOfPoints(), _Source->NumberOfPoints());
176 }
177 
178 // -----------------------------------------------------------------------------
Reinitialize(int m,int n)179 void PointSetDistance::Reinitialize(int m, int n)
180 {
181   // Allocate memory for non-parametric gradient
182   Deallocate(_GradientWrtTarget);
183   Deallocate(_GradientWrtSource);
184   if (_Target->Transformation()) _GradientWrtTarget = Allocate<GradientType>(m);
185   if (_Source->Transformation()) _GradientWrtSource = Allocate<GradientType>(n);
186 }
187 
188 // -----------------------------------------------------------------------------
Reinitialize()189 void PointSetDistance::Reinitialize()
190 {
191   // Reinitialize with allocation of memory for all points
192   Reinitialize(_Target->NumberOfPoints(), _Source->NumberOfPoints());
193 }
194 
195 // =============================================================================
196 // Evaluation
197 // =============================================================================
198 
199 // -----------------------------------------------------------------------------
Update(bool)200 void PointSetDistance::Update(bool)
201 {
202   if (_InitialUpdate || _Target->Transformation()) {
203     _Target->Update(_InitialUpdate && _Target->SelfUpdate());
204   }
205   if (_InitialUpdate || _Source->Transformation()) {
206     _Source->Update(_InitialUpdate && _Source->SelfUpdate());
207   }
208   _InitialUpdate = false;
209 }
210 
211 // -----------------------------------------------------------------------------
212 void PointSetDistance
ParametricGradient(const RegisteredPointSet * wrt_pset,const Vector3D<double> * np_gradient,double * gradient,double weight)213 ::ParametricGradient(const RegisteredPointSet *wrt_pset,
214                      const Vector3D<double>   *np_gradient,
215                      double                   *gradient,
216                      double                    weight)
217 {
218   const class Transformation * const T = wrt_pset->Transformation();
219   mirtkAssert(T != NULL, "point set is being transformed");
220   const double t0 = wrt_pset->InputTime();
221   const double t  = wrt_pset->Time();
222   T->ParametricGradient(wrt_pset->InputPoints(), np_gradient, gradient, t, t0, weight);
223 }
224 
225 // -----------------------------------------------------------------------------
EvaluateGradient(double * gradient,double,double weight)226 void PointSetDistance::EvaluateGradient(double *gradient, double, double weight)
227 {
228   // Get transformations of input data sets
229   const class Transformation * const T1 = _Target->Transformation();
230   const class Transformation * const T2 = _Source->Transformation();
231   // Compute parametric gradient w.r.t target transformation
232   if (T1) {
233     this->NonParametricGradient(_Target, _GradientWrtTarget);
234     this->ParametricGradient   (_Target, _GradientWrtTarget, gradient, weight);
235   }
236   // If target and source are transformed by different transformations,
237   // the gradient vector contains first the derivative values w.r.t the
238   // parameters of the target transformation followed by those computed
239   // w.r.t the parameters of the source transformation. Otherwise, if
240   // both point sets are transformed by the same transformation, i.e., a
241   // velocity based transformation integrated half way in both directions,
242   // the derivative values are summed up instead.
243   if (T1 && T2 && !HaveSameDOFs(T1, T2)) gradient += T2->NumberOfDOFs();
244   // Compute parametric gradient w.r.t source transformation
245   if (T2) {
246     this->NonParametricGradient(_Source, _GradientWrtSource);
247     this->ParametricGradient   (_Source, _GradientWrtSource, gradient, weight);
248   }
249 }
250 
251 // =============================================================================
252 // Debugging
253 // =============================================================================
254 
255 // -----------------------------------------------------------------------------
ToFloatArray(const PointSetDistance::GradientType * v,int n)256 inline vtkSmartPointer<vtkFloatArray> ToFloatArray(const PointSetDistance::GradientType *v, int n)
257 {
258   vtkSmartPointer<vtkFloatArray> array = vtkSmartPointer<vtkFloatArray>::New();
259   array->SetNumberOfComponents(3);
260   array->SetNumberOfTuples(n);
261   for (int i = 0; i < n; ++i) array->SetTuple3(i, v[i]._x, v[i]._y, v[i]._z);
262   return array;
263 }
264 
265 // -----------------------------------------------------------------------------
WriteDataSets(const char * p,const char * suffix,bool all) const266 void PointSetDistance::WriteDataSets(const char *p, const char *suffix, bool all) const
267 {
268   const int   sz = 1024;
269   char        fname[sz];
270   string _prefix = Prefix(p);
271   const char  *prefix = _prefix.c_str();
272 
273   if (_Target->Transformation() || all) {
274     snprintf(fname, sz, "%starget%s%s", prefix, suffix, _Target->DefaultExtension());
275     _Target->Write(fname);
276   }
277   if (_Source->Transformation() || all) {
278     snprintf(fname, sz, "%ssource%s%s", prefix, suffix, _Source->DefaultExtension());
279     _Source->Write(fname);
280   }
281 }
282 
283 // -----------------------------------------------------------------------------
WriteGradient(const char * p,const char * suffix) const284 void PointSetDistance::WriteGradient(const char *p, const char *suffix) const
285 {
286   const int   sz = 1024;
287   char        fname[sz];
288   string _prefix = Prefix(p);
289   const char  *prefix = _prefix.c_str();
290 
291   if (_GradientWrtTarget) {
292     snprintf(fname, sz, "%sgradient_wrt_target%s.vtp", prefix, suffix);
293     this->WriteGradient(fname, _Target, _GradientWrtTarget);
294   }
295   if (_GradientWrtSource) {
296     snprintf(fname, sz, "%sgradient_wrt_source%s.vtp", prefix, suffix);
297     this->WriteGradient(fname, _Source, _GradientWrtSource);
298   }
299 }
300 
301 // -----------------------------------------------------------------------------
WriteGradient(const char * fname,const RegisteredPointSet * data,const GradientType * g,const Array<int> * sample) const302 void PointSetDistance::WriteGradient(const char               *fname,
303                                      const RegisteredPointSet *data,
304                                      const GradientType       *g,
305                                      const Array<int>         *sample) const
306 {
307   bool samples_only = sample && !sample->empty();
308   const vtkIdType n = (samples_only ? sample->size() : data->NumberOfPoints());
309 
310   vtkSmartPointer<vtkPoints>    points;
311   vtkSmartPointer<vtkCellArray> vertices;
312   vtkSmartPointer<vtkPolyData>  output;
313 
314   points = vtkSmartPointer<vtkPoints>::New();
315   points->SetNumberOfPoints(n);
316 
317   vertices = vtkSmartPointer<vtkCellArray>::New();
318   vertices->Allocate(n);
319 
320   double p[3];
321   for (vtkIdType i = 0; i < n; ++i) {
322     data->GetInputPoint((samples_only ? (*sample)[i] : i), p);
323     points  ->SetPoint(i, p);
324     vertices->InsertNextCell(1, &i);
325   }
326 
327   vtkSmartPointer<vtkDataArray> gradient = ToFloatArray(g, n);
328   gradient->SetName("gradient");
329 
330   output = vtkSmartPointer<vtkPolyData>::New();
331   output->SetPoints(points);
332   output->SetVerts(vertices);
333   output->GetPointData()->AddArray(gradient);
334 
335   WritePolyData(fname, output);
336 }
337 
338 
339 } // namespace mirtk
340