1 /*=========================================================================
2 *
3 * Copyright Insight Software Consortium
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0.txt
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 *=========================================================================*/
18 #ifndef itkFEMRegistrationFilter_hxx
19 #define itkFEMRegistrationFilter_hxx
20
21 #include "itkFEMRegistrationFilter.h"
22
23 #include "itkFEMElements.h"
24 #include "itkFEMLoadBC.h"
25
26 #include "itkMath.h"
27 #include "itkGroupSpatialObject.h"
28 #include "itkLinearInterpolateImageFunction.h"
29 #include "itkSpatialObject.h"
30 #include "itkFEMObjectSpatialObject.h"
31 #include "itkDisplacementFieldJacobianDeterminantFilter.h"
32 #include "itkStatisticsImageFilter.h"
33 #include "itkRecursiveGaussianImageFilter.h"
34
35 #include "vnl/algo/vnl_determinant.h"
36 #include "itkMath.h"
37
38 namespace itk
39 {
40 namespace fem
41 {
42
43 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
FEMRegistrationFilter()44 FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::FEMRegistrationFilter() :
45 m_DoLineSearchOnImageEnergy( 1 ),
46 m_LineSearchMaximumIterations( 100 ),
47 m_TotalIterations( 0 ),
48 m_MaxLevel( 1 ),
49 m_FileCount( 0 ),
50 m_CurrentLevel( 0 ),
51 m_WhichMetric( 0 ),
52 m_TimeStep( 1 ),
53 m_Energy( 0.0 ),
54 m_MinE( vnl_huge_val( 0 ) ),
55 m_MinJacobian( 1.0 ),
56 m_Alpha( 1.0 ),
57 m_UseLandmarks( false ),
58 m_UseMassMatrix( true ),
59 m_UseNormalizedGradient( false ),
60 m_CreateMeshFromImage( true ),
61 m_EmployRegridding( 1 ),
62 m_DescentDirection( positive ),
63 m_EnergyReductionFactor( 0.0 ),
64 m_MaximumError( 0.1 ),
65 m_MaximumKernelWidth( 30 )
66 {
67 this->SetNumberOfRequiredInputs(2);
68
69 m_NumberOfIntegrationPoints.set_size(1);
70 m_NumberOfIntegrationPoints[m_CurrentLevel] = 4;
71
72 m_MetricWidth.set_size(1);
73 m_MetricWidth[m_CurrentLevel] = 3;
74
75 m_Maxiters.set_size(1);
76 m_Maxiters[m_CurrentLevel] = 1;
77
78 m_E.set_size(1);
79 m_E[m_CurrentLevel] = 1.;
80
81 m_Rho.set_size(1);
82 m_Rho[m_CurrentLevel] = 1.;
83
84 m_Gamma.set_size(1);
85 m_Gamma[m_CurrentLevel] = 1;
86
87 m_CurrentLevelImageSize.Fill( 0 );
88
89 m_MeshPixelsPerElementAtEachResolution.set_size(1);
90 m_MeshPixelsPerElementAtEachResolution[m_CurrentLevel] = 1;
91
92 m_ImageScaling.Fill( 1 );
93 m_CurrentImageScaling.Fill( 1 );
94 m_FullImageSize.Fill( 0 );
95 m_ImageOrigin.Fill( 0 );
96 m_StandardDeviations.Fill( 1.0 );
97
98 // Set up the default interpolator
99 typename DefaultInterpolatorType::Pointer interp = DefaultInterpolatorType::New();
100 m_Interpolator = static_cast<InterpolatorType *>( interp.GetPointer() );
101 m_Interpolator->SetInputImage(m_Field);
102 }
103
104 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
~FEMRegistrationFilter()105 FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::~FEMRegistrationFilter()
106 {
107 }
108
109 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
SetMaxLevel(unsigned int level)110 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::SetMaxLevel(unsigned int level)
111 {
112 m_MaxLevel = level;
113
114 m_E.set_size(level);
115 m_Gamma.set_size(level);
116 m_Rho.set_size(level);
117 m_Maxiters.set_size(level);
118 m_MeshPixelsPerElementAtEachResolution.set_size(level);
119 m_NumberOfIntegrationPoints.set_size(level);
120 m_MetricWidth.set_size(level);
121
122 for(unsigned int i = 0; i < level; i++ )
123 {
124 m_Gamma[i] = 1;
125 m_E[i] = 1.0;
126 m_Rho[i] = 1.0;
127 m_Maxiters[i] = 1;
128 m_MeshPixelsPerElementAtEachResolution[i] = 1;
129 m_NumberOfIntegrationPoints[i] = 4;
130 m_MetricWidth[i] = 3;
131 }
132 }
133
134 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
SetStandardDeviations(double value)135 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::SetStandardDeviations(double value)
136 {
137 unsigned int j;
138
139 for( j = 0; j < ImageDimension; j++ )
140 {
141 if( Math::NotExactlyEquals(value, m_StandardDeviations[j]) )
142 {
143 break;
144 }
145 }
146 if( j < ImageDimension )
147 {
148 this->Modified();
149 for( j = 0; j < ImageDimension; j++ )
150 {
151 m_StandardDeviations[j] = value;
152 }
153 }
154 }
155
156
157 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
RunRegistration()158 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::RunRegistration()
159 {
160
161 MultiResSolve();
162
163 if( m_Field )
164 {
165 if( m_TotalField )
166 {
167 m_Field = m_TotalField;
168 }
169 this->ComputeJacobian( );
170 WarpImage(m_OriginalMovingImage);
171 }
172 }
173
174 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
SetMovingImage(MovingImageType * R)175 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::SetMovingImage(MovingImageType* R)
176 {
177 m_MovingImage = R;
178 if( m_TotalIterations == 0 )
179 {
180 m_OriginalMovingImage = R;
181 this->ProcessObject::SetNthInput( 0, const_cast<MovingImageType *>( R ) );
182 }
183 }
184
185 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
SetFixedImage(FixedImageType * T)186 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::SetFixedImage(FixedImageType* T)
187 {
188 m_FixedImage = T;
189 m_FullImageSize = m_FixedImage->GetLargestPossibleRegion().GetSize();
190
191 if( m_TotalIterations == 0 )
192 {
193 this->ProcessObject::SetNthInput( 1, const_cast<FixedImageType *>( T ) );
194 }
195 VectorType disp;
196 for( unsigned int i = 0; i < ImageDimension; i++ )
197 {
198 disp[i] = 0.0;
199 m_ImageOrigin[i] = 0;
200 }
201
202 m_CurrentLevelImageSize = m_FullImageSize;
203 }
204
205 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
SetInputFEMObject(FEMObjectType * F,unsigned int level)206 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::SetInputFEMObject(FEMObjectType* F,
207 unsigned int level)
208 {
209 this->ProcessObject::SetNthInput( 2 + level, const_cast<FEMObjectType *>( F ) );
210 }
211
212 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
213 typename FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::FEMObjectType *
GetInputFEMObject(unsigned int level)214 FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::GetInputFEMObject(unsigned int level)
215 {
216 return static_cast<FEMObjectType *>(this->ProcessObject::GetInput(2 + level) );
217 }
218
219 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
ChooseMetric(unsigned int which)220 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::ChooseMetric(unsigned int which)
221 {
222 // Choose the similarity Function
223
224 using MetricType0 = itk::MeanSquareRegistrationFunction<FixedImageType, MovingImageType, FieldType>;
225 using MetricType1 = itk::NCCRegistrationFunction<FixedImageType, MovingImageType, FieldType>;
226 using MetricType2 = itk::MIRegistrationFunction<FixedImageType, MovingImageType, FieldType>;
227 using MetricType3 = itk::DemonsRegistrationFunction<FixedImageType, MovingImageType, FieldType>;
228
229 m_WhichMetric = (unsigned int)which;
230
231 switch( which )
232 {
233 case 0:
234 m_Metric = MetricType0::New();
235 SetDescentDirectionMinimize();
236 break;
237 case 1:
238 m_Metric = MetricType1::New();
239 SetDescentDirectionMinimize();
240 break;
241 case 2:
242 m_Metric = MetricType2::New();
243 SetDescentDirectionMaximize();
244 break;
245 case 3:
246 m_Metric = MetricType3::New();
247 SetDescentDirectionMinimize();
248 break;
249 default:
250 m_Metric = MetricType0::New();
251 SetDescentDirectionMinimize();
252 }
253
254 m_Metric->SetGradientStep( m_Gamma[m_CurrentLevel] );
255 m_Metric->SetNormalizeGradient( m_UseNormalizedGradient );
256 }
257
258 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
WarpImage(const MovingImageType * ImageToWarp)259 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::WarpImage( const MovingImageType * ImageToWarp)
260 {
261 typename WarperType::Pointer warper = WarperType::New();
262 using WarperCoordRepType = typename WarperType::CoordRepType;
263 using InterpolatorType1 =
264 itk::LinearInterpolateImageFunction<MovingImageType, WarperCoordRepType>;
265 typename InterpolatorType1::Pointer interpolator = InterpolatorType1::New();
266
267 warper = WarperType::New();
268 warper->SetInput( ImageToWarp );
269 warper->SetDisplacementField( m_Field );
270 warper->SetInterpolator( interpolator );
271 warper->SetOutputOrigin( m_FixedImage->GetOrigin() );
272 warper->SetOutputSpacing( m_FixedImage->GetSpacing() );
273 warper->SetOutputDirection( m_FixedImage->GetDirection() );
274 typename FixedImageType::PixelType padValue = 0;
275 warper->SetEdgePaddingValue( padValue );
276 warper->Update();
277
278 m_WarpedImage = warper->GetOutput();
279 }
280
281 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
CreateMesh(unsigned int PixelsPerElement,SolverType * solver)282 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::CreateMesh(unsigned int PixelsPerElement,
283 SolverType *solver)
284 {
285 vnl_vector<unsigned int> pixPerElement;
286 pixPerElement.set_size( ImageDimension );
287 pixPerElement.fill( PixelsPerElement );
288
289 if( ImageDimension == 2 && dynamic_cast<Element2DC0LinearQuadrilateral *>(&*m_Element) != nullptr )
290 {
291 m_Material->SetYoungsModulus(this->GetElasticity(m_CurrentLevel) );
292
293 itkDebugMacro( << " Generating regular Quad mesh " << std::endl );
294 typename ImageToMeshType::Pointer meshFilter = ImageToMeshType::New();
295 meshFilter->SetInput( m_MovingImage );
296 meshFilter->SetPixelsPerElement( pixPerElement );
297 meshFilter->SetElement( &*m_Element );
298 meshFilter->SetMaterial( m_Material );
299 meshFilter->Update();
300 m_FEMObject = meshFilter->GetOutput();
301 m_FEMObject->FinalizeMesh();
302 itkDebugMacro( << " Generating regular mesh done " << std::endl );
303 }
304 else if( ImageDimension == 3 && dynamic_cast<Element3DC0LinearHexahedron *>(&*m_Element) != nullptr )
305 {
306 m_Material->SetYoungsModulus( this->GetElasticity(m_CurrentLevel) );
307
308 itkDebugMacro( << " Generating regular Hex mesh " << std::endl );
309 typename ImageToMeshType::Pointer meshFilter = ImageToMeshType::New();
310 meshFilter->SetInput( m_MovingImage );
311 meshFilter->SetPixelsPerElement( pixPerElement );
312 meshFilter->SetElement( &*m_Element );
313 meshFilter->SetMaterial( m_Material );
314 meshFilter->Update();
315 m_FEMObject = meshFilter->GetOutput();
316 m_FEMObject->FinalizeMesh();
317 itkDebugMacro( << " Generating regular mesh done " << std::endl );
318 }
319 else
320 {
321 FEMException e(__FILE__, __LINE__);
322 e.SetDescription("CreateMesh - wrong image or element type");
323 e.SetLocation(ITK_LOCATION);
324 throw e;
325 }
326
327 if( m_UseLandmarks )
328 {
329 for( unsigned int i = 0; i < m_LandmarkArray.size(); i++ )
330 {
331 m_FEMObject->AddNextLoad( (m_LandmarkArray[i]) );
332 }
333 }
334
335 solver->SetInput(m_FEMObject);
336 solver->InitializeInterpolationGrid(m_FixedImage->GetBufferedRegion(),
337 m_FixedImage->GetOrigin(),
338 m_FixedImage->GetSpacing(),
339 m_FixedImage->GetDirection());
340 }
341
342 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
343 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>
ApplyImageLoads(TMovingImage * movingimg,TFixedImage * fixedimg)344 ::ApplyImageLoads(TMovingImage* movingimg, TFixedImage* fixedimg )
345 {
346 m_Load = FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::ImageMetricLoadType::New();
347 m_Load->SetMovingImage(movingimg);
348 m_Load->SetFixedImage(fixedimg);
349 if( !m_Field )
350 {
351 this->InitializeField();
352 }
353 m_Load->SetDisplacementField( this->GetDisplacementField() );
354 m_Load->SetMetric(m_Metric);
355 m_Load->InitializeMetric();
356 m_Load->SetGamma(m_Gamma[m_CurrentLevel]);
357 ImageSizeType r;
358 for( unsigned int dd = 0; dd < ImageDimension; dd++ )
359 {
360 r[dd] = m_MetricWidth[m_CurrentLevel];
361 }
362 m_Load->SetMetricRadius(r);
363 m_Load->SetNumberOfIntegrationPoints(m_NumberOfIntegrationPoints[m_CurrentLevel]);
364 m_Load->SetGlobalNumber(m_FEMObject->GetNumberOfLoads() + 1);
365 if( m_DescentDirection == positive )
366 {
367 m_Load->SetDescentDirectionMinimize( );
368 }
369 else
370 {
371 m_Load->SetDescentDirectionMaximize( );
372 }
373 m_FEMObject->AddNextLoad(m_Load);
374 m_Load = dynamic_cast<typename FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::ImageMetricLoadType *>
375 (&*m_FEMObject->GetLoadWithGlobalNumber(m_FEMObject->GetNumberOfLoads() ) );
376 }
377
378 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
ApplyLoads(ImageSizeType ImgSz,double * scaling)379 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::ApplyLoads(
380 ImageSizeType ImgSz, double* scaling)
381 {
382 // Apply the boundary conditions. We pin the image corners.
383 // First compute which elements these will be.
384 //
385 itkDebugMacro( << " Applying loads " );
386
387 vnl_vector<Float> pd; pd.set_size(ImageDimension);
388 vnl_vector<Float> pu; pu.set_size(ImageDimension);
389
390 // Now scale the landmarks
391
392 itkDebugMacro( " Number of LM loads: " << m_LandmarkArray.size() );
393
394 // Step over all the loads again to scale them by the global landmark weight.
395 if( !m_LandmarkArray.empty() )
396 {
397 for( unsigned int lmind = 0; lmind < m_LandmarkArray.size(); lmind++ )
398 {
399 m_LandmarkArray[lmind]->GetElementArray()[0] = nullptr;
400
401 itkDebugMacro( << " Prescale Pt: " << m_LandmarkArray[lmind]->GetTarget() );
402 if( scaling )
403 {
404 m_LandmarkArray[lmind]->ScalePointAndForce(scaling, m_EnergyReductionFactor);
405 itkDebugMacro( << " Postscale Pt: " << m_LandmarkArray[lmind]->GetTarget() << "; scale: " << scaling[0] );
406 }
407
408 pu = m_LandmarkArray[lmind]->GetSource();
409 pd = m_LandmarkArray[lmind]->GetPoint();
410
411 int numElements = m_FEMObject->GetNumberOfElements();
412 for( int i = 0; i < numElements; i++ )
413 {
414 if( m_FEMObject->GetElement(i)->GetLocalFromGlobalCoordinates(pu, pd ) )
415 {
416 m_LandmarkArray[lmind]->SetPoint(pd);
417 m_LandmarkArray[lmind]->GetElementArray()[0] = m_FEMObject->GetElement(i);
418 }
419 }
420
421 m_LandmarkArray[lmind]->SetGlobalNumber(lmind);
422 LoadLandmark::Pointer l5 = dynamic_cast<LoadLandmark *>( &*m_LandmarkArray[lmind]->CreateAnother() );
423 m_FEMObject->AddNextLoad(l5);
424 }
425 itkDebugMacro( << " Landmarks done" );
426 }
427
428 // Now apply the BC loads
429 LoadBC::Pointer l1;
430
431 // Pin one corner of image
432 unsigned int CornerCounter, ii, EdgeCounter = 0;
433
434 int numNodes = m_FEMObject->GetNumberOfNodes();
435
436 Element::VectorType coord;
437
438 bool EdgeFound;
439 unsigned int nodect = 0;
440 for( int i = 0; i < numNodes; i++ )
441 {
442 if( EdgeCounter >= ImageDimension )
443 {
444 return;
445 }
446
447 coord = m_FEMObject->GetNode(i)->GetCoordinates();
448 CornerCounter = 0;
449 for( ii = 0; ii < ImageDimension; ii++ )
450 {
451 if( Math::AlmostEquals( coord[ii], m_ImageOrigin[ii] )
452 || Math::AlmostEquals( coord[ii], ImgSz[ii] - 1 ) )
453 {
454 CornerCounter++;
455 }
456 }
457
458 if( CornerCounter == ImageDimension ) // the node is located at a true corner
459 {
460 unsigned int ndofpernode = (*(m_FEMObject->GetNode(i)->m_elements.begin() ) )->GetNumberOfDegreesOfFreedomPerNode();
461 unsigned int numnodesperelt = (*(m_FEMObject->GetNode(i)->m_elements.begin() ) )->GetNumberOfNodes();
462 unsigned int whichnode;
463
464 unsigned int maxnode = numnodesperelt - 1;
465
466 for( auto elt = m_FEMObject->GetNode(i)->m_elements.begin();
467 elt != m_FEMObject->GetNode(i)->m_elements.end(); elt++ )
468 {
469 for( whichnode = 0; whichnode <= maxnode; whichnode++ )
470 {
471 coord = (*elt)->GetNode(whichnode)->GetCoordinates();
472 CornerCounter = 0;
473 for( ii = 0; ii < ImageDimension; ii++ )
474 {
475 if( Math::AlmostEquals( coord[ii], m_ImageOrigin[ii] )
476 || Math::AlmostEquals( coord[ii], ImgSz[ii] - 1 ) )
477 {
478 CornerCounter++;
479 }
480 }
481 if( CornerCounter == ImageDimension - 1 )
482 {
483 EdgeFound = true;
484 }
485 else
486 {
487 EdgeFound = false;
488 }
489 if( EdgeFound )
490 {
491 for( unsigned int jj = 0; jj < ndofpernode; jj++ )
492 {
493 itkDebugMacro( " Which node " << whichnode );
494 itkDebugMacro( " Edge coord " << coord );
495
496 l1 = LoadBC::New();
497 // now we get the element from the node -- we assume we need fix the dof only once
498 // even if more than one element shares it.
499
500 l1->SetElement(Element::ConstPointer(*elt));
501 unsigned int localdof = whichnode * ndofpernode + jj;
502 l1->SetDegreeOfFreedom(localdof);
503 l1->SetValue(vnl_vector<double>(1, 0.0) );
504
505 m_FEMObject->AddNextLoad(l1);
506 }
507 EdgeCounter++;
508 }
509 }
510 } // end elt loop
511 }
512 nodect++;
513 itkDebugMacro( << " Node: " << nodect );
514 }
515 }
516
517 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
IterativeSolve(SolverType * solver)518 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::IterativeSolve(SolverType *solver)
519 {
520 if( !m_Load )
521 {
522 itkExceptionMacro( <<"No Load set" );
523 }
524
525 bool Done = false;
526 unsigned int iters = 0;
527 m_MinE = 10.e99;
528 Float deltE = 0;
529 while( !Done && iters < m_Maxiters[m_CurrentLevel] )
530 {
531 const Float lastdeltE = deltE;
532 const unsigned int DLS = m_DoLineSearchOnImageEnergy;
533 // Reset the variational similarity term to zero.
534
535 Float LastE = m_Load->GetCurrentEnergy();
536 m_Load->SetCurrentEnergy(0.0);
537 m_Load->InitializeMetric();
538
539 if( !m_Field )
540 {
541 itkExceptionMacro( <<"No Field set" );
542 }
543 solver->SetUseMassMatrix( m_UseMassMatrix );
544
545 // Solve the system of equations for displacements (u=K^-1*F)
546 solver->Modified();
547 solver->Update();
548 m_Load->PrintCurrentEnergy();
549
550 if( m_DescentDirection == 1 )
551 {
552 deltE = (LastE - m_Load->GetCurrentEnergy() );
553 }
554 else
555 {
556 deltE = (m_Load->GetCurrentEnergy() - LastE );
557 }
558
559 if( DLS == 2 && deltE < 0.0 )
560 {
561 itkDebugMacro( << " Line search " );
562 constexpr float tol = 1.0; // ((0.01 < LastE) ? 0.01 : LastE/10.);
563 LastE = this->GoldenSection(solver, tol, m_LineSearchMaximumIterations);
564 deltE = (m_MinE - LastE);
565 itkDebugMacro( << " Line search done " << std::endl );
566 }
567
568 iters++;
569
570 if( deltE == 0.0 )
571 {
572 itkDebugMacro( << " No change in energy " << std::endl);
573 Done = true;
574 }
575 if( (DLS == 0) && ( iters >= m_Maxiters[m_CurrentLevel] ) )
576 {
577 Done = true;
578 }
579 else if( (DLS > 0) &&
580 ( iters >= m_Maxiters[m_CurrentLevel] || (deltE < 0.0 && iters > 5 && lastdeltE < 0.0) ) )
581 {
582 Done = true;
583 }
584 float curmaxsol = solver->GetCurrentMaxSolution();
585 if( Math::AlmostEquals( curmaxsol, 0.0f ) )
586 {
587 curmaxsol = 1.0;
588 }
589 Float mint = m_Gamma[m_CurrentLevel] / curmaxsol;
590 if( mint > 1 )
591 {
592 mint = 1.0;
593 }
594
595 if( solver->GetCurrentMaxSolution() < 0.01 && iters > 2 )
596 {
597 Done = true;
598 }
599 solver->AddToDisplacements(mint);
600 m_MinE = LastE;
601
602 InterpolateVectorField(solver);
603
604 if( m_EmployRegridding != 0 )
605 {
606 if( iters % m_EmployRegridding == 0 )
607 {
608 this->EnforceDiffeomorphism(1.0, solver, true);
609 }
610 }
611 itkDebugMacro( << " min E: " << m_MinE << "; delt E: " << deltE << "; iters: " << iters << std::endl );
612 m_TotalIterations++;
613 }
614 }
615
616 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
617 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>
InitializeField()618 ::InitializeField()
619 {
620 m_Field = FieldType::New();
621
622 m_FieldRegion.SetSize(m_CurrentLevelImageSize );
623 m_Field->SetOrigin( m_FixedImage->GetOrigin() );
624 m_Field->SetSpacing( m_FixedImage->GetSpacing() );
625 m_Field->SetDirection( m_FixedImage->GetDirection() );
626 m_Field->SetLargestPossibleRegion( m_FieldRegion );
627 m_Field->SetBufferedRegion( m_FieldRegion );
628 m_Field->SetLargestPossibleRegion( m_FieldRegion );
629 m_Field->Allocate();
630
631 VectorType disp;
632 for( unsigned int t = 0; t < ImageDimension; t++ )
633 {
634 disp[t] = 0.0;
635 }
636 FieldIterator fieldIter( m_Field, m_FieldRegion );
637 fieldIter.GoToBegin();
638 for(; !fieldIter.IsAtEnd(); ++fieldIter )
639 {
640 fieldIter.Set(disp);
641 }
642 }
643
644 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
645 void
InterpolateVectorField(SolverType * solver)646 FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::InterpolateVectorField(SolverType *solver)
647 {
648 typename FieldType::Pointer field = m_Field;
649
650 if( !field )
651 {
652 this->InitializeField();
653 }
654 m_FieldSize = field->GetLargestPossibleRegion().GetSize();
655
656 itkDebugMacro( << " Interpolating vector field of size " << m_FieldSize);
657
658 Float rstep, sstep, tstep;
659
660 vnl_vector<double> Pos; // solution at the point
661 vnl_vector<double> Sol; // solution at the local point
662 vnl_vector<double> Gpt; // global position given by local point
663
664 VectorType disp;
665 for( unsigned int t = 0; t < ImageDimension; t++ )
666 {
667 disp[t] = 0.0;
668 }
669 FieldIterator fieldIter( field, field->GetLargestPossibleRegion() );
670
671 fieldIter.GoToBegin();
672 typename FixedImageType::IndexType rindex = fieldIter.GetIndex();
673
674 Sol.set_size(ImageDimension);
675 Gpt.set_size(ImageDimension);
676
677 if( ImageDimension == 2 )
678 {
679 Element::ConstPointer eltp;
680 for(; !fieldIter.IsAtEnd(); ++fieldIter )
681 {
682 // Get element pointer from the solver elt pointer image
683 typename FieldType::PointType physicalPoint;
684 rindex = fieldIter.GetIndex();
685 field->TransformIndexToPhysicalPoint(rindex, physicalPoint);
686 for( unsigned int d = 0; d < ImageDimension; d++ )
687 {
688 Gpt[d] = (double) (physicalPoint[d]);
689 }
690
691 eltp = solver->GetElementAtPoint(Gpt);
692 if( eltp )
693 {
694 eltp->GetLocalFromGlobalCoordinates(Gpt, Pos);
695
696 unsigned int Nnodes = eltp->GetNumberOfNodes();
697 typename Element::VectorType shapef(Nnodes);
698 shapef = eltp->ShapeFunctions(Pos);
699 Float solval;
700 for( unsigned int f = 0; f < ImageDimension; f++ )
701 {
702 solval = 0.0;
703 for( unsigned int n = 0; n < Nnodes; n++ )
704 {
705 solval += shapef[n] * solver->GetLinearSystem()->GetSolutionValue(
706 eltp->GetNode(n)->GetDegreeOfFreedom(f), solver->GetTotalSolutionIndex() );
707 }
708 Sol[f] = solval;
709 disp[f] = (Float) 1.0 * Sol[f];
710 }
711 field->SetPixel(rindex, disp );
712 }
713 }
714 }
715
716 if( ImageDimension == 3 )
717 {
718 // FIXME SHOULD BE 2.0 over meshpixperelt
719 rstep = 1.25 / ( (double)m_MeshPixelsPerElementAtEachResolution[m_CurrentLevel]);
720 sstep = 1.25 / ( (double)m_MeshPixelsPerElementAtEachResolution[m_CurrentLevel]);
721 tstep = 1.25 / ( (double)m_MeshPixelsPerElementAtEachResolution[m_CurrentLevel]);
722
723 Pos.set_size(ImageDimension);
724 int numElements = solver->GetInput()->GetNumberOfElements();
725 for( int i = 0; i < numElements; i++ )
726 {
727 Element::Pointer eltp = solver->GetInput()->GetElement(i);
728 for( double r = -1.0; r <= 1.0; r = r + rstep )
729 {
730 for( double s = -1.0; s <= 1.0; s = s + sstep )
731 {
732 for( double t = -1.0; t <= 1.0; t = t + tstep )
733 {
734 Pos[0] = r;
735 Pos[1] = s;
736 Pos[2] = t;
737
738 unsigned int numNodes = eltp->GetNumberOfNodes();
739 typename Element::VectorType shapef(numNodes);
740
741 #define FASTHEX
742 #ifdef FASTHEX
743 // FIXME temporarily using hexahedron shape f for speed
744 shapef[0] = (1 - r) * (1 - s) * (1 - t) * 0.125;
745 shapef[1] = (1 + r) * (1 - s) * (1 - t) * 0.125;
746 shapef[2] = (1 + r) * (1 + s) * (1 - t) * 0.125;
747 shapef[3] = (1 - r) * (1 + s) * (1 - t) * 0.125;
748 shapef[4] = (1 - r) * (1 - s) * (1 + t) * 0.125;
749 shapef[5] = (1 + r) * (1 - s) * (1 + t) * 0.125;
750 shapef[6] = (1 + r) * (1 + s) * (1 + t) * 0.125;
751 shapef[7] = (1 - r) * (1 + s) * (1 + t) * 0.125;
752 #else
753 shapef = (*eltp)->ShapeFunctions(Pos);
754 #endif
755
756 Float solval, posval;
757 bool inimage = true;
758 typename FixedImageType::PointType physicalPoint;
759
760 for( unsigned int f = 0; f < ImageDimension; f++ )
761 {
762 solval = 0.0;
763 posval = 0.0;
764 for( unsigned int n = 0; n < numNodes; n++ )
765 {
766 posval += shapef[n] * ( ( (eltp)->GetNodeCoordinates(n) )[f]);
767 solval += shapef[n] * solver->GetLinearSystem()->GetSolutionValue(
768 (eltp)->GetNode(n)->GetDegreeOfFreedom(f), solver->GetTotalSolutionIndex() );
769 }
770 Sol[f] = solval;
771 Gpt[f] = posval;
772 disp[f] = (Float) 1.0 * Sol[f];
773 physicalPoint[f] = Gpt[f];
774 }
775
776 inimage = m_FixedImage->TransformPhysicalPointToIndex(physicalPoint, rindex);
777 if( inimage )
778 {
779 field->SetPixel(rindex, disp);
780 }
781 }
782 }
783 } // end of for loops
784 } // end of elt array loop
785
786 }
787
788 // Ensure that the values are exact at the nodes. They won't necessarily be unless we use this code.
789 itkDebugMacro( << " Interpolation done " << std::endl);
790 }
791
792 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
ComputeJacobian()793 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::ComputeJacobian( )
794 {
795 m_MinJacobian = 1.0;
796
797 using JacobianFilterType = typename itk::DisplacementFieldJacobianDeterminantFilter< FieldType, float, FloatImageType >;
798 typename JacobianFilterType::Pointer jacobianFilter = JacobianFilterType::New();
799 jacobianFilter->SetInput( m_Field );
800 jacobianFilter->Update( );
801 m_FloatImage = jacobianFilter->GetOutput();
802
803 using StatisticsFilterType = typename itk::StatisticsImageFilter< FloatImageType >;
804 typename StatisticsFilterType::Pointer statisticsFilter = StatisticsFilterType::New();
805 statisticsFilter->SetInput( m_FloatImage );
806 statisticsFilter->Update( );
807
808 m_MinJacobian = statisticsFilter->GetMinimum();
809
810 itkDebugMacro( << " Min Jacobian: " << m_MinJacobian << std::endl);
811 }
812
813 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
EnforceDiffeomorphism(float thresh,SolverType * solver,bool onlywriteimages)814 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::EnforceDiffeomorphism(float thresh,
815 SolverType *solver,
816 bool onlywriteimages )
817 {
818 itkDebugMacro( << " Checking Jacobian using threshold " << thresh );
819
820 this->ComputeJacobian();
821
822 if( m_MinJacobian < thresh )
823 {
824 // Smooth deformation field
825 this->SmoothDisplacementField();
826 }
827
828 typename WarperType::Pointer warper = WarperType::New();
829 using WarperCoordRepType = typename WarperType::CoordRepType;
830 using InterpolatorType1 =
831 itk::LinearInterpolateImageFunction<MovingImageType, WarperCoordRepType>;
832 typename InterpolatorType1::Pointer interpolator = InterpolatorType1::New();
833
834 // If using landmarks, warp them
835 if( m_UseLandmarks )
836 {
837 itkDebugMacro( << " Warping landmarks: " << m_LandmarkArray.size() );
838
839 if( !m_LandmarkArray.empty() )
840 {
841 for( unsigned int lmind = 0; lmind < m_LandmarkArray.size(); lmind++ )
842 {
843 itkDebugMacro( << " Old source: " << m_LandmarkArray[lmind]->GetSource() );
844 itkDebugMacro( << " Target: " << m_LandmarkArray[lmind]->GetTarget() );
845
846 // Convert the source to warped coords.
847 m_LandmarkArray[lmind]->GetSource() = m_LandmarkArray[lmind]->GetSource()
848 + (dynamic_cast<LoadLandmark *>( &*solver->GetOutput()->GetLoadWithGlobalNumber(lmind) )->GetForce() );
849 itkDebugMacro( << " New source: " << m_LandmarkArray[lmind]->GetSource() );
850 itkDebugMacro( << " Target: " << m_LandmarkArray[lmind]->GetTarget() );
851 LoadLandmark::Pointer l5 = dynamic_cast<LoadLandmark *>( &*m_LandmarkArray[lmind]->CreateAnother() );
852 solver->GetOutput()->AddNextLoad(l5);
853 }
854 itkDebugMacro( << " Warping landmarks done " );
855 }
856 else
857 {
858 itkDebugMacro( << " Landmark array empty " );
859 }
860 }
861
862 // Store the total deformation by composing with the full field
863 if( !m_TotalField && !onlywriteimages )
864 {
865 itkDebugMacro( << " Allocating total deformation field " );
866
867 m_TotalField = FieldType::New();
868
869 m_FieldRegion.SetSize(m_Field->GetLargestPossibleRegion().GetSize() );
870 m_TotalField->SetLargestPossibleRegion( m_FieldRegion );
871 m_TotalField->SetBufferedRegion( m_FieldRegion );
872 m_TotalField->SetLargestPossibleRegion( m_FieldRegion );
873 m_TotalField->Allocate();
874
875 VectorType disp;
876 disp.Fill(0.0);
877
878 FieldIterator fieldIter( m_TotalField, m_FieldRegion );
879
880 for( fieldIter.GoToBegin(); !fieldIter.IsAtEnd(); ++fieldIter )
881 {
882 fieldIter.Set(disp);
883 }
884 }
885
886 if( onlywriteimages )
887 {
888 warper = WarperType::New();
889 warper->SetInput( m_OriginalMovingImage );
890 warper->SetDisplacementField( m_Field );
891 warper->SetInterpolator( interpolator );
892 warper->SetOutputOrigin( m_FixedImage->GetOrigin() );
893 warper->SetOutputSpacing( m_FixedImage->GetSpacing() );
894 warper->SetOutputDirection( m_FixedImage->GetDirection() );
895 typename MovingImageType::PixelType padValue = 0;
896 warper->SetEdgePaddingValue( padValue );
897 warper->Update();
898 m_WarpedImage = warper->GetOutput();
899 }
900 else if( m_TotalField )
901 {
902
903 typename InterpolatorType::ContinuousIndexType inputIndex;
904
905 using InterpolatedType = typename InterpolatorType::OutputType;
906
907 InterpolatedType interpolatedValue;
908
909 m_Interpolator->SetInputImage(m_Field);
910
911 typename FixedImageType::IndexType index;
912 FieldIterator totalFieldIter( m_TotalField, m_TotalField->GetLargestPossibleRegion() );
913 totalFieldIter.GoToBegin();
914 unsigned int jj;
915 float pathsteplength = 0;
916 while( !totalFieldIter.IsAtEnd() )
917 {
918 index = totalFieldIter.GetIndex();
919 for( jj = 0; jj < ImageDimension; jj++ )
920 {
921 inputIndex[jj] = (WarperCoordRepType) index[jj];
922 interpolatedValue[jj] = 0.0;
923 }
924
925 if ( m_Interpolator->IsInsideBuffer( inputIndex ) )
926 {
927 interpolatedValue =
928 m_Interpolator->EvaluateAtContinuousIndex( inputIndex );
929 }
930 VectorType interped;
931 float temp = 0.0;
932 for( jj = 0; jj < ImageDimension; jj++ )
933 {
934 interped[jj] = interpolatedValue[jj];
935 temp += interped[jj] * interped[jj];
936 }
937 pathsteplength += std::sqrt(temp);
938 m_TotalField->SetPixel(index, m_TotalField->GetPixel(index) + interped);
939 ++totalFieldIter;
940 }
941
942 itkDebugMacro( << " Incremental path length: " << pathsteplength );
943
944 // Set the field to zero
945 FieldIterator fieldIter( m_Field, m_Field->GetLargestPossibleRegion() );
946 fieldIter.GoToBegin();
947 while( !fieldIter.IsAtEnd() )
948 {
949 VectorType disp;
950 disp.Fill(0.0);
951 fieldIter.Set(disp);
952 ++fieldIter;
953 }
954
955 // Now do the same for the solver
956 int numNodes = solver->GetOutput()->GetNumberOfNodes();
957 for( int i = 0; i < numNodes; i++ )
958 {
959 // Now put it into the solution!
960 for( unsigned int ii = 0; ii < ImageDimension; ii++ )
961 {
962 solver->GetLinearSystemWrapper()->
963 SetSolutionValue( (solver->GetOutput()->GetNode(i) )->GetDegreeOfFreedom(
964 ii), 0.0, solver->GetTotalSolutionIndex() );
965 solver->GetLinearSystemWrapper()->
966 SetSolutionValue( (solver->GetOutput()->GetNode(i) )->GetDegreeOfFreedom(
967 ii), 0.0, solver->GetSolutionTMinus1Index() );
968 }
969 }
970
971 warper = WarperType::New();
972 warper->SetInput( m_OriginalMovingImage );
973 warper->SetDisplacementField( m_TotalField );
974 warper->SetInterpolator( interpolator );
975 warper->SetOutputOrigin( m_FixedImage->GetOrigin() );
976 warper->SetOutputSpacing( m_FixedImage->GetSpacing() );
977 warper->SetOutputDirection( m_FixedImage->GetDirection() );
978 typename FixedImageType::PixelType padValue = 0;
979 warper->SetEdgePaddingValue( padValue );
980 warper->Update();
981
982 // Set it as the new moving image
983 this->SetMovingImage( warper->GetOutput() );
984
985 m_WarpedImage = m_MovingImage;
986
987 m_Load->SetMovingImage( this->GetMovingImage() );
988 }
989 itkDebugMacro( << " Enforcing diffeomorphism done " );
990 }
991
992 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
993 void
SmoothDisplacementField()994 FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::SmoothDisplacementField()
995 {
996
997 using GaussianFilterType = RecursiveGaussianImageFilter< FieldType, FieldType >;
998 typename GaussianFilterType::Pointer smoother = GaussianFilterType::New();
999
1000 for( unsigned int dim = 0; dim < ImageDimension; ++dim )
1001 {
1002 // Sigma accounts for the subsampling of the pyramid
1003 double sigma = m_StandardDeviations[dim];
1004
1005 smoother->SetInput( m_Field );
1006 smoother->SetSigma(sigma);
1007 smoother->SetDirection(dim);
1008 smoother->Update();
1009
1010 m_Field = smoother->GetOutput();
1011 m_Field->DisconnectPipeline();
1012 }
1013 }
1014
1015 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
1016 typename FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::FieldPointer
ExpandVectorField(ExpandFactorsType * expandFactors,FieldType * field)1017 FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::ExpandVectorField( ExpandFactorsType* expandFactors,
1018 FieldType* field)
1019 {
1020 // Re-size the vector field
1021 if( !field )
1022 {
1023 field = m_Field;
1024 }
1025
1026 itkDebugMacro( << " Input field size: " << m_Field->GetLargestPossibleRegion().GetSize() );
1027 itkDebugMacro( << " Expand factors: " );
1028 VectorType pad;
1029 for( unsigned int i = 0; i < ImageDimension; i++ )
1030 {
1031 pad[i] = 0.0;
1032 itkDebugMacro( << expandFactors[i] << " " );
1033 }
1034 itkDebugMacro( << std::endl );
1035 typename ExpanderType::Pointer m_FieldExpander = ExpanderType::New();
1036 m_FieldExpander->SetInput(field);
1037 m_FieldExpander->SetExpandFactors( expandFactors );
1038 // use default
1039 // TEST_RMV20100728 m_FieldExpander->SetEdgePaddingValue( pad );
1040 m_FieldExpander->UpdateLargestPossibleRegion();
1041
1042 m_FieldSize = m_FieldExpander->GetOutput()->GetLargestPossibleRegion().GetSize();
1043
1044 return m_FieldExpander->GetOutput();
1045 }
1046
1047 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
SampleVectorFieldAtNodes(SolverType * solver)1048 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::SampleVectorFieldAtNodes(SolverType *solver)
1049 {
1050 // Here, we need to iterate through the nodes, get the nodal coordinates,
1051 // sample the VF at the node and place the values in the SolutionVector.
1052 int numNodes = solver->GetOutput()->GetNumberOfNodes();
1053 Element::VectorType coord;
1054 VectorType SolutionAtNode;
1055
1056 m_Interpolator->SetInputImage(m_Field);
1057 for( int i = 0; i < numNodes; i++ )
1058 {
1059 coord = solver->GetOutput()->GetNode(i)->GetCoordinates();
1060 typename InterpolatorType::ContinuousIndexType inputIndex;
1061 using InterpolatedType = typename InterpolatorType::OutputType;
1062 InterpolatedType interpolatedValue;
1063 for( unsigned int jj = 0; jj < ImageDimension; jj++ )
1064 {
1065 inputIndex[jj] = (CoordRepType) coord[jj];
1066 interpolatedValue[jj] = 0.0;
1067 }
1068 if( m_Interpolator->IsInsideBuffer( inputIndex ) )
1069 {
1070 interpolatedValue =
1071 m_Interpolator->EvaluateAtContinuousIndex( inputIndex );
1072 }
1073 for( unsigned int jj = 0; jj < ImageDimension; jj++ )
1074 {
1075 SolutionAtNode[jj] = interpolatedValue[jj];
1076 }
1077 // Now put it into the solution!
1078 for( unsigned int ii = 0; ii < ImageDimension; ii++ )
1079 {
1080 Float Sol = SolutionAtNode[ii];
1081 solver->GetLinearSystemWrapper()->
1082 SetSolutionValue(solver->GetOutput()->GetNode(i)->GetDegreeOfFreedom(ii), Sol, solver->GetTotalSolutionIndex() );
1083 solver->GetLinearSystemWrapper()->
1084 SetSolutionValue(solver->GetOutput()->GetNode(i)->GetDegreeOfFreedom(
1085 ii), Sol, solver->GetSolutionTMinus1Index() );
1086 }
1087 }
1088 }
1089
1090 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
PrintVectorField(unsigned int modnum)1091 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::PrintVectorField(unsigned int modnum)
1092 {
1093 FieldIterator fieldIter( m_Field, m_Field->GetLargestPossibleRegion() );
1094
1095 fieldIter.GoToBegin();
1096 unsigned int ct = 0;
1097
1098 float max = 0;
1099 while( !fieldIter.IsAtEnd() )
1100 {
1101 VectorType disp = fieldIter.Get();
1102 if( (ct % modnum) == 0 )
1103 {
1104 itkDebugMacro( << " Field pix: " << fieldIter.Get() << std::endl);
1105 }
1106 for( unsigned int i = 0; i < ImageDimension; i++ )
1107 {
1108 if( std::fabs(disp[i]) > max )
1109 {
1110 max = std::fabs(disp[i]);
1111 }
1112 }
1113 ++fieldIter;
1114 ct++;
1115
1116 }
1117
1118 itkDebugMacro( << " Max vec: " << max << std::endl );
1119 }
1120
1121 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
MultiResSolve()1122 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::MultiResSolve()
1123 {
1124 for( m_CurrentLevel = 0; m_CurrentLevel < m_MaxLevel; m_CurrentLevel++ )
1125 {
1126 itkDebugMacro( << " Beginning level " << m_CurrentLevel << std::endl );
1127
1128 typename SolverType::Pointer solver = SolverType::New();
1129
1130 if( m_Maxiters[m_CurrentLevel] > 0 )
1131 {
1132 unsigned int meshResolution = this->m_MeshPixelsPerElementAtEachResolution(m_CurrentLevel);
1133
1134 solver->SetTimeStep(m_TimeStep);
1135 solver->SetRho(m_Rho[m_CurrentLevel]);
1136 solver->SetAlpha(m_Alpha);
1137
1138 if( m_CreateMeshFromImage )
1139 {
1140 this->CreateMesh(meshResolution, solver);
1141 }
1142 else
1143 {
1144 m_FEMObject = this->GetInputFEMObject( m_CurrentLevel );
1145 }
1146
1147 this->ApplyLoads(m_FullImageSize, nullptr);
1148 this->ApplyImageLoads(m_MovingImage, m_FixedImage );
1149
1150
1151 unsigned int ndofpernode = m_Element->GetNumberOfDegreesOfFreedomPerNode();
1152 unsigned int numnodesperelt = m_Element->GetNumberOfNodes() + 1;
1153 unsigned int ndof = solver->GetInput()->GetNumberOfDegreesOfFreedom();
1154 unsigned int nzelts;
1155 nzelts = numnodesperelt * ndofpernode * ndof;
1156
1157 // Used when reading a mesh from file
1158 // nzelts=((2*numnodesperelt*ndofpernode*ndof > 25*ndof) ? 2*numnodesperelt*ndofpernode*ndof : 25*ndof);
1159
1160 LinearSystemWrapperItpack itpackWrapper;
1161 unsigned int maxits = 2 * solver->GetInput()->GetNumberOfDegreesOfFreedom();
1162 itpackWrapper.SetMaximumNumberIterations(maxits);
1163 itpackWrapper.SetTolerance(1.e-1);
1164 itpackWrapper.JacobianConjugateGradient();
1165 itpackWrapper.SetMaximumNonZeroValuesInMatrix(nzelts);
1166 solver->SetLinearSystemWrapper(&itpackWrapper);
1167 solver->SetUseMassMatrix( m_UseMassMatrix );
1168 if( m_CurrentLevel > 0 )
1169 {
1170 this->SampleVectorFieldAtNodes(solver);
1171 }
1172 this->IterativeSolve(solver);
1173 }
1174
1175 // Now expand the field for the next level, if necessary.
1176
1177 itkDebugMacro( << " End level: " << m_CurrentLevel );
1178
1179 } // end image resolution loop
1180
1181 if( m_TotalField )
1182 {
1183 itkDebugMacro( << " Copy field: " << m_TotalField->GetLargestPossibleRegion().GetSize() );
1184 itkDebugMacro( << " To: " << m_Field->GetLargestPossibleRegion().GetSize() << std::endl);
1185 FieldIterator fieldIter( m_TotalField, m_TotalField->GetLargestPossibleRegion() );
1186 fieldIter.GoToBegin();
1187 for(; !fieldIter.IsAtEnd(); ++fieldIter )
1188 {
1189 typename FixedImageType::IndexType index = fieldIter.GetIndex();
1190 m_TotalField->SetPixel(index, m_TotalField->GetPixel(index)
1191 + m_Field->GetPixel(index) );
1192 }
1193 }
1194 }
1195
1196 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
EvaluateResidual(SolverType * solver,Float t)1197 Element::Float FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::EvaluateResidual(SolverType *solver,
1198 Float t)
1199 {
1200 Float SimE = m_Load->EvaluateMetricGivenSolution(solver->GetOutput()->GetModifiableElementContainer(), t);
1201 Float maxsim = 1.0;
1202 for( unsigned int i = 0; i < ImageDimension; i++ )
1203 {
1204 maxsim *= (Float)m_FullImageSize[i];
1205 }
1206 if( m_WhichMetric != 0 )
1207 {
1208 SimE = maxsim - SimE;
1209 }
1210 return std::fabs( static_cast<double>(SimE) ); // +defe;
1211 }
1212
1213 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
FindBracketingTriplet(SolverType * solver,Float * a,Float * b,Float * c)1214 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::FindBracketingTriplet(SolverType *solver, Float* a,
1215 Float* b,
1216 Float* c)
1217 {
1218 // See Numerical Recipes
1219
1220 constexpr Float Gold = 1.618034;
1221 constexpr Float Glimit = 100.0;
1222 const Float Tiny = 1.e-20;
1223 Float ax = 0.0;
1224 Float bx = 1.0;
1225 Float fa = std::fabs( this->EvaluateResidual(solver, ax) );
1226 Float fb = std::fabs( this->EvaluateResidual(solver, bx) );
1227
1228 Float dum;
1229
1230 if( fb > fa )
1231 {
1232 dum = ax; ax = bx; bx = dum;
1233 dum = fb; fb = fa; fa = dum;
1234 }
1235
1236 Float cx = bx + Gold * (bx - ax); // first guess for c - the 3rd pt needed to bracket the min
1237 Float fc = std::fabs( this->EvaluateResidual(solver, cx) );
1238
1239 Float ulim, u, r, q, fu;
1240 while( fb > fc )
1241 // && std::fabs(ax) < 3. && std::fabs(bx) < 3. && std::fabs(cx) < 3.)
1242 {
1243 r = (bx - ax) * (fb - fc);
1244 q = (bx - cx) * (fb - fa);
1245 Float denom = (2.0 * solver->GSSign(solver->GSMax(std::fabs(q - r), Tiny), q - r) );
1246 u = (bx) - ( (bx - cx) * q - (bx - ax) * r) / denom;
1247 ulim = bx + Glimit * (cx - bx);
1248 if( (bx - u) * (u - cx) > 0.0 )
1249 {
1250 fu = std::fabs( this->EvaluateResidual(solver, u) );
1251 if( fu < fc )
1252 {
1253 ax = bx;
1254 bx = u;
1255 *a = ax; *b = bx; *c = cx;
1256 return;
1257 }
1258 else if( fu > fb )
1259 {
1260 cx = u;
1261 *a = ax; *b = bx; *c = cx;
1262 return;
1263 }
1264
1265 u = cx + Gold * (cx - bx);
1266 fu = std::fabs( this->EvaluateResidual(solver, u) );
1267
1268 }
1269 else if( (cx - u) * (u - ulim) > 0.0 )
1270 {
1271 fu = std::fabs( this->EvaluateResidual(solver, u) );
1272 if( fu < fc )
1273 {
1274 bx = cx; cx = u; u = cx + Gold * (cx - bx);
1275 fb = fc; fc = fu; fu = std::fabs( this->EvaluateResidual(solver, u) );
1276 }
1277
1278 }
1279 else if( (u - ulim) * (ulim - cx) >= 0.0 )
1280 {
1281 u = ulim;
1282 fu = std::fabs( this->EvaluateResidual(solver, u) );
1283 }
1284 else
1285 {
1286 u = cx + Gold * (cx - bx);
1287 fu = std::fabs( this->EvaluateResidual(solver, u) );
1288 }
1289
1290 ax = bx; bx = cx; cx = u;
1291 fa = fb; fb = fc; fc = fu;
1292
1293 }
1294
1295 if( std::fabs(ax) > 1.e3 || std::fabs(bx) > 1.e3 || std::fabs(cx) > 1.e3 )
1296 {
1297 ax = -2.0; bx = 1.0; cx = 2.0;
1298 } // to avoid crazy numbers caused by bad bracket (u goes nuts)
1299
1300 *a = ax; *b = bx; *c = cx;
1301 }
1302
1303 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
GoldenSection(SolverType * solver,Float tol,unsigned int MaxIters)1304 Element::Float FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::GoldenSection(
1305 SolverType *solver, Float tol, unsigned int MaxIters)
1306 {
1307 // We should now have a, b and c, as well as f(a), f(b), f(c),
1308 // where b gives the minimum energy position;
1309 Float ax, bx, cx;
1310
1311 this->FindBracketingTriplet(solver, &ax, &bx, &cx);
1312
1313 constexpr Float R = 0.6180339;
1314 const Float C = (1.0 - R);
1315
1316 Float x0 = ax;
1317 Float x1;
1318 Float x2;
1319 Float x3 = cx;
1320 if( std::fabs(cx - bx) > std::fabs(bx - ax) )
1321 {
1322 x1 = bx;
1323 x2 = bx + C * (cx - bx);
1324 }
1325 else
1326 {
1327 x2 = bx;
1328 x1 = bx - C * (bx - ax);
1329 }
1330
1331 Float f1 = std::fabs( this->EvaluateResidual(solver, x1) );
1332 Float f2 = std::fabs( this->EvaluateResidual(solver, x2) );
1333 unsigned int iters = 0;
1334 while( std::fabs(x3 - x0) > tol * (std::fabs(x1) + std::fabs(x2) ) && iters < MaxIters )
1335 {
1336 iters++;
1337 if( f2 < f1 )
1338 {
1339 x0 = x1; x1 = x2; x2 = R * x1 + C * x3;
1340 f1 = f2; f2 = std::fabs( this->EvaluateResidual(solver, x2) );
1341 }
1342 else
1343 {
1344 x3 = x2; x2 = x1; x1 = R * x2 + C * x0;
1345 f2 = f1; f1 = std::fabs( this->EvaluateResidual(solver, x1) );
1346 }
1347 }
1348
1349 Float xmin, fmin;
1350 if( f1 < f2 )
1351 {
1352 xmin = x1;
1353 fmin = f1;
1354 }
1355 else
1356 {
1357 xmin = x2;
1358 fmin = f2;
1359 }
1360
1361 solver->SetEnergyToMin(xmin);
1362 return std::fabs( static_cast<double>(fmin) );
1363 }
1364
1365 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
AddLandmark(PointType source,PointType target)1366 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::AddLandmark(PointType source, PointType target)
1367 {
1368 typename LoadLandmark::Pointer newLandmark = LoadLandmark::New();
1369
1370 vnl_vector<Float> localSource;
1371 vnl_vector<Float> localTarget;
1372 localSource.set_size(ImageDimension);
1373 localTarget.set_size(ImageDimension);
1374 for( unsigned int i = 0; i < ImageDimension; i++ )
1375 {
1376 localSource[i] = source[i];
1377 localTarget[i] = target[i];
1378 }
1379
1380 newLandmark->SetSource( localSource );
1381 newLandmark->SetTarget( localTarget );
1382 newLandmark->SetPoint( localSource );
1383
1384 m_LandmarkArray.push_back( newLandmark );
1385 }
1386
1387 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
InsertLandmark(unsigned int index,PointType source,PointType target)1388 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::InsertLandmark(unsigned int index, PointType source,
1389 PointType target)
1390 {
1391 typename LoadLandmark::Pointer newLandmark = LoadLandmark::New();
1392
1393 vnl_vector<Float> localSource;
1394 vnl_vector<Float> localTarget;
1395 localSource.set_size(ImageDimension);
1396 localTarget.set_size(ImageDimension);
1397 for( unsigned int i = 0; i < ImageDimension; i++ )
1398 {
1399 localSource[i] = source[i];
1400 localTarget[i] = target[i];
1401 }
1402
1403 newLandmark->SetSource( localSource );
1404 newLandmark->SetTarget( localTarget );
1405 newLandmark->SetPoint( localSource );
1406
1407 m_LandmarkArray.insert( m_LandmarkArray.begin() + index, newLandmark );
1408 }
1409
1410 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
DeleteLandmark(unsigned int index)1411 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::DeleteLandmark(unsigned int index)
1412 {
1413 m_LandmarkArray.erase( m_LandmarkArray.begin() + index );
1414 }
1415
1416 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
ClearLandmarks()1417 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::ClearLandmarks()
1418 {
1419 m_LandmarkArray.clear();
1420 }
1421
1422 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
GetLandmark(unsigned int index,PointType & source,PointType & target)1423 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::GetLandmark(unsigned int index, PointType& source,
1424 PointType& target)
1425 {
1426 Element::VectorType localSource;
1427 Element::VectorType localTarget;
1428
1429 localTarget = m_LandmarkArray[index]->GetTarget();
1430 localSource = m_LandmarkArray[index]->GetSource();
1431 for( unsigned int i = 0; i < ImageDimension; i++ )
1432 {
1433 source[i] = localSource[i];
1434 target[i] = localTarget[i];
1435 }
1436 }
1437
1438 template <typename TMovingImage, typename TFixedImage, typename TFemObject>
PrintSelf(std::ostream & os,Indent indent) const1439 void FEMRegistrationFilter<TMovingImage, TFixedImage, TFemObject>::PrintSelf(std::ostream& os, Indent indent) const
1440 {
1441 Superclass::PrintSelf( os, indent );
1442
1443 os << indent << "Min E: " << m_MinE << std::endl;
1444 os << indent << "Current Level: " << m_CurrentLevel << std::endl;
1445 os << indent << "Max Iterations: " << m_Maxiters << std::endl;
1446 os << indent << "Max Level: " << m_MaxLevel << std::endl;
1447 os << indent << "Total Iterations: " << m_TotalIterations << std::endl;
1448
1449 os << indent << "Descent Direction: " << m_DescentDirection << std::endl;
1450 os << indent << "E: " << m_E << std::endl;
1451 os << indent << "Gamma: " << m_Gamma << std::endl;
1452 os << indent << "Rho: " << m_Rho << std::endl;
1453 os << indent << "Time Step: " << m_TimeStep << std::endl;
1454 os << indent << "Alpha: " << m_Alpha << std::endl;
1455
1456 os << indent << "Smoothing Standard Deviation: " << m_StandardDeviations << std::endl;
1457 os << indent << "Maximum Error: " << m_MaximumError << std::endl;
1458 os << indent << "Maximum Kernel Width: " << m_MaximumKernelWidth << std::endl;
1459
1460 os << indent << "Pixels Per Element: " << m_MeshPixelsPerElementAtEachResolution << std::endl;
1461 os << indent << "Number of Integration Points: " << m_NumberOfIntegrationPoints << std::endl;
1462 os << indent << "Metric Width: " << m_MetricWidth << std::endl;
1463 os << indent << "Line Search Energy = " << m_DoLineSearchOnImageEnergy << std::endl;
1464 os << indent << "Line Search Maximum Iterations: " << m_LineSearchMaximumIterations << std::endl;
1465 os << indent << "Use Mass Matrix: " << m_UseMassMatrix << std::endl;
1466 os << indent << "Employ Regridding: " << m_EmployRegridding << std::endl;
1467
1468 os << indent << "Use Landmarks: " << m_UseLandmarks << std::endl;
1469 os << indent << "Use Normalized Gradient: " << m_UseNormalizedGradient << std::endl;
1470 os << indent << "Min Jacobian = " << m_MinJacobian << std::endl;
1471
1472 os << indent << "Image Scaling: " << m_ImageScaling << std::endl;
1473 os << indent << "Current Image Scaling: " << m_CurrentImageScaling << std::endl;
1474 os << indent << "Full Image Size: " << m_FullImageSize << std::endl;
1475 os << indent << "Image Origin: " << m_ImageOrigin << std::endl;
1476 os << indent << "Create Mesh: " << m_CreateMeshFromImage << std::endl;
1477
1478 itkPrintSelfObjectMacro( Field );
1479 itkPrintSelfObjectMacro( TotalField );
1480 itkPrintSelfObjectMacro( Load );
1481 itkPrintSelfObjectMacro( WarpedImage );
1482 itkPrintSelfObjectMacro( FEMObject );
1483 itkPrintSelfObjectMacro( Interpolator );
1484 }
1485
1486 } // end namespace fem
1487 } // end namespace itk
1488
1489 #endif
1490