1 /*=========================================================================
2  *
3  *  Copyright Insight Software Consortium
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at
8  *
9  *         http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  *=========================================================================*/
18 #ifndef itkESMDemonsRegistrationFunction_hxx
19 #define itkESMDemonsRegistrationFunction_hxx
20 
21 #include "itkESMDemonsRegistrationFunction.h"
22 #include "itkExceptionObject.h"
23 #include "itkMath.h"
24 
25 namespace itk
26 {
27 /**
28  * Default constructor
29  */
30 template< typename TFixedImage, typename TMovingImage, typename TDisplacementField >
31 ESMDemonsRegistrationFunction< TFixedImage, TMovingImage, TDisplacementField >
ESMDemonsRegistrationFunction()32 ::ESMDemonsRegistrationFunction()
33 {
34   RadiusType   r;
35   unsigned int j;
36 
37   for ( j = 0; j < ImageDimension; j++ )
38     {
39     r[j] = 0;
40     }
41   this->SetRadius(r);
42 
43   m_TimeStep = 1.0;
44   m_DenominatorThreshold = 1e-9;
45   m_IntensityDifferenceThreshold = 0.001;
46   m_MaximumUpdateStepLength = 0.5;
47 
48   this->SetMovingImage(nullptr);
49   this->SetFixedImage(nullptr);
50   m_FixedImageSpacing.Fill(1.0);
51   m_FixedImageOrigin.Fill(0.0);
52   m_FixedImageDirection.SetIdentity();
53   m_Normalizer = 0.0;
54   m_FixedImageGradientCalculator = GradientCalculatorType::New();
55   // Gradient orientation will be taken care of explicitely
56   m_FixedImageGradientCalculator->UseImageDirectionOff();
57   m_MappedMovingImageGradientCalculator = MovingImageGradientCalculatorType::New();
58   // Gradient orientation will be taken care of explicitely
59   m_MappedMovingImageGradientCalculator->UseImageDirectionOff();
60 
61   this->m_UseGradientType = Symmetric;
62 
63   typename DefaultInterpolatorType::Pointer interp =
64     DefaultInterpolatorType::New();
65 
66   m_MovingImageInterpolator = itkDynamicCastInDebugMode< InterpolatorType * >
67     ( interp.GetPointer() );
68 
69   m_MovingImageWarper = WarperType::New();
70   m_MovingImageWarper->SetInterpolator(m_MovingImageInterpolator);
71   m_MovingImageWarper->SetEdgePaddingValue( NumericTraits< MovingPixelType >::max() );
72 
73   m_MovingImageWarperOutput = nullptr;
74 
75   m_Metric = NumericTraits< double >::max();
76   m_SumOfSquaredDifference = 0.0;
77   m_NumberOfPixelsProcessed = 0L;
78   m_RMSChange = NumericTraits< double >::max();
79   m_SumOfSquaredChange = 0.0;
80 }
81 
82 /*
83  * Standard "PrintSelf" method.
84  */
85 template< typename TFixedImage, typename TMovingImage, typename TDisplacementField >
86 void
87 ESMDemonsRegistrationFunction< TFixedImage, TMovingImage, TDisplacementField >
PrintSelf(std::ostream & os,Indent indent) const88 ::PrintSelf(std::ostream & os, Indent indent) const
89 {
90   Superclass::PrintSelf(os, indent);
91 
92   os << indent << "UseGradientType: ";
93   os << m_UseGradientType << std::endl;
94   os << indent << "MaximumUpdateStepLength: ";
95   os << m_MaximumUpdateStepLength << std::endl;
96 
97   os << indent << "MovingImageIterpolator: ";
98   os << m_MovingImageInterpolator.GetPointer() << std::endl;
99   os << indent << "FixedImageGradientCalculator: ";
100   os << m_FixedImageGradientCalculator.GetPointer() << std::endl;
101   os << indent << "MappedMovingImageGradientCalculator: ";
102   os << m_MappedMovingImageGradientCalculator.GetPointer() << std::endl;
103   os << indent << "DenominatorThreshold: ";
104   os << m_DenominatorThreshold << std::endl;
105   os << indent << "IntensityDifferenceThreshold: ";
106   os << m_IntensityDifferenceThreshold << std::endl;
107 
108   os << indent << "Metric: ";
109   os << m_Metric << std::endl;
110   os << indent << "SumOfSquaredDifference: ";
111   os << m_SumOfSquaredDifference << std::endl;
112   os << indent << "NumberOfPixelsProcessed: ";
113   os << m_NumberOfPixelsProcessed << std::endl;
114   os << indent << "RMSChange: ";
115   os << m_RMSChange << std::endl;
116   os << indent << "SumOfSquaredChange: ";
117   os << m_SumOfSquaredChange << std::endl;
118 }
119 
120 /**
121  *
122  */
123 template< typename TFixedImage, typename TMovingImage, typename TDisplacementField >
124 void
125 ESMDemonsRegistrationFunction< TFixedImage, TMovingImage, TDisplacementField >
SetIntensityDifferenceThreshold(double threshold)126 ::SetIntensityDifferenceThreshold(double threshold)
127 {
128   m_IntensityDifferenceThreshold = threshold;
129 }
130 
131 /**
132  *
133  */
134 template< typename TFixedImage, typename TMovingImage, typename TDisplacementField >
135 double
136 ESMDemonsRegistrationFunction< TFixedImage, TMovingImage, TDisplacementField >
GetIntensityDifferenceThreshold() const137 ::GetIntensityDifferenceThreshold() const
138 {
139   return m_IntensityDifferenceThreshold;
140 }
141 
142 /**
143  * Set the function state values before each iteration
144  */
145 template< typename TFixedImage, typename TMovingImage, typename TDisplacementField >
146 void
147 ESMDemonsRegistrationFunction< TFixedImage, TMovingImage, TDisplacementField >
InitializeIteration()148 ::InitializeIteration()
149 {
150   if ( !this->GetMovingImage() || !this->GetFixedImage()
151        || !m_MovingImageInterpolator )
152     {
153     itkExceptionMacro(
154       << "MovingImage, FixedImage and/or Interpolator not set");
155     }
156 
157   // cache fixed image information
158   m_FixedImageOrigin  = this->GetFixedImage()->GetOrigin();
159   m_FixedImageSpacing = this->GetFixedImage()->GetSpacing();
160   m_FixedImageDirection = this->GetFixedImage()->GetDirection();
161 
162   // compute the normalizer
163   if ( m_MaximumUpdateStepLength > 0.0 )
164     {
165     m_Normalizer = 0.0;
166     for ( unsigned int k = 0; k < ImageDimension; k++ )
167       {
168       m_Normalizer += m_FixedImageSpacing[k] * m_FixedImageSpacing[k];
169       }
170     m_Normalizer *= m_MaximumUpdateStepLength * m_MaximumUpdateStepLength
171                     / static_cast< double >( ImageDimension );
172     }
173   else
174     {
175     // set it to minus one to denote a special case
176     // ( unrestricted update length )
177     m_Normalizer = -1.0;
178     }
179 
180   // setup gradient calculator
181   m_FixedImageGradientCalculator->SetInputImage( this->GetFixedImage() );
182   m_MappedMovingImageGradientCalculator->SetInputImage( this->GetMovingImage() );
183 
184   // Compute warped moving image
185   m_MovingImageWarper->SetOutputOrigin(this->m_FixedImageOrigin);
186   m_MovingImageWarper->SetOutputSpacing(this->m_FixedImageSpacing);
187   m_MovingImageWarper->SetOutputDirection(this->m_FixedImageDirection);
188   m_MovingImageWarper->SetInput( this->GetMovingImage() );
189   m_MovingImageWarper->SetDisplacementField( this->GetDisplacementField() );
190   m_MovingImageWarper->GetOutput()->SetRequestedRegion( this->GetDisplacementField()->GetRequestedRegion() );
191   m_MovingImageWarper->Update();
192   this->m_MovingImageWarperOutput =
193     this->m_MovingImageWarper->GetOutput();
194   // setup moving image interpolator for further access
195   m_MovingImageInterpolator->SetInputImage( this->GetMovingImage() );
196 
197   // initialize metric computation variables
198   m_SumOfSquaredDifference  = 0.0;
199   m_NumberOfPixelsProcessed = 0L;
200   m_SumOfSquaredChange      = 0.0;
201 }
202 
203 /**
204  * Compute update at a non boundary neighbourhood
205  */
206 template< typename TFixedImage, typename TMovingImage, typename TDisplacementField >
207 typename ESMDemonsRegistrationFunction< TFixedImage, TMovingImage, TDisplacementField >
208 ::PixelType
209 ESMDemonsRegistrationFunction< TFixedImage, TMovingImage, TDisplacementField >
ComputeUpdate(const NeighborhoodType & it,void * gd,const FloatOffsetType & itkNotUsed (offset))210 ::ComputeUpdate( const NeighborhoodType & it, void *gd,
211                  const FloatOffsetType & itkNotUsed(offset) )
212 {
213   auto * globalData = (GlobalDataStruct *)gd;
214   PixelType         update;
215   IndexType         FirstIndex = this->GetFixedImage()->GetLargestPossibleRegion().GetIndex();
216   IndexType         LastIndex = this->GetFixedImage()->GetLargestPossibleRegion().GetIndex()
217                                 + this->GetFixedImage()->GetLargestPossibleRegion().GetSize();
218 
219   const IndexType index = it.GetIndex();
220 
221   // Get fixed image related information
222   // Note: no need to check if the index is within
223   // fixed image buffer. This is done by the external filter.
224   const auto fixedValue = static_cast< double >( this->GetFixedImage()->GetPixel(index) );
225 
226   // Get moving image related information
227   // check if the point was mapped outside of the moving image using
228   // the "special value" NumericTraits<MovingPixelType>::max()
229   MovingPixelType movingPixValue =
230     m_MovingImageWarperOutput->GetPixel(index);
231 
232   if ( movingPixValue == NumericTraits< MovingPixelType >::max() )
233     {
234     update.Fill(0.0);
235     return update;
236     }
237 
238   const auto movingValue = static_cast< double >( movingPixValue );
239 
240   // We compute the gradient more or less by hand.
241   // We first start by ignoring the image orientation and introduce it
242   // afterwards
243   CovariantVectorType usedOrientFreeGradientTimes2;
244 
245   if ( ( this->m_UseGradientType == Symmetric )
246        || ( this->m_UseGradientType == WarpedMoving ) )
247     {
248     // we don't use a CentralDifferenceImageFunction here to be able to
249     // check for NumericTraits<MovingPixelType>::max()
250     CovariantVectorType warpedMovingGradient;
251     IndexType           tmpIndex = index;
252     for ( unsigned int dim = 0; dim < ImageDimension; dim++ )
253       {
254       // bounds checking
255       if ( FirstIndex[dim] == LastIndex[dim]
256            || index[dim] < FirstIndex[dim]
257            || index[dim] >= LastIndex[dim] )
258         {
259         warpedMovingGradient[dim] = 0.0;
260         continue;
261         }
262       else if ( index[dim] == FirstIndex[dim] )
263         {
264         // compute derivative
265         tmpIndex[dim] += 1;
266         movingPixValue = m_MovingImageWarperOutput->GetPixel(tmpIndex);
267         if ( movingPixValue == NumericTraits< MovingPixelType >::max() )
268           {
269           // weird crunched border case
270           warpedMovingGradient[dim] = 0.0;
271           }
272         else
273           {
274           // forward difference
275           warpedMovingGradient[dim] = static_cast< double >( movingPixValue ) - movingValue;
276           warpedMovingGradient[dim] /= m_FixedImageSpacing[dim];
277           }
278         tmpIndex[dim] -= 1;
279         continue;
280         }
281       else if ( index[dim] == ( LastIndex[dim] - 1 ) )
282         {
283         // compute derivative
284         tmpIndex[dim] -= 1;
285         movingPixValue = m_MovingImageWarperOutput->GetPixel(tmpIndex);
286         if ( movingPixValue == NumericTraits< MovingPixelType >::max() )
287           {
288           // weird crunched border case
289           warpedMovingGradient[dim] = 0.0;
290           }
291         else
292           {
293           // backward difference
294           warpedMovingGradient[dim] = movingValue - static_cast< double >( movingPixValue );
295           warpedMovingGradient[dim] /= m_FixedImageSpacing[dim];
296           }
297         tmpIndex[dim] += 1;
298         continue;
299         }
300 
301       // compute derivative
302       tmpIndex[dim] += 1;
303       movingPixValue = m_MovingImageWarperOutput->GetPixel(tmpIndex);
304       if ( movingPixValue == NumericTraits
305            < MovingPixelType >::max() )
306         {
307         // backward difference
308         warpedMovingGradient[dim] = movingValue;
309 
310         tmpIndex[dim] -= 2;
311         movingPixValue = m_MovingImageWarperOutput->GetPixel(tmpIndex);
312         if ( movingPixValue == NumericTraits< MovingPixelType >::max() )
313           {
314           // weird crunched border case
315           warpedMovingGradient[dim] = 0.0;
316           }
317         else
318           {
319           // backward difference
320           warpedMovingGradient[dim] -= static_cast< double >(
321             m_MovingImageWarperOutput->GetPixel(tmpIndex) );
322 
323           warpedMovingGradient[dim] /= m_FixedImageSpacing[dim];
324           }
325         }
326       else
327         {
328         warpedMovingGradient[dim] = static_cast< double >( movingPixValue );
329 
330         tmpIndex[dim] -= 2;
331         movingPixValue = m_MovingImageWarperOutput->GetPixel(tmpIndex);
332         if ( movingPixValue == NumericTraits< MovingPixelType >::max() )
333           {
334           // forward difference
335           warpedMovingGradient[dim] -= movingValue;
336           warpedMovingGradient[dim] /= m_FixedImageSpacing[dim];
337           }
338         else
339           {
340           // normal case, central difference
341           warpedMovingGradient[dim] -= static_cast< double >( movingPixValue );
342           warpedMovingGradient[dim] *= 0.5 / m_FixedImageSpacing[dim];
343           }
344         }
345       tmpIndex[dim] += 1;
346       }
347 
348     if ( this->m_UseGradientType == Symmetric )
349       {
350       // Compute orientation-free gradient with calculator
351       const CovariantVectorType fixedGradient =
352         m_FixedImageGradientCalculator->EvaluateAtIndex(index);
353 
354       usedOrientFreeGradientTimes2 = fixedGradient + warpedMovingGradient;
355       }
356     else if ( this->m_UseGradientType == WarpedMoving )
357       {
358       usedOrientFreeGradientTimes2 = warpedMovingGradient + warpedMovingGradient;
359       }
360     else
361       {
362       itkExceptionMacro(<< "Unknown gradient type");
363       }
364     }
365   else if ( this->m_UseGradientType == Fixed )
366     {
367     // Compute orientation-free gradient with calculator
368     const CovariantVectorType fixedGradient =
369       m_FixedImageGradientCalculator->EvaluateAtIndex(index);
370 
371     usedOrientFreeGradientTimes2 = fixedGradient + fixedGradient;
372     }
373   else if ( this->m_UseGradientType == MappedMoving )
374     {
375     PointType mappedPoint;
376     this->GetFixedImage()->TransformIndexToPhysicalPoint(index, mappedPoint);
377     for ( unsigned int j = 0; j < ImageDimension; j++ )
378       {
379       mappedPoint[j] += it.GetCenterPixel()[j];
380       }
381 
382     const CovariantVectorType mappedMovingGradient =
383       m_MappedMovingImageGradientCalculator->Evaluate(mappedPoint);
384 
385     usedOrientFreeGradientTimes2 = mappedMovingGradient + mappedMovingGradient;
386     }
387   else
388     {
389     itkExceptionMacro(<< "Unknown gradient type");
390     }
391 
392   CovariantVectorType usedGradientTimes2;
393   this->GetFixedImage()->TransformLocalVectorToPhysicalVector(
394     usedOrientFreeGradientTimes2, usedGradientTimes2);
395 
396   /**
397    * Compute Update.
398    * We avoid the mismatch in units between the two terms.
399    * and avoid large step using a normalization term.
400    */
401 
402   const double usedGradientTimes2SquaredMagnitude =
403     usedGradientTimes2.GetSquaredNorm();
404 
405   const double speedValue = fixedValue - movingValue;
406   if ( itk::Math::abs(speedValue) < m_IntensityDifferenceThreshold )
407     {
408     update.Fill(0.0);
409     }
410   else
411     {
412     double denom;
413     if (  m_Normalizer > 0.0 )
414       {
415       // "ITK-Thirion" normalization
416       denom =  usedGradientTimes2SquaredMagnitude + ( itk::Math::sqr(speedValue) / m_Normalizer );
417       }
418     else
419       {
420       // least square solution of the system
421       denom =  usedGradientTimes2SquaredMagnitude;
422       }
423 
424     if ( denom < m_DenominatorThreshold )
425       {
426       update.Fill(0.0);
427       }
428     else
429       {
430       const double factor = 2.0 * speedValue / denom;
431 
432       for ( unsigned int j = 0; j < ImageDimension; j++ )
433         {
434         update[j] = factor * usedGradientTimes2[j];
435         }
436       }
437     }
438 
439   // WARNING!! We compute the global data without taking into account the
440   // current update step.
441   // There are several reasons for that: If an exponential, a smoothing or any
442   // other operation
443   // is applied on the update field, we cannot compute the newMappedCenterPoint
444   // here; and even
445   // if we could, this would be an often unnecessary time-consuming task.
446   if ( globalData )
447     {
448     globalData->m_SumOfSquaredDifference += itk::Math::sqr(speedValue);
449     globalData->m_NumberOfPixelsProcessed += 1;
450     globalData->m_SumOfSquaredChange += update.GetSquaredNorm();
451     }
452 
453   return update;
454 }
455 
456 /**
457  * Update the metric and release the per-thread-global data.
458  */
459 template< typename TFixedImage, typename TMovingImage, typename TDisplacementField >
460 void
461 ESMDemonsRegistrationFunction< TFixedImage, TMovingImage, TDisplacementField >
ReleaseGlobalDataPointer(void * gd) const462 ::ReleaseGlobalDataPointer(void *gd) const
463 {
464   auto * globalData = (GlobalDataStruct *)gd;
465 
466   m_MetricCalculationLock.lock();
467   m_SumOfSquaredDifference += globalData->m_SumOfSquaredDifference;
468   m_NumberOfPixelsProcessed += globalData->m_NumberOfPixelsProcessed;
469   m_SumOfSquaredChange += globalData->m_SumOfSquaredChange;
470   if ( m_NumberOfPixelsProcessed )
471     {
472     m_Metric = m_SumOfSquaredDifference
473                / static_cast< double >( m_NumberOfPixelsProcessed );
474     m_RMSChange = std::sqrt( m_SumOfSquaredChange
475                             / static_cast< double >( m_NumberOfPixelsProcessed ) );
476     }
477   m_MetricCalculationLock.unlock();
478 
479   delete globalData;
480 }
481 } // end namespace itk
482 
483 #endif
484