1 /*=========================================================================
2 *
3 * Copyright Insight Software Consortium
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0.txt
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 *=========================================================================*/
18
19 #include "itkAffineTransform.h"
20 #include "itkBSplineSmoothingOnUpdateDisplacementFieldTransformParametersAdaptor.h"
21 #include "itkBSplineSyNImageRegistrationMethod.h"
22 #include "itkEuclideanDistancePointSetToPointSetMetricv4.h"
23
itkBSplineSyNPointSetRegistrationTest(int itkNotUsed (argc),char * itkNotUsed (argv)[])24 int itkBSplineSyNPointSetRegistrationTest( int itkNotUsed( argc ), char * itkNotUsed( argv )[] )
25 {
26 constexpr unsigned int Dimension = 2;
27
28 using PointSetType = itk::PointSet<unsigned int, Dimension>;
29
30 using PointSetMetricType = itk::EuclideanDistancePointSetToPointSetMetricv4<PointSetType>;
31 PointSetMetricType::Pointer metric = PointSetMetricType::New();
32
33 using PointSetType = PointSetMetricType::FixedPointSetType;
34 using PointType = PointSetType::PointType;
35
36 using PixelType = double;
37 using FixedImageType = itk::Image<PixelType, Dimension>;
38 using MovingImageType = itk::Image<PixelType, Dimension>;
39
40
41 PointSetType::Pointer fixedPoints = PointSetType::New();
42 fixedPoints->Initialize();
43
44 PointSetType::Pointer movingPoints = PointSetType::New();
45 movingPoints->Initialize();
46
47 // two circles with a small offset
48 PointType offset;
49 for( unsigned int d=0; d < PointSetType::PointDimension; d++ )
50 {
51 offset[d] = 2.0;
52 }
53 unsigned long count = 0;
54 for( float theta = 0; theta < 2.0 * itk::Math::pi; theta += 0.1 )
55 {
56 auto label = static_cast<unsigned int>( 1.5 + count / 100 );
57
58 PointType fixedPoint;
59 float radius = 100.0;
60 fixedPoint[0] = radius * std::cos( theta );
61 fixedPoint[1] = radius * std::sin( theta );
62 if( PointSetType::PointDimension > 2 )
63 {
64 fixedPoint[2] = radius * std::sin( theta );
65 }
66 fixedPoints->SetPoint( count, fixedPoint );
67 fixedPoints->SetPointData( count, label );
68
69 PointType movingPoint;
70 movingPoint[0] = fixedPoint[0] + offset[0];
71 movingPoint[1] = fixedPoint[1] + offset[1];
72 if( PointSetType::PointDimension > 2 )
73 {
74 movingPoint[2] = fixedPoint[2] + offset[2];
75 }
76 movingPoints->SetPoint( count, movingPoint );
77 movingPoints->SetPointData( count, label );
78
79 count++;
80 }
81
82 // virtual image domain is [-110,-110] [110,110]
83
84 FixedImageType::SizeType fixedImageSize;
85 FixedImageType::PointType fixedImageOrigin;
86 FixedImageType::DirectionType fixedImageDirection;
87 FixedImageType::SpacingType fixedImageSpacing;
88
89 fixedImageSize.Fill( 221 );
90 fixedImageOrigin.Fill( -110 );
91 fixedImageDirection.SetIdentity();
92 fixedImageSpacing.Fill( 1 );
93
94 FixedImageType::Pointer fixedImage = FixedImageType::New();
95 fixedImage->SetRegions( fixedImageSize );
96 fixedImage->SetOrigin( fixedImageOrigin );
97 fixedImage->SetDirection( fixedImageDirection );
98 fixedImage->SetSpacing( fixedImageSpacing );
99 fixedImage->Allocate();
100
101 using AffineTransformType = itk::AffineTransform<double, PointSetType::PointDimension>;
102 AffineTransformType::Pointer transform = AffineTransformType::New();
103 transform->SetIdentity();
104
105 metric->SetFixedPointSet( fixedPoints );
106 metric->SetMovingPointSet( movingPoints );
107 metric->SetVirtualDomainFromImage( fixedImage );
108 metric->SetMovingTransform( transform );
109 metric->Initialize();
110
111 // Create the SyN deformable registration method
112
113 using VectorType = itk::Vector<double, Dimension>;
114 VectorType zeroVector( 0.0 );
115
116 using DisplacementFieldType = itk::Image<VectorType, Dimension>;
117 DisplacementFieldType::Pointer displacementField = DisplacementFieldType::New();
118 displacementField->CopyInformation( fixedImage );
119 displacementField->SetRegions( fixedImage->GetBufferedRegion() );
120 displacementField->Allocate();
121 displacementField->FillBuffer( zeroVector );
122
123 DisplacementFieldType::Pointer inverseDisplacementField = DisplacementFieldType::New();
124 inverseDisplacementField->CopyInformation( fixedImage );
125 inverseDisplacementField->SetRegions( fixedImage->GetBufferedRegion() );
126 inverseDisplacementField->Allocate();
127 inverseDisplacementField->FillBuffer( zeroVector );
128
129 using DisplacementFieldRegistrationType = itk::BSplineSyNImageRegistrationMethod<FixedImageType, MovingImageType>;
130 DisplacementFieldRegistrationType::Pointer displacementFieldRegistration = DisplacementFieldRegistrationType::New();
131
132 using OutputTransformType = DisplacementFieldRegistrationType::OutputTransformType;
133 OutputTransformType::Pointer outputTransform = OutputTransformType::New();
134 outputTransform->SetDisplacementField( displacementField );
135 outputTransform->SetInverseDisplacementField( inverseDisplacementField );
136
137 displacementFieldRegistration->SetInitialTransform( outputTransform );
138 displacementFieldRegistration->InPlaceOn();
139
140 using DisplacementFieldTransformAdaptorType = itk::BSplineSmoothingOnUpdateDisplacementFieldTransformParametersAdaptor<OutputTransformType>;
141 DisplacementFieldRegistrationType::TransformParametersAdaptorsContainerType adaptors;
142
143 OutputTransformType::ArrayType updateMeshSize;
144 OutputTransformType::ArrayType totalMeshSize;
145 for( unsigned int d = 0; d < Dimension; d++ )
146 {
147 updateMeshSize[d] = 10;
148 totalMeshSize[d] = 0;
149 }
150
151 // Create the transform adaptors
152 // For the gaussian displacement field, the specified variances are in image spacing terms
153 // and, in normal practice, we typically don't change these values at each level. However,
154 // if the user wishes to add that option, they can use the class
155 // GaussianSmoothingOnUpdateDisplacementFieldTransformAdaptor
156
157 unsigned int numberOfLevels = 3;
158
159 DisplacementFieldRegistrationType::NumberOfIterationsArrayType numberOfIterationsPerLevel;
160 numberOfIterationsPerLevel.SetSize( 3 );
161 numberOfIterationsPerLevel[0] = 1;
162 numberOfIterationsPerLevel[1] = 1;
163 numberOfIterationsPerLevel[2] = 50;
164
165 DisplacementFieldRegistrationType::ShrinkFactorsArrayType shrinkFactorsPerLevel;
166 shrinkFactorsPerLevel.SetSize( 3 );
167 shrinkFactorsPerLevel.Fill( 1 );
168
169 DisplacementFieldRegistrationType::SmoothingSigmasArrayType smoothingSigmasPerLevel;
170 smoothingSigmasPerLevel.SetSize( 3 );
171 smoothingSigmasPerLevel.Fill( 0 );
172
173 for( unsigned int level = 0; level < numberOfLevels; level++ )
174 {
175 // We use the shrink image filter to calculate the fixed parameters of the virtual
176 // domain at each level. To speed up calculation and avoid unnecessary memory
177 // usage, we could calculate these fixed parameters directly.
178
179 using ShrinkFilterType = itk::ShrinkImageFilter<DisplacementFieldType, DisplacementFieldType>;
180 ShrinkFilterType::Pointer shrinkFilter = ShrinkFilterType::New();
181 shrinkFilter->SetShrinkFactors( shrinkFactorsPerLevel[level] );
182 shrinkFilter->SetInput( displacementField );
183 shrinkFilter->Update();
184
185 DisplacementFieldTransformAdaptorType::Pointer fieldTransformAdaptor = DisplacementFieldTransformAdaptorType::New();
186 fieldTransformAdaptor->SetRequiredSpacing( shrinkFilter->GetOutput()->GetSpacing() );
187 fieldTransformAdaptor->SetRequiredSize( shrinkFilter->GetOutput()->GetBufferedRegion().GetSize() );
188 fieldTransformAdaptor->SetRequiredDirection( shrinkFilter->GetOutput()->GetDirection() );
189 fieldTransformAdaptor->SetRequiredOrigin( shrinkFilter->GetOutput()->GetOrigin() );
190 fieldTransformAdaptor->SetTransform( outputTransform );
191
192 // A good heuristic is to double the b-spline mesh resolution at each level
193 OutputTransformType::ArrayType newUpdateMeshSize = updateMeshSize;
194 OutputTransformType::ArrayType newTotalMeshSize = totalMeshSize;
195 for( unsigned int d = 0; d < Dimension; d++ )
196 {
197 newUpdateMeshSize[d] = newUpdateMeshSize[d] << ( level );
198 newTotalMeshSize[d] = newTotalMeshSize[d] << ( level );
199 }
200 fieldTransformAdaptor->SetMeshSizeForTheUpdateField( newUpdateMeshSize );
201 fieldTransformAdaptor->SetMeshSizeForTheTotalField( newTotalMeshSize );
202
203 adaptors.push_back( fieldTransformAdaptor );
204 }
205
206 displacementFieldRegistration->SetFixedPointSet( fixedPoints );
207 displacementFieldRegistration->SetMovingPointSet( movingPoints );
208 displacementFieldRegistration->SetNumberOfLevels( 3 );
209 displacementFieldRegistration->SetMovingInitialTransform( transform );
210 displacementFieldRegistration->SetShrinkFactorsPerLevel( shrinkFactorsPerLevel );
211 displacementFieldRegistration->SetSmoothingSigmasPerLevel( smoothingSigmasPerLevel );
212 displacementFieldRegistration->SetMetric( metric );
213 displacementFieldRegistration->SetLearningRate( 0.25 );
214 displacementFieldRegistration->SetNumberOfIterationsPerLevel( numberOfIterationsPerLevel );
215 displacementFieldRegistration->SetTransformParametersAdaptorsPerLevel( adaptors );
216
217 outputTransform->SetDisplacementField( displacementField );
218 outputTransform->SetInverseDisplacementField( inverseDisplacementField );
219 displacementFieldRegistration->SetInitialTransform( outputTransform );
220 displacementFieldRegistration->InPlaceOn();
221
222 try
223 {
224 std::cout << "B-spline SyN point set registration" << std::endl;
225 displacementFieldRegistration->Update();
226 }
227 catch( itk::ExceptionObject &e )
228 {
229 std::cerr << "Exception caught: " << e << std::endl;
230 return EXIT_FAILURE;
231 }
232
233 // applying the resultant transform to moving points and verify result
234 std::cout << "Fixed\tMoving\tMovingTransformed\tFixedTransformed\tDiff" << std::endl;
235 PointType::ValueType tolerance = 0.01;
236
237 float averageError = 0.0;
238 for( unsigned int n = 0; n < movingPoints->GetNumberOfPoints(); n++ )
239 {
240 // compare the points in virtual domain
241 PointType transformedMovingPoint =
242 displacementFieldRegistration->GetModifiableTransform()->GetInverseTransform()->TransformPoint( movingPoints->GetPoint( n ) );
243 PointType fixedPoint = fixedPoints->GetPoint( n );
244 PointType transformedFixedPoint = displacementFieldRegistration->GetModifiableTransform()->TransformPoint( fixedPoints->GetPoint( n ) );
245 PointType difference;
246 difference[0] = transformedMovingPoint[0] - fixedPoint[0];
247 difference[1] = transformedMovingPoint[1] - fixedPoint[1];
248 std::cout << fixedPoints->GetPoint( n ) << "\t" << movingPoints->GetPoint( n )
249 << "\t" << transformedMovingPoint << "\t" << transformedFixedPoint << "\t" << difference << std::endl;
250
251 averageError += ( ( difference.GetVectorFromOrigin() ).GetSquaredNorm() );
252 }
253
254 unsigned int numberOfPoints = movingPoints->GetNumberOfPoints();
255 if( numberOfPoints > 0 )
256 {
257 averageError /= static_cast<float>( numberOfPoints );
258 std::cout << "Average error: " << averageError << std::endl;
259 if( averageError > tolerance )
260 {
261 std::cerr << "Results do not match truth within tolerance." << std::endl;
262 return EXIT_FAILURE;
263 }
264 }
265 else
266 {
267 std::cerr << "No points." << std::endl;
268 return EXIT_FAILURE;
269 }
270
271 return EXIT_SUCCESS;
272 }
273