1 /*=========================================================================
2 *
3 *  Copyright Insight Software Consortium
4 *
5 *  Licensed under the Apache License, Version 2.0 (the "License");
6 *  you may not use this file except in compliance with the License.
7 *  You may obtain a copy of the License at
8 *
9 *         http://www.apache.org/licenses/LICENSE-2.0.txt
10 *
11 *  Unless required by applicable law or agreed to in writing, software
12 *  distributed under the License is distributed on an "AS IS" BASIS,
13 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 *  See the License for the specific language governing permissions and
15 *  limitations under the License.
16 *
17 *=========================================================================*/
18 
19 #include "itkImageFileReader.h"
20 #include "itkImageFileWriter.h"
21 
22 #include "itkImageRegistrationMethodv4.h"
23 
24 #include "itkAffineTransform.h"
25 #include "itkANTSNeighborhoodCorrelationImageToImageMetricv4.h"
26 #include "itkBSplineExponentialDiffeomorphicTransform.h"
27 #include "itkBSplineExponentialDiffeomorphicTransformParametersAdaptor.h"
28 #include "itkComposeDisplacementFieldsImageFilter.h"
29 #include "itkVectorMagnitudeImageFilter.h"
30 #include "itkStatisticsImageFilter.h"
31 #include "itkTestingMacros.h"
32 
33 template<typename TFilter>
34 class CommandIterationUpdate : public itk::Command
35 {
36 public:
37   using Self = CommandIterationUpdate;
38   using Superclass = itk::Command;
39   using Pointer = itk::SmartPointer<Self>;
40   itkNewMacro( Self );
41 
42 protected:
43   CommandIterationUpdate() = default;
44 
45 public:
46 
Execute(itk::Object * caller,const itk::EventObject & event)47   void Execute(itk::Object *caller, const itk::EventObject & event) override
48     {
49     Execute( (const itk::Object *) caller, event);
50     }
51 
Execute(const itk::Object * object,const itk::EventObject & event)52   void Execute(const itk::Object * object, const itk::EventObject & event) override
53     {
54     const auto * filter = static_cast< const TFilter * >( object );
55     if( typeid( event ) != typeid( itk::IterationEvent ) )
56       {
57       return;
58       }
59 
60     unsigned int currentLevel = filter->GetCurrentLevel();
61     typename TFilter::ShrinkFactorsPerDimensionContainerType shrinkFactors = filter->GetShrinkFactorsPerDimension( currentLevel );
62     typename TFilter::SmoothingSigmasArrayType smoothingSigmas = filter->GetSmoothingSigmasPerLevel();
63     typename TFilter::TransformParametersAdaptorsContainerType adaptors = filter->GetTransformParametersAdaptorsPerLevel();
64 
65     const itk::ObjectToObjectOptimizerBase * optimizerBase = filter->GetOptimizer();
66     using GradientDescentOptimizerv4Type = itk::GradientDescentOptimizerv4;
67     typename GradientDescentOptimizerv4Type::ConstPointer optimizer =
68       dynamic_cast<const GradientDescentOptimizerv4Type *>(optimizerBase);
69     if( !optimizer )
70       {
71       itkGenericExceptionMacro( "Error dynamic_cast failed" );
72       }
73     typename GradientDescentOptimizerv4Type::DerivativeType gradient = optimizer->GetGradient();
74 
75     /* orig
76     std::cout << "  Current level = " << currentLevel << std::endl;
77     std::cout << "    shrink factor = " << shrinkFactors[currentLevel] << std::endl;
78     std::cout << "    smoothing sigma = " << smoothingSigmas[currentLevel] << std::endl;
79     std::cout << "    required fixed parameters = " << adaptors[currentLevel]->GetRequiredFixedParameters() << std::endl;
80     */
81 
82     //debug:
83     std::cout << "  CL Current level:           " << currentLevel << std::endl;
84     std::cout << "   SF Shrink factor:          " << shrinkFactors << std::endl;
85     std::cout << "   SS Smoothing sigma:        " << smoothingSigmas[currentLevel] << std::endl;
86     std::cout << "   RFP Required fixed params: " << adaptors[currentLevel]->GetRequiredFixedParameters() << std::endl;
87     std::cout << "   LR Final learning rate:    " << optimizer->GetLearningRate() << std::endl;
88     std::cout << "   FM Final metric value:     " << optimizer->GetCurrentMetricValue() << std::endl;
89     std::cout << "   SC Optimizer scales:       " << optimizer->GetScales() << std::endl;
90     std::cout << "   FG Final metric gradient (sample of values): ";
91     if( gradient.GetSize() < 10 )
92       {
93       std::cout << gradient;
94       }
95     else
96       {
97       for( itk::SizeValueType i = 0; i < gradient.GetSize(); i += (gradient.GetSize() / 16) )
98         {
99         std::cout << gradient[i] << " ";
100         }
101       }
102     std::cout << std::endl;
103     }
104 };
105 
106 template <unsigned int VImageDimension>
PerformBSplineExpImageRegistration(int argc,char * argv[])107 int PerformBSplineExpImageRegistration( int argc, char *argv[] )
108 {
109   if( argc < 6 )
110     {
111     std::cout << itkNameOfTestExecutableMacro(argv) << " imageDimension fixedImage movingImage outputImage numberOfAffineIterations numberOfDeformableIterations" << std::endl;
112     exit( 1 );
113     }
114 
115   using PixelType = double;
116   using FixedImageType = itk::Image<PixelType, VImageDimension>;
117   using MovingImageType = itk::Image<PixelType, VImageDimension>;
118 
119   using ImageReaderType = itk::ImageFileReader<FixedImageType>;
120 
121   typename ImageReaderType::Pointer fixedImageReader = ImageReaderType::New();
122   fixedImageReader->SetFileName( argv[2] );
123   fixedImageReader->Update();
124   typename FixedImageType::Pointer fixedImage = fixedImageReader->GetOutput();
125   fixedImage->Update();
126   fixedImage->DisconnectPipeline();
127 
128   typename ImageReaderType::Pointer movingImageReader = ImageReaderType::New();
129   movingImageReader->SetFileName( argv[3] );
130   movingImageReader->Update();
131   typename MovingImageType::Pointer movingImage = movingImageReader->GetOutput();
132   movingImage->Update();
133   movingImage->DisconnectPipeline();
134 
135   using AffineTransformType = itk::AffineTransform<double, VImageDimension>;
136   using AffineRegistrationType = itk::ImageRegistrationMethodv4<FixedImageType, MovingImageType, AffineTransformType>;
137   using GradientDescentOptimizerv4Type = itk::GradientDescentOptimizerv4;
138   typename AffineRegistrationType::Pointer affineSimple = AffineRegistrationType::New();
139   affineSimple->SetFixedImage( fixedImage );
140   affineSimple->SetMovingImage( movingImage );
141 
142   // Smooth by specified gaussian sigmas for each level.  These values are specified in
143   // physical units. Sigmas of zero cause inconsistency between some platforms.
144   {
145   typename AffineRegistrationType::SmoothingSigmasArrayType smoothingSigmasPerLevel;
146   smoothingSigmasPerLevel.SetSize( 3 );
147   smoothingSigmasPerLevel[0] = 2;
148   smoothingSigmasPerLevel[1] = 1;
149   smoothingSigmasPerLevel[2] = 1; //0;
150   affineSimple->SetSmoothingSigmasPerLevel( smoothingSigmasPerLevel );
151   }
152 
153   using GradientDescentOptimizerv4Type = itk::GradientDescentOptimizerv4;
154   typename GradientDescentOptimizerv4Type::Pointer affineOptimizer =
155     dynamic_cast<GradientDescentOptimizerv4Type * >( affineSimple->GetModifiableOptimizer() );
156   if( !affineOptimizer )
157     {
158     itkGenericExceptionMacro( "Error dynamic_cast failed" );
159     }
160 #ifdef NDEBUG
161   affineOptimizer->SetNumberOfIterations( std::stoi( argv[5] ) );
162 #else
163   affineOptimizer->SetNumberOfIterations( 1 );
164 #endif
165 
166   affineOptimizer->SetDoEstimateLearningRateOnce( false ); //true by default
167   affineOptimizer->SetDoEstimateLearningRateAtEachIteration( true );
168 
169   using AffineCommandType = CommandIterationUpdate<AffineRegistrationType>;
170   typename AffineCommandType::Pointer affineObserver = AffineCommandType::New();
171   affineSimple->AddObserver( itk::IterationEvent(), affineObserver );
172 
173   {
174   using ImageMetricType = itk::ImageToImageMetricv4<FixedImageType, MovingImageType>;
175   typename ImageMetricType::Pointer imageMetric = dynamic_cast<ImageMetricType*>( affineSimple->GetModifiableMetric() );
176   if(imageMetric.IsNull())
177     {
178     std::cout << "Test failed - too many pixels different." << std::endl;
179     return EXIT_FAILURE;
180     }
181   imageMetric->SetFloatingPointCorrectionResolution(1e4);
182   }
183 
184   try
185     {
186     std::cout << "Affine txf:" << std::endl;
187     affineSimple->Update();
188     }
189   catch( itk::ExceptionObject &e )
190     {
191     std::cerr << "Exception caught: " << e << std::endl;
192     return EXIT_FAILURE;
193     }
194 
195   {
196   using ImageMetricType = itk::ImageToImageMetricv4<FixedImageType, MovingImageType>;
197   typename ImageMetricType::Pointer imageMetric = dynamic_cast<ImageMetricType*>( affineOptimizer->GetModifiableMetric() );
198   std::cout << "Affine parameters after registration: " << std::endl
199             << affineOptimizer->GetCurrentPosition() << std::endl
200             << "Last LearningRate: " << affineOptimizer->GetLearningRate() << std::endl
201             << "Use FltPtCorrex: " << imageMetric->GetUseFloatingPointCorrection() << std::endl
202             << "FltPtCorrexRes: " << imageMetric->GetFloatingPointCorrectionResolution() << std::endl
203             << "Number of threads used: metric: " << imageMetric->GetNumberOfWorkUnitsUsed()
204             << std::endl << " optimizer: " << affineOptimizer->GetNumberOfWorkUnits() << std::endl;
205   }
206   //
207   // Now do the displacement field transform with gaussian smoothing using
208   // the composite transform.
209   //
210 
211   using RealType = typename AffineRegistrationType::RealType;
212 
213   using CompositeTransformType = itk::CompositeTransform<RealType, VImageDimension>;
214   typename CompositeTransformType::Pointer compositeTransform = CompositeTransformType::New();
215   compositeTransform->AddTransform( affineSimple->GetModifiableTransform() );
216 
217   using VectorType = itk::Vector<RealType, VImageDimension>;
218   VectorType zeroVector( 0.0 );
219   using DisplacementFieldType = itk::Image<VectorType, VImageDimension>;
220   typename DisplacementFieldType::Pointer displacementField = DisplacementFieldType::New();
221   displacementField->CopyInformation( fixedImage );
222   displacementField->SetRegions( fixedImage->GetBufferedRegion() );
223   displacementField->Allocate();
224   displacementField->FillBuffer( zeroVector );
225 
226   using DisplacementFieldTransformType = itk::BSplineExponentialDiffeomorphicTransform<RealType, VImageDimension>;
227 
228   using DisplacementFieldRegistrationType = itk::ImageRegistrationMethodv4<FixedImageType, MovingImageType, DisplacementFieldTransformType>;
229   typename DisplacementFieldRegistrationType::Pointer displacementFieldSimple = DisplacementFieldRegistrationType::New();
230 
231   typename DisplacementFieldTransformType::Pointer fieldTransform = DisplacementFieldTransformType::New();
232 
233   typename DisplacementFieldTransformType::ArrayType updateControlPoints;
234   updateControlPoints.Fill( 10 );
235 
236   typename DisplacementFieldTransformType::ArrayType velocityControlPoints;
237   velocityControlPoints.Fill( 10 );
238 
239   fieldTransform->SetNumberOfControlPointsForTheUpdateField( updateControlPoints );
240   fieldTransform->SetNumberOfControlPointsForTheConstantVelocityField( velocityControlPoints );
241   fieldTransform->SetConstantVelocityField( displacementField );
242   fieldTransform->SetCalculateNumberOfIntegrationStepsAutomatically( true );
243 
244   displacementFieldSimple->SetInitialTransform( fieldTransform );
245   displacementFieldSimple->InPlaceOn();
246 
247   using CorrelationMetricType = itk::ANTSNeighborhoodCorrelationImageToImageMetricv4<FixedImageType, MovingImageType>;
248   typename CorrelationMetricType::Pointer correlationMetric = CorrelationMetricType::New();
249   typename CorrelationMetricType::RadiusType radius;
250   radius.Fill( 4 );
251   correlationMetric->SetRadius( radius );
252   correlationMetric->SetUseMovingImageGradientFilter( false );
253   correlationMetric->SetUseFixedImageGradientFilter( false );
254 
255   //correlationMetric->SetUseFloatingPointCorrection(true);
256   //correlationMetric->SetFloatingPointCorrectionResolution(1e4);
257 
258   using ScalesEstimatorType = itk::RegistrationParameterScalesFromPhysicalShift<CorrelationMetricType>;
259   typename ScalesEstimatorType::Pointer scalesEstimator = ScalesEstimatorType::New();
260   scalesEstimator->SetMetric( correlationMetric );
261   scalesEstimator->SetTransformForward( true );
262   scalesEstimator->SetSmallParameterVariation( 1.0 );
263 
264   typename GradientDescentOptimizerv4Type::Pointer optimizer = GradientDescentOptimizerv4Type::New();
265   optimizer->SetLearningRate( 1.0 );
266 #ifdef NDEBUG
267   optimizer->SetNumberOfIterations( std::stoi( argv[6] ) );
268 #else
269   optimizer->SetNumberOfIterations( 1 );
270 #endif
271   optimizer->SetScalesEstimator( nullptr );
272   optimizer->SetDoEstimateLearningRateOnce( false ); //true by default
273   optimizer->SetDoEstimateLearningRateAtEachIteration( true );
274 
275   displacementFieldSimple->SetFixedImage( fixedImage );
276   displacementFieldSimple->SetMovingImage( movingImage );
277   displacementFieldSimple->SetNumberOfLevels( 3 );
278   displacementFieldSimple->SetMovingInitialTransform( compositeTransform );
279   displacementFieldSimple->SetMetric( correlationMetric );
280   displacementFieldSimple->SetOptimizer( optimizer );
281 
282   // Shrink the virtual domain by specified factors for each level.  See documentation
283   // for the itkShrinkImageFilter for more detailed behavior.
284   typename DisplacementFieldRegistrationType::ShrinkFactorsArrayType shrinkFactorsPerLevel;
285   shrinkFactorsPerLevel.SetSize( 3 );
286   shrinkFactorsPerLevel[0] = 3;
287   shrinkFactorsPerLevel[1] = 2;
288   shrinkFactorsPerLevel[2] = 1;
289   displacementFieldSimple->SetShrinkFactorsPerLevel( shrinkFactorsPerLevel );
290 
291   // Smooth by specified gaussian sigmas for each level.  These values are specified in
292   // physical units.
293   typename DisplacementFieldRegistrationType::SmoothingSigmasArrayType smoothingSigmasPerLevel;
294   smoothingSigmasPerLevel.SetSize( 3 );
295   smoothingSigmasPerLevel[0] = 2;
296   smoothingSigmasPerLevel[1] = 1;
297   smoothingSigmasPerLevel[2] = 1;
298   displacementFieldSimple->SetSmoothingSigmasPerLevel( smoothingSigmasPerLevel );
299 
300   using DisplacementFieldTransformAdaptorType = itk::BSplineExponentialDiffeomorphicTransformParametersAdaptor<DisplacementFieldTransformType>;
301 
302   typename DisplacementFieldRegistrationType::TransformParametersAdaptorsContainerType adaptors;
303 
304   for( unsigned int level = 0; level < shrinkFactorsPerLevel.Size(); level++ )
305     {
306     // We use the shrink image filter to calculate the fixed parameters of the virtual
307     // domain at each level.  To speed up calculation and avoid unnecessary memory
308     // usage, we could calculate these fixed parameters directly.
309 
310     using ShrinkFilterType = itk::ShrinkImageFilter<DisplacementFieldType, DisplacementFieldType>;
311     typename ShrinkFilterType::Pointer shrinkFilter = ShrinkFilterType::New();
312     shrinkFilter->SetShrinkFactors( shrinkFactorsPerLevel[level] );
313     shrinkFilter->SetInput( displacementField );
314     shrinkFilter->Update();
315 
316     typename DisplacementFieldTransformAdaptorType::Pointer fieldTransformAdaptor = DisplacementFieldTransformAdaptorType::New();
317     fieldTransformAdaptor->SetRequiredSpacing( shrinkFilter->GetOutput()->GetSpacing() );
318     fieldTransformAdaptor->SetRequiredSize( shrinkFilter->GetOutput()->GetBufferedRegion().GetSize() );
319     fieldTransformAdaptor->SetRequiredDirection( shrinkFilter->GetOutput()->GetDirection() );
320     fieldTransformAdaptor->SetRequiredOrigin( shrinkFilter->GetOutput()->GetOrigin() );
321 
322     adaptors.push_back( fieldTransformAdaptor );
323     }
324   displacementFieldSimple->SetTransformParametersAdaptorsPerLevel( adaptors );
325 
326   using DisplacementFieldRegistrationCommandType = CommandIterationUpdate<DisplacementFieldRegistrationType>;
327   typename DisplacementFieldRegistrationCommandType::Pointer displacementFieldObserver = DisplacementFieldRegistrationCommandType::New();
328   displacementFieldSimple->AddObserver( itk::IterationEvent(), displacementFieldObserver );
329 
330   try
331     {
332     std::cout << "Displ. txf - bspline update" << std::endl;
333     displacementFieldSimple->Update();
334     }
335   catch( itk::ExceptionObject &e )
336     {
337     std::cerr << "Exception caught: " << e << std::endl;
338     return EXIT_FAILURE;
339     }
340 
341   compositeTransform->AddTransform( displacementFieldSimple->GetModifiableTransform() );
342 
343   std::cout << "After displacement registration: " << std::endl
344             << "Last LearningRate: " << optimizer->GetLearningRate() << std::endl
345             << "Use FltPtCorrex: " << correlationMetric->GetUseFloatingPointCorrection() << std::endl
346             << "FltPtCorrexRes: " << correlationMetric->GetFloatingPointCorrectionResolution() << std::endl
347             << "Number of threads used: metric: " << correlationMetric->GetNumberOfWorkUnitsUsed()
348             << "Number of threads used: metric: " << correlationMetric->GetNumberOfWorkUnitsUsed()
349             << " optimizer: " << displacementFieldSimple->GetOptimizer()->GetNumberOfWorkUnits() << std::endl;
350 
351   using ResampleFilterType = itk::ResampleImageFilter<MovingImageType, FixedImageType>;
352   typename ResampleFilterType::Pointer resampler = ResampleFilterType::New();
353   resampler->SetTransform( compositeTransform );
354   resampler->SetInput( movingImage );
355   resampler->SetSize( fixedImage->GetLargestPossibleRegion().GetSize() );
356   resampler->SetOutputOrigin(  fixedImage->GetOrigin() );
357   resampler->SetOutputSpacing( fixedImage->GetSpacing() );
358   resampler->SetOutputDirection( fixedImage->GetDirection() );
359   resampler->SetDefaultPixelValue( 0 );
360   resampler->Update();
361 
362   using WriterType = itk::ImageFileWriter<FixedImageType>;
363   typename WriterType::Pointer writer = WriterType::New();
364   writer->SetFileName( argv[4] );
365   writer->SetInput( resampler->GetOutput() );
366   writer->Update();
367 
368   // Check identity of forward and inverse transforms
369 
370   using ComposerType = itk::ComposeDisplacementFieldsImageFilter<DisplacementFieldType, DisplacementFieldType>;
371   typename ComposerType::Pointer composer = ComposerType::New();
372   composer->SetDisplacementField( fieldTransform->GetDisplacementField() );
373   composer->SetWarpingField( fieldTransform->GetInverseDisplacementField() );
374   composer->Update();
375 
376   using MagnituderType = itk::VectorMagnitudeImageFilter<DisplacementFieldType, MovingImageType>;
377   typename MagnituderType::Pointer magnituder = MagnituderType::New();
378   magnituder->SetInput( composer->GetOutput() );
379   magnituder->Update();
380 
381   using StatisticsImageFilterType = itk::StatisticsImageFilter<MovingImageType>;
382   typename StatisticsImageFilterType::Pointer stats = StatisticsImageFilterType::New();
383   stats->SetInput( magnituder->GetOutput() );
384   stats->Update();
385 
386   std::cout << "Identity check:" << std::endl;
387   std::cout << "  Min:  " << stats->GetMinimum() << std::endl;
388   std::cout << "  Max:  " << stats->GetMaximum() << std::endl;
389   std::cout << "  Mean:  " << stats->GetMean() << std::endl;
390   std::cout << "  Variance:  " << stats->GetVariance() << std::endl;
391 
392   if( stats->GetMean() > 0.1 )
393     {
394     std::cerr << "Identity test failed." << std::endl;
395     }
396 
397   return EXIT_SUCCESS;
398 }
399 
itkBSplineExponentialImageRegistrationTest(int argc,char * argv[])400 int itkBSplineExponentialImageRegistrationTest( int argc, char *argv[] )
401 {
402   if( argc < 6 )
403     {
404     std::cout << itkNameOfTestExecutableMacro(argv) << " imageDimension fixedImage movingImage outputImage numberOfAffineIterations numberOfDeformableIterations" << std::endl;
405     exit( 1 );
406     }
407 
408   switch( std::stoi( argv[1] ) )
409    {
410    case 2:
411      PerformBSplineExpImageRegistration<2>( argc, argv );
412      break;
413    case 3:
414      PerformBSplineExpImageRegistration<3>( argc, argv );
415      break;
416    default:
417       std::cerr << "Unsupported dimension" << std::endl;
418       exit( EXIT_FAILURE );
419    }
420   return EXIT_SUCCESS;
421 }
422