1 /*=========================================================================
2  *
3  *  Copyright Insight Software Consortium
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at
8  *
9  *         http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  *=========================================================================*/
18 #ifndef itkDemonsRegistrationFunction_hxx
19 #define itkDemonsRegistrationFunction_hxx
20 
21 #include "itkDemonsRegistrationFunction.h"
22 #include "itkMacro.h"
23 #include "itkMath.h"
24 
25 namespace itk
26 {
27 /**
28  * Default constructor
29  */
30 template< typename TFixedImage, typename TMovingImage, typename TDisplacementField >
31 DemonsRegistrationFunction< TFixedImage, TMovingImage, TDisplacementField >
DemonsRegistrationFunction()32 ::DemonsRegistrationFunction()
33 {
34   RadiusType   r;
35   unsigned int j;
36 
37   for ( j = 0; j < ImageDimension; j++ )
38     {
39     r[j] = 0;
40     }
41   this->SetRadius(r);
42 
43   m_TimeStep = 1.0;
44   m_DenominatorThreshold = 1e-9;
45   m_IntensityDifferenceThreshold = 0.001;
46   this->SetMovingImage(nullptr);
47   this->SetFixedImage(nullptr);
48   //m_FixedImageSpacing.Fill( 1.0 );
49   //m_FixedImageOrigin.Fill( 0.0 );
50   m_Normalizer = 1.0;
51   m_FixedImageGradientCalculator = GradientCalculatorType::New();
52 
53   typename DefaultInterpolatorType::Pointer interp =
54     DefaultInterpolatorType::New();
55 
56   m_MovingImageInterpolator = static_cast< InterpolatorType * >(
57     interp.GetPointer() );
58 
59   m_Metric = NumericTraits< double >::max();
60   m_SumOfSquaredDifference = 0.0;
61   m_NumberOfPixelsProcessed = 0L;
62   m_RMSChange = NumericTraits< double >::max();
63   m_SumOfSquaredChange = 0.0;
64 
65   m_MovingImageGradientCalculator = MovingImageGradientCalculatorType::New();
66   m_UseMovingImageGradient = false;
67 }
68 
69 /**
70  * Standard "PrintSelf" method.
71  */
72 template< typename TFixedImage, typename TMovingImage, typename TDisplacementField >
73 void
74 DemonsRegistrationFunction< TFixedImage, TMovingImage, TDisplacementField >
PrintSelf(std::ostream & os,Indent indent) const75 ::PrintSelf(std::ostream & os, Indent indent) const
76 {
77   Superclass::PrintSelf(os, indent);
78 
79   os << indent << "MovingImageIterpolator: ";
80   os << m_MovingImageInterpolator.GetPointer() << std::endl;
81   os << indent << "FixedImageGradientCalculator: ";
82   os << m_FixedImageGradientCalculator.GetPointer() << std::endl;
83   os << indent << "DenominatorThreshold: ";
84   os << m_DenominatorThreshold << std::endl;
85   os << indent << "IntensityDifferenceThreshold: ";
86   os << m_IntensityDifferenceThreshold << std::endl;
87 
88   os << indent << "UseMovingImageGradient: ";
89   os << m_UseMovingImageGradient << std::endl;
90 
91   os << indent << "Metric: ";
92   os << m_Metric << std::endl;
93   os << indent << "SumOfSquaredDifference: ";
94   os << m_SumOfSquaredDifference << std::endl;
95   os << indent << "NumberOfPixelsProcessed: ";
96   os << m_NumberOfPixelsProcessed << std::endl;
97   os << indent << "RMSChange: ";
98   os << m_RMSChange << std::endl;
99   os << indent << "SumOfSquaredChange: ";
100   os << m_SumOfSquaredChange << std::endl;
101 }
102 
103 /**
104  *
105  */
106 template< typename TFixedImage, typename TMovingImage, typename TDisplacementField >
107 void
108 DemonsRegistrationFunction< TFixedImage, TMovingImage, TDisplacementField >
SetIntensityDifferenceThreshold(double threshold)109 ::SetIntensityDifferenceThreshold(double threshold)
110 {
111   m_IntensityDifferenceThreshold = threshold;
112 }
113 
114 /**
115  *
116  */
117 template< typename TFixedImage, typename TMovingImage, typename TDisplacementField >
118 double
119 DemonsRegistrationFunction< TFixedImage, TMovingImage, TDisplacementField >
GetIntensityDifferenceThreshold() const120 ::GetIntensityDifferenceThreshold() const
121 {
122   return m_IntensityDifferenceThreshold;
123 }
124 
125 /**
126  * Set the function state values before each iteration
127  */
128 template< typename TFixedImage, typename TMovingImage, typename TDisplacementField >
129 void
130 DemonsRegistrationFunction< TFixedImage, TMovingImage, TDisplacementField >
InitializeIteration()131 ::InitializeIteration()
132 {
133   if ( !this->GetMovingImage() || !this->GetFixedImage() || !m_MovingImageInterpolator )
134     {
135     itkExceptionMacro(<< "MovingImage, FixedImage and/or Interpolator not set");
136     }
137 
138   // cache fixed image information
139   SpacingType fixedImageSpacing    = this->GetFixedImage()->GetSpacing();
140   m_ZeroUpdateReturn.Fill(0.0);
141 
142   // compute the normalizer
143   m_Normalizer      = 0.0;
144   for ( unsigned int k = 0; k < ImageDimension; k++ )
145     {
146     m_Normalizer += fixedImageSpacing[k] * fixedImageSpacing[k];
147     }
148   m_Normalizer /= static_cast< double >( ImageDimension );
149 
150   // setup gradient calculator
151   m_FixedImageGradientCalculator->SetInputImage( this->GetFixedImage() );
152   m_MovingImageGradientCalculator->SetInputImage( this->GetMovingImage() );
153 
154   // setup moving image interpolator
155   m_MovingImageInterpolator->SetInputImage( this->GetMovingImage() );
156 
157   // initialize metric computation variables
158   m_SumOfSquaredDifference  = 0.0;
159   m_NumberOfPixelsProcessed = 0L;
160   m_SumOfSquaredChange      = 0.0;
161 }
162 
163 /**
164  * Compute update at a specify neighbourhood
165  */
166 template< typename TFixedImage, typename TMovingImage, typename TDisplacementField >
167 typename DemonsRegistrationFunction< TFixedImage, TMovingImage, TDisplacementField >
168 ::PixelType
169 DemonsRegistrationFunction< TFixedImage, TMovingImage, TDisplacementField >
ComputeUpdate(const NeighborhoodType & it,void * gd,const FloatOffsetType & itkNotUsed (offset))170 ::ComputeUpdate( const NeighborhoodType & it, void *gd,
171                  const FloatOffsetType & itkNotUsed(offset) )
172 {
173   // Get fixed image related information
174   // Note: no need to check the index is within
175   // fixed image buffer. This is done by the external filter.
176   const IndexType index = it.GetIndex();
177   const auto fixedValue = (double)this->GetFixedImage()->GetPixel(index);
178 
179   // Get moving image related information
180   PointType mappedPoint;
181 
182   this->GetFixedImage()->TransformIndexToPhysicalPoint(index, mappedPoint);
183   for ( unsigned int j = 0; j < ImageDimension; j++ )
184     {
185     mappedPoint[j] += it.GetCenterPixel()[j];
186     }
187 
188   double movingValue;
189   if ( m_MovingImageInterpolator->IsInsideBuffer(mappedPoint) )
190     {
191     movingValue = m_MovingImageInterpolator->Evaluate(mappedPoint);
192     }
193   else
194     {
195     return m_ZeroUpdateReturn;
196     }
197 
198   CovariantVectorType gradient;
199   // Compute the gradient of either fixed or moving image
200   if ( !m_UseMovingImageGradient )
201     {
202     gradient = m_FixedImageGradientCalculator->EvaluateAtIndex(index);
203     }
204   else
205     {
206     gradient = m_MovingImageGradientCalculator->Evaluate(mappedPoint);
207     }
208 
209   double gradientSquaredMagnitude = 0;
210   for ( unsigned int j = 0; j < ImageDimension; j++ )
211     {
212     gradientSquaredMagnitude += itk::Math::sqr(gradient[j]);
213     }
214 
215   /**
216    * Compute Update.
217    * In the original equation the denominator is defined as (g-f)^2 + grad_mag^2.
218    * However there is a mismatch in units between the two terms.
219    * The units for the second term is intensity^2/mm^2 while the
220    * units for the first term is intensity^2. This mismatch is particularly
221    * problematic when the fixed image does not have unit spacing.
222    * In this implementation, we normalize the first term by a factor K,
223    * such that denominator = (g-f)^2/K + grad_mag^2
224    * where K = mean square spacing to compensate for the mismatch in units.
225    */
226   const double speedValue = fixedValue - movingValue;
227   const double sqr_speedValue = itk::Math::sqr(speedValue);
228 
229   // update the metric
230   auto * globalData = (GlobalDataStruct *)gd;
231   if ( globalData )
232     {
233     globalData->m_SumOfSquaredDifference += sqr_speedValue;
234     globalData->m_NumberOfPixelsProcessed += 1;
235     }
236 
237   const double denominator = sqr_speedValue / m_Normalizer
238                              + gradientSquaredMagnitude;
239 
240   if ( itk::Math::abs(speedValue) < m_IntensityDifferenceThreshold
241        || denominator < m_DenominatorThreshold )
242     {
243     return m_ZeroUpdateReturn;
244     }
245 
246   PixelType update;
247   for ( unsigned int j = 0; j < ImageDimension; j++ )
248     {
249     update[j] = speedValue * gradient[j] / denominator;
250     if ( globalData )
251       {
252       globalData->m_SumOfSquaredChange += itk::Math::sqr(update[j]);
253       }
254     }
255   return update;
256 }
257 
258 /**
259  * Update the metric and release the per-thread-global data.
260  */
261 template< typename TFixedImage, typename TMovingImage, typename TDisplacementField >
262 void
263 DemonsRegistrationFunction< TFixedImage, TMovingImage, TDisplacementField >
ReleaseGlobalDataPointer(void * gd) const264 ::ReleaseGlobalDataPointer(void *gd) const
265 {
266   auto * globalData = (GlobalDataStruct *)gd;
267 
268   m_MetricCalculationLock.lock();
269   m_SumOfSquaredDifference += globalData->m_SumOfSquaredDifference;
270   m_NumberOfPixelsProcessed += globalData->m_NumberOfPixelsProcessed;
271   m_SumOfSquaredChange += globalData->m_SumOfSquaredChange;
272   if ( m_NumberOfPixelsProcessed )
273     {
274     m_Metric = m_SumOfSquaredDifference
275                / static_cast< double >( m_NumberOfPixelsProcessed );
276     m_RMSChange = std::sqrt( m_SumOfSquaredChange
277                             / static_cast< double >( m_NumberOfPixelsProcessed ) );
278     }
279   m_MetricCalculationLock.unlock();
280 
281   delete globalData;
282 }
283 } // end namespace itk
284 
285 #endif
286