1 /*========================================================================= 2 * 3 * Copyright Insight Software Consortium 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0.txt 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 *=========================================================================*/ 18 #ifndef itkPointSetToPointSetMetricv4_h 19 #define itkPointSetToPointSetMetricv4_h 20 21 #include "itkObjectToObjectMetric.h" 22 23 #include "itkFixedArray.h" 24 #include "itkPointsLocator.h" 25 #include "itkPointSet.h" 26 27 namespace itk 28 { 29 /** \class PointSetToPointSetMetricv4 30 * \brief Computes similarity between two point sets. 31 * 32 * This class is templated over the type of the two point-sets. It 33 * expects a Transform to be plugged in for each of fixed and moving 34 * point sets. The transforms default to IdenityTransform types. This particular 35 * class is the base class for a hierarchy of point-set to point-set metrics. 36 * 37 * This class computes a value that measures the similarity between the fixed 38 * point-set and the moving point-set in the moving domain. The fixed point set 39 * is transformed into the virtual domain by computing the inverse of the 40 * fixed transform, then transformed into the moving domain using the 41 * moving transform. 42 * 43 * Since the \c PointSet class permits each \c Point to be associated with a 44 * \c PixelType, there are potential applications which could make use of 45 * this additional information. For example, the derived \c LabeledPointSetToPointSetMetric 46 * class uses the \c PixelType as a \c LabelType for estimating total metric values 47 * and gradients from the individual label-wise point subset metric and derivatives 48 * 49 * If a virtual domain is not defined by the user, one of two things happens: 50 * 1) If the moving transform is a global type, then the virtual domain is 51 * left undefined and every point is considered to be within the virtual domain. 52 * 2) If the moving transform is a local-support type, then the virtual domain 53 * is taken during initialization from the moving transform displacement field, 54 * and all fixed points are verified to be within the virtual domain after 55 * transformation by the inverse fixed transform. Points outside the virtual 56 * domain are not used. See GetNumberOfValidPoints() to verify how many fixed 57 * points were used during evaluation. 58 * 59 * See ObjectToObjectMetric documentation for more discussion on the virutal domain. 60 * 61 * \note When used with an RegistrationParameterScalesEstimator estimator, a VirtualDomainPointSet 62 * must be defined and assigned to the estimator, for use in shift estimation. 63 * The virtual domain point set can be retrieved from the metric using the 64 * GetVirtualTransformedPointSet() method. 65 * 66 * \ingroup ITKMetricsv4 67 */ 68 69 template<typename TFixedPointSet, typename TMovingPointSet, 70 class TInternalComputationValueType = double> 71 class ITK_TEMPLATE_EXPORT PointSetToPointSetMetricv4 72 : public ObjectToObjectMetric<TFixedPointSet::PointDimension, TMovingPointSet::PointDimension, 73 Image<TInternalComputationValueType, TFixedPointSet::PointDimension>, TInternalComputationValueType> 74 { 75 public: 76 ITK_DISALLOW_COPY_AND_ASSIGN(PointSetToPointSetMetricv4); 77 78 /** Standard class type aliases. */ 79 using Self = PointSetToPointSetMetricv4; 80 using Superclass = ObjectToObjectMetric<TFixedPointSet::PointDimension, 81 TMovingPointSet::PointDimension, 82 Image<TInternalComputationValueType, TFixedPointSet::PointDimension>, 83 TInternalComputationValueType>; 84 using Pointer = SmartPointer<Self>; 85 using ConstPointer = SmartPointer<const Self>; 86 87 /** Run-time type information (and related methods). */ 88 itkTypeMacro( PointSetToPointSetMetricv4, ObjectToObjectMetric ); 89 90 /** Type of the measure. */ 91 using MeasureType = typename Superclass::MeasureType; 92 93 /** Type of the parameters. */ 94 using ParametersType = typename Superclass::ParametersType; 95 using ParametersValueType = typename Superclass::ParametersValueType; 96 using NumberOfParametersType = typename Superclass::NumberOfParametersType; 97 98 /** Type of the derivative. */ 99 using DerivativeType = typename Superclass::DerivativeType; 100 101 /** Transform types from Superclass*/ 102 using FixedTransformType = typename Superclass::FixedTransformType; 103 using FixedTransformPointer = typename Superclass::FixedTransformPointer; 104 using FixedInputPointType = typename Superclass::FixedInputPointType; 105 using FixedOutputPointType = typename Superclass::FixedOutputPointType; 106 using FixedTransformParametersType = typename Superclass::FixedTransformParametersType; 107 108 using MovingTransformType = typename Superclass::MovingTransformType; 109 using MovingTransformPointer = typename Superclass::MovingTransformPointer; 110 using MovingInputPointType = typename Superclass::MovingInputPointType; 111 using MovingOutputPointType = typename Superclass::MovingOutputPointType; 112 using MovingTransformParametersType = typename Superclass::MovingTransformParametersType; 113 114 using JacobianType = typename Superclass::JacobianType; 115 using FixedTransformJacobianType = typename Superclass::FixedTransformJacobianType; 116 using MovingTransformJacobianType = typename Superclass::MovingTransformJacobianType; 117 118 using DisplacementFieldTransformType = typename Superclass::MovingDisplacementFieldTransformType; 119 120 using ObjectType = typename Superclass::ObjectType; 121 122 /** Dimension type */ 123 using DimensionType = typename Superclass::DimensionType; 124 125 /** Type of the fixed point set. */ 126 using FixedPointSetType = TFixedPointSet; 127 using FixedPointType = typename TFixedPointSet::PointType; 128 using FixedPixelType = typename TFixedPointSet::PixelType; 129 using FixedPointsContainer = typename TFixedPointSet::PointsContainer; 130 131 static constexpr DimensionType FixedPointDimension = Superclass::FixedDimension; 132 133 /** Type of the moving point set. */ 134 using MovingPointSetType = TMovingPointSet; 135 using MovingPointType = typename TMovingPointSet::PointType; 136 using MovingPixelType = typename TMovingPointSet::PixelType; 137 using MovingPointsContainer = typename TMovingPointSet::PointsContainer; 138 139 static constexpr DimensionType MovingPointDimension = Superclass::MovingDimension; 140 141 /** 142 * typedefs for the data types used in the point set metric calculations. 143 * It is assumed that the constants of the fixed point set, such as the 144 * point dimension, are the same for the "common space" in which the metric 145 * calculation occurs. 146 */ 147 static constexpr DimensionType PointDimension = Superclass::FixedDimension; 148 149 using PointType = FixedPointType; 150 using PixelType = FixedPixelType; 151 using CoordRepType = typename PointType::CoordRepType; 152 using PointsContainer = FixedPointsContainer; 153 using PointsConstIterator = typename PointsContainer::ConstIterator; 154 using PointIdentifier = typename PointsContainer::ElementIdentifier; 155 156 /** Typedef for points locator class to speed up finding neighboring points */ 157 using PointsLocatorType = PointsLocator< PointsContainer>; 158 using NeighborsIdentifierType = typename PointsLocatorType::NeighborsIdentifierType; 159 160 using FixedTransformedPointSetType = PointSet<FixedPixelType, Self::PointDimension >; 161 using MovingTransformedPointSetType = PointSet<MovingPixelType, Self::PointDimension >; 162 163 using DerivativeValueType = typename DerivativeType::ValueType; 164 using LocalDerivativeType = FixedArray<DerivativeValueType, Self::PointDimension >; 165 166 /** Types for the virtual domain */ 167 using VirtualImageType = typename Superclass::VirtualImageType; 168 using VirtualImagePointer = typename Superclass::VirtualImagePointer; 169 using VirtualPixelType = typename Superclass::VirtualPixelType; 170 using VirtualRegionType = typename Superclass::VirtualRegionType; 171 using VirtualSizeType = typename Superclass::VirtualSizeType; 172 using VirtualSpacingType = typename Superclass::VirtualSpacingType; 173 using VirtualOriginType = typename Superclass::VirtualPointType; 174 using VirtualPointType = typename Superclass::VirtualPointType; 175 using VirtualDirectionType = typename Superclass::VirtualDirectionType; 176 using VirtualRadiusType = typename Superclass::VirtualSizeType; 177 using VirtualIndexType = typename Superclass::VirtualIndexType; 178 using VirtualPointSetType = typename Superclass::VirtualPointSetType; 179 using VirtualPointSetPointer = typename Superclass::VirtualPointSetPointer; 180 181 /** Set fixed point set*/ SetFixedObject(const ObjectType * object)182 void SetFixedObject( const ObjectType *object ) override 183 { 184 auto * pointSet = dynamic_cast<FixedPointSetType *>( const_cast<ObjectType *>( object ) ); 185 if( pointSet != nullptr ) 186 { 187 this->SetFixedPointSet( pointSet ); 188 } 189 else 190 { 191 itkExceptionMacro( "Incorrect object type. Should be a point set." ) 192 } 193 } 194 195 /** Set moving point set*/ SetMovingObject(const ObjectType * object)196 void SetMovingObject( const ObjectType *object ) override 197 { 198 auto * pointSet = dynamic_cast<MovingPointSetType *>( const_cast<ObjectType *>( object ) ); 199 if( pointSet != nullptr ) 200 { 201 this->SetMovingPointSet( pointSet ); 202 } 203 else 204 { 205 itkExceptionMacro( "Incorrect object type. Should be a point set." ) 206 } 207 } 208 209 /** Get/Set the fixed pointset. */ 210 itkSetConstObjectMacro( FixedPointSet, FixedPointSetType ); 211 itkGetConstObjectMacro( FixedPointSet, FixedPointSetType ); 212 213 /** Get the fixed transformed point set. */ 214 itkGetModifiableObjectMacro( FixedTransformedPointSet, FixedTransformedPointSetType ); 215 216 /** Get/Set the moving point set. */ 217 itkSetConstObjectMacro( MovingPointSet, MovingPointSetType ); 218 itkGetConstObjectMacro( MovingPointSet, MovingPointSetType ); 219 220 /** Get the moving transformed point set. */ 221 itkGetModifiableObjectMacro( MovingTransformedPointSet, MovingTransformedPointSetType ); 222 223 /** 224 * For now return the number of points used in the value/derivative calculations. 225 */ 226 SizeValueType GetNumberOfComponents() const; 227 228 /** 229 * This method returns the value of the metric based on the current 230 * transformation(s). This function can be redefined in derived classes 231 * but many point set metrics follow the same structure---one iterates 232 * through the points and, for each point a metric value is calculated. 233 * The summation of these individual point metric values gives the total 234 * value of the metric. Note that this might not be applicable to all 235 * point set metrics. For those cases, the developer will have to redefine 236 * the GetValue() function. 237 */ 238 MeasureType GetValue() const override; 239 240 /** 241 * This method returns the derivative based on the current 242 * transformation(s). This function can be redefined in derived classes 243 * but many point set metrics follow the same structure---one iterates 244 * through the points and, for each point a derivative is calculated. 245 * The set of all these local derivatives constitutes the total derivative. 246 * Note that this might not be applicable to all point set metrics. For 247 * those cases, the developer will have to redefine the GetDerivative() 248 * function. 249 */ 250 void GetDerivative( DerivativeType & ) const override; 251 252 /** 253 * This method returns the derivative and value based on the current 254 * transformation(s). This function can be redefined in derived classes 255 * but many point set metrics follow the same structure---one iterates 256 * through the points and, for each point a derivative and value is calculated. 257 * The set of all these local derivatives/values constitutes the total 258 * derivative and value. Note that this might not be applicable to all 259 * point set metrics. For those cases, the developer will have to redefine 260 * the GetValue() and GetDerivative() functions. 261 */ 262 void GetValueAndDerivative( MeasureType &, DerivativeType & ) const override; 263 264 /** 265 * Function to be defined in the appropriate derived classes. Calculates 266 * the local metric value for a single point. The \c PixelType may or 267 * may not be used. See class description for further explanation. 268 */ 269 virtual MeasureType GetLocalNeighborhoodValue( const PointType &, const PixelType & pixel ) const = 0; 270 271 /** 272 * Calculates the local derivative for a single point. The \c PixelType may or 273 * may not be used. See class description for further explanation. 274 */ 275 virtual LocalDerivativeType GetLocalNeighborhoodDerivative( const PointType &, const PixelType & pixel ) const; 276 277 /** 278 * Calculates the local value/derivative for a single point. The \c PixelType may or 279 * may not be used. See class description for further explanation. 280 */ 281 virtual void GetLocalNeighborhoodValueAndDerivative( const PointType &, 282 MeasureType &, LocalDerivativeType &, const PixelType & pixel ) const = 0; 283 284 /** 285 * Get the virtual point set, derived from the fixed point set. 286 * If the virtual point set has not yet been derived, it will be 287 * in this call. */ 288 const VirtualPointSetType * GetVirtualTransformedPointSet() const; 289 290 /** 291 * Initialize the metric by making sure that all the components 292 * are present and plugged together correctly. 293 */ 294 void Initialize() override; 295 SupportsArbitraryVirtualDomainSamples()296 bool SupportsArbitraryVirtualDomainSamples() const override 297 { 298 /* An arbitrary point in the virtual domain will not always 299 * correspond to a point within either point set. */ 300 return false; 301 } 302 303 /** 304 * By default, the point set metric derivative for a displacement field transform 305 * is stored by saving the gradient for every voxel in the displacement field (see 306 * the function StorePointDerivative()). Since the "fixed points" will typically 307 * constitute a sparse set, this means that the field will have zero gradient values 308 * at every voxel that doesn't have a corresponding point. This might cause additional 309 * computation time for certain transforms (e.g. B-spline SyN). To avoid this, this 310 * option permits storing the point derivative only at the fixed point locations. 311 * If this variable is set to false, then the derivative array will be of length 312 * = PointDimension * m_FixedPointSet->GetNumberOfPoints(). 313 */ 314 itkSetMacro( StoreDerivativeAsSparseFieldForLocalSupportTransforms, bool ); 315 itkGetConstMacro( StoreDerivativeAsSparseFieldForLocalSupportTransforms, bool ); 316 itkBooleanMacro( StoreDerivativeAsSparseFieldForLocalSupportTransforms ); 317 318 /** 319 * 320 */ 321 itkSetMacro( CalculateValueAndDerivativeInTangentSpace, bool ); 322 itkGetConstMacro( CalculateValueAndDerivativeInTangentSpace, bool ); 323 itkBooleanMacro( CalculateValueAndDerivativeInTangentSpace ); 324 325 protected: 326 PointSetToPointSetMetricv4(); 327 ~PointSetToPointSetMetricv4() override = default; 328 void PrintSelf( std::ostream & os, Indent indent ) const override; 329 330 typename FixedPointSetType::ConstPointer m_FixedPointSet; 331 mutable typename FixedTransformedPointSetType::Pointer m_FixedTransformedPointSet; 332 333 mutable typename PointsLocatorType::Pointer m_FixedTransformedPointsLocator; 334 335 typename MovingPointSetType::ConstPointer m_MovingPointSet; 336 mutable typename MovingTransformedPointSetType::Pointer m_MovingTransformedPointSet; 337 338 mutable typename PointsLocatorType::Pointer m_MovingTransformedPointsLocator; 339 340 /** Holds the fixed points after transformation into virtual domain. */ 341 mutable VirtualPointSetPointer m_VirtualTransformedPointSet; 342 343 /** 344 * Bool set by derived classes on whether the point set data (i.e. \c PixelType) 345 * should be used. Default = false. 346 */ 347 bool m_UsePointSetData; 348 349 /** 350 * Flag to calculate value and/or derivative at tangent space. This is needed 351 * for the diffeomorphic registration methods. The fixed and moving points are 352 * warped to the virtual domain where the metric is calculated. Derived point 353 * set metrics might have associated gradient information which will need to be 354 * warped if this flag is true. Default = false. 355 */ 356 bool m_CalculateValueAndDerivativeInTangentSpace; 357 358 /** 359 * Prepare point sets for use. */ 360 virtual void InitializePointSets() const; 361 362 /** 363 * Initialize to prepare for a particular iteration, generally 364 * an iteration of optimization. Distinct from Initialize() 365 * which is a one-time initialization. */ 366 virtual void InitializeForIteration() const; 367 368 /** 369 * Determine the number of valid fixed points. A fixed point 370 * is valid if, when transformed into the virtual domain using 371 * the inverse of the FixedTransform, it is within the defined 372 * virtual domain bounds. */ 373 virtual SizeValueType CalculateNumberOfValidFixedPoints() const; 374 375 /** Helper method allows for code reuse while skipping the metric value 376 * calculation when appropriate */ 377 void CalculateValueAndDerivative( MeasureType & value, DerivativeType & derivative, bool calculateValue ) const; 378 379 /** 380 * Warp the fixed point set into the moving domain based on the fixed transform, 381 * passing through the virtual domain and storing a virtual domain set. 382 * Note that the warped moving point set is of type FixedPointSetType since the transform 383 * takes the points from the fixed to the moving domain. 384 */ 385 void TransformFixedAndCreateVirtualPointSet() const; 386 387 /** 388 * Warp the moving point set based on the moving transform. Note that the 389 * warped moving point set is of type FixedPointSetType since the transform 390 * takes the points from the moving to the fixed domain. 391 * FIXME: needs update. 392 */ 393 void TransformMovingPointSet() const; 394 395 /** 396 * Build point locators for the fixed and moving point sets to speed up 397 * derivative and value calculations. 398 */ 399 void InitializePointsLocators() const; 400 401 /** 402 * Store a derivative from a single point in a field. 403 * Only relevant when active transform has local support. 404 */ 405 void StorePointDerivative( const VirtualPointType &, const DerivativeType &, DerivativeType & ) const; 406 407 using MetricCategoryType = typename Superclass::MetricCategoryType; 408 409 /** Get metric category */ GetMetricCategory()410 MetricCategoryType GetMetricCategory() const override 411 { 412 return Superclass::POINT_SET_METRIC; 413 } 414 415 416 private: 417 mutable bool m_MovingTransformPointLocatorsNeedInitialization; 418 mutable bool m_FixedTransformPointLocatorsNeedInitialization; 419 420 // Flag to keep track of whether a warning has already been issued 421 // regarding the number of valid points. 422 mutable bool m_HaveWarnedAboutNumberOfValidPoints; 423 424 // Flag to store derivatives at fixed point locations with the rest being zero gradient 425 // (default = true). 426 bool m_StoreDerivativeAsSparseFieldForLocalSupportTransforms; 427 428 mutable ModifiedTimeType m_MovingTransformedPointSetTime; 429 mutable ModifiedTimeType m_FixedTransformedPointSetTime; 430 }; 431 } // end namespace itk 432 433 #ifndef ITK_MANUAL_INSTANTIATION 434 #include "itkPointSetToPointSetMetricv4.hxx" 435 #endif 436 437 #endif 438