1 /*
2  * Medical Image Registration ToolKit (MIRTK)
3  *
4  * Copyright 2013-2015 Imperial College London
5  * Copyright 2013-2015 Andreas Schuh
6  *
7  * Licensed under the Apache License, Version 2.0 (the "License");
8  * you may not use this file except in compliance with the License.
9  * You may obtain a copy of the License at
10  *
11  *     http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing, software
14  * distributed under the License is distributed on an "AS IS" BASIS,
15  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16  * See the License for the specific language governing permissions and
17  * limitations under the License.
18  */
19 
20 #include "mirtk/GenericRegistrationDebugger.h"
21 
22 #include "mirtk/Config.h" // WINDOWS
23 #include "mirtk/Event.h"
24 #include "mirtk/Point.h"
25 #include "mirtk/Matrix.h"
26 #include "mirtk/GenericImage.h"
27 #include "mirtk/GenericRegistrationFilter.h"
28 #include "mirtk/HomogeneousTransformation.h"
29 #include "mirtk/FreeFormTransformation.h"
30 #include "mirtk/MultiLevelTransformation.h"
31 #include "mirtk/ImageSimilarity.h"
32 
33 #ifdef HAVE_VTK
34 #  include "mirtk/Vtk.h"
35 #  include "vtkPoints.h"
36 #  include "vtkPointData.h"
37 #  include "vtkShortArray.h"
38 #  include "vtkFloatArray.h"
39 #  include "vtkStructuredGrid.h"
40 #  include "vtkXMLStructuredGridWriter.h"
41 #endif // HAVE_VTK
42 
43 #ifdef HAVE_MIRTK_PointSet
44 #  include "vtkPointSet.h"
45 #  include "mirtk/PointSetIO.h"
46 #endif // HAVE_MIRTK_PointSet
47 
48 #include "mirtk/CommonExport.h"
49 
50 #include <cstdio>
51 
52 
53 namespace mirtk {
54 
55 
56 // global "debug" flag (cf. mirtk/Options.h)
57 MIRTK_Common_EXPORT extern int debug;
58 
59 
60 // -----------------------------------------------------------------------------
CopyString(char * out,size_t sz,const string & str)61 void CopyString(char *out, size_t sz, const string &str)
62 {
63 #ifdef WINDOWS
64   strcpy_s(out, sz, str.c_str());
65 #else
66   strcpy(out, str.c_str());
67 #endif
68 }
69 
70 // -----------------------------------------------------------------------------
71 template <class TReal>
WriteGradient(const char * fname,FreeFormTransformation * ffd,int l,const TReal * g)72 void WriteGradient(const char *fname, FreeFormTransformation *ffd, int l, const TReal *g)
73 {
74   GenericImage<TReal> gradient(ffd->Attributes(), 3);
75   int xdof, ydof, zdof;
76   for (int k = 0; k < ffd->Z(); ++k)
77   for (int j = 0; j < ffd->Y(); ++j)
78   for (int i = 0; i < ffd->X(); ++i) {
79     ffd->IndexToDOFs(ffd->LatticeToIndex(i, j, k, l), xdof, ydof, zdof);
80     gradient(i, j, k, 0) = g[xdof];
81     gradient(i, j, k, 1) = g[ydof];
82     gradient(i, j, k, 2) = g[zdof];
83   }
84   gradient.Write(fname);
85 }
86 
87 // -----------------------------------------------------------------------------
WriteAsVTKDataSet(const char * fname,FreeFormTransformation * ffd,const double * g=NULL)88 static void WriteAsVTKDataSet(const char *fname, FreeFormTransformation *ffd, const double *g = NULL)
89 {
90 #ifdef HAVE_VTK
91   vtkSmartPointer<vtkPoints>         pos  = vtkSmartPointer<vtkPoints>::New();
92   vtkSmartPointer<vtkShortArray>     stat = vtkSmartPointer<vtkShortArray>::New();
93   vtkSmartPointer<vtkFloatArray>     coef = vtkSmartPointer<vtkFloatArray>::New();
94   vtkSmartPointer<vtkFloatArray>     disp = vtkSmartPointer<vtkFloatArray>::New();
95   vtkSmartPointer<vtkFloatArray>     grad = (g ? vtkSmartPointer<vtkFloatArray>::New() : NULL);
96   vtkSmartPointer<vtkStructuredGrid> grid = vtkSmartPointer<vtkStructuredGrid>::New();
97 
98   pos->SetNumberOfPoints(ffd->NumberOfCPs());
99 
100   stat->SetName("status");
101   stat->SetNumberOfComponents(1);
102   stat->SetNumberOfTuples(ffd->NumberOfCPs());
103 
104   coef->SetName("coefficient");
105   coef->SetNumberOfComponents(3);
106   coef->SetNumberOfTuples(ffd->NumberOfCPs());
107 
108   disp->SetName("displacement");
109   disp->SetNumberOfComponents(3);
110   disp->SetNumberOfTuples(ffd->NumberOfCPs());
111 
112   if (grad) {
113     grad->SetName("gradient");
114     grad->SetNumberOfComponents(3);
115     grad->SetNumberOfTuples(ffd->NumberOfCPs());
116   }
117 
118   int    i, j, k;
119   double x1, y1, z1, x2, y2, z2;
120   for (int cp = 0; cp < ffd->NumberOfCPs(); ++cp) {
121     ffd->IndexToLattice(cp, i, j, k);
122     x1 = i, y1 = j, z1 = k;
123     ffd->LatticeToWorld(x1, y1, z1);
124     pos->SetPoint(cp, x1, y1, z1);
125     x2 = x1, y2 = y1, z2 = z1;
126     ffd->Transform(x2, y2, z2);
127     disp->SetTuple3(cp, x2 - x1, y2 - y1, z2 - z1);
128     stat->SetTuple1(cp, ffd->IsActive(cp));
129     ffd->Get(i, j, k, x2, y2, z2);
130     coef->SetTuple3(cp, x2, y2, z2);
131     if (grad) {
132       ffd->IndexToDOFs(cp, i, j, k);
133       grad->SetTuple3(cp, g[i], g[j], g[k]);
134     }
135   }
136 
137   grid->SetDimensions(ffd->X(), ffd->Y(), ffd->Z());
138   grid->SetPoints(pos);
139   grid->GetPointData()->SetScalars(stat);
140   grid->GetPointData()->SetVectors(coef);
141   grid->GetPointData()->AddArray(disp);
142   if (grad) grid->GetPointData()->AddArray(grad);
143 
144   vtkSmartPointer<vtkXMLStructuredGridWriter> writer = vtkSmartPointer<vtkXMLStructuredGridWriter>::New();
145   writer->SetFileName(fname);
146   writer->SetCompressorTypeToZLib();
147   SetVTKInput(writer, grid);
148   writer->Update();
149 #endif // HAVE_VTK
150 }
151 
152 // -----------------------------------------------------------------------------
WriteTransformation(const char * fname,HomogeneousTransformation * lin,const Point & target_offset,const Point & source_offset)153 void WriteTransformation(const char                *fname,
154                          HomogeneousTransformation *lin,
155                          const Point               &target_offset,
156                          const Point               &source_offset)
157 {
158   const Matrix mat = lin->GetMatrix();
159   Matrix pre (4, 4);
160   Matrix post(4, 4);
161   pre .Ident();
162   post.Ident();
163   pre (0, 3) = - target_offset._x;
164   pre (1, 3) = - target_offset._y;
165   pre (2, 3) = - target_offset._z;
166   post(0, 3) = + source_offset._x;
167   post(1, 3) = + source_offset._y;
168   post(2, 3) = + source_offset._z;
169   lin->PutMatrix(post * mat * pre);
170   lin->Write(fname);
171   lin->PutMatrix(mat);
172 }
173 
174 // -----------------------------------------------------------------------------
GenericRegistrationDebugger(const char * prefix)175 GenericRegistrationDebugger::GenericRegistrationDebugger(const char *prefix)
176 :
177   _Prefix      (prefix),
178   _LevelPrefix (true),
179   _Registration(NULL)
180 {
181 }
182 
183 // -----------------------------------------------------------------------------
~GenericRegistrationDebugger()184 GenericRegistrationDebugger::~GenericRegistrationDebugger()
185 {
186 }
187 
188 // -----------------------------------------------------------------------------
HandleEvent(Observable * obj,Event event,const void * data)189 void GenericRegistrationDebugger::HandleEvent(Observable *obj, Event event, const void *data)
190 {
191   GenericRegistrationFilter * const r = _Registration;
192 
193   const int sz = 256;
194   char prefix[sz];
195   char suffix[sz];
196   char fname [sz];
197 
198   // ---------------------------------------------------------------------------
199   // ---------------------------------------------------------------------------
200   // Initialize/update debugger, set common file name prefix/suffix
201   switch (event) {
202 
203     // -------------------------------------------------------------------------
204     // Attach/detach debugger
205     case RegisteredEvent:
206       _Registration = dynamic_cast<GenericRegistrationFilter *>(obj);
207       if (!_Registration) {
208         cerr << "GenericRegistrationDebugger::HandleEvent: Cannot attach debugger to object which is not of type GenericRegistrationFilter" << endl;
209         exit(1);
210       }
211       _Iteration = 0;
212       break;
213     case UnregisteredEvent:
214       _Registration = NULL;
215       break;
216 
217     // -------------------------------------------------------------------------
218     // Start/end
219     case StartEvent:
220       _Level = reinterpret_cast<const struct Iteration *>(data)->Count() + 1; // Iter()
221       if (_LevelPrefix) _Iteration = 0;
222       // Get pointers to similarity terms
223       _Similarity.clear();
224       for (int i = 0; i < r->_Energy.NumberOfTerms(); ++i) {
225         ImageSimilarity *similarity = dynamic_cast<ImageSimilarity *>(r->_Energy.Term(i));
226         if (similarity) _Similarity.push_back(similarity);
227       }
228       // Do not add a break statement here!
229     case EndEvent:
230       snprintf(prefix, sz, "%slevel_%d_", _Prefix.c_str(), _Level);
231       suffix[0] = '\0';
232       break;
233 
234     // -------------------------------------------------------------------------
235     // Iteration
236     case IterationStartEvent:
237     case IterationEvent:
238       ++_Iteration;
239       return; // No data to write yet
240     case LineSearchIterationStartEvent:
241       _LineIteration = reinterpret_cast<const struct Iteration *>(data)->Iter();
242       return; // No data to write yet
243 
244     case LineSearchStartEvent:
245       if (_LevelPrefix) {
246         snprintf(prefix, sz, "%slevel_%d_",         _Prefix.c_str(), _Level);
247         snprintf(suffix, sz, "_%03d",               _Iteration);
248       } else {
249         CopyString(prefix, sz, _Prefix);
250         snprintf(suffix, sz, "_%03d",               _Iteration);
251       }
252       break;
253     case AcceptedStepEvent:
254       if (_LevelPrefix) {
255         snprintf(prefix, sz, "%slevel_%d_",         _Prefix.c_str(), _Level);
256         snprintf(suffix, sz, "_%03d_%03d_accepted", _Iteration, _LineIteration);
257       } else {
258         CopyString(prefix, sz, _Prefix);
259         snprintf(suffix, sz, "_%03d_%03d_accepted", _Iteration, _LineIteration);
260       }
261       break;
262     case RejectedStepEvent:
263       if (_LevelPrefix) {
264         snprintf(prefix, sz, "%slevel_%d_",         _Prefix.c_str(), _Level);
265         snprintf(suffix, sz, "_%03d_%03d_rejected", _Iteration, _LineIteration);
266       } else {
267         CopyString(prefix, sz, _Prefix);
268         snprintf(suffix, sz, "_%03d_%03d_rejected", _Iteration, _LineIteration);
269       }
270       break;
271 
272     // -------------------------------------------------------------------------
273     // Ignored event
274     default: return;
275   }
276 
277   MultiLevelTransformation  *mffd = NULL;
278   FreeFormTransformation    *ffd  = NULL;
279   HomogeneousTransformation *lin  = NULL;
280 
281   if (r) {
282     (mffd = dynamic_cast<MultiLevelTransformation  *>(r->_Transformation)) ||
283     (ffd  = dynamic_cast<FreeFormTransformation    *>(r->_Transformation)) ||
284     (lin  = dynamic_cast<HomogeneousTransformation *>(r->_Transformation));
285   }
286   if (mffd) {
287     for (int l = 0; l < mffd->NumberOfLevels(); ++l) {
288       if (mffd->LocalTransformationIsActive(l)) {
289         if (ffd) {
290           ffd = NULL;
291           break;
292         }
293         ffd = mffd->GetLocalTransformation(l);
294       }
295     }
296   }
297 
298   // ---------------------------------------------------------------------------
299   // ---------------------------------------------------------------------------
300   // Write debug information
301   switch (event) {
302 
303     // -------------------------------------------------------------------------
304     // Write initial state
305     case StartEvent: {
306 
307       // Write input images and their derivatives
308       for (size_t i = 0; i < r->_Image[r->_CurrentLevel].size(); ++i) {
309         snprintf(fname, sz, "%simage_%02zu", prefix, i+1);
310         r->_Image[r->_CurrentLevel][i].Write(fname);
311         if (debug >= 2) {
312           BaseImage *gradient = NULL;
313           BaseImage *hessian  = NULL;
314           for (size_t j = 0; j < _Similarity.size(); ++j) {
315             if (_Similarity[j]->Target()->InputImage() == &r->_Image[r->_CurrentLevel][i]) {
316               if (_Similarity[j]->Target()->PrecomputeDerivatives()) {
317                 gradient = _Similarity[j]->Target()->InputGradient();
318                 hessian  = _Similarity[j]->Target()->InputHessian();
319               }
320               break;
321             }
322             if (_Similarity[j]->Source()->InputImage() == &r->_Image[r->_CurrentLevel][i]) {
323               if (_Similarity[j]->Source()->PrecomputeDerivatives()) {
324                 gradient = _Similarity[j]->Source()->InputGradient();
325                 hessian  = _Similarity[j]->Source()->InputHessian();
326               }
327               break;
328             }
329           }
330           if (gradient) {
331             snprintf(fname, sz, "%simage_%02zu_gradient", prefix, i+1);
332             gradient->Write(fname);
333           }
334           if (hessian) {
335             snprintf(fname, sz, "%simage_%02zu_hessian", prefix, i+1);
336             hessian->Write(fname);
337           }
338         }
339       }
340 
341       // Write input domain mask
342       if (r->_Mask[r->_CurrentLevel]) {
343         snprintf(fname, sz, "%smask", prefix);
344         r->_Mask[r->_CurrentLevel]->Write(fname);
345       }
346 
347       // Write input point set
348       #ifdef HAVE_MIRTK_PointSet
349         for (size_t i = 0; i < r->_PointSet[r->_CurrentLevel].size(); ++i) {
350           vtkPointSet *pointset = r->_PointSet[r->_CurrentLevel][i];
351           snprintf(fname, sz, "%spointset_%02zu%s", prefix, i+1, DefaultExtension(pointset));
352           WritePointSet(fname, pointset);
353         }
354       #endif // HAVE_MIRTK_PointSet
355 
356     } break;
357 
358     // -------------------------------------------------------------------------
359     // Write intermediate results after each gradient step
360     case LineSearchStartEvent: {
361 
362       // Energy gradient vector
363       const double * const gradient = reinterpret_cast<const LineSearchStep *>(data)->_Direction;
364 
365       // Write input and other debug output of energy terms
366       r->_Energy.WriteDataSets(prefix, suffix, _Iteration == 1);
367 
368       if (debug >= 3) {
369 
370         // Write non-parametric gradient of data fidelity terms
371         r->_Energy.WriteGradient(prefix, suffix);
372 
373         // Write energy gradient(s) w.r.t control points
374         if (ffd) {
375           if (ffd->T() > 1) {
376             for (int l = 0; l < ffd->T(); ++l) {
377               snprintf(fname, sz, "%senergy_gradient_t%02d%s", prefix, l+1, suffix);
378               WriteGradient(fname, ffd, l, gradient);
379             }
380           } else {
381             snprintf(fname, sz, "%senergy_gradient%s", prefix, suffix);
382             WriteGradient(fname, ffd, 0, gradient);
383           }
384         } else if (mffd) {
385           const double *g = gradient;
386           for (int i = 0; i < mffd->NumberOfLevels(); ++i) {
387             if (!mffd->LocalTransformationIsActive(i)) continue;
388             ffd = mffd->GetLocalTransformation(i);
389             if (ffd->T() > 1) {
390               for (int l = 0; l < ffd->T(); ++l) {
391                 snprintf(fname, sz, "%senergy_gradient_wrt_ffd_%d_t%02d%s", prefix, i+1, l+1, suffix);
392                 WriteGradient(fname, ffd, l, gradient);
393               }
394             } else {
395               snprintf(fname, sz, "%senergy_gradient_wrt_ffd_%d_%s", prefix, i+1, suffix);
396               WriteGradient(fname, ffd, 0, g);
397             }
398             g += ffd->NumberOfDOFs();
399           }
400           ffd = NULL;
401         } else if (lin) {
402           snprintf(fname, sz, "%senergy_gradient%s.txt", prefix, suffix);
403           ofstream of(fname);
404           for (int dof = 0; dof < r->_Energy.NumberOfDOFs(); ++dof) {
405             of << gradient[dof] << "\n";
406           }
407           of.close();
408         }
409 
410       }
411 
412       // Write current transformation estimate
413       snprintf(fname, sz, "%stransformation%s.dof.gz", prefix, suffix);
414       if (lin) {
415         WriteTransformation(fname, lin, r->_TargetOffset, r->_SourceOffset);
416       } else {
417         r->_Transformation->Write(fname);
418         if (ffd && r->_Input.empty() && r->NumberOfPointSets() > 0 && debug >= 4) {
419           snprintf(fname, sz, "%stransformation%s.vtp", prefix, suffix);
420           WriteAsVTKDataSet(fname, ffd, gradient);
421         }
422       }
423     } break;
424 
425     case AcceptedStepEvent:
426     case RejectedStepEvent: {
427       if (debug >= 5) {
428 
429         // Write updated input of data fidelity terms
430         r->_Energy.WriteDataSets(prefix, suffix, false);
431 
432         // Write current transformation estimate
433         snprintf(fname, sz, "%stransformation%s.dof.gz", prefix, suffix);
434         if (lin) WriteTransformation(fname, lin, r->_TargetOffset, r->_SourceOffset);
435         else     r->_Transformation->Write(fname);
436 
437       }
438     } break;
439 
440     // -------------------------------------------------------------------------
441     // Unhandled event
442     default: break;
443   }
444 }
445 
446 
447 } // namespace mirtk
448