1 /*=========================================================================
2 *
3 * Copyright Insight Software Consortium
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0.txt
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 *=========================================================================*/
18
19 #include "itkJensenHavrdaCharvatTsallisPointSetToPointSetMetricv4.h"
20 #include "itkGradientDescentOptimizerv4.h"
21 #include "itkTransform.h"
22 #include "itkAffineTransform.h"
23 #include "itkRegistrationParameterScalesFromPhysicalShift.h"
24 #include "itkCommand.h"
25
26 #include <fstream>
27
28 template<typename TFilter>
29 class itkJensenHavrdaCharvatTsallisPointSetMetricRegistrationTestCommandIterationUpdate : public itk::Command
30 {
31 public:
32 using Self = itkJensenHavrdaCharvatTsallisPointSetMetricRegistrationTestCommandIterationUpdate;
33
34 using Superclass = itk::Command;
35 using Pointer = itk::SmartPointer<Self>;
36 itkNewMacro( Self );
37
38 protected:
39 itkJensenHavrdaCharvatTsallisPointSetMetricRegistrationTestCommandIterationUpdate() = default;
40
41 public:
42
Execute(itk::Object * caller,const itk::EventObject & event)43 void Execute(itk::Object *caller, const itk::EventObject & event) override
44 {
45 Execute( (const itk::Object *) caller, event);
46 }
47
Execute(const itk::Object * object,const itk::EventObject & event)48 void Execute(const itk::Object * object, const itk::EventObject & event) override
49 {
50 if( typeid( event ) != typeid( itk::IterationEvent ) )
51 {
52 return;
53 }
54 const auto * optimizer = dynamic_cast< const TFilter * >( object );
55
56 if( !optimizer )
57 {
58 itkGenericExceptionMacro( "Error dynamic_cast failed" );
59 }
60 std::cout << "It: " << optimizer->GetCurrentIteration() << " metric value: " << optimizer->GetCurrentMetricValue();
61 std::cout << std::endl;
62 }
63 };
64
itkJensenHavrdaCharvatTsallisPointSetMetricRegistrationTest(int argc,char * argv[])65 int itkJensenHavrdaCharvatTsallisPointSetMetricRegistrationTest( int argc, char *argv[] )
66 {
67 constexpr unsigned int Dimension = 2;
68
69 unsigned int numberOfIterations = 10;
70 if( argc > 1 )
71 {
72 numberOfIterations = std::stoi( argv[1] );
73 }
74
75 using PointSetType = itk::PointSet<unsigned char, Dimension>;
76
77 using PointType = PointSetType::PointType;
78
79 PointSetType::Pointer fixedPoints = PointSetType::New();
80 fixedPoints->Initialize();
81
82 PointSetType::Pointer movingPoints = PointSetType::New();
83 movingPoints->Initialize();
84
85
86 // two ellipses, one rotated slightly
87 /*
88 // Having trouble with these, as soon as there's a slight rotation added.
89 unsigned long count = 0;
90 for( float theta = 0; theta < 2.0 * itk::Math::pi; theta += 0.1 )
91 {
92 float radius = 100.0;
93 PointType fixedPoint;
94 fixedPoint[0] = 2 * radius * std::cos( theta );
95 fixedPoint[1] = radius * std::sin( theta );
96 fixedPoints->SetPoint( count, fixedPoint );
97
98 PointType movingPoint;
99 movingPoint[0] = 2 * radius * std::cos( theta + (0.02 * itk::Math::pi) ) + 2.0;
100 movingPoint[1] = radius * std::sin( theta + (0.02 * itk::Math::pi) ) + 2.0;
101 movingPoints->SetPoint( count, movingPoint );
102
103 count++;
104 }
105 */
106
107 // two circles with a small offset
108 PointType offset;
109 for( unsigned int d=0; d < Dimension; d++ )
110 {
111 offset[d] = 2.0;
112 }
113 unsigned long count = 0;
114 for( float theta = 0; theta < 2.0 * itk::Math::pi; theta += 0.1 )
115 {
116 PointType fixedPoint;
117 float radius = 100.0;
118 fixedPoint[0] = radius * std::cos( theta );
119 fixedPoint[1] = radius * std::sin( theta );
120 if( Dimension > 2 )
121 {
122 fixedPoint[2] = radius * std::sin( theta );
123 }
124 fixedPoints->SetPoint( count, fixedPoint );
125
126 PointType movingPoint;
127 movingPoint[0] = fixedPoint[0] + offset[0];
128 movingPoint[1] = fixedPoint[1] + offset[1];
129 if( Dimension > 2 )
130 {
131 movingPoint[2] = fixedPoint[2] + offset[2];
132 }
133 movingPoints->SetPoint( count, movingPoint );
134
135 count++;
136 }
137
138 using AffineTransformType = itk::AffineTransform<double, Dimension>;
139 AffineTransformType::Pointer transform = AffineTransformType::New();
140 transform->SetIdentity();
141
142 // Instantiate the metric
143 using PointSetMetricType = itk::JensenHavrdaCharvatTsallisPointSetToPointSetMetricv4<PointSetType>;
144 PointSetMetricType::Pointer metric = PointSetMetricType::New();
145 metric->SetFixedPointSet( fixedPoints );
146 metric->SetMovingPointSet( movingPoints );
147 metric->SetPointSetSigma( 1.0 );
148 metric->SetKernelSigma( 10.0 );
149 metric->SetUseAnisotropicCovariances( false );
150 metric->SetCovarianceKNeighborhood( 5 );
151 metric->SetEvaluationKNeighborhood( 10 );
152 metric->SetMovingTransform( transform );
153 metric->SetAlpha( 1.1 );
154 metric->Initialize();
155
156 // scales estimator
157 using RegistrationParameterScalesFromShiftType = itk::RegistrationParameterScalesFromPhysicalShift< PointSetMetricType >;
158 RegistrationParameterScalesFromShiftType::Pointer shiftScaleEstimator = RegistrationParameterScalesFromShiftType::New();
159 shiftScaleEstimator->SetMetric( metric );
160 // needed with pointset metrics
161 shiftScaleEstimator->SetVirtualDomainPointSet( metric->GetVirtualTransformedPointSet() );
162
163 // optimizer
164 using OptimizerType = itk::GradientDescentOptimizerv4;
165 OptimizerType::Pointer optimizer = OptimizerType::New();
166 optimizer->SetMetric( metric );
167 optimizer->SetNumberOfIterations( numberOfIterations );
168 optimizer->SetScalesEstimator( shiftScaleEstimator );
169 optimizer->SetMaximumStepSizeInPhysicalUnits( 3.0 );
170
171 using CommandType = itkJensenHavrdaCharvatTsallisPointSetMetricRegistrationTestCommandIterationUpdate<OptimizerType>;
172 CommandType::Pointer observer = CommandType::New();
173 optimizer->AddObserver( itk::IterationEvent(), observer );
174
175 optimizer->SetMinimumConvergenceValue( 0.0 );
176 optimizer->SetConvergenceWindowSize( 10 );
177 optimizer->StartOptimization();
178
179 std::cout << "numberOfIterations: " << numberOfIterations << std::endl;
180 std::cout << "Moving-source final value: " << optimizer->GetCurrentMetricValue() << std::endl;
181 std::cout << "Moving-source final position: " << optimizer->GetCurrentPosition() << std::endl;
182 std::cout << "Optimizer scales: " << optimizer->GetScales() << std::endl;
183 std::cout << "Optimizer learning rate: " << optimizer->GetLearningRate() << std::endl;
184
185 // applying the resultant transform to moving points and verify result
186 std::cout << "Fixed\tMoving\tMovingTransformed\tFixedTransformed\tDiff" << std::endl;
187 bool passed = true;
188 PointType::ValueType tolerance = 1e-2;
189 AffineTransformType::InverseTransformBasePointer movingInverse = metric->GetMovingTransform()->GetInverseTransform();
190 AffineTransformType::InverseTransformBasePointer fixedInverse = metric->GetFixedTransform()->GetInverseTransform();
191 for( unsigned int n=0; n < metric->GetNumberOfComponents(); n++ )
192 {
193 // compare the points in virtual domain
194 PointType transformedMovingPoint = movingInverse->TransformPoint( movingPoints->GetPoint( n ) );
195 PointType transformedFixedPoint = fixedInverse->TransformPoint( fixedPoints->GetPoint( n ) );
196 PointType difference;
197 difference[0] = transformedMovingPoint[0] - transformedFixedPoint[0];
198 difference[1] = transformedMovingPoint[1] - transformedFixedPoint[1];
199 std::cout << fixedPoints->GetPoint( n ) << "\t" << movingPoints->GetPoint( n )
200 << "\t" << transformedMovingPoint << "\t" << transformedFixedPoint << "\t" << difference << std::endl;
201 if( fabs( difference[0] ) > tolerance || fabs( difference[1] ) > tolerance )
202 {
203 passed = false;
204 }
205 }
206 if( ! passed )
207 {
208 std::cerr << "Results do not match truth within tolerance." << std::endl;
209 return EXIT_FAILURE;
210 }
211
212
213 return EXIT_SUCCESS;
214 }
215