1 /*=========================================================================
2 *
3 * Copyright Insight Software Consortium
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0.txt
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 *=========================================================================*/
18
19 /**
20 * Test program for MeanSquaresImageToImageMetricv4 and
21 * GradientDescentOptimizerv4 classes with vector (color)
22 * images.
23 *
24 * Perform a registration using user-supplied images.
25 * No numerical verification is performed. Test passes as long
26 * as no exception occurs.
27 */
28 #include "itkMeanSquaresImageToImageMetricv4.h"
29 #include "itkGradientDescentOptimizerv4.h"
30 #include "itkRegistrationParameterScalesFromPhysicalShift.h"
31 #include "itkVectorImageToImageMetricTraitsv4.h"
32
33 #include "itkGaussianSmoothingOnUpdateDisplacementFieldTransform.h"
34
35 #include "itkCastImageFilter.h"
36
37 #include "itkCommand.h"
38 #include "itkImageFileReader.h"
39 #include "itkImageFileWriter.h"
40
41 #include <iomanip>
42 #include "itkTestingMacros.h"
43
itkMeanSquaresImageToImageMetricv4VectorRegistrationTest(int argc,char * argv[])44 int itkMeanSquaresImageToImageMetricv4VectorRegistrationTest(int argc, char *argv[])
45 {
46
47 if( argc < 4 )
48 {
49 std::cerr << "Missing Parameters " << std::endl;
50 std::cerr << "Usage: " << itkNameOfTestExecutableMacro(argv);
51 std::cerr << " fixedImageFile movingImageFile ";
52 std::cerr << " outputImageFile ";
53 std::cerr << " [numberOfAffineIterations numberOfDisplacementIterations] ";
54 std::cerr << std::endl;
55 return EXIT_FAILURE;
56 }
57
58 std::cout << argc << std::endl;
59 unsigned int numberOfAffineIterations = 2;
60 unsigned int numberOfDisplacementIterations = 2;
61 if( argc >= 5 )
62 {
63 numberOfAffineIterations = std::stoi( argv[4] );
64 }
65 if( argc >= 6 )
66 {
67 numberOfDisplacementIterations = std::stoi( argv[5] );
68 }
69 std::cout << " affine iterations "<< numberOfAffineIterations << " displacementIterations " << numberOfDisplacementIterations << std::endl;
70
71 constexpr unsigned int Dimension = 2;
72
73 // RGBPixel type is not supported by GradientRecursiveGaussianFilter at this point.
74 //using PixelType = itk::RGBPixel<FloatType>;
75 using PixelType = itk::Vector<double, 3>;
76
77 using FixedImageType = itk::Image< PixelType, Dimension >;
78 using MovingImageType = itk::Image< PixelType, Dimension >;
79
80 using FixedImageReaderType = itk::ImageFileReader< FixedImageType >;
81 using MovingImageReaderType = itk::ImageFileReader< MovingImageType >;
82
83 FixedImageReaderType::Pointer fixedImageReader = FixedImageReaderType::New();
84 MovingImageReaderType::Pointer movingImageReader = MovingImageReaderType::New();
85
86 fixedImageReader->SetFileName( argv[1] );
87 movingImageReader->SetFileName( argv[2] );
88
89 //get the images
90 fixedImageReader->Update();
91 FixedImageType::Pointer fixedImage = fixedImageReader->GetOutput();
92 movingImageReader->Update();
93 MovingImageType::Pointer movingImage = movingImageReader->GetOutput();
94
95 /** define a resample filter that will ultimately be used to deform the image */
96 using ResampleFilterType = itk::ResampleImageFilter< MovingImageType, FixedImageType >;
97 ResampleFilterType::Pointer resample = ResampleFilterType::New();
98
99 /** create a composite transform holder for other transforms */
100 using CompositeType = itk::CompositeTransform<double, Dimension>;
101
102 CompositeType::Pointer compositeTransform = CompositeType::New();
103
104 //create an affine transform
105 using AffineTransformType = itk::AffineTransform<double, Dimension>;
106 AffineTransformType::Pointer affineTransform = AffineTransformType::New();
107 affineTransform->SetIdentity();
108 std::cout <<" affineTransform params prior to optimization " << affineTransform->GetParameters() << std::endl;
109
110 using DisplacementTransformType = itk::GaussianSmoothingOnUpdateDisplacementFieldTransform< double, Dimension>;
111 DisplacementTransformType::Pointer displacementTransform = DisplacementTransformType::New();
112
113 using DisplacementFieldType = DisplacementTransformType::DisplacementFieldType;
114 DisplacementFieldType::Pointer field = DisplacementFieldType::New();
115
116 // set the field to be the same as the fixed image region, which will
117 // act by default as the virtual domain in this example.
118 field->SetRegions( fixedImage->GetLargestPossibleRegion() );
119 //make sure the field has the same spatial information as the image
120 field->CopyInformation( fixedImage );
121 std::cout << "fixedImage->GetLargestPossibleRegion(): " << fixedImage->GetLargestPossibleRegion() << std::endl;
122 field->Allocate();
123 // Fill it with 0's
124 DisplacementTransformType::OutputVectorType zeroVector;
125 zeroVector.Fill( 0 );
126 field->FillBuffer( zeroVector );
127 // Assign to transform
128 displacementTransform->SetDisplacementField( field );
129 displacementTransform->SetGaussianSmoothingVarianceForTheUpdateField( 5 );
130 displacementTransform->SetGaussianSmoothingVarianceForTheTotalField( 6 );
131
132 //identity transform for fixed image
133 using IdentityTransformType = itk::IdentityTransform<double, Dimension>;
134 IdentityTransformType::Pointer identityTransform = IdentityTransformType::New();
135 identityTransform->SetIdentity();
136
137 // The metric
138 using VirtualImageType = itk::Image< double, Dimension>;
139 using MetricTraitsType = itk::VectorImageToImageMetricTraitsv4< FixedImageType, MovingImageType, VirtualImageType, PixelType::Length>;
140 using MetricType = itk::MeanSquaresImageToImageMetricv4 < FixedImageType, MovingImageType, VirtualImageType, double, MetricTraitsType >;
141 using PointSetType = MetricType::FixedSampledPointSetType;
142 MetricType::Pointer metric = MetricType::New();
143
144 using PointType = PointSetType::PointType;
145 PointSetType::Pointer pset(PointSetType::New());
146 unsigned long ind=0,ct=0;
147 itk::ImageRegionIteratorWithIndex<FixedImageType> It(fixedImage, fixedImage->GetLargestPossibleRegion() );
148
149 for( It.GoToBegin(); !It.IsAtEnd(); ++It )
150 {
151 // take every N^th point
152 if ( ct % 2 == 0 )
153 {
154 PointType pt;
155 fixedImage->TransformIndexToPhysicalPoint( It.GetIndex(), pt);
156 pset->SetPoint(ind, pt);
157 ind++;
158 }
159 ct++;
160 }
161 std::cout << "Setting point set with " << ind << " points of " << fixedImage->GetLargestPossibleRegion().GetNumberOfPixels() << " total " << std::endl;
162 metric->SetFixedSampledPointSet( pset );
163 metric->SetUseSampledPointSet( true );
164 std::cout << "Testing metric with point set..." << std::endl;
165
166
167 // Assign images and transforms.
168 // By not setting a virtual domain image or virtual domain settings,
169 // the metric will use the fixed image for the virtual domain.
170 metric->SetFixedImage( fixedImage );
171 metric->SetMovingImage( movingImage );
172 metric->SetFixedTransform( identityTransform );
173 metric->SetMovingTransform( affineTransform );
174 const bool gaussian = false;
175 metric->SetUseMovingImageGradientFilter( gaussian );
176 metric->SetUseFixedImageGradientFilter( gaussian );
177 metric->Initialize();
178
179 using RegistrationParameterScalesFromShiftType = itk::RegistrationParameterScalesFromPhysicalShift< MetricType >;
180 RegistrationParameterScalesFromShiftType::Pointer shiftScaleEstimator = RegistrationParameterScalesFromShiftType::New();
181 shiftScaleEstimator->SetMetric(metric);
182
183 //
184 // Affine registration
185 //
186 std::cout << "First do an affine registration " << std::endl;
187 using OptimizerType = itk::GradientDescentOptimizerv4;
188 OptimizerType::Pointer optimizer = OptimizerType::New();
189 optimizer->SetMetric( metric );
190 optimizer->SetNumberOfIterations( numberOfAffineIterations );
191 optimizer->SetScalesEstimator( shiftScaleEstimator );
192 optimizer->StartOptimization();
193
194 std::cout << "Number of threads: metric: " << metric->GetNumberOfWorkUnitsUsed() << " optimizer: " << optimizer->GetNumberOfWorkUnits() << std::endl;
195 std::cout << "GetNumberOfSkippedFixedSampledPoints: " << metric->GetNumberOfSkippedFixedSampledPoints() << std::endl;
196
197 //
198 // Deformable registration
199 //
200 // now add the displacement field to the composite transform
201 compositeTransform->AddTransform( affineTransform );
202 compositeTransform->AddTransform( displacementTransform );
203 //compositeTransform->SetAllTransformsToOptimizeOn(); //Set to optimize all.
204 // Optimize only the displacement field, but still using the previously-compute affine transformation
205 compositeTransform->SetOnlyMostRecentTransformToOptimizeOn(); //set to optimize the displacement field
206 metric->SetMovingTransform( compositeTransform );
207 metric->SetUseSampledPointSet( false );
208 metric->Initialize();
209
210 // Optimizer
211 RegistrationParameterScalesFromShiftType::ScalesType displacementScales( displacementTransform->GetNumberOfLocalParameters() );
212 displacementScales.Fill(1);
213 if( false )
214 {
215 optimizer->SetScales( displacementScales );
216 }
217 else
218 {
219 optimizer->SetScalesEstimator( shiftScaleEstimator );
220 }
221 optimizer->SetMetric( metric );
222 optimizer->SetNumberOfIterations( numberOfDisplacementIterations );
223 try
224 {
225 if( numberOfDisplacementIterations > 0 )
226 {
227 std::cout << "Follow affine with deformable registration... " << std::endl;
228 optimizer->StartOptimization();
229 std::cout << "...finished. " << std::endl;
230 std::cout << "GetNumberOfSkippedFixedSampledPoints: " << metric->GetNumberOfSkippedFixedSampledPoints() << std::endl;
231 }
232 else
233 {
234 std::cout << "** SKIPPING DISPLACEMENT FIELD OPT\n";
235 }
236 }
237 catch( itk::ExceptionObject & e )
238 {
239 std::cout << "Exception thrown ! " << std::endl;
240 std::cout << "An error occurred during deformation Optimization:" << std::endl;
241 std::cout << e.GetLocation() << std::endl;
242 std::cout << e.GetDescription() << std::endl;
243 std::cout << e.what() << std::endl;
244 std::cout << "Test FAILED." << std::endl;
245 return EXIT_FAILURE;
246 }
247
248 //dump part of the displacement field for debugging
249 std::cout << "Deformation field samples: " << std::endl;
250 FixedImageType::SizeType size = fixedImage->GetBufferedRegion().GetSize();
251 for( itk::SizeValueType x = 0; x < size[0]; x+=30 )
252 {
253 std::cout << x << std::endl;
254 for( itk::SizeValueType y = 0; y < size[1]; y+=30 )
255 {
256 FixedImageType::IndexType index;
257 index[0] = x;
258 index[1] = y;
259 std::cout << field->GetPixel(index);
260 }
261 std::cout << std::endl;
262 }
263
264 //warp the image with the displacement field
265 resample->SetTransform( compositeTransform );
266 resample->SetInput( movingImageReader->GetOutput() );
267 resample->SetSize( fixedImage->GetLargestPossibleRegion().GetSize() );
268 resample->SetOutputOrigin( fixedImage->GetOrigin() );
269 resample->SetOutputSpacing( fixedImage->GetSpacing() );
270 resample->SetOutputDirection( fixedImage->GetDirection() );
271 resample->SetDefaultPixelValue( itk::NumericTraits<FixedImageType::PixelType::ValueType>::ZeroValue() );
272 resample->Update();
273
274 //write out the displacement field
275 using DisplacementWriterType = itk::ImageFileWriter< DisplacementFieldType >;
276 DisplacementWriterType::Pointer displacementwriter = DisplacementWriterType::New();
277 std::string outfilename( argv[3] );
278 std::string ext = itksys::SystemTools::GetFilenameExtension( outfilename );
279 std::string name = itksys::SystemTools::GetFilenameWithoutExtension( outfilename );
280 std::string path = itksys::SystemTools::GetFilenamePath( outfilename );
281 std::string defout = path + std::string( "/" ) + name + std::string("_def") + ext;
282 displacementwriter->SetFileName( defout.c_str() );
283 displacementwriter->SetInput( displacementTransform->GetDisplacementField() );
284 displacementwriter->Update();
285
286 //write the warped image into a file
287 using OutputPixelType = PixelType;
288 using OutputImageType = itk::Image< OutputPixelType, Dimension >;
289 using CastFilterType = itk::CastImageFilter<
290 MovingImageType,
291 OutputImageType >;
292 using WriterType = itk::ImageFileWriter< OutputImageType >;
293 WriterType::Pointer writer = WriterType::New();
294 CastFilterType::Pointer caster = CastFilterType::New();
295 writer->SetFileName( argv[3] );
296 caster->SetInput( resample->GetOutput() );
297 writer->SetInput( caster->GetOutput() );
298 writer->Update();
299
300 std::cout << "After optimization affine params are: " << affineTransform->GetParameters() << std::endl;
301 std::cout << "Test PASSED." << std::endl;
302 return EXIT_SUCCESS;
303
304 }
305