1 /*========================================================================= 2 * 3 * Copyright Insight Software Consortium 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0.txt 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 *=========================================================================*/ 18 #ifndef itkImageRegistrationMethodv4_h 19 #define itkImageRegistrationMethodv4_h 20 21 #include "itkProcessObject.h" 22 23 #include "itkCompositeTransform.h" 24 #include "itkDataObjectDecorator.h" 25 #include "itkObjectToObjectMetricBase.h" 26 #include "itkObjectToObjectMultiMetricv4.h" 27 #include "itkObjectToObjectOptimizerBase.h" 28 #include "itkImageToImageMetricv4.h" 29 #include "itkPointSetToPointSetMetricv4.h" 30 #include "itkShrinkImageFilter.h" 31 #include "itkIdentityTransform.h" 32 #include "itkTransformParametersAdaptorBase.h" 33 34 #include <vector> 35 36 namespace itk 37 { 38 39 /** \class ImageRegistrationMethodv4 40 * \brief Interface method for the current registration framework. 41 * 42 * This interface method class encapsulates typical registration 43 * usage by incorporating all the necessary elements for performing a 44 * simple image registration between two images. This method also 45 * allows for multistage registration whereby each stage is 46 * characterize by possibly different transforms of and different 47 * image metrics. For example, many users will want to perform 48 * a linear registration followed by deformable registration where 49 * both stages are performed in multiple levels. Each level can be 50 * characterized by: 51 * 52 * \li the resolution of the virtual domain image (see below) 53 * \li smoothing of the fixed and moving images 54 * \li the coarseness of the current transform via transform adaptors 55 * (see below) 56 * 57 * Multiple stages are handled by linking multiple instantiations of 58 * this class where the output transform is added to the optional 59 * composite transform input. 60 * 61 * Transform adaptors: To accommodate new changes to the current ITK 62 * registration framework, we introduced the concept of transform adaptors. 63 * Whereas each stage is associated with a moving and, possibly, fixed 64 * transform, each level of each stage is defined by a transform adaptor 65 * which describes how to adapt the transform to the current level. For 66 * example, if one were to use the B-spline transform during a deformable 67 * registration stage, common practice is to increase the resolution of 68 * the B-spline mesh (or, analogously, the control point grid size) at 69 * each level. At each level, one would define the parameters of the 70 * B-spline transform adaptor at that level which increases the resolution 71 * from the previous level. For many transforms, such as affine, this 72 * concept of an adaptor may be nonsensical. For this reason, the base 73 * transform adaptor class does not do anything to the transform but merely 74 * passes it through. Each level of each stage must define a transform 75 * adaptor but, by default, the base adaptor class is assigned which, again, 76 * does not do anything to the transform. A special mention should be made 77 * of the transform adaptor at level 0 of any stage. Most likely, the user 78 * will not want to do anything to the transform as it enters into the 79 * given stage so typical use will be to assign the base adaptor class to 80 * level 0 of all stages but we leave that open to the user. 81 * 82 * Output: The output is the updated transform. 83 * 84 * \author Nick Tustison 85 * \author Brian Avants 86 * 87 * \ingroup ITKRegistrationMethodsv4 88 */ 89 template<typename TFixedImage, 90 typename TMovingImage, 91 typename TOutputTransform = Transform<double, TFixedImage::ImageDimension, TFixedImage::ImageDimension>, 92 typename TVirtualImage = TFixedImage, 93 typename TPointSet = PointSet<unsigned int, TFixedImage::ImageDimension> > 94 class ITK_TEMPLATE_EXPORT ImageRegistrationMethodv4 95 :public ProcessObject 96 { 97 public: 98 ITK_DISALLOW_COPY_AND_ASSIGN(ImageRegistrationMethodv4); 99 100 /** Standard class type aliases. */ 101 using Self = ImageRegistrationMethodv4; 102 using Superclass = ProcessObject; 103 using Pointer = SmartPointer<Self>; 104 using ConstPointer = SmartPointer<const Self>; 105 106 /** Method for creation through the object factory. */ 107 itkNewMacro( Self ); 108 109 /** ImageDimension constants */ 110 static constexpr unsigned int ImageDimension = TFixedImage::ImageDimension; 111 112 /** Run-time type information (and related methods). */ 113 itkTypeMacro( ImageRegistrationMethodv4, ProcessObject ); 114 115 /** Input type alias for the images and transforms. */ 116 using FixedImageType = TFixedImage; 117 using FixedImagePointer = typename FixedImageType::Pointer; 118 using FixedImageConstPointer = typename FixedImageType::ConstPointer; 119 using FixedImagesContainerType = std::vector<FixedImageConstPointer>; 120 using MovingImageType = TMovingImage; 121 using MovingImagePointer = typename MovingImageType::Pointer; 122 using MovingImageConstPointer = typename MovingImageType::ConstPointer; 123 using MovingImagesContainerType = std::vector<MovingImageConstPointer>; 124 125 using PointSetType = TPointSet; 126 using PointSetConstPointer = typename PointSetType::ConstPointer; 127 using PointSetsContainerType = std::vector<PointSetConstPointer>; 128 129 /** Metric and transform type alias */ 130 using OutputTransformType = TOutputTransform; 131 using OutputTransformPointer = typename OutputTransformType::Pointer; 132 using RealType = typename OutputTransformType::ScalarType; 133 using DerivativeType = typename OutputTransformType::DerivativeType; 134 using DerivativeValueType = typename DerivativeType::ValueType; 135 136 using InitialTransformType = Transform<RealType, ImageDimension, ImageDimension>; 137 using InitialTransformPointer = typename InitialTransformType::Pointer; 138 139 using CompositeTransformType = CompositeTransform<RealType, ImageDimension>; 140 using CompositeTransformPointer = typename CompositeTransformType::Pointer; 141 142 using MetricType = ObjectToObjectMetricBaseTemplate<RealType>; 143 using MetricPointer = typename MetricType::Pointer; 144 145 using VectorType = Vector<RealType, ImageDimension>; 146 147 using VirtualImageType = TVirtualImage; 148 using VirtualImagePointer = typename VirtualImageType::Pointer; 149 using VirtualImageBaseType = ImageBase<ImageDimension>; 150 using VirtualImageBaseConstPointer = typename VirtualImageBaseType::ConstPointer; 151 152 using MultiMetricType = ObjectToObjectMultiMetricv4<ImageDimension, ImageDimension, VirtualImageType, RealType>; 153 using ImageMetricType = ImageToImageMetricv4<FixedImageType, MovingImageType, VirtualImageType, RealType>; 154 using PointSetMetricType = PointSetToPointSetMetricv4<PointSetType, PointSetType, RealType>; 155 156 using FixedImageMaskType = typename ImageMetricType::FixedImageMaskType; 157 using FixedImageMaskConstPointer = typename FixedImageMaskType::ConstPointer; 158 using FixedImageMasksContainerType = std::vector<FixedImageMaskConstPointer>; 159 using MovingImageMaskType = typename ImageMetricType::MovingImageMaskType; 160 using MovingImageMaskConstPointer = typename MovingImageMaskType::ConstPointer; 161 using MovingImageMasksContainerType = std::vector<MovingImageMaskConstPointer>; 162 163 /** 164 * Type for the output: Using Decorator pattern for enabling the transform to be 165 * passed in the data pipeline 166 */ 167 using DecoratedOutputTransformType = DataObjectDecorator<OutputTransformType>; 168 using DecoratedOutputTransformPointer = typename DecoratedOutputTransformType::Pointer; 169 using DecoratedInitialTransformType = DataObjectDecorator<InitialTransformType>; 170 using DecoratedInitialTransformPointer = typename DecoratedInitialTransformType::Pointer; 171 172 using ShrinkFilterType = ShrinkImageFilter<FixedImageType, VirtualImageType>; 173 using ShrinkFactorsPerDimensionContainerType = typename ShrinkFilterType::ShrinkFactorsType; 174 175 using ShrinkFactorsArrayType = Array<SizeValueType>; 176 177 using SmoothingSigmasArrayType = Array<RealType>; 178 using MetricSamplingPercentageArrayType = Array<RealType>; 179 180 /** Transform adaptor type alias */ 181 using TransformParametersAdaptorType = TransformParametersAdaptorBase<InitialTransformType>; 182 using TransformParametersAdaptorPointer = typename TransformParametersAdaptorType::Pointer; 183 using TransformParametersAdaptorsContainerType = std::vector<TransformParametersAdaptorPointer>; 184 185 /** Type of the optimizer. */ 186 using OptimizerType = ObjectToObjectOptimizerBaseTemplate<RealType>; 187 using OptimizerPointer = typename OptimizerType::Pointer; 188 189 /** Weights type for the optimizer. */ 190 using OptimizerWeightsType = typename OptimizerType::ScalesType; 191 192 /** enum type for metric sampling strategy */ 193 enum MetricSamplingStrategyType { NONE, REGULAR, RANDOM }; 194 195 using MetricSamplePointSetType = typename ImageMetricType::FixedSampledPointSetType; 196 197 /** Set/get the fixed images. */ SetFixedImage(const FixedImageType * image)198 virtual void SetFixedImage( const FixedImageType *image ) 199 { 200 this->SetFixedImage( 0, image ); 201 } GetFixedImage()202 virtual const FixedImageType * GetFixedImage() const 203 { 204 return this->GetFixedImage( 0 ); 205 } 206 virtual void SetFixedImage( SizeValueType, const FixedImageType * ); 207 virtual const FixedImageType * GetFixedImage( SizeValueType ) const; 208 209 /** Set the moving images. */ SetMovingImage(const MovingImageType * image)210 virtual void SetMovingImage( const MovingImageType *image ) 211 { 212 this->SetMovingImage( 0, image ); 213 } GetMovingImage()214 virtual const MovingImageType * GetMovingImage() const 215 { 216 return this->GetMovingImage( 0 ); 217 } 218 virtual void SetMovingImage( SizeValueType, const MovingImageType * ); 219 virtual const MovingImageType * GetMovingImage( SizeValueType ) const; 220 221 /** Set/get the fixed point sets. */ SetFixedPointSet(const PointSetType * pointSet)222 virtual void SetFixedPointSet( const PointSetType *pointSet ) 223 { 224 this->SetFixedPointSet( 0, pointSet ); 225 } GetFixedPointSet()226 virtual const PointSetType * GetFixedPointSet() const 227 { 228 return this->GetFixedPointSet( 0 ); 229 } 230 virtual void SetFixedPointSet( SizeValueType, const PointSetType * ); 231 virtual const PointSetType * GetFixedPointSet( SizeValueType ) const; 232 233 /** Set the moving point sets. */ SetMovingPointSet(const PointSetType * pointSet)234 virtual void SetMovingPointSet( const PointSetType *pointSet ) 235 { 236 this->SetMovingPointSet( 0, pointSet ); 237 } GetMovingPointSet()238 virtual const PointSetType * GetMovingPointSet() const 239 { 240 return this->GetMovingPointSet( 0 ); 241 } 242 virtual void SetMovingPointSet( SizeValueType, const PointSetType * ); 243 virtual const PointSetType * GetMovingPointSet( SizeValueType ) const; 244 245 /** Set/Get the optimizer. */ 246 itkSetObjectMacro( Optimizer, OptimizerType ); 247 itkGetModifiableObjectMacro( Optimizer, OptimizerType ); 248 249 /** 250 * Set/Get the optimizer weights. Allows setting of a per-local-parameter 251 * weighting array. If unset, the weights are treated as identity. Weights 252 * are used to mask out a particular parameter during optimzation to hold 253 * it constant. Or they may be used to apply another kind of prior knowledge. 254 * The size of the weights must be equal to the number of the local transformation 255 * parameters. 256 */ 257 void SetOptimizerWeights( OptimizerWeightsType & ); 258 itkGetConstMacro( OptimizerWeights, OptimizerWeightsType ); 259 260 /** Set/Get the metric. */ 261 itkSetObjectMacro( Metric, MetricType ); 262 itkGetModifiableObjectMacro( Metric, MetricType ); 263 264 /** Set/Get the metric sampling strategy. */ 265 itkSetMacro( MetricSamplingStrategy, MetricSamplingStrategyType ); 266 itkGetConstMacro( MetricSamplingStrategy, MetricSamplingStrategyType ); 267 268 /** Reinitialize the seed for the random number generators that 269 * select the samples for some metric sampling strategies. 270 * 271 * By initializing the random number generator seed to a value the 272 * same deterministic sampling will be used each Update 273 * execution. On the other hand, calling the method 274 * ReinitializeSeed() without arguments will use the wall clock in 275 * order to have psuedo-random initialization of the seeds. This 276 * will indeed increase the non-deterministic behavior of the 277 * metric. 278 */ 279 void MetricSamplingReinitializeSeed(); 280 void MetricSamplingReinitializeSeed(int seed); 281 282 /** Set the metric sampling percentage. Valid values are in (0.0, 1.0] */ 283 void SetMetricSamplingPercentage( const RealType ); 284 285 /** Set the metric sampling percentage. Valid values are in (0.0,1.0]. */ 286 virtual void SetMetricSamplingPercentagePerLevel( const MetricSamplingPercentageArrayType &samplingPercentages ); 287 itkGetConstMacro( MetricSamplingPercentagePerLevel, MetricSamplingPercentageArrayType ); 288 289 /** Set/Get the initial fixed transform. */ 290 itkSetGetDecoratedObjectInputMacro( FixedInitialTransform, InitialTransformType ); 291 292 /** Set/Get the initial moving transform. */ 293 itkSetGetDecoratedObjectInputMacro( MovingInitialTransform, InitialTransformType ); 294 295 /** Set/Get the initial transform to be optimized 296 * 297 * This transform is composed with the MovingInitialTransform to 298 * specify the initial transformation from the moving image to 299 * the virtual image. It is used for the default parameters, and can 300 * be use to specify the transform type. 301 * 302 * If the filter has "InPlace" set then this transform will be the 303 * output transform object or "grafted" to the output. Otherwise, 304 * this InitialTransform will be deep copied or "cloned" to the 305 * output. 306 * 307 * If this parameter is not set then a default constructed output 308 * transform is used. 309 */ 310 itkSetGetDecoratedObjectInputMacro(InitialTransform, InitialTransformType); 311 312 /** Set/Get the transform adaptors. */ 313 void SetTransformParametersAdaptorsPerLevel( TransformParametersAdaptorsContainerType & ); 314 const TransformParametersAdaptorsContainerType & GetTransformParametersAdaptorsPerLevel() const; 315 316 /** 317 * Set/Get the number of multi-resolution levels. In setting the number of 318 * levels we need to set the following for each level: 319 * \li shrink factors for the virtual domain 320 * \li sigma smoothing parameter 321 * \li transform adaptor with specific parameters for the specified level 322 */ 323 void SetNumberOfLevels( const SizeValueType ); 324 itkGetConstMacro( NumberOfLevels, SizeValueType ); 325 326 /** 327 * Set the shrink factors for each level where each level has a constant 328 * shrink factor for each dimension. For example, input to the function 329 * of factors = [4,2,1] will shrink the image in every dimension by 4 330 * the first level, then by 2 at the second level, then the original resolution 331 * for the final level (uses the \c itkShrinkImageFilter). 332 */ SetShrinkFactorsPerLevel(ShrinkFactorsArrayType factors)333 void SetShrinkFactorsPerLevel( ShrinkFactorsArrayType factors ) 334 { 335 for( unsigned int level = 0; level < factors.Size(); ++level ) 336 { 337 ShrinkFactorsPerDimensionContainerType shrinkFactors; 338 shrinkFactors.Fill( factors[level] ); 339 this->SetShrinkFactorsPerDimension( level, shrinkFactors ); 340 } 341 } 342 343 /** 344 * Get the shrink factors for a specific level. 345 */ GetShrinkFactorsPerDimension(const unsigned int level)346 ShrinkFactorsPerDimensionContainerType GetShrinkFactorsPerDimension( const unsigned int level ) const 347 { 348 if( level >= this->m_ShrinkFactorsPerLevel.size() ) 349 { 350 itkExceptionMacro( "Requesting level greater than the number of levels." ); 351 } 352 return this->m_ShrinkFactorsPerLevel[level]; 353 } 354 355 /** 356 * Set the shrink factors for a specific level for each dimension. 357 */ SetShrinkFactorsPerDimension(unsigned int level,ShrinkFactorsPerDimensionContainerType factors)358 void SetShrinkFactorsPerDimension( unsigned int level, ShrinkFactorsPerDimensionContainerType factors ) 359 { 360 if( level >= this->m_ShrinkFactorsPerLevel.size() ) 361 { 362 this->m_ShrinkFactorsPerLevel.resize( level + 1 ); 363 } 364 this->m_ShrinkFactorsPerLevel[level] = factors; 365 this->Modified(); 366 } 367 368 /** 369 * Set/Get the smoothing sigmas for each level. At each resolution level, a gaussian smoothing 370 * filter (specifically, the \c itkDiscreteGaussianImageFilter) is applied. Sigma values are 371 * specified according to the option \c m_SmoothingSigmasAreSpecifiedInPhysicalUnits. 372 */ 373 itkSetMacro( SmoothingSigmasPerLevel, SmoothingSigmasArrayType ); 374 itkGetConstMacro( SmoothingSigmasPerLevel, SmoothingSigmasArrayType ); 375 376 /** 377 * Set/Get whether to specify the smoothing sigmas for each level in physical units 378 * (default) or in terms of voxels. 379 */ 380 itkSetMacro( SmoothingSigmasAreSpecifiedInPhysicalUnits, bool ); 381 itkGetConstMacro( SmoothingSigmasAreSpecifiedInPhysicalUnits, bool ); 382 itkBooleanMacro( SmoothingSigmasAreSpecifiedInPhysicalUnits ); 383 384 /** Make a DataObject of the correct type to be used as the specified output. */ 385 using DataObjectPointerArraySizeType = ProcessObject::DataObjectPointerArraySizeType; 386 using Superclass::MakeOutput; 387 DataObjectPointer MakeOutput( DataObjectPointerArraySizeType ) override; 388 389 /** Returns the transform resulting from the registration process */ 390 virtual DecoratedOutputTransformType * GetOutput(); 391 virtual const DecoratedOutputTransformType * GetOutput() const; 392 GetTransformOutput()393 virtual DecoratedOutputTransformType * GetTransformOutput() { return this->GetOutput(); } GetTransformOutput()394 virtual const DecoratedOutputTransformType * GetTransformOutput() const { return this->GetOutput(); } 395 396 virtual OutputTransformType * GetModifiableTransform(); 397 virtual const OutputTransformType * GetTransform() const; 398 399 /** Get the current level. This is a helper function for reporting observations. */ 400 itkGetConstMacro( CurrentLevel, SizeValueType ); 401 402 /** Get the current iteration. This is a helper function for reporting observations. */ 403 itkGetConstReferenceMacro( CurrentIteration, SizeValueType ); 404 405 /* Get the current metric value. This is a helper function for reporting observations. */ 406 itkGetConstReferenceMacro( CurrentMetricValue, RealType ); 407 408 /** Get the current convergence value. This is a helper function for reporting observations. */ 409 itkGetConstReferenceMacro( CurrentConvergenceValue, RealType ); 410 411 /** Get the current convergence state per level. This is a helper function for reporting observations. */ 412 itkGetConstReferenceMacro( IsConverged, bool ); 413 414 /** Request that the InitialTransform be grafted onto the output, 415 * there by not creating a copy. 416 */ 417 itkSetMacro( InPlace, bool ); 418 itkGetConstMacro( InPlace, bool ); 419 itkBooleanMacro( InPlace ); 420 421 /** 422 * Initialize the current linear transform to be optimized with the center of the 423 * previous transform in the queue. This provides a much better initialization than 424 * the default origin. 425 */ 426 itkBooleanMacro( InitializeCenterOfLinearOutputTransform ); 427 itkSetMacro( InitializeCenterOfLinearOutputTransform, bool ); 428 itkGetConstMacro( InitializeCenterOfLinearOutputTransform, bool ); 429 430 /** 431 * We try to initialize the center of a linear transform (specifically those 432 * derived from itk::MatrixOffsetTransformBase). There are a number of 433 * checks that we need to make to account for all possible scenarios: 434 * 1) we check to make sure the m_OutputTransform is of the appropriate type 435 * such that it makes sense to try to center the transform. Local transforms 436 * such as SyN and B-spline do not need to be "centered", 437 * 2) we check to make sure the composite transform (to which we'll add the 438 * m_OutputTransform) is not empty, 439 * 3) we look for the first previous transform which has a center parameter, 440 * (which, presumably, been optimized beforehand), and 441 */ 442 void InitializeCenterOfLinearOutputTransform(); 443 444 protected: 445 ImageRegistrationMethodv4(); 446 ~ImageRegistrationMethodv4() override = default; 447 void PrintSelf( std::ostream & os, Indent indent ) const override; 448 449 /** Perform the registration. */ 450 void GenerateData() override; 451 452 virtual void AllocateOutputs(); 453 454 /** Initialize by setting the interconnects between the components. */ 455 virtual void InitializeRegistrationAtEachLevel( const SizeValueType ); 456 457 /** Get the virtual domain image from the metric(s) */ 458 virtual VirtualImageBaseConstPointer GetCurrentLevelVirtualDomainImage(); 459 460 /** Get metric samples. */ 461 virtual void SetMetricSamplePoints(); 462 463 SizeValueType m_CurrentLevel; 464 SizeValueType m_NumberOfLevels; 465 SizeValueType m_CurrentIteration; 466 RealType m_CurrentMetricValue; 467 RealType m_CurrentConvergenceValue; 468 bool m_IsConverged; 469 470 FixedImagesContainerType m_FixedSmoothImages; 471 MovingImagesContainerType m_MovingSmoothImages; 472 FixedImageMasksContainerType m_FixedImageMasks; 473 MovingImageMasksContainerType m_MovingImageMasks; 474 VirtualImagePointer m_VirtualDomainImage; 475 PointSetsContainerType m_FixedPointSets; 476 PointSetsContainerType m_MovingPointSets; 477 SizeValueType m_NumberOfFixedObjects; 478 SizeValueType m_NumberOfMovingObjects; 479 480 OptimizerPointer m_Optimizer; 481 OptimizerWeightsType m_OptimizerWeights; 482 bool m_OptimizerWeightsAreIdentity; 483 484 MetricPointer m_Metric; 485 MetricSamplingStrategyType m_MetricSamplingStrategy; 486 MetricSamplingPercentageArrayType m_MetricSamplingPercentagePerLevel; 487 SizeValueType m_NumberOfMetrics; 488 int m_FirstImageMetricIndex; 489 std::vector<ShrinkFactorsPerDimensionContainerType> m_ShrinkFactorsPerLevel; 490 SmoothingSigmasArrayType m_SmoothingSigmasPerLevel; 491 bool m_SmoothingSigmasAreSpecifiedInPhysicalUnits; 492 493 bool m_ReseedIterator; 494 int m_RandomSeed; 495 int m_CurrentRandomSeed; 496 497 498 TransformParametersAdaptorsContainerType m_TransformParametersAdaptorsPerLevel; 499 500 CompositeTransformPointer m_CompositeTransform; 501 502 //TODO: m_OutputTransform should be removed and replaced with a named input parameter for 503 // the pipeline 504 OutputTransformPointer m_OutputTransform; 505 506 507 private: 508 bool m_InPlace; 509 510 bool m_InitializeCenterOfLinearOutputTransform; 511 512 // helper function to create the right kind of concrete transform 513 template<typename TTransform> MakeOutputTransform(SmartPointer<TTransform> & ptr)514 static void MakeOutputTransform(SmartPointer<TTransform> &ptr) 515 { 516 ptr = TTransform::New(); 517 } 518 MakeOutputTransform(SmartPointer<InitialTransformType> & ptr)519 static void MakeOutputTransform(SmartPointer<InitialTransformType> &ptr) 520 { 521 ptr = IdentityTransform<RealType, ImageDimension>::New().GetPointer(); 522 } 523 524 }; 525 } // end namespace itk 526 527 #ifndef ITK_MANUAL_INSTANTIATION 528 #include "itkImageRegistrationMethodv4.hxx" 529 #endif 530 531 #endif 532