1 /*=========================================================================
2 *
3 *  Copyright Insight Software Consortium
4 *
5 *  Licensed under the Apache License, Version 2.0 (the "License");
6 *  you may not use this file except in compliance with the License.
7 *  You may obtain a copy of the License at
8 *
9 *         http://www.apache.org/licenses/LICENSE-2.0.txt
10 *
11 *  Unless required by applicable law or agreed to in writing, software
12 *  distributed under the License is distributed on an "AS IS" BASIS,
13 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 *  See the License for the specific language governing permissions and
15 *  limitations under the License.
16 *
17 *=========================================================================*/
18 
19 /**
20  * Test program for MeanSquaresImageToImageMetricv4 and
21  * GradientDescentOptimizerv4 classes with vector (color)
22  * images.
23  *
24  * Perform a registration using user-supplied images.
25  * No numerical verification is performed. Test passes as long
26  * as no exception occurs.
27  */
28 #include "itkMeanSquaresImageToImageMetricv4.h"
29 #include "itkGradientDescentOptimizerv4.h"
30 #include "itkRegistrationParameterScalesFromPhysicalShift.h"
31 #include "itkVectorImageToImageMetricTraitsv4.h"
32 
33 #include "itkGaussianSmoothingOnUpdateDisplacementFieldTransform.h"
34 
35 #include "itkCastImageFilter.h"
36 
37 #include "itkCommand.h"
38 #include "itkImageFileReader.h"
39 #include "itkImageFileWriter.h"
40 
41 #include <iomanip>
42 #include "itkTestingMacros.h"
43 
itkMeanSquaresImageToImageMetricv4VectorRegistrationTest(int argc,char * argv[])44 int itkMeanSquaresImageToImageMetricv4VectorRegistrationTest(int argc, char *argv[])
45 {
46 
47   if( argc < 4 )
48     {
49     std::cerr << "Missing Parameters " << std::endl;
50     std::cerr << "Usage: " << itkNameOfTestExecutableMacro(argv);
51     std::cerr << " fixedImageFile movingImageFile ";
52     std::cerr << " outputImageFile ";
53     std::cerr << " [numberOfAffineIterations numberOfDisplacementIterations] ";
54     std::cerr << std::endl;
55     return EXIT_FAILURE;
56     }
57 
58   std::cout << argc << std::endl;
59   unsigned int numberOfAffineIterations = 2;
60   unsigned int numberOfDisplacementIterations = 2;
61   if( argc >= 5 )
62     {
63     numberOfAffineIterations = std::stoi( argv[4] );
64     }
65   if( argc >= 6 )
66     {
67     numberOfDisplacementIterations = std::stoi( argv[5] );
68     }
69   std::cout << " affine iterations "<< numberOfAffineIterations << " displacementIterations " << numberOfDisplacementIterations << std::endl;
70 
71   constexpr unsigned int Dimension = 2;
72 
73   // RGBPixel type is not supported by GradientRecursiveGaussianFilter at this point.
74   //using PixelType = itk::RGBPixel<FloatType>;
75   using PixelType = itk::Vector<double, 3>;
76 
77   using FixedImageType = itk::Image< PixelType, Dimension >;
78   using MovingImageType = itk::Image< PixelType, Dimension >;
79 
80   using FixedImageReaderType = itk::ImageFileReader< FixedImageType  >;
81   using MovingImageReaderType = itk::ImageFileReader< MovingImageType >;
82 
83   FixedImageReaderType::Pointer fixedImageReader   = FixedImageReaderType::New();
84   MovingImageReaderType::Pointer movingImageReader = MovingImageReaderType::New();
85 
86   fixedImageReader->SetFileName( argv[1] );
87   movingImageReader->SetFileName( argv[2] );
88 
89   //get the images
90   fixedImageReader->Update();
91   FixedImageType::Pointer  fixedImage = fixedImageReader->GetOutput();
92   movingImageReader->Update();
93   MovingImageType::Pointer movingImage = movingImageReader->GetOutput();
94 
95   /** define a resample filter that will ultimately be used to deform the image */
96   using ResampleFilterType = itk::ResampleImageFilter< MovingImageType, FixedImageType >;
97   ResampleFilterType::Pointer resample = ResampleFilterType::New();
98 
99   /** create a composite transform holder for other transforms  */
100   using CompositeType = itk::CompositeTransform<double, Dimension>;
101 
102   CompositeType::Pointer compositeTransform = CompositeType::New();
103 
104   //create an affine transform
105   using AffineTransformType = itk::AffineTransform<double, Dimension>;
106   AffineTransformType::Pointer affineTransform = AffineTransformType::New();
107   affineTransform->SetIdentity();
108   std::cout <<" affineTransform params prior to optimization " << affineTransform->GetParameters() << std::endl;
109 
110   using DisplacementTransformType = itk::GaussianSmoothingOnUpdateDisplacementFieldTransform< double, Dimension>;
111   DisplacementTransformType::Pointer displacementTransform = DisplacementTransformType::New();
112 
113   using DisplacementFieldType = DisplacementTransformType::DisplacementFieldType;
114   DisplacementFieldType::Pointer field = DisplacementFieldType::New();
115 
116   // set the field to be the same as the fixed image region, which will
117   // act by default as the virtual domain in this example.
118   field->SetRegions( fixedImage->GetLargestPossibleRegion() );
119   //make sure the field has the same spatial information as the image
120   field->CopyInformation( fixedImage );
121   std::cout << "fixedImage->GetLargestPossibleRegion(): " << fixedImage->GetLargestPossibleRegion() << std::endl;
122   field->Allocate();
123   // Fill it with 0's
124   DisplacementTransformType::OutputVectorType zeroVector;
125   zeroVector.Fill( 0 );
126   field->FillBuffer( zeroVector );
127   // Assign to transform
128   displacementTransform->SetDisplacementField( field );
129   displacementTransform->SetGaussianSmoothingVarianceForTheUpdateField( 5 );
130   displacementTransform->SetGaussianSmoothingVarianceForTheTotalField( 6 );
131 
132   //identity transform for fixed image
133   using IdentityTransformType = itk::IdentityTransform<double, Dimension>;
134   IdentityTransformType::Pointer identityTransform = IdentityTransformType::New();
135   identityTransform->SetIdentity();
136 
137   // The metric
138   using VirtualImageType = itk::Image< double, Dimension>;
139   using MetricTraitsType = itk::VectorImageToImageMetricTraitsv4< FixedImageType, MovingImageType, VirtualImageType, PixelType::Length>;
140   using MetricType = itk::MeanSquaresImageToImageMetricv4 < FixedImageType, MovingImageType, VirtualImageType, double, MetricTraitsType >;
141   using PointSetType = MetricType::FixedSampledPointSetType;
142   MetricType::Pointer metric = MetricType::New();
143 
144   using PointType = PointSetType::PointType;
145   PointSetType::Pointer               pset(PointSetType::New());
146   unsigned long ind=0,ct=0;
147   itk::ImageRegionIteratorWithIndex<FixedImageType> It(fixedImage, fixedImage->GetLargestPossibleRegion() );
148 
149   for( It.GoToBegin(); !It.IsAtEnd(); ++It )
150     {
151     // take every N^th point
152     if ( ct % 2 == 0  )
153       {
154         PointType pt;
155         fixedImage->TransformIndexToPhysicalPoint( It.GetIndex(), pt);
156         pset->SetPoint(ind, pt);
157         ind++;
158       }
159       ct++;
160     }
161   std::cout << "Setting point set with " << ind << " points of " << fixedImage->GetLargestPossibleRegion().GetNumberOfPixels() << " total " << std::endl;
162   metric->SetFixedSampledPointSet( pset );
163   metric->SetUseSampledPointSet( true );
164   std::cout << "Testing metric with point set..." << std::endl;
165 
166 
167   // Assign images and transforms.
168   // By not setting a virtual domain image or virtual domain settings,
169   // the metric will use the fixed image for the virtual domain.
170   metric->SetFixedImage( fixedImage );
171   metric->SetMovingImage( movingImage );
172   metric->SetFixedTransform( identityTransform );
173   metric->SetMovingTransform( affineTransform );
174   const bool gaussian = false;
175   metric->SetUseMovingImageGradientFilter( gaussian );
176   metric->SetUseFixedImageGradientFilter( gaussian );
177   metric->Initialize();
178 
179   using RegistrationParameterScalesFromShiftType = itk::RegistrationParameterScalesFromPhysicalShift< MetricType >;
180   RegistrationParameterScalesFromShiftType::Pointer shiftScaleEstimator = RegistrationParameterScalesFromShiftType::New();
181   shiftScaleEstimator->SetMetric(metric);
182 
183   //
184   // Affine registration
185   //
186   std::cout << "First do an affine registration " << std::endl;
187   using OptimizerType = itk::GradientDescentOptimizerv4;
188   OptimizerType::Pointer  optimizer = OptimizerType::New();
189   optimizer->SetMetric( metric );
190   optimizer->SetNumberOfIterations( numberOfAffineIterations );
191   optimizer->SetScalesEstimator( shiftScaleEstimator );
192   optimizer->StartOptimization();
193 
194   std::cout << "Number of threads: metric: " << metric->GetNumberOfWorkUnitsUsed() << " optimizer: " << optimizer->GetNumberOfWorkUnits() << std::endl;
195   std::cout << "GetNumberOfSkippedFixedSampledPoints: " << metric->GetNumberOfSkippedFixedSampledPoints() << std::endl;
196 
197   //
198   // Deformable registration
199   //
200   // now add the displacement field to the composite transform
201   compositeTransform->AddTransform( affineTransform );
202   compositeTransform->AddTransform( displacementTransform );
203   //compositeTransform->SetAllTransformsToOptimizeOn(); //Set to optimize all.
204   // Optimize only the displacement field, but still using the previously-compute affine transformation
205   compositeTransform->SetOnlyMostRecentTransformToOptimizeOn(); //set to optimize the displacement field
206   metric->SetMovingTransform( compositeTransform );
207   metric->SetUseSampledPointSet( false );
208   metric->Initialize();
209 
210   // Optimizer
211   RegistrationParameterScalesFromShiftType::ScalesType displacementScales( displacementTransform->GetNumberOfLocalParameters() );
212   displacementScales.Fill(1);
213   if( false )
214     {
215     optimizer->SetScales( displacementScales );
216     }
217   else
218     {
219     optimizer->SetScalesEstimator( shiftScaleEstimator );
220     }
221   optimizer->SetMetric( metric );
222   optimizer->SetNumberOfIterations( numberOfDisplacementIterations );
223   try
224     {
225     if( numberOfDisplacementIterations > 0 )
226       {
227       std::cout << "Follow affine with deformable registration... " << std::endl;
228       optimizer->StartOptimization();
229       std::cout << "...finished. " << std::endl;
230       std::cout << "GetNumberOfSkippedFixedSampledPoints: " << metric->GetNumberOfSkippedFixedSampledPoints() << std::endl;
231       }
232     else
233       {
234       std::cout << "** SKIPPING DISPLACEMENT FIELD OPT\n";
235       }
236     }
237   catch( itk::ExceptionObject & e )
238     {
239     std::cout << "Exception thrown ! " << std::endl;
240     std::cout << "An error occurred during deformation Optimization:" << std::endl;
241     std::cout << e.GetLocation() << std::endl;
242     std::cout << e.GetDescription() << std::endl;
243     std::cout << e.what()    << std::endl;
244     std::cout << "Test FAILED." << std::endl;
245     return EXIT_FAILURE;
246     }
247 
248   //dump part of the displacement field for debugging
249   std::cout << "Deformation field samples: " << std::endl;
250   FixedImageType::SizeType size = fixedImage->GetBufferedRegion().GetSize();
251   for( itk::SizeValueType x = 0; x < size[0]; x+=30 )
252     {
253     std::cout << x << std::endl;
254     for( itk::SizeValueType y = 0; y < size[1]; y+=30 )
255       {
256       FixedImageType::IndexType index;
257       index[0] = x;
258       index[1] = y;
259       std::cout << field->GetPixel(index);
260       }
261     std::cout << std::endl;
262     }
263 
264   //warp the image with the displacement field
265   resample->SetTransform( compositeTransform );
266   resample->SetInput( movingImageReader->GetOutput() );
267   resample->SetSize(    fixedImage->GetLargestPossibleRegion().GetSize() );
268   resample->SetOutputOrigin(  fixedImage->GetOrigin() );
269   resample->SetOutputSpacing( fixedImage->GetSpacing() );
270   resample->SetOutputDirection( fixedImage->GetDirection() );
271   resample->SetDefaultPixelValue( itk::NumericTraits<FixedImageType::PixelType::ValueType>::ZeroValue() );
272   resample->Update();
273 
274   //write out the displacement field
275   using DisplacementWriterType = itk::ImageFileWriter< DisplacementFieldType >;
276   DisplacementWriterType::Pointer      displacementwriter =  DisplacementWriterType::New();
277   std::string outfilename( argv[3] );
278   std::string  ext = itksys::SystemTools::GetFilenameExtension( outfilename );
279   std::string name = itksys::SystemTools::GetFilenameWithoutExtension( outfilename );
280   std::string path = itksys::SystemTools::GetFilenamePath( outfilename );
281   std::string defout = path + std::string( "/" ) + name + std::string("_def") + ext;
282   displacementwriter->SetFileName( defout.c_str() );
283   displacementwriter->SetInput( displacementTransform->GetDisplacementField() );
284   displacementwriter->Update();
285 
286   //write the warped image into a file
287   using OutputPixelType = PixelType;
288   using OutputImageType = itk::Image< OutputPixelType, Dimension >;
289   using CastFilterType = itk::CastImageFilter<
290                         MovingImageType,
291                         OutputImageType >;
292   using WriterType = itk::ImageFileWriter< OutputImageType >;
293   WriterType::Pointer      writer =  WriterType::New();
294   CastFilterType::Pointer  caster =  CastFilterType::New();
295   writer->SetFileName( argv[3] );
296   caster->SetInput( resample->GetOutput() );
297   writer->SetInput( caster->GetOutput() );
298   writer->Update();
299 
300   std::cout << "After optimization affine params are: " <<  affineTransform->GetParameters() << std::endl;
301   std::cout << "Test PASSED." << std::endl;
302   return EXIT_SUCCESS;
303 
304 }
305