1 /*=========================================================================
2  *
3  *  Copyright Insight Software Consortium
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at
8  *
9  *         http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  *=========================================================================*/
18 
19 #include "itkTranslationTransform.h"
20 #include "itkLinearInterpolateImageFunction.h"
21 #include "itkMeanSquaresImageToImageMetric.h"
22 #include "itkGaussianImageSource.h"
23 
24 #include <iostream>
25 #include "itkStdStreamStateSave.h"
26 
27 /**
28  *  This test uses two 2D-Gaussians (standard deviation RegionSize/2)
29  *  One is shifted by 5 pixels from the other.
30  *
31  *  This test computes the mean squares value and derivatives
32  *  for various shift values in (-10,10).
33  *
34  */
35 
itkMeanSquaresImageMetricTest(int,char * [])36 int itkMeanSquaresImageMetricTest(int, char* [] )
37 {
38 
39 // Save the format stream variables for std::cout
40 // They will be restored when coutState goes out of scope
41   itk::StdStreamStateSave coutState(std::cout);
42 
43 //------------------------------------------------------------
44 // Create two simple images
45 //------------------------------------------------------------
46 
47   constexpr unsigned int ImageDimension = 2;
48 
49   using PixelType = double;
50 
51   using CoordinateRepresentationType = double;
52 
53   //Allocate Images
54   using MovingImageType = itk::Image<PixelType,ImageDimension>;
55   using FixedImageType = itk::Image<PixelType,ImageDimension>;
56 
57   // Declare Gaussian Sources
58   using MovingImageSourceType = itk::GaussianImageSource< MovingImageType >;
59   using FixedImageSourceType = itk::GaussianImageSource< FixedImageType  >;
60 
61   // Note: the following declarations are classical arrays
62   FixedImageType::SizeValueType fixedImageSize[]     = {  100,  100 };
63   MovingImageType::SizeValueType movingImageSize[]    = {  100,  100 };
64 
65   FixedImageType::SpacingValueType fixedImageSpacing[]  = { 1.0f, 1.0f };
66   MovingImageType::SpacingValueType movingImageSpacing[] = { 1.0f, 1.0f };
67 
68   FixedImageType::PointValueType fixedImageOrigin[]   = { 0.0f, 0.0f };
69   MovingImageType::PointValueType movingImageOrigin[]  = { 0.0f, 0.0f };
70 
71   MovingImageSourceType::Pointer movingImageSource = MovingImageSourceType::New();
72   FixedImageSourceType::Pointer  fixedImageSource  = FixedImageSourceType::New();
73 
74   movingImageSource->SetSize(    movingImageSize    );
75   movingImageSource->SetOrigin(  movingImageOrigin  );
76   movingImageSource->SetSpacing( movingImageSpacing );
77   movingImageSource->SetNormalized( false );
78   movingImageSource->SetScale( 250.0f );
79 
80   fixedImageSource->SetSize(    fixedImageSize    );
81   fixedImageSource->SetOrigin(  fixedImageOrigin  );
82   fixedImageSource->SetSpacing( fixedImageSpacing );
83   fixedImageSource->SetNormalized( false );
84   fixedImageSource->SetScale( 250.0f );
85 
86   movingImageSource->Update(); // Force the filter to run
87   fixedImageSource->Update();  // Force the filter to run
88 
89   MovingImageType::Pointer movingImage = movingImageSource->GetOutput();
90   FixedImageType::Pointer  fixedImage  = fixedImageSource->GetOutput();
91 
92 
93 //-----------------------------------------------------------
94 // Set up  the Metric
95 //-----------------------------------------------------------
96   using MetricType = itk::MeanSquaresImageToImageMetric<
97                                        FixedImageType,
98                                        MovingImageType >;
99 
100   using TransformBaseType = MetricType::TransformType;
101   using ParametersType = TransformBaseType::ParametersType;
102 
103   MetricType::Pointer  metric = MetricType::New();
104 
105 
106 //-----------------------------------------------------------
107 // Plug the Images into the metric
108 //-----------------------------------------------------------
109   metric->SetFixedImage( fixedImage );
110   metric->SetMovingImage( movingImage );
111 
112 //-----------------------------------------------------------
113 // Set up a Transform
114 //-----------------------------------------------------------
115 
116   using TransformType = itk::TranslationTransform<
117                         CoordinateRepresentationType,
118                         ImageDimension >;
119 
120   TransformType::Pointer transform = TransformType::New();
121 
122   metric->SetTransform( transform );
123 
124 
125 //------------------------------------------------------------
126 // Set up an Interpolator
127 //------------------------------------------------------------
128   using InterpolatorType = itk::LinearInterpolateImageFunction<
129                     MovingImageType,
130                     double >;
131 
132   InterpolatorType::Pointer interpolator = InterpolatorType::New();
133 
134   interpolator->SetInputImage( movingImage );
135 
136   metric->SetInterpolator( interpolator );
137 
138 
139 //------------------------------------------------------------
140 // Define the region over which the metric will be computed
141 //------------------------------------------------------------
142   metric->SetFixedImageRegion( fixedImage->GetBufferedRegion() );
143 
144   std::cout << metric << std::endl;
145 
146 
147 //------------------------------------------------------------
148 // This call is mandatory before start querying the Metric
149 // This method makes all the necessary connections between the
150 // internal components: Interpolator, Transform and Images
151 //------------------------------------------------------------
152   try {
153     metric->Initialize();
154     }
155   catch( itk::ExceptionObject & e )
156     {
157     std::cout << "Metric initialization failed" << std::endl;
158     std::cout << "Reason " << e.GetDescription() << std::endl;
159 
160     return EXIT_FAILURE;
161     }
162 
163 
164 //------------------------------------------------------------
165 // Set up transform parameters
166 //------------------------------------------------------------
167   ParametersType parameters( transform->GetNumberOfParameters() );
168 
169   // initialize the offset/vector part
170   for( unsigned int k = 0; k < ImageDimension; k++ )
171     {
172     parameters[k]= 0.0f;
173     }
174 
175 
176 //---------------------------------------------------------
177 // Print out metric values
178 // for parameters[1] = {-10,10}  (arbitrary choice...)
179 //---------------------------------------------------------
180 
181   MetricType::MeasureType     measure;
182   MetricType::DerivativeType  derivative;
183 
184   std::cout << "param[1]   Metric    d(Metric)/d(param[1]) " << std::endl;
185 
186   for( double trans = -10; trans <= 5; trans += 0.2  )
187     {
188     parameters[1] = trans;
189     metric->GetValueAndDerivative( parameters, measure, derivative );
190 
191     std::cout.width(5);
192     std::cout.precision(5);
193     std::cout << trans;
194     std::cout.width(15);
195     std::cout.precision(5);
196     std::cout << measure;
197     std::cout.width(15);
198     std::cout.precision(5);
199     std::cout << derivative[1];
200     std::cout << std::endl;
201 
202     // exercise the other functions
203     metric->GetValue( parameters );
204     metric->GetDerivative( parameters, derivative );
205     }
206 
207   // Compute a reference metric and partial derivative with one
208   // thread. NOTE - this test checks for consistency in the answer
209   // computed by differing numbers of threads, not correctness.
210   metric->SetNumberOfWorkUnits(1);
211   metric->Initialize();
212   parameters[1] = 2.0;
213   MetricType::MeasureType    referenceMeasure;
214   MetricType::DerivativeType referenceDerivative;
215   referenceMeasure = metric->GetValue(parameters);
216   metric->GetDerivative( parameters, referenceDerivative );
217 
218   std::cout << "Testing consistency of the metric value computed by "
219             << "several different thread counts." << std::endl;
220 
221   // Now check that the same metric value is computed when the number
222   // of threads is adjusted from 1 to 8.
223   for (int currNumThreadsToTest = 1; currNumThreadsToTest <= 8; currNumThreadsToTest++)
224     {
225     itk::MultiThreaderBase::SetGlobalMaximumNumberOfThreads(currNumThreadsToTest);
226     metric->SetNumberOfWorkUnits(currNumThreadsToTest);
227     metric->Initialize();
228 
229     std::cout << "Threads Metric    d(Metric)/d(param[1]) " << std::endl;
230 
231     measure = metric->GetValue( parameters );
232     metric->GetDerivative( parameters, derivative );
233     std::cout.width(4);
234     std::cout << currNumThreadsToTest;
235     std::cout.width(10);
236     std::cout.precision(5);
237     std::cout << measure;
238     std::cout.width(10);
239     std::cout.precision(5);
240     std::cout << derivative[1];
241     std::cout << std::endl;
242 
243     bool sameDerivative = true;
244     for (unsigned int d = 0; d < parameters.Size(); d++)
245       {
246       if ( fabs(derivative[d] - referenceDerivative[d]) > 1e-5 )
247         {
248         sameDerivative = false;
249         break;
250         }
251       }
252 
253     if ( fabs(measure - referenceMeasure) > 1e-5 || !sameDerivative )
254       {
255       std::cout << "Testing different number of threads... FAILED" << std::endl;
256       std::cout << "Metric value computed with " << currNumThreadsToTest
257                 << " threads is incorrect. Computed value is "
258                 << measure << ", should be " << referenceMeasure
259                 << ", computed derivative is " << derivative
260                 << ", should be " << referenceDerivative << std::endl;
261 
262       return EXIT_FAILURE;
263       }
264     }
265   std::cout << "Testing different number of threads... PASSED." << std::endl;
266 
267   // Now check that the same metric value is computed when the number
268   // of threads in the metric is set to 8 and the global max number of
269   // threads is reduced to 2. These are arbitrary numbers of threads
270   // used to verify the correctness of the metric under a particular
271   // usage scenario.
272   metric->SetNumberOfWorkUnits(8);
273   constexpr int numThreads = 2;
274   itk::MultiThreaderBase::SetGlobalMaximumNumberOfThreads(numThreads);
275   metric->Initialize();
276 
277   std::cout << "Threads Metric    d(Metric)/d(param[1]) " << std::endl;
278 
279   measure = metric->GetValue( parameters );
280   std::cout.width(4);
281   std::cout << numThreads;
282   std::cout.width(10);
283   std::cout.precision(5);
284   std::cout << measure;
285   std::cout.width(10);
286   std::cout.precision(5);
287   std::cout << derivative[1];
288   std::cout << std::endl;
289   if ( fabs(measure - referenceMeasure) > 1e-5 )
290     {
291     std::cout << "Test reducing global max number of threads... FAILED." << std::endl;
292     std::cout << "Metric value computed with " << numThreads
293               << " threads is incorrect. Computed value is "
294               << measure << ", should be " << referenceMeasure << std::endl;
295 
296     return EXIT_FAILURE;
297     }
298   std::cout << "Test reducing global max number of threads... PASSED." << std::endl;
299 
300 //-------------------------------------------------------
301 // exercise Print() method
302 //-------------------------------------------------------
303   metric->Print( std::cout );
304 
305 //-------------------------------------------------------
306 // exercise misc member functions
307 //-------------------------------------------------------
308   std::cout << "FixedImage: " << metric->GetFixedImage() << std::endl;
309   std::cout << "MovingImage: " << metric->GetMovingImage() << std::endl;
310   std::cout << "Transform: " << metric->GetTransform() << std::endl;
311   std::cout << "Interpolator: " << metric->GetInterpolator() << std::endl;
312   std::cout << "NumberOfPixelsCounted: " << metric->GetNumberOfPixelsCounted() << std::endl;
313   std::cout << "FixedImageRegion: " << metric->GetFixedImageRegion() << std::endl;
314 
315   std::cout << "Check case when Target is nullptr" << std::endl;
316   metric->SetFixedImage( nullptr );
317   try
318     {
319     std::cout << "Value = " << metric->GetValue( parameters );
320     std::cout << "If you are reading this message the Metric " << std::endl;
321     std::cout << "is NOT managing exceptions correctly    " << std::endl;
322 
323     return EXIT_FAILURE;
324     }
325   catch( itk::ExceptionObject & e )
326     {
327     std::cout << "Exception received (as expected) "    << std::endl;
328     std::cout << "Description : " << e.GetDescription() << std::endl;
329     std::cout << "Location    : " << e.GetLocation()    << std::endl;
330     std::cout << "Test for exception throwing... PASSED ! " << std::endl;
331     }
332 
333   try
334     {
335     metric->GetValueAndDerivative( parameters, measure, derivative );
336     std::cout << "Value = " << measure << std::endl;
337     std::cout << "If you are reading this message the Metric " << std::endl;
338     std::cout << "is NOT managing exceptions correctly    " << std::endl;
339 
340     return EXIT_FAILURE;
341     }
342   catch( itk::ExceptionObject & e )
343     {
344     std::cout << "Exception received (as expected) "    << std::endl;
345     std::cout << "Description : " << e.GetDescription() << std::endl;
346     std::cout << "Location    : " << e.GetLocation()    << std::endl;
347     std::cout << "Test for exception throwing... PASSED ! "  << std::endl;
348     }
349 
350  bool pass;
351 #define TEST_INITIALIZATION_ERROR( ComponentName, badComponent, goodComponent ) \
352   metric->Set##ComponentName( badComponent ); \
353   try \
354     { \
355     pass = false; \
356     metric->Initialize(); \
357     } \
358   catch( itk::ExceptionObject& err ) \
359     { \
360     std::cout << "Caught expected ExceptionObject" << std::endl; \
361     std::cout << err << std::endl; \
362     pass = true; \
363     } \
364   metric->Set##ComponentName( goodComponent ); \
365   \
366   if( !pass ) \
367     { \
368     std::cout << "Test failed." << std::endl; \
369     return EXIT_FAILURE; \
370     }
371 
372   TEST_INITIALIZATION_ERROR( Transform, nullptr, transform );
373   TEST_INITIALIZATION_ERROR( FixedImage, nullptr, fixedImage );
374   TEST_INITIALIZATION_ERROR( MovingImage, nullptr, movingImage );
375   TEST_INITIALIZATION_ERROR( Interpolator, nullptr, interpolator );
376 
377   std::cout << "Test passed. " << std::endl;
378   return EXIT_SUCCESS;
379 
380 }
381