1 /*=========================================================================
2 *
3 * Copyright Insight Software Consortium
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0.txt
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 *=========================================================================*/
18
19 #include "itkImageFileReader.h"
20 #include "itkImageFileWriter.h"
21
22 #include "itkImageRegistrationMethodv4.h"
23
24 #include "itkAffineTransform.h"
25 #include "itkANTSNeighborhoodCorrelationImageToImageMetricv4.h"
26 #include "itkBSplineExponentialDiffeomorphicTransform.h"
27 #include "itkBSplineExponentialDiffeomorphicTransformParametersAdaptor.h"
28 #include "itkComposeDisplacementFieldsImageFilter.h"
29 #include "itkVectorMagnitudeImageFilter.h"
30 #include "itkStatisticsImageFilter.h"
31 #include "itkTestingMacros.h"
32
33 template<typename TFilter>
34 class CommandIterationUpdate : public itk::Command
35 {
36 public:
37 using Self = CommandIterationUpdate;
38 using Superclass = itk::Command;
39 using Pointer = itk::SmartPointer<Self>;
40 itkNewMacro( Self );
41
42 protected:
43 CommandIterationUpdate() = default;
44
45 public:
46
Execute(itk::Object * caller,const itk::EventObject & event)47 void Execute(itk::Object *caller, const itk::EventObject & event) override
48 {
49 Execute( (const itk::Object *) caller, event);
50 }
51
Execute(const itk::Object * object,const itk::EventObject & event)52 void Execute(const itk::Object * object, const itk::EventObject & event) override
53 {
54 const auto * filter = static_cast< const TFilter * >( object );
55 if( typeid( event ) != typeid( itk::IterationEvent ) )
56 {
57 return;
58 }
59
60 unsigned int currentLevel = filter->GetCurrentLevel();
61 typename TFilter::ShrinkFactorsPerDimensionContainerType shrinkFactors = filter->GetShrinkFactorsPerDimension( currentLevel );
62 typename TFilter::SmoothingSigmasArrayType smoothingSigmas = filter->GetSmoothingSigmasPerLevel();
63 typename TFilter::TransformParametersAdaptorsContainerType adaptors = filter->GetTransformParametersAdaptorsPerLevel();
64
65 const itk::ObjectToObjectOptimizerBase * optimizerBase = filter->GetOptimizer();
66 using GradientDescentOptimizerv4Type = itk::GradientDescentOptimizerv4;
67 typename GradientDescentOptimizerv4Type::ConstPointer optimizer =
68 dynamic_cast<const GradientDescentOptimizerv4Type *>(optimizerBase);
69 if( !optimizer )
70 {
71 itkGenericExceptionMacro( "Error dynamic_cast failed" );
72 }
73 typename GradientDescentOptimizerv4Type::DerivativeType gradient = optimizer->GetGradient();
74
75 /* orig
76 std::cout << " Current level = " << currentLevel << std::endl;
77 std::cout << " shrink factor = " << shrinkFactors[currentLevel] << std::endl;
78 std::cout << " smoothing sigma = " << smoothingSigmas[currentLevel] << std::endl;
79 std::cout << " required fixed parameters = " << adaptors[currentLevel]->GetRequiredFixedParameters() << std::endl;
80 */
81
82 //debug:
83 std::cout << " CL Current level: " << currentLevel << std::endl;
84 std::cout << " SF Shrink factor: " << shrinkFactors << std::endl;
85 std::cout << " SS Smoothing sigma: " << smoothingSigmas[currentLevel] << std::endl;
86 std::cout << " RFP Required fixed params: " << adaptors[currentLevel]->GetRequiredFixedParameters() << std::endl;
87 std::cout << " LR Final learning rate: " << optimizer->GetLearningRate() << std::endl;
88 std::cout << " FM Final metric value: " << optimizer->GetCurrentMetricValue() << std::endl;
89 std::cout << " SC Optimizer scales: " << optimizer->GetScales() << std::endl;
90 std::cout << " FG Final metric gradient (sample of values): ";
91 if( gradient.GetSize() < 10 )
92 {
93 std::cout << gradient;
94 }
95 else
96 {
97 for( itk::SizeValueType i = 0; i < gradient.GetSize(); i += (gradient.GetSize() / 16) )
98 {
99 std::cout << gradient[i] << " ";
100 }
101 }
102 std::cout << std::endl;
103 }
104 };
105
106 template <unsigned int VImageDimension>
PerformBSplineExpImageRegistration(int argc,char * argv[])107 int PerformBSplineExpImageRegistration( int argc, char *argv[] )
108 {
109 if( argc < 6 )
110 {
111 std::cout << itkNameOfTestExecutableMacro(argv) << " imageDimension fixedImage movingImage outputImage numberOfAffineIterations numberOfDeformableIterations" << std::endl;
112 exit( 1 );
113 }
114
115 using PixelType = double;
116 using FixedImageType = itk::Image<PixelType, VImageDimension>;
117 using MovingImageType = itk::Image<PixelType, VImageDimension>;
118
119 using ImageReaderType = itk::ImageFileReader<FixedImageType>;
120
121 typename ImageReaderType::Pointer fixedImageReader = ImageReaderType::New();
122 fixedImageReader->SetFileName( argv[2] );
123 fixedImageReader->Update();
124 typename FixedImageType::Pointer fixedImage = fixedImageReader->GetOutput();
125 fixedImage->Update();
126 fixedImage->DisconnectPipeline();
127
128 typename ImageReaderType::Pointer movingImageReader = ImageReaderType::New();
129 movingImageReader->SetFileName( argv[3] );
130 movingImageReader->Update();
131 typename MovingImageType::Pointer movingImage = movingImageReader->GetOutput();
132 movingImage->Update();
133 movingImage->DisconnectPipeline();
134
135 using AffineTransformType = itk::AffineTransform<double, VImageDimension>;
136 using AffineRegistrationType = itk::ImageRegistrationMethodv4<FixedImageType, MovingImageType, AffineTransformType>;
137 using GradientDescentOptimizerv4Type = itk::GradientDescentOptimizerv4;
138 typename AffineRegistrationType::Pointer affineSimple = AffineRegistrationType::New();
139 affineSimple->SetFixedImage( fixedImage );
140 affineSimple->SetMovingImage( movingImage );
141
142 // Smooth by specified gaussian sigmas for each level. These values are specified in
143 // physical units. Sigmas of zero cause inconsistency between some platforms.
144 {
145 typename AffineRegistrationType::SmoothingSigmasArrayType smoothingSigmasPerLevel;
146 smoothingSigmasPerLevel.SetSize( 3 );
147 smoothingSigmasPerLevel[0] = 2;
148 smoothingSigmasPerLevel[1] = 1;
149 smoothingSigmasPerLevel[2] = 1; //0;
150 affineSimple->SetSmoothingSigmasPerLevel( smoothingSigmasPerLevel );
151 }
152
153 using GradientDescentOptimizerv4Type = itk::GradientDescentOptimizerv4;
154 typename GradientDescentOptimizerv4Type::Pointer affineOptimizer =
155 dynamic_cast<GradientDescentOptimizerv4Type * >( affineSimple->GetModifiableOptimizer() );
156 if( !affineOptimizer )
157 {
158 itkGenericExceptionMacro( "Error dynamic_cast failed" );
159 }
160 #ifdef NDEBUG
161 affineOptimizer->SetNumberOfIterations( std::stoi( argv[5] ) );
162 #else
163 affineOptimizer->SetNumberOfIterations( 1 );
164 #endif
165
166 affineOptimizer->SetDoEstimateLearningRateOnce( false ); //true by default
167 affineOptimizer->SetDoEstimateLearningRateAtEachIteration( true );
168
169 using AffineCommandType = CommandIterationUpdate<AffineRegistrationType>;
170 typename AffineCommandType::Pointer affineObserver = AffineCommandType::New();
171 affineSimple->AddObserver( itk::IterationEvent(), affineObserver );
172
173 {
174 using ImageMetricType = itk::ImageToImageMetricv4<FixedImageType, MovingImageType>;
175 typename ImageMetricType::Pointer imageMetric = dynamic_cast<ImageMetricType*>( affineSimple->GetModifiableMetric() );
176 if(imageMetric.IsNull())
177 {
178 std::cout << "Test failed - too many pixels different." << std::endl;
179 return EXIT_FAILURE;
180 }
181 imageMetric->SetFloatingPointCorrectionResolution(1e4);
182 }
183
184 try
185 {
186 std::cout << "Affine txf:" << std::endl;
187 affineSimple->Update();
188 }
189 catch( itk::ExceptionObject &e )
190 {
191 std::cerr << "Exception caught: " << e << std::endl;
192 return EXIT_FAILURE;
193 }
194
195 {
196 using ImageMetricType = itk::ImageToImageMetricv4<FixedImageType, MovingImageType>;
197 typename ImageMetricType::Pointer imageMetric = dynamic_cast<ImageMetricType*>( affineOptimizer->GetModifiableMetric() );
198 std::cout << "Affine parameters after registration: " << std::endl
199 << affineOptimizer->GetCurrentPosition() << std::endl
200 << "Last LearningRate: " << affineOptimizer->GetLearningRate() << std::endl
201 << "Use FltPtCorrex: " << imageMetric->GetUseFloatingPointCorrection() << std::endl
202 << "FltPtCorrexRes: " << imageMetric->GetFloatingPointCorrectionResolution() << std::endl
203 << "Number of threads used: metric: " << imageMetric->GetNumberOfWorkUnitsUsed()
204 << std::endl << " optimizer: " << affineOptimizer->GetNumberOfWorkUnits() << std::endl;
205 }
206 //
207 // Now do the displacement field transform with gaussian smoothing using
208 // the composite transform.
209 //
210
211 using RealType = typename AffineRegistrationType::RealType;
212
213 using CompositeTransformType = itk::CompositeTransform<RealType, VImageDimension>;
214 typename CompositeTransformType::Pointer compositeTransform = CompositeTransformType::New();
215 compositeTransform->AddTransform( affineSimple->GetModifiableTransform() );
216
217 using VectorType = itk::Vector<RealType, VImageDimension>;
218 VectorType zeroVector( 0.0 );
219 using DisplacementFieldType = itk::Image<VectorType, VImageDimension>;
220 typename DisplacementFieldType::Pointer displacementField = DisplacementFieldType::New();
221 displacementField->CopyInformation( fixedImage );
222 displacementField->SetRegions( fixedImage->GetBufferedRegion() );
223 displacementField->Allocate();
224 displacementField->FillBuffer( zeroVector );
225
226 using DisplacementFieldTransformType = itk::BSplineExponentialDiffeomorphicTransform<RealType, VImageDimension>;
227
228 using DisplacementFieldRegistrationType = itk::ImageRegistrationMethodv4<FixedImageType, MovingImageType, DisplacementFieldTransformType>;
229 typename DisplacementFieldRegistrationType::Pointer displacementFieldSimple = DisplacementFieldRegistrationType::New();
230
231 typename DisplacementFieldTransformType::Pointer fieldTransform = DisplacementFieldTransformType::New();
232
233 typename DisplacementFieldTransformType::ArrayType updateControlPoints;
234 updateControlPoints.Fill( 10 );
235
236 typename DisplacementFieldTransformType::ArrayType velocityControlPoints;
237 velocityControlPoints.Fill( 10 );
238
239 fieldTransform->SetNumberOfControlPointsForTheUpdateField( updateControlPoints );
240 fieldTransform->SetNumberOfControlPointsForTheConstantVelocityField( velocityControlPoints );
241 fieldTransform->SetConstantVelocityField( displacementField );
242 fieldTransform->SetCalculateNumberOfIntegrationStepsAutomatically( true );
243
244 displacementFieldSimple->SetInitialTransform( fieldTransform );
245 displacementFieldSimple->InPlaceOn();
246
247 using CorrelationMetricType = itk::ANTSNeighborhoodCorrelationImageToImageMetricv4<FixedImageType, MovingImageType>;
248 typename CorrelationMetricType::Pointer correlationMetric = CorrelationMetricType::New();
249 typename CorrelationMetricType::RadiusType radius;
250 radius.Fill( 4 );
251 correlationMetric->SetRadius( radius );
252 correlationMetric->SetUseMovingImageGradientFilter( false );
253 correlationMetric->SetUseFixedImageGradientFilter( false );
254
255 //correlationMetric->SetUseFloatingPointCorrection(true);
256 //correlationMetric->SetFloatingPointCorrectionResolution(1e4);
257
258 using ScalesEstimatorType = itk::RegistrationParameterScalesFromPhysicalShift<CorrelationMetricType>;
259 typename ScalesEstimatorType::Pointer scalesEstimator = ScalesEstimatorType::New();
260 scalesEstimator->SetMetric( correlationMetric );
261 scalesEstimator->SetTransformForward( true );
262 scalesEstimator->SetSmallParameterVariation( 1.0 );
263
264 typename GradientDescentOptimizerv4Type::Pointer optimizer = GradientDescentOptimizerv4Type::New();
265 optimizer->SetLearningRate( 1.0 );
266 #ifdef NDEBUG
267 optimizer->SetNumberOfIterations( std::stoi( argv[6] ) );
268 #else
269 optimizer->SetNumberOfIterations( 1 );
270 #endif
271 optimizer->SetScalesEstimator( nullptr );
272 optimizer->SetDoEstimateLearningRateOnce( false ); //true by default
273 optimizer->SetDoEstimateLearningRateAtEachIteration( true );
274
275 displacementFieldSimple->SetFixedImage( fixedImage );
276 displacementFieldSimple->SetMovingImage( movingImage );
277 displacementFieldSimple->SetNumberOfLevels( 3 );
278 displacementFieldSimple->SetMovingInitialTransform( compositeTransform );
279 displacementFieldSimple->SetMetric( correlationMetric );
280 displacementFieldSimple->SetOptimizer( optimizer );
281
282 // Shrink the virtual domain by specified factors for each level. See documentation
283 // for the itkShrinkImageFilter for more detailed behavior.
284 typename DisplacementFieldRegistrationType::ShrinkFactorsArrayType shrinkFactorsPerLevel;
285 shrinkFactorsPerLevel.SetSize( 3 );
286 shrinkFactorsPerLevel[0] = 3;
287 shrinkFactorsPerLevel[1] = 2;
288 shrinkFactorsPerLevel[2] = 1;
289 displacementFieldSimple->SetShrinkFactorsPerLevel( shrinkFactorsPerLevel );
290
291 // Smooth by specified gaussian sigmas for each level. These values are specified in
292 // physical units.
293 typename DisplacementFieldRegistrationType::SmoothingSigmasArrayType smoothingSigmasPerLevel;
294 smoothingSigmasPerLevel.SetSize( 3 );
295 smoothingSigmasPerLevel[0] = 2;
296 smoothingSigmasPerLevel[1] = 1;
297 smoothingSigmasPerLevel[2] = 1;
298 displacementFieldSimple->SetSmoothingSigmasPerLevel( smoothingSigmasPerLevel );
299
300 using DisplacementFieldTransformAdaptorType = itk::BSplineExponentialDiffeomorphicTransformParametersAdaptor<DisplacementFieldTransformType>;
301
302 typename DisplacementFieldRegistrationType::TransformParametersAdaptorsContainerType adaptors;
303
304 for( unsigned int level = 0; level < shrinkFactorsPerLevel.Size(); level++ )
305 {
306 // We use the shrink image filter to calculate the fixed parameters of the virtual
307 // domain at each level. To speed up calculation and avoid unnecessary memory
308 // usage, we could calculate these fixed parameters directly.
309
310 using ShrinkFilterType = itk::ShrinkImageFilter<DisplacementFieldType, DisplacementFieldType>;
311 typename ShrinkFilterType::Pointer shrinkFilter = ShrinkFilterType::New();
312 shrinkFilter->SetShrinkFactors( shrinkFactorsPerLevel[level] );
313 shrinkFilter->SetInput( displacementField );
314 shrinkFilter->Update();
315
316 typename DisplacementFieldTransformAdaptorType::Pointer fieldTransformAdaptor = DisplacementFieldTransformAdaptorType::New();
317 fieldTransformAdaptor->SetRequiredSpacing( shrinkFilter->GetOutput()->GetSpacing() );
318 fieldTransformAdaptor->SetRequiredSize( shrinkFilter->GetOutput()->GetBufferedRegion().GetSize() );
319 fieldTransformAdaptor->SetRequiredDirection( shrinkFilter->GetOutput()->GetDirection() );
320 fieldTransformAdaptor->SetRequiredOrigin( shrinkFilter->GetOutput()->GetOrigin() );
321
322 adaptors.push_back( fieldTransformAdaptor );
323 }
324 displacementFieldSimple->SetTransformParametersAdaptorsPerLevel( adaptors );
325
326 using DisplacementFieldRegistrationCommandType = CommandIterationUpdate<DisplacementFieldRegistrationType>;
327 typename DisplacementFieldRegistrationCommandType::Pointer displacementFieldObserver = DisplacementFieldRegistrationCommandType::New();
328 displacementFieldSimple->AddObserver( itk::IterationEvent(), displacementFieldObserver );
329
330 try
331 {
332 std::cout << "Displ. txf - bspline update" << std::endl;
333 displacementFieldSimple->Update();
334 }
335 catch( itk::ExceptionObject &e )
336 {
337 std::cerr << "Exception caught: " << e << std::endl;
338 return EXIT_FAILURE;
339 }
340
341 compositeTransform->AddTransform( displacementFieldSimple->GetModifiableTransform() );
342
343 std::cout << "After displacement registration: " << std::endl
344 << "Last LearningRate: " << optimizer->GetLearningRate() << std::endl
345 << "Use FltPtCorrex: " << correlationMetric->GetUseFloatingPointCorrection() << std::endl
346 << "FltPtCorrexRes: " << correlationMetric->GetFloatingPointCorrectionResolution() << std::endl
347 << "Number of threads used: metric: " << correlationMetric->GetNumberOfWorkUnitsUsed()
348 << "Number of threads used: metric: " << correlationMetric->GetNumberOfWorkUnitsUsed()
349 << " optimizer: " << displacementFieldSimple->GetOptimizer()->GetNumberOfWorkUnits() << std::endl;
350
351 using ResampleFilterType = itk::ResampleImageFilter<MovingImageType, FixedImageType>;
352 typename ResampleFilterType::Pointer resampler = ResampleFilterType::New();
353 resampler->SetTransform( compositeTransform );
354 resampler->SetInput( movingImage );
355 resampler->SetSize( fixedImage->GetLargestPossibleRegion().GetSize() );
356 resampler->SetOutputOrigin( fixedImage->GetOrigin() );
357 resampler->SetOutputSpacing( fixedImage->GetSpacing() );
358 resampler->SetOutputDirection( fixedImage->GetDirection() );
359 resampler->SetDefaultPixelValue( 0 );
360 resampler->Update();
361
362 using WriterType = itk::ImageFileWriter<FixedImageType>;
363 typename WriterType::Pointer writer = WriterType::New();
364 writer->SetFileName( argv[4] );
365 writer->SetInput( resampler->GetOutput() );
366 writer->Update();
367
368 // Check identity of forward and inverse transforms
369
370 using ComposerType = itk::ComposeDisplacementFieldsImageFilter<DisplacementFieldType, DisplacementFieldType>;
371 typename ComposerType::Pointer composer = ComposerType::New();
372 composer->SetDisplacementField( fieldTransform->GetDisplacementField() );
373 composer->SetWarpingField( fieldTransform->GetInverseDisplacementField() );
374 composer->Update();
375
376 using MagnituderType = itk::VectorMagnitudeImageFilter<DisplacementFieldType, MovingImageType>;
377 typename MagnituderType::Pointer magnituder = MagnituderType::New();
378 magnituder->SetInput( composer->GetOutput() );
379 magnituder->Update();
380
381 using StatisticsImageFilterType = itk::StatisticsImageFilter<MovingImageType>;
382 typename StatisticsImageFilterType::Pointer stats = StatisticsImageFilterType::New();
383 stats->SetInput( magnituder->GetOutput() );
384 stats->Update();
385
386 std::cout << "Identity check:" << std::endl;
387 std::cout << " Min: " << stats->GetMinimum() << std::endl;
388 std::cout << " Max: " << stats->GetMaximum() << std::endl;
389 std::cout << " Mean: " << stats->GetMean() << std::endl;
390 std::cout << " Variance: " << stats->GetVariance() << std::endl;
391
392 if( stats->GetMean() > 0.1 )
393 {
394 std::cerr << "Identity test failed." << std::endl;
395 }
396
397 return EXIT_SUCCESS;
398 }
399
itkBSplineExponentialImageRegistrationTest(int argc,char * argv[])400 int itkBSplineExponentialImageRegistrationTest( int argc, char *argv[] )
401 {
402 if( argc < 6 )
403 {
404 std::cout << itkNameOfTestExecutableMacro(argv) << " imageDimension fixedImage movingImage outputImage numberOfAffineIterations numberOfDeformableIterations" << std::endl;
405 exit( 1 );
406 }
407
408 switch( std::stoi( argv[1] ) )
409 {
410 case 2:
411 PerformBSplineExpImageRegistration<2>( argc, argv );
412 break;
413 case 3:
414 PerformBSplineExpImageRegistration<3>( argc, argv );
415 break;
416 default:
417 std::cerr << "Unsupported dimension" << std::endl;
418 exit( EXIT_FAILURE );
419 }
420 return EXIT_SUCCESS;
421 }
422