1 /*=========================================================================
2 *
3 * Copyright Insight Software Consortium
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0.txt
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 *=========================================================================*/
18
19 #include "itkCorrelationImageToImageMetricv4.h"
20 #include "itkTranslationTransform.h"
21 #include "itkLinearInterpolateImageFunction.h"
22 #include "itkImage.h"
23 #include "itkGaussianImageSource.h"
24 #include "itkCyclicShiftImageFilter.h"
25 #include "itkRegistrationParameterScalesFromPhysicalShift.h"
26 #include "itkGradientDescentOptimizerv4.h"
27 #include "itkImageRegionIteratorWithIndex.h"
28 #include "itkObjectToObjectMultiMetricv4.h"
29 #include "itkMeanSquaresImageToImageMetricv4.h"
30
31 /* This test performs a simple registration test using
32 * a single metric and a multivariate metric containing
33 * two copies of the metric, testing
34 * that the results are the same.
35 */
36
37 template<typename TFilter>
38 class itkObjectToObjectMultiMetricv4RegistrationTestCommandIterationUpdate : public itk::Command
39 {
40 public:
41 using Self = itkObjectToObjectMultiMetricv4RegistrationTestCommandIterationUpdate;
42
43 using Superclass = itk::Command;
44 using Pointer = itk::SmartPointer<Self>;
45 itkNewMacro( Self );
46
47 protected:
48 itkObjectToObjectMultiMetricv4RegistrationTestCommandIterationUpdate() = default;
49
50 public:
51
Execute(itk::Object * caller,const itk::EventObject & event)52 void Execute(itk::Object *caller, const itk::EventObject & event) override
53 {
54 Execute( (const itk::Object *) caller, event);
55 }
56
Execute(const itk::Object * object,const itk::EventObject & event)57 void Execute(const itk::Object * object, const itk::EventObject & event) override
58 {
59 if( typeid( event ) != typeid( itk::IterationEvent ) )
60 {
61 return;
62 }
63 const auto * optimizer = dynamic_cast< const TFilter * >( object );
64
65 if( !optimizer )
66 {
67 itkGenericExceptionMacro( "Error dynamic_cast failed" );
68 }
69 std::cout << "It- " << optimizer->GetCurrentIteration() << " gradient: " << optimizer->GetGradient() << " metric value: " << optimizer->GetCurrentMetricValue()
70 << " Params: " << const_cast<TFilter*>(optimizer)->GetCurrentPosition() << std::endl;
71 }
72 };
73
74 template<typename TImage>
ObjectToObjectMultiMetricv4RegistrationTestCreateImages(typename TImage::Pointer & fixedImage,typename TImage::Pointer & movingImage,typename TImage::OffsetType & imageShift)75 void ObjectToObjectMultiMetricv4RegistrationTestCreateImages( typename TImage::Pointer & fixedImage, typename TImage::Pointer & movingImage, typename TImage::OffsetType & imageShift )
76 {
77 using PixelType = typename TImage::PixelType;
78 using CoordinateRepresentationType = PixelType;
79
80 // Create two simple images
81 itk::SizeValueType ImageSize = 100;
82 itk::OffsetValueType boundary = 6;
83
84 // Declare Gaussian Sources
85 using GaussianImageSourceType = itk::GaussianImageSource< TImage >;
86
87 typename TImage::SizeType size;
88 size.Fill( ImageSize );
89
90 typename TImage::SpacingType spacing;
91 spacing.Fill( itk::NumericTraits<CoordinateRepresentationType>::OneValue() );
92
93 typename TImage::PointType origin;
94 origin.Fill( itk::NumericTraits<CoordinateRepresentationType>::ZeroValue() );
95
96 typename TImage::DirectionType direction;
97 direction.Fill( itk::NumericTraits<CoordinateRepresentationType>::OneValue() );
98
99 typename GaussianImageSourceType::Pointer fixedImageSource = GaussianImageSourceType::New();
100
101 fixedImageSource->SetSize( size );
102 fixedImageSource->SetOrigin( origin );
103 fixedImageSource->SetSpacing( spacing );
104 fixedImageSource->SetNormalized( false );
105 fixedImageSource->SetScale( 1.0f );
106 fixedImageSource->Update();
107 fixedImage = fixedImageSource->GetOutput();
108
109 // zero-out the boundary
110 itk::ImageRegionIteratorWithIndex<TImage> it( fixedImage, fixedImage->GetLargestPossibleRegion() );
111 for( it.GoToBegin(); ! it.IsAtEnd(); ++it )
112 {
113 for( itk::SizeValueType n=0; n < TImage::ImageDimension; n++ )
114 {
115 if( it.GetIndex()[n] < boundary || (static_cast<itk::OffsetValueType>(size[n]) - it.GetIndex()[n]) <= boundary )
116 {
117 it.Set( itk::NumericTraits<PixelType>::ZeroValue() );
118 break;
119 }
120 }
121 }
122
123 // shift the fixed image to get the moving image
124 using CyclicShiftFilterType = itk::CyclicShiftImageFilter<TImage, TImage>;
125 typename CyclicShiftFilterType::Pointer shiftFilter = CyclicShiftFilterType::New();
126 typename CyclicShiftFilterType::OffsetValueType maxImageShift = boundary-1;
127 imageShift.Fill( maxImageShift );
128 imageShift[0] = maxImageShift / 2;
129 shiftFilter->SetInput( fixedImage );
130 shiftFilter->SetShift( imageShift );
131 shiftFilter->Update();
132 movingImage = shiftFilter->GetOutput();
133 }
134
135 //////////////////////////////////////////////////////////////////////////////////////////////////////////////
136
137 template<typename TMetric>
ObjectToObjectMultiMetricv4RegistrationTestRun(typename TMetric::Pointer & metric,int numberOfIterations,typename TMetric::MeasureType & valueResult,typename TMetric::DerivativeType & derivativeResult,typename TMetric::InternalComputationValueType maxStep,bool estimateStepOnce)138 int ObjectToObjectMultiMetricv4RegistrationTestRun( typename TMetric::Pointer & metric, int numberOfIterations,
139 typename TMetric::MeasureType & valueResult, typename TMetric::DerivativeType & derivativeResult,
140 typename TMetric::InternalComputationValueType maxStep, bool estimateStepOnce )
141 {
142 // calculate initial metric value
143 metric->Initialize();
144 typename TMetric::MeasureType initialValue = metric->GetValue();
145
146 // scales estimator
147 using RegistrationParameterScalesFromShiftType = itk::RegistrationParameterScalesFromPhysicalShift< TMetric >;
148 typename RegistrationParameterScalesFromShiftType::Pointer shiftScaleEstimator = RegistrationParameterScalesFromShiftType::New();
149 shiftScaleEstimator->SetMetric(metric);
150
151 //
152 // optimizer
153 //
154 using OptimizerType = itk::GradientDescentOptimizerv4;
155 typename OptimizerType::Pointer optimizer = OptimizerType::New();
156
157 optimizer->SetMetric( metric );
158 optimizer->SetNumberOfIterations( numberOfIterations );
159 optimizer->SetScalesEstimator( shiftScaleEstimator );
160 optimizer->SetMaximumStepSizeInPhysicalUnits( maxStep );
161 optimizer->SetDoEstimateLearningRateOnce( estimateStepOnce );
162 optimizer->SetDoEstimateLearningRateAtEachIteration( ! estimateStepOnce );
163
164 using CommandType = itkObjectToObjectMultiMetricv4RegistrationTestCommandIterationUpdate<OptimizerType>;
165 typename CommandType::Pointer observer = CommandType::New();
166 //optimizer->AddObserver( itk::IterationEvent(), observer );
167
168 optimizer->StartOptimization();
169
170 std::cout << "# of iterations: " << optimizer->GetNumberOfIterations() << std::endl;
171 std::cout << "DoEstimateLearningRateOnce: " << optimizer->GetDoEstimateLearningRateOnce()
172 << " GetDoEstimateLearningRateAtEachIteration: " << optimizer->GetDoEstimateLearningRateAtEachIteration() << std::endl;
173 derivativeResult = optimizer->GetCurrentPosition();
174 std::cout << "Transform final parameters: " << derivativeResult << " mag: " << derivativeResult.magnitude() << std::endl;
175
176 // final metric value
177 valueResult = metric->GetValue();
178 std::cout << "metric value: initial: " << initialValue << ", final: " << valueResult << std::endl;
179
180 // scales
181 std::cout << "scales: " << optimizer->GetScales() << std::endl;
182 std::cout << "optimizer learning rate at end: " << optimizer->GetLearningRate() << std::endl;
183
184 return EXIT_SUCCESS;
185 }
186
187 //////////////////////////////////////////////////////////////
itkObjectToObjectMultiMetricv4RegistrationTest(int argc,char * argv[])188 int itkObjectToObjectMultiMetricv4RegistrationTest(int argc, char *argv[])
189 {
190 constexpr int Dimension = 2;
191 using ImageType = itk::Image< double, Dimension >;
192
193 int numberOfIterations = 30;
194 if( argc > 1 )
195 {
196 numberOfIterations = std::stoi( argv[1] );
197 }
198
199 // create an affine transform
200 using TranslationTransformType = itk::TranslationTransform<double, Dimension>;
201 TranslationTransformType::Pointer translationTransform = TranslationTransformType::New();
202 translationTransform->SetIdentity();
203
204 // create images
205 ImageType::Pointer fixedImage = nullptr, movingImage = nullptr;
206 ImageType::OffsetType imageShift;
207 imageShift.Fill(0);
208 ObjectToObjectMultiMetricv4RegistrationTestCreateImages<ImageType>( fixedImage, movingImage, imageShift );
209
210 using CorrelationMetricType = itk::CorrelationImageToImageMetricv4<ImageType, ImageType>;
211 CorrelationMetricType::Pointer correlationMetric = CorrelationMetricType::New();
212 correlationMetric->SetFixedImage( fixedImage );
213 correlationMetric->SetMovingImage( movingImage );
214 correlationMetric->SetMovingTransform( translationTransform );
215 correlationMetric->Initialize();
216
217 translationTransform->SetIdentity();
218
219 std::cout << std::endl << "*** Single image metric: " << std::endl;
220 CorrelationMetricType::MeasureType singleValueResult = 0.0;
221 CorrelationMetricType::DerivativeType singleDerivativeResult;
222 singleDerivativeResult.Fill(0);
223 ObjectToObjectMultiMetricv4RegistrationTestRun<CorrelationMetricType>( correlationMetric, numberOfIterations, singleValueResult, singleDerivativeResult, 1.0, true );
224
225 std::cout << "*** multi-variate metric: " << std::endl;
226 CorrelationMetricType::Pointer metric2 = CorrelationMetricType::New();
227 metric2->SetFixedImage( fixedImage );
228 metric2->SetMovingImage( movingImage );
229 metric2->SetMovingTransform( translationTransform );
230
231 using MultiMetricType = itk::ObjectToObjectMultiMetricv4<Dimension,Dimension>;
232 MultiMetricType::Pointer multiMetric = MultiMetricType::New();
233 multiMetric->AddMetric( correlationMetric );
234 multiMetric->AddMetric( metric2 );
235 multiMetric->AddMetric( metric2 );
236 multiMetric->Initialize();
237
238 translationTransform->SetIdentity();
239
240 CorrelationMetricType::MeasureType multiValueResult = 0.0;
241 CorrelationMetricType::DerivativeType multiDerivativeResult;
242 multiDerivativeResult.Fill(0);
243 ObjectToObjectMultiMetricv4RegistrationTestRun<MultiMetricType>( multiMetric, numberOfIterations, multiValueResult, multiDerivativeResult, 1.0, true );
244
245 // Comparison between single-metric and multi-variate metric registrations
246 auto tolerance = static_cast<CorrelationMetricType::DerivativeValueType>(1e-6);
247 if( std::fabs( multiDerivativeResult[0] - singleDerivativeResult[0] ) > tolerance ||
248 std::fabs( multiDerivativeResult[1] - singleDerivativeResult[1] ) > tolerance )
249 {
250 std::cerr << "multi-variate registration derivative: " << multiDerivativeResult
251 << " are different from single-variate derivative: " << singleDerivativeResult << std::endl;
252 return EXIT_FAILURE;
253 }
254 if( std::fabs( multiValueResult - singleValueResult ) > tolerance )
255 {
256 std::cerr << "multi-variate registration value: " << multiValueResult
257 << " is different from single-variate value: " << singleValueResult << std::endl;
258 return EXIT_FAILURE;
259 }
260
261 // compare results with truth
262 tolerance = static_cast<CorrelationMetricType::DerivativeValueType>(0.05);
263 if( std::fabs( multiDerivativeResult[0] - imageShift[0] ) / imageShift[0] > tolerance ||
264 std::fabs( multiDerivativeResult[1] - imageShift[1] ) / imageShift[1] > tolerance )
265 {
266 std::cerr << "multi-variate registration results: " << multiDerivativeResult << " are not as expected: " << imageShift << std::endl;
267 return EXIT_FAILURE;
268 }
269
270
271 //
272 // Try with step estimation at every iteration
273 // Comparison between single-metric and multi-variate metric registrations
274 //
275 std::cout << std::endl << "*** Single image metric 2: " << std::endl;
276 translationTransform->SetIdentity();
277 ObjectToObjectMultiMetricv4RegistrationTestRun<CorrelationMetricType>( correlationMetric, numberOfIterations, singleValueResult, singleDerivativeResult, 0.25, false );
278
279 std::cout << std::endl << "*** Multi-variate image metric 2: " << std::endl;
280 translationTransform->SetIdentity();
281 ObjectToObjectMultiMetricv4RegistrationTestRun<MultiMetricType>( multiMetric, numberOfIterations, multiValueResult, multiDerivativeResult, 0.25, false );
282
283 if( std::fabs( multiDerivativeResult[0] - singleDerivativeResult[0] ) > tolerance ||
284 std::fabs( multiDerivativeResult[1] - singleDerivativeResult[1] ) > tolerance )
285 {
286 std::cerr << "multi-variate registration derivative: " << multiDerivativeResult
287 << " are different from single-variate derivative: " << singleDerivativeResult << std::endl;
288 return EXIT_FAILURE;
289 }
290 if( std::fabs( multiValueResult - singleValueResult ) > tolerance )
291 {
292 std::cerr << "multi-variate registration value: " << multiValueResult
293 << " is different from single-variate value: " << singleValueResult << std::endl;
294 return EXIT_FAILURE;
295 }
296
297 // compare results with truth
298 tolerance = static_cast<CorrelationMetricType::DerivativeValueType>(0.05);
299 if( std::fabs( multiDerivativeResult[0] - imageShift[0] ) / imageShift[0] > tolerance ||
300 std::fabs( multiDerivativeResult[1] - imageShift[1] ) / imageShift[1] > tolerance )
301 {
302 std::cerr << "multi-variate registration results: " << multiDerivativeResult << " are not as expected: " << imageShift << std::endl;
303 return EXIT_FAILURE;
304 }
305
306 //
307 // Test with two different metric types
308 //
309 using MeanSquaresMetricType = itk::MeanSquaresImageToImageMetricv4<ImageType, ImageType>;
310 MeanSquaresMetricType::Pointer meanSquaresMetric = MeanSquaresMetricType::New();
311 meanSquaresMetric->SetFixedImage( fixedImage );
312 meanSquaresMetric->SetMovingImage( movingImage );
313 meanSquaresMetric->SetMovingTransform( translationTransform );
314
315 MultiMetricType::Pointer multiMetric2 = MultiMetricType::New();
316 multiMetric2->AddMetric( correlationMetric );
317 multiMetric2->AddMetric( meanSquaresMetric );
318 multiMetric2->Initialize();
319
320 translationTransform->SetIdentity();
321 std::cout << "*** Multi-metric with different metric types: " << std::endl;
322 ObjectToObjectMultiMetricv4RegistrationTestRun<MultiMetricType>( multiMetric2, numberOfIterations, multiValueResult, multiDerivativeResult, 1.0, true );
323
324 // compare results with truth
325 tolerance = static_cast<MeanSquaresMetricType::DerivativeValueType>(0.05);
326 if( std::fabs( multiDerivativeResult[0] - imageShift[0] ) / imageShift[0] > tolerance ||
327 std::fabs( multiDerivativeResult[1] - imageShift[1] ) / imageShift[1] > tolerance )
328 {
329 std::cerr << "multi-variate registration results: " << multiDerivativeResult << " are not as expected: " << imageShift << std::endl;
330 return EXIT_FAILURE;
331 }
332
333 return EXIT_SUCCESS;
334 }
335