1 /*=========================================================================
2  *
3  *  Copyright Insight Software Consortium
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at
8  *
9  *         http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  *=========================================================================*/
18 
19 #include "itkCorrelationImageToImageMetricv4.h"
20 #include "itkTranslationTransform.h"
21 #include "itkLinearInterpolateImageFunction.h"
22 #include "itkImage.h"
23 #include "itkGaussianImageSource.h"
24 #include "itkCyclicShiftImageFilter.h"
25 #include "itkRegistrationParameterScalesFromPhysicalShift.h"
26 #include "itkGradientDescentOptimizerv4.h"
27 #include "itkImageRegionIteratorWithIndex.h"
28 #include "itkObjectToObjectMultiMetricv4.h"
29 #include "itkMeanSquaresImageToImageMetricv4.h"
30 
31 /* This test performs a simple registration test using
32  * a single metric and a multivariate metric containing
33  * two copies of the metric, testing
34  * that the results are the same.
35  */
36 
37 template<typename TFilter>
38 class itkObjectToObjectMultiMetricv4RegistrationTestCommandIterationUpdate : public itk::Command
39 {
40 public:
41   using Self = itkObjectToObjectMultiMetricv4RegistrationTestCommandIterationUpdate;
42 
43   using Superclass = itk::Command;
44   using Pointer = itk::SmartPointer<Self>;
45   itkNewMacro( Self );
46 
47 protected:
48   itkObjectToObjectMultiMetricv4RegistrationTestCommandIterationUpdate() = default;
49 
50 public:
51 
Execute(itk::Object * caller,const itk::EventObject & event)52   void Execute(itk::Object *caller, const itk::EventObject & event) override
53     {
54     Execute( (const itk::Object *) caller, event);
55     }
56 
Execute(const itk::Object * object,const itk::EventObject & event)57   void Execute(const itk::Object * object, const itk::EventObject & event) override
58     {
59     if( typeid( event ) != typeid( itk::IterationEvent ) )
60       {
61       return;
62       }
63     const auto * optimizer = dynamic_cast< const TFilter * >( object );
64 
65     if( !optimizer )
66       {
67       itkGenericExceptionMacro( "Error dynamic_cast failed" );
68       }
69     std::cout << "It- " << optimizer->GetCurrentIteration() << " gradient: " << optimizer->GetGradient() << " metric value: " << optimizer->GetCurrentMetricValue()
70               << " Params: " << const_cast<TFilter*>(optimizer)->GetCurrentPosition() << std::endl;
71     }
72 };
73 
74 template<typename TImage>
ObjectToObjectMultiMetricv4RegistrationTestCreateImages(typename TImage::Pointer & fixedImage,typename TImage::Pointer & movingImage,typename TImage::OffsetType & imageShift)75 void ObjectToObjectMultiMetricv4RegistrationTestCreateImages( typename TImage::Pointer & fixedImage, typename TImage::Pointer & movingImage, typename TImage::OffsetType & imageShift )
76 {
77   using PixelType = typename TImage::PixelType;
78   using CoordinateRepresentationType = PixelType;
79 
80   // Create two simple images
81   itk::SizeValueType ImageSize = 100;
82   itk::OffsetValueType boundary = 6;
83 
84    // Declare Gaussian Sources
85   using GaussianImageSourceType = itk::GaussianImageSource< TImage >;
86 
87   typename TImage::SizeType size;
88   size.Fill( ImageSize );
89 
90   typename TImage::SpacingType spacing;
91   spacing.Fill( itk::NumericTraits<CoordinateRepresentationType>::OneValue() );
92 
93   typename TImage::PointType origin;
94   origin.Fill( itk::NumericTraits<CoordinateRepresentationType>::ZeroValue() );
95 
96   typename TImage::DirectionType direction;
97   direction.Fill( itk::NumericTraits<CoordinateRepresentationType>::OneValue() );
98 
99   typename GaussianImageSourceType::Pointer  fixedImageSource = GaussianImageSourceType::New();
100 
101   fixedImageSource->SetSize(    size    );
102   fixedImageSource->SetOrigin(  origin  );
103   fixedImageSource->SetSpacing( spacing );
104   fixedImageSource->SetNormalized( false );
105   fixedImageSource->SetScale( 1.0f );
106   fixedImageSource->Update();
107   fixedImage = fixedImageSource->GetOutput();
108 
109   // zero-out the boundary
110   itk::ImageRegionIteratorWithIndex<TImage> it( fixedImage, fixedImage->GetLargestPossibleRegion() );
111   for( it.GoToBegin(); ! it.IsAtEnd(); ++it )
112     {
113     for( itk::SizeValueType n=0; n < TImage::ImageDimension; n++ )
114       {
115       if( it.GetIndex()[n] < boundary || (static_cast<itk::OffsetValueType>(size[n]) - it.GetIndex()[n]) <= boundary )
116         {
117         it.Set( itk::NumericTraits<PixelType>::ZeroValue() );
118         break;
119         }
120       }
121     }
122 
123   // shift the fixed image to get the moving image
124   using CyclicShiftFilterType = itk::CyclicShiftImageFilter<TImage, TImage>;
125   typename CyclicShiftFilterType::Pointer shiftFilter = CyclicShiftFilterType::New();
126   typename CyclicShiftFilterType::OffsetValueType maxImageShift = boundary-1;
127   imageShift.Fill( maxImageShift );
128   imageShift[0] = maxImageShift / 2;
129   shiftFilter->SetInput( fixedImage );
130   shiftFilter->SetShift( imageShift );
131   shiftFilter->Update();
132   movingImage = shiftFilter->GetOutput();
133 }
134 
135 //////////////////////////////////////////////////////////////////////////////////////////////////////////////
136 
137 template<typename TMetric>
ObjectToObjectMultiMetricv4RegistrationTestRun(typename TMetric::Pointer & metric,int numberOfIterations,typename TMetric::MeasureType & valueResult,typename TMetric::DerivativeType & derivativeResult,typename TMetric::InternalComputationValueType maxStep,bool estimateStepOnce)138 int ObjectToObjectMultiMetricv4RegistrationTestRun( typename TMetric::Pointer & metric, int numberOfIterations,
139                                                     typename TMetric::MeasureType & valueResult, typename TMetric::DerivativeType & derivativeResult,
140                                                     typename TMetric::InternalComputationValueType maxStep, bool estimateStepOnce )
141 {
142   // calculate initial metric value
143   metric->Initialize();
144   typename TMetric::MeasureType initialValue = metric->GetValue();
145 
146   // scales estimator
147   using RegistrationParameterScalesFromShiftType = itk::RegistrationParameterScalesFromPhysicalShift< TMetric >;
148   typename RegistrationParameterScalesFromShiftType::Pointer shiftScaleEstimator = RegistrationParameterScalesFromShiftType::New();
149   shiftScaleEstimator->SetMetric(metric);
150 
151   //
152   // optimizer
153   //
154   using OptimizerType = itk::GradientDescentOptimizerv4;
155   typename OptimizerType::Pointer  optimizer = OptimizerType::New();
156 
157   optimizer->SetMetric( metric );
158   optimizer->SetNumberOfIterations( numberOfIterations );
159   optimizer->SetScalesEstimator( shiftScaleEstimator );
160   optimizer->SetMaximumStepSizeInPhysicalUnits( maxStep );
161   optimizer->SetDoEstimateLearningRateOnce( estimateStepOnce );
162   optimizer->SetDoEstimateLearningRateAtEachIteration( ! estimateStepOnce );
163 
164   using CommandType = itkObjectToObjectMultiMetricv4RegistrationTestCommandIterationUpdate<OptimizerType>;
165   typename CommandType::Pointer observer = CommandType::New();
166   //optimizer->AddObserver( itk::IterationEvent(), observer );
167 
168   optimizer->StartOptimization();
169 
170   std::cout << "# of iterations: " << optimizer->GetNumberOfIterations() << std::endl;
171   std::cout << "DoEstimateLearningRateOnce: " << optimizer->GetDoEstimateLearningRateOnce()
172             << " GetDoEstimateLearningRateAtEachIteration: " << optimizer->GetDoEstimateLearningRateAtEachIteration() << std::endl;
173   derivativeResult = optimizer->GetCurrentPosition();
174   std::cout << "Transform final parameters: " << derivativeResult << " mag: " << derivativeResult.magnitude() << std::endl;
175 
176   // final metric value
177   valueResult = metric->GetValue();
178   std::cout << "metric value: initial: " << initialValue << ", final: " << valueResult << std::endl;
179 
180   // scales
181   std::cout << "scales: " << optimizer->GetScales() << std::endl;
182   std::cout << "optimizer learning rate at end: " << optimizer->GetLearningRate() << std::endl;
183 
184   return EXIT_SUCCESS;
185 }
186 
187 //////////////////////////////////////////////////////////////
itkObjectToObjectMultiMetricv4RegistrationTest(int argc,char * argv[])188 int itkObjectToObjectMultiMetricv4RegistrationTest(int argc, char *argv[])
189 {
190   constexpr int Dimension = 2;
191   using ImageType = itk::Image< double, Dimension >;
192 
193   int numberOfIterations = 30;
194   if( argc > 1 )
195     {
196     numberOfIterations = std::stoi( argv[1] );
197     }
198 
199   // create an affine transform
200   using TranslationTransformType = itk::TranslationTransform<double, Dimension>;
201   TranslationTransformType::Pointer translationTransform = TranslationTransformType::New();
202   translationTransform->SetIdentity();
203 
204   // create images
205   ImageType::Pointer fixedImage = nullptr, movingImage = nullptr;
206   ImageType::OffsetType imageShift;
207   imageShift.Fill(0);
208   ObjectToObjectMultiMetricv4RegistrationTestCreateImages<ImageType>( fixedImage, movingImage, imageShift );
209 
210   using CorrelationMetricType = itk::CorrelationImageToImageMetricv4<ImageType, ImageType>;
211   CorrelationMetricType::Pointer correlationMetric = CorrelationMetricType::New();
212   correlationMetric->SetFixedImage( fixedImage );
213   correlationMetric->SetMovingImage( movingImage );
214   correlationMetric->SetMovingTransform( translationTransform );
215   correlationMetric->Initialize();
216 
217   translationTransform->SetIdentity();
218 
219   std::cout << std::endl << "*** Single image metric: " << std::endl;
220   CorrelationMetricType::MeasureType singleValueResult = 0.0;
221   CorrelationMetricType::DerivativeType singleDerivativeResult;
222   singleDerivativeResult.Fill(0);
223   ObjectToObjectMultiMetricv4RegistrationTestRun<CorrelationMetricType>( correlationMetric, numberOfIterations, singleValueResult, singleDerivativeResult, 1.0, true );
224 
225   std::cout << "*** multi-variate metric: " << std::endl;
226   CorrelationMetricType::Pointer metric2 = CorrelationMetricType::New();
227   metric2->SetFixedImage( fixedImage );
228   metric2->SetMovingImage( movingImage );
229   metric2->SetMovingTransform( translationTransform );
230 
231   using MultiMetricType = itk::ObjectToObjectMultiMetricv4<Dimension,Dimension>;
232   MultiMetricType::Pointer multiMetric = MultiMetricType::New();
233   multiMetric->AddMetric( correlationMetric );
234   multiMetric->AddMetric( metric2 );
235   multiMetric->AddMetric( metric2 );
236   multiMetric->Initialize();
237 
238   translationTransform->SetIdentity();
239 
240   CorrelationMetricType::MeasureType multiValueResult = 0.0;
241   CorrelationMetricType::DerivativeType multiDerivativeResult;
242   multiDerivativeResult.Fill(0);
243   ObjectToObjectMultiMetricv4RegistrationTestRun<MultiMetricType>( multiMetric, numberOfIterations, multiValueResult, multiDerivativeResult, 1.0, true );
244 
245   // Comparison between single-metric and multi-variate metric registrations
246   auto tolerance = static_cast<CorrelationMetricType::DerivativeValueType>(1e-6);
247   if( std::fabs( multiDerivativeResult[0] - singleDerivativeResult[0] ) > tolerance ||
248       std::fabs( multiDerivativeResult[1] - singleDerivativeResult[1] ) > tolerance )
249       {
250       std::cerr << "multi-variate registration derivative: " << multiDerivativeResult
251                 << " are different from single-variate derivative: " << singleDerivativeResult << std::endl;
252       return EXIT_FAILURE;
253       }
254   if( std::fabs( multiValueResult - singleValueResult ) > tolerance )
255       {
256       std::cerr << "multi-variate registration value: " << multiValueResult
257                 << " is different from single-variate value: " << singleValueResult << std::endl;
258       return EXIT_FAILURE;
259       }
260 
261   // compare results with truth
262   tolerance = static_cast<CorrelationMetricType::DerivativeValueType>(0.05);
263   if( std::fabs( multiDerivativeResult[0] - imageShift[0] ) / imageShift[0] > tolerance ||
264       std::fabs( multiDerivativeResult[1] - imageShift[1] ) / imageShift[1] > tolerance )
265       {
266       std::cerr << "multi-variate registration results: " << multiDerivativeResult << " are not as expected: " << imageShift << std::endl;
267       return EXIT_FAILURE;
268       }
269 
270 
271   //
272   // Try with step estimation at every iteration
273   // Comparison between single-metric and multi-variate metric registrations
274   //
275   std::cout << std::endl << "*** Single image metric 2: " << std::endl;
276   translationTransform->SetIdentity();
277   ObjectToObjectMultiMetricv4RegistrationTestRun<CorrelationMetricType>( correlationMetric, numberOfIterations, singleValueResult, singleDerivativeResult, 0.25, false );
278 
279   std::cout << std::endl << "*** Multi-variate image metric 2: " << std::endl;
280   translationTransform->SetIdentity();
281   ObjectToObjectMultiMetricv4RegistrationTestRun<MultiMetricType>( multiMetric, numberOfIterations, multiValueResult, multiDerivativeResult, 0.25, false );
282 
283   if( std::fabs( multiDerivativeResult[0] - singleDerivativeResult[0] ) > tolerance ||
284       std::fabs( multiDerivativeResult[1] - singleDerivativeResult[1] ) > tolerance )
285       {
286       std::cerr << "multi-variate registration derivative: " << multiDerivativeResult
287                 << " are different from single-variate derivative: " << singleDerivativeResult << std::endl;
288       return EXIT_FAILURE;
289       }
290   if( std::fabs( multiValueResult - singleValueResult ) > tolerance )
291       {
292       std::cerr << "multi-variate registration value: " << multiValueResult
293                 << " is different from single-variate value: " << singleValueResult << std::endl;
294       return EXIT_FAILURE;
295       }
296 
297   // compare results with truth
298   tolerance = static_cast<CorrelationMetricType::DerivativeValueType>(0.05);
299   if( std::fabs( multiDerivativeResult[0] - imageShift[0] ) / imageShift[0] > tolerance ||
300       std::fabs( multiDerivativeResult[1] - imageShift[1] ) / imageShift[1] > tolerance )
301       {
302       std::cerr << "multi-variate registration results: " << multiDerivativeResult << " are not as expected: " << imageShift << std::endl;
303       return EXIT_FAILURE;
304       }
305 
306   //
307   // Test with two different metric types
308   //
309   using MeanSquaresMetricType = itk::MeanSquaresImageToImageMetricv4<ImageType, ImageType>;
310   MeanSquaresMetricType::Pointer meanSquaresMetric = MeanSquaresMetricType::New();
311   meanSquaresMetric->SetFixedImage( fixedImage );
312   meanSquaresMetric->SetMovingImage( movingImage );
313   meanSquaresMetric->SetMovingTransform( translationTransform );
314 
315   MultiMetricType::Pointer multiMetric2 = MultiMetricType::New();
316   multiMetric2->AddMetric( correlationMetric );
317   multiMetric2->AddMetric( meanSquaresMetric );
318   multiMetric2->Initialize();
319 
320   translationTransform->SetIdentity();
321   std::cout << "*** Multi-metric with different metric types: " << std::endl;
322   ObjectToObjectMultiMetricv4RegistrationTestRun<MultiMetricType>( multiMetric2, numberOfIterations, multiValueResult, multiDerivativeResult, 1.0, true );
323 
324   // compare results with truth
325   tolerance = static_cast<MeanSquaresMetricType::DerivativeValueType>(0.05);
326   if( std::fabs( multiDerivativeResult[0] - imageShift[0] ) / imageShift[0] > tolerance ||
327       std::fabs( multiDerivativeResult[1] - imageShift[1] ) / imageShift[1] > tolerance )
328       {
329       std::cerr << "multi-variate registration results: " << multiDerivativeResult << " are not as expected: " << imageShift << std::endl;
330       return EXIT_FAILURE;
331       }
332 
333   return EXIT_SUCCESS;
334 }
335