1 /*=========================================================================
2  *
3  *  Copyright Insight Software Consortium
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at
8  *
9  *         http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  *=========================================================================*/
18 #ifndef itkImageRegistrationMethodv4_hxx
19 #define itkImageRegistrationMethodv4_hxx
20 
21 #include "itkImageRegistrationMethodv4.h"
22 
23 #include "itkSmoothingRecursiveGaussianImageFilter.h"
24 #include "itkGradientDescentOptimizerv4.h"
25 #include "itkImageRandomConstIteratorWithIndex.h"
26 #include "itkImageRegionConstIteratorWithIndex.h"
27 #include "itkImageToImageMetricv4.h"
28 #include "itkIterationReporter.h"
29 #include "itkMattesMutualInformationImageToImageMetricv4.h"
30 #include "itkMersenneTwisterRandomVariateGenerator.h"
31 #include "itkRegistrationParameterScalesFromPhysicalShift.h"
32 
33 namespace itk
34 {
35 /**
36  * Constructor
37  */
38 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
39 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
ImageRegistrationMethodv4()40 ::ImageRegistrationMethodv4()
41 {
42   ProcessObject::SetNumberOfRequiredOutputs( 1 );
43   Self::SetPrimaryOutputName( "Transform" );
44 
45   // indexed input are alternating fixed and moving images
46   Self::SetPrimaryInputName( "Fixed" );
47   Self::AddRequiredInputName( "Moving", 1 );
48   ProcessObject::SetNumberOfRequiredInputs( 2 );
49 
50   // optional named inputs
51   Self::SetInput( "InitialTransform", nullptr );
52   Self::SetInput( "FixedInitialTransform", nullptr );
53   Self::SetInput( "MovingInitialTransform", nullptr );
54 
55   this->m_VirtualDomainImage = nullptr;
56 
57   Self::ReleaseDataBeforeUpdateFlagOff();
58 
59   this->m_CurrentLevel = 0;
60   this->m_CurrentIteration = 0;
61   this->m_CurrentMetricValue = 0.0;
62   this->m_CurrentConvergenceValue = 0.0;
63   this->m_IsConverged = false;
64   this->m_NumberOfFixedObjects = 0;
65   this->m_NumberOfMovingObjects = 0;
66 
67   Self::ReleaseDataBeforeUpdateFlagOff();
68 
69   this->m_InPlace = true;
70 
71   this->m_InitializeCenterOfLinearOutputTransform = true;
72 
73   this->m_CompositeTransform = CompositeTransformType::New();
74 
75   using DefaultMetricType = MattesMutualInformationImageToImageMetricv4<FixedImageType, MovingImageType, VirtualImageType, RealType>;
76   typename DefaultMetricType::Pointer mutualInformationMetric = DefaultMetricType::New();
77   mutualInformationMetric->SetNumberOfHistogramBins( 20 );
78   mutualInformationMetric->SetUseMovingImageGradientFilter( false );
79   mutualInformationMetric->SetUseFixedImageGradientFilter( false );
80   mutualInformationMetric->SetUseSampledPointSet( false );
81   this->m_Metric = mutualInformationMetric;
82 
83   using DefaultScalesEstimatorType = RegistrationParameterScalesFromPhysicalShift<DefaultMetricType>;
84   typename DefaultScalesEstimatorType::Pointer scalesEstimator = DefaultScalesEstimatorType::New();
85   scalesEstimator->SetMetric( mutualInformationMetric );
86   scalesEstimator->SetTransformForward( true );
87 
88   using DefaultOptimizerType = GradientDescentOptimizerv4Template<RealType>;
89   typename DefaultOptimizerType::Pointer optimizer = DefaultOptimizerType::New();
90   optimizer->SetLearningRate( 1.0 );
91   optimizer->SetNumberOfIterations( 1000 );
92   optimizer->SetScalesEstimator( scalesEstimator );
93   this->m_Optimizer = optimizer;
94 
95   this->m_OptimizerWeights.SetSize( 0 );
96   this->m_OptimizerWeightsAreIdentity = true;
97 
98   DecoratedOutputTransformPointer transformDecorator =
99         itkDynamicCastInDebugMode< DecoratedOutputTransformType * >( this->MakeOutput(0).GetPointer() );
100   this->ProcessObject::SetNthOutput( 0, transformDecorator );
101   this->m_OutputTransform = transformDecorator->GetModifiable();
102 
103   // By default we set up a 3-level image registration.
104 
105   this->m_NumberOfLevels = 0;
106   this->SetNumberOfLevels( 3 );
107 
108   this->m_ShrinkFactorsPerLevel.resize( this->m_NumberOfLevels );
109   ShrinkFactorsPerDimensionContainerType shrinkFactors;
110   shrinkFactors.Fill( 2 );
111   this->m_ShrinkFactorsPerLevel[0] = shrinkFactors;
112   shrinkFactors.Fill( 1 );
113   this->m_ShrinkFactorsPerLevel[1] = shrinkFactors;
114   shrinkFactors.Fill( 1 );
115   this->m_ShrinkFactorsPerLevel[2] = shrinkFactors;
116 
117   this->m_SmoothingSigmasPerLevel.SetSize( this->m_NumberOfLevels );
118   this->m_SmoothingSigmasPerLevel[0] = 2;
119   this->m_SmoothingSigmasPerLevel[1] = 1;
120   this->m_SmoothingSigmasPerLevel[2] = 0;
121 
122   this->m_SmoothingSigmasAreSpecifiedInPhysicalUnits = true;
123 
124   this->m_ReseedIterator = false;
125   this->m_RandomSeed = Statistics::MersenneTwisterRandomVariateGenerator::GetNextSeed();
126   this->m_CurrentRandomSeed = this->m_RandomSeed;
127 
128   this->m_MetricSamplingStrategy = NONE;
129   this->m_MetricSamplingPercentagePerLevel.SetSize( this->m_NumberOfLevels );
130   this->m_MetricSamplingPercentagePerLevel.Fill( 1.0 );
131 }
132 
133 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
134 void
135 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
SetFixedImage(SizeValueType index,const FixedImageType * image)136 ::SetFixedImage( SizeValueType index, const FixedImageType *image )
137 {
138   itkDebugMacro( "setting fixed image input " << index << " to " << image );
139   if( image != static_cast<FixedImageType *>( this->ProcessObject::GetInput( 2 * index ) ) )
140     {
141     if( !this->ProcessObject::GetInput( 2 * index ) )
142       {
143       this->m_NumberOfFixedObjects++;
144       }
145     this->ProcessObject::SetNthInput( 2 * index, const_cast<FixedImageType *>( image ) );
146     this->Modified();
147     }
148 }
149 
150 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
151 const typename ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>::FixedImageType *
152 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
GetFixedImage(SizeValueType index) const153 ::GetFixedImage( SizeValueType index ) const
154 {
155   itkDebugMacro( "returning fixed image input " << index << " of "
156                                     << static_cast<const FixedImageType *>( this->ProcessObject::GetInput( 2 * index ) ) );
157   return static_cast<const FixedImageType *>( this->ProcessObject::GetInput( 2 * index ) );
158 }
159 
160 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
161 void
162 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
SetMovingImage(SizeValueType index,const MovingImageType * image)163 ::SetMovingImage( SizeValueType index, const MovingImageType *image )
164 {
165   itkDebugMacro( "setting moving image input " << index << " to " << image );
166   if( image != static_cast<MovingImageType *>( this->ProcessObject::GetInput( 2 * index + 1 ) ) )
167     {
168     if( !this->ProcessObject::GetInput( 2 * index + 1 ) )
169       {
170       this->m_NumberOfMovingObjects++;
171       }
172     this->ProcessObject::SetNthInput( 2 * index + 1, const_cast<MovingImageType *>( image ) );
173     this->Modified();
174     }
175 }
176 
177 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
178 const typename ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>::MovingImageType *
179 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
GetMovingImage(SizeValueType index) const180 ::GetMovingImage( SizeValueType index ) const
181 {
182   itkDebugMacro( "returning moving image input " << index << " of "
183                                     << static_cast<const MovingImageType *>( this->ProcessObject::GetInput( 2 * index + 1 ) ) );
184   return static_cast<const MovingImageType *>( this->ProcessObject::GetInput( 2 * index + 1 ) );
185 }
186 
187 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
188 void
189 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
SetFixedPointSet(SizeValueType index,const PointSetType * pointSet)190 ::SetFixedPointSet( SizeValueType index, const PointSetType *pointSet )
191 {
192   itkDebugMacro( "setting fixed point set input " << index << " to " << pointSet );
193   if( pointSet != static_cast<PointSetType *>( this->ProcessObject::GetInput( 2 * index ) ) )
194     {
195     if( !this->ProcessObject::GetInput( 2 * index ) )
196       {
197       this->m_NumberOfFixedObjects++;
198       }
199     this->ProcessObject::SetNthInput( 2 * index, const_cast<PointSetType *>( pointSet ) );
200     this->Modified();
201     }
202 }
203 
204 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
205 const typename ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>::PointSetType *
206 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
GetFixedPointSet(SizeValueType index) const207 ::GetFixedPointSet( SizeValueType index ) const
208 {
209   itkDebugMacro( "returning fixed point set input " << index << " of "
210                                     << static_cast<const PointSetType *>( this->ProcessObject::GetInput( 2 * index ) ) );
211   return static_cast<const PointSetType *>( this->ProcessObject::GetInput( 2 * index ) );
212 }
213 
214 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
215 void
216 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
SetMovingPointSet(SizeValueType index,const PointSetType * pointSet)217 ::SetMovingPointSet( SizeValueType index, const PointSetType *pointSet )
218 {
219   itkDebugMacro( "setting moving point set input " << index << " to " << pointSet );
220   if( pointSet != static_cast<PointSetType *>( this->ProcessObject::GetInput( 2 * index + 1 ) ) )
221     {
222     if( !this->ProcessObject::GetInput( 2 * index + 1 ) )
223       {
224       this->m_NumberOfMovingObjects++;
225       }
226     this->ProcessObject::SetNthInput( 2 * index + 1, const_cast<PointSetType *>( pointSet ) );
227     this->Modified();
228     }
229 }
230 
231 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
232 const typename ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>::PointSetType *
233 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
GetMovingPointSet(SizeValueType index) const234 ::GetMovingPointSet( SizeValueType index ) const
235 {
236   itkDebugMacro( "returning moving point set input " << index << " of "
237                                     << static_cast<const PointSetType *>( this->ProcessObject::GetInput( 2 * index + 1 ) ) );
238   return static_cast<const PointSetType *>( this->ProcessObject::GetInput( 2 * index + 1 ) );
239 }
240 
241 /*
242  * Set optimizer weights and do checking for identity.
243  */
244 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
245 void
246 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
SetOptimizerWeights(OptimizerWeightsType & weights)247 ::SetOptimizerWeights( OptimizerWeightsType & weights )
248 {
249   if( weights != this->m_OptimizerWeights )
250     {
251     itkDebugMacro( "setting optimizer weights to " << weights );
252 
253     this->m_OptimizerWeights = weights;
254 
255     // Check to see if optimizer weights are identity to avoid unnecessary
256     // computations.
257 
258     this->m_OptimizerWeightsAreIdentity = true;
259     if( this->m_OptimizerWeights.Size() > 0 )
260       {
261       using OptimizerWeightsValueType = typename OptimizerWeightsType::ValueType;
262       auto tolerance = static_cast<OptimizerWeightsValueType>( 1e-4 );
263 
264       for( SizeValueType i = 0; i < this->m_OptimizerWeights.Size(); i++ )
265         {
266         OptimizerWeightsValueType difference =
267           std::fabs( NumericTraits<OptimizerWeightsValueType>::OneValue() - this->m_OptimizerWeights[i] );
268         if( difference > tolerance  )
269           {
270           this->m_OptimizerWeightsAreIdentity = false;
271           break;
272           }
273         }
274       }
275     this->Modified();
276     }
277 }
278 
279 /*
280  * Initialize by setting the interconnects between components.
281  */
282 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
283 void
284 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
InitializeRegistrationAtEachLevel(const SizeValueType level)285 ::InitializeRegistrationAtEachLevel( const SizeValueType level )
286 {
287 
288   // To avoid casting to a multimetric several times, we do it once and use it
289   // throughout this function if the current enumerated metric type is MULTI_METRIC
290   typename MultiMetricType::Pointer multiMetric = dynamic_cast<MultiMetricType *>( this->m_Metric.GetPointer() );
291 
292   // Sanity checks and find the virtual domain image
293 
294   if( level == 0 )
295     {
296     SizeValueType numberOfObjectPairs = static_cast<unsigned int>( 0.5 * this->GetNumberOfIndexedInputs() );
297     if( numberOfObjectPairs == 0 )
298       {
299       itkExceptionMacro( "There are no input objects." );
300       }
301 
302     if( this->m_Metric->GetMetricCategory() == MetricType::MULTI_METRIC )
303       {
304       this->m_NumberOfMetrics = multiMetric->GetNumberOfMetrics();
305       if( this->m_NumberOfMetrics != numberOfObjectPairs )
306         {
307         itkExceptionMacro( "Mismatch between number of image pairs and the number of metrics." );
308         }
309       }
310     else
311       {
312       this->m_NumberOfMetrics = 1;
313       }
314 
315     // The number of image pairs also includes nullptr image pairs for the point set
316     // metrics
317     if( this->m_NumberOfFixedObjects != this->m_NumberOfMovingObjects )
318       {
319       itkExceptionMacro( "The number of fixed and moving images is not equal." );
320       }
321     }
322 
323   if( !this->m_Optimizer )
324     {
325     itkExceptionMacro( "The optimizer is not present." );
326     }
327   if( !this->m_Metric )
328     {
329     itkExceptionMacro( "The metric is not present." );
330     }
331 
332   auto * movingInitialTransform = const_cast<InitialTransformType*>( this->GetMovingInitialTransform() );
333   auto * fixedInitialTransform = const_cast<InitialTransformType*>( this->GetFixedInitialTransform() );
334 
335   this->m_CurrentIteration = 0;
336   this->m_CurrentMetricValue = 0.0;
337   this->m_CurrentConvergenceValue = 0.0;
338   this->m_IsConverged = false;
339 
340   this->InvokeEvent( MultiResolutionIterationEvent() );
341 
342   // For each level, we adapt the current transform.  For many transforms, e.g.
343   // affine, the base transform adaptor does not do anything.  However, in the
344   // case of other transforms, e.g. the b-spline and displacement field transforms
345   // the fixed parameters are changed to reflect an increase in transform resolution.
346   // This could involve increasing the mesh size of the B-spline transform or
347   // increase the resolution of the displacement field.
348 
349   if( this->m_TransformParametersAdaptorsPerLevel[level] )
350     {
351     this->m_TransformParametersAdaptorsPerLevel[level]->SetTransform( this->m_OutputTransform );
352     this->m_TransformParametersAdaptorsPerLevel[level]->AdaptTransformParameters();
353     }
354 
355   // Set-up the composite transform at initialization
356   // Also, find the virtual domain image
357   if( level == 0 )
358     {
359     this->m_CompositeTransform->ClearTransformQueue();
360 
361     // Since we cannot instantiate a null object from an abstract class, we need to initialize the moving
362     // initial transform as an identity transform.
363     // Nevertheless, we do not need add this transform to the composite transform when it is only an
364     // identity transform. Simply by not setting that, we can save lots of time in jacobian computations
365     // of the composite transform since we can avoid some matrix multiplications.
366 
367     // Skip adding an IdentityTransform to the m_CompositeTransform
368     if( movingInitialTransform != nullptr &&
369       std::string( movingInitialTransform->GetNameOfClass() ) != std::string( "IdentityTransform" ) )
370       {
371       this->m_CompositeTransform->AddTransform( movingInitialTransform );
372       }
373 
374     // If the moving initial transform is a composite transform, unroll
375     // it into m_CompositeTransform.
376     this->m_CompositeTransform->FlattenTransformQueue();
377 
378     if( this->m_InitializeCenterOfLinearOutputTransform )
379       {
380       this->InitializeCenterOfLinearOutputTransform();
381       }
382     this->m_CompositeTransform->AddTransform( this->m_OutputTransform );
383 
384     if( this->m_OptimizerWeights.Size() > 0 )
385       {
386       this->m_Optimizer->SetWeights( this->m_OptimizerWeights );
387       }
388 
389     // Get index of first image metric
390     this->m_FirstImageMetricIndex = -1;
391     if( this->m_NumberOfMetrics == 1 && this->m_Metric->GetMetricCategory() == MetricType::IMAGE_METRIC )
392       {
393       this->m_FirstImageMetricIndex = 0;
394       }
395     else if( this->m_Metric->GetMetricCategory() == MetricType::MULTI_METRIC )
396       {
397       for( SizeValueType n = 0; n < this->m_NumberOfMetrics; n++ )
398         {
399         if( multiMetric->GetMetricQueue()[n]->GetMetricCategory() == MetricType::IMAGE_METRIC )
400           {
401           this->m_FirstImageMetricIndex = n;
402           break;
403           }
404         }
405       }
406 
407     {
408     VirtualImageBaseConstPointer virtualDomainBaseImage = this->GetCurrentLevelVirtualDomainImage();
409 
410     if( virtualDomainBaseImage.IsNull() && this->m_FirstImageMetricIndex >= 0 )
411       {
412       virtualDomainBaseImage =  this->GetFixedImage( this->m_FirstImageMetricIndex );
413       }
414     this->m_VirtualDomainImage = VirtualImageType::New();
415     this->m_VirtualDomainImage->CopyInformation( virtualDomainBaseImage );
416     this->m_VirtualDomainImage->SetRegions( virtualDomainBaseImage->GetLargestPossibleRegion() );
417     this->m_VirtualDomainImage->Allocate();
418     }
419 
420     this->m_FixedImageMasks.clear();
421     this->m_FixedImageMasks.resize( this->m_NumberOfMetrics );
422     this->m_MovingImageMasks.clear();
423     this->m_MovingImageMasks.resize( this->m_NumberOfMetrics );
424 
425     for( SizeValueType n = 0; n < this->m_NumberOfMetrics; n++ )
426       {
427       this->m_FixedImageMasks[n] = nullptr;
428       this->m_MovingImageMasks[n] = nullptr;
429 
430       if( this->m_Metric->GetMetricCategory() == MetricType::IMAGE_METRIC ||
431           ( this->m_Metric->GetMetricCategory() == MetricType::MULTI_METRIC &&
432             multiMetric->GetMetricQueue()[n]->GetMetricCategory() == MetricType::IMAGE_METRIC ) )
433         {
434 
435         if( this->m_Metric->GetMetricCategory() == MetricType::MULTI_METRIC )
436           {
437           this->m_FixedImageMasks[n] = dynamic_cast<ImageMetricType *>( multiMetric->GetMetricQueue()[n].GetPointer() )->GetFixedImageMask();
438           this->m_MovingImageMasks[n] = dynamic_cast<ImageMetricType *>( multiMetric->GetMetricQueue()[n].GetPointer() )->GetMovingImageMask();
439           }
440         else if( this->m_Metric->GetMetricCategory() == MetricType::IMAGE_METRIC )
441           {
442           this->m_FixedImageMasks[n] = dynamic_cast<ImageMetricType *>( this->m_Metric.GetPointer() )->GetFixedImageMask();
443           this->m_MovingImageMasks[n] = dynamic_cast<ImageMetricType *>( this->m_Metric.GetPointer() )->GetMovingImageMask();
444           }
445         else
446           {
447           itkExceptionMacro( "Invalid metric type." )
448           }
449         }
450       }
451     }
452   this->m_CompositeTransform->SetOnlyMostRecentTransformToOptimizeOn();
453 
454   // At each resolution and for each image pair (assuming an image metric), we
455   //   1. subsample the reference domain (typically the fixed image) and/or
456   //   2. smooth the fixed and moving images.
457 
458   typename VirtualImageType::Pointer currentLevelVirtualDomainImage = nullptr;
459   if( this->m_VirtualDomainImage.IsNotNull() )
460     {
461     typename ShrinkFilterType::Pointer shrinkFilter = ShrinkFilterType::New();
462     shrinkFilter->SetShrinkFactors( this->m_ShrinkFactorsPerLevel[level] );
463     shrinkFilter->SetInput( this->m_VirtualDomainImage );
464 
465     currentLevelVirtualDomainImage = shrinkFilter->GetOutput();
466     currentLevelVirtualDomainImage->Update();
467     }
468   else
469     {
470     itkExceptionMacro( "A virtual domain image is not found.  It should be specified in one of the metrics." );
471     }
472 
473   if( this->m_Metric->GetMetricCategory() == MetricType::MULTI_METRIC )
474     {
475     if( fixedInitialTransform )
476       {
477       multiMetric->SetFixedTransform( fixedInitialTransform );
478       }
479     else
480       {
481       using IdentityTransformType = IdentityTransform<RealType, ImageDimension>;
482       typename IdentityTransformType::Pointer defaultFixedInitialTransform = IdentityTransformType::New();
483       multiMetric->SetFixedTransform( defaultFixedInitialTransform );
484       }
485     multiMetric->SetMovingTransform( this->m_CompositeTransform );
486     if( currentLevelVirtualDomainImage.IsNotNull() )
487       {
488       multiMetric->SetVirtualDomainFromImage( currentLevelVirtualDomainImage );
489       }
490 
491     for( SizeValueType n = 0; n < multiMetric->GetNumberOfMetrics(); n++ )
492       {
493       if( multiMetric->GetMetricQueue()[n]->GetMetricCategory() == MetricType::IMAGE_METRIC )
494         {
495         if( currentLevelVirtualDomainImage.IsNotNull() )
496           {
497           dynamic_cast<ImageMetricType *>( multiMetric->GetMetricQueue()[n].GetPointer() )->SetVirtualDomainFromImage( currentLevelVirtualDomainImage );
498           }
499         else
500           {
501           itkExceptionMacro( "Virtual domain image is not specified." );
502           }
503         }
504       else if( multiMetric->GetMetricQueue()[n]->GetMetricCategory() == MetricType::POINT_SET_METRIC )
505         {
506         if( currentLevelVirtualDomainImage.IsNotNull() )
507           {
508           // This casting is a hack as all the metrics should be coordinated such that
509           // they have identical virtual image types.
510 
511           using CasterType = CastImageFilter<VirtualImageType, typename PointSetMetricType::VirtualImageType>;
512           typename CasterType::Pointer caster = CasterType::New();
513           caster->SetInput( currentLevelVirtualDomainImage );
514           caster->Update();
515 
516           dynamic_cast<PointSetMetricType *>( multiMetric->GetMetricQueue()[n].GetPointer() )->SetVirtualDomainFromImage( caster->GetOutput() );
517           }
518         }
519       }
520     }
521   else if( this->m_Metric->GetMetricCategory() == MetricType::IMAGE_METRIC )
522     {
523     typename ImageMetricType::Pointer imageMetric = dynamic_cast<ImageMetricType *>( this->m_Metric.GetPointer() );
524     if( fixedInitialTransform )
525       {
526       imageMetric->SetFixedTransform( fixedInitialTransform );
527       }
528     else
529       {
530       using IdentityTransformType = IdentityTransform<RealType, ImageDimension>;
531       typename IdentityTransformType::Pointer defaultFixedInitialTransform = IdentityTransformType::New();
532       imageMetric->SetFixedTransform( defaultFixedInitialTransform );
533       }
534     imageMetric->SetMovingTransform( this->m_CompositeTransform );
535     if( currentLevelVirtualDomainImage.IsNotNull() )
536       {
537       imageMetric->SetVirtualDomainFromImage( currentLevelVirtualDomainImage );
538       }
539     }
540   else if( this->m_Metric->GetMetricCategory() == MetricType::POINT_SET_METRIC )
541     {
542     typename PointSetMetricType::Pointer pointSetMetric =
543       dynamic_cast<PointSetMetricType *>( this->m_Metric.GetPointer() );
544     if( fixedInitialTransform )
545       {
546       pointSetMetric->SetFixedTransform( fixedInitialTransform );
547       }
548     else
549       {
550       using IdentityTransformType = IdentityTransform<RealType, ImageDimension>;
551       typename IdentityTransformType::Pointer defaultFixedInitialTransform = IdentityTransformType::New();
552       pointSetMetric->SetFixedTransform( defaultFixedInitialTransform );
553       }
554     pointSetMetric->SetMovingTransform( this->m_CompositeTransform );
555     if( currentLevelVirtualDomainImage.IsNotNull() )
556       {
557       // This casting is a hack as all the metrics should be coordinated such that
558       // they have identical virtual image types.
559 
560       using CasterType = CastImageFilter<VirtualImageType, typename PointSetMetricType::VirtualImageType>;
561       typename CasterType::Pointer caster = CasterType::New();
562       caster->SetInput( currentLevelVirtualDomainImage );
563       caster->Update();
564 
565       pointSetMetric->SetVirtualDomainFromImage( caster->GetOutput() );
566       }
567     }
568   else
569     {
570     itkExceptionMacro( "Invalid metric conversion." );
571     }
572 
573   // We update the fixed and moving images for the image metrics and
574   // the fixed and moving point sets for the point set metrics.  Note
575   // that we set the point sets here just like we set the images.
576   // Although this isn't necessary, we want to leave the option for
577   // changing the point sets per level.
578 
579   this->m_FixedSmoothImages.clear();
580   this->m_FixedSmoothImages.resize( this->m_NumberOfMetrics );
581   this->m_MovingSmoothImages.clear();
582   this->m_MovingSmoothImages.resize( this->m_NumberOfMetrics );
583   this->m_FixedPointSets.clear();
584   this->m_FixedPointSets.resize( this->m_NumberOfMetrics );
585   this->m_MovingPointSets.clear();
586   this->m_MovingPointSets.resize( this->m_NumberOfMetrics );
587 
588   for( SizeValueType n = 0; n < this->m_NumberOfMetrics; n++ )
589     {
590     this->m_FixedSmoothImages[n] = nullptr;
591     this->m_MovingSmoothImages[n] = nullptr;
592     this->m_FixedPointSets[n] = nullptr;
593     this->m_MovingPointSets[n] = nullptr;
594 
595     if( this->m_Metric->GetMetricCategory() == MetricType::IMAGE_METRIC ||
596         ( this->m_Metric->GetMetricCategory() == MetricType::MULTI_METRIC &&
597           multiMetric->GetMetricQueue()[n]->GetMetricCategory() == MetricType::IMAGE_METRIC ) )
598       {
599       if ( this->m_SmoothingSigmasPerLevel[level] > 0 )
600         {
601         using FixedImageSmoothingFilterType = SmoothingRecursiveGaussianImageFilter<FixedImageType, FixedImageType>;
602         typename FixedImageSmoothingFilterType::Pointer fixedImageSmoothingFilter = FixedImageSmoothingFilterType::New();
603         typename FixedImageSmoothingFilterType::SigmaArrayType fixedImageSigmaArray( this->m_SmoothingSigmasPerLevel[level] );
604 
605         if( !this->m_SmoothingSigmasAreSpecifiedInPhysicalUnits  )
606           {
607           auto & fixedSpacing  = this->GetFixedImage( n )->GetSpacing();
608           for ( unsigned int i = 0; i < fixedImageSigmaArray.Size(); ++i )
609             {
610             fixedImageSigmaArray[i] *= fixedSpacing[i];
611             }
612           }
613         fixedImageSmoothingFilter->SetSigmaArray( fixedImageSigmaArray );
614         fixedImageSmoothingFilter->SetInput( this->GetFixedImage( n ) );
615 
616         this->m_FixedSmoothImages[n] = fixedImageSmoothingFilter->GetOutput();
617         fixedImageSmoothingFilter->Update();
618         fixedImageSmoothingFilter->GetOutput()->DisconnectPipeline();
619 
620         using MovingImageSmoothingFilterType = SmoothingRecursiveGaussianImageFilter<MovingImageType, MovingImageType>;
621         typename MovingImageSmoothingFilterType::Pointer movingImageSmoothingFilter = MovingImageSmoothingFilterType::New();
622         typename MovingImageSmoothingFilterType::SigmaArrayType movingImageSigmaArray( this->m_SmoothingSigmasPerLevel[level] );
623 
624         if( !this->m_SmoothingSigmasAreSpecifiedInPhysicalUnits  )
625           {
626           auto & movingSpacing  = this->GetMovingImage( n )->GetSpacing();
627           for ( unsigned int i = 0; i < movingImageSigmaArray.Size(); ++i )
628             {
629             movingImageSigmaArray[i] *= movingSpacing[i];
630             }
631           }
632         movingImageSmoothingFilter->SetSigmaArray( movingImageSigmaArray );
633         movingImageSmoothingFilter->SetInput( this->GetMovingImage( n ) );
634 
635         this->m_MovingSmoothImages[n] = movingImageSmoothingFilter->GetOutput();
636         movingImageSmoothingFilter->Update();
637         movingImageSmoothingFilter->GetOutput()->DisconnectPipeline();
638         }
639       else
640         {
641         this->m_MovingSmoothImages[n] = this->GetMovingImage( n );
642         this->m_FixedSmoothImages[n] = this->GetFixedImage( n );
643         }
644 
645       // Update the image metric
646 
647       if( this->m_Metric->GetMetricCategory() == MetricType::MULTI_METRIC )
648         {
649         multiMetric->GetMetricQueue()[n]->SetFixedObject( this->m_FixedSmoothImages[n] );
650         multiMetric->GetMetricQueue()[n]->SetMovingObject( this->m_MovingSmoothImages[n] );
651 
652         dynamic_cast<ImageMetricType *>( multiMetric->GetMetricQueue()[n].GetPointer() )->SetFixedImageMask( this->m_FixedImageMasks[n] );
653         dynamic_cast<ImageMetricType *>( multiMetric->GetMetricQueue()[n].GetPointer() )->SetMovingImageMask( this->m_MovingImageMasks[n] );
654         }
655       else if( this->m_Metric->GetMetricCategory() == MetricType::IMAGE_METRIC )
656         {
657         this->m_Metric->SetFixedObject( this->m_FixedSmoothImages[n] );
658         this->m_Metric->SetMovingObject( this->m_MovingSmoothImages[n] );
659 
660         dynamic_cast<ImageMetricType *>( this->m_Metric.GetPointer() )->SetFixedImageMask( this->m_FixedImageMasks[n] );
661         dynamic_cast<ImageMetricType *>( this->m_Metric.GetPointer() )->SetMovingImageMask( this->m_MovingImageMasks[n] );
662         }
663       else
664         {
665         itkExceptionMacro( "Invalid metric type." )
666         }
667       }
668     else if( this->m_Metric->GetMetricCategory() == MetricType::POINT_SET_METRIC ||
669         ( this->m_Metric->GetMetricCategory() == MetricType::MULTI_METRIC &&
670           multiMetric->GetMetricQueue()[n]->GetMetricCategory() == MetricType::POINT_SET_METRIC ) )
671       {
672       this->m_FixedPointSets[n] = this->GetFixedPointSet( n );
673       this->m_MovingPointSets[n] = this->GetMovingPointSet( n );
674 
675       // Update the point set metric
676 
677       if( this->m_Metric->GetMetricCategory() == MetricType::MULTI_METRIC )
678         {
679         multiMetric->GetMetricQueue()[n]->SetFixedObject( this->GetFixedPointSet( n ) );
680         multiMetric->GetMetricQueue()[n]->SetMovingObject( this->GetMovingPointSet( n ) );
681         }
682       else if( this->m_Metric->GetMetricCategory() == MetricType::POINT_SET_METRIC )
683         {
684         this->m_Metric->SetFixedObject( this->GetFixedPointSet( n ) );
685         this->m_Metric->SetMovingObject( this->GetMovingPointSet( n ) );
686         }
687       else
688         {
689         itkExceptionMacro( "Invalid metric type." )
690         }
691       }
692     else
693       {
694       itkExceptionMacro( "Invalid metric type." )
695       }
696     }
697 
698   if( this->m_MetricSamplingStrategy != NONE )
699     {
700     this->SetMetricSamplePoints();
701     }
702 
703   // Update the optimizer
704 
705   this->m_Optimizer->SetMetric( this->m_Metric );
706 
707   if( ( this->m_Optimizer->GetScales() ).Size() != this->m_OutputTransform->GetNumberOfLocalParameters() )
708     {
709     using ScalesType = typename OptimizerType::ScalesType;
710     ScalesType scales;
711     scales.SetSize( this->m_OutputTransform->GetNumberOfLocalParameters() );
712     scales.Fill( NumericTraits<typename ScalesType::ValueType>::OneValue() );
713     this->m_Optimizer->SetScales( scales );
714     }
715 }
716 
717 
718 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
719 void
720 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
AllocateOutputs()721 ::AllocateOutputs()
722 {
723   const DecoratedInitialTransformType * decoratedInitialTransform = this->GetInitialTransformInput();
724   DecoratedOutputTransformType *decoratedOutputTransform = this->GetOutput();
725 
726   if( decoratedInitialTransform )
727     {
728     const InitialTransformType * initialTransform = decoratedInitialTransform->Get();
729 
730     if( initialTransform )
731       {
732       if( this->GetInPlace() )
733         {
734         // graft the input to the output which may fail if the types
735         // aren't compatible.
736         decoratedOutputTransform->Graft( decoratedInitialTransform );
737 
738         if( decoratedOutputTransform->Get() )
739           {
740           this->m_OutputTransform = decoratedOutputTransform->GetModifiable();
741 
742           // This is generally done in the ReleaseInputs methods,
743           // however we do not need it again
744           const_cast<DecoratedInitialTransformType *>( decoratedInitialTransform )->ReleaseData();
745 
746           // successful in-place grafting
747           itkDebugMacro( "inplace allocation of output transform" );
748           return;
749           }
750         }
751 
752       const auto * initialAsOutputTransform = dynamic_cast<const OutputTransformType*>( initialTransform );
753 
754       if( initialAsOutputTransform )
755         {
756         // Clone performs a deep copy of the parameters and composition
757         this->m_OutputTransform = initialAsOutputTransform->Clone();
758         decoratedOutputTransform->Set( this->m_OutputTransform );
759 
760         // successful deep copy from initial to output
761         itkDebugMacro( "clone copy allocation of output transform" );
762         return;
763         }
764       else
765         {
766         itkExceptionMacro( "Unable to convert InitialTransform input to the OutputTransform type" );
767         }
768 
769       }
770     }
771 
772   // fallback allocation and initialization
773 
774 
775   // initialize to identity? what happens if we re-run with optimized values?
776   itkDebugMacro( "fallback allocation of output transform" );
777 
778   if( !decoratedOutputTransform->Get() )
779     {
780     // the output decorated component is null, allocate
781     OutputTransformPointer ptr;
782     Self::MakeOutputTransform( ptr );
783     decoratedOutputTransform->Set( ptr );
784     }
785 
786   this->m_OutputTransform = this->GetModifiableTransform();
787 }
788 
789 /*
790  * Start the registration
791  */
792 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
793 void
794 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
GenerateData()795 ::GenerateData()
796 {
797   this->AllocateOutputs();
798 
799   // Ensure the same seed is used for each update
800   this->m_CurrentRandomSeed = this->m_RandomSeed;
801 
802   for( this->m_CurrentLevel = 0; this->m_CurrentLevel < this->m_NumberOfLevels; this->m_CurrentLevel++ )
803     {
804     this->InitializeRegistrationAtEachLevel( this->m_CurrentLevel );
805 
806     this->m_Metric->Initialize();
807 
808     this->m_Optimizer->StartOptimization();
809     }
810 }
811 
812 /**
813  * Set the moving transform adaptors per stage
814  */
815 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
816 void
817 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
SetTransformParametersAdaptorsPerLevel(TransformParametersAdaptorsContainerType & adaptors)818 ::SetTransformParametersAdaptorsPerLevel( TransformParametersAdaptorsContainerType & adaptors )
819 {
820   if( this->m_NumberOfLevels != adaptors.size() )
821     {
822     itkExceptionMacro( "The number of levels does not equal the number array size." );
823     }
824   else
825     {
826     itkDebugMacro( "Setting the transform parameter adaptors." );
827     this->m_TransformParametersAdaptorsPerLevel = adaptors;
828     this->Modified();
829     }
830 }
831 
832 /**
833  * Get the moving transform adaptors per stage
834  */
835 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
836 const typename  ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>::TransformParametersAdaptorsContainerType &
837 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
GetTransformParametersAdaptorsPerLevel() const838 ::GetTransformParametersAdaptorsPerLevel() const
839 {
840   return this->m_TransformParametersAdaptorsPerLevel;
841 }
842 
843 /**
844  * Set the number of levels
845  */
846 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
847 void
848 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
SetNumberOfLevels(const SizeValueType numberOfLevels)849 ::SetNumberOfLevels( const SizeValueType numberOfLevels )
850 {
851   if( this->m_NumberOfLevels != numberOfLevels )
852     {
853     this->m_NumberOfLevels = numberOfLevels;
854 
855     // Set default transform adaptors which don't do anything to the input transform
856     // Similarly, fill in some default values for the shrink factors, smoothing sigmas,
857     // and learning rates.
858 
859     this->m_TransformParametersAdaptorsPerLevel.clear();
860     for( SizeValueType level = 0; level < this->m_NumberOfLevels; level++ )
861       {
862       this->m_TransformParametersAdaptorsPerLevel.push_back( nullptr );
863       }
864 
865     for( SizeValueType level = 0; level < this->m_NumberOfLevels; ++level )
866       {
867       ShrinkFactorsPerDimensionContainerType shrinkFactors;
868       shrinkFactors.Fill( 1 );
869       this->SetShrinkFactorsPerDimension( level, shrinkFactors );
870       }
871 
872     this->m_SmoothingSigmasPerLevel.SetSize( this->m_NumberOfLevels );
873     this->m_SmoothingSigmasPerLevel.Fill( 1.0 );
874 
875     this->m_MetricSamplingPercentagePerLevel.SetSize( this->m_NumberOfLevels );
876     this->m_MetricSamplingPercentagePerLevel.Fill( 1.0 );
877 
878     this->Modified();
879     }
880 }
881 
882 /**
883  * Get the metric samples
884  */
885 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
886 void
887 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
SetMetricSamplePoints()888 ::SetMetricSamplePoints()
889 {
890   using VirtualDomainImageType = typename ImageMetricType::VirtualImageType;
891   using VirtualDomainRegionType = typename VirtualDomainImageType::RegionType;
892 
893   const VirtualDomainImageType * virtualImage = nullptr;
894   const FixedImageMaskType * fixedMaskImage = nullptr;
895 
896   SizeValueType numberOfLocalMetrics = 1;
897 
898   typename MultiMetricType::Pointer multiMetric = dynamic_cast<MultiMetricType *>( this->m_Metric.GetPointer() );
899   if( multiMetric )
900     {
901     numberOfLocalMetrics = multiMetric->GetNumberOfMetrics();
902     if( numberOfLocalMetrics < 1 )
903       {
904       itkExceptionMacro( "Input multi metric should have at least one metric component." );
905       }
906     else
907       {
908       typename ImageMetricType::Pointer firstMetric = dynamic_cast<ImageMetricType *>( multiMetric->GetMetricQueue()[0].GetPointer() );
909       if( firstMetric.IsNotNull() )
910         {
911         virtualImage = firstMetric->GetVirtualImage();
912         fixedMaskImage = firstMetric->GetFixedImageMask();
913         }
914       else
915         {
916         itkExceptionMacro( "Invalid metric conversion." );
917         }
918       }
919     }
920   else
921     {
922     typename ImageMetricType::Pointer singleMetric = dynamic_cast<ImageMetricType *>( this->m_Metric.GetPointer() );
923     if( singleMetric.IsNotNull() )
924       {
925       virtualImage = singleMetric->GetVirtualImage();
926       fixedMaskImage = singleMetric->GetFixedImageMask();
927       }
928     else
929       {
930       itkExceptionMacro( "Invalid metric conversion." );
931       }
932     }
933 
934   const VirtualDomainRegionType & virtualDomainRegion = virtualImage->GetRequestedRegion();
935   const typename VirtualDomainImageType::SpacingType oneThirdVirtualSpacing = virtualImage->GetSpacing() / 3.0;
936 
937   for( SizeValueType n = 0; n < numberOfLocalMetrics; n++ )
938     {
939     typename MetricSamplePointSetType::Pointer samplePointSet = MetricSamplePointSetType::New();
940     samplePointSet->Initialize();
941 
942     using SamplePointType = typename MetricSamplePointSetType::PointType;
943 
944     using RandomizerType = Statistics::MersenneTwisterRandomVariateGenerator;
945     typename RandomizerType::Pointer randomizer = RandomizerType::New();
946     if (m_ReseedIterator)
947       {
948       randomizer->SetSeed( );
949       }
950     else
951       {
952       randomizer->SetSeed( m_CurrentRandomSeed++ );
953       }
954 
955 
956     unsigned long index = 0;
957 
958     switch( this->m_MetricSamplingStrategy )
959       {
960       case REGULAR:
961         {
962         const auto sampleCount = static_cast<unsigned long>(
963           std::ceil( 1.0 / this->m_MetricSamplingPercentagePerLevel[this->m_CurrentLevel] ) );
964         unsigned long count = sampleCount; //Start at sampleCount to keep behavior backwards identical, using first element.
965         ImageRegionConstIteratorWithIndex<VirtualDomainImageType> It( virtualImage, virtualDomainRegion );
966         for( It.GoToBegin(); !It.IsAtEnd(); ++It )
967           {
968           if( count == sampleCount )
969             {
970             count=0; //Reset counter
971             SamplePointType point;
972             virtualImage->TransformIndexToPhysicalPoint( It.GetIndex(), point );
973 
974             // randomly perturb the point within a voxel (approximately)
975             for( SizeValueType d = 0; d < ImageDimension; d++ )
976               {
977               point[d] += randomizer->GetNormalVariate() * oneThirdVirtualSpacing[d];
978               }
979             if( !fixedMaskImage || fixedMaskImage->IsInsideInWorldSpace(
980                 point ) )
981               {
982               samplePointSet->SetPoint( index, point );
983               ++index;
984               }
985             }
986           ++count;
987           }
988         break;
989         }
990       case RANDOM:
991         {
992         const unsigned long totalVirtualDomainVoxels = virtualDomainRegion.GetNumberOfPixels();
993         const auto sampleCount = static_cast<unsigned long>(
994          static_cast<float>( totalVirtualDomainVoxels )
995                * this->m_MetricSamplingPercentagePerLevel[this->m_CurrentLevel] );
996         ImageRandomConstIteratorWithIndex<VirtualDomainImageType> ItR( virtualImage, virtualDomainRegion );
997         if (m_ReseedIterator)
998           {
999           ItR.ReinitializeSeed();
1000           }
1001         else
1002           {
1003           ItR.ReinitializeSeed( m_CurrentRandomSeed++ );
1004           }
1005         ItR.SetNumberOfSamples( sampleCount );
1006         for( ItR.GoToBegin(); !ItR.IsAtEnd(); ++ItR )
1007           {
1008           SamplePointType point;
1009           virtualImage->TransformIndexToPhysicalPoint( ItR.GetIndex(), point );
1010 
1011           // randomly perturb the point within a voxel (approximately)
1012           for ( unsigned int d = 0; d < ImageDimension; d++ )
1013             {
1014             point[d] += randomizer->GetNormalVariate() * oneThirdVirtualSpacing[d];
1015             }
1016           if( !fixedMaskImage || fixedMaskImage->IsInsideInWorldSpace(
1017               point ) )
1018             {
1019             samplePointSet->SetPoint( index, point );
1020             ++index;
1021             }
1022           }
1023         break;
1024         }
1025       default:
1026         {
1027         itkExceptionMacro( "Invalid sampling strategy requested." );
1028         }
1029       }
1030 
1031     if( multiMetric )
1032       {
1033       dynamic_cast<ImageMetricType *>( multiMetric->GetMetricQueue()[n].GetPointer() )->SetVirtualSampledPointSet( samplePointSet );
1034       dynamic_cast<ImageMetricType *>( multiMetric->GetMetricQueue()[n].GetPointer() )->UseSampledPointSetOn();
1035       dynamic_cast<ImageMetricType *>( multiMetric->GetMetricQueue()[n].GetPointer() )->UseVirtualSampledPointSetOn();
1036       }
1037     else
1038       {
1039       dynamic_cast<ImageMetricType *>( this->m_Metric.GetPointer() )->SetVirtualSampledPointSet( samplePointSet );
1040       dynamic_cast<ImageMetricType *>( this->m_Metric.GetPointer() )->UseSampledPointSetOn();
1041       dynamic_cast<ImageMetricType *>( this->m_Metric.GetPointer() )->UseVirtualSampledPointSetOn();
1042       }
1043     }
1044 }
1045 
1046 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
1047 void
1048 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
InitializeCenterOfLinearOutputTransform()1049 ::InitializeCenterOfLinearOutputTransform()
1050 {
1051   using MatrixOffsetTransformType = MatrixOffsetTransformBase<typename OutputTransformType::ScalarType,
1052     ImageDimension, ImageDimension>;
1053 
1054   auto * matrixOffsetOutputTransform = dynamic_cast<MatrixOffsetTransformType *>(
1055                                          this->GetModifiableTransform() );
1056 
1057   if( ! matrixOffsetOutputTransform )
1058     {
1059     return;
1060     }
1061 
1062   SizeValueType numberOfTransforms = this->m_CompositeTransform->GetNumberOfTransforms();
1063 
1064   if( numberOfTransforms == 0 )
1065     {
1066     return;
1067     }
1068 
1069   typename TTransform::InputPointType center = matrixOffsetOutputTransform->GetCenter();
1070 
1071   bool optimalIndexFound = false;
1072 
1073   for( int i = numberOfTransforms - 1; i >= 0; i-- )
1074     {
1075     auto * matrixOffsetTransform = dynamic_cast<MatrixOffsetTransformType *>(
1076                                      this->m_CompositeTransform->GetNthTransformModifiablePointer( i ) );
1077     if( matrixOffsetTransform )
1078       {
1079       center = matrixOffsetTransform->GetCenter();
1080       optimalIndexFound = true;
1081       break;
1082       }
1083     }
1084 
1085   if( ! optimalIndexFound )
1086     {
1087     return;
1088     }
1089   else
1090     {
1091     matrixOffsetOutputTransform->SetCenter( center );
1092     }
1093 }
1094 
1095 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
1096 typename ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>::VirtualImageBaseConstPointer
1097 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
GetCurrentLevelVirtualDomainImage()1098 ::GetCurrentLevelVirtualDomainImage()
1099 {
1100   // Get virtual domain image
1101   VirtualImageBaseConstPointer currentLevelVirtualDomainImage = nullptr;
1102   if( this->m_Metric->GetMetricCategory() == MetricType::IMAGE_METRIC )
1103     {
1104     currentLevelVirtualDomainImage = dynamic_cast<ImageMetricType *>( this->m_Metric.GetPointer() )->GetVirtualImage();
1105     }
1106   else if( this->m_Metric->GetMetricCategory() == MetricType::POINT_SET_METRIC )
1107     {
1108     currentLevelVirtualDomainImage = dynamic_cast<PointSetMetricType *>( this->m_Metric.GetPointer() )->GetVirtualImage();
1109     }
1110   else
1111     {
1112     typename MultiMetricType::Pointer multiMetric = dynamic_cast<MultiMetricType *>( this->m_Metric.GetPointer() );
1113     if( multiMetric->GetMetricQueue()[0]->GetMetricCategory() == MetricType::POINT_SET_METRIC )
1114       {
1115       currentLevelVirtualDomainImage = dynamic_cast<PointSetMetricType *>( multiMetric->GetMetricQueue()[0].GetPointer() )->GetVirtualImage();
1116       }
1117     else
1118       {
1119       currentLevelVirtualDomainImage = dynamic_cast<ImageMetricType *>( multiMetric->GetMetricQueue()[0].GetPointer() )->GetVirtualImage();
1120       }
1121     }
1122 
1123   return currentLevelVirtualDomainImage;
1124 }
1125 
1126 /*
1127  * PrintSelf
1128  */
1129 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
1130 void
1131 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
PrintSelf(std::ostream & os,Indent indent) const1132 ::PrintSelf( std::ostream & os, Indent indent ) const
1133 {
1134   Superclass::PrintSelf( os, indent );
1135   Indent indent2 = indent.GetNextIndent();
1136 
1137   os << indent << "Number of levels = " << this->m_NumberOfLevels << std::endl;
1138 
1139   for( SizeValueType level = 0; level < this->m_NumberOfLevels; ++level )
1140     {
1141     os << indent << "Shrink factors (level " << level << "): "
1142        << this->m_ShrinkFactorsPerLevel[level] << std::endl;
1143     }
1144   os << indent << "Smoothing sigmas: " << this->m_SmoothingSigmasPerLevel << std::endl;
1145 
1146   if( this->m_SmoothingSigmasAreSpecifiedInPhysicalUnits == true )
1147     {
1148     os << indent2 << "Smoothing sigmas are specified in physical units." << std::endl;
1149     }
1150   else
1151     {
1152     os << indent2 << "Smoothing sigmas are specified in voxel units." << std::endl;
1153     }
1154 
1155   if( this->m_OptimizerWeights.Size() > 0 )
1156     {
1157     os << indent << "Optimizers weights: " << this->m_OptimizerWeights << std::endl;
1158     }
1159 
1160   os << indent << "Metric sampling strategy: " << this->m_MetricSamplingStrategy << std::endl;
1161 
1162   os << indent << "Metric sampling percentage: ";
1163   for( SizeValueType i = 0; i < this->m_NumberOfLevels; i++ )
1164     {
1165     os << this->m_MetricSamplingPercentagePerLevel[i] << " ";
1166     }
1167   os << std::endl;
1168 
1169   os << indent << "ReseedIterator: " << m_ReseedIterator << std::endl;
1170   os << indent << "RandomSeed: " << m_RandomSeed << std::endl;
1171   os << indent << "CurrentRandomSeed: " << m_CurrentRandomSeed << std::endl;
1172 
1173   os << indent << "InPlace: " << ( this->m_InPlace ? "On" : "Off" ) << std::endl;
1174 
1175   os << indent << "InitializeCenterOfLinearOutputTransform: "
1176      << ( m_InitializeCenterOfLinearOutputTransform ? "On" : "Off" ) << std::endl;
1177 }
1178 
1179 /*
1180  *  Get output transform
1181  */
1182 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
1183 typename ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>::DecoratedOutputTransformType *
1184 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
GetOutput()1185 ::GetOutput()
1186 {
1187   return static_cast<DecoratedOutputTransformType *>( this->ProcessObject::GetOutput( 0 ) );
1188 }
1189 
1190 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
1191 const typename ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>::DecoratedOutputTransformType *
1192 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
GetOutput() const1193 ::GetOutput() const
1194 {
1195   return static_cast<const DecoratedOutputTransformType *>( this->ProcessObject::GetOutput( 0 ) );
1196 }
1197 
1198 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
1199 typename ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>::OutputTransformType *
1200 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
GetModifiableTransform()1201 ::GetModifiableTransform()
1202 {
1203   DecoratedOutputTransformType * temp = this->GetOutput();
1204   // required outputs of process object should always exits
1205   itkAssertInDebugAndIgnoreInReleaseMacro( temp != nullptr );
1206   return temp->GetModifiable();
1207 }
1208 
1209 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
1210 const typename ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>::OutputTransformType *
1211 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
GetTransform() const1212 ::GetTransform() const
1213 {
1214   const  DecoratedOutputTransformType * temp = this->GetOutput();
1215   // required outputs of process object should always exits
1216   itkAssertInDebugAndIgnoreInReleaseMacro( temp != nullptr );
1217   return temp->Get();
1218 }
1219 
1220 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
1221 DataObject::Pointer
1222 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
MakeOutput(DataObjectPointerArraySizeType output)1223 ::MakeOutput( DataObjectPointerArraySizeType output )
1224 {
1225   if (output > 0)
1226   {
1227     itkExceptionMacro("MakeOutput request for an output number larger than the expected number of outputs.");
1228   }
1229   OutputTransformPointer ptr;
1230   Self::MakeOutputTransform(ptr);
1231   DecoratedOutputTransformPointer transformDecorator =  DecoratedOutputTransformType::New();
1232   transformDecorator->Set( ptr );
1233   return transformDecorator.GetPointer();
1234 }
1235 
1236 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
1237 void
1238 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
MetricSamplingReinitializeSeed()1239 ::MetricSamplingReinitializeSeed()
1240 {
1241   if (!m_ReseedIterator)
1242     {
1243     m_ReseedIterator = true;
1244     this->Modified();
1245     }
1246 }
1247 
1248 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
1249 void
1250 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
MetricSamplingReinitializeSeed(int seed)1251 ::MetricSamplingReinitializeSeed(int seed)
1252 {
1253   if (m_ReseedIterator || m_RandomSeed != seed)
1254     {
1255     m_ReseedIterator = false;
1256     m_RandomSeed = seed;
1257     this->Modified();
1258     }
1259 }
1260 
1261 
1262 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
1263 void
1264 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
SetMetricSamplingPercentage(const RealType samplingPercentage)1265 ::SetMetricSamplingPercentage( const RealType samplingPercentage )
1266 {
1267   MetricSamplingPercentageArrayType samplingPercentagePerLevel;
1268   samplingPercentagePerLevel.SetSize( this->m_NumberOfLevels );
1269   samplingPercentagePerLevel.Fill( samplingPercentage );
1270   this->SetMetricSamplingPercentagePerLevel( samplingPercentagePerLevel );
1271 }
1272 
1273 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
1274 void
1275 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
SetMetricSamplingPercentagePerLevel(const MetricSamplingPercentageArrayType & samplingPercentages)1276 ::SetMetricSamplingPercentagePerLevel( const MetricSamplingPercentageArrayType  &samplingPercentages )
1277 {
1278   if( this->m_MetricSamplingPercentagePerLevel != samplingPercentages )
1279     {
1280     for( typename MetricSamplingPercentageArrayType::const_iterator it = samplingPercentages.begin();
1281          it != samplingPercentages.end(); it++ )
1282       {
1283       if( *it <= 0.0 || *it > 1.0 )
1284         {
1285         itkExceptionMacro("sampling percentage outside expected (0,1] range");
1286         }
1287       }
1288     this->m_MetricSamplingPercentagePerLevel = samplingPercentages;
1289     this->Modified();
1290     }
1291 }
1292 
1293 } // end namespace itk
1294 #endif
1295