1 /*=========================================================================
2 *
3 * Copyright Insight Software Consortium
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0.txt
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 *=========================================================================*/
18 #ifndef itkESMDemonsRegistrationFunction_hxx
19 #define itkESMDemonsRegistrationFunction_hxx
20
21 #include "itkESMDemonsRegistrationFunction.h"
22 #include "itkExceptionObject.h"
23 #include "itkMath.h"
24
25 namespace itk
26 {
27 /**
28 * Default constructor
29 */
30 template< typename TFixedImage, typename TMovingImage, typename TDisplacementField >
31 ESMDemonsRegistrationFunction< TFixedImage, TMovingImage, TDisplacementField >
ESMDemonsRegistrationFunction()32 ::ESMDemonsRegistrationFunction()
33 {
34 RadiusType r;
35 unsigned int j;
36
37 for ( j = 0; j < ImageDimension; j++ )
38 {
39 r[j] = 0;
40 }
41 this->SetRadius(r);
42
43 m_TimeStep = 1.0;
44 m_DenominatorThreshold = 1e-9;
45 m_IntensityDifferenceThreshold = 0.001;
46 m_MaximumUpdateStepLength = 0.5;
47
48 this->SetMovingImage(nullptr);
49 this->SetFixedImage(nullptr);
50 m_FixedImageSpacing.Fill(1.0);
51 m_FixedImageOrigin.Fill(0.0);
52 m_FixedImageDirection.SetIdentity();
53 m_Normalizer = 0.0;
54 m_FixedImageGradientCalculator = GradientCalculatorType::New();
55 // Gradient orientation will be taken care of explicitely
56 m_FixedImageGradientCalculator->UseImageDirectionOff();
57 m_MappedMovingImageGradientCalculator = MovingImageGradientCalculatorType::New();
58 // Gradient orientation will be taken care of explicitely
59 m_MappedMovingImageGradientCalculator->UseImageDirectionOff();
60
61 this->m_UseGradientType = Symmetric;
62
63 typename DefaultInterpolatorType::Pointer interp =
64 DefaultInterpolatorType::New();
65
66 m_MovingImageInterpolator = itkDynamicCastInDebugMode< InterpolatorType * >
67 ( interp.GetPointer() );
68
69 m_MovingImageWarper = WarperType::New();
70 m_MovingImageWarper->SetInterpolator(m_MovingImageInterpolator);
71 m_MovingImageWarper->SetEdgePaddingValue( NumericTraits< MovingPixelType >::max() );
72
73 m_MovingImageWarperOutput = nullptr;
74
75 m_Metric = NumericTraits< double >::max();
76 m_SumOfSquaredDifference = 0.0;
77 m_NumberOfPixelsProcessed = 0L;
78 m_RMSChange = NumericTraits< double >::max();
79 m_SumOfSquaredChange = 0.0;
80 }
81
82 /*
83 * Standard "PrintSelf" method.
84 */
85 template< typename TFixedImage, typename TMovingImage, typename TDisplacementField >
86 void
87 ESMDemonsRegistrationFunction< TFixedImage, TMovingImage, TDisplacementField >
PrintSelf(std::ostream & os,Indent indent) const88 ::PrintSelf(std::ostream & os, Indent indent) const
89 {
90 Superclass::PrintSelf(os, indent);
91
92 os << indent << "UseGradientType: ";
93 os << m_UseGradientType << std::endl;
94 os << indent << "MaximumUpdateStepLength: ";
95 os << m_MaximumUpdateStepLength << std::endl;
96
97 os << indent << "MovingImageIterpolator: ";
98 os << m_MovingImageInterpolator.GetPointer() << std::endl;
99 os << indent << "FixedImageGradientCalculator: ";
100 os << m_FixedImageGradientCalculator.GetPointer() << std::endl;
101 os << indent << "MappedMovingImageGradientCalculator: ";
102 os << m_MappedMovingImageGradientCalculator.GetPointer() << std::endl;
103 os << indent << "DenominatorThreshold: ";
104 os << m_DenominatorThreshold << std::endl;
105 os << indent << "IntensityDifferenceThreshold: ";
106 os << m_IntensityDifferenceThreshold << std::endl;
107
108 os << indent << "Metric: ";
109 os << m_Metric << std::endl;
110 os << indent << "SumOfSquaredDifference: ";
111 os << m_SumOfSquaredDifference << std::endl;
112 os << indent << "NumberOfPixelsProcessed: ";
113 os << m_NumberOfPixelsProcessed << std::endl;
114 os << indent << "RMSChange: ";
115 os << m_RMSChange << std::endl;
116 os << indent << "SumOfSquaredChange: ";
117 os << m_SumOfSquaredChange << std::endl;
118 }
119
120 /**
121 *
122 */
123 template< typename TFixedImage, typename TMovingImage, typename TDisplacementField >
124 void
125 ESMDemonsRegistrationFunction< TFixedImage, TMovingImage, TDisplacementField >
SetIntensityDifferenceThreshold(double threshold)126 ::SetIntensityDifferenceThreshold(double threshold)
127 {
128 m_IntensityDifferenceThreshold = threshold;
129 }
130
131 /**
132 *
133 */
134 template< typename TFixedImage, typename TMovingImage, typename TDisplacementField >
135 double
136 ESMDemonsRegistrationFunction< TFixedImage, TMovingImage, TDisplacementField >
GetIntensityDifferenceThreshold() const137 ::GetIntensityDifferenceThreshold() const
138 {
139 return m_IntensityDifferenceThreshold;
140 }
141
142 /**
143 * Set the function state values before each iteration
144 */
145 template< typename TFixedImage, typename TMovingImage, typename TDisplacementField >
146 void
147 ESMDemonsRegistrationFunction< TFixedImage, TMovingImage, TDisplacementField >
InitializeIteration()148 ::InitializeIteration()
149 {
150 if ( !this->GetMovingImage() || !this->GetFixedImage()
151 || !m_MovingImageInterpolator )
152 {
153 itkExceptionMacro(
154 << "MovingImage, FixedImage and/or Interpolator not set");
155 }
156
157 // cache fixed image information
158 m_FixedImageOrigin = this->GetFixedImage()->GetOrigin();
159 m_FixedImageSpacing = this->GetFixedImage()->GetSpacing();
160 m_FixedImageDirection = this->GetFixedImage()->GetDirection();
161
162 // compute the normalizer
163 if ( m_MaximumUpdateStepLength > 0.0 )
164 {
165 m_Normalizer = 0.0;
166 for ( unsigned int k = 0; k < ImageDimension; k++ )
167 {
168 m_Normalizer += m_FixedImageSpacing[k] * m_FixedImageSpacing[k];
169 }
170 m_Normalizer *= m_MaximumUpdateStepLength * m_MaximumUpdateStepLength
171 / static_cast< double >( ImageDimension );
172 }
173 else
174 {
175 // set it to minus one to denote a special case
176 // ( unrestricted update length )
177 m_Normalizer = -1.0;
178 }
179
180 // setup gradient calculator
181 m_FixedImageGradientCalculator->SetInputImage( this->GetFixedImage() );
182 m_MappedMovingImageGradientCalculator->SetInputImage( this->GetMovingImage() );
183
184 // Compute warped moving image
185 m_MovingImageWarper->SetOutputOrigin(this->m_FixedImageOrigin);
186 m_MovingImageWarper->SetOutputSpacing(this->m_FixedImageSpacing);
187 m_MovingImageWarper->SetOutputDirection(this->m_FixedImageDirection);
188 m_MovingImageWarper->SetInput( this->GetMovingImage() );
189 m_MovingImageWarper->SetDisplacementField( this->GetDisplacementField() );
190 m_MovingImageWarper->GetOutput()->SetRequestedRegion( this->GetDisplacementField()->GetRequestedRegion() );
191 m_MovingImageWarper->Update();
192 this->m_MovingImageWarperOutput =
193 this->m_MovingImageWarper->GetOutput();
194 // setup moving image interpolator for further access
195 m_MovingImageInterpolator->SetInputImage( this->GetMovingImage() );
196
197 // initialize metric computation variables
198 m_SumOfSquaredDifference = 0.0;
199 m_NumberOfPixelsProcessed = 0L;
200 m_SumOfSquaredChange = 0.0;
201 }
202
203 /**
204 * Compute update at a non boundary neighbourhood
205 */
206 template< typename TFixedImage, typename TMovingImage, typename TDisplacementField >
207 typename ESMDemonsRegistrationFunction< TFixedImage, TMovingImage, TDisplacementField >
208 ::PixelType
209 ESMDemonsRegistrationFunction< TFixedImage, TMovingImage, TDisplacementField >
ComputeUpdate(const NeighborhoodType & it,void * gd,const FloatOffsetType & itkNotUsed (offset))210 ::ComputeUpdate( const NeighborhoodType & it, void *gd,
211 const FloatOffsetType & itkNotUsed(offset) )
212 {
213 auto * globalData = (GlobalDataStruct *)gd;
214 PixelType update;
215 IndexType FirstIndex = this->GetFixedImage()->GetLargestPossibleRegion().GetIndex();
216 IndexType LastIndex = this->GetFixedImage()->GetLargestPossibleRegion().GetIndex()
217 + this->GetFixedImage()->GetLargestPossibleRegion().GetSize();
218
219 const IndexType index = it.GetIndex();
220
221 // Get fixed image related information
222 // Note: no need to check if the index is within
223 // fixed image buffer. This is done by the external filter.
224 const auto fixedValue = static_cast< double >( this->GetFixedImage()->GetPixel(index) );
225
226 // Get moving image related information
227 // check if the point was mapped outside of the moving image using
228 // the "special value" NumericTraits<MovingPixelType>::max()
229 MovingPixelType movingPixValue =
230 m_MovingImageWarperOutput->GetPixel(index);
231
232 if ( movingPixValue == NumericTraits< MovingPixelType >::max() )
233 {
234 update.Fill(0.0);
235 return update;
236 }
237
238 const auto movingValue = static_cast< double >( movingPixValue );
239
240 // We compute the gradient more or less by hand.
241 // We first start by ignoring the image orientation and introduce it
242 // afterwards
243 CovariantVectorType usedOrientFreeGradientTimes2;
244
245 if ( ( this->m_UseGradientType == Symmetric )
246 || ( this->m_UseGradientType == WarpedMoving ) )
247 {
248 // we don't use a CentralDifferenceImageFunction here to be able to
249 // check for NumericTraits<MovingPixelType>::max()
250 CovariantVectorType warpedMovingGradient;
251 IndexType tmpIndex = index;
252 for ( unsigned int dim = 0; dim < ImageDimension; dim++ )
253 {
254 // bounds checking
255 if ( FirstIndex[dim] == LastIndex[dim]
256 || index[dim] < FirstIndex[dim]
257 || index[dim] >= LastIndex[dim] )
258 {
259 warpedMovingGradient[dim] = 0.0;
260 continue;
261 }
262 else if ( index[dim] == FirstIndex[dim] )
263 {
264 // compute derivative
265 tmpIndex[dim] += 1;
266 movingPixValue = m_MovingImageWarperOutput->GetPixel(tmpIndex);
267 if ( movingPixValue == NumericTraits< MovingPixelType >::max() )
268 {
269 // weird crunched border case
270 warpedMovingGradient[dim] = 0.0;
271 }
272 else
273 {
274 // forward difference
275 warpedMovingGradient[dim] = static_cast< double >( movingPixValue ) - movingValue;
276 warpedMovingGradient[dim] /= m_FixedImageSpacing[dim];
277 }
278 tmpIndex[dim] -= 1;
279 continue;
280 }
281 else if ( index[dim] == ( LastIndex[dim] - 1 ) )
282 {
283 // compute derivative
284 tmpIndex[dim] -= 1;
285 movingPixValue = m_MovingImageWarperOutput->GetPixel(tmpIndex);
286 if ( movingPixValue == NumericTraits< MovingPixelType >::max() )
287 {
288 // weird crunched border case
289 warpedMovingGradient[dim] = 0.0;
290 }
291 else
292 {
293 // backward difference
294 warpedMovingGradient[dim] = movingValue - static_cast< double >( movingPixValue );
295 warpedMovingGradient[dim] /= m_FixedImageSpacing[dim];
296 }
297 tmpIndex[dim] += 1;
298 continue;
299 }
300
301 // compute derivative
302 tmpIndex[dim] += 1;
303 movingPixValue = m_MovingImageWarperOutput->GetPixel(tmpIndex);
304 if ( movingPixValue == NumericTraits
305 < MovingPixelType >::max() )
306 {
307 // backward difference
308 warpedMovingGradient[dim] = movingValue;
309
310 tmpIndex[dim] -= 2;
311 movingPixValue = m_MovingImageWarperOutput->GetPixel(tmpIndex);
312 if ( movingPixValue == NumericTraits< MovingPixelType >::max() )
313 {
314 // weird crunched border case
315 warpedMovingGradient[dim] = 0.0;
316 }
317 else
318 {
319 // backward difference
320 warpedMovingGradient[dim] -= static_cast< double >(
321 m_MovingImageWarperOutput->GetPixel(tmpIndex) );
322
323 warpedMovingGradient[dim] /= m_FixedImageSpacing[dim];
324 }
325 }
326 else
327 {
328 warpedMovingGradient[dim] = static_cast< double >( movingPixValue );
329
330 tmpIndex[dim] -= 2;
331 movingPixValue = m_MovingImageWarperOutput->GetPixel(tmpIndex);
332 if ( movingPixValue == NumericTraits< MovingPixelType >::max() )
333 {
334 // forward difference
335 warpedMovingGradient[dim] -= movingValue;
336 warpedMovingGradient[dim] /= m_FixedImageSpacing[dim];
337 }
338 else
339 {
340 // normal case, central difference
341 warpedMovingGradient[dim] -= static_cast< double >( movingPixValue );
342 warpedMovingGradient[dim] *= 0.5 / m_FixedImageSpacing[dim];
343 }
344 }
345 tmpIndex[dim] += 1;
346 }
347
348 if ( this->m_UseGradientType == Symmetric )
349 {
350 // Compute orientation-free gradient with calculator
351 const CovariantVectorType fixedGradient =
352 m_FixedImageGradientCalculator->EvaluateAtIndex(index);
353
354 usedOrientFreeGradientTimes2 = fixedGradient + warpedMovingGradient;
355 }
356 else if ( this->m_UseGradientType == WarpedMoving )
357 {
358 usedOrientFreeGradientTimes2 = warpedMovingGradient + warpedMovingGradient;
359 }
360 else
361 {
362 itkExceptionMacro(<< "Unknown gradient type");
363 }
364 }
365 else if ( this->m_UseGradientType == Fixed )
366 {
367 // Compute orientation-free gradient with calculator
368 const CovariantVectorType fixedGradient =
369 m_FixedImageGradientCalculator->EvaluateAtIndex(index);
370
371 usedOrientFreeGradientTimes2 = fixedGradient + fixedGradient;
372 }
373 else if ( this->m_UseGradientType == MappedMoving )
374 {
375 PointType mappedPoint;
376 this->GetFixedImage()->TransformIndexToPhysicalPoint(index, mappedPoint);
377 for ( unsigned int j = 0; j < ImageDimension; j++ )
378 {
379 mappedPoint[j] += it.GetCenterPixel()[j];
380 }
381
382 const CovariantVectorType mappedMovingGradient =
383 m_MappedMovingImageGradientCalculator->Evaluate(mappedPoint);
384
385 usedOrientFreeGradientTimes2 = mappedMovingGradient + mappedMovingGradient;
386 }
387 else
388 {
389 itkExceptionMacro(<< "Unknown gradient type");
390 }
391
392 CovariantVectorType usedGradientTimes2;
393 this->GetFixedImage()->TransformLocalVectorToPhysicalVector(
394 usedOrientFreeGradientTimes2, usedGradientTimes2);
395
396 /**
397 * Compute Update.
398 * We avoid the mismatch in units between the two terms.
399 * and avoid large step using a normalization term.
400 */
401
402 const double usedGradientTimes2SquaredMagnitude =
403 usedGradientTimes2.GetSquaredNorm();
404
405 const double speedValue = fixedValue - movingValue;
406 if ( itk::Math::abs(speedValue) < m_IntensityDifferenceThreshold )
407 {
408 update.Fill(0.0);
409 }
410 else
411 {
412 double denom;
413 if ( m_Normalizer > 0.0 )
414 {
415 // "ITK-Thirion" normalization
416 denom = usedGradientTimes2SquaredMagnitude + ( itk::Math::sqr(speedValue) / m_Normalizer );
417 }
418 else
419 {
420 // least square solution of the system
421 denom = usedGradientTimes2SquaredMagnitude;
422 }
423
424 if ( denom < m_DenominatorThreshold )
425 {
426 update.Fill(0.0);
427 }
428 else
429 {
430 const double factor = 2.0 * speedValue / denom;
431
432 for ( unsigned int j = 0; j < ImageDimension; j++ )
433 {
434 update[j] = factor * usedGradientTimes2[j];
435 }
436 }
437 }
438
439 // WARNING!! We compute the global data without taking into account the
440 // current update step.
441 // There are several reasons for that: If an exponential, a smoothing or any
442 // other operation
443 // is applied on the update field, we cannot compute the newMappedCenterPoint
444 // here; and even
445 // if we could, this would be an often unnecessary time-consuming task.
446 if ( globalData )
447 {
448 globalData->m_SumOfSquaredDifference += itk::Math::sqr(speedValue);
449 globalData->m_NumberOfPixelsProcessed += 1;
450 globalData->m_SumOfSquaredChange += update.GetSquaredNorm();
451 }
452
453 return update;
454 }
455
456 /**
457 * Update the metric and release the per-thread-global data.
458 */
459 template< typename TFixedImage, typename TMovingImage, typename TDisplacementField >
460 void
461 ESMDemonsRegistrationFunction< TFixedImage, TMovingImage, TDisplacementField >
ReleaseGlobalDataPointer(void * gd) const462 ::ReleaseGlobalDataPointer(void *gd) const
463 {
464 auto * globalData = (GlobalDataStruct *)gd;
465
466 m_MetricCalculationLock.lock();
467 m_SumOfSquaredDifference += globalData->m_SumOfSquaredDifference;
468 m_NumberOfPixelsProcessed += globalData->m_NumberOfPixelsProcessed;
469 m_SumOfSquaredChange += globalData->m_SumOfSquaredChange;
470 if ( m_NumberOfPixelsProcessed )
471 {
472 m_Metric = m_SumOfSquaredDifference
473 / static_cast< double >( m_NumberOfPixelsProcessed );
474 m_RMSChange = std::sqrt( m_SumOfSquaredChange
475 / static_cast< double >( m_NumberOfPixelsProcessed ) );
476 }
477 m_MetricCalculationLock.unlock();
478
479 delete globalData;
480 }
481 } // end namespace itk
482
483 #endif
484