1 /*=========================================================================
2  *
3  *  Copyright Insight Software Consortium
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at
8  *
9  *         http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  *=========================================================================*/
18 #ifndef itkANTSNeighborhoodCorrelationImageToImageMetricv4GetValueAndDerivativeThreader_h
19 #define itkANTSNeighborhoodCorrelationImageToImageMetricv4GetValueAndDerivativeThreader_h
20 
21 #include "itkImageToImageMetricv4GetValueAndDerivativeThreader.h"
22 #include "itkThreadedImageRegionPartitioner.h"
23 #include "itkThreadedIndexedContainerPartitioner.h"
24 #include "itkConstNeighborhoodIterator.h"
25 
26 #include <deque>
27 
28 namespace itk
29 {
30 
31 /**
32  *  A template struct to identify different input type arguments. This is used
33  *  for function overloading by different threaders. Refer to the comments below.
34  */
35 template<typename T>
36 struct IdentityHelper
37 {
38   using MyType = T;
39 };
40 
41 /** \class ANTSNeighborhoodCorrelationImageToImageMetricv4GetValueAndDerivativeThreader
42  * \brief Threading implementation for ANTS CC metric \c ANTSNeighborhoodCorrelationImageToImageMetricv4 .
43  * Supports both dense and sparse threading ways. The dense threader iterates over the whole image domain
44  * in order and use a neighborhood scanning window to compute the local cross correlation metric and
45  * its derivative incrementally inside the window. The sparse threader uses a sampled point set partitioner to
46  * computer local cross correlation only at the sampled positions.
47  *
48  * This threader class is designed to host the dense and sparse threader under the same name so most computation
49  * routine functions and interior member variables can be shared. This eliminates the need to duplicate codes
50  * for two threaders. This is made by using function overloading and a helper class to identify different types of domain
51  * partitioners.
52  *
53  *
54  * \ingroup ITKMetricsv4
55  */
56 template< typename TDomainPartitioner, typename TImageToImageMetric, typename TNeighborhoodCorrelationMetric >
57 class ITK_TEMPLATE_EXPORT ANTSNeighborhoodCorrelationImageToImageMetricv4GetValueAndDerivativeThreader
58   : public ImageToImageMetricv4GetValueAndDerivativeThreader< TDomainPartitioner, TImageToImageMetric >
59 {
60 public:
61   ITK_DISALLOW_COPY_AND_ASSIGN(ANTSNeighborhoodCorrelationImageToImageMetricv4GetValueAndDerivativeThreader);
62 
63   /** Standard class type aliases. */
64   using Self = ANTSNeighborhoodCorrelationImageToImageMetricv4GetValueAndDerivativeThreader;
65   using Superclass =
66       ImageToImageMetricv4GetValueAndDerivativeThreader< TDomainPartitioner, TImageToImageMetric >;
67   using Pointer = SmartPointer< Self >;
68   using ConstPointer = SmartPointer< const Self >;
69 
70   itkTypeMacro( ANTSNeighborhoodCorrelationImageToImageMetricv4GetValueAndDerivativeThreader, ImageToImageMetricv4GetValueAndDerivativeThreader );
71 
72   itkNewMacro( Self );
73 
74   using DomainType = typename Superclass::DomainType;
75   using AssociateType = typename Superclass::AssociateType;
76 
77   using VirtualImageType = typename Superclass::VirtualImageType;
78   using VirtualPointType = typename Superclass::VirtualPointType;
79   using VirtualIndexType = typename Superclass::VirtualIndexType;
80   using FixedImagePointType = typename Superclass::FixedImagePointType;
81   using FixedImagePixelType = typename Superclass::FixedImagePixelType;
82   using FixedImageGradientType = typename Superclass::FixedImageGradientType;
83   using MovingImagePointType = typename Superclass::MovingImagePointType;
84   using MovingImagePixelType = typename Superclass::MovingImagePixelType;
85   using MovingImageGradientType = typename Superclass::MovingImageGradientType;
86   using MeasureType = typename Superclass::MeasureType;
87   using DerivativeType = typename Superclass::DerivativeType;
88   using DerivativeValueType = typename Superclass::DerivativeValueType;
89 
90   using NeighborhoodCorrelationMetricType = TNeighborhoodCorrelationMetric;
91 
92   using ImageRegionType = typename NeighborhoodCorrelationMetricType::ImageRegionType;
93   using InternalComputationValueType = typename NeighborhoodCorrelationMetricType::InternalComputationValueType;
94   using ImageDimensionType = typename NeighborhoodCorrelationMetricType::ImageDimensionType;
95   using JacobianType = typename NeighborhoodCorrelationMetricType::JacobianType;
96   using NumberOfParametersType = typename NeighborhoodCorrelationMetricType::NumberOfParametersType;
97   using FixedImageType = typename NeighborhoodCorrelationMetricType::FixedImageType;
98   using MovingImageType = typename NeighborhoodCorrelationMetricType::MovingImageType;
99   using RadiusType = typename NeighborhoodCorrelationMetricType::RadiusType;
100 
101   // interested values here updated during scanning
102   using QueueRealType = InternalComputationValueType;
103   using SumQueueType = std::deque<QueueRealType>;
104   using ScanIteratorType = ConstNeighborhoodIterator<VirtualImageType>;
105 
106   // one ScanMemType for each thread
107   typedef struct ScanMemType {
108     // queues used in the scanning
109     // sum of the fixed value squared
110     SumQueueType QsumFixed2;
111     // sum of the moving value squared
112     SumQueueType QsumMoving2;
113     SumQueueType QsumFixed;
114     SumQueueType QsumMoving;
115     SumQueueType QsumFixedMoving;
116     SumQueueType Qcount;
117 
118     QueueRealType fixedA;
119     QueueRealType movingA;
120     QueueRealType sFixedMoving;
121     QueueRealType sFixedFixed;
122     QueueRealType sMovingMoving;
123 
124     FixedImageGradientType  fixedImageGradient;
125     MovingImageGradientType movingImageGradient;
126 
127     FixedImagePointType     mappedFixedPoint;
128     MovingImagePointType    mappedMovingPoint;
129     VirtualPointType        virtualPoint;
130   } ScanMemType;
131 
132   // For dense scan over one image region
133   typedef struct ScanParametersType {
134     // const values during scanning
135     ImageRegionType scanRegion;
136     SizeValueType   numberOfFillZero; // for each queue
137     SizeValueType   windowLength; // number of voxels in the scanning window
138     IndexValueType  scanRegionBeginIndexDim0;
139 
140     typename FixedImageType::ConstPointer   fixedImage;
141     typename MovingImageType::ConstPointer  movingImage;
142     typename VirtualImageType::ConstPointer virtualImage;
143     RadiusType radius;
144 
145   } ScanParametersType;
146 
147 protected:
ANTSNeighborhoodCorrelationImageToImageMetricv4GetValueAndDerivativeThreader()148   ANTSNeighborhoodCorrelationImageToImageMetricv4GetValueAndDerivativeThreader() :
149     m_ANTSAssociate(nullptr)
150   {}
151 
152   /**
153    * Dense threader and sparse threader invoke different in multi-threading. This class uses overloaded
154    * implementations of \c ProcessVirtualPoint_impl and \c ThreadExecution_impl in order to handle the
155    * dense and sparse cases differently. The helper class IdentityHelper allows for correct overloading
156    * these methods when substituting different type of the threaded partitioner
157    *
158    * 1) Dense threader: through its own \c ThreadedExecution. \c ProcessVirtualPoint and
159    * \c ProcessPoint of the base class are thus not used.
160    *
161    * 2) Sparse threader: through its own \c ProcessVirtualPoint. \c ThreadedExecution still invokes (mostly)
162    * from the base class.
163    *
164    * In order to invoke different \c ThreadedExecution by different threader, we use function overloading
165    * techniques to resolve which version of  \c ThreadedExecution and \c ProcessVirtualPoint by
166    * the type of the domain partitioner.
167    *
168    * Specifically, a helper class \c IdentityHelper is used as a function parameter, with the sole purpose
169    * to differentiate different types of domain partitioners: \c ThreadedIndexedContainerPartitioner for sparse
170    * or \c ThreadedImageRegionPartitioner for dense. \c IdentityHelper is simply a class template, ie. a struct
171    * wrapper of type template arguments.
172    *
173    * This technique takes advantage of SFINAE (Substitution Failure Is Not An Error) in specializing function
174    * templates. The helper class \c IdentityHelper is used to overload w.r.t different partitioner types.
175    * More discussion can be found at:
176    * http://stackoverflow.com/questions/3052579/explicit-specialization-in-non-namespace-scope
177    *
178    * */
179 
180   /** Method called by the threaders to process the given virtual point.  This
181    * in turn calls \c TransformAndEvaluateFixedPoint, \c
182    * TransformAndEvaluateMovingPoint, and \c ProcessPoint.
183    * And adds entries to m_MeasurePerThread and m_LocalDerivativesPerThread,
184    * m_NumberOfValidPointsPerThread. */
ProcessVirtualPoint(const VirtualIndexType & virtualIndex,const VirtualPointType & virtualPoint,const ThreadIdType threadId)185   bool ProcessVirtualPoint( const VirtualIndexType & virtualIndex,
186                                     const VirtualPointType & virtualPoint,
187                                     const ThreadIdType threadId ) override {
188     return ProcessVirtualPoint_impl(IdentityHelper<TDomainPartitioner>(), virtualIndex, virtualPoint, threadId );
189   }
190 
191   /* specific overloading for sparse CC metric */
192   bool ProcessVirtualPoint_impl(
193                              IdentityHelper<ThreadedIndexedContainerPartitioner> itkNotUsed(self),
194                              const VirtualIndexType & virtualIndex,
195                              const VirtualPointType & virtualPoint,
196                              const ThreadIdType threadId );
197 
198   /* for other default case */
199   template<typename T>
ProcessVirtualPoint_impl(IdentityHelper<T> itkNotUsed (self),const VirtualIndexType & virtualIndex,const VirtualPointType & virtualPoint,const ThreadIdType threadId)200   bool ProcessVirtualPoint_impl(
201                              IdentityHelper<T> itkNotUsed(self),
202                              const VirtualIndexType & virtualIndex,
203                              const VirtualPointType & virtualPoint,
204                              const ThreadIdType threadId ) {
205     return Superclass::ProcessVirtualPoint(virtualIndex, virtualPoint, threadId);
206   }
207 
208 
209   /** \c ProcessPoint() must be overloaded since it is a pure virtual function.
210    * It is not used for either sparse or dense threader.
211    * */
ProcessPoint(const VirtualIndexType & itkNotUsed (virtualIndex),const VirtualPointType & itkNotUsed (virtualPoint),const FixedImagePointType & itkNotUsed (mappedFixedPoint),const FixedImagePixelType & itkNotUsed (mappedFixedPixelValue),const FixedImageGradientType & itkNotUsed (mappedFixedImageGradient),const MovingImagePointType & itkNotUsed (mappedMovingPoint),const MovingImagePixelType & itkNotUsed (mappedMovingPixelValue),const MovingImageGradientType & itkNotUsed (mappedMovingImageGradient),MeasureType & itkNotUsed (metricValueReturn),DerivativeType & itkNotUsed (localDerivativeReturn),const ThreadIdType itkNotUsed (threadId))212   bool ProcessPoint(
213          const VirtualIndexType &          itkNotUsed(virtualIndex),
214          const VirtualPointType &          itkNotUsed(virtualPoint),
215          const FixedImagePointType &       itkNotUsed(mappedFixedPoint),
216          const FixedImagePixelType &       itkNotUsed(mappedFixedPixelValue),
217          const FixedImageGradientType &    itkNotUsed(mappedFixedImageGradient),
218          const MovingImagePointType &      itkNotUsed(mappedMovingPoint),
219          const MovingImagePixelType &      itkNotUsed(mappedMovingPixelValue),
220          const MovingImageGradientType &   itkNotUsed(mappedMovingImageGradient),
221          MeasureType &                     itkNotUsed(metricValueReturn),
222          DerivativeType &                  itkNotUsed(localDerivativeReturn),
223          const ThreadIdType                itkNotUsed(threadId) ) const override
224      {
225         itkExceptionMacro("ProcessPoint should never be reached in ANTS CC metric threader class.");
226      }
227 
ThreadedExecution(const DomainType & domain,const ThreadIdType threadId)228   void ThreadedExecution( const DomainType& domain,
229                                     const ThreadIdType threadId ) override
230     {
231     ThreadedExecution_impl(IdentityHelper<TDomainPartitioner>(), domain, threadId );
232     }
233 
234   /* specific overloading for dense threader only based CC metric */
235   void ThreadedExecution_impl(
236                              IdentityHelper<ThreadedImageRegionPartitioner<TImageToImageMetric::VirtualImageDimension> > itkNotUsed(self),
237                              const DomainType& domain,
238                              const ThreadIdType threadId );
239 
240   /* for other default case */
241   template<typename T>
242   void ThreadedExecution_impl(
243                              IdentityHelper<T> itkNotUsed(self),
244                              const DomainType& domain,
245                              const ThreadIdType threadId );
246 
247   /** Common functions for computing correlation over scanning windows **/
248 
249   /** Create an iterator over the virtual sub region */
250   void InitializeScanning(const ImageRegionType &scanRegion,
251     ScanIteratorType &scanIt, ScanMemType &scanMem,
252     ScanParametersType &scanParameters ) const;
253 
254   /** Update the queues for the next point.  Calls either \c
255    * UpdateQueuesAtBeginningOfLine or \c UpdateQueuesToNextScanWindow. */
256   void UpdateQueues(const ScanIteratorType &scanIt,
257     ScanMemType &scanMem, const ScanParametersType &scanParameters,
258     const ThreadIdType threadId) const;
259 
260   void UpdateQueuesAtBeginningOfLine(
261     const ScanIteratorType &scanIt, ScanMemType &scanMem,
262     const ScanParametersType &scanParameters,
263     const ThreadIdType threadId) const;
264 
265   /** Increment the iterator and check to see if we're at the end of the
266    * line.  If so, go to the next line.  Otherwise, add the
267    * the values for the next hyperplane. */
268   void UpdateQueuesToNextScanWindow(
269     const ScanIteratorType &scanIt, ScanMemType &scanMem,
270     const ScanParametersType &scanParameters,
271     const ThreadIdType threadId) const;
272 
273   /** Test to see if there are any voxels we need to handle in the current
274    * window. */
275   bool ComputeInformationFromQueues(
276     const ScanIteratorType &scanIt, ScanMemType &scanMem,
277     const ScanParametersType &scanParameters,
278     const ThreadIdType threadId) const;
279 
280   void ComputeMovingTransformDerivative(
281     const ScanIteratorType &scanIt, ScanMemType &scanMem,
282     const ScanParametersType &scanParameters, DerivativeType &deriv,
283     MeasureType &local_cc, const ThreadIdType threadId) const;
284 
285 private:
286   /** Internal pointer to the metric object in use by this threader.
287    *  This will avoid costly dynamic casting in tight loops. */
288   TNeighborhoodCorrelationMetric * m_ANTSAssociate;
289 };
290 
291 
292 } // end namespace itk
293 
294 #ifndef ITK_MANUAL_INSTANTIATION
295 #include "itkANTSNeighborhoodCorrelationImageToImageMetricv4GetValueAndDerivativeThreader.hxx"
296 #endif
297 
298 #endif
299