1 /*=========================================================================
2  *
3  *  Copyright Insight Software Consortium
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at
8  *
9  *         http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  *=========================================================================*/
18 #ifndef itkImageToImageMetricv4GetValueAndDerivativeThreaderBase_hxx
19 #define itkImageToImageMetricv4GetValueAndDerivativeThreaderBase_hxx
20 
21 #include "itkImageToImageMetricv4GetValueAndDerivativeThreaderBase.h"
22 #include "itkNumericTraits.h"
23 
24 namespace itk
25 {
26 
27 template< typename TDomainPartitioner, typename TImageToImageMetricv4 >
28 ImageToImageMetricv4GetValueAndDerivativeThreaderBase< TDomainPartitioner, TImageToImageMetricv4 >
ImageToImageMetricv4GetValueAndDerivativeThreaderBase()29 ::ImageToImageMetricv4GetValueAndDerivativeThreaderBase():
30   m_GetValueAndDerivativePerThreadVariables( nullptr ),
31   m_CachedNumberOfParameters( 0 ),
32   m_CachedNumberOfLocalParameters( 0 )
33 {
34 }
35 
36 template< typename TDomainPartitioner, typename TImageToImageMetricv4 >
37 ImageToImageMetricv4GetValueAndDerivativeThreaderBase< TDomainPartitioner, TImageToImageMetricv4 >
~ImageToImageMetricv4GetValueAndDerivativeThreaderBase()38 ::~ImageToImageMetricv4GetValueAndDerivativeThreaderBase()
39 {
40   delete[] m_GetValueAndDerivativePerThreadVariables;
41 }
42 
43 template< typename TDomainPartitioner, typename TImageToImageMetricv4 >
44 void
45 ImageToImageMetricv4GetValueAndDerivativeThreaderBase< TDomainPartitioner, TImageToImageMetricv4 >
BeforeThreadedExecution()46 ::BeforeThreadedExecution()
47 {
48   //---------------------------------------------------------------
49   // Resize the per thread memory objects.
50   //-----------------------------------------------------------------
51   // Cache some values
52   this->m_CachedNumberOfParameters      = this->m_Associate->GetNumberOfParameters();
53   this->m_CachedNumberOfLocalParameters = this->m_Associate->GetNumberOfLocalParameters();
54 
55   /* Per-thread results */
56   const ThreadIdType numThreadsUsed = this->GetNumberOfWorkUnitsUsed();
57   delete[] m_GetValueAndDerivativePerThreadVariables;
58   this->m_GetValueAndDerivativePerThreadVariables = new AlignedGetValueAndDerivativePerThreadStruct[ numThreadsUsed ];
59 
60   if( this->m_Associate->GetComputeDerivative() )
61     {
62     for (ThreadIdType i = 0; i < numThreadsUsed; ++i)
63       {
64       /* Allocate intermediary per-thread storage used to get results from
65        * derived classes */
66       this->m_GetValueAndDerivativePerThreadVariables[i].LocalDerivatives.SetSize( this->m_CachedNumberOfLocalParameters );
67       this->m_GetValueAndDerivativePerThreadVariables[i].MovingTransformJacobian.SetSize(
68         this->m_Associate->VirtualImageDimension, this->m_CachedNumberOfLocalParameters );
69       // Not pre-allocated since it may not be used
70       //this->m_GetValueAndDerivativePerThreadVariables[i].MovingTransformJacobianPositional
71       if ( this->m_Associate->m_MovingTransform->GetTransformCategory() == MovingTransformType::DisplacementField )
72         {
73         /* For transforms with local support, e.g. displacement field,
74          * use a single derivative container that's updated by region
75          * in multiple threads.
76          * Initialization to zero is done in main class. */
77         itkDebugMacro( "ImageToImageMetricv4::Initialize: transform HAS local support\n" );
78         /* Set each per-thread object to point to m_DerivativeResult for efficiency. */
79         this->m_GetValueAndDerivativePerThreadVariables[i].Derivatives.SetData( this->m_Associate->m_DerivativeResult->data_block(),
80           this->m_Associate->m_DerivativeResult->Size(), false );
81         }
82       else
83         {
84         itkDebugMacro("ImageToImageMetricv4::Initialize: transform does NOT have local support\n");
85         /* This size always comes from the moving image */
86         const NumberOfParametersType globalDerivativeSize = this->m_CachedNumberOfParameters;
87         /* Global transforms get a separate derivatives container for each thread
88          * that holds the result over a particular image region.
89          * Use a CompensatedSummation value to provide for better consistency between
90          * different number of threads. */
91         this->m_GetValueAndDerivativePerThreadVariables[i].CompensatedDerivatives.resize( globalDerivativeSize );
92         }
93       }
94     }
95 
96   //---------------------------------------------------------------
97   // Set initial values.
98   for (ThreadIdType thread = 0; thread < numThreadsUsed; ++thread)
99     {
100     this->m_GetValueAndDerivativePerThreadVariables[thread].NumberOfValidPoints = NumericTraits< SizeValueType >::ZeroValue();
101     this->m_GetValueAndDerivativePerThreadVariables[thread].Measure = NumericTraits< InternalComputationValueType >::ZeroValue();
102     if( this->m_Associate->GetComputeDerivative() )
103       {
104       if ( this->m_Associate->m_MovingTransform->GetTransformCategory() != MovingTransformType::DisplacementField )
105         {
106         /* Be sure to init to 0 here, because the threader may not use
107          * all the threads if the region is better split into fewer
108          * subregions. */
109         for( NumberOfParametersType p = 0; p < this->m_CachedNumberOfParameters; p++ )
110           {
111           this->m_GetValueAndDerivativePerThreadVariables[thread].CompensatedDerivatives[p].ResetToZero();
112           }
113         }
114       }
115     }
116 }
117 
118 template< typename TDomainPartitioner, typename TImageToImageMetricv4 >
119 void
120 ImageToImageMetricv4GetValueAndDerivativeThreaderBase< TDomainPartitioner, TImageToImageMetricv4 >
AfterThreadedExecution()121 ::AfterThreadedExecution()
122 {
123   const ThreadIdType numThreadsUsed = this->GetNumberOfWorkUnitsUsed();
124   /* Store the number of valid points the enclosing class \c
125    * m_NumberOfValidPoints by collecting the valid points per thread. */
126   this->m_Associate->m_NumberOfValidPoints = NumericTraits< SizeValueType >::ZeroValue();
127   for (ThreadIdType i = 0; i < numThreadsUsed; ++i)
128     {
129     this->m_Associate->m_NumberOfValidPoints += this->m_GetValueAndDerivativePerThreadVariables[i].NumberOfValidPoints;
130     }
131   itkDebugMacro( "ImageToImageMetricv4: NumberOfValidPoints: " << this->m_Associate->m_NumberOfValidPoints );
132 
133   /* For global transforms, sum the derivatives from each region. */
134   if( this->m_Associate->GetComputeDerivative() )
135     {
136     if ( this->m_Associate->m_MovingTransform->GetTransformCategory() != MovingTransformType::DisplacementField )
137       {
138       for (NumberOfParametersType p = 0; p < this->m_Associate->GetNumberOfParameters(); p++ )
139         {
140         /* Use a compensated sum to be ready for when there is a very large number of threads */
141         CompensatedDerivativeValueType sum;
142         sum.ResetToZero();
143         for (ThreadIdType i=0; i<numThreadsUsed; i++)
144           {
145           sum += this->m_GetValueAndDerivativePerThreadVariables[i].CompensatedDerivatives[p].GetSum();
146           }
147         (*(this->m_Associate->m_DerivativeResult))[p] += sum.GetSum();
148         }
149       }
150     }
151 
152   /* Check the number of valid points. If there aren't enough,
153    * m_Value and m_DerivativeResult will get appropriate values assigned,
154    * and a warning will be output. */
155   if( this->m_Associate->VerifyNumberOfValidPoints( this->m_Associate->m_Value, *(this->m_Associate->m_DerivativeResult) ) )
156     {
157     this->m_Associate->m_Value = NumericTraits<MeasureType>::ZeroValue();
158     /* Accumulate the metric value from threads and store the average. */
159     for(ThreadIdType threadId = 0; threadId < numThreadsUsed; ++threadId )
160       {
161       this->m_Associate->m_Value += this->m_GetValueAndDerivativePerThreadVariables[threadId].Measure;
162       }
163     this->m_Associate->m_Value /= this->m_Associate->m_NumberOfValidPoints;
164 
165     /* For global transforms, calculate the average values */
166     if( this->m_Associate->GetComputeDerivative() )
167       {
168       if ( this->m_Associate->m_MovingTransform->GetTransformCategory() != MovingTransformType::DisplacementField )
169         {
170         *(this->m_Associate->m_DerivativeResult) /= this->m_Associate->m_NumberOfValidPoints;
171         }
172       }
173     }
174 }
175 
176 template< typename TDomainPartitioner, typename TImageToImageMetricv4 >
177 bool
178 ImageToImageMetricv4GetValueAndDerivativeThreaderBase< TDomainPartitioner, TImageToImageMetricv4 >
ProcessVirtualPoint(const VirtualIndexType & virtualIndex,const VirtualPointType & virtualPoint,const ThreadIdType threadId)179 ::ProcessVirtualPoint( const VirtualIndexType & virtualIndex,
180                        const VirtualPointType & virtualPoint,
181                        const ThreadIdType threadId )
182 {
183   FixedImagePointType         mappedFixedPoint;
184   FixedImagePixelType         mappedFixedPixelValue;
185   FixedImageGradientType      mappedFixedImageGradient;
186   MovingImagePointType        mappedMovingPoint;
187   MovingImagePixelType        mappedMovingPixelValue;
188   MovingImageGradientType     mappedMovingImageGradient;
189   bool                        pointIsValid = false;
190   MeasureType                 metricValueResult;
191 
192   /* Transform the point into fixed and moving spaces, and evaluate.
193    * Do this in a try block to catch exceptions and print more useful info
194    * then we otherwise get when exceptions are caught in MultiThreaderBase. */
195   try
196     {
197     pointIsValid = this->m_Associate->TransformAndEvaluateFixedPoint( virtualPoint, mappedFixedPoint, mappedFixedPixelValue);
198     if( pointIsValid &&
199         this->m_Associate->GetComputeDerivative() &&
200         this->m_Associate->GetGradientSourceIncludesFixed() )
201       {
202       this->m_Associate->ComputeFixedImageGradientAtPoint( mappedFixedPoint, mappedFixedImageGradient );
203       }
204     }
205   catch( ExceptionObject & exc )
206     {
207     //NOTE: there must be a cleaner way to do this:
208     std::string msg("Caught exception: \n");
209     msg += exc.what();
210     ExceptionObject err(__FILE__, __LINE__, msg);
211     throw err;
212     }
213   if( !pointIsValid )
214     {
215     return pointIsValid;
216     }
217 
218   try
219     {
220     pointIsValid = this->m_Associate->TransformAndEvaluateMovingPoint( virtualPoint, mappedMovingPoint, mappedMovingPixelValue );
221     if( pointIsValid &&
222         this->m_Associate->GetComputeDerivative() &&
223         this->m_Associate->GetGradientSourceIncludesMoving() )
224       {
225       this->m_Associate->ComputeMovingImageGradientAtPoint( mappedMovingPoint, mappedMovingImageGradient );
226       }
227     }
228   catch( ExceptionObject & exc )
229     {
230     std::string msg("Caught exception: \n");
231     msg += exc.what();
232     ExceptionObject err(__FILE__, __LINE__, msg);
233     throw err;
234     }
235   if( !pointIsValid )
236     {
237     return pointIsValid;
238     }
239 
240   /* Call the user method in derived classes to do the specific
241    * calculations for value and derivative. */
242   try
243     {
244     pointIsValid = this->ProcessPoint(
245                                    virtualIndex,
246                                    virtualPoint,
247                                    mappedFixedPoint, mappedFixedPixelValue,
248                                    mappedFixedImageGradient,
249                                    mappedMovingPoint, mappedMovingPixelValue,
250                                    mappedMovingImageGradient,
251                                    metricValueResult,
252                                    this->m_GetValueAndDerivativePerThreadVariables[threadId].LocalDerivatives,
253                                    threadId );
254     }
255   catch( ExceptionObject & exc )
256     {
257     //NOTE: there must be a cleaner way to do this:
258     std::string msg("Exception in GetValueAndDerivativeProcessPoint:\n");
259     msg += exc.what();
260     ExceptionObject err(__FILE__, __LINE__, msg);
261     throw err;
262     }
263   if( pointIsValid )
264     {
265     this->m_GetValueAndDerivativePerThreadVariables[threadId].NumberOfValidPoints++;
266     this->m_GetValueAndDerivativePerThreadVariables[threadId].Measure += metricValueResult;
267     if( this->m_Associate->GetComputeDerivative() )
268       {
269       this->StorePointDerivativeResult( virtualIndex, threadId );
270       }
271     }
272 
273   return pointIsValid;
274 }
275 
276 template< typename TDomainPartitioner, typename TImageToImageMetricv4 >
277 void
278 ImageToImageMetricv4GetValueAndDerivativeThreaderBase< TDomainPartitioner, TImageToImageMetricv4 >
StorePointDerivativeResult(const VirtualIndexType & virtualIndex,const ThreadIdType threadId)279 ::StorePointDerivativeResult( const VirtualIndexType & virtualIndex, const ThreadIdType threadId )
280 {
281   if ( this->m_Associate->m_MovingTransform->GetTransformCategory() != MovingTransformType::DisplacementField )
282     {
283     /* Global support */
284     if ( this->m_Associate->GetUseFloatingPointCorrection() )
285       {
286       DerivativeValueType correctionResolution = this->m_Associate->GetFloatingPointCorrectionResolution();
287       for (NumberOfParametersType p = 0; p < this->m_CachedNumberOfParameters; p++ )
288         {
289         auto test = static_cast< intmax_t >(
290           this->m_GetValueAndDerivativePerThreadVariables[threadId].LocalDerivatives[p] * correctionResolution
291         );
292         this->m_GetValueAndDerivativePerThreadVariables[threadId].LocalDerivatives[p] = static_cast<DerivativeValueType>( test / correctionResolution );
293         }
294       }
295     for (NumberOfParametersType p = 0; p < this->m_CachedNumberOfParameters; p++ )
296       {
297       this->m_GetValueAndDerivativePerThreadVariables[threadId].CompensatedDerivatives[p] += this->m_GetValueAndDerivativePerThreadVariables[threadId].LocalDerivatives[p];
298       }
299     }
300   else
301     {
302     // Update derivative at some index
303     // this requires the moving image displacement field to be
304     // same size as virtual image, and that VirtualImage PixelType
305     // is scalar (which is verified during Metric initialization).
306     try
307       {
308       OffsetValueType offset = this->m_Associate->ComputeParameterOffsetFromVirtualIndex( virtualIndex, this->m_CachedNumberOfLocalParameters );
309       for (NumberOfParametersType i=0; i < this->m_CachedNumberOfLocalParameters; i++)
310         {
311         /* Be sure to *add* here and not assign. Required for proper behavior
312          * with multi-variate metric. */
313         this->m_GetValueAndDerivativePerThreadVariables[threadId].Derivatives[offset+i] += this->m_GetValueAndDerivativePerThreadVariables[threadId].LocalDerivatives[i];
314         }
315       }
316     catch( ExceptionObject & exc )
317       {
318       std::string msg("Caught exception: \n");
319       msg += exc.what();
320       ExceptionObject err(__FILE__, __LINE__, msg);
321       throw err;
322       }
323     }
324 }
325 
326 template< typename TDomainPartitioner, typename TImageToImageMetricv4 >
327 bool
328 ImageToImageMetricv4GetValueAndDerivativeThreaderBase< TDomainPartitioner, TImageToImageMetricv4 >
GetComputeDerivative() const329 ::GetComputeDerivative() const
330 {
331   return this->m_Associate->GetComputeDerivative();
332 }
333 
334 } // end namespace itk
335 
336 #endif
337