1 /*=========================================================================
2 *
3 * Copyright Insight Software Consortium
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0.txt
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 *=========================================================================*/
18
19 #include "itkTranslationTransform.h"
20 #include "itkLinearInterpolateImageFunction.h"
21 #include "itkMeanSquaresImageToImageMetric.h"
22 #include "itkGaussianImageSource.h"
23
24 #include <iostream>
25 #include "itkStdStreamStateSave.h"
26
27 /**
28 * This test uses two 2D-Gaussians (standard deviation RegionSize/2)
29 * One is shifted by 5 pixels from the other.
30 *
31 * This test computes the mean squares value and derivatives
32 * for various shift values in (-10,10).
33 *
34 */
35
itkMeanSquaresImageMetricTest(int,char * [])36 int itkMeanSquaresImageMetricTest(int, char* [] )
37 {
38
39 // Save the format stream variables for std::cout
40 // They will be restored when coutState goes out of scope
41 itk::StdStreamStateSave coutState(std::cout);
42
43 //------------------------------------------------------------
44 // Create two simple images
45 //------------------------------------------------------------
46
47 constexpr unsigned int ImageDimension = 2;
48
49 using PixelType = double;
50
51 using CoordinateRepresentationType = double;
52
53 //Allocate Images
54 using MovingImageType = itk::Image<PixelType,ImageDimension>;
55 using FixedImageType = itk::Image<PixelType,ImageDimension>;
56
57 // Declare Gaussian Sources
58 using MovingImageSourceType = itk::GaussianImageSource< MovingImageType >;
59 using FixedImageSourceType = itk::GaussianImageSource< FixedImageType >;
60
61 // Note: the following declarations are classical arrays
62 FixedImageType::SizeValueType fixedImageSize[] = { 100, 100 };
63 MovingImageType::SizeValueType movingImageSize[] = { 100, 100 };
64
65 FixedImageType::SpacingValueType fixedImageSpacing[] = { 1.0f, 1.0f };
66 MovingImageType::SpacingValueType movingImageSpacing[] = { 1.0f, 1.0f };
67
68 FixedImageType::PointValueType fixedImageOrigin[] = { 0.0f, 0.0f };
69 MovingImageType::PointValueType movingImageOrigin[] = { 0.0f, 0.0f };
70
71 MovingImageSourceType::Pointer movingImageSource = MovingImageSourceType::New();
72 FixedImageSourceType::Pointer fixedImageSource = FixedImageSourceType::New();
73
74 movingImageSource->SetSize( movingImageSize );
75 movingImageSource->SetOrigin( movingImageOrigin );
76 movingImageSource->SetSpacing( movingImageSpacing );
77 movingImageSource->SetNormalized( false );
78 movingImageSource->SetScale( 250.0f );
79
80 fixedImageSource->SetSize( fixedImageSize );
81 fixedImageSource->SetOrigin( fixedImageOrigin );
82 fixedImageSource->SetSpacing( fixedImageSpacing );
83 fixedImageSource->SetNormalized( false );
84 fixedImageSource->SetScale( 250.0f );
85
86 movingImageSource->Update(); // Force the filter to run
87 fixedImageSource->Update(); // Force the filter to run
88
89 MovingImageType::Pointer movingImage = movingImageSource->GetOutput();
90 FixedImageType::Pointer fixedImage = fixedImageSource->GetOutput();
91
92
93 //-----------------------------------------------------------
94 // Set up the Metric
95 //-----------------------------------------------------------
96 using MetricType = itk::MeanSquaresImageToImageMetric<
97 FixedImageType,
98 MovingImageType >;
99
100 using TransformBaseType = MetricType::TransformType;
101 using ParametersType = TransformBaseType::ParametersType;
102
103 MetricType::Pointer metric = MetricType::New();
104
105
106 //-----------------------------------------------------------
107 // Plug the Images into the metric
108 //-----------------------------------------------------------
109 metric->SetFixedImage( fixedImage );
110 metric->SetMovingImage( movingImage );
111
112 //-----------------------------------------------------------
113 // Set up a Transform
114 //-----------------------------------------------------------
115
116 using TransformType = itk::TranslationTransform<
117 CoordinateRepresentationType,
118 ImageDimension >;
119
120 TransformType::Pointer transform = TransformType::New();
121
122 metric->SetTransform( transform );
123
124
125 //------------------------------------------------------------
126 // Set up an Interpolator
127 //------------------------------------------------------------
128 using InterpolatorType = itk::LinearInterpolateImageFunction<
129 MovingImageType,
130 double >;
131
132 InterpolatorType::Pointer interpolator = InterpolatorType::New();
133
134 interpolator->SetInputImage( movingImage );
135
136 metric->SetInterpolator( interpolator );
137
138
139 //------------------------------------------------------------
140 // Define the region over which the metric will be computed
141 //------------------------------------------------------------
142 metric->SetFixedImageRegion( fixedImage->GetBufferedRegion() );
143
144 std::cout << metric << std::endl;
145
146
147 //------------------------------------------------------------
148 // This call is mandatory before start querying the Metric
149 // This method makes all the necessary connections between the
150 // internal components: Interpolator, Transform and Images
151 //------------------------------------------------------------
152 try {
153 metric->Initialize();
154 }
155 catch( itk::ExceptionObject & e )
156 {
157 std::cout << "Metric initialization failed" << std::endl;
158 std::cout << "Reason " << e.GetDescription() << std::endl;
159
160 return EXIT_FAILURE;
161 }
162
163
164 //------------------------------------------------------------
165 // Set up transform parameters
166 //------------------------------------------------------------
167 ParametersType parameters( transform->GetNumberOfParameters() );
168
169 // initialize the offset/vector part
170 for( unsigned int k = 0; k < ImageDimension; k++ )
171 {
172 parameters[k]= 0.0f;
173 }
174
175
176 //---------------------------------------------------------
177 // Print out metric values
178 // for parameters[1] = {-10,10} (arbitrary choice...)
179 //---------------------------------------------------------
180
181 MetricType::MeasureType measure;
182 MetricType::DerivativeType derivative;
183
184 std::cout << "param[1] Metric d(Metric)/d(param[1]) " << std::endl;
185
186 for( double trans = -10; trans <= 5; trans += 0.2 )
187 {
188 parameters[1] = trans;
189 metric->GetValueAndDerivative( parameters, measure, derivative );
190
191 std::cout.width(5);
192 std::cout.precision(5);
193 std::cout << trans;
194 std::cout.width(15);
195 std::cout.precision(5);
196 std::cout << measure;
197 std::cout.width(15);
198 std::cout.precision(5);
199 std::cout << derivative[1];
200 std::cout << std::endl;
201
202 // exercise the other functions
203 metric->GetValue( parameters );
204 metric->GetDerivative( parameters, derivative );
205 }
206
207 // Compute a reference metric and partial derivative with one
208 // thread. NOTE - this test checks for consistency in the answer
209 // computed by differing numbers of threads, not correctness.
210 metric->SetNumberOfWorkUnits(1);
211 metric->Initialize();
212 parameters[1] = 2.0;
213 MetricType::MeasureType referenceMeasure;
214 MetricType::DerivativeType referenceDerivative;
215 referenceMeasure = metric->GetValue(parameters);
216 metric->GetDerivative( parameters, referenceDerivative );
217
218 std::cout << "Testing consistency of the metric value computed by "
219 << "several different thread counts." << std::endl;
220
221 // Now check that the same metric value is computed when the number
222 // of threads is adjusted from 1 to 8.
223 for (int currNumThreadsToTest = 1; currNumThreadsToTest <= 8; currNumThreadsToTest++)
224 {
225 itk::MultiThreaderBase::SetGlobalMaximumNumberOfThreads(currNumThreadsToTest);
226 metric->SetNumberOfWorkUnits(currNumThreadsToTest);
227 metric->Initialize();
228
229 std::cout << "Threads Metric d(Metric)/d(param[1]) " << std::endl;
230
231 measure = metric->GetValue( parameters );
232 metric->GetDerivative( parameters, derivative );
233 std::cout.width(4);
234 std::cout << currNumThreadsToTest;
235 std::cout.width(10);
236 std::cout.precision(5);
237 std::cout << measure;
238 std::cout.width(10);
239 std::cout.precision(5);
240 std::cout << derivative[1];
241 std::cout << std::endl;
242
243 bool sameDerivative = true;
244 for (unsigned int d = 0; d < parameters.Size(); d++)
245 {
246 if ( fabs(derivative[d] - referenceDerivative[d]) > 1e-5 )
247 {
248 sameDerivative = false;
249 break;
250 }
251 }
252
253 if ( fabs(measure - referenceMeasure) > 1e-5 || !sameDerivative )
254 {
255 std::cout << "Testing different number of threads... FAILED" << std::endl;
256 std::cout << "Metric value computed with " << currNumThreadsToTest
257 << " threads is incorrect. Computed value is "
258 << measure << ", should be " << referenceMeasure
259 << ", computed derivative is " << derivative
260 << ", should be " << referenceDerivative << std::endl;
261
262 return EXIT_FAILURE;
263 }
264 }
265 std::cout << "Testing different number of threads... PASSED." << std::endl;
266
267 // Now check that the same metric value is computed when the number
268 // of threads in the metric is set to 8 and the global max number of
269 // threads is reduced to 2. These are arbitrary numbers of threads
270 // used to verify the correctness of the metric under a particular
271 // usage scenario.
272 metric->SetNumberOfWorkUnits(8);
273 constexpr int numThreads = 2;
274 itk::MultiThreaderBase::SetGlobalMaximumNumberOfThreads(numThreads);
275 metric->Initialize();
276
277 std::cout << "Threads Metric d(Metric)/d(param[1]) " << std::endl;
278
279 measure = metric->GetValue( parameters );
280 std::cout.width(4);
281 std::cout << numThreads;
282 std::cout.width(10);
283 std::cout.precision(5);
284 std::cout << measure;
285 std::cout.width(10);
286 std::cout.precision(5);
287 std::cout << derivative[1];
288 std::cout << std::endl;
289 if ( fabs(measure - referenceMeasure) > 1e-5 )
290 {
291 std::cout << "Test reducing global max number of threads... FAILED." << std::endl;
292 std::cout << "Metric value computed with " << numThreads
293 << " threads is incorrect. Computed value is "
294 << measure << ", should be " << referenceMeasure << std::endl;
295
296 return EXIT_FAILURE;
297 }
298 std::cout << "Test reducing global max number of threads... PASSED." << std::endl;
299
300 //-------------------------------------------------------
301 // exercise Print() method
302 //-------------------------------------------------------
303 metric->Print( std::cout );
304
305 //-------------------------------------------------------
306 // exercise misc member functions
307 //-------------------------------------------------------
308 std::cout << "FixedImage: " << metric->GetFixedImage() << std::endl;
309 std::cout << "MovingImage: " << metric->GetMovingImage() << std::endl;
310 std::cout << "Transform: " << metric->GetTransform() << std::endl;
311 std::cout << "Interpolator: " << metric->GetInterpolator() << std::endl;
312 std::cout << "NumberOfPixelsCounted: " << metric->GetNumberOfPixelsCounted() << std::endl;
313 std::cout << "FixedImageRegion: " << metric->GetFixedImageRegion() << std::endl;
314
315 std::cout << "Check case when Target is nullptr" << std::endl;
316 metric->SetFixedImage( nullptr );
317 try
318 {
319 std::cout << "Value = " << metric->GetValue( parameters );
320 std::cout << "If you are reading this message the Metric " << std::endl;
321 std::cout << "is NOT managing exceptions correctly " << std::endl;
322
323 return EXIT_FAILURE;
324 }
325 catch( itk::ExceptionObject & e )
326 {
327 std::cout << "Exception received (as expected) " << std::endl;
328 std::cout << "Description : " << e.GetDescription() << std::endl;
329 std::cout << "Location : " << e.GetLocation() << std::endl;
330 std::cout << "Test for exception throwing... PASSED ! " << std::endl;
331 }
332
333 try
334 {
335 metric->GetValueAndDerivative( parameters, measure, derivative );
336 std::cout << "Value = " << measure << std::endl;
337 std::cout << "If you are reading this message the Metric " << std::endl;
338 std::cout << "is NOT managing exceptions correctly " << std::endl;
339
340 return EXIT_FAILURE;
341 }
342 catch( itk::ExceptionObject & e )
343 {
344 std::cout << "Exception received (as expected) " << std::endl;
345 std::cout << "Description : " << e.GetDescription() << std::endl;
346 std::cout << "Location : " << e.GetLocation() << std::endl;
347 std::cout << "Test for exception throwing... PASSED ! " << std::endl;
348 }
349
350 bool pass;
351 #define TEST_INITIALIZATION_ERROR( ComponentName, badComponent, goodComponent ) \
352 metric->Set##ComponentName( badComponent ); \
353 try \
354 { \
355 pass = false; \
356 metric->Initialize(); \
357 } \
358 catch( itk::ExceptionObject& err ) \
359 { \
360 std::cout << "Caught expected ExceptionObject" << std::endl; \
361 std::cout << err << std::endl; \
362 pass = true; \
363 } \
364 metric->Set##ComponentName( goodComponent ); \
365 \
366 if( !pass ) \
367 { \
368 std::cout << "Test failed." << std::endl; \
369 return EXIT_FAILURE; \
370 }
371
372 TEST_INITIALIZATION_ERROR( Transform, nullptr, transform );
373 TEST_INITIALIZATION_ERROR( FixedImage, nullptr, fixedImage );
374 TEST_INITIALIZATION_ERROR( MovingImage, nullptr, movingImage );
375 TEST_INITIALIZATION_ERROR( Interpolator, nullptr, interpolator );
376
377 std::cout << "Test passed. " << std::endl;
378 return EXIT_SUCCESS;
379
380 }
381