1 /*=========================================================================
2 *
3 *  Copyright Insight Software Consortium
4 *
5 *  Licensed under the Apache License, Version 2.0 (the "License");
6 *  you may not use this file except in compliance with the License.
7 *  You may obtain a copy of the License at
8 *
9 *         http://www.apache.org/licenses/LICENSE-2.0.txt
10 *
11 *  Unless required by applicable law or agreed to in writing, software
12 *  distributed under the License is distributed on an "AS IS" BASIS,
13 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 *  See the License for the specific language governing permissions and
15 *  limitations under the License.
16 *
17 *=========================================================================*/
18 
19 #include "itkAffineTransform.h"
20 #include "itkBSplineSmoothingOnUpdateDisplacementFieldTransformParametersAdaptor.h"
21 #include "itkBSplineSyNImageRegistrationMethod.h"
22 #include "itkEuclideanDistancePointSetToPointSetMetricv4.h"
23 
itkBSplineSyNPointSetRegistrationTest(int itkNotUsed (argc),char * itkNotUsed (argv)[])24 int itkBSplineSyNPointSetRegistrationTest( int itkNotUsed( argc ), char * itkNotUsed( argv )[] )
25 {
26   constexpr unsigned int Dimension = 2;
27 
28   using PointSetType = itk::PointSet<unsigned int, Dimension>;
29 
30   using PointSetMetricType = itk::EuclideanDistancePointSetToPointSetMetricv4<PointSetType>;
31   PointSetMetricType::Pointer metric = PointSetMetricType::New();
32 
33   using PointSetType = PointSetMetricType::FixedPointSetType;
34   using PointType = PointSetType::PointType;
35 
36   using PixelType = double;
37   using FixedImageType = itk::Image<PixelType, Dimension>;
38   using MovingImageType = itk::Image<PixelType, Dimension>;
39 
40 
41   PointSetType::Pointer fixedPoints = PointSetType::New();
42   fixedPoints->Initialize();
43 
44   PointSetType::Pointer movingPoints = PointSetType::New();
45   movingPoints->Initialize();
46 
47   // two circles with a small offset
48   PointType offset;
49   for( unsigned int d=0; d < PointSetType::PointDimension; d++ )
50     {
51     offset[d] = 2.0;
52     }
53   unsigned long count = 0;
54   for( float theta = 0; theta < 2.0 * itk::Math::pi; theta += 0.1 )
55     {
56     auto label = static_cast<unsigned int>( 1.5 + count / 100 );
57 
58     PointType fixedPoint;
59     float radius = 100.0;
60     fixedPoint[0] = radius * std::cos( theta );
61     fixedPoint[1] = radius * std::sin( theta );
62     if( PointSetType::PointDimension > 2 )
63       {
64       fixedPoint[2] = radius * std::sin( theta );
65       }
66     fixedPoints->SetPoint( count, fixedPoint );
67     fixedPoints->SetPointData( count, label );
68 
69     PointType movingPoint;
70     movingPoint[0] = fixedPoint[0] + offset[0];
71     movingPoint[1] = fixedPoint[1] + offset[1];
72     if( PointSetType::PointDimension > 2 )
73       {
74       movingPoint[2] = fixedPoint[2] + offset[2];
75       }
76     movingPoints->SetPoint( count, movingPoint );
77     movingPoints->SetPointData( count, label );
78 
79     count++;
80     }
81 
82   // virtual image domain is [-110,-110]  [110,110]
83 
84   FixedImageType::SizeType fixedImageSize;
85   FixedImageType::PointType fixedImageOrigin;
86   FixedImageType::DirectionType fixedImageDirection;
87   FixedImageType::SpacingType fixedImageSpacing;
88 
89   fixedImageSize.Fill( 221 );
90   fixedImageOrigin.Fill( -110 );
91   fixedImageDirection.SetIdentity();
92   fixedImageSpacing.Fill( 1 );
93 
94   FixedImageType::Pointer fixedImage = FixedImageType::New();
95   fixedImage->SetRegions( fixedImageSize );
96   fixedImage->SetOrigin( fixedImageOrigin );
97   fixedImage->SetDirection( fixedImageDirection );
98   fixedImage->SetSpacing( fixedImageSpacing );
99   fixedImage->Allocate();
100 
101   using AffineTransformType = itk::AffineTransform<double, PointSetType::PointDimension>;
102   AffineTransformType::Pointer transform = AffineTransformType::New();
103   transform->SetIdentity();
104 
105   metric->SetFixedPointSet( fixedPoints );
106   metric->SetMovingPointSet( movingPoints );
107   metric->SetVirtualDomainFromImage( fixedImage );
108   metric->SetMovingTransform( transform );
109   metric->Initialize();
110 
111   // Create the SyN deformable registration method
112 
113   using VectorType = itk::Vector<double, Dimension>;
114   VectorType zeroVector( 0.0 );
115 
116   using DisplacementFieldType = itk::Image<VectorType, Dimension>;
117   DisplacementFieldType::Pointer displacementField = DisplacementFieldType::New();
118   displacementField->CopyInformation( fixedImage );
119   displacementField->SetRegions( fixedImage->GetBufferedRegion() );
120   displacementField->Allocate();
121   displacementField->FillBuffer( zeroVector );
122 
123   DisplacementFieldType::Pointer inverseDisplacementField = DisplacementFieldType::New();
124   inverseDisplacementField->CopyInformation( fixedImage );
125   inverseDisplacementField->SetRegions( fixedImage->GetBufferedRegion() );
126   inverseDisplacementField->Allocate();
127   inverseDisplacementField->FillBuffer( zeroVector );
128 
129   using DisplacementFieldRegistrationType = itk::BSplineSyNImageRegistrationMethod<FixedImageType, MovingImageType>;
130   DisplacementFieldRegistrationType::Pointer displacementFieldRegistration = DisplacementFieldRegistrationType::New();
131 
132   using OutputTransformType = DisplacementFieldRegistrationType::OutputTransformType;
133   OutputTransformType::Pointer outputTransform = OutputTransformType::New();
134   outputTransform->SetDisplacementField( displacementField );
135   outputTransform->SetInverseDisplacementField( inverseDisplacementField );
136 
137   displacementFieldRegistration->SetInitialTransform( outputTransform );
138   displacementFieldRegistration->InPlaceOn();
139 
140   using DisplacementFieldTransformAdaptorType = itk::BSplineSmoothingOnUpdateDisplacementFieldTransformParametersAdaptor<OutputTransformType>;
141   DisplacementFieldRegistrationType::TransformParametersAdaptorsContainerType adaptors;
142 
143   OutputTransformType::ArrayType updateMeshSize;
144   OutputTransformType::ArrayType totalMeshSize;
145   for( unsigned int d = 0; d < Dimension; d++ )
146     {
147     updateMeshSize[d] = 10;
148     totalMeshSize[d] = 0;
149     }
150 
151   // Create the transform adaptors
152   // For the gaussian displacement field, the specified variances are in image spacing terms
153   // and, in normal practice, we typically don't change these values at each level.  However,
154   // if the user wishes to add that option, they can use the class
155   // GaussianSmoothingOnUpdateDisplacementFieldTransformAdaptor
156 
157   unsigned int numberOfLevels = 3;
158 
159   DisplacementFieldRegistrationType::NumberOfIterationsArrayType numberOfIterationsPerLevel;
160   numberOfIterationsPerLevel.SetSize( 3 );
161   numberOfIterationsPerLevel[0] = 1;
162   numberOfIterationsPerLevel[1] = 1;
163   numberOfIterationsPerLevel[2] = 50;
164 
165   DisplacementFieldRegistrationType::ShrinkFactorsArrayType shrinkFactorsPerLevel;
166   shrinkFactorsPerLevel.SetSize( 3 );
167   shrinkFactorsPerLevel.Fill( 1 );
168 
169   DisplacementFieldRegistrationType::SmoothingSigmasArrayType smoothingSigmasPerLevel;
170   smoothingSigmasPerLevel.SetSize( 3 );
171   smoothingSigmasPerLevel.Fill( 0 );
172 
173   for( unsigned int level = 0; level < numberOfLevels; level++ )
174     {
175     // We use the shrink image filter to calculate the fixed parameters of the virtual
176     // domain at each level.  To speed up calculation and avoid unnecessary memory
177     // usage, we could calculate these fixed parameters directly.
178 
179     using ShrinkFilterType = itk::ShrinkImageFilter<DisplacementFieldType, DisplacementFieldType>;
180     ShrinkFilterType::Pointer shrinkFilter = ShrinkFilterType::New();
181     shrinkFilter->SetShrinkFactors( shrinkFactorsPerLevel[level] );
182     shrinkFilter->SetInput( displacementField );
183     shrinkFilter->Update();
184 
185     DisplacementFieldTransformAdaptorType::Pointer fieldTransformAdaptor = DisplacementFieldTransformAdaptorType::New();
186     fieldTransformAdaptor->SetRequiredSpacing( shrinkFilter->GetOutput()->GetSpacing() );
187     fieldTransformAdaptor->SetRequiredSize( shrinkFilter->GetOutput()->GetBufferedRegion().GetSize() );
188     fieldTransformAdaptor->SetRequiredDirection( shrinkFilter->GetOutput()->GetDirection() );
189     fieldTransformAdaptor->SetRequiredOrigin( shrinkFilter->GetOutput()->GetOrigin() );
190     fieldTransformAdaptor->SetTransform( outputTransform );
191 
192     // A good heuristic is to double the b-spline mesh resolution at each level
193     OutputTransformType::ArrayType newUpdateMeshSize = updateMeshSize;
194     OutputTransformType::ArrayType newTotalMeshSize = totalMeshSize;
195     for( unsigned int d = 0; d < Dimension; d++ )
196       {
197       newUpdateMeshSize[d] = newUpdateMeshSize[d] << ( level );
198       newTotalMeshSize[d] = newTotalMeshSize[d] << ( level );
199       }
200     fieldTransformAdaptor->SetMeshSizeForTheUpdateField( newUpdateMeshSize );
201     fieldTransformAdaptor->SetMeshSizeForTheTotalField( newTotalMeshSize );
202 
203     adaptors.push_back( fieldTransformAdaptor );
204     }
205 
206   displacementFieldRegistration->SetFixedPointSet( fixedPoints );
207   displacementFieldRegistration->SetMovingPointSet( movingPoints );
208   displacementFieldRegistration->SetNumberOfLevels( 3 );
209   displacementFieldRegistration->SetMovingInitialTransform( transform );
210   displacementFieldRegistration->SetShrinkFactorsPerLevel( shrinkFactorsPerLevel );
211   displacementFieldRegistration->SetSmoothingSigmasPerLevel( smoothingSigmasPerLevel );
212   displacementFieldRegistration->SetMetric( metric );
213   displacementFieldRegistration->SetLearningRate( 0.25 );
214   displacementFieldRegistration->SetNumberOfIterationsPerLevel( numberOfIterationsPerLevel );
215   displacementFieldRegistration->SetTransformParametersAdaptorsPerLevel( adaptors );
216 
217   outputTransform->SetDisplacementField( displacementField );
218   outputTransform->SetInverseDisplacementField( inverseDisplacementField );
219   displacementFieldRegistration->SetInitialTransform( outputTransform );
220   displacementFieldRegistration->InPlaceOn();
221 
222   try
223     {
224     std::cout << "B-spline SyN point set registration" << std::endl;
225     displacementFieldRegistration->Update();
226     }
227   catch( itk::ExceptionObject &e )
228     {
229     std::cerr << "Exception caught: " << e << std::endl;
230     return EXIT_FAILURE;
231     }
232 
233   // applying the resultant transform to moving points and verify result
234   std::cout << "Fixed\tMoving\tMovingTransformed\tFixedTransformed\tDiff" << std::endl;
235   PointType::ValueType tolerance = 0.01;
236 
237   float averageError = 0.0;
238   for( unsigned int n = 0; n < movingPoints->GetNumberOfPoints(); n++ )
239     {
240     // compare the points in virtual domain
241     PointType transformedMovingPoint =
242       displacementFieldRegistration->GetModifiableTransform()->GetInverseTransform()->TransformPoint( movingPoints->GetPoint( n ) );
243     PointType fixedPoint = fixedPoints->GetPoint( n );
244     PointType transformedFixedPoint = displacementFieldRegistration->GetModifiableTransform()->TransformPoint( fixedPoints->GetPoint( n ) );
245     PointType difference;
246     difference[0] = transformedMovingPoint[0] - fixedPoint[0];
247     difference[1] = transformedMovingPoint[1] - fixedPoint[1];
248     std::cout << fixedPoints->GetPoint( n ) << "\t" << movingPoints->GetPoint( n )
249           << "\t" << transformedMovingPoint << "\t" << transformedFixedPoint << "\t" << difference << std::endl;
250 
251     averageError += ( ( difference.GetVectorFromOrigin() ).GetSquaredNorm() );
252     }
253 
254   unsigned int numberOfPoints = movingPoints->GetNumberOfPoints();
255   if( numberOfPoints > 0 )
256     {
257     averageError /= static_cast<float>( numberOfPoints );
258     std::cout << "Average error: " << averageError << std::endl;
259     if( averageError > tolerance )
260       {
261       std::cerr << "Results do not match truth within tolerance." << std::endl;
262       return EXIT_FAILURE;
263       }
264     }
265   else
266     {
267     std::cerr << "No points." << std::endl;
268     return EXIT_FAILURE;
269     }
270 
271   return EXIT_SUCCESS;
272 }
273