1 /*=========================================================================
2 *
3 * Copyright Insight Software Consortium
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0.txt
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 *=========================================================================*/
18 #ifndef itkImageRegistrationMethodv4_hxx
19 #define itkImageRegistrationMethodv4_hxx
20
21 #include "itkImageRegistrationMethodv4.h"
22
23 #include "itkSmoothingRecursiveGaussianImageFilter.h"
24 #include "itkGradientDescentOptimizerv4.h"
25 #include "itkImageRandomConstIteratorWithIndex.h"
26 #include "itkImageRegionConstIteratorWithIndex.h"
27 #include "itkImageToImageMetricv4.h"
28 #include "itkIterationReporter.h"
29 #include "itkMattesMutualInformationImageToImageMetricv4.h"
30 #include "itkMersenneTwisterRandomVariateGenerator.h"
31 #include "itkRegistrationParameterScalesFromPhysicalShift.h"
32
33 namespace itk
34 {
35 /**
36 * Constructor
37 */
38 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
39 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
ImageRegistrationMethodv4()40 ::ImageRegistrationMethodv4()
41 {
42 ProcessObject::SetNumberOfRequiredOutputs( 1 );
43 Self::SetPrimaryOutputName( "Transform" );
44
45 // indexed input are alternating fixed and moving images
46 Self::SetPrimaryInputName( "Fixed" );
47 Self::AddRequiredInputName( "Moving", 1 );
48 ProcessObject::SetNumberOfRequiredInputs( 2 );
49
50 // optional named inputs
51 Self::SetInput( "InitialTransform", nullptr );
52 Self::SetInput( "FixedInitialTransform", nullptr );
53 Self::SetInput( "MovingInitialTransform", nullptr );
54
55 this->m_VirtualDomainImage = nullptr;
56
57 Self::ReleaseDataBeforeUpdateFlagOff();
58
59 this->m_CurrentLevel = 0;
60 this->m_CurrentIteration = 0;
61 this->m_CurrentMetricValue = 0.0;
62 this->m_CurrentConvergenceValue = 0.0;
63 this->m_IsConverged = false;
64 this->m_NumberOfFixedObjects = 0;
65 this->m_NumberOfMovingObjects = 0;
66
67 Self::ReleaseDataBeforeUpdateFlagOff();
68
69 this->m_InPlace = true;
70
71 this->m_InitializeCenterOfLinearOutputTransform = true;
72
73 this->m_CompositeTransform = CompositeTransformType::New();
74
75 using DefaultMetricType = MattesMutualInformationImageToImageMetricv4<FixedImageType, MovingImageType, VirtualImageType, RealType>;
76 typename DefaultMetricType::Pointer mutualInformationMetric = DefaultMetricType::New();
77 mutualInformationMetric->SetNumberOfHistogramBins( 20 );
78 mutualInformationMetric->SetUseMovingImageGradientFilter( false );
79 mutualInformationMetric->SetUseFixedImageGradientFilter( false );
80 mutualInformationMetric->SetUseSampledPointSet( false );
81 this->m_Metric = mutualInformationMetric;
82
83 using DefaultScalesEstimatorType = RegistrationParameterScalesFromPhysicalShift<DefaultMetricType>;
84 typename DefaultScalesEstimatorType::Pointer scalesEstimator = DefaultScalesEstimatorType::New();
85 scalesEstimator->SetMetric( mutualInformationMetric );
86 scalesEstimator->SetTransformForward( true );
87
88 using DefaultOptimizerType = GradientDescentOptimizerv4Template<RealType>;
89 typename DefaultOptimizerType::Pointer optimizer = DefaultOptimizerType::New();
90 optimizer->SetLearningRate( 1.0 );
91 optimizer->SetNumberOfIterations( 1000 );
92 optimizer->SetScalesEstimator( scalesEstimator );
93 this->m_Optimizer = optimizer;
94
95 this->m_OptimizerWeights.SetSize( 0 );
96 this->m_OptimizerWeightsAreIdentity = true;
97
98 DecoratedOutputTransformPointer transformDecorator =
99 itkDynamicCastInDebugMode< DecoratedOutputTransformType * >( this->MakeOutput(0).GetPointer() );
100 this->ProcessObject::SetNthOutput( 0, transformDecorator );
101 this->m_OutputTransform = transformDecorator->GetModifiable();
102
103 // By default we set up a 3-level image registration.
104
105 this->m_NumberOfLevels = 0;
106 this->SetNumberOfLevels( 3 );
107
108 this->m_ShrinkFactorsPerLevel.resize( this->m_NumberOfLevels );
109 ShrinkFactorsPerDimensionContainerType shrinkFactors;
110 shrinkFactors.Fill( 2 );
111 this->m_ShrinkFactorsPerLevel[0] = shrinkFactors;
112 shrinkFactors.Fill( 1 );
113 this->m_ShrinkFactorsPerLevel[1] = shrinkFactors;
114 shrinkFactors.Fill( 1 );
115 this->m_ShrinkFactorsPerLevel[2] = shrinkFactors;
116
117 this->m_SmoothingSigmasPerLevel.SetSize( this->m_NumberOfLevels );
118 this->m_SmoothingSigmasPerLevel[0] = 2;
119 this->m_SmoothingSigmasPerLevel[1] = 1;
120 this->m_SmoothingSigmasPerLevel[2] = 0;
121
122 this->m_SmoothingSigmasAreSpecifiedInPhysicalUnits = true;
123
124 this->m_ReseedIterator = false;
125 this->m_RandomSeed = Statistics::MersenneTwisterRandomVariateGenerator::GetNextSeed();
126 this->m_CurrentRandomSeed = this->m_RandomSeed;
127
128 this->m_MetricSamplingStrategy = NONE;
129 this->m_MetricSamplingPercentagePerLevel.SetSize( this->m_NumberOfLevels );
130 this->m_MetricSamplingPercentagePerLevel.Fill( 1.0 );
131 }
132
133 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
134 void
135 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
SetFixedImage(SizeValueType index,const FixedImageType * image)136 ::SetFixedImage( SizeValueType index, const FixedImageType *image )
137 {
138 itkDebugMacro( "setting fixed image input " << index << " to " << image );
139 if( image != static_cast<FixedImageType *>( this->ProcessObject::GetInput( 2 * index ) ) )
140 {
141 if( !this->ProcessObject::GetInput( 2 * index ) )
142 {
143 this->m_NumberOfFixedObjects++;
144 }
145 this->ProcessObject::SetNthInput( 2 * index, const_cast<FixedImageType *>( image ) );
146 this->Modified();
147 }
148 }
149
150 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
151 const typename ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>::FixedImageType *
152 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
GetFixedImage(SizeValueType index) const153 ::GetFixedImage( SizeValueType index ) const
154 {
155 itkDebugMacro( "returning fixed image input " << index << " of "
156 << static_cast<const FixedImageType *>( this->ProcessObject::GetInput( 2 * index ) ) );
157 return static_cast<const FixedImageType *>( this->ProcessObject::GetInput( 2 * index ) );
158 }
159
160 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
161 void
162 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
SetMovingImage(SizeValueType index,const MovingImageType * image)163 ::SetMovingImage( SizeValueType index, const MovingImageType *image )
164 {
165 itkDebugMacro( "setting moving image input " << index << " to " << image );
166 if( image != static_cast<MovingImageType *>( this->ProcessObject::GetInput( 2 * index + 1 ) ) )
167 {
168 if( !this->ProcessObject::GetInput( 2 * index + 1 ) )
169 {
170 this->m_NumberOfMovingObjects++;
171 }
172 this->ProcessObject::SetNthInput( 2 * index + 1, const_cast<MovingImageType *>( image ) );
173 this->Modified();
174 }
175 }
176
177 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
178 const typename ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>::MovingImageType *
179 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
GetMovingImage(SizeValueType index) const180 ::GetMovingImage( SizeValueType index ) const
181 {
182 itkDebugMacro( "returning moving image input " << index << " of "
183 << static_cast<const MovingImageType *>( this->ProcessObject::GetInput( 2 * index + 1 ) ) );
184 return static_cast<const MovingImageType *>( this->ProcessObject::GetInput( 2 * index + 1 ) );
185 }
186
187 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
188 void
189 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
SetFixedPointSet(SizeValueType index,const PointSetType * pointSet)190 ::SetFixedPointSet( SizeValueType index, const PointSetType *pointSet )
191 {
192 itkDebugMacro( "setting fixed point set input " << index << " to " << pointSet );
193 if( pointSet != static_cast<PointSetType *>( this->ProcessObject::GetInput( 2 * index ) ) )
194 {
195 if( !this->ProcessObject::GetInput( 2 * index ) )
196 {
197 this->m_NumberOfFixedObjects++;
198 }
199 this->ProcessObject::SetNthInput( 2 * index, const_cast<PointSetType *>( pointSet ) );
200 this->Modified();
201 }
202 }
203
204 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
205 const typename ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>::PointSetType *
206 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
GetFixedPointSet(SizeValueType index) const207 ::GetFixedPointSet( SizeValueType index ) const
208 {
209 itkDebugMacro( "returning fixed point set input " << index << " of "
210 << static_cast<const PointSetType *>( this->ProcessObject::GetInput( 2 * index ) ) );
211 return static_cast<const PointSetType *>( this->ProcessObject::GetInput( 2 * index ) );
212 }
213
214 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
215 void
216 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
SetMovingPointSet(SizeValueType index,const PointSetType * pointSet)217 ::SetMovingPointSet( SizeValueType index, const PointSetType *pointSet )
218 {
219 itkDebugMacro( "setting moving point set input " << index << " to " << pointSet );
220 if( pointSet != static_cast<PointSetType *>( this->ProcessObject::GetInput( 2 * index + 1 ) ) )
221 {
222 if( !this->ProcessObject::GetInput( 2 * index + 1 ) )
223 {
224 this->m_NumberOfMovingObjects++;
225 }
226 this->ProcessObject::SetNthInput( 2 * index + 1, const_cast<PointSetType *>( pointSet ) );
227 this->Modified();
228 }
229 }
230
231 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
232 const typename ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>::PointSetType *
233 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
GetMovingPointSet(SizeValueType index) const234 ::GetMovingPointSet( SizeValueType index ) const
235 {
236 itkDebugMacro( "returning moving point set input " << index << " of "
237 << static_cast<const PointSetType *>( this->ProcessObject::GetInput( 2 * index + 1 ) ) );
238 return static_cast<const PointSetType *>( this->ProcessObject::GetInput( 2 * index + 1 ) );
239 }
240
241 /*
242 * Set optimizer weights and do checking for identity.
243 */
244 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
245 void
246 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
SetOptimizerWeights(OptimizerWeightsType & weights)247 ::SetOptimizerWeights( OptimizerWeightsType & weights )
248 {
249 if( weights != this->m_OptimizerWeights )
250 {
251 itkDebugMacro( "setting optimizer weights to " << weights );
252
253 this->m_OptimizerWeights = weights;
254
255 // Check to see if optimizer weights are identity to avoid unnecessary
256 // computations.
257
258 this->m_OptimizerWeightsAreIdentity = true;
259 if( this->m_OptimizerWeights.Size() > 0 )
260 {
261 using OptimizerWeightsValueType = typename OptimizerWeightsType::ValueType;
262 auto tolerance = static_cast<OptimizerWeightsValueType>( 1e-4 );
263
264 for( SizeValueType i = 0; i < this->m_OptimizerWeights.Size(); i++ )
265 {
266 OptimizerWeightsValueType difference =
267 std::fabs( NumericTraits<OptimizerWeightsValueType>::OneValue() - this->m_OptimizerWeights[i] );
268 if( difference > tolerance )
269 {
270 this->m_OptimizerWeightsAreIdentity = false;
271 break;
272 }
273 }
274 }
275 this->Modified();
276 }
277 }
278
279 /*
280 * Initialize by setting the interconnects between components.
281 */
282 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
283 void
284 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
InitializeRegistrationAtEachLevel(const SizeValueType level)285 ::InitializeRegistrationAtEachLevel( const SizeValueType level )
286 {
287
288 // To avoid casting to a multimetric several times, we do it once and use it
289 // throughout this function if the current enumerated metric type is MULTI_METRIC
290 typename MultiMetricType::Pointer multiMetric = dynamic_cast<MultiMetricType *>( this->m_Metric.GetPointer() );
291
292 // Sanity checks and find the virtual domain image
293
294 if( level == 0 )
295 {
296 SizeValueType numberOfObjectPairs = static_cast<unsigned int>( 0.5 * this->GetNumberOfIndexedInputs() );
297 if( numberOfObjectPairs == 0 )
298 {
299 itkExceptionMacro( "There are no input objects." );
300 }
301
302 if( this->m_Metric->GetMetricCategory() == MetricType::MULTI_METRIC )
303 {
304 this->m_NumberOfMetrics = multiMetric->GetNumberOfMetrics();
305 if( this->m_NumberOfMetrics != numberOfObjectPairs )
306 {
307 itkExceptionMacro( "Mismatch between number of image pairs and the number of metrics." );
308 }
309 }
310 else
311 {
312 this->m_NumberOfMetrics = 1;
313 }
314
315 // The number of image pairs also includes nullptr image pairs for the point set
316 // metrics
317 if( this->m_NumberOfFixedObjects != this->m_NumberOfMovingObjects )
318 {
319 itkExceptionMacro( "The number of fixed and moving images is not equal." );
320 }
321 }
322
323 if( !this->m_Optimizer )
324 {
325 itkExceptionMacro( "The optimizer is not present." );
326 }
327 if( !this->m_Metric )
328 {
329 itkExceptionMacro( "The metric is not present." );
330 }
331
332 auto * movingInitialTransform = const_cast<InitialTransformType*>( this->GetMovingInitialTransform() );
333 auto * fixedInitialTransform = const_cast<InitialTransformType*>( this->GetFixedInitialTransform() );
334
335 this->m_CurrentIteration = 0;
336 this->m_CurrentMetricValue = 0.0;
337 this->m_CurrentConvergenceValue = 0.0;
338 this->m_IsConverged = false;
339
340 this->InvokeEvent( MultiResolutionIterationEvent() );
341
342 // For each level, we adapt the current transform. For many transforms, e.g.
343 // affine, the base transform adaptor does not do anything. However, in the
344 // case of other transforms, e.g. the b-spline and displacement field transforms
345 // the fixed parameters are changed to reflect an increase in transform resolution.
346 // This could involve increasing the mesh size of the B-spline transform or
347 // increase the resolution of the displacement field.
348
349 if( this->m_TransformParametersAdaptorsPerLevel[level] )
350 {
351 this->m_TransformParametersAdaptorsPerLevel[level]->SetTransform( this->m_OutputTransform );
352 this->m_TransformParametersAdaptorsPerLevel[level]->AdaptTransformParameters();
353 }
354
355 // Set-up the composite transform at initialization
356 // Also, find the virtual domain image
357 if( level == 0 )
358 {
359 this->m_CompositeTransform->ClearTransformQueue();
360
361 // Since we cannot instantiate a null object from an abstract class, we need to initialize the moving
362 // initial transform as an identity transform.
363 // Nevertheless, we do not need add this transform to the composite transform when it is only an
364 // identity transform. Simply by not setting that, we can save lots of time in jacobian computations
365 // of the composite transform since we can avoid some matrix multiplications.
366
367 // Skip adding an IdentityTransform to the m_CompositeTransform
368 if( movingInitialTransform != nullptr &&
369 std::string( movingInitialTransform->GetNameOfClass() ) != std::string( "IdentityTransform" ) )
370 {
371 this->m_CompositeTransform->AddTransform( movingInitialTransform );
372 }
373
374 // If the moving initial transform is a composite transform, unroll
375 // it into m_CompositeTransform.
376 this->m_CompositeTransform->FlattenTransformQueue();
377
378 if( this->m_InitializeCenterOfLinearOutputTransform )
379 {
380 this->InitializeCenterOfLinearOutputTransform();
381 }
382 this->m_CompositeTransform->AddTransform( this->m_OutputTransform );
383
384 if( this->m_OptimizerWeights.Size() > 0 )
385 {
386 this->m_Optimizer->SetWeights( this->m_OptimizerWeights );
387 }
388
389 // Get index of first image metric
390 this->m_FirstImageMetricIndex = -1;
391 if( this->m_NumberOfMetrics == 1 && this->m_Metric->GetMetricCategory() == MetricType::IMAGE_METRIC )
392 {
393 this->m_FirstImageMetricIndex = 0;
394 }
395 else if( this->m_Metric->GetMetricCategory() == MetricType::MULTI_METRIC )
396 {
397 for( SizeValueType n = 0; n < this->m_NumberOfMetrics; n++ )
398 {
399 if( multiMetric->GetMetricQueue()[n]->GetMetricCategory() == MetricType::IMAGE_METRIC )
400 {
401 this->m_FirstImageMetricIndex = n;
402 break;
403 }
404 }
405 }
406
407 {
408 VirtualImageBaseConstPointer virtualDomainBaseImage = this->GetCurrentLevelVirtualDomainImage();
409
410 if( virtualDomainBaseImage.IsNull() && this->m_FirstImageMetricIndex >= 0 )
411 {
412 virtualDomainBaseImage = this->GetFixedImage( this->m_FirstImageMetricIndex );
413 }
414 this->m_VirtualDomainImage = VirtualImageType::New();
415 this->m_VirtualDomainImage->CopyInformation( virtualDomainBaseImage );
416 this->m_VirtualDomainImage->SetRegions( virtualDomainBaseImage->GetLargestPossibleRegion() );
417 this->m_VirtualDomainImage->Allocate();
418 }
419
420 this->m_FixedImageMasks.clear();
421 this->m_FixedImageMasks.resize( this->m_NumberOfMetrics );
422 this->m_MovingImageMasks.clear();
423 this->m_MovingImageMasks.resize( this->m_NumberOfMetrics );
424
425 for( SizeValueType n = 0; n < this->m_NumberOfMetrics; n++ )
426 {
427 this->m_FixedImageMasks[n] = nullptr;
428 this->m_MovingImageMasks[n] = nullptr;
429
430 if( this->m_Metric->GetMetricCategory() == MetricType::IMAGE_METRIC ||
431 ( this->m_Metric->GetMetricCategory() == MetricType::MULTI_METRIC &&
432 multiMetric->GetMetricQueue()[n]->GetMetricCategory() == MetricType::IMAGE_METRIC ) )
433 {
434
435 if( this->m_Metric->GetMetricCategory() == MetricType::MULTI_METRIC )
436 {
437 this->m_FixedImageMasks[n] = dynamic_cast<ImageMetricType *>( multiMetric->GetMetricQueue()[n].GetPointer() )->GetFixedImageMask();
438 this->m_MovingImageMasks[n] = dynamic_cast<ImageMetricType *>( multiMetric->GetMetricQueue()[n].GetPointer() )->GetMovingImageMask();
439 }
440 else if( this->m_Metric->GetMetricCategory() == MetricType::IMAGE_METRIC )
441 {
442 this->m_FixedImageMasks[n] = dynamic_cast<ImageMetricType *>( this->m_Metric.GetPointer() )->GetFixedImageMask();
443 this->m_MovingImageMasks[n] = dynamic_cast<ImageMetricType *>( this->m_Metric.GetPointer() )->GetMovingImageMask();
444 }
445 else
446 {
447 itkExceptionMacro( "Invalid metric type." )
448 }
449 }
450 }
451 }
452 this->m_CompositeTransform->SetOnlyMostRecentTransformToOptimizeOn();
453
454 // At each resolution and for each image pair (assuming an image metric), we
455 // 1. subsample the reference domain (typically the fixed image) and/or
456 // 2. smooth the fixed and moving images.
457
458 typename VirtualImageType::Pointer currentLevelVirtualDomainImage = nullptr;
459 if( this->m_VirtualDomainImage.IsNotNull() )
460 {
461 typename ShrinkFilterType::Pointer shrinkFilter = ShrinkFilterType::New();
462 shrinkFilter->SetShrinkFactors( this->m_ShrinkFactorsPerLevel[level] );
463 shrinkFilter->SetInput( this->m_VirtualDomainImage );
464
465 currentLevelVirtualDomainImage = shrinkFilter->GetOutput();
466 currentLevelVirtualDomainImage->Update();
467 }
468 else
469 {
470 itkExceptionMacro( "A virtual domain image is not found. It should be specified in one of the metrics." );
471 }
472
473 if( this->m_Metric->GetMetricCategory() == MetricType::MULTI_METRIC )
474 {
475 if( fixedInitialTransform )
476 {
477 multiMetric->SetFixedTransform( fixedInitialTransform );
478 }
479 else
480 {
481 using IdentityTransformType = IdentityTransform<RealType, ImageDimension>;
482 typename IdentityTransformType::Pointer defaultFixedInitialTransform = IdentityTransformType::New();
483 multiMetric->SetFixedTransform( defaultFixedInitialTransform );
484 }
485 multiMetric->SetMovingTransform( this->m_CompositeTransform );
486 if( currentLevelVirtualDomainImage.IsNotNull() )
487 {
488 multiMetric->SetVirtualDomainFromImage( currentLevelVirtualDomainImage );
489 }
490
491 for( SizeValueType n = 0; n < multiMetric->GetNumberOfMetrics(); n++ )
492 {
493 if( multiMetric->GetMetricQueue()[n]->GetMetricCategory() == MetricType::IMAGE_METRIC )
494 {
495 if( currentLevelVirtualDomainImage.IsNotNull() )
496 {
497 dynamic_cast<ImageMetricType *>( multiMetric->GetMetricQueue()[n].GetPointer() )->SetVirtualDomainFromImage( currentLevelVirtualDomainImage );
498 }
499 else
500 {
501 itkExceptionMacro( "Virtual domain image is not specified." );
502 }
503 }
504 else if( multiMetric->GetMetricQueue()[n]->GetMetricCategory() == MetricType::POINT_SET_METRIC )
505 {
506 if( currentLevelVirtualDomainImage.IsNotNull() )
507 {
508 // This casting is a hack as all the metrics should be coordinated such that
509 // they have identical virtual image types.
510
511 using CasterType = CastImageFilter<VirtualImageType, typename PointSetMetricType::VirtualImageType>;
512 typename CasterType::Pointer caster = CasterType::New();
513 caster->SetInput( currentLevelVirtualDomainImage );
514 caster->Update();
515
516 dynamic_cast<PointSetMetricType *>( multiMetric->GetMetricQueue()[n].GetPointer() )->SetVirtualDomainFromImage( caster->GetOutput() );
517 }
518 }
519 }
520 }
521 else if( this->m_Metric->GetMetricCategory() == MetricType::IMAGE_METRIC )
522 {
523 typename ImageMetricType::Pointer imageMetric = dynamic_cast<ImageMetricType *>( this->m_Metric.GetPointer() );
524 if( fixedInitialTransform )
525 {
526 imageMetric->SetFixedTransform( fixedInitialTransform );
527 }
528 else
529 {
530 using IdentityTransformType = IdentityTransform<RealType, ImageDimension>;
531 typename IdentityTransformType::Pointer defaultFixedInitialTransform = IdentityTransformType::New();
532 imageMetric->SetFixedTransform( defaultFixedInitialTransform );
533 }
534 imageMetric->SetMovingTransform( this->m_CompositeTransform );
535 if( currentLevelVirtualDomainImage.IsNotNull() )
536 {
537 imageMetric->SetVirtualDomainFromImage( currentLevelVirtualDomainImage );
538 }
539 }
540 else if( this->m_Metric->GetMetricCategory() == MetricType::POINT_SET_METRIC )
541 {
542 typename PointSetMetricType::Pointer pointSetMetric =
543 dynamic_cast<PointSetMetricType *>( this->m_Metric.GetPointer() );
544 if( fixedInitialTransform )
545 {
546 pointSetMetric->SetFixedTransform( fixedInitialTransform );
547 }
548 else
549 {
550 using IdentityTransformType = IdentityTransform<RealType, ImageDimension>;
551 typename IdentityTransformType::Pointer defaultFixedInitialTransform = IdentityTransformType::New();
552 pointSetMetric->SetFixedTransform( defaultFixedInitialTransform );
553 }
554 pointSetMetric->SetMovingTransform( this->m_CompositeTransform );
555 if( currentLevelVirtualDomainImage.IsNotNull() )
556 {
557 // This casting is a hack as all the metrics should be coordinated such that
558 // they have identical virtual image types.
559
560 using CasterType = CastImageFilter<VirtualImageType, typename PointSetMetricType::VirtualImageType>;
561 typename CasterType::Pointer caster = CasterType::New();
562 caster->SetInput( currentLevelVirtualDomainImage );
563 caster->Update();
564
565 pointSetMetric->SetVirtualDomainFromImage( caster->GetOutput() );
566 }
567 }
568 else
569 {
570 itkExceptionMacro( "Invalid metric conversion." );
571 }
572
573 // We update the fixed and moving images for the image metrics and
574 // the fixed and moving point sets for the point set metrics. Note
575 // that we set the point sets here just like we set the images.
576 // Although this isn't necessary, we want to leave the option for
577 // changing the point sets per level.
578
579 this->m_FixedSmoothImages.clear();
580 this->m_FixedSmoothImages.resize( this->m_NumberOfMetrics );
581 this->m_MovingSmoothImages.clear();
582 this->m_MovingSmoothImages.resize( this->m_NumberOfMetrics );
583 this->m_FixedPointSets.clear();
584 this->m_FixedPointSets.resize( this->m_NumberOfMetrics );
585 this->m_MovingPointSets.clear();
586 this->m_MovingPointSets.resize( this->m_NumberOfMetrics );
587
588 for( SizeValueType n = 0; n < this->m_NumberOfMetrics; n++ )
589 {
590 this->m_FixedSmoothImages[n] = nullptr;
591 this->m_MovingSmoothImages[n] = nullptr;
592 this->m_FixedPointSets[n] = nullptr;
593 this->m_MovingPointSets[n] = nullptr;
594
595 if( this->m_Metric->GetMetricCategory() == MetricType::IMAGE_METRIC ||
596 ( this->m_Metric->GetMetricCategory() == MetricType::MULTI_METRIC &&
597 multiMetric->GetMetricQueue()[n]->GetMetricCategory() == MetricType::IMAGE_METRIC ) )
598 {
599 if ( this->m_SmoothingSigmasPerLevel[level] > 0 )
600 {
601 using FixedImageSmoothingFilterType = SmoothingRecursiveGaussianImageFilter<FixedImageType, FixedImageType>;
602 typename FixedImageSmoothingFilterType::Pointer fixedImageSmoothingFilter = FixedImageSmoothingFilterType::New();
603 typename FixedImageSmoothingFilterType::SigmaArrayType fixedImageSigmaArray( this->m_SmoothingSigmasPerLevel[level] );
604
605 if( !this->m_SmoothingSigmasAreSpecifiedInPhysicalUnits )
606 {
607 auto & fixedSpacing = this->GetFixedImage( n )->GetSpacing();
608 for ( unsigned int i = 0; i < fixedImageSigmaArray.Size(); ++i )
609 {
610 fixedImageSigmaArray[i] *= fixedSpacing[i];
611 }
612 }
613 fixedImageSmoothingFilter->SetSigmaArray( fixedImageSigmaArray );
614 fixedImageSmoothingFilter->SetInput( this->GetFixedImage( n ) );
615
616 this->m_FixedSmoothImages[n] = fixedImageSmoothingFilter->GetOutput();
617 fixedImageSmoothingFilter->Update();
618 fixedImageSmoothingFilter->GetOutput()->DisconnectPipeline();
619
620 using MovingImageSmoothingFilterType = SmoothingRecursiveGaussianImageFilter<MovingImageType, MovingImageType>;
621 typename MovingImageSmoothingFilterType::Pointer movingImageSmoothingFilter = MovingImageSmoothingFilterType::New();
622 typename MovingImageSmoothingFilterType::SigmaArrayType movingImageSigmaArray( this->m_SmoothingSigmasPerLevel[level] );
623
624 if( !this->m_SmoothingSigmasAreSpecifiedInPhysicalUnits )
625 {
626 auto & movingSpacing = this->GetMovingImage( n )->GetSpacing();
627 for ( unsigned int i = 0; i < movingImageSigmaArray.Size(); ++i )
628 {
629 movingImageSigmaArray[i] *= movingSpacing[i];
630 }
631 }
632 movingImageSmoothingFilter->SetSigmaArray( movingImageSigmaArray );
633 movingImageSmoothingFilter->SetInput( this->GetMovingImage( n ) );
634
635 this->m_MovingSmoothImages[n] = movingImageSmoothingFilter->GetOutput();
636 movingImageSmoothingFilter->Update();
637 movingImageSmoothingFilter->GetOutput()->DisconnectPipeline();
638 }
639 else
640 {
641 this->m_MovingSmoothImages[n] = this->GetMovingImage( n );
642 this->m_FixedSmoothImages[n] = this->GetFixedImage( n );
643 }
644
645 // Update the image metric
646
647 if( this->m_Metric->GetMetricCategory() == MetricType::MULTI_METRIC )
648 {
649 multiMetric->GetMetricQueue()[n]->SetFixedObject( this->m_FixedSmoothImages[n] );
650 multiMetric->GetMetricQueue()[n]->SetMovingObject( this->m_MovingSmoothImages[n] );
651
652 dynamic_cast<ImageMetricType *>( multiMetric->GetMetricQueue()[n].GetPointer() )->SetFixedImageMask( this->m_FixedImageMasks[n] );
653 dynamic_cast<ImageMetricType *>( multiMetric->GetMetricQueue()[n].GetPointer() )->SetMovingImageMask( this->m_MovingImageMasks[n] );
654 }
655 else if( this->m_Metric->GetMetricCategory() == MetricType::IMAGE_METRIC )
656 {
657 this->m_Metric->SetFixedObject( this->m_FixedSmoothImages[n] );
658 this->m_Metric->SetMovingObject( this->m_MovingSmoothImages[n] );
659
660 dynamic_cast<ImageMetricType *>( this->m_Metric.GetPointer() )->SetFixedImageMask( this->m_FixedImageMasks[n] );
661 dynamic_cast<ImageMetricType *>( this->m_Metric.GetPointer() )->SetMovingImageMask( this->m_MovingImageMasks[n] );
662 }
663 else
664 {
665 itkExceptionMacro( "Invalid metric type." )
666 }
667 }
668 else if( this->m_Metric->GetMetricCategory() == MetricType::POINT_SET_METRIC ||
669 ( this->m_Metric->GetMetricCategory() == MetricType::MULTI_METRIC &&
670 multiMetric->GetMetricQueue()[n]->GetMetricCategory() == MetricType::POINT_SET_METRIC ) )
671 {
672 this->m_FixedPointSets[n] = this->GetFixedPointSet( n );
673 this->m_MovingPointSets[n] = this->GetMovingPointSet( n );
674
675 // Update the point set metric
676
677 if( this->m_Metric->GetMetricCategory() == MetricType::MULTI_METRIC )
678 {
679 multiMetric->GetMetricQueue()[n]->SetFixedObject( this->GetFixedPointSet( n ) );
680 multiMetric->GetMetricQueue()[n]->SetMovingObject( this->GetMovingPointSet( n ) );
681 }
682 else if( this->m_Metric->GetMetricCategory() == MetricType::POINT_SET_METRIC )
683 {
684 this->m_Metric->SetFixedObject( this->GetFixedPointSet( n ) );
685 this->m_Metric->SetMovingObject( this->GetMovingPointSet( n ) );
686 }
687 else
688 {
689 itkExceptionMacro( "Invalid metric type." )
690 }
691 }
692 else
693 {
694 itkExceptionMacro( "Invalid metric type." )
695 }
696 }
697
698 if( this->m_MetricSamplingStrategy != NONE )
699 {
700 this->SetMetricSamplePoints();
701 }
702
703 // Update the optimizer
704
705 this->m_Optimizer->SetMetric( this->m_Metric );
706
707 if( ( this->m_Optimizer->GetScales() ).Size() != this->m_OutputTransform->GetNumberOfLocalParameters() )
708 {
709 using ScalesType = typename OptimizerType::ScalesType;
710 ScalesType scales;
711 scales.SetSize( this->m_OutputTransform->GetNumberOfLocalParameters() );
712 scales.Fill( NumericTraits<typename ScalesType::ValueType>::OneValue() );
713 this->m_Optimizer->SetScales( scales );
714 }
715 }
716
717
718 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
719 void
720 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
AllocateOutputs()721 ::AllocateOutputs()
722 {
723 const DecoratedInitialTransformType * decoratedInitialTransform = this->GetInitialTransformInput();
724 DecoratedOutputTransformType *decoratedOutputTransform = this->GetOutput();
725
726 if( decoratedInitialTransform )
727 {
728 const InitialTransformType * initialTransform = decoratedInitialTransform->Get();
729
730 if( initialTransform )
731 {
732 if( this->GetInPlace() )
733 {
734 // graft the input to the output which may fail if the types
735 // aren't compatible.
736 decoratedOutputTransform->Graft( decoratedInitialTransform );
737
738 if( decoratedOutputTransform->Get() )
739 {
740 this->m_OutputTransform = decoratedOutputTransform->GetModifiable();
741
742 // This is generally done in the ReleaseInputs methods,
743 // however we do not need it again
744 const_cast<DecoratedInitialTransformType *>( decoratedInitialTransform )->ReleaseData();
745
746 // successful in-place grafting
747 itkDebugMacro( "inplace allocation of output transform" );
748 return;
749 }
750 }
751
752 const auto * initialAsOutputTransform = dynamic_cast<const OutputTransformType*>( initialTransform );
753
754 if( initialAsOutputTransform )
755 {
756 // Clone performs a deep copy of the parameters and composition
757 this->m_OutputTransform = initialAsOutputTransform->Clone();
758 decoratedOutputTransform->Set( this->m_OutputTransform );
759
760 // successful deep copy from initial to output
761 itkDebugMacro( "clone copy allocation of output transform" );
762 return;
763 }
764 else
765 {
766 itkExceptionMacro( "Unable to convert InitialTransform input to the OutputTransform type" );
767 }
768
769 }
770 }
771
772 // fallback allocation and initialization
773
774
775 // initialize to identity? what happens if we re-run with optimized values?
776 itkDebugMacro( "fallback allocation of output transform" );
777
778 if( !decoratedOutputTransform->Get() )
779 {
780 // the output decorated component is null, allocate
781 OutputTransformPointer ptr;
782 Self::MakeOutputTransform( ptr );
783 decoratedOutputTransform->Set( ptr );
784 }
785
786 this->m_OutputTransform = this->GetModifiableTransform();
787 }
788
789 /*
790 * Start the registration
791 */
792 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
793 void
794 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
GenerateData()795 ::GenerateData()
796 {
797 this->AllocateOutputs();
798
799 // Ensure the same seed is used for each update
800 this->m_CurrentRandomSeed = this->m_RandomSeed;
801
802 for( this->m_CurrentLevel = 0; this->m_CurrentLevel < this->m_NumberOfLevels; this->m_CurrentLevel++ )
803 {
804 this->InitializeRegistrationAtEachLevel( this->m_CurrentLevel );
805
806 this->m_Metric->Initialize();
807
808 this->m_Optimizer->StartOptimization();
809 }
810 }
811
812 /**
813 * Set the moving transform adaptors per stage
814 */
815 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
816 void
817 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
SetTransformParametersAdaptorsPerLevel(TransformParametersAdaptorsContainerType & adaptors)818 ::SetTransformParametersAdaptorsPerLevel( TransformParametersAdaptorsContainerType & adaptors )
819 {
820 if( this->m_NumberOfLevels != adaptors.size() )
821 {
822 itkExceptionMacro( "The number of levels does not equal the number array size." );
823 }
824 else
825 {
826 itkDebugMacro( "Setting the transform parameter adaptors." );
827 this->m_TransformParametersAdaptorsPerLevel = adaptors;
828 this->Modified();
829 }
830 }
831
832 /**
833 * Get the moving transform adaptors per stage
834 */
835 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
836 const typename ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>::TransformParametersAdaptorsContainerType &
837 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
GetTransformParametersAdaptorsPerLevel() const838 ::GetTransformParametersAdaptorsPerLevel() const
839 {
840 return this->m_TransformParametersAdaptorsPerLevel;
841 }
842
843 /**
844 * Set the number of levels
845 */
846 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
847 void
848 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
SetNumberOfLevels(const SizeValueType numberOfLevels)849 ::SetNumberOfLevels( const SizeValueType numberOfLevels )
850 {
851 if( this->m_NumberOfLevels != numberOfLevels )
852 {
853 this->m_NumberOfLevels = numberOfLevels;
854
855 // Set default transform adaptors which don't do anything to the input transform
856 // Similarly, fill in some default values for the shrink factors, smoothing sigmas,
857 // and learning rates.
858
859 this->m_TransformParametersAdaptorsPerLevel.clear();
860 for( SizeValueType level = 0; level < this->m_NumberOfLevels; level++ )
861 {
862 this->m_TransformParametersAdaptorsPerLevel.push_back( nullptr );
863 }
864
865 for( SizeValueType level = 0; level < this->m_NumberOfLevels; ++level )
866 {
867 ShrinkFactorsPerDimensionContainerType shrinkFactors;
868 shrinkFactors.Fill( 1 );
869 this->SetShrinkFactorsPerDimension( level, shrinkFactors );
870 }
871
872 this->m_SmoothingSigmasPerLevel.SetSize( this->m_NumberOfLevels );
873 this->m_SmoothingSigmasPerLevel.Fill( 1.0 );
874
875 this->m_MetricSamplingPercentagePerLevel.SetSize( this->m_NumberOfLevels );
876 this->m_MetricSamplingPercentagePerLevel.Fill( 1.0 );
877
878 this->Modified();
879 }
880 }
881
882 /**
883 * Get the metric samples
884 */
885 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
886 void
887 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
SetMetricSamplePoints()888 ::SetMetricSamplePoints()
889 {
890 using VirtualDomainImageType = typename ImageMetricType::VirtualImageType;
891 using VirtualDomainRegionType = typename VirtualDomainImageType::RegionType;
892
893 const VirtualDomainImageType * virtualImage = nullptr;
894 const FixedImageMaskType * fixedMaskImage = nullptr;
895
896 SizeValueType numberOfLocalMetrics = 1;
897
898 typename MultiMetricType::Pointer multiMetric = dynamic_cast<MultiMetricType *>( this->m_Metric.GetPointer() );
899 if( multiMetric )
900 {
901 numberOfLocalMetrics = multiMetric->GetNumberOfMetrics();
902 if( numberOfLocalMetrics < 1 )
903 {
904 itkExceptionMacro( "Input multi metric should have at least one metric component." );
905 }
906 else
907 {
908 typename ImageMetricType::Pointer firstMetric = dynamic_cast<ImageMetricType *>( multiMetric->GetMetricQueue()[0].GetPointer() );
909 if( firstMetric.IsNotNull() )
910 {
911 virtualImage = firstMetric->GetVirtualImage();
912 fixedMaskImage = firstMetric->GetFixedImageMask();
913 }
914 else
915 {
916 itkExceptionMacro( "Invalid metric conversion." );
917 }
918 }
919 }
920 else
921 {
922 typename ImageMetricType::Pointer singleMetric = dynamic_cast<ImageMetricType *>( this->m_Metric.GetPointer() );
923 if( singleMetric.IsNotNull() )
924 {
925 virtualImage = singleMetric->GetVirtualImage();
926 fixedMaskImage = singleMetric->GetFixedImageMask();
927 }
928 else
929 {
930 itkExceptionMacro( "Invalid metric conversion." );
931 }
932 }
933
934 const VirtualDomainRegionType & virtualDomainRegion = virtualImage->GetRequestedRegion();
935 const typename VirtualDomainImageType::SpacingType oneThirdVirtualSpacing = virtualImage->GetSpacing() / 3.0;
936
937 for( SizeValueType n = 0; n < numberOfLocalMetrics; n++ )
938 {
939 typename MetricSamplePointSetType::Pointer samplePointSet = MetricSamplePointSetType::New();
940 samplePointSet->Initialize();
941
942 using SamplePointType = typename MetricSamplePointSetType::PointType;
943
944 using RandomizerType = Statistics::MersenneTwisterRandomVariateGenerator;
945 typename RandomizerType::Pointer randomizer = RandomizerType::New();
946 if (m_ReseedIterator)
947 {
948 randomizer->SetSeed( );
949 }
950 else
951 {
952 randomizer->SetSeed( m_CurrentRandomSeed++ );
953 }
954
955
956 unsigned long index = 0;
957
958 switch( this->m_MetricSamplingStrategy )
959 {
960 case REGULAR:
961 {
962 const auto sampleCount = static_cast<unsigned long>(
963 std::ceil( 1.0 / this->m_MetricSamplingPercentagePerLevel[this->m_CurrentLevel] ) );
964 unsigned long count = sampleCount; //Start at sampleCount to keep behavior backwards identical, using first element.
965 ImageRegionConstIteratorWithIndex<VirtualDomainImageType> It( virtualImage, virtualDomainRegion );
966 for( It.GoToBegin(); !It.IsAtEnd(); ++It )
967 {
968 if( count == sampleCount )
969 {
970 count=0; //Reset counter
971 SamplePointType point;
972 virtualImage->TransformIndexToPhysicalPoint( It.GetIndex(), point );
973
974 // randomly perturb the point within a voxel (approximately)
975 for( SizeValueType d = 0; d < ImageDimension; d++ )
976 {
977 point[d] += randomizer->GetNormalVariate() * oneThirdVirtualSpacing[d];
978 }
979 if( !fixedMaskImage || fixedMaskImage->IsInsideInWorldSpace(
980 point ) )
981 {
982 samplePointSet->SetPoint( index, point );
983 ++index;
984 }
985 }
986 ++count;
987 }
988 break;
989 }
990 case RANDOM:
991 {
992 const unsigned long totalVirtualDomainVoxels = virtualDomainRegion.GetNumberOfPixels();
993 const auto sampleCount = static_cast<unsigned long>(
994 static_cast<float>( totalVirtualDomainVoxels )
995 * this->m_MetricSamplingPercentagePerLevel[this->m_CurrentLevel] );
996 ImageRandomConstIteratorWithIndex<VirtualDomainImageType> ItR( virtualImage, virtualDomainRegion );
997 if (m_ReseedIterator)
998 {
999 ItR.ReinitializeSeed();
1000 }
1001 else
1002 {
1003 ItR.ReinitializeSeed( m_CurrentRandomSeed++ );
1004 }
1005 ItR.SetNumberOfSamples( sampleCount );
1006 for( ItR.GoToBegin(); !ItR.IsAtEnd(); ++ItR )
1007 {
1008 SamplePointType point;
1009 virtualImage->TransformIndexToPhysicalPoint( ItR.GetIndex(), point );
1010
1011 // randomly perturb the point within a voxel (approximately)
1012 for ( unsigned int d = 0; d < ImageDimension; d++ )
1013 {
1014 point[d] += randomizer->GetNormalVariate() * oneThirdVirtualSpacing[d];
1015 }
1016 if( !fixedMaskImage || fixedMaskImage->IsInsideInWorldSpace(
1017 point ) )
1018 {
1019 samplePointSet->SetPoint( index, point );
1020 ++index;
1021 }
1022 }
1023 break;
1024 }
1025 default:
1026 {
1027 itkExceptionMacro( "Invalid sampling strategy requested." );
1028 }
1029 }
1030
1031 if( multiMetric )
1032 {
1033 dynamic_cast<ImageMetricType *>( multiMetric->GetMetricQueue()[n].GetPointer() )->SetVirtualSampledPointSet( samplePointSet );
1034 dynamic_cast<ImageMetricType *>( multiMetric->GetMetricQueue()[n].GetPointer() )->UseSampledPointSetOn();
1035 dynamic_cast<ImageMetricType *>( multiMetric->GetMetricQueue()[n].GetPointer() )->UseVirtualSampledPointSetOn();
1036 }
1037 else
1038 {
1039 dynamic_cast<ImageMetricType *>( this->m_Metric.GetPointer() )->SetVirtualSampledPointSet( samplePointSet );
1040 dynamic_cast<ImageMetricType *>( this->m_Metric.GetPointer() )->UseSampledPointSetOn();
1041 dynamic_cast<ImageMetricType *>( this->m_Metric.GetPointer() )->UseVirtualSampledPointSetOn();
1042 }
1043 }
1044 }
1045
1046 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
1047 void
1048 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
InitializeCenterOfLinearOutputTransform()1049 ::InitializeCenterOfLinearOutputTransform()
1050 {
1051 using MatrixOffsetTransformType = MatrixOffsetTransformBase<typename OutputTransformType::ScalarType,
1052 ImageDimension, ImageDimension>;
1053
1054 auto * matrixOffsetOutputTransform = dynamic_cast<MatrixOffsetTransformType *>(
1055 this->GetModifiableTransform() );
1056
1057 if( ! matrixOffsetOutputTransform )
1058 {
1059 return;
1060 }
1061
1062 SizeValueType numberOfTransforms = this->m_CompositeTransform->GetNumberOfTransforms();
1063
1064 if( numberOfTransforms == 0 )
1065 {
1066 return;
1067 }
1068
1069 typename TTransform::InputPointType center = matrixOffsetOutputTransform->GetCenter();
1070
1071 bool optimalIndexFound = false;
1072
1073 for( int i = numberOfTransforms - 1; i >= 0; i-- )
1074 {
1075 auto * matrixOffsetTransform = dynamic_cast<MatrixOffsetTransformType *>(
1076 this->m_CompositeTransform->GetNthTransformModifiablePointer( i ) );
1077 if( matrixOffsetTransform )
1078 {
1079 center = matrixOffsetTransform->GetCenter();
1080 optimalIndexFound = true;
1081 break;
1082 }
1083 }
1084
1085 if( ! optimalIndexFound )
1086 {
1087 return;
1088 }
1089 else
1090 {
1091 matrixOffsetOutputTransform->SetCenter( center );
1092 }
1093 }
1094
1095 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
1096 typename ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>::VirtualImageBaseConstPointer
1097 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
GetCurrentLevelVirtualDomainImage()1098 ::GetCurrentLevelVirtualDomainImage()
1099 {
1100 // Get virtual domain image
1101 VirtualImageBaseConstPointer currentLevelVirtualDomainImage = nullptr;
1102 if( this->m_Metric->GetMetricCategory() == MetricType::IMAGE_METRIC )
1103 {
1104 currentLevelVirtualDomainImage = dynamic_cast<ImageMetricType *>( this->m_Metric.GetPointer() )->GetVirtualImage();
1105 }
1106 else if( this->m_Metric->GetMetricCategory() == MetricType::POINT_SET_METRIC )
1107 {
1108 currentLevelVirtualDomainImage = dynamic_cast<PointSetMetricType *>( this->m_Metric.GetPointer() )->GetVirtualImage();
1109 }
1110 else
1111 {
1112 typename MultiMetricType::Pointer multiMetric = dynamic_cast<MultiMetricType *>( this->m_Metric.GetPointer() );
1113 if( multiMetric->GetMetricQueue()[0]->GetMetricCategory() == MetricType::POINT_SET_METRIC )
1114 {
1115 currentLevelVirtualDomainImage = dynamic_cast<PointSetMetricType *>( multiMetric->GetMetricQueue()[0].GetPointer() )->GetVirtualImage();
1116 }
1117 else
1118 {
1119 currentLevelVirtualDomainImage = dynamic_cast<ImageMetricType *>( multiMetric->GetMetricQueue()[0].GetPointer() )->GetVirtualImage();
1120 }
1121 }
1122
1123 return currentLevelVirtualDomainImage;
1124 }
1125
1126 /*
1127 * PrintSelf
1128 */
1129 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
1130 void
1131 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
PrintSelf(std::ostream & os,Indent indent) const1132 ::PrintSelf( std::ostream & os, Indent indent ) const
1133 {
1134 Superclass::PrintSelf( os, indent );
1135 Indent indent2 = indent.GetNextIndent();
1136
1137 os << indent << "Number of levels = " << this->m_NumberOfLevels << std::endl;
1138
1139 for( SizeValueType level = 0; level < this->m_NumberOfLevels; ++level )
1140 {
1141 os << indent << "Shrink factors (level " << level << "): "
1142 << this->m_ShrinkFactorsPerLevel[level] << std::endl;
1143 }
1144 os << indent << "Smoothing sigmas: " << this->m_SmoothingSigmasPerLevel << std::endl;
1145
1146 if( this->m_SmoothingSigmasAreSpecifiedInPhysicalUnits == true )
1147 {
1148 os << indent2 << "Smoothing sigmas are specified in physical units." << std::endl;
1149 }
1150 else
1151 {
1152 os << indent2 << "Smoothing sigmas are specified in voxel units." << std::endl;
1153 }
1154
1155 if( this->m_OptimizerWeights.Size() > 0 )
1156 {
1157 os << indent << "Optimizers weights: " << this->m_OptimizerWeights << std::endl;
1158 }
1159
1160 os << indent << "Metric sampling strategy: " << this->m_MetricSamplingStrategy << std::endl;
1161
1162 os << indent << "Metric sampling percentage: ";
1163 for( SizeValueType i = 0; i < this->m_NumberOfLevels; i++ )
1164 {
1165 os << this->m_MetricSamplingPercentagePerLevel[i] << " ";
1166 }
1167 os << std::endl;
1168
1169 os << indent << "ReseedIterator: " << m_ReseedIterator << std::endl;
1170 os << indent << "RandomSeed: " << m_RandomSeed << std::endl;
1171 os << indent << "CurrentRandomSeed: " << m_CurrentRandomSeed << std::endl;
1172
1173 os << indent << "InPlace: " << ( this->m_InPlace ? "On" : "Off" ) << std::endl;
1174
1175 os << indent << "InitializeCenterOfLinearOutputTransform: "
1176 << ( m_InitializeCenterOfLinearOutputTransform ? "On" : "Off" ) << std::endl;
1177 }
1178
1179 /*
1180 * Get output transform
1181 */
1182 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
1183 typename ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>::DecoratedOutputTransformType *
1184 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
GetOutput()1185 ::GetOutput()
1186 {
1187 return static_cast<DecoratedOutputTransformType *>( this->ProcessObject::GetOutput( 0 ) );
1188 }
1189
1190 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
1191 const typename ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>::DecoratedOutputTransformType *
1192 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
GetOutput() const1193 ::GetOutput() const
1194 {
1195 return static_cast<const DecoratedOutputTransformType *>( this->ProcessObject::GetOutput( 0 ) );
1196 }
1197
1198 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
1199 typename ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>::OutputTransformType *
1200 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
GetModifiableTransform()1201 ::GetModifiableTransform()
1202 {
1203 DecoratedOutputTransformType * temp = this->GetOutput();
1204 // required outputs of process object should always exits
1205 itkAssertInDebugAndIgnoreInReleaseMacro( temp != nullptr );
1206 return temp->GetModifiable();
1207 }
1208
1209 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
1210 const typename ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>::OutputTransformType *
1211 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
GetTransform() const1212 ::GetTransform() const
1213 {
1214 const DecoratedOutputTransformType * temp = this->GetOutput();
1215 // required outputs of process object should always exits
1216 itkAssertInDebugAndIgnoreInReleaseMacro( temp != nullptr );
1217 return temp->Get();
1218 }
1219
1220 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
1221 DataObject::Pointer
1222 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
MakeOutput(DataObjectPointerArraySizeType output)1223 ::MakeOutput( DataObjectPointerArraySizeType output )
1224 {
1225 if (output > 0)
1226 {
1227 itkExceptionMacro("MakeOutput request for an output number larger than the expected number of outputs.");
1228 }
1229 OutputTransformPointer ptr;
1230 Self::MakeOutputTransform(ptr);
1231 DecoratedOutputTransformPointer transformDecorator = DecoratedOutputTransformType::New();
1232 transformDecorator->Set( ptr );
1233 return transformDecorator.GetPointer();
1234 }
1235
1236 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
1237 void
1238 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
MetricSamplingReinitializeSeed()1239 ::MetricSamplingReinitializeSeed()
1240 {
1241 if (!m_ReseedIterator)
1242 {
1243 m_ReseedIterator = true;
1244 this->Modified();
1245 }
1246 }
1247
1248 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
1249 void
1250 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
MetricSamplingReinitializeSeed(int seed)1251 ::MetricSamplingReinitializeSeed(int seed)
1252 {
1253 if (m_ReseedIterator || m_RandomSeed != seed)
1254 {
1255 m_ReseedIterator = false;
1256 m_RandomSeed = seed;
1257 this->Modified();
1258 }
1259 }
1260
1261
1262 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
1263 void
1264 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
SetMetricSamplingPercentage(const RealType samplingPercentage)1265 ::SetMetricSamplingPercentage( const RealType samplingPercentage )
1266 {
1267 MetricSamplingPercentageArrayType samplingPercentagePerLevel;
1268 samplingPercentagePerLevel.SetSize( this->m_NumberOfLevels );
1269 samplingPercentagePerLevel.Fill( samplingPercentage );
1270 this->SetMetricSamplingPercentagePerLevel( samplingPercentagePerLevel );
1271 }
1272
1273 template<typename TFixedImage, typename TMovingImage, typename TTransform, typename TVirtualImage, typename TPointSet>
1274 void
1275 ImageRegistrationMethodv4<TFixedImage, TMovingImage, TTransform, TVirtualImage, TPointSet>
SetMetricSamplingPercentagePerLevel(const MetricSamplingPercentageArrayType & samplingPercentages)1276 ::SetMetricSamplingPercentagePerLevel( const MetricSamplingPercentageArrayType &samplingPercentages )
1277 {
1278 if( this->m_MetricSamplingPercentagePerLevel != samplingPercentages )
1279 {
1280 for( typename MetricSamplingPercentageArrayType::const_iterator it = samplingPercentages.begin();
1281 it != samplingPercentages.end(); it++ )
1282 {
1283 if( *it <= 0.0 || *it > 1.0 )
1284 {
1285 itkExceptionMacro("sampling percentage outside expected (0,1] range");
1286 }
1287 }
1288 this->m_MetricSamplingPercentagePerLevel = samplingPercentages;
1289 this->Modified();
1290 }
1291 }
1292
1293 } // end namespace itk
1294 #endif
1295