1  /*=========================================================================
2  *
3  *  Copyright Insight Software Consortium
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at
8  *
9  *         http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  *=========================================================================*/
18 #ifndef itkCorrelationImageToImageMetricv4GetValueAndDerivativeThreader_hxx
19 #define itkCorrelationImageToImageMetricv4GetValueAndDerivativeThreader_hxx
20 
21 #include "itkCorrelationImageToImageMetricv4GetValueAndDerivativeThreader.h"
22 
23 namespace itk
24 {
25 
26 template<typename TDomainPartitioner, typename TImageToImageMetric, typename TCorrelationMetric>
27 CorrelationImageToImageMetricv4GetValueAndDerivativeThreader< TDomainPartitioner, TImageToImageMetric, TCorrelationMetric>
CorrelationImageToImageMetricv4GetValueAndDerivativeThreader()28 ::CorrelationImageToImageMetricv4GetValueAndDerivativeThreader() :
29   m_CorrelationMetricValueDerivativePerThreadVariables( nullptr ),
30   m_CorrelationAssociate( nullptr )
31 {}
32 
33 
34 template<typename TDomainPartitioner, typename TImageToImageMetric, typename TCorrelationMetric>
35 CorrelationImageToImageMetricv4GetValueAndDerivativeThreader< TDomainPartitioner, TImageToImageMetric, TCorrelationMetric>
~CorrelationImageToImageMetricv4GetValueAndDerivativeThreader()36 ::~CorrelationImageToImageMetricv4GetValueAndDerivativeThreader()
37 {
38   delete[] m_CorrelationMetricValueDerivativePerThreadVariables;
39 }
40 
41 
42 template<typename TDomainPartitioner, typename TImageToImageMetric, typename TCorrelationMetric>
43 void
44 CorrelationImageToImageMetricv4GetValueAndDerivativeThreader< TDomainPartitioner, TImageToImageMetric, TCorrelationMetric>
BeforeThreadedExecution()45 ::BeforeThreadedExecution()
46 {
47   Superclass::BeforeThreadedExecution();
48 
49   /* Store the casted pointer to avoid dynamic casting in tight loops. */
50   this->m_CorrelationAssociate = dynamic_cast<TCorrelationMetric *>(this->m_Associate);
51   if( this->m_CorrelationAssociate == nullptr )
52     {
53     itkExceptionMacro("Dynamic casting of associate pointer failed.");
54     }
55 
56   /* This size always comes from the moving image */
57   const NumberOfParametersType globalDerivativeSize = this->GetCachedNumberOfParameters();
58 
59   const ThreadIdType numThreadsUsed = this->GetNumberOfWorkUnitsUsed();
60   // set size
61   delete[] m_CorrelationMetricValueDerivativePerThreadVariables;
62   m_CorrelationMetricValueDerivativePerThreadVariables = new AlignedCorrelationMetricValueDerivativePerThreadStruct[ numThreadsUsed ];
63   for (ThreadIdType i = 0; i < numThreadsUsed; ++i)
64     {
65     this->m_CorrelationMetricValueDerivativePerThreadVariables[i].fdm.SetSize(globalDerivativeSize);
66     this->m_CorrelationMetricValueDerivativePerThreadVariables[i].mdm.SetSize(globalDerivativeSize);
67     }
68 
69   //---------------------------------------------------------------
70   // Set initial values.
71   for (ThreadIdType i = 0; i < numThreadsUsed; ++i)
72     {
73     m_CorrelationMetricValueDerivativePerThreadVariables[i].fm = NumericTraits<InternalComputationValueType>::ZeroValue();
74     m_CorrelationMetricValueDerivativePerThreadVariables[i].f2 = NumericTraits<InternalComputationValueType>::ZeroValue();
75     m_CorrelationMetricValueDerivativePerThreadVariables[i].m2 = NumericTraits<InternalComputationValueType>::ZeroValue();
76     m_CorrelationMetricValueDerivativePerThreadVariables[i].f = NumericTraits<InternalComputationValueType>::ZeroValue();
77     m_CorrelationMetricValueDerivativePerThreadVariables[i].m = NumericTraits<InternalComputationValueType>::ZeroValue();
78 
79     this->m_CorrelationMetricValueDerivativePerThreadVariables[i].mdm.Fill(NumericTraits<DerivativeValueType>::ZeroValue());
80     this->m_CorrelationMetricValueDerivativePerThreadVariables[i].fdm.Fill(NumericTraits<DerivativeValueType>::ZeroValue());
81     }
82 
83 }
84 
85 template<typename TDomainPartitioner, typename TImageToImageMetric, typename TCorrelationMetric>
86 void
87 CorrelationImageToImageMetricv4GetValueAndDerivativeThreader<TDomainPartitioner, TImageToImageMetric, TCorrelationMetric>
AfterThreadedExecution()88 ::AfterThreadedExecution()
89 {
90 
91   /* This size always comes from the moving image */
92   const NumberOfParametersType globalDerivativeSize = this->GetCachedNumberOfParameters();
93   const ThreadIdType numThreadsUsed = this->GetNumberOfWorkUnitsUsed();
94 
95   /* Store the number of valid points the enclosing class \c
96    * m_NumberOfValidPoints by collecting the valid points per thread. */
97   this->m_CorrelationAssociate->m_NumberOfValidPoints = NumericTraits<SizeValueType>::ZeroValue();
98   for (ThreadIdType i = 0; i < numThreadsUsed; ++i)
99     {
100     this->m_CorrelationAssociate->m_NumberOfValidPoints += this->m_GetValueAndDerivativePerThreadVariables[i].NumberOfValidPoints;
101     }
102 
103   /* Check the number of valid points meets the default minimum.
104    * If not, parameters will hold default return values for this case */
105   if( ! this->m_CorrelationAssociate->VerifyNumberOfValidPoints( this->m_CorrelationAssociate->m_Value, *(this->m_CorrelationAssociate->m_DerivativeResult) ) )
106     {
107     return;
108     }
109 
110   itkDebugMacro("CorrelationImageToImageMetricv4: NumberOfValidPoints: " << this->m_CorrelationAssociate->m_NumberOfValidPoints);
111 
112   /* Accumulate the metric value from threads and store */
113   this->m_CorrelationAssociate->m_Value = NumericTraits<InternalComputationValueType>::ZeroValue();
114   InternalComputationValueType fm = NumericTraits<InternalComputationValueType>::ZeroValue();
115   InternalComputationValueType f2 = NumericTraits<InternalComputationValueType>::ZeroValue();
116   InternalComputationValueType m2 = NumericTraits<InternalComputationValueType>::ZeroValue();
117   for (ThreadIdType threadId = 0; threadId < numThreadsUsed; ++threadId)
118     {
119     fm += this->m_CorrelationMetricValueDerivativePerThreadVariables[threadId].fm;
120     m2 += this->m_CorrelationMetricValueDerivativePerThreadVariables[threadId].m2;
121     f2 += this->m_CorrelationMetricValueDerivativePerThreadVariables[threadId].f2;
122     }
123 
124   InternalComputationValueType m2f2 = m2 * f2;
125   if ( m2f2 <= NumericTraits<InternalComputationValueType>::epsilon() )
126     {
127     itkDebugMacro( "CorrelationImageToImageMetricv4: m2 * f2 <= epsilon");
128     return;
129     }
130 
131   this->m_CorrelationAssociate->m_Value = -1.0 * fm * fm / (m2f2);
132 
133   /* For global transforms, compute the derivatives by combining values from each region. */
134   if( this->m_CorrelationAssociate->GetComputeDerivative() )
135     {
136     DerivativeType fdm, mdm;
137     fdm.SetSize(globalDerivativeSize);
138     mdm.SetSize(globalDerivativeSize);
139 
140     fdm.Fill(NumericTraits<DerivativeValueType>::ZeroValue());
141     mdm.Fill(NumericTraits<DerivativeValueType>::ZeroValue());
142 
143     const auto fc = static_cast<InternalComputationValueType>( 2.0 );
144 
145     for (ThreadIdType i = 0; i < numThreadsUsed; ++i)
146       {
147       fdm += this->m_CorrelationMetricValueDerivativePerThreadVariables[i].fdm;
148       mdm += this->m_CorrelationMetricValueDerivativePerThreadVariables[i].mdm;
149       }
150 
151     /** There should be a minus sign of \frac{d}{dp} mathematically, which
152      *  is not in the implementation to match the requirement of the metricv4
153      *  optimization framework.
154      *
155      *  We use += instead of assignment here because for multi-variate vector,
156      *  we will want to always add to the values in m_DerivativeResult so they
157      *  can be efficiently accumulated between multiple metrics.
158      */
159     *(this->m_CorrelationAssociate->m_DerivativeResult) += fc *fm/(f2*m2)*(fdm - fm/m2*mdm);
160     }
161 
162 }
163 
164 template<typename TDomainPartitioner, typename TImageToImageMetric, typename TCorrelationMetric>
165 bool
166 CorrelationImageToImageMetricv4GetValueAndDerivativeThreader<TDomainPartitioner, TImageToImageMetric, TCorrelationMetric>
ProcessVirtualPoint(const VirtualIndexType & virtualIndex,const VirtualPointType & virtualPoint,const ThreadIdType threadId)167 ::ProcessVirtualPoint( const VirtualIndexType & virtualIndex, const VirtualPointType & virtualPoint, const ThreadIdType threadId )
168 {
169   FixedImagePointType         mappedFixedPoint;
170   FixedImagePixelType         mappedFixedPixelValue;
171   FixedImageGradientType      mappedFixedImageGradient;
172   MovingImagePointType        mappedMovingPoint;
173   MovingImagePixelType        mappedMovingPixelValue;
174   MovingImageGradientType     mappedMovingImageGradient;
175   bool                        pointIsValid = false;
176   MeasureType                 metricValueResult;
177 
178   /* Transform the point into fixed and moving spaces, and evaluate.
179    * Different behavior with pre-warping enabled is handled transparently.
180    * Do this in a try block to catch exceptions and print more useful info
181    * then we otherwise get when exceptions are caught in MultiThreaderBase. */
182   try
183     {
184     pointIsValid = this->m_CorrelationAssociate->TransformAndEvaluateFixedPoint( virtualPoint, mappedFixedPoint, mappedFixedPixelValue );
185     if( pointIsValid &&
186         this->m_CorrelationAssociate->GetComputeDerivative() &&
187         this->m_CorrelationAssociate->GetGradientSourceIncludesFixed() )
188       {
189       this->m_CorrelationAssociate->ComputeFixedImageGradientAtPoint( mappedFixedPoint, mappedFixedImageGradient );
190       }
191     }
192   catch( ExceptionObject & exc )
193     {
194     //NOTE: there must be a cleaner way to do this:
195     std::string msg("Caught exception: \n");
196     msg += exc.what();
197     ExceptionObject err(__FILE__, __LINE__, msg);
198     throw err;
199     }
200   if( !pointIsValid )
201     {
202     return pointIsValid;
203     }
204 
205   try
206     {
207     pointIsValid = this->m_CorrelationAssociate->TransformAndEvaluateMovingPoint( virtualPoint, mappedMovingPoint, mappedMovingPixelValue );
208     if( pointIsValid &&
209         this->m_CorrelationAssociate->GetComputeDerivative() &&
210         this->m_CorrelationAssociate->GetGradientSourceIncludesMoving() )
211       {
212       this->m_CorrelationAssociate->ComputeMovingImageGradientAtPoint( mappedMovingPoint, mappedMovingImageGradient );
213       }
214     }
215   catch( ExceptionObject & exc )
216     {
217     std::string msg("Caught exception: \n");
218     msg += exc.what();
219     ExceptionObject err(__FILE__, __LINE__, msg);
220     throw err;
221     }
222   if( !pointIsValid )
223     {
224     return pointIsValid;
225     }
226 
227   /* Call the user method in derived classes to do the specific
228    * calculations for value and derivative. */
229   try
230     {
231     pointIsValid = this->ProcessPoint(
232                                    virtualIndex,
233                                    virtualPoint,
234                                    mappedFixedPoint, mappedFixedPixelValue,
235                                    mappedFixedImageGradient,
236                                    mappedMovingPoint, mappedMovingPixelValue,
237                                    mappedMovingImageGradient,
238                                    metricValueResult, this->m_GetValueAndDerivativePerThreadVariables[threadId].LocalDerivatives,
239                                    threadId );
240     }
241   catch( ExceptionObject & exc )
242     {
243     //NOTE: there must be a cleaner way to do this:
244     std::string msg("Exception in GetValueAndDerivativeProcessPoint:\n");
245     msg += exc.what();
246     ExceptionObject err(__FILE__, __LINE__, msg);
247     throw err;
248     }
249   if( pointIsValid )
250     {
251     this->m_GetValueAndDerivativePerThreadVariables[threadId].NumberOfValidPoints++;
252     }
253 
254   return pointIsValid;
255 }
256 
257 template<typename TDomainPartitioner, typename TImageToImageMetric, typename TCorrelationMetric>
258 bool
259 CorrelationImageToImageMetricv4GetValueAndDerivativeThreader<TDomainPartitioner, TImageToImageMetric, TCorrelationMetric>
ProcessPoint(const VirtualIndexType & itkNotUsed (virtualIndex),const VirtualPointType & virtualPoint,const FixedImagePointType & itkNotUsed (mappedFixedPoint),const FixedImagePixelType & fixedImageValue,const FixedImageGradientType & itkNotUsed (mappedFixedImageGradient),const MovingImagePointType & itkNotUsed (mappedMovingPoint),const MovingImagePixelType & movingImageValue,const MovingImageGradientType & movingImageGradient,MeasureType & itkNotUsed (metricValueReturn),DerivativeType & itkNotUsed (localDerivativeReturn),const ThreadIdType threadId) const260 ::ProcessPoint( const VirtualIndexType &           itkNotUsed(virtualIndex),
261                 const VirtualPointType &           virtualPoint,
262                 const FixedImagePointType &        itkNotUsed(mappedFixedPoint),
263                 const FixedImagePixelType &        fixedImageValue,
264                 const FixedImageGradientType &     itkNotUsed(mappedFixedImageGradient),
265                 const MovingImagePointType &       itkNotUsed(mappedMovingPoint),
266                 const MovingImagePixelType &       movingImageValue,
267                 const MovingImageGradientType &    movingImageGradient,
268                 MeasureType &                      itkNotUsed(metricValueReturn),
269                 DerivativeType &                   itkNotUsed(localDerivativeReturn),
270                 const ThreadIdType                 threadId) const
271 {
272 
273   /*
274    * metricValueReturn and localDerivativeReturn will not be computed here.
275    * Instead, m_CorrelationMetricValueDerivativePerThreadVariables will store temporary results for each thread
276    * and finally compute metric and derivative in overloaded AfterThreadedExecution
277    */
278 
279   /* subtract the average of pixels (computed during InitializeIteration) */
280   const InternalComputationValueType & f1 = fixedImageValue - this->m_CorrelationAssociate->m_AverageFix;
281   const InternalComputationValueType & m1 = movingImageValue - this->m_CorrelationAssociate->m_AverageMov;
282 
283   AlignedCorrelationMetricValueDerivativePerThreadStruct & cumsum = this->m_CorrelationMetricValueDerivativePerThreadVariables[threadId];
284   cumsum.f += f1;
285   cumsum.m += m1;
286   cumsum.f2 += f1 * f1;
287   cumsum.m2 += m1 * m1;
288   cumsum.fm += f1 * m1;
289 
290   if( this->m_CorrelationAssociate->GetComputeDerivative() )
291     {
292     /* Use a pre-allocated jacobian object for efficiency */
293     using JacobianReferenceType = typename TImageToImageMetric::JacobianType &;
294     JacobianReferenceType jacobian = this->m_GetValueAndDerivativePerThreadVariables[threadId].MovingTransformJacobian;
295     JacobianReferenceType jacobianPositional = this->m_GetValueAndDerivativePerThreadVariables[threadId].MovingTransformJacobianPositional;
296 
297     /** For dense transforms, this returns identity */
298     this->m_CorrelationAssociate->GetMovingTransform()->
299       ComputeJacobianWithRespectToParametersCachedTemporaries(virtualPoint,
300                                                               jacobian,
301                                                               jacobianPositional);
302 
303     for (unsigned int par = 0; par < this->m_CorrelationAssociate->GetNumberOfLocalParameters(); par++)
304       {
305       InternalComputationValueType sum = NumericTraits< InternalComputationValueType >::ZeroValue();
306       for (SizeValueType dim = 0; dim < ImageToImageMetricv4Type::MovingImageDimension; dim++)
307         {
308         sum += movingImageGradient[dim] * jacobian(dim, par);
309         }
310 
311       cumsum.fdm[par] += f1 * sum;
312       cumsum.mdm[par] += m1 * sum;
313       }
314     }
315 
316   return true;
317 }
318 
319 } // end namespace itk
320 
321 #endif
322