1 /*=========================================================================
2  *
3  *  Copyright Insight Software Consortium
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at
8  *
9  *         http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  *=========================================================================*/
18 
19 #include "itkJensenHavrdaCharvatTsallisPointSetToPointSetMetricv4.h"
20 #include "itkGradientDescentOptimizerv4.h"
21 #include "itkTransform.h"
22 #include "itkAffineTransform.h"
23 #include "itkRegistrationParameterScalesFromPhysicalShift.h"
24 #include "itkCommand.h"
25 
26 #include <fstream>
27 
28 template<typename TFilter>
29 class itkJensenHavrdaCharvatTsallisPointSetMetricRegistrationTestCommandIterationUpdate : public itk::Command
30 {
31 public:
32   using Self = itkJensenHavrdaCharvatTsallisPointSetMetricRegistrationTestCommandIterationUpdate;
33 
34   using Superclass = itk::Command;
35   using Pointer = itk::SmartPointer<Self>;
36   itkNewMacro( Self );
37 
38 protected:
39   itkJensenHavrdaCharvatTsallisPointSetMetricRegistrationTestCommandIterationUpdate() = default;
40 
41 public:
42 
Execute(itk::Object * caller,const itk::EventObject & event)43   void Execute(itk::Object *caller, const itk::EventObject & event) override
44     {
45     Execute( (const itk::Object *) caller, event);
46     }
47 
Execute(const itk::Object * object,const itk::EventObject & event)48   void Execute(const itk::Object * object, const itk::EventObject & event) override
49     {
50     if( typeid( event ) != typeid( itk::IterationEvent ) )
51       {
52       return;
53       }
54     const auto * optimizer = dynamic_cast< const TFilter * >( object );
55 
56     if( !optimizer )
57       {
58       itkGenericExceptionMacro( "Error dynamic_cast failed" );
59       }
60     std::cout << "It: " << optimizer->GetCurrentIteration() << " metric value: " << optimizer->GetCurrentMetricValue();
61     std::cout << std::endl;
62     }
63 };
64 
itkJensenHavrdaCharvatTsallisPointSetMetricRegistrationTest(int argc,char * argv[])65 int itkJensenHavrdaCharvatTsallisPointSetMetricRegistrationTest( int argc, char *argv[] )
66 {
67   constexpr unsigned int Dimension = 2;
68 
69   unsigned int numberOfIterations = 10;
70   if( argc > 1 )
71     {
72     numberOfIterations = std::stoi( argv[1] );
73     }
74 
75   using PointSetType = itk::PointSet<unsigned char, Dimension>;
76 
77   using PointType = PointSetType::PointType;
78 
79   PointSetType::Pointer fixedPoints = PointSetType::New();
80   fixedPoints->Initialize();
81 
82   PointSetType::Pointer movingPoints = PointSetType::New();
83   movingPoints->Initialize();
84 
85 
86   // two ellipses, one rotated slightly
87 /*
88   // Having trouble with these, as soon as there's a slight rotation added.
89   unsigned long count = 0;
90   for( float theta = 0; theta < 2.0 * itk::Math::pi; theta += 0.1 )
91     {
92     float radius = 100.0;
93     PointType fixedPoint;
94     fixedPoint[0] = 2 * radius * std::cos( theta );
95     fixedPoint[1] = radius * std::sin( theta );
96     fixedPoints->SetPoint( count, fixedPoint );
97 
98     PointType movingPoint;
99     movingPoint[0] = 2 * radius * std::cos( theta + (0.02 * itk::Math::pi) ) + 2.0;
100     movingPoint[1] = radius * std::sin( theta + (0.02 * itk::Math::pi) ) + 2.0;
101     movingPoints->SetPoint( count, movingPoint );
102 
103     count++;
104     }
105 */
106 
107   // two circles with a small offset
108   PointType offset;
109   for( unsigned int d=0; d < Dimension; d++ )
110     {
111     offset[d] = 2.0;
112     }
113   unsigned long count = 0;
114   for( float theta = 0; theta < 2.0 * itk::Math::pi; theta += 0.1 )
115     {
116     PointType fixedPoint;
117     float radius = 100.0;
118     fixedPoint[0] = radius * std::cos( theta );
119     fixedPoint[1] = radius * std::sin( theta );
120     if( Dimension > 2 )
121       {
122       fixedPoint[2] = radius * std::sin( theta );
123       }
124     fixedPoints->SetPoint( count, fixedPoint );
125 
126     PointType movingPoint;
127     movingPoint[0] = fixedPoint[0] + offset[0];
128     movingPoint[1] = fixedPoint[1] + offset[1];
129     if( Dimension > 2 )
130       {
131       movingPoint[2] = fixedPoint[2] + offset[2];
132       }
133     movingPoints->SetPoint( count, movingPoint );
134 
135     count++;
136     }
137 
138   using AffineTransformType = itk::AffineTransform<double, Dimension>;
139   AffineTransformType::Pointer transform = AffineTransformType::New();
140   transform->SetIdentity();
141 
142   // Instantiate the metric
143   using PointSetMetricType = itk::JensenHavrdaCharvatTsallisPointSetToPointSetMetricv4<PointSetType>;
144   PointSetMetricType::Pointer metric = PointSetMetricType::New();
145   metric->SetFixedPointSet( fixedPoints );
146   metric->SetMovingPointSet( movingPoints );
147   metric->SetPointSetSigma( 1.0 );
148   metric->SetKernelSigma( 10.0 );
149   metric->SetUseAnisotropicCovariances( false );
150   metric->SetCovarianceKNeighborhood( 5 );
151   metric->SetEvaluationKNeighborhood( 10 );
152   metric->SetMovingTransform( transform );
153   metric->SetAlpha( 1.1 );
154   metric->Initialize();
155 
156   // scales estimator
157   using RegistrationParameterScalesFromShiftType = itk::RegistrationParameterScalesFromPhysicalShift< PointSetMetricType >;
158   RegistrationParameterScalesFromShiftType::Pointer shiftScaleEstimator = RegistrationParameterScalesFromShiftType::New();
159   shiftScaleEstimator->SetMetric( metric );
160   // needed with pointset metrics
161   shiftScaleEstimator->SetVirtualDomainPointSet( metric->GetVirtualTransformedPointSet() );
162 
163   // optimizer
164   using OptimizerType = itk::GradientDescentOptimizerv4;
165   OptimizerType::Pointer  optimizer = OptimizerType::New();
166   optimizer->SetMetric( metric );
167   optimizer->SetNumberOfIterations( numberOfIterations );
168   optimizer->SetScalesEstimator( shiftScaleEstimator );
169   optimizer->SetMaximumStepSizeInPhysicalUnits( 3.0 );
170 
171   using CommandType = itkJensenHavrdaCharvatTsallisPointSetMetricRegistrationTestCommandIterationUpdate<OptimizerType>;
172   CommandType::Pointer observer = CommandType::New();
173   optimizer->AddObserver( itk::IterationEvent(), observer );
174 
175   optimizer->SetMinimumConvergenceValue( 0.0 );
176   optimizer->SetConvergenceWindowSize( 10 );
177   optimizer->StartOptimization();
178 
179   std::cout << "numberOfIterations: " << numberOfIterations << std::endl;
180   std::cout << "Moving-source final value: " << optimizer->GetCurrentMetricValue() << std::endl;
181   std::cout << "Moving-source final position: " << optimizer->GetCurrentPosition() << std::endl;
182   std::cout << "Optimizer scales: " << optimizer->GetScales() << std::endl;
183   std::cout << "Optimizer learning rate: " << optimizer->GetLearningRate() << std::endl;
184 
185   // applying the resultant transform to moving points and verify result
186   std::cout << "Fixed\tMoving\tMovingTransformed\tFixedTransformed\tDiff" << std::endl;
187   bool passed = true;
188   PointType::ValueType tolerance = 1e-2;
189   AffineTransformType::InverseTransformBasePointer movingInverse = metric->GetMovingTransform()->GetInverseTransform();
190   AffineTransformType::InverseTransformBasePointer fixedInverse = metric->GetFixedTransform()->GetInverseTransform();
191   for( unsigned int n=0; n < metric->GetNumberOfComponents(); n++ )
192     {
193     // compare the points in virtual domain
194     PointType transformedMovingPoint = movingInverse->TransformPoint( movingPoints->GetPoint( n ) );
195     PointType transformedFixedPoint = fixedInverse->TransformPoint( fixedPoints->GetPoint( n ) );
196     PointType difference;
197     difference[0] = transformedMovingPoint[0] - transformedFixedPoint[0];
198     difference[1] = transformedMovingPoint[1] - transformedFixedPoint[1];
199     std::cout << fixedPoints->GetPoint( n ) << "\t" << movingPoints->GetPoint( n )
200           << "\t" << transformedMovingPoint << "\t" << transformedFixedPoint << "\t" << difference << std::endl;
201     if( fabs( difference[0] ) > tolerance || fabs( difference[1] ) > tolerance )
202       {
203       passed = false;
204       }
205     }
206   if( ! passed )
207     {
208     std::cerr << "Results do not match truth within tolerance." << std::endl;
209     return EXIT_FAILURE;
210     }
211 
212 
213   return EXIT_SUCCESS;
214 }
215