1 /*=========================================================================
2  *
3  *  Copyright Insight Software Consortium
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at
8  *
9  *         http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  *=========================================================================*/
18 #ifndef itkFEMRegistrationFilter_hxx
19 #define itkFEMRegistrationFilter_hxx
20 
21 #include "itkFEMRegistrationFilter.h"
22 
23 #include "itkFEMElements.h"
24 #include "itkFEMLoadBC.h"
25 
26 #include "itkMath.h"
27 #include "itkGroupSpatialObject.h"
28 #include "itkLinearInterpolateImageFunction.h"
29 #include "itkSpatialObject.h"
30 #include "itkFEMObjectSpatialObject.h"
31 #include "itkDisplacementFieldJacobianDeterminantFilter.h"
32 #include "itkStatisticsImageFilter.h"
33 #include "itkRecursiveGaussianImageFilter.h"
34 
35 #include "vnl/algo/vnl_determinant.h"
36 #include "itkMath.h"
37 
38 namespace itk
39 {
40 namespace fem
41 {
42 
43 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
FEMRegistrationFilter()44 FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::FEMRegistrationFilter() :
45   m_DoLineSearchOnImageEnergy( 1 ),
46   m_LineSearchMaximumIterations( 100 ),
47   m_TotalIterations( 0 ),
48   m_MaxLevel( 1 ),
49   m_FileCount( 0 ),
50   m_CurrentLevel( 0 ),
51   m_WhichMetric( 0 ),
52   m_TimeStep( 1 ),
53   m_Energy( 0.0 ),
54   m_MinE( vnl_huge_val( 0 ) ),
55   m_MinJacobian( 1.0 ),
56   m_Alpha( 1.0 ),
57   m_UseLandmarks( false ),
58   m_UseMassMatrix( true ),
59   m_UseNormalizedGradient( false ),
60   m_CreateMeshFromImage( true ),
61   m_EmployRegridding( 1 ),
62   m_DescentDirection( positive ),
63   m_EnergyReductionFactor( 0.0 ),
64   m_MaximumError( 0.1 ),
65   m_MaximumKernelWidth( 30 )
66 {
67   this->SetNumberOfRequiredInputs(2);
68 
69   m_NumberOfIntegrationPoints.set_size(1);
70   m_NumberOfIntegrationPoints[m_CurrentLevel] = 4;
71 
72   m_MetricWidth.set_size(1);
73   m_MetricWidth[m_CurrentLevel] = 3;
74 
75   m_Maxiters.set_size(1);
76   m_Maxiters[m_CurrentLevel] = 1;
77 
78   m_E.set_size(1);
79   m_E[m_CurrentLevel] = 1.;
80 
81   m_Rho.set_size(1);
82   m_Rho[m_CurrentLevel] = 1.;
83 
84   m_Gamma.set_size(1);
85   m_Gamma[m_CurrentLevel] = 1;
86 
87   m_CurrentLevelImageSize.Fill( 0 );
88 
89   m_MeshPixelsPerElementAtEachResolution.set_size(1);
90   m_MeshPixelsPerElementAtEachResolution[m_CurrentLevel] = 1;
91 
92   m_ImageScaling.Fill( 1 );
93   m_CurrentImageScaling.Fill( 1 );
94   m_FullImageSize.Fill( 0 );
95   m_ImageOrigin.Fill( 0 );
96   m_StandardDeviations.Fill( 1.0 );
97 
98   // Set up the default interpolator
99   typename DefaultInterpolatorType::Pointer interp = DefaultInterpolatorType::New();
100   m_Interpolator = static_cast<InterpolatorType *>( interp.GetPointer() );
101   m_Interpolator->SetInputImage(m_Field);
102 }
103 
104 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
~FEMRegistrationFilter()105 FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::~FEMRegistrationFilter()
106 {
107 }
108 
109 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
SetMaxLevel(unsigned int level)110 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::SetMaxLevel(unsigned int level)
111 {
112   m_MaxLevel = level;
113 
114   m_E.set_size(level);
115   m_Gamma.set_size(level);
116   m_Rho.set_size(level);
117   m_Maxiters.set_size(level);
118   m_MeshPixelsPerElementAtEachResolution.set_size(level);
119   m_NumberOfIntegrationPoints.set_size(level);
120   m_MetricWidth.set_size(level);
121 
122   for(unsigned int i = 0; i < level; i++ )
123     {
124     m_Gamma[i] = 1;
125     m_E[i] = 1.0;
126     m_Rho[i] = 1.0;
127     m_Maxiters[i] = 1;
128     m_MeshPixelsPerElementAtEachResolution[i] = 1;
129     m_NumberOfIntegrationPoints[i] = 4;
130     m_MetricWidth[i] = 3;
131     }
132 }
133 
134 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
SetStandardDeviations(double value)135 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::SetStandardDeviations(double value)
136 {
137   unsigned int j;
138 
139   for( j = 0; j < ImageDimension; j++ )
140   {
141     if( Math::NotExactlyEquals(value, m_StandardDeviations[j]) )
142     {
143       break;
144     }
145   }
146   if( j < ImageDimension )
147   {
148     this->Modified();
149     for( j = 0; j < ImageDimension; j++ )
150     {
151       m_StandardDeviations[j] = value;
152     }
153   }
154 }
155 
156 
157 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
RunRegistration()158 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::RunRegistration()
159 {
160 
161   MultiResSolve();
162 
163   if( m_Field )
164     {
165     if( m_TotalField )
166       {
167       m_Field = m_TotalField;
168       }
169     this->ComputeJacobian( );
170     WarpImage(m_OriginalMovingImage);
171     }
172 }
173 
174 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
SetMovingImage(MovingImageType * R)175 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::SetMovingImage(MovingImageType* R)
176 {
177   m_MovingImage = R;
178   if( m_TotalIterations == 0 )
179     {
180     m_OriginalMovingImage = R;
181     this->ProcessObject::SetNthInput( 0, const_cast<MovingImageType *>( R ) );
182     }
183 }
184 
185 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
SetFixedImage(FixedImageType * T)186 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::SetFixedImage(FixedImageType* T)
187 {
188   m_FixedImage = T;
189   m_FullImageSize = m_FixedImage->GetLargestPossibleRegion().GetSize();
190 
191   if( m_TotalIterations == 0 )
192     {
193     this->ProcessObject::SetNthInput( 1, const_cast<FixedImageType *>( T ) );
194     }
195   VectorType disp;
196   for( unsigned int i = 0; i < ImageDimension; i++ )
197     {
198     disp[i] = 0.0;
199     m_ImageOrigin[i] = 0;
200     }
201 
202   m_CurrentLevelImageSize = m_FullImageSize;
203 }
204 
205 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
SetInputFEMObject(FEMObjectType * F,unsigned int level)206 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::SetInputFEMObject(FEMObjectType* F,
207                                                                                      unsigned int level)
208 {
209   this->ProcessObject::SetNthInput( 2 + level, const_cast<FEMObjectType *>( F ) );
210 }
211 
212 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
213 typename FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::FEMObjectType *
GetInputFEMObject(unsigned int level)214 FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::GetInputFEMObject(unsigned int level)
215 {
216   return static_cast<FEMObjectType *>(this->ProcessObject::GetInput(2 + level) );
217 }
218 
219 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
ChooseMetric(unsigned int which)220 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::ChooseMetric(unsigned int which)
221 {
222   // Choose the similarity Function
223 
224   using MetricType0 = itk::MeanSquareRegistrationFunction<FixedImageType, MovingImageType, FieldType>;
225   using MetricType1 = itk::NCCRegistrationFunction<FixedImageType, MovingImageType, FieldType>;
226   using MetricType2 = itk::MIRegistrationFunction<FixedImageType, MovingImageType, FieldType>;
227   using MetricType3 = itk::DemonsRegistrationFunction<FixedImageType, MovingImageType, FieldType>;
228 
229   m_WhichMetric = (unsigned int)which;
230 
231   switch( which )
232     {
233     case 0:
234       m_Metric = MetricType0::New();
235       SetDescentDirectionMinimize();
236       break;
237     case 1:
238       m_Metric = MetricType1::New();
239       SetDescentDirectionMinimize();
240       break;
241     case 2:
242       m_Metric = MetricType2::New();
243       SetDescentDirectionMaximize();
244       break;
245     case 3:
246       m_Metric = MetricType3::New();
247       SetDescentDirectionMinimize();
248       break;
249     default:
250       m_Metric = MetricType0::New();
251       SetDescentDirectionMinimize();
252     }
253 
254   m_Metric->SetGradientStep( m_Gamma[m_CurrentLevel] );
255   m_Metric->SetNormalizeGradient( m_UseNormalizedGradient );
256 }
257 
258 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
WarpImage(const MovingImageType * ImageToWarp)259 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::WarpImage( const MovingImageType * ImageToWarp)
260 {
261   typename WarperType::Pointer warper = WarperType::New();
262   using WarperCoordRepType = typename WarperType::CoordRepType;
263   using InterpolatorType1 =
264       itk::LinearInterpolateImageFunction<MovingImageType, WarperCoordRepType>;
265   typename InterpolatorType1::Pointer interpolator = InterpolatorType1::New();
266 
267   warper = WarperType::New();
268   warper->SetInput( ImageToWarp );
269   warper->SetDisplacementField( m_Field );
270   warper->SetInterpolator( interpolator );
271   warper->SetOutputOrigin( m_FixedImage->GetOrigin() );
272   warper->SetOutputSpacing( m_FixedImage->GetSpacing() );
273   warper->SetOutputDirection( m_FixedImage->GetDirection() );
274   typename FixedImageType::PixelType padValue = 0;
275   warper->SetEdgePaddingValue( padValue );
276   warper->Update();
277 
278   m_WarpedImage = warper->GetOutput();
279 }
280 
281 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
CreateMesh(unsigned int PixelsPerElement,SolverType * solver)282 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::CreateMesh(unsigned int PixelsPerElement,
283                                                                               SolverType *solver)
284 {
285   vnl_vector<unsigned int> pixPerElement;
286   pixPerElement.set_size( ImageDimension );
287   pixPerElement.fill( PixelsPerElement );
288 
289   if( ImageDimension == 2 && dynamic_cast<Element2DC0LinearQuadrilateral *>(&*m_Element) != nullptr )
290     {
291     m_Material->SetYoungsModulus(this->GetElasticity(m_CurrentLevel) );
292 
293     itkDebugMacro( << " Generating regular Quad mesh " << std::endl );
294     typename ImageToMeshType::Pointer meshFilter = ImageToMeshType::New();
295     meshFilter->SetInput( m_MovingImage );
296     meshFilter->SetPixelsPerElement( pixPerElement );
297     meshFilter->SetElement( &*m_Element );
298     meshFilter->SetMaterial( m_Material );
299     meshFilter->Update();
300     m_FEMObject = meshFilter->GetOutput();
301     m_FEMObject->FinalizeMesh();
302     itkDebugMacro( << " Generating regular mesh done " << std::endl );
303     }
304   else if( ImageDimension == 3 && dynamic_cast<Element3DC0LinearHexahedron *>(&*m_Element) != nullptr )
305     {
306     m_Material->SetYoungsModulus( this->GetElasticity(m_CurrentLevel) );
307 
308     itkDebugMacro( << " Generating regular Hex mesh " << std::endl );
309     typename ImageToMeshType::Pointer meshFilter = ImageToMeshType::New();
310     meshFilter->SetInput( m_MovingImage );
311     meshFilter->SetPixelsPerElement( pixPerElement );
312     meshFilter->SetElement( &*m_Element );
313     meshFilter->SetMaterial( m_Material );
314     meshFilter->Update();
315     m_FEMObject = meshFilter->GetOutput();
316     m_FEMObject->FinalizeMesh();
317     itkDebugMacro( << " Generating regular mesh done " << std::endl );
318     }
319   else
320     {
321     FEMException e(__FILE__, __LINE__);
322     e.SetDescription("CreateMesh - wrong image or element type");
323     e.SetLocation(ITK_LOCATION);
324     throw e;
325     }
326 
327   if( m_UseLandmarks )
328     {
329     for( unsigned int i = 0; i < m_LandmarkArray.size(); i++ )
330       {
331       m_FEMObject->AddNextLoad( (m_LandmarkArray[i]) );
332       }
333     }
334 
335   solver->SetInput(m_FEMObject);
336   solver->InitializeInterpolationGrid(m_FixedImage->GetBufferedRegion(),
337                                         m_FixedImage->GetOrigin(),
338                                         m_FixedImage->GetSpacing(),
339                                         m_FixedImage->GetDirection());
340 }
341 
342 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
343 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>
ApplyImageLoads(TMovingImage * movingimg,TFixedImage * fixedimg)344 ::ApplyImageLoads(TMovingImage*  movingimg, TFixedImage* fixedimg )
345 {
346   m_Load = FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::ImageMetricLoadType::New();
347   m_Load->SetMovingImage(movingimg);
348   m_Load->SetFixedImage(fixedimg);
349   if( !m_Field )
350     {
351     this->InitializeField();
352     }
353   m_Load->SetDisplacementField( this->GetDisplacementField() );
354   m_Load->SetMetric(m_Metric);
355   m_Load->InitializeMetric();
356   m_Load->SetGamma(m_Gamma[m_CurrentLevel]);
357   ImageSizeType r;
358   for( unsigned int dd = 0; dd < ImageDimension; dd++ )
359     {
360     r[dd] = m_MetricWidth[m_CurrentLevel];
361     }
362   m_Load->SetMetricRadius(r);
363   m_Load->SetNumberOfIntegrationPoints(m_NumberOfIntegrationPoints[m_CurrentLevel]);
364   m_Load->SetGlobalNumber(m_FEMObject->GetNumberOfLoads() + 1);
365   if( m_DescentDirection == positive )
366     {
367     m_Load->SetDescentDirectionMinimize( );
368     }
369   else
370     {
371     m_Load->SetDescentDirectionMaximize( );
372     }
373   m_FEMObject->AddNextLoad(m_Load);
374   m_Load = dynamic_cast<typename FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::ImageMetricLoadType *>
375     (&*m_FEMObject->GetLoadWithGlobalNumber(m_FEMObject->GetNumberOfLoads() ) );
376 }
377 
378 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
ApplyLoads(ImageSizeType ImgSz,double * scaling)379 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::ApplyLoads(
380   ImageSizeType ImgSz, double* scaling)
381 {
382   // Apply the boundary conditions. We pin the image corners.
383   // First compute which elements these will be.
384   //
385   itkDebugMacro( << " Applying loads " );
386 
387   vnl_vector<Float> pd; pd.set_size(ImageDimension);
388   vnl_vector<Float> pu; pu.set_size(ImageDimension);
389 
390   // Now scale the landmarks
391 
392   itkDebugMacro( " Number of LM loads: " << m_LandmarkArray.size() );
393 
394   // Step over all the loads again to scale them by the global landmark weight.
395   if( !m_LandmarkArray.empty() )
396     {
397     for( unsigned int lmind = 0; lmind < m_LandmarkArray.size(); lmind++ )
398       {
399       m_LandmarkArray[lmind]->GetElementArray()[0] = nullptr;
400 
401       itkDebugMacro( << " Prescale Pt: " << m_LandmarkArray[lmind]->GetTarget() );
402       if( scaling )
403         {
404         m_LandmarkArray[lmind]->ScalePointAndForce(scaling, m_EnergyReductionFactor);
405         itkDebugMacro( << " Postscale Pt: " << m_LandmarkArray[lmind]->GetTarget() << "; scale: " << scaling[0] );
406         }
407 
408       pu = m_LandmarkArray[lmind]->GetSource();
409       pd = m_LandmarkArray[lmind]->GetPoint();
410 
411       int numElements = m_FEMObject->GetNumberOfElements();
412       for( int i = 0; i < numElements; i++ )
413         {
414         if( m_FEMObject->GetElement(i)->GetLocalFromGlobalCoordinates(pu, pd ) )
415           {
416           m_LandmarkArray[lmind]->SetPoint(pd);
417           m_LandmarkArray[lmind]->GetElementArray()[0] =  m_FEMObject->GetElement(i);
418           }
419         }
420 
421       m_LandmarkArray[lmind]->SetGlobalNumber(lmind);
422       LoadLandmark::Pointer l5 = dynamic_cast<LoadLandmark *>( &*m_LandmarkArray[lmind]->CreateAnother() );
423       m_FEMObject->AddNextLoad(l5);
424       }
425     itkDebugMacro( << " Landmarks done" );
426     }
427 
428   // Now apply the BC loads
429   LoadBC::Pointer l1;
430 
431   // Pin one corner of image
432   unsigned int CornerCounter, ii, EdgeCounter = 0;
433 
434   int numNodes = m_FEMObject->GetNumberOfNodes();
435 
436   Element::VectorType coord;
437 
438   bool EdgeFound;
439   unsigned int nodect = 0;
440   for( int i = 0; i < numNodes; i++ )
441     {
442     if( EdgeCounter >= ImageDimension )
443       {
444       return;
445       }
446 
447     coord = m_FEMObject->GetNode(i)->GetCoordinates();
448     CornerCounter = 0;
449     for( ii = 0; ii < ImageDimension; ii++ )
450       {
451       if( Math::AlmostEquals( coord[ii], m_ImageOrigin[ii] )
452        || Math::AlmostEquals( coord[ii], ImgSz[ii] - 1 ) )
453         {
454         CornerCounter++;
455         }
456       }
457 
458     if( CornerCounter == ImageDimension ) // the node is located at a true corner
459       {
460       unsigned int ndofpernode = (*(m_FEMObject->GetNode(i)->m_elements.begin() ) )->GetNumberOfDegreesOfFreedomPerNode();
461       unsigned int numnodesperelt = (*(m_FEMObject->GetNode(i)->m_elements.begin() ) )->GetNumberOfNodes();
462       unsigned int whichnode;
463 
464       unsigned int maxnode = numnodesperelt - 1;
465 
466       for( auto elt = m_FEMObject->GetNode(i)->m_elements.begin();
467            elt != m_FEMObject->GetNode(i)->m_elements.end(); elt++ )
468         {
469         for( whichnode = 0; whichnode <= maxnode; whichnode++ )
470           {
471           coord = (*elt)->GetNode(whichnode)->GetCoordinates();
472           CornerCounter = 0;
473           for( ii = 0; ii < ImageDimension; ii++ )
474             {
475             if( Math::AlmostEquals( coord[ii], m_ImageOrigin[ii] )
476              || Math::AlmostEquals( coord[ii], ImgSz[ii] - 1 ) )
477               {
478               CornerCounter++;
479               }
480             }
481           if( CornerCounter == ImageDimension - 1 )
482             {
483             EdgeFound = true;
484             }
485           else
486             {
487             EdgeFound = false;
488             }
489           if( EdgeFound )
490             {
491             for( unsigned int jj = 0; jj < ndofpernode; jj++ )
492               {
493               itkDebugMacro( " Which node " << whichnode );
494               itkDebugMacro( " Edge coord " << coord );
495 
496               l1 = LoadBC::New();
497               // now we get the element from the node -- we assume we need fix the dof only once
498               // even if more than one element shares it.
499 
500               l1->SetElement(Element::ConstPointer(*elt));
501               unsigned int localdof = whichnode * ndofpernode + jj;
502               l1->SetDegreeOfFreedom(localdof);
503               l1->SetValue(vnl_vector<double>(1, 0.0) );
504 
505               m_FEMObject->AddNextLoad(l1);
506               }
507             EdgeCounter++;
508             }
509           }
510         } // end elt loop
511       }
512     nodect++;
513     itkDebugMacro( << " Node: " << nodect );
514     }
515 }
516 
517 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
IterativeSolve(SolverType * solver)518 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::IterativeSolve(SolverType *solver)
519 {
520   if( !m_Load )
521     {
522     itkExceptionMacro( <<"No Load set" );
523     }
524 
525   bool Done = false;
526   unsigned int iters = 0;
527   m_MinE = 10.e99;
528   Float deltE = 0;
529   while( !Done && iters < m_Maxiters[m_CurrentLevel] )
530     {
531     const Float lastdeltE = deltE;
532     const unsigned int DLS = m_DoLineSearchOnImageEnergy;
533     //  Reset the variational similarity term to zero.
534 
535     Float LastE = m_Load->GetCurrentEnergy();
536     m_Load->SetCurrentEnergy(0.0);
537     m_Load->InitializeMetric();
538 
539     if( !m_Field )
540       {
541       itkExceptionMacro( <<"No Field set" );
542       }
543     solver->SetUseMassMatrix( m_UseMassMatrix );
544 
545     // Solve the system of equations for displacements (u=K^-1*F)
546     solver->Modified();
547     solver->Update();
548     m_Load->PrintCurrentEnergy();
549 
550     if( m_DescentDirection == 1 )
551       {
552       deltE = (LastE - m_Load->GetCurrentEnergy() );
553       }
554     else
555       {
556       deltE = (m_Load->GetCurrentEnergy() - LastE );
557       }
558 
559     if(  DLS == 2 && deltE < 0.0 )
560       {
561       itkDebugMacro( << " Line search " );
562       constexpr float tol = 1.0; // ((0.01  < LastE) ? 0.01 : LastE/10.);
563       LastE = this->GoldenSection(solver, tol, m_LineSearchMaximumIterations);
564       deltE = (m_MinE - LastE);
565       itkDebugMacro( << " Line search done " << std::endl );
566       }
567 
568     iters++;
569 
570     if( deltE == 0.0 )
571       {
572       itkDebugMacro( << " No change in energy " << std::endl);
573       Done = true;
574       }
575     if( (DLS == 0)  && (  iters >= m_Maxiters[m_CurrentLevel] ) )
576       {
577       Done = true;
578       }
579     else if( (DLS > 0) &&
580              ( iters >= m_Maxiters[m_CurrentLevel] || (deltE < 0.0 && iters > 5 && lastdeltE < 0.0) ) )
581       {
582       Done = true;
583       }
584     float curmaxsol = solver->GetCurrentMaxSolution();
585     if( Math::AlmostEquals( curmaxsol, 0.0f ) )
586       {
587       curmaxsol = 1.0;
588       }
589     Float mint = m_Gamma[m_CurrentLevel] / curmaxsol;
590     if( mint > 1 )
591       {
592       mint = 1.0;
593       }
594 
595     if( solver->GetCurrentMaxSolution() < 0.01 && iters > 2 )
596       {
597       Done = true;
598       }
599     solver->AddToDisplacements(mint);
600     m_MinE = LastE;
601 
602     InterpolateVectorField(solver);
603 
604     if( m_EmployRegridding != 0 )
605       {
606       if( iters % m_EmployRegridding == 0  )
607         {
608         this->EnforceDiffeomorphism(1.0, solver, true);
609         }
610       }
611     itkDebugMacro( << " min E: " << m_MinE << "; delt E: " << deltE << "; iters: " << iters << std::endl );
612     m_TotalIterations++;
613     }
614 }
615 
616 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
617 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>
InitializeField()618 ::InitializeField()
619 {
620   m_Field = FieldType::New();
621 
622   m_FieldRegion.SetSize(m_CurrentLevelImageSize );
623   m_Field->SetOrigin( m_FixedImage->GetOrigin() );
624   m_Field->SetSpacing( m_FixedImage->GetSpacing() );
625   m_Field->SetDirection( m_FixedImage->GetDirection() );
626   m_Field->SetLargestPossibleRegion( m_FieldRegion );
627   m_Field->SetBufferedRegion( m_FieldRegion );
628   m_Field->SetLargestPossibleRegion( m_FieldRegion );
629   m_Field->Allocate();
630 
631   VectorType disp;
632   for( unsigned int t = 0; t < ImageDimension; t++ )
633     {
634     disp[t] = 0.0;
635     }
636   FieldIterator fieldIter( m_Field, m_FieldRegion );
637   fieldIter.GoToBegin();
638   for(; !fieldIter.IsAtEnd(); ++fieldIter )
639     {
640     fieldIter.Set(disp);
641     }
642 }
643 
644 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
645 void
InterpolateVectorField(SolverType * solver)646 FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::InterpolateVectorField(SolverType *solver)
647 {
648   typename FieldType::Pointer field = m_Field;
649 
650   if( !field )
651     {
652     this->InitializeField();
653     }
654   m_FieldSize = field->GetLargestPossibleRegion().GetSize();
655 
656   itkDebugMacro( << " Interpolating vector field of size " << m_FieldSize);
657 
658   Float rstep, sstep, tstep;
659 
660   vnl_vector<double> Pos;  // solution at the point
661   vnl_vector<double> Sol;  // solution at the local point
662   vnl_vector<double> Gpt;  // global position given by local point
663 
664   VectorType disp;
665   for( unsigned int t = 0; t < ImageDimension; t++ )
666     {
667     disp[t] = 0.0;
668     }
669   FieldIterator fieldIter( field, field->GetLargestPossibleRegion() );
670 
671   fieldIter.GoToBegin();
672   typename FixedImageType::IndexType rindex = fieldIter.GetIndex();
673 
674   Sol.set_size(ImageDimension);
675   Gpt.set_size(ImageDimension);
676 
677   if( ImageDimension == 2 )
678     {
679     Element::ConstPointer eltp;
680     for(; !fieldIter.IsAtEnd(); ++fieldIter )
681       {
682       // Get element pointer from the solver elt pointer image
683       typename FieldType::PointType physicalPoint;
684       rindex = fieldIter.GetIndex();
685       field->TransformIndexToPhysicalPoint(rindex, physicalPoint);
686       for( unsigned int d = 0; d < ImageDimension; d++ )
687         {
688         Gpt[d] = (double) (physicalPoint[d]);
689         }
690 
691       eltp = solver->GetElementAtPoint(Gpt);
692       if( eltp )
693         {
694         eltp->GetLocalFromGlobalCoordinates(Gpt, Pos);
695 
696         unsigned int Nnodes = eltp->GetNumberOfNodes();
697         typename Element::VectorType shapef(Nnodes);
698         shapef = eltp->ShapeFunctions(Pos);
699         Float solval;
700         for( unsigned int f = 0; f < ImageDimension; f++ )
701           {
702           solval = 0.0;
703           for( unsigned int n = 0; n < Nnodes; n++ )
704             {
705             solval += shapef[n] * solver->GetLinearSystem()->GetSolutionValue(
706                 eltp->GetNode(n)->GetDegreeOfFreedom(f), solver->GetTotalSolutionIndex() );
707             }
708           Sol[f] = solval;
709           disp[f] = (Float) 1.0 * Sol[f];
710           }
711         field->SetPixel(rindex, disp );
712         }
713       }
714     }
715 
716   if( ImageDimension == 3 )
717     {
718     // FIXME SHOULD BE 2.0 over meshpixperelt
719     rstep = 1.25 / ( (double)m_MeshPixelsPerElementAtEachResolution[m_CurrentLevel]);
720     sstep = 1.25 / ( (double)m_MeshPixelsPerElementAtEachResolution[m_CurrentLevel]);
721     tstep = 1.25 / ( (double)m_MeshPixelsPerElementAtEachResolution[m_CurrentLevel]);
722 
723     Pos.set_size(ImageDimension);
724     int numElements = solver->GetInput()->GetNumberOfElements();
725     for( int i = 0; i < numElements; i++ )
726       {
727       Element::Pointer eltp = solver->GetInput()->GetElement(i);
728       for( double r = -1.0; r <= 1.0; r = r + rstep )
729         {
730         for( double s = -1.0; s <= 1.0; s = s + sstep )
731           {
732           for( double t = -1.0; t <= 1.0; t = t + tstep )
733             {
734             Pos[0] = r;
735             Pos[1] = s;
736             Pos[2] = t;
737 
738             unsigned int numNodes = eltp->GetNumberOfNodes();
739             typename Element::VectorType shapef(numNodes);
740 
741 #define FASTHEX
742 #ifdef FASTHEX
743 // FIXME temporarily using hexahedron shape f for speed
744             shapef[0] = (1 - r) * (1 - s) * (1 - t) * 0.125;
745             shapef[1] = (1 + r) * (1 - s) * (1 - t) * 0.125;
746             shapef[2] = (1 + r) * (1 + s) * (1 - t) * 0.125;
747             shapef[3] = (1 - r) * (1 + s) * (1 - t) * 0.125;
748             shapef[4] = (1 - r) * (1 - s) * (1 + t) * 0.125;
749             shapef[5] = (1 + r) * (1 - s) * (1 + t) * 0.125;
750             shapef[6] = (1 + r) * (1 + s) * (1 + t) * 0.125;
751             shapef[7] = (1 - r) * (1 + s) * (1 + t) * 0.125;
752 #else
753             shapef = (*eltp)->ShapeFunctions(Pos);
754 #endif
755 
756             Float solval, posval;
757             bool inimage = true;
758             typename FixedImageType::PointType physicalPoint;
759 
760             for( unsigned int f = 0; f < ImageDimension; f++ )
761               {
762               solval = 0.0;
763               posval = 0.0;
764               for( unsigned int n = 0; n < numNodes; n++ )
765                 {
766                 posval += shapef[n] * ( ( (eltp)->GetNodeCoordinates(n) )[f]);
767                 solval += shapef[n] * solver->GetLinearSystem()->GetSolutionValue(
768                     (eltp)->GetNode(n)->GetDegreeOfFreedom(f), solver->GetTotalSolutionIndex() );
769                 }
770               Sol[f] = solval;
771               Gpt[f] = posval;
772               disp[f] = (Float) 1.0 * Sol[f];
773               physicalPoint[f] = Gpt[f];
774               }
775 
776             inimage = m_FixedImage->TransformPhysicalPointToIndex(physicalPoint, rindex);
777             if( inimage )
778               {
779               field->SetPixel(rindex, disp);
780               }
781             }
782           }
783         } // end of for loops
784       }   // end of elt array loop
785 
786     }
787 
788   // Ensure that the values are exact at the nodes. They won't necessarily be unless we use this code.
789   itkDebugMacro( << " Interpolation done " << std::endl);
790 }
791 
792 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
ComputeJacobian()793 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::ComputeJacobian( )
794 {
795   m_MinJacobian = 1.0;
796 
797   using JacobianFilterType = typename itk::DisplacementFieldJacobianDeterminantFilter< FieldType, float, FloatImageType >;
798   typename JacobianFilterType::Pointer jacobianFilter = JacobianFilterType::New();
799   jacobianFilter->SetInput( m_Field );
800   jacobianFilter->Update( );
801   m_FloatImage = jacobianFilter->GetOutput();
802 
803   using StatisticsFilterType = typename itk::StatisticsImageFilter< FloatImageType >;
804   typename StatisticsFilterType::Pointer statisticsFilter = StatisticsFilterType::New();
805   statisticsFilter->SetInput( m_FloatImage );
806   statisticsFilter->Update( );
807 
808   m_MinJacobian = statisticsFilter->GetMinimum();
809 
810   itkDebugMacro( << " Min Jacobian: " << m_MinJacobian << std::endl);
811 }
812 
813 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
EnforceDiffeomorphism(float thresh,SolverType * solver,bool onlywriteimages)814 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::EnforceDiffeomorphism(float thresh,
815                                                                                          SolverType *solver,
816                                                                                          bool onlywriteimages )
817 {
818   itkDebugMacro( << " Checking Jacobian using threshold " << thresh );
819 
820   this->ComputeJacobian();
821 
822   if( m_MinJacobian < thresh )
823   {
824     // Smooth deformation field
825     this->SmoothDisplacementField();
826   }
827 
828   typename WarperType::Pointer warper = WarperType::New();
829   using WarperCoordRepType = typename WarperType::CoordRepType;
830   using InterpolatorType1 =
831       itk::LinearInterpolateImageFunction<MovingImageType, WarperCoordRepType>;
832   typename InterpolatorType1::Pointer interpolator = InterpolatorType1::New();
833 
834   // If using landmarks, warp them
835   if( m_UseLandmarks )
836     {
837     itkDebugMacro( << " Warping landmarks: " << m_LandmarkArray.size() );
838 
839     if( !m_LandmarkArray.empty() )
840       {
841       for( unsigned int lmind = 0; lmind < m_LandmarkArray.size(); lmind++ )
842         {
843         itkDebugMacro( << " Old source: " << m_LandmarkArray[lmind]->GetSource() );
844         itkDebugMacro( << " Target: " << m_LandmarkArray[lmind]->GetTarget() );
845 
846         // Convert the source to warped coords.
847         m_LandmarkArray[lmind]->GetSource() = m_LandmarkArray[lmind]->GetSource()
848             + (dynamic_cast<LoadLandmark *>( &*solver->GetOutput()->GetLoadWithGlobalNumber(lmind) )->GetForce() );
849         itkDebugMacro( << " New source: " << m_LandmarkArray[lmind]->GetSource() );
850         itkDebugMacro( << " Target: " << m_LandmarkArray[lmind]->GetTarget() );
851         LoadLandmark::Pointer l5 = dynamic_cast<LoadLandmark *>( &*m_LandmarkArray[lmind]->CreateAnother() );
852         solver->GetOutput()->AddNextLoad(l5);
853         }
854       itkDebugMacro( << " Warping landmarks done " );
855       }
856     else
857       {
858       itkDebugMacro( << " Landmark array empty " );
859       }
860     }
861 
862   // Store the total deformation by composing with the full field
863   if( !m_TotalField && !onlywriteimages )
864     {
865     itkDebugMacro( << " Allocating total deformation field " );
866 
867     m_TotalField = FieldType::New();
868 
869     m_FieldRegion.SetSize(m_Field->GetLargestPossibleRegion().GetSize() );
870     m_TotalField->SetLargestPossibleRegion( m_FieldRegion );
871     m_TotalField->SetBufferedRegion( m_FieldRegion );
872     m_TotalField->SetLargestPossibleRegion( m_FieldRegion );
873     m_TotalField->Allocate();
874 
875     VectorType disp;
876     disp.Fill(0.0);
877 
878     FieldIterator fieldIter( m_TotalField, m_FieldRegion );
879 
880     for( fieldIter.GoToBegin(); !fieldIter.IsAtEnd(); ++fieldIter )
881       {
882       fieldIter.Set(disp);
883       }
884     }
885 
886   if( onlywriteimages )
887     {
888     warper = WarperType::New();
889     warper->SetInput( m_OriginalMovingImage );
890     warper->SetDisplacementField( m_Field );
891     warper->SetInterpolator( interpolator );
892     warper->SetOutputOrigin( m_FixedImage->GetOrigin() );
893     warper->SetOutputSpacing( m_FixedImage->GetSpacing() );
894     warper->SetOutputDirection( m_FixedImage->GetDirection() );
895     typename MovingImageType::PixelType padValue = 0;
896     warper->SetEdgePaddingValue( padValue );
897     warper->Update();
898     m_WarpedImage = warper->GetOutput();
899     }
900   else if( m_TotalField )
901     {
902 
903     typename InterpolatorType::ContinuousIndexType inputIndex;
904 
905     using InterpolatedType = typename InterpolatorType::OutputType;
906 
907     InterpolatedType interpolatedValue;
908 
909     m_Interpolator->SetInputImage(m_Field);
910 
911     typename FixedImageType::IndexType index;
912     FieldIterator totalFieldIter( m_TotalField, m_TotalField->GetLargestPossibleRegion() );
913     totalFieldIter.GoToBegin();
914     unsigned int jj;
915     float pathsteplength = 0;
916     while( !totalFieldIter.IsAtEnd()  )
917       {
918       index = totalFieldIter.GetIndex();
919       for( jj = 0; jj < ImageDimension; jj++ )
920         {
921         inputIndex[jj] = (WarperCoordRepType) index[jj];
922         interpolatedValue[jj] = 0.0;
923         }
924 
925       if ( m_Interpolator->IsInsideBuffer( inputIndex ) )
926         {
927         interpolatedValue =
928             m_Interpolator->EvaluateAtContinuousIndex( inputIndex );
929         }
930       VectorType interped;
931       float temp = 0.0;
932       for( jj = 0; jj < ImageDimension; jj++ )
933         {
934         interped[jj] = interpolatedValue[jj];
935         temp += interped[jj] * interped[jj];
936         }
937       pathsteplength += std::sqrt(temp);
938       m_TotalField->SetPixel(index, m_TotalField->GetPixel(index) + interped);
939       ++totalFieldIter;
940       }
941 
942     itkDebugMacro( << " Incremental path length: " << pathsteplength );
943 
944     // Set the field to zero
945     FieldIterator fieldIter( m_Field, m_Field->GetLargestPossibleRegion() );
946     fieldIter.GoToBegin();
947     while( !fieldIter.IsAtEnd()  )
948       {
949       VectorType disp;
950       disp.Fill(0.0);
951       fieldIter.Set(disp);
952       ++fieldIter;
953       }
954 
955     // Now do the same for the solver
956     int numNodes = solver->GetOutput()->GetNumberOfNodes();
957     for( int i = 0; i < numNodes; i++ )
958       {
959       // Now put it into the solution!
960       for( unsigned int ii = 0; ii < ImageDimension; ii++ )
961         {
962         solver->GetLinearSystemWrapper()->
963         SetSolutionValue( (solver->GetOutput()->GetNode(i) )->GetDegreeOfFreedom(
964                               ii), 0.0, solver->GetTotalSolutionIndex() );
965         solver->GetLinearSystemWrapper()->
966         SetSolutionValue( (solver->GetOutput()->GetNode(i) )->GetDegreeOfFreedom(
967                               ii), 0.0, solver->GetSolutionTMinus1Index() );
968         }
969       }
970 
971     warper = WarperType::New();
972     warper->SetInput( m_OriginalMovingImage );
973     warper->SetDisplacementField( m_TotalField );
974     warper->SetInterpolator( interpolator );
975     warper->SetOutputOrigin( m_FixedImage->GetOrigin() );
976     warper->SetOutputSpacing( m_FixedImage->GetSpacing() );
977     warper->SetOutputDirection( m_FixedImage->GetDirection() );
978     typename FixedImageType::PixelType padValue = 0;
979     warper->SetEdgePaddingValue( padValue );
980     warper->Update();
981 
982     // Set it as the new moving image
983     this->SetMovingImage( warper->GetOutput() );
984 
985     m_WarpedImage = m_MovingImage;
986 
987     m_Load->SetMovingImage( this->GetMovingImage() );
988     }
989   itkDebugMacro( << " Enforcing diffeomorphism done " );
990 }
991 
992 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
993 void
SmoothDisplacementField()994 FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::SmoothDisplacementField()
995 {
996 
997   using GaussianFilterType = RecursiveGaussianImageFilter< FieldType, FieldType >;
998   typename GaussianFilterType::Pointer smoother = GaussianFilterType::New();
999 
1000   for( unsigned int dim = 0; dim < ImageDimension; ++dim )
1001     {
1002     // Sigma accounts for the subsampling of the pyramid
1003     double sigma = m_StandardDeviations[dim];
1004 
1005     smoother->SetInput( m_Field );
1006     smoother->SetSigma(sigma);
1007     smoother->SetDirection(dim);
1008     smoother->Update();
1009 
1010     m_Field = smoother->GetOutput();
1011     m_Field->DisconnectPipeline();
1012     }
1013 }
1014 
1015 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
1016 typename FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::FieldPointer
ExpandVectorField(ExpandFactorsType * expandFactors,FieldType * field)1017 FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::ExpandVectorField( ExpandFactorsType* expandFactors,
1018                                                                                  FieldType* field)
1019 {
1020   // Re-size the vector field
1021   if( !field )
1022     {
1023     field = m_Field;
1024     }
1025 
1026   itkDebugMacro( << " Input field size: " << m_Field->GetLargestPossibleRegion().GetSize() );
1027   itkDebugMacro( << " Expand factors: " );
1028   VectorType pad;
1029   for( unsigned int i = 0; i < ImageDimension; i++ )
1030     {
1031     pad[i] = 0.0;
1032     itkDebugMacro( << expandFactors[i] << " " );
1033     }
1034   itkDebugMacro( << std::endl );
1035   typename ExpanderType::Pointer m_FieldExpander = ExpanderType::New();
1036   m_FieldExpander->SetInput(field);
1037   m_FieldExpander->SetExpandFactors( expandFactors );
1038   // use default
1039 // TEST_RMV20100728   m_FieldExpander->SetEdgePaddingValue( pad );
1040   m_FieldExpander->UpdateLargestPossibleRegion();
1041 
1042   m_FieldSize = m_FieldExpander->GetOutput()->GetLargestPossibleRegion().GetSize();
1043 
1044   return m_FieldExpander->GetOutput();
1045 }
1046 
1047 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
SampleVectorFieldAtNodes(SolverType * solver)1048 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::SampleVectorFieldAtNodes(SolverType *solver)
1049 {
1050   // Here, we need to iterate through the nodes, get the nodal coordinates,
1051   // sample the VF at the node and place the values in the SolutionVector.
1052   int numNodes = solver->GetOutput()->GetNumberOfNodes();
1053   Element::VectorType coord;
1054   VectorType SolutionAtNode;
1055 
1056   m_Interpolator->SetInputImage(m_Field);
1057   for( int i = 0; i < numNodes; i++ )
1058     {
1059     coord = solver->GetOutput()->GetNode(i)->GetCoordinates();
1060     typename InterpolatorType::ContinuousIndexType inputIndex;
1061     using InterpolatedType = typename InterpolatorType::OutputType;
1062     InterpolatedType interpolatedValue;
1063     for( unsigned int jj = 0; jj < ImageDimension; jj++ )
1064       {
1065       inputIndex[jj] = (CoordRepType) coord[jj];
1066       interpolatedValue[jj] = 0.0;
1067       }
1068     if( m_Interpolator->IsInsideBuffer( inputIndex ) )
1069       {
1070       interpolatedValue =
1071         m_Interpolator->EvaluateAtContinuousIndex( inputIndex );
1072       }
1073     for( unsigned int jj = 0; jj < ImageDimension; jj++ )
1074       {
1075       SolutionAtNode[jj] = interpolatedValue[jj];
1076       }
1077     // Now put it into the solution!
1078     for( unsigned int ii = 0; ii < ImageDimension; ii++ )
1079       {
1080       Float Sol = SolutionAtNode[ii];
1081       solver->GetLinearSystemWrapper()->
1082       SetSolutionValue(solver->GetOutput()->GetNode(i)->GetDegreeOfFreedom(ii), Sol, solver->GetTotalSolutionIndex() );
1083       solver->GetLinearSystemWrapper()->
1084       SetSolutionValue(solver->GetOutput()->GetNode(i)->GetDegreeOfFreedom(
1085                          ii), Sol, solver->GetSolutionTMinus1Index() );
1086       }
1087     }
1088 }
1089 
1090 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
PrintVectorField(unsigned int modnum)1091 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::PrintVectorField(unsigned int modnum)
1092 {
1093   FieldIterator fieldIter( m_Field, m_Field->GetLargestPossibleRegion() );
1094 
1095   fieldIter.GoToBegin();
1096   unsigned int ct = 0;
1097 
1098   float max = 0;
1099   while( !fieldIter.IsAtEnd()  )
1100     {
1101     VectorType disp = fieldIter.Get();
1102     if( (ct % modnum) == 0 )
1103       {
1104       itkDebugMacro( << " Field pix: " << fieldIter.Get() << std::endl);
1105       }
1106     for( unsigned int i = 0; i < ImageDimension; i++ )
1107       {
1108       if( std::fabs(disp[i]) > max )
1109         {
1110         max = std::fabs(disp[i]);
1111         }
1112       }
1113     ++fieldIter;
1114     ct++;
1115 
1116     }
1117 
1118   itkDebugMacro( << " Max vec: " << max << std::endl );
1119 }
1120 
1121 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
MultiResSolve()1122 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::MultiResSolve()
1123 {
1124   for( m_CurrentLevel = 0; m_CurrentLevel < m_MaxLevel; m_CurrentLevel++ )
1125     {
1126     itkDebugMacro( << " Beginning level " << m_CurrentLevel << std::endl );
1127 
1128     typename SolverType::Pointer solver = SolverType::New();
1129 
1130     if( m_Maxiters[m_CurrentLevel] > 0 )
1131       {
1132       unsigned int meshResolution = this->m_MeshPixelsPerElementAtEachResolution(m_CurrentLevel);
1133 
1134       solver->SetTimeStep(m_TimeStep);
1135       solver->SetRho(m_Rho[m_CurrentLevel]);
1136       solver->SetAlpha(m_Alpha);
1137 
1138       if( m_CreateMeshFromImage )
1139         {
1140         this->CreateMesh(meshResolution, solver);
1141         }
1142       else
1143         {
1144         m_FEMObject = this->GetInputFEMObject( m_CurrentLevel );
1145         }
1146 
1147       this->ApplyLoads(m_FullImageSize, nullptr);
1148       this->ApplyImageLoads(m_MovingImage, m_FixedImage );
1149 
1150 
1151       unsigned int ndofpernode = m_Element->GetNumberOfDegreesOfFreedomPerNode();
1152       unsigned int numnodesperelt = m_Element->GetNumberOfNodes() + 1;
1153       unsigned int ndof = solver->GetInput()->GetNumberOfDegreesOfFreedom();
1154       unsigned int nzelts;
1155       nzelts = numnodesperelt * ndofpernode * ndof;
1156 
1157       // Used when reading a mesh from file
1158       // nzelts=((2*numnodesperelt*ndofpernode*ndof > 25*ndof) ? 2*numnodesperelt*ndofpernode*ndof : 25*ndof);
1159 
1160       LinearSystemWrapperItpack itpackWrapper;
1161       unsigned int maxits = 2 * solver->GetInput()->GetNumberOfDegreesOfFreedom();
1162       itpackWrapper.SetMaximumNumberIterations(maxits);
1163       itpackWrapper.SetTolerance(1.e-1);
1164       itpackWrapper.JacobianConjugateGradient();
1165       itpackWrapper.SetMaximumNonZeroValuesInMatrix(nzelts);
1166       solver->SetLinearSystemWrapper(&itpackWrapper);
1167       solver->SetUseMassMatrix( m_UseMassMatrix );
1168       if( m_CurrentLevel > 0 )
1169         {
1170         this->SampleVectorFieldAtNodes(solver);
1171         }
1172       this->IterativeSolve(solver);
1173       }
1174 
1175     // Now expand the field for the next level, if necessary.
1176 
1177     itkDebugMacro( << " End level: " << m_CurrentLevel );
1178 
1179     } // end image resolution loop
1180 
1181   if( m_TotalField )
1182     {
1183     itkDebugMacro( << " Copy field: " <<  m_TotalField->GetLargestPossibleRegion().GetSize() );
1184     itkDebugMacro( << " To: " << m_Field->GetLargestPossibleRegion().GetSize() << std::endl);
1185     FieldIterator fieldIter( m_TotalField, m_TotalField->GetLargestPossibleRegion() );
1186     fieldIter.GoToBegin();
1187     for(; !fieldIter.IsAtEnd(); ++fieldIter )
1188       {
1189       typename FixedImageType::IndexType index = fieldIter.GetIndex();
1190       m_TotalField->SetPixel(index, m_TotalField->GetPixel(index)
1191                              + m_Field->GetPixel(index) );
1192       }
1193     }
1194 }
1195 
1196 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
EvaluateResidual(SolverType * solver,Float t)1197 Element::Float FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::EvaluateResidual(SolverType *solver,
1198                                                                                               Float t)
1199 {
1200   Float SimE = m_Load->EvaluateMetricGivenSolution(solver->GetOutput()->GetModifiableElementContainer(), t);
1201   Float maxsim = 1.0;
1202   for( unsigned int i = 0; i < ImageDimension; i++ )
1203     {
1204     maxsim *= (Float)m_FullImageSize[i];
1205     }
1206   if( m_WhichMetric != 0 )
1207     {
1208     SimE = maxsim - SimE;
1209     }
1210   return std::fabs( static_cast<double>(SimE) ); // +defe;
1211 }
1212 
1213 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
FindBracketingTriplet(SolverType * solver,Float * a,Float * b,Float * c)1214 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::FindBracketingTriplet(SolverType *solver, Float* a,
1215                                                                                          Float* b,
1216                                                                                          Float* c)
1217 {
1218   // See Numerical Recipes
1219 
1220   constexpr Float Gold  = 1.618034;
1221   constexpr Float Glimit  = 100.0;
1222   const Float Tiny = 1.e-20;
1223   Float ax = 0.0;
1224   Float bx = 1.0;
1225   Float fa = std::fabs( this->EvaluateResidual(solver, ax) );
1226   Float fb = std::fabs( this->EvaluateResidual(solver, bx) );
1227 
1228   Float dum;
1229 
1230   if( fb > fa )
1231     {
1232     dum = ax; ax = bx; bx = dum;
1233     dum = fb; fb = fa; fa = dum;
1234     }
1235 
1236   Float cx = bx + Gold * (bx - ax);  // first guess for c - the 3rd pt needed to bracket the min
1237   Float fc = std::fabs( this->EvaluateResidual(solver, cx) );
1238 
1239   Float ulim, u, r, q, fu;
1240   while( fb > fc )
1241   // && std::fabs(ax) < 3. && std::fabs(bx) < 3. && std::fabs(cx) < 3.)
1242     {
1243     r = (bx - ax) * (fb - fc);
1244     q = (bx - cx) * (fb - fa);
1245     Float denom = (2.0 * solver->GSSign(solver->GSMax(std::fabs(q - r), Tiny), q - r) );
1246     u = (bx) - ( (bx - cx) * q - (bx - ax) * r) / denom;
1247     ulim = bx + Glimit * (cx - bx);
1248     if( (bx - u) * (u - cx) > 0.0 )
1249       {
1250       fu = std::fabs( this->EvaluateResidual(solver, u) );
1251       if( fu < fc )
1252         {
1253         ax = bx;
1254         bx = u;
1255         *a = ax; *b = bx; *c = cx;
1256         return;
1257         }
1258       else if( fu > fb )
1259         {
1260         cx = u;
1261         *a = ax; *b = bx; *c = cx;
1262         return;
1263         }
1264 
1265       u = cx + Gold * (cx - bx);
1266       fu = std::fabs( this->EvaluateResidual(solver, u) );
1267 
1268       }
1269     else if( (cx - u) * (u - ulim) > 0.0 )
1270       {
1271       fu = std::fabs( this->EvaluateResidual(solver, u) );
1272       if( fu < fc )
1273         {
1274         bx = cx; cx = u; u = cx + Gold * (cx - bx);
1275         fb = fc; fc = fu; fu = std::fabs( this->EvaluateResidual(solver, u) );
1276         }
1277 
1278       }
1279     else if( (u - ulim) * (ulim - cx) >= 0.0 )
1280       {
1281       u = ulim;
1282       fu = std::fabs( this->EvaluateResidual(solver, u) );
1283       }
1284     else
1285       {
1286       u = cx + Gold * (cx - bx);
1287       fu = std::fabs( this->EvaluateResidual(solver, u) );
1288       }
1289 
1290     ax = bx; bx = cx; cx = u;
1291     fa = fb; fb = fc; fc = fu;
1292 
1293     }
1294 
1295   if( std::fabs(ax) > 1.e3 || std::fabs(bx) > 1.e3 || std::fabs(cx) > 1.e3 )
1296     {
1297     ax = -2.0; bx = 1.0; cx = 2.0;
1298     } // to avoid crazy numbers caused by bad bracket (u goes nuts)
1299 
1300   *a = ax; *b = bx; *c = cx;
1301 }
1302 
1303 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
GoldenSection(SolverType * solver,Float tol,unsigned int MaxIters)1304 Element::Float FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::GoldenSection(
1305   SolverType *solver, Float tol, unsigned int MaxIters)
1306 {
1307   // We should now have a, b and c, as well as f(a), f(b), f(c),
1308   // where b gives the minimum energy position;
1309   Float ax, bx, cx;
1310 
1311   this->FindBracketingTriplet(solver, &ax, &bx, &cx);
1312 
1313   constexpr Float R  = 0.6180339;
1314   const Float C = (1.0 - R);
1315 
1316   Float x0 = ax;
1317   Float x1;
1318   Float x2;
1319   Float x3 = cx;
1320   if( std::fabs(cx - bx) > std::fabs(bx - ax) )
1321     {
1322     x1 = bx;
1323     x2 = bx + C * (cx - bx);
1324     }
1325   else
1326     {
1327     x2 = bx;
1328     x1 = bx - C * (bx - ax);
1329     }
1330 
1331   Float f1 = std::fabs( this->EvaluateResidual(solver, x1) );
1332   Float f2 = std::fabs( this->EvaluateResidual(solver, x2) );
1333   unsigned int iters = 0;
1334   while( std::fabs(x3 - x0) > tol * (std::fabs(x1) + std::fabs(x2) ) && iters < MaxIters )
1335     {
1336     iters++;
1337     if( f2 < f1 )
1338       {
1339       x0 = x1; x1 = x2; x2 = R * x1 + C * x3;
1340       f1 = f2; f2 = std::fabs( this->EvaluateResidual(solver, x2) );
1341       }
1342     else
1343       {
1344       x3 = x2; x2 = x1; x1 = R * x2 + C * x0;
1345       f2 = f1; f1 = std::fabs( this->EvaluateResidual(solver, x1) );
1346       }
1347     }
1348 
1349   Float xmin, fmin;
1350   if( f1 < f2 )
1351     {
1352     xmin = x1;
1353     fmin = f1;
1354     }
1355   else
1356     {
1357     xmin = x2;
1358     fmin = f2;
1359     }
1360 
1361   solver->SetEnergyToMin(xmin);
1362   return std::fabs( static_cast<double>(fmin) );
1363 }
1364 
1365 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
AddLandmark(PointType source,PointType target)1366 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::AddLandmark(PointType source, PointType target)
1367 {
1368   typename LoadLandmark::Pointer newLandmark = LoadLandmark::New();
1369 
1370   vnl_vector<Float> localSource;
1371   vnl_vector<Float> localTarget;
1372   localSource.set_size(ImageDimension);
1373   localTarget.set_size(ImageDimension);
1374   for( unsigned int i = 0; i < ImageDimension; i++ )
1375     {
1376     localSource[i] = source[i];
1377     localTarget[i] = target[i];
1378     }
1379 
1380   newLandmark->SetSource( localSource );
1381   newLandmark->SetTarget( localTarget );
1382   newLandmark->SetPoint( localSource );
1383 
1384   m_LandmarkArray.push_back( newLandmark );
1385 }
1386 
1387 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
InsertLandmark(unsigned int index,PointType source,PointType target)1388 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::InsertLandmark(unsigned int index, PointType source,
1389                                                                                   PointType target)
1390 {
1391   typename LoadLandmark::Pointer newLandmark = LoadLandmark::New();
1392 
1393   vnl_vector<Float> localSource;
1394   vnl_vector<Float> localTarget;
1395   localSource.set_size(ImageDimension);
1396   localTarget.set_size(ImageDimension);
1397   for( unsigned int i = 0; i < ImageDimension; i++ )
1398     {
1399     localSource[i] = source[i];
1400     localTarget[i] = target[i];
1401     }
1402 
1403   newLandmark->SetSource( localSource );
1404   newLandmark->SetTarget( localTarget );
1405   newLandmark->SetPoint( localSource );
1406 
1407   m_LandmarkArray.insert( m_LandmarkArray.begin() + index, newLandmark );
1408 }
1409 
1410 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
DeleteLandmark(unsigned int index)1411 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::DeleteLandmark(unsigned int index)
1412 {
1413   m_LandmarkArray.erase( m_LandmarkArray.begin() + index );
1414 }
1415 
1416 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
ClearLandmarks()1417 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::ClearLandmarks()
1418 {
1419   m_LandmarkArray.clear();
1420 }
1421 
1422 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
GetLandmark(unsigned int index,PointType & source,PointType & target)1423 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::GetLandmark(unsigned int index, PointType& source,
1424                                                                                PointType& target)
1425 {
1426   Element::VectorType localSource;
1427   Element::VectorType localTarget;
1428 
1429   localTarget = m_LandmarkArray[index]->GetTarget();
1430   localSource = m_LandmarkArray[index]->GetSource();
1431   for( unsigned int i = 0; i < ImageDimension; i++ )
1432     {
1433     source[i] = localSource[i];
1434     target[i] = localTarget[i];
1435     }
1436 }
1437 
1438 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
PrintSelf(std::ostream & os,Indent indent) const1439 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::PrintSelf(std::ostream& os, Indent indent) const
1440 {
1441   Superclass::PrintSelf( os, indent );
1442 
1443   os << indent << "Min E: " << m_MinE << std::endl;
1444   os << indent << "Current Level: " << m_CurrentLevel << std::endl;
1445   os << indent << "Max Iterations: " << m_Maxiters << std::endl;
1446   os << indent << "Max Level: " << m_MaxLevel << std::endl;
1447   os << indent << "Total Iterations: " << m_TotalIterations << std::endl;
1448 
1449   os << indent << "Descent Direction: " << m_DescentDirection << std::endl;
1450   os << indent << "E: " << m_E << std::endl;
1451   os << indent << "Gamma: " << m_Gamma << std::endl;
1452   os << indent << "Rho: " << m_Rho << std::endl;
1453   os << indent << "Time Step: " << m_TimeStep << std::endl;
1454   os << indent << "Alpha: " << m_Alpha << std::endl;
1455 
1456   os << indent << "Smoothing Standard Deviation: " << m_StandardDeviations << std::endl;
1457   os << indent << "Maximum Error: " << m_MaximumError << std::endl;
1458   os << indent << "Maximum Kernel Width: " << m_MaximumKernelWidth << std::endl;
1459 
1460   os << indent << "Pixels Per Element: " << m_MeshPixelsPerElementAtEachResolution << std::endl;
1461   os << indent << "Number of Integration Points: " << m_NumberOfIntegrationPoints << std::endl;
1462   os << indent << "Metric Width: " << m_MetricWidth << std::endl;
1463   os << indent << "Line Search Energy = " << m_DoLineSearchOnImageEnergy << std::endl;
1464   os << indent << "Line Search Maximum Iterations: " << m_LineSearchMaximumIterations << std::endl;
1465   os << indent << "Use Mass Matrix: " << m_UseMassMatrix << std::endl;
1466   os << indent << "Employ Regridding: " << m_EmployRegridding << std::endl;
1467 
1468   os << indent << "Use Landmarks: " << m_UseLandmarks << std::endl;
1469   os << indent << "Use Normalized Gradient: " << m_UseNormalizedGradient << std::endl;
1470   os << indent << "Min Jacobian = " << m_MinJacobian << std::endl;
1471 
1472   os << indent << "Image Scaling: " << m_ImageScaling << std::endl;
1473   os << indent << "Current Image Scaling: " << m_CurrentImageScaling << std::endl;
1474   os << indent << "Full Image Size: " << m_FullImageSize << std::endl;
1475   os << indent << "Image Origin: " << m_ImageOrigin << std::endl;
1476   os << indent << "Create Mesh: " << m_CreateMeshFromImage << std::endl;
1477 
1478   itkPrintSelfObjectMacro( Field );
1479   itkPrintSelfObjectMacro( TotalField );
1480   itkPrintSelfObjectMacro( Load );
1481   itkPrintSelfObjectMacro( WarpedImage );
1482   itkPrintSelfObjectMacro( FEMObject );
1483   itkPrintSelfObjectMacro( Interpolator );
1484 }
1485 
1486 } // end namespace fem
1487 } // end namespace itk
1488 
1489 #endif
1490