1 /*=========================================================================
2  *
3  *  Copyright Insight Software Consortium
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at
8  *
9  *         http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  *=========================================================================*/
18 #ifndef itkPointSetToPointSetMetricv4_h
19 #define itkPointSetToPointSetMetricv4_h
20 
21 #include "itkObjectToObjectMetric.h"
22 
23 #include "itkFixedArray.h"
24 #include "itkPointsLocator.h"
25 #include "itkPointSet.h"
26 
27 namespace itk
28 {
29 /** \class PointSetToPointSetMetricv4
30  * \brief Computes similarity between two point sets.
31  *
32  * This class is templated over the type of the two point-sets.  It
33  * expects a Transform to be plugged in for each of fixed and moving
34  * point sets. The transforms default to IdenityTransform types. This particular
35  * class is the base class for a hierarchy of point-set to point-set metrics.
36  *
37  * This class computes a value that measures the similarity between the fixed
38  * point-set and the moving point-set in the moving domain. The fixed point set
39  * is transformed into the virtual domain by computing the inverse of the
40  * fixed transform, then transformed into the moving domain using the
41  * moving transform.
42  *
43  * Since the \c PointSet class permits each \c Point to be associated with a
44  * \c PixelType, there are potential applications which could make use of
45  * this additional information.  For example, the derived \c LabeledPointSetToPointSetMetric
46  * class uses the \c PixelType as a \c LabelType for estimating total metric values
47  * and gradients from the individual label-wise point subset metric and derivatives
48  *
49  * If a virtual domain is not defined by the user, one of two things happens:
50  * 1) If the moving transform is a global type, then the virtual domain is
51  * left undefined and every point is considered to be within the virtual domain.
52  * 2) If the moving transform is a local-support type, then the virtual domain
53  * is taken during initialization from the moving transform displacement field,
54  * and all fixed points are verified to be within the virtual domain after
55  * transformation by the inverse fixed transform. Points outside the virtual
56  * domain are not used. See GetNumberOfValidPoints() to verify how many fixed
57  * points were used during evaluation.
58  *
59  * See ObjectToObjectMetric documentation for more discussion on the virutal domain.
60  *
61  * \note When used with an RegistrationParameterScalesEstimator estimator, a VirtualDomainPointSet
62  * must be defined and assigned to the estimator, for use in shift estimation.
63  * The virtual domain point set can be retrieved from the metric using the
64  * GetVirtualTransformedPointSet() method.
65  *
66  * \ingroup ITKMetricsv4
67  */
68 
69 template<typename TFixedPointSet,  typename TMovingPointSet,
70   class TInternalComputationValueType = double>
71 class ITK_TEMPLATE_EXPORT PointSetToPointSetMetricv4
72 : public ObjectToObjectMetric<TFixedPointSet::PointDimension, TMovingPointSet::PointDimension,
73    Image<TInternalComputationValueType, TFixedPointSet::PointDimension>, TInternalComputationValueType>
74 {
75 public:
76   ITK_DISALLOW_COPY_AND_ASSIGN(PointSetToPointSetMetricv4);
77 
78   /** Standard class type aliases. */
79   using Self = PointSetToPointSetMetricv4;
80   using Superclass = ObjectToObjectMetric<TFixedPointSet::PointDimension,
81     TMovingPointSet::PointDimension,
82     Image<TInternalComputationValueType, TFixedPointSet::PointDimension>,
83     TInternalComputationValueType>;
84   using Pointer = SmartPointer<Self>;
85   using ConstPointer = SmartPointer<const Self>;
86 
87   /** Run-time type information (and related methods). */
88   itkTypeMacro( PointSetToPointSetMetricv4, ObjectToObjectMetric );
89 
90   /**  Type of the measure. */
91   using MeasureType = typename Superclass::MeasureType;
92 
93   /**  Type of the parameters. */
94   using ParametersType = typename Superclass::ParametersType;
95   using ParametersValueType = typename Superclass::ParametersValueType;
96   using NumberOfParametersType = typename Superclass::NumberOfParametersType;
97 
98   /**  Type of the derivative. */
99   using DerivativeType = typename Superclass::DerivativeType;
100 
101   /** Transform types from Superclass*/
102   using FixedTransformType = typename Superclass::FixedTransformType;
103   using FixedTransformPointer = typename Superclass::FixedTransformPointer;
104   using FixedInputPointType = typename Superclass::FixedInputPointType;
105   using FixedOutputPointType = typename Superclass::FixedOutputPointType;
106   using FixedTransformParametersType = typename Superclass::FixedTransformParametersType;
107 
108   using MovingTransformType = typename Superclass::MovingTransformType;
109   using MovingTransformPointer = typename Superclass::MovingTransformPointer;
110   using MovingInputPointType = typename Superclass::MovingInputPointType;
111   using MovingOutputPointType = typename Superclass::MovingOutputPointType;
112   using MovingTransformParametersType = typename Superclass::MovingTransformParametersType;
113 
114   using JacobianType = typename Superclass::JacobianType;
115   using FixedTransformJacobianType = typename Superclass::FixedTransformJacobianType;
116   using MovingTransformJacobianType = typename Superclass::MovingTransformJacobianType;
117 
118   using DisplacementFieldTransformType = typename Superclass::MovingDisplacementFieldTransformType;
119 
120   using ObjectType = typename Superclass::ObjectType;
121 
122   /** Dimension type */
123   using DimensionType = typename Superclass::DimensionType;
124 
125   /**  Type of the fixed point set. */
126   using FixedPointSetType = TFixedPointSet;
127   using FixedPointType = typename TFixedPointSet::PointType;
128   using FixedPixelType = typename TFixedPointSet::PixelType;
129   using FixedPointsContainer = typename TFixedPointSet::PointsContainer;
130 
131   static constexpr DimensionType FixedPointDimension = Superclass::FixedDimension;
132 
133   /**  Type of the moving point set. */
134   using MovingPointSetType = TMovingPointSet;
135   using MovingPointType = typename TMovingPointSet::PointType;
136   using MovingPixelType = typename TMovingPointSet::PixelType;
137   using MovingPointsContainer = typename TMovingPointSet::PointsContainer;
138 
139   static constexpr DimensionType MovingPointDimension = Superclass::MovingDimension;
140 
141   /**
142    * typedefs for the data types used in the point set metric calculations.
143    * It is assumed that the constants of the fixed point set, such as the
144    * point dimension, are the same for the "common space" in which the metric
145    * calculation occurs.
146    */
147   static constexpr DimensionType PointDimension = Superclass::FixedDimension;
148 
149   using PointType = FixedPointType;
150   using PixelType = FixedPixelType;
151   using CoordRepType = typename PointType::CoordRepType;
152   using PointsContainer = FixedPointsContainer;
153   using PointsConstIterator = typename PointsContainer::ConstIterator;
154   using PointIdentifier = typename PointsContainer::ElementIdentifier;
155 
156   /** Typedef for points locator class to speed up finding neighboring points */
157   using PointsLocatorType = PointsLocator< PointsContainer>;
158   using NeighborsIdentifierType = typename PointsLocatorType::NeighborsIdentifierType;
159 
160   using FixedTransformedPointSetType = PointSet<FixedPixelType, Self::PointDimension >;
161   using MovingTransformedPointSetType = PointSet<MovingPixelType, Self::PointDimension >;
162 
163   using DerivativeValueType = typename DerivativeType::ValueType;
164   using LocalDerivativeType = FixedArray<DerivativeValueType, Self::PointDimension >;
165 
166   /** Types for the virtual domain */
167   using VirtualImageType = typename Superclass::VirtualImageType;
168   using VirtualImagePointer = typename Superclass::VirtualImagePointer;
169   using VirtualPixelType = typename Superclass::VirtualPixelType;
170   using VirtualRegionType = typename Superclass::VirtualRegionType;
171   using VirtualSizeType = typename Superclass::VirtualSizeType;
172   using VirtualSpacingType = typename Superclass::VirtualSpacingType;
173   using VirtualOriginType = typename Superclass::VirtualPointType;
174   using VirtualPointType = typename Superclass::VirtualPointType;
175   using VirtualDirectionType = typename Superclass::VirtualDirectionType;
176   using VirtualRadiusType = typename Superclass::VirtualSizeType;
177   using VirtualIndexType = typename Superclass::VirtualIndexType;
178   using VirtualPointSetType = typename Superclass::VirtualPointSetType;
179   using VirtualPointSetPointer = typename Superclass::VirtualPointSetPointer;
180 
181   /** Set fixed point set*/
SetFixedObject(const ObjectType * object)182   void SetFixedObject( const ObjectType *object ) override
183     {
184     auto * pointSet = dynamic_cast<FixedPointSetType *>( const_cast<ObjectType *>( object ) );
185     if( pointSet != nullptr )
186       {
187       this->SetFixedPointSet( pointSet );
188       }
189     else
190       {
191       itkExceptionMacro( "Incorrect object type.  Should be a point set." )
192       }
193     }
194 
195   /** Set moving point set*/
SetMovingObject(const ObjectType * object)196   void SetMovingObject( const ObjectType *object ) override
197     {
198     auto * pointSet = dynamic_cast<MovingPointSetType *>( const_cast<ObjectType *>( object ) );
199     if( pointSet != nullptr )
200       {
201       this->SetMovingPointSet( pointSet );
202       }
203     else
204       {
205       itkExceptionMacro( "Incorrect object type.  Should be a point set." )
206       }
207     }
208 
209   /** Get/Set the fixed pointset.  */
210   itkSetConstObjectMacro( FixedPointSet, FixedPointSetType );
211   itkGetConstObjectMacro( FixedPointSet, FixedPointSetType );
212 
213   /** Get the fixed transformed point set.  */
214   itkGetModifiableObjectMacro( FixedTransformedPointSet, FixedTransformedPointSetType );
215 
216   /** Get/Set the moving point set.  */
217   itkSetConstObjectMacro( MovingPointSet, MovingPointSetType );
218   itkGetConstObjectMacro( MovingPointSet, MovingPointSetType );
219 
220   /** Get the moving transformed point set.  */
221   itkGetModifiableObjectMacro( MovingTransformedPointSet, MovingTransformedPointSetType );
222 
223   /**
224    * For now return the number of points used in the value/derivative calculations.
225    */
226   SizeValueType GetNumberOfComponents() const;
227 
228   /**
229    * This method returns the value of the metric based on the current
230    * transformation(s).  This function can be redefined in derived classes
231    * but many point set metrics follow the same structure---one iterates
232    * through the points and, for each point a metric value is calculated.
233    * The summation of these individual point metric values gives the total
234    * value of the metric.  Note that this might not be applicable to all
235    * point set metrics.  For those cases, the developer will have to redefine
236    * the GetValue() function.
237    */
238   MeasureType GetValue() const override;
239 
240   /**
241    * This method returns the derivative based on the current
242    * transformation(s).  This function can be redefined in derived classes
243    * but many point set metrics follow the same structure---one iterates
244    * through the points and, for each point a derivative is calculated.
245    * The set of all these local derivatives constitutes the total derivative.
246    * Note that this might not be applicable to all point set metrics.  For
247    * those cases, the developer will have to redefine the GetDerivative()
248    * function.
249    */
250   void GetDerivative( DerivativeType & ) const override;
251 
252   /**
253    * This method returns the derivative and value based on the current
254    * transformation(s).  This function can be redefined in derived classes
255    * but many point set metrics follow the same structure---one iterates
256    * through the points and, for each point a derivative and value is calculated.
257    * The set of all these local derivatives/values constitutes the total
258    * derivative and value.  Note that this might not be applicable to all
259    * point set metrics.  For those cases, the developer will have to redefine
260    * the GetValue() and GetDerivative() functions.
261    */
262   void GetValueAndDerivative( MeasureType &, DerivativeType & ) const override;
263 
264   /**
265    * Function to be defined in the appropriate derived classes.  Calculates
266    * the local metric value for a single point.  The \c PixelType may or
267    * may not be used.  See class description for further explanation.
268    */
269   virtual MeasureType GetLocalNeighborhoodValue( const PointType &, const PixelType & pixel ) const = 0;
270 
271   /**
272    * Calculates the local derivative for a single point. The \c PixelType may or
273    * may not be used.  See class description for further explanation.
274    */
275   virtual LocalDerivativeType GetLocalNeighborhoodDerivative( const PointType &, const PixelType & pixel ) const;
276 
277   /**
278    * Calculates the local value/derivative for a single point.  The \c PixelType may or
279    * may not be used.  See class description for further explanation.
280    */
281   virtual void GetLocalNeighborhoodValueAndDerivative( const PointType &,
282     MeasureType &, LocalDerivativeType &, const PixelType & pixel ) const = 0;
283 
284   /**
285    * Get the virtual point set, derived from the fixed point set.
286    * If the virtual point set has not yet been derived, it will be
287    * in this call. */
288   const VirtualPointSetType * GetVirtualTransformedPointSet() const;
289 
290   /**
291    * Initialize the metric by making sure that all the components
292    *  are present and plugged together correctly.
293    */
294   void Initialize() override;
295 
SupportsArbitraryVirtualDomainSamples()296   bool SupportsArbitraryVirtualDomainSamples() const override
297   {
298     /* An arbitrary point in the virtual domain will not always
299      * correspond to a point within either point set. */
300     return false;
301   }
302 
303   /**
304    * By default, the point set metric derivative for a displacement field transform
305    * is stored by saving the gradient for every voxel in the displacement field (see
306    * the function StorePointDerivative()).  Since the "fixed points" will typically
307    * constitute a sparse set, this means that the field will have zero gradient values
308    * at every voxel that doesn't have a corresponding point.  This might cause additional
309    * computation time for certain transforms (e.g. B-spline SyN). To avoid this, this
310    * option permits storing the point derivative only at the fixed point locations.
311    * If this variable is set to false, then the derivative array will be of length
312    * = PointDimension * m_FixedPointSet->GetNumberOfPoints().
313    */
314   itkSetMacro( StoreDerivativeAsSparseFieldForLocalSupportTransforms, bool );
315   itkGetConstMacro( StoreDerivativeAsSparseFieldForLocalSupportTransforms, bool );
316   itkBooleanMacro( StoreDerivativeAsSparseFieldForLocalSupportTransforms );
317 
318   /**
319    *
320    */
321   itkSetMacro( CalculateValueAndDerivativeInTangentSpace, bool );
322   itkGetConstMacro( CalculateValueAndDerivativeInTangentSpace, bool );
323   itkBooleanMacro( CalculateValueAndDerivativeInTangentSpace );
324 
325 protected:
326   PointSetToPointSetMetricv4();
327   ~PointSetToPointSetMetricv4() override = default;
328   void PrintSelf( std::ostream & os, Indent indent ) const override;
329 
330   typename FixedPointSetType::ConstPointer                m_FixedPointSet;
331   mutable typename FixedTransformedPointSetType::Pointer  m_FixedTransformedPointSet;
332 
333   mutable typename PointsLocatorType::Pointer             m_FixedTransformedPointsLocator;
334 
335   typename MovingPointSetType::ConstPointer               m_MovingPointSet;
336   mutable typename MovingTransformedPointSetType::Pointer m_MovingTransformedPointSet;
337 
338   mutable typename PointsLocatorType::Pointer             m_MovingTransformedPointsLocator;
339 
340   /** Holds the fixed points after transformation into virtual domain. */
341   mutable VirtualPointSetPointer                          m_VirtualTransformedPointSet;
342 
343   /**
344    * Bool set by derived classes on whether the point set data (i.e. \c PixelType)
345    * should be used.  Default = false.
346    */
347   bool m_UsePointSetData;
348 
349   /**
350    * Flag to calculate value and/or derivative at tangent space.  This is needed
351    * for the diffeomorphic registration methods.  The fixed and moving points are
352    * warped to the virtual domain where the metric is calculated.  Derived point
353    * set metrics might have associated gradient information which will need to be
354    * warped if this flag is true.  Default = false.
355    */
356   bool m_CalculateValueAndDerivativeInTangentSpace;
357 
358   /**
359    * Prepare point sets for use. */
360   virtual void InitializePointSets() const;
361 
362   /**
363    * Initialize to prepare for a particular iteration, generally
364    * an iteration of optimization. Distinct from Initialize()
365    * which is a one-time initialization. */
366   virtual void InitializeForIteration() const;
367 
368   /**
369    * Determine the number of valid fixed points. A fixed point
370    * is valid if, when transformed into the virtual domain using
371    * the inverse of the FixedTransform, it is within the defined
372    * virtual domain bounds. */
373   virtual SizeValueType CalculateNumberOfValidFixedPoints() const;
374 
375   /** Helper method allows for code reuse while skipping the metric value
376    * calculation when appropriate */
377   void CalculateValueAndDerivative( MeasureType & value, DerivativeType & derivative, bool calculateValue ) const;
378 
379   /**
380    * Warp the fixed point set into the moving domain based on the fixed transform,
381    * passing through the virtual domain and storing a virtual domain set.
382    * Note that the warped moving point set is of type FixedPointSetType since the transform
383    * takes the points from the fixed to the moving domain.
384    */
385   void TransformFixedAndCreateVirtualPointSet() const;
386 
387   /**
388    * Warp the moving point set based on the moving transform.  Note that the
389    * warped moving point set is of type FixedPointSetType since the transform
390    * takes the points from the moving to the fixed domain.
391    * FIXME: needs update.
392    */
393   void TransformMovingPointSet() const;
394 
395   /**
396    * Build point locators for the fixed and moving point sets to speed up
397    * derivative and value calculations.
398    */
399   void InitializePointsLocators() const;
400 
401   /**
402    * Store a derivative from a single point in a field.
403    * Only relevant when active transform has local support.
404    */
405   void StorePointDerivative( const VirtualPointType &, const DerivativeType &, DerivativeType & ) const;
406 
407   using MetricCategoryType = typename Superclass::MetricCategoryType;
408 
409   /** Get metric category */
GetMetricCategory()410   MetricCategoryType GetMetricCategory() const override
411     {
412     return Superclass::POINT_SET_METRIC;
413     }
414 
415 
416 private:
417   mutable bool m_MovingTransformPointLocatorsNeedInitialization;
418   mutable bool m_FixedTransformPointLocatorsNeedInitialization;
419 
420   // Flag to keep track of whether a warning has already been issued
421   // regarding the number of valid points.
422   mutable bool m_HaveWarnedAboutNumberOfValidPoints;
423 
424   // Flag to store derivatives at fixed point locations with the rest being zero gradient
425   // (default = true).
426   bool m_StoreDerivativeAsSparseFieldForLocalSupportTransforms;
427 
428   mutable ModifiedTimeType m_MovingTransformedPointSetTime;
429   mutable ModifiedTimeType m_FixedTransformedPointSetTime;
430 };
431 } // end namespace itk
432 
433 #ifndef ITK_MANUAL_INSTANTIATION
434 #include "itkPointSetToPointSetMetricv4.hxx"
435 #endif
436 
437 #endif
438