1 /*=========================================================================
2 *
3 * Copyright Insight Software Consortium
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0.txt
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 *=========================================================================*/
18 #ifndef itkImageToImageMetricv4GetValueAndDerivativeThreaderBase_hxx
19 #define itkImageToImageMetricv4GetValueAndDerivativeThreaderBase_hxx
20
21 #include "itkImageToImageMetricv4GetValueAndDerivativeThreaderBase.h"
22 #include "itkNumericTraits.h"
23
24 namespace itk
25 {
26
27 template< typename TDomainPartitioner, typename TImageToImageMetricv4 >
28 ImageToImageMetricv4GetValueAndDerivativeThreaderBase< TDomainPartitioner, TImageToImageMetricv4 >
ImageToImageMetricv4GetValueAndDerivativeThreaderBase()29 ::ImageToImageMetricv4GetValueAndDerivativeThreaderBase():
30 m_GetValueAndDerivativePerThreadVariables( nullptr ),
31 m_CachedNumberOfParameters( 0 ),
32 m_CachedNumberOfLocalParameters( 0 )
33 {
34 }
35
36 template< typename TDomainPartitioner, typename TImageToImageMetricv4 >
37 ImageToImageMetricv4GetValueAndDerivativeThreaderBase< TDomainPartitioner, TImageToImageMetricv4 >
~ImageToImageMetricv4GetValueAndDerivativeThreaderBase()38 ::~ImageToImageMetricv4GetValueAndDerivativeThreaderBase()
39 {
40 delete[] m_GetValueAndDerivativePerThreadVariables;
41 }
42
43 template< typename TDomainPartitioner, typename TImageToImageMetricv4 >
44 void
45 ImageToImageMetricv4GetValueAndDerivativeThreaderBase< TDomainPartitioner, TImageToImageMetricv4 >
BeforeThreadedExecution()46 ::BeforeThreadedExecution()
47 {
48 //---------------------------------------------------------------
49 // Resize the per thread memory objects.
50 //-----------------------------------------------------------------
51 // Cache some values
52 this->m_CachedNumberOfParameters = this->m_Associate->GetNumberOfParameters();
53 this->m_CachedNumberOfLocalParameters = this->m_Associate->GetNumberOfLocalParameters();
54
55 /* Per-thread results */
56 const ThreadIdType numThreadsUsed = this->GetNumberOfWorkUnitsUsed();
57 delete[] m_GetValueAndDerivativePerThreadVariables;
58 this->m_GetValueAndDerivativePerThreadVariables = new AlignedGetValueAndDerivativePerThreadStruct[ numThreadsUsed ];
59
60 if( this->m_Associate->GetComputeDerivative() )
61 {
62 for (ThreadIdType i = 0; i < numThreadsUsed; ++i)
63 {
64 /* Allocate intermediary per-thread storage used to get results from
65 * derived classes */
66 this->m_GetValueAndDerivativePerThreadVariables[i].LocalDerivatives.SetSize( this->m_CachedNumberOfLocalParameters );
67 this->m_GetValueAndDerivativePerThreadVariables[i].MovingTransformJacobian.SetSize(
68 this->m_Associate->VirtualImageDimension, this->m_CachedNumberOfLocalParameters );
69 // Not pre-allocated since it may not be used
70 //this->m_GetValueAndDerivativePerThreadVariables[i].MovingTransformJacobianPositional
71 if ( this->m_Associate->m_MovingTransform->GetTransformCategory() == MovingTransformType::DisplacementField )
72 {
73 /* For transforms with local support, e.g. displacement field,
74 * use a single derivative container that's updated by region
75 * in multiple threads.
76 * Initialization to zero is done in main class. */
77 itkDebugMacro( "ImageToImageMetricv4::Initialize: transform HAS local support\n" );
78 /* Set each per-thread object to point to m_DerivativeResult for efficiency. */
79 this->m_GetValueAndDerivativePerThreadVariables[i].Derivatives.SetData( this->m_Associate->m_DerivativeResult->data_block(),
80 this->m_Associate->m_DerivativeResult->Size(), false );
81 }
82 else
83 {
84 itkDebugMacro("ImageToImageMetricv4::Initialize: transform does NOT have local support\n");
85 /* This size always comes from the moving image */
86 const NumberOfParametersType globalDerivativeSize = this->m_CachedNumberOfParameters;
87 /* Global transforms get a separate derivatives container for each thread
88 * that holds the result over a particular image region.
89 * Use a CompensatedSummation value to provide for better consistency between
90 * different number of threads. */
91 this->m_GetValueAndDerivativePerThreadVariables[i].CompensatedDerivatives.resize( globalDerivativeSize );
92 }
93 }
94 }
95
96 //---------------------------------------------------------------
97 // Set initial values.
98 for (ThreadIdType thread = 0; thread < numThreadsUsed; ++thread)
99 {
100 this->m_GetValueAndDerivativePerThreadVariables[thread].NumberOfValidPoints = NumericTraits< SizeValueType >::ZeroValue();
101 this->m_GetValueAndDerivativePerThreadVariables[thread].Measure = NumericTraits< InternalComputationValueType >::ZeroValue();
102 if( this->m_Associate->GetComputeDerivative() )
103 {
104 if ( this->m_Associate->m_MovingTransform->GetTransformCategory() != MovingTransformType::DisplacementField )
105 {
106 /* Be sure to init to 0 here, because the threader may not use
107 * all the threads if the region is better split into fewer
108 * subregions. */
109 for( NumberOfParametersType p = 0; p < this->m_CachedNumberOfParameters; p++ )
110 {
111 this->m_GetValueAndDerivativePerThreadVariables[thread].CompensatedDerivatives[p].ResetToZero();
112 }
113 }
114 }
115 }
116 }
117
118 template< typename TDomainPartitioner, typename TImageToImageMetricv4 >
119 void
120 ImageToImageMetricv4GetValueAndDerivativeThreaderBase< TDomainPartitioner, TImageToImageMetricv4 >
AfterThreadedExecution()121 ::AfterThreadedExecution()
122 {
123 const ThreadIdType numThreadsUsed = this->GetNumberOfWorkUnitsUsed();
124 /* Store the number of valid points the enclosing class \c
125 * m_NumberOfValidPoints by collecting the valid points per thread. */
126 this->m_Associate->m_NumberOfValidPoints = NumericTraits< SizeValueType >::ZeroValue();
127 for (ThreadIdType i = 0; i < numThreadsUsed; ++i)
128 {
129 this->m_Associate->m_NumberOfValidPoints += this->m_GetValueAndDerivativePerThreadVariables[i].NumberOfValidPoints;
130 }
131 itkDebugMacro( "ImageToImageMetricv4: NumberOfValidPoints: " << this->m_Associate->m_NumberOfValidPoints );
132
133 /* For global transforms, sum the derivatives from each region. */
134 if( this->m_Associate->GetComputeDerivative() )
135 {
136 if ( this->m_Associate->m_MovingTransform->GetTransformCategory() != MovingTransformType::DisplacementField )
137 {
138 for (NumberOfParametersType p = 0; p < this->m_Associate->GetNumberOfParameters(); p++ )
139 {
140 /* Use a compensated sum to be ready for when there is a very large number of threads */
141 CompensatedDerivativeValueType sum;
142 sum.ResetToZero();
143 for (ThreadIdType i=0; i<numThreadsUsed; i++)
144 {
145 sum += this->m_GetValueAndDerivativePerThreadVariables[i].CompensatedDerivatives[p].GetSum();
146 }
147 (*(this->m_Associate->m_DerivativeResult))[p] += sum.GetSum();
148 }
149 }
150 }
151
152 /* Check the number of valid points. If there aren't enough,
153 * m_Value and m_DerivativeResult will get appropriate values assigned,
154 * and a warning will be output. */
155 if( this->m_Associate->VerifyNumberOfValidPoints( this->m_Associate->m_Value, *(this->m_Associate->m_DerivativeResult) ) )
156 {
157 this->m_Associate->m_Value = NumericTraits<MeasureType>::ZeroValue();
158 /* Accumulate the metric value from threads and store the average. */
159 for(ThreadIdType threadId = 0; threadId < numThreadsUsed; ++threadId )
160 {
161 this->m_Associate->m_Value += this->m_GetValueAndDerivativePerThreadVariables[threadId].Measure;
162 }
163 this->m_Associate->m_Value /= this->m_Associate->m_NumberOfValidPoints;
164
165 /* For global transforms, calculate the average values */
166 if( this->m_Associate->GetComputeDerivative() )
167 {
168 if ( this->m_Associate->m_MovingTransform->GetTransformCategory() != MovingTransformType::DisplacementField )
169 {
170 *(this->m_Associate->m_DerivativeResult) /= this->m_Associate->m_NumberOfValidPoints;
171 }
172 }
173 }
174 }
175
176 template< typename TDomainPartitioner, typename TImageToImageMetricv4 >
177 bool
178 ImageToImageMetricv4GetValueAndDerivativeThreaderBase< TDomainPartitioner, TImageToImageMetricv4 >
ProcessVirtualPoint(const VirtualIndexType & virtualIndex,const VirtualPointType & virtualPoint,const ThreadIdType threadId)179 ::ProcessVirtualPoint( const VirtualIndexType & virtualIndex,
180 const VirtualPointType & virtualPoint,
181 const ThreadIdType threadId )
182 {
183 FixedImagePointType mappedFixedPoint;
184 FixedImagePixelType mappedFixedPixelValue;
185 FixedImageGradientType mappedFixedImageGradient;
186 MovingImagePointType mappedMovingPoint;
187 MovingImagePixelType mappedMovingPixelValue;
188 MovingImageGradientType mappedMovingImageGradient;
189 bool pointIsValid = false;
190 MeasureType metricValueResult;
191
192 /* Transform the point into fixed and moving spaces, and evaluate.
193 * Do this in a try block to catch exceptions and print more useful info
194 * then we otherwise get when exceptions are caught in MultiThreaderBase. */
195 try
196 {
197 pointIsValid = this->m_Associate->TransformAndEvaluateFixedPoint( virtualPoint, mappedFixedPoint, mappedFixedPixelValue);
198 if( pointIsValid &&
199 this->m_Associate->GetComputeDerivative() &&
200 this->m_Associate->GetGradientSourceIncludesFixed() )
201 {
202 this->m_Associate->ComputeFixedImageGradientAtPoint( mappedFixedPoint, mappedFixedImageGradient );
203 }
204 }
205 catch( ExceptionObject & exc )
206 {
207 //NOTE: there must be a cleaner way to do this:
208 std::string msg("Caught exception: \n");
209 msg += exc.what();
210 ExceptionObject err(__FILE__, __LINE__, msg);
211 throw err;
212 }
213 if( !pointIsValid )
214 {
215 return pointIsValid;
216 }
217
218 try
219 {
220 pointIsValid = this->m_Associate->TransformAndEvaluateMovingPoint( virtualPoint, mappedMovingPoint, mappedMovingPixelValue );
221 if( pointIsValid &&
222 this->m_Associate->GetComputeDerivative() &&
223 this->m_Associate->GetGradientSourceIncludesMoving() )
224 {
225 this->m_Associate->ComputeMovingImageGradientAtPoint( mappedMovingPoint, mappedMovingImageGradient );
226 }
227 }
228 catch( ExceptionObject & exc )
229 {
230 std::string msg("Caught exception: \n");
231 msg += exc.what();
232 ExceptionObject err(__FILE__, __LINE__, msg);
233 throw err;
234 }
235 if( !pointIsValid )
236 {
237 return pointIsValid;
238 }
239
240 /* Call the user method in derived classes to do the specific
241 * calculations for value and derivative. */
242 try
243 {
244 pointIsValid = this->ProcessPoint(
245 virtualIndex,
246 virtualPoint,
247 mappedFixedPoint, mappedFixedPixelValue,
248 mappedFixedImageGradient,
249 mappedMovingPoint, mappedMovingPixelValue,
250 mappedMovingImageGradient,
251 metricValueResult,
252 this->m_GetValueAndDerivativePerThreadVariables[threadId].LocalDerivatives,
253 threadId );
254 }
255 catch( ExceptionObject & exc )
256 {
257 //NOTE: there must be a cleaner way to do this:
258 std::string msg("Exception in GetValueAndDerivativeProcessPoint:\n");
259 msg += exc.what();
260 ExceptionObject err(__FILE__, __LINE__, msg);
261 throw err;
262 }
263 if( pointIsValid )
264 {
265 this->m_GetValueAndDerivativePerThreadVariables[threadId].NumberOfValidPoints++;
266 this->m_GetValueAndDerivativePerThreadVariables[threadId].Measure += metricValueResult;
267 if( this->m_Associate->GetComputeDerivative() )
268 {
269 this->StorePointDerivativeResult( virtualIndex, threadId );
270 }
271 }
272
273 return pointIsValid;
274 }
275
276 template< typename TDomainPartitioner, typename TImageToImageMetricv4 >
277 void
278 ImageToImageMetricv4GetValueAndDerivativeThreaderBase< TDomainPartitioner, TImageToImageMetricv4 >
StorePointDerivativeResult(const VirtualIndexType & virtualIndex,const ThreadIdType threadId)279 ::StorePointDerivativeResult( const VirtualIndexType & virtualIndex, const ThreadIdType threadId )
280 {
281 if ( this->m_Associate->m_MovingTransform->GetTransformCategory() != MovingTransformType::DisplacementField )
282 {
283 /* Global support */
284 if ( this->m_Associate->GetUseFloatingPointCorrection() )
285 {
286 DerivativeValueType correctionResolution = this->m_Associate->GetFloatingPointCorrectionResolution();
287 for (NumberOfParametersType p = 0; p < this->m_CachedNumberOfParameters; p++ )
288 {
289 auto test = static_cast< intmax_t >(
290 this->m_GetValueAndDerivativePerThreadVariables[threadId].LocalDerivatives[p] * correctionResolution
291 );
292 this->m_GetValueAndDerivativePerThreadVariables[threadId].LocalDerivatives[p] = static_cast<DerivativeValueType>( test / correctionResolution );
293 }
294 }
295 for (NumberOfParametersType p = 0; p < this->m_CachedNumberOfParameters; p++ )
296 {
297 this->m_GetValueAndDerivativePerThreadVariables[threadId].CompensatedDerivatives[p] += this->m_GetValueAndDerivativePerThreadVariables[threadId].LocalDerivatives[p];
298 }
299 }
300 else
301 {
302 // Update derivative at some index
303 // this requires the moving image displacement field to be
304 // same size as virtual image, and that VirtualImage PixelType
305 // is scalar (which is verified during Metric initialization).
306 try
307 {
308 OffsetValueType offset = this->m_Associate->ComputeParameterOffsetFromVirtualIndex( virtualIndex, this->m_CachedNumberOfLocalParameters );
309 for (NumberOfParametersType i=0; i < this->m_CachedNumberOfLocalParameters; i++)
310 {
311 /* Be sure to *add* here and not assign. Required for proper behavior
312 * with multi-variate metric. */
313 this->m_GetValueAndDerivativePerThreadVariables[threadId].Derivatives[offset+i] += this->m_GetValueAndDerivativePerThreadVariables[threadId].LocalDerivatives[i];
314 }
315 }
316 catch( ExceptionObject & exc )
317 {
318 std::string msg("Caught exception: \n");
319 msg += exc.what();
320 ExceptionObject err(__FILE__, __LINE__, msg);
321 throw err;
322 }
323 }
324 }
325
326 template< typename TDomainPartitioner, typename TImageToImageMetricv4 >
327 bool
328 ImageToImageMetricv4GetValueAndDerivativeThreaderBase< TDomainPartitioner, TImageToImageMetricv4 >
GetComputeDerivative() const329 ::GetComputeDerivative() const
330 {
331 return this->m_Associate->GetComputeDerivative();
332 }
333
334 } // end namespace itk
335
336 #endif
337