1 /*=========================================================================
2  *
3  *  Copyright Insight Software Consortium
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at
8  *
9  *         http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  *=========================================================================*/
18 
19 #include "itkFEMRegistrationFilter.h"
20 #include "itkImageFileWriter.h"
21 #include "itkTestingMacros.h"
22 
23 #include <fstream>
24 
25 
26 // Typedefs used for registration
27 
28 // itkFEMRegistrationFilter2DTest.cxx tests the itk::FEMRegistrationFilter
29 // class on 2D images
30 constexpr unsigned int ImageDimension = 3;
31 
32 using InputImagePixelType = unsigned char;
33 using DeformationFieldPixelType = float;
34 
35 using InputImageType = itk::Image< InputImagePixelType, ImageDimension >;
36 using DeformationFieldVectorType = itk::Vector< DeformationFieldPixelType, ImageDimension >;
37 using DeformationFieldImageType = itk::Image< DeformationFieldVectorType, ImageDimension >;
38 
39 using ElementType = itk::fem::Element3DC0LinearHexahedronMembrane;
40 
41 
42 // Template function to fill in an image with a value.
43 template< typename TImage >
FillImage(TImage * image,typename TImage::PixelType value)44 void FillImage( TImage * image, typename TImage::PixelType value )
45 {
46   using Iterator = itk::ImageRegionIteratorWithIndex< TImage >;
47   Iterator it( image, image->GetBufferedRegion() );
48 
49   for( it.GoToBegin(); !it.IsAtEnd(); ++it )
50     {
51     it.Set( value );
52     }
53 }
54 
55 // Template function to fill in an image with a circle.
56 template< typename TImage >
FillWithCircle(TImage * image,double * center,double radius,typename TImage::PixelType foregnd,typename TImage::PixelType backgnd)57 void FillWithCircle( TImage * image, double * center, double radius,
58   typename TImage::PixelType foregnd, typename TImage::PixelType backgnd )
59 {
60   using Iterator = itk::ImageRegionIteratorWithIndex< TImage >;
61   Iterator it( image, image->GetBufferedRegion() );
62 
63   typename TImage::IndexType index;
64   double r2 = itk::Math::sqr( radius );
65   for( it.GoToBegin(); !it.IsAtEnd(); ++it )
66     {
67     index = it.GetIndex();
68     double distance = 0;
69     for( unsigned int j = 0; j < TImage::ImageDimension; ++j )
70       {
71       distance += itk::Math::sqr( (double) index[j] - center[j] );
72       }
73     if( distance <= r2 )
74       {
75       it.Set( foregnd );
76       }
77     else
78       {
79       it.Set( backgnd );
80       }
81     }
82 }
83 
itkFEMRegistrationFilterTest2(int argc,char * argv[])84 int itkFEMRegistrationFilterTest2( int argc, char *argv[] )
85 {
86 
87   using IndexType = InputImageType::IndexType;
88   using SizeType = InputImageType::SizeType;
89   using RegionType = InputImageType::RegionType;
90   using SpacingType = InputImageType::SpacingType;
91   using PointType = InputImageType::PointType;
92 
93 
94   // Generate input images and initial deformation field
95 
96   PointType imageOrigin;
97   imageOrigin[0] = 100.0;
98   imageOrigin[1] =  50.0;
99   imageOrigin[2] = 200.0;
100 
101   SpacingType spacing;
102   InputImageType::SizeValueType sizeArray[ImageDimension];
103   for( unsigned int i = 0; i < ImageDimension; ++i )
104     {
105     sizeArray[i] = 32;
106     spacing[i] = 2.0;
107     }
108 
109   SizeType size;
110   size.SetSize( sizeArray );
111 
112   IndexType index;
113   index.Fill( 0 );
114 
115   RegionType region;
116   region.SetSize( size );
117   region.SetIndex( index );
118 
119   InputImageType::Pointer movingImage = InputImageType::New();
120   InputImageType::Pointer fixedImage = InputImageType::New();
121 
122   DeformationFieldImageType::Pointer initField = DeformationFieldImageType::New();
123 
124   movingImage->SetLargestPossibleRegion( region );
125   movingImage->SetBufferedRegion( region );
126   movingImage->SetOrigin( imageOrigin );
127   movingImage->SetSpacing( spacing );
128   movingImage->Allocate();
129 
130   fixedImage->SetLargestPossibleRegion( region );
131   fixedImage->SetBufferedRegion( region );
132   fixedImage->SetOrigin( imageOrigin );
133   fixedImage->SetSpacing( spacing );
134   fixedImage->Allocate();
135 
136   initField->SetLargestPossibleRegion( region );
137   initField->SetBufferedRegion( region );
138   initField->SetOrigin( imageOrigin );
139   initField->SetSpacing( spacing );
140   initField->Allocate();
141 
142   double center[ImageDimension];
143   double radius;
144   InputImagePixelType fgnd = 250;
145   InputImagePixelType bgnd = 15;
146 
147   // Set the circle center
148   for(double & i : center)
149     {
150     i = 16;
151     }
152 
153   // Fill fixed image with a circle
154   radius = 8;
155   FillWithCircle< InputImageType >( fixedImage, center, radius, fgnd, bgnd );
156 
157   // Fill moving image with a circle
158   radius = 5;
159   FillWithCircle< InputImageType >( movingImage, center, radius, fgnd, bgnd );
160 
161   // Fill initial deformation with zero vectors
162   DeformationFieldVectorType zeroVec;
163   zeroVec.Fill( 0.0 );
164   FillImage< DeformationFieldImageType >( initField, zeroVec );
165 
166 
167   using FEMObjectType = itk::fem::FEMObject< ImageDimension >;
168   using RegistrationType = itk::fem::FEMRegistrationFilter< InputImageType,
169                                           InputImageType,
170                                           FEMObjectType >;
171 
172 
173   // Run registration and warp moving
174   for( unsigned int met = 0; met < 4; ++met )
175     {
176     RegistrationType::Pointer registrator = RegistrationType::New();
177 
178     EXERCISE_BASIC_OBJECT_METHODS( registrator, FEMRegistrationFilter,
179       ImageToImageFilter );
180 
181     registrator->SetFixedImage( fixedImage );
182     TEST_SET_GET_VALUE( fixedImage, registrator->GetFixedImage() );
183 
184     registrator->SetMovingImage( movingImage );
185     TEST_SET_GET_VALUE( movingImage, registrator->GetMovingImage() );
186 
187     unsigned int maxLevel = 2;
188     registrator->SetMaxLevel( maxLevel );
189     TEST_SET_GET_VALUE( maxLevel, registrator->GetMaxLevel() );
190 
191     bool useNormalizedGradient = true;
192     TEST_SET_GET_BOOLEAN( registrator, UseNormalizedGradient, useNormalizedGradient );
193 
194     registrator->ChooseMetric( met );
195 
196     unsigned int numberOfMaxIterations = 5;
197     registrator->SetMaximumIterations( numberOfMaxIterations, 0 );
198     registrator->SetMaximumIterations( numberOfMaxIterations, 1 );
199 
200     RegistrationType::Float elasticity = 10;
201     registrator->SetElasticity( elasticity, 0 );
202     registrator->SetElasticity( elasticity, 1 );
203     TEST_SET_GET_VALUE( elasticity, registrator->GetElasticity() );
204 
205     RegistrationType::Float rho = 1;
206     registrator->SetRho( rho, 0 );
207     registrator->SetRho( rho, 1 );
208     //TEST_SET_GET_VALUE( rho, registrator->GetRho() );
209 
210     RegistrationType::Float gamma = 1.;
211     registrator->SetGamma( gamma, 0 );
212     registrator->SetGamma( gamma, 1 );
213     //TEST_SET_GET_VALUE( gamma, registrator->GetGamma() );
214 
215     RegistrationType::Float alpha = 1.;
216     registrator->SetAlpha( alpha );
217     TEST_SET_GET_VALUE( alpha, registrator->GetAlpha() );
218 
219     registrator->SetMeshPixelsPerElementAtEachResolution( 8, 0 );
220     registrator->SetMeshPixelsPerElementAtEachResolution( 4, 1 );
221 
222     unsigned int widthOfMetricRegion;
223     if( met == 0 || met == 3 )
224       {
225       widthOfMetricRegion = 0;
226       }
227     else
228       {
229       widthOfMetricRegion = 1;
230       }
231     registrator->SetWidthOfMetricRegion( widthOfMetricRegion, 0 );
232     TEST_SET_GET_VALUE( widthOfMetricRegion,
233       registrator->GetWidthOfMetricRegion() );
234 
235 
236     registrator->SetNumberOfIntegrationPoints( 2, 0 );
237     registrator->SetNumberOfIntegrationPoints( 2, 1 );
238 
239     RegistrationType::Float timeStep = 1.;
240     registrator->SetTimeStep( timeStep );
241     TEST_SET_GET_VALUE( timeStep, registrator->GetTimeStep() );
242 
243     unsigned int doLineSearchOnImageEnergy;
244     unsigned int employRegridding;
245     if( met == 0 )
246       {
247       doLineSearchOnImageEnergy = 2;
248       employRegridding = true;
249       }
250     else
251       {
252       doLineSearchOnImageEnergy = 0;
253       employRegridding = false;
254       }
255 
256     registrator->SetDoLineSearchOnImageEnergy( doLineSearchOnImageEnergy );
257     TEST_SET_GET_VALUE( doLineSearchOnImageEnergy,
258       registrator->GetDoLineSearchOnImageEnergy() );
259 
260     registrator->SetEmployRegridding( employRegridding );
261     TEST_SET_GET_VALUE( employRegridding, registrator->GetEmployRegridding() );
262 
263     bool useLandmarks = true;
264     TEST_SET_GET_BOOLEAN( registrator, UseLandmarks, useLandmarks );
265 
266     bool useMassMatrix = true;
267     TEST_SET_GET_BOOLEAN( registrator, UseMassMatrix, useMassMatrix );
268 
269     RegistrationType::Float energyReductionFactor = 0.0;
270     registrator->SetEnergyReductionFactor( energyReductionFactor );
271     TEST_SET_GET_VALUE( energyReductionFactor,
272       registrator->GetEnergyReductionFactor() );
273 
274     unsigned int lineSearchMaximumIterations = 100;
275     registrator->SetLineSearchMaximumIterations( lineSearchMaximumIterations );
276     TEST_SET_GET_VALUE( lineSearchMaximumIterations,
277       registrator->GetLineSearchMaximumIterations() );
278 
279     bool createMeshFromImage = true;
280     TEST_SET_GET_BOOLEAN( registrator, CreateMeshFromImage, createMeshFromImage );
281 
282     double standardDeviation = 0.5;
283     registrator->SetStandardDeviations( standardDeviation );
284     //TEST_SET_GET_VALUE( standardDeviations, registrator->GetStandardDeviations() );
285 
286     standardDeviation = 1.0;
287     RegistrationType::StandardDeviationsType standardDeviations;
288     standardDeviations.Fill( standardDeviation );
289     registrator->SetStandardDeviations( standardDeviations );
290     //TEST_SET_GET_VALUE( standardDeviations, registrator->GetStandardDeviations() );
291 
292     unsigned int maximumKernelWidth = 30;
293     registrator->SetMaximumKernelWidth( maximumKernelWidth );
294     TEST_SET_GET_VALUE( maximumKernelWidth, registrator->GetMaximumKernelWidth() );
295 
296     double maximumError = 0.1;
297     registrator->SetMaximumError( maximumError );
298     TEST_SET_GET_VALUE( maximumError, registrator->GetMaximumError() );
299 
300 
301     itk::fem::MaterialLinearElasticity::Pointer material =
302       itk::fem::MaterialLinearElasticity::New();
303     material->SetGlobalNumber( 0 );
304     material->SetYoungsModulus( registrator->GetElasticity() );
305     material->SetCrossSectionalArea( 1.0 );
306     material->SetThickness( 1.0 );
307     material->SetMomentOfInertia( 1.0 );
308     material->SetPoissonsRatio( 0. ); // DON'T CHOOSE 1.0!!
309     material->SetDensityHeatProduct( 1.0 );
310 
311     // Create the element type
312     ElementType::Pointer element1 = ElementType::New();
313     element1->SetMaterial( dynamic_cast< itk::fem::MaterialLinearElasticity * >( &*material ) );
314     registrator->SetElement( &*element1 );
315     registrator->SetMaterial( material );
316 
317     try
318       {
319       // Register the images
320       registrator->RunRegistration();
321       }
322     catch( ::itk::ExceptionObject & err )
323       {
324         std::cerr << "ITK exception detected: "  << err;
325         std::cout << "Test failed!" << std::endl;
326         return EXIT_FAILURE;
327       }
328     catch( ... )
329       {
330       // fixme - changes to femparray cause it to fail : old version works
331       std::cout << "Caught an exception: " << std::endl;
332       return EXIT_FAILURE;
333       // std::cout << err << std::endl;
334       // throw err;
335       }
336 
337     if( argc == 2 )
338       {
339       std::string outFileName = argv[1];
340       std::stringstream ss;
341       ss << met;
342       outFileName += ss.str();
343       outFileName += ".mhd";
344       using ImageWriterType = itk::ImageFileWriter< RegistrationType::FieldType >;
345       ImageWriterType::Pointer writer = ImageWriterType::New();
346       writer->SetFileName( outFileName );
347       writer->SetInput( registrator->GetDisplacementField() );
348       writer->Update();
349       }
350 
351     if( argc == 3 )
352       {
353       std::string outFileName = argv[2];
354       std::stringstream ss;
355       ss << met;
356       outFileName += ss.str();
357       outFileName += ".mhd";
358       using ImageWriterType = itk::ImageFileWriter< InputImageType >;
359       ImageWriterType::Pointer writer = ImageWriterType::New();
360       writer->SetFileName( outFileName );
361       writer->SetInput( registrator->GetWarpedImage() );
362       writer->Update();
363       }
364     }
365 
366   /*
367   // get warped reference image
368   // ---------------------------------------------------------
369   std::cout << "Compare warped moving and fixed." << std::endl;
370 
371   // compare the warp and fixed images
372   itk::ImageRegionIterator<ImageType> fixedIter( fixed,
373       fixed->GetBufferedRegion() );
374   itk::ImageRegionIterator<ImageType> warpedIter( registrator->GetWarpedImage(),
375       fixed->GetBufferedRegion() );
376 
377   unsigned int numPixelsDifferent = 0;
378   while( !fixedIter.IsAtEnd() )
379     {
380     if( fixedIter.Get() != warpedIter.Get() )
381       {
382       numPixelsDifferent++;
383       }
384     ++fixedIter;
385     ++warpedIter;
386     }
387 
388   std::cout << "Number of pixels different: " << numPixelsDifferent;
389   std::cout << std::endl;
390 
391   if( numPixelsDifferent > 400 )
392     {
393     std::cout << "Test failed - too many pixels different." << std::endl;
394     return EXIT_FAILURE;
395     }
396 
397   std::cout << "Test passed" << std::endl;
398   */
399 
400   std::cout << "Test finished." << std::endl;
401   return EXIT_SUCCESS;
402 }
403