1 /*=========================================================================
2 *
3 * Copyright Insight Software Consortium
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0.txt
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 *=========================================================================*/
18 #include "itkMeanSquaresImageToImageMetricv4.h"
19 #include "itkTranslationTransform.h"
20 #include "itkMath.h"
21 #include "itkMath.h"
22
23 /* Simple test to verify that class builds and runs.
24 * Results are not verified. See ImageToImageMetricv4Test
25 * for verification of basic metric functionality.
26 *
27 * TODO Numerical verification.
28 */
29
itkMeanSquaresImageToImageMetricv4Test(int,char ** const)30 int itkMeanSquaresImageToImageMetricv4Test(int, char ** const)
31 {
32
33 constexpr unsigned int imageSize = 5;
34 constexpr unsigned int imageDimensionality = 3;
35 using ImageType = itk::Image< double, imageDimensionality >;
36
37 ImageType::SizeType size;
38 size.Fill( imageSize );
39 ImageType::IndexType index;
40 index.Fill( 0 );
41 ImageType::RegionType region;
42 region.SetSize( size );
43 region.SetIndex( index );
44 ImageType::SpacingType spacing;
45 spacing.Fill(1.0);
46 ImageType::PointType origin;
47 origin.Fill(0);
48 ImageType::DirectionType direction;
49 direction.SetIdentity();
50
51 /* Create simple test images. */
52 ImageType::Pointer fixedImage = ImageType::New();
53 fixedImage->SetRegions( region );
54 fixedImage->SetSpacing( spacing );
55 fixedImage->SetOrigin( origin );
56 fixedImage->SetDirection( direction );
57 fixedImage->Allocate();
58
59 ImageType::Pointer movingImage = ImageType::New();
60 movingImage->SetRegions( region );
61 movingImage->SetSpacing( spacing );
62 movingImage->SetOrigin( origin );
63 movingImage->SetDirection( direction );
64 movingImage->Allocate();
65
66 /* Fill images */
67 itk::ImageRegionIterator<ImageType> itFixed( fixedImage, region );
68 itFixed.GoToBegin();
69 unsigned int count = 1;
70 while( !itFixed.IsAtEnd() )
71 {
72 itFixed.Set( count*count );
73 count++;
74 ++itFixed;
75 }
76
77 itk::ImageRegionIteratorWithIndex<ImageType> itMoving( movingImage, region );
78
79 itMoving.GoToBegin();
80 count = 1;
81
82 while( !itMoving.IsAtEnd() )
83 {
84 itMoving.Set( 1.0/(count*count) );
85 count++;
86 ++itMoving;
87 }
88
89 /* Transforms */
90 using FixedTransformType = itk::TranslationTransform<double,imageDimensionality>;
91 using MovingTransformType = itk::TranslationTransform<double,imageDimensionality>;
92
93 FixedTransformType::Pointer fixedTransform = FixedTransformType::New();
94 MovingTransformType::Pointer movingTransform = MovingTransformType::New();
95
96 fixedTransform->SetIdentity();
97 movingTransform->SetIdentity();
98
99 /* The metric */
100 using MetricType = itk::MeanSquaresImageToImageMetricv4< ImageType, ImageType, ImageType >;
101
102 MetricType::Pointer metric = MetricType::New();
103
104 /* Assign images and transforms.
105 * By not setting a virtual domain image or virtual domain settings,
106 * the metric will use the fixed image for the virtual domain. */
107 metric->SetFixedImage( fixedImage );
108 metric->SetMovingImage( movingImage );
109 metric->SetFixedTransform( fixedTransform );
110 metric->SetMovingTransform( movingTransform );
111
112 /* Initialize. */
113 try
114 {
115 std::cout << "Calling Initialize..." << std::endl;
116 metric->Initialize();
117 }
118 catch( itk::ExceptionObject & exc )
119 {
120 std::cerr << "Caught unexpected exception during Initialize: " << exc << std::endl;
121 return EXIT_FAILURE;
122 }
123
124 // Evaluate with GetValueAndDerivative
125 MetricType::MeasureType valueReturn1, valueReturn2;
126 MetricType::DerivativeType derivativeReturn;
127
128 try
129 {
130 std::cout << "Calling GetValueAndDerivative..." << std::endl;
131 metric->GetValueAndDerivative( valueReturn1, derivativeReturn );
132 }
133 catch( itk::ExceptionObject & exc )
134 {
135 std::cout << "Caught unexpected exception during GetValueAndDerivative: "
136 << exc;
137 return EXIT_FAILURE;
138 }
139
140 /* Re-initialize. */
141 try
142 {
143 std::cout << "Calling Initialize..." << std::endl;
144 metric->Initialize();
145 }
146 catch( itk::ExceptionObject & exc )
147 {
148 std::cerr << "Caught unexpected exception during re-initialize: " << exc << std::endl;
149 return EXIT_FAILURE;
150 }
151
152 try
153 {
154 std::cout << "Calling GetValue..." << std::endl;
155 valueReturn2 = metric->GetValue();
156 }
157 catch( itk::ExceptionObject & exc )
158 {
159 std::cout << "Caught unexpected exception during GetValue: "
160 << exc;
161 return EXIT_FAILURE;
162 }
163
164 // Test same value returned by different methods
165 std::cout << "Check Value return values..." << std::endl;
166 if( itk::Math::NotExactlyEquals(valueReturn1, valueReturn2) )
167 {
168 std::cerr << "Results for Value don't match: " << valueReturn1
169 << ", " << valueReturn2 << std::endl;
170 }
171 else
172 {
173 std::cout << "Metric value = " << valueReturn1 << std::endl;
174 std::cout << "Gradient value = " << derivativeReturn << std::endl;
175 }
176
177 // Test that using floating point correction produces
178 // a different result
179 std::cout << "Testing with different floating point correction settings." << std::endl;
180 MetricType::DerivativeType derivativeWithFPC, derivativeWithOutFPC;
181 metric->SetMaximumNumberOfWorkUnits( 1 );
182 metric->SetUseFloatingPointCorrection( false ); //default
183 metric->GetValueAndDerivative( valueReturn1, derivativeWithOutFPC );
184 metric->SetUseFloatingPointCorrection( true );
185 metric->SetFloatingPointCorrectionResolution( 1e1 ); //severe truncation
186 metric->GetValueAndDerivative( valueReturn1, derivativeWithFPC );
187 if( derivativeWithFPC == derivativeWithOutFPC )
188 {
189 std::cerr << "Expected different derivative result when using floating-point correction: "
190 << "With correction: " << derivativeWithFPC << ", without: " << derivativeWithOutFPC
191 << std::endl;
192 return EXIT_FAILURE;
193 }
194
195 std::cout << "Test passed." << std::endl;
196 return EXIT_SUCCESS;
197 }
198