1 /*=========================================================================
2  *
3  *  Copyright Insight Software Consortium
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at
8  *
9  *         http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  *=========================================================================*/
18 #ifndef itkDisplacementFieldTransform_h
19 #define itkDisplacementFieldTransform_h
20 
21 #include "itkTransform.h"
22 
23 #include "itkImage.h"
24 #include "itkMatrixOffsetTransformBase.h"
25 #include "itkImageVectorOptimizerParametersHelper.h"
26 #include "itkVectorInterpolateImageFunction.h"
27 
28 namespace itk
29 {
30 
31 /** \class DisplacementFieldTransform
32  * \brief Provides local/dense/high-dimensionaltiy transformation via a
33  * a displacement field.
34  *
35  * The displacement field stores vectors of displacements, with
36  * dimension \c NDimensions. Transformation is performed at a given
37  * point by adding the displacement at that point to the input point.
38  *
39  * T(x, p), x is the position, p is the local parameter at position x.
40  * For a 2D example:
41  *
42  *  x = (x0, x1), p = (p0, p1)
43  *
44  * then T(x, p) is defined as:
45  *
46  *    T(x, p) = (T0, T1) = (x0+p0, x1+p1)
47  *
48  * During transformation, out-of-bounds input points are returned
49  * with zero displacement.
50  *
51  * The displacement field is defined using an itkImage, and must be set
52  * before use by the user, using \c SetDisplacementField. The image has
53  * the same dimensionality as the input and output spaces, defined by
54  * template parameter \c NDimensions, and is an image of vectors of
55  * type \c OutputVectorType, with dimensionality NDimensions as well.
56  *
57  * An interpolator of type \c VectorInterpolateImageFunction is used with
58  * the displacement field image. By default,
59  * VectorLinearInterpolateImageFunction is used, and the user can override
60  * using SetInterpolator.
61  *
62  * The displacement field data is stored using the common
63  * \c OptimizerParameters type
64  * in conjunction with the \c ImageVectorOptimizerParametersHelper class. This
65  * allows access of the displacement field image as if it were an itkArray,
66  * allowing transparent use with other classes.
67  * \warning The \c SetParameters
68  * method will copy the passed parameters, which can be costly since
69  * displacement fields are dense and thus potentially very large.
70  *
71  * The \c UpdateTransformParameters method simply adds the provided
72  * update array, applying the usual optional scaling factor. Derived
73  * classes may provide different behavior.
74  *
75  * Because this is a local transform, methods that have a version that takes
76  * a point must be used, such as \c TransformVector,
77  * \c TransformCovariantVector, and \c TransformDiffusionTensor. Also,
78  * \c ComputeJacobianWithRespectToParameters simply returns
79  * an identity matrix (see method documentation),
80  * and \c ComputeJacobianWithRespectToPosition should be used.
81  *
82  *
83  * \ingroup ITKDisplacementField
84  */
85 template
86 <typename TParametersValueType, unsigned int NDimensions>
87 class ITK_TEMPLATE_EXPORT DisplacementFieldTransform :
88   public Transform<TParametersValueType, NDimensions, NDimensions>
89 {
90 public:
91   ITK_DISALLOW_COPY_AND_ASSIGN(DisplacementFieldTransform);
92 
93   /** Standard class type aliases. */
94   using Self = DisplacementFieldTransform;
95   using Superclass = Transform<TParametersValueType, NDimensions, NDimensions>;
96   using Pointer = SmartPointer<Self>;
97   using ConstPointer = SmartPointer<const Self>;
98 
99   /** Run-time type information (and related methods). */
100   itkTypeMacro( DisplacementFieldTransform, Transform );
101 
102   /** New macro for creation of through a Smart Pointer */
103   itkNewMacro( Self );
104 
105   /** InverseTransform type. */
106   using InverseTransformBasePointer = typename Superclass::InverseTransformBasePointer;
107 
108   /** Scalar type. */
109   using ScalarType = typename Superclass::ScalarType;
110 
111   /** Type of the input parameters. */
112   using FixedParametersType = typename Superclass::FixedParametersType;
113   using FixedParametersValueType = typename Superclass::FixedParametersValueType;
114   using ParametersType = typename Superclass::ParametersType;
115   using ParametersValueType = typename Superclass::ParametersValueType;
116 
117   /** Jacobian types. */
118   using JacobianType = typename Superclass::JacobianType;
119   using JacobianPositionType = typename Superclass::JacobianPositionType;
120   using InverseJacobianPositionType = typename Superclass::InverseJacobianPositionType;
121 
122   /** Transform category type. */
123   using TransformCategoryType = typename Superclass::TransformCategoryType;
124 
125   /** The number of parameters defininig this transform. */
126   using NumberOfParametersType = typename Superclass::NumberOfParametersType;
127 
128   /** Standard coordinate point type for this class. */
129   using InputPointType = typename Superclass::InputPointType;
130   using OutputPointType = typename Superclass::OutputPointType;
131 
132   /** Standard vector type for this class. */
133   using InputVectorType = typename Superclass::InputVectorType;
134   using OutputVectorType = typename Superclass::OutputVectorType;
135 
136   using InputVectorPixelType = typename Superclass::InputVectorPixelType;
137   using OutputVectorPixelType = typename Superclass::OutputVectorPixelType;
138 
139   /** Standard covariant vector type for this class */
140   using InputCovariantVectorType = typename Superclass::InputCovariantVectorType;
141   using OutputCovariantVectorType = typename Superclass::OutputCovariantVectorType;
142 
143   /** Standard vnl_vector type for this class. */
144   using InputVnlVectorType = typename Superclass::InputVnlVectorType;
145   using OutputVnlVectorType = typename Superclass::OutputVnlVectorType;
146 
147   /** Standard diffusion tensor type for this class */
148   using InputDiffusionTensor3DType = typename Superclass::InputDiffusionTensor3DType;
149   using OutputDiffusionTensor3DType = typename Superclass::OutputDiffusionTensor3DType;
150 
151   /** Standard tensor type for this class */
152   using InputTensorEigenVectorType =
153       CovariantVector<ScalarType, InputDiffusionTensor3DType::Dimension>;
154   using OutputTensorEigenVectorType =
155       CovariantVector<ScalarType, OutputDiffusionTensor3DType::Dimension>;
156   /** Derivative type */
157   using DerivativeType = typename Superclass::DerivativeType;
158 
159   /** Dimension of the domain spaces. */
160   static constexpr unsigned int Dimension = NDimensions;
161 
162   /** Define the displacement field type and corresponding interpolator type. */
163   using DisplacementFieldType = Image<OutputVectorType,  Dimension>;
164   using DisplacementFieldPointer = typename DisplacementFieldType::Pointer;
165   using DisplacementFieldConstPointer = typename DisplacementFieldType::ConstPointer;
166 
167   using InterpolatorType = VectorInterpolateImageFunction
168     <DisplacementFieldType, ScalarType>;
169 
170   /** Standard types for the displacement Field */
171   using IndexType = typename DisplacementFieldType::IndexType;
172   using RegionType = typename DisplacementFieldType::RegionType;
173   using SizeType = typename DisplacementFieldType::SizeType;
174   using SpacingType = typename DisplacementFieldType::SpacingType;
175   using DirectionType = typename DisplacementFieldType::DirectionType;
176   using PointType = typename DisplacementFieldType::PointType;
177   using PixelType = typename DisplacementFieldType::PixelType;
178 
179   /** Define the internal parameter helper used to access the field */
180   using OptimizerParametersHelperType = ImageVectorOptimizerParametersHelper<
181     ScalarType,
182     OutputVectorType::Dimension,
183     Dimension>;
184 
185   /** Get/Set the displacement field.
186    * Set the displacement field. Create special set accessor to update
187    * interpolator and assign displacement field to transform parameters
188    * container. */
189   virtual void SetDisplacementField( DisplacementFieldType* field );
190   itkGetModifiableObjectMacro(DisplacementField, DisplacementFieldType );
191 
192   /** Get/Set the inverse displacement field. This must be supplied by the user for
193    * GetInverse() to work. */
194   virtual void SetInverseDisplacementField( DisplacementFieldType * inverseDisplacementField );
195   itkGetModifiableObjectMacro(InverseDisplacementField, DisplacementFieldType );
196 
197   /** Get/Set the interpolator.
198    * Create out own set accessor that assigns the displacement field. */
199   virtual void SetInterpolator( InterpolatorType* interpolator );
200   itkGetModifiableObjectMacro( Interpolator, InterpolatorType );
201 
202   /** Get/Set the interpolator for the inverse field.
203    * Create out own set accessor that assigns the displacement field. */
204   virtual void SetInverseInterpolator( InterpolatorType* interpolator );
205   itkGetModifiableObjectMacro(InverseInterpolator, InterpolatorType );
206 
207   /** Get the modification time of displacement field. */
208   itkGetConstReferenceMacro( DisplacementFieldSetTime, ModifiedTimeType );
209 
210   /**  Method to transform a point. Out-of-bounds points will
211    * be returned with zero displacemnt. */
212   OutputPointType TransformPoint( const InputPointType& thisPoint ) const override;
213 
214   /**  Method to transform a vector. */
215   using Superclass::TransformVector;
TransformVector(const InputVectorType &)216   OutputVectorType TransformVector(const InputVectorType &) const override
217   {
218     itkExceptionMacro( "TransformVector(Vector) unimplemented, use "
219                        "TransformVector(Vector,Point)" );
220   }
221 
TransformVector(const InputVectorPixelType &)222   OutputVectorPixelType TransformVector(const InputVectorPixelType &) const override
223   {
224     itkExceptionMacro( "TransformVector(Vector) unimplemented, use "
225                        "TransformVector(Vector,Point)" );
226   }
227 
TransformVector(const InputVnlVectorType &)228   OutputVnlVectorType TransformVector(const InputVnlVectorType &) const override
229   {
230     itkExceptionMacro( "TransformVector(Vector) unimplemented, use "
231                        "TransformVector(Vector,Point)" );
232   }
233 
234   /** Method to transform a tensor. */
235   using Superclass::TransformDiffusionTensor3D;
TransformDiffusionTensor(const InputDiffusionTensor3DType &)236   OutputDiffusionTensor3DType TransformDiffusionTensor(
237     const InputDiffusionTensor3DType & ) const
238   {
239     itkExceptionMacro( "TransformDiffusionTensor(Tensor) unimplemented, use "
240                        "TransformDiffusionTensor(Tensor,Point)" );
241   }
242 
TransformDiffusionTensor(const InputVectorPixelType &)243   OutputVectorPixelType TransformDiffusionTensor(const InputVectorPixelType & )
244   const
245   {
246     itkExceptionMacro( "TransformDiffusionTensor(Tensor) unimplemented, use "
247                        "TransformDiffusionTensor(Tensor,Point)" );
248   }
249 
250   /** Method to transform a CovariantVector. */
251   using Superclass::TransformCovariantVector;
TransformCovariantVector(const InputCovariantVectorType &)252   OutputCovariantVectorType TransformCovariantVector(
253     const InputCovariantVectorType &) const override
254   {
255     itkExceptionMacro( "TransformCovariantVector(CovariantVector) "
256                        "unimplemented, use TransformCovariantVector(CovariantVector,Point)" );
257   }
258 
TransformCovariantVector(const InputVectorPixelType &)259   OutputVectorPixelType TransformCovariantVector(
260     const InputVectorPixelType &) const override
261   {
262     itkExceptionMacro( "TransformCovariantVector(CovariantVector) "
263                        "unimplemented, use TransformCovariantVector(CovariantVector,Point)" );
264   }
265 
266   /** Set the transformation parameters. This sets the displacement
267    * field image directly. */
SetParameters(const ParametersType & params)268   void SetParameters(const ParametersType & params) override
269   {
270     if( &(this->m_Parameters) != &params )
271       {
272       if( params.Size() != this->m_Parameters.Size() )
273         {
274         itkExceptionMacro("Input parameters size (" << params.Size()
275                                                     << ") does not match internal size ("
276                                                     << this->m_Parameters.Size() << ").");
277         }
278       // Copy into existing object
279       this->m_Parameters = params;
280       this->Modified();
281       }
282   }
283 
284   /**
285    * This method sets the fixed parameters of the transform.
286    * For a displacement field transform, the fixed parameters are the
287    * following: field size, field origin, field spacing, and field direction.
288    *
289    * Note: If a displacement field already exists, this function
290    * creates a new one with zero displacement (identity transform). If
291    * an inverse displacement field exists, a new one is also created.
292    */
293   void SetFixedParameters( const FixedParametersType & ) override;
294 
295   /**
296    * Compute the jacobian with respect to the parameters at a point.
297    * Simply returns identity matrix, sized [NDimensions, NDimensions].
298    *
299    * T(x, p), x is the position, p is the local parameter at position x.
300    * Take a 2D example, x = (x0, x1), p = (p0, p1) and T(x, p) is defined as:
301    *
302    *    T(x, p) = (T0, T1) = (x0+p0, x1+p1)
303    *
304    * Each local deformation is defined as a translation transform.
305    * So the Jacobian w.r.t parameters are
306    *
307    * dT/dp =
308    *    [ dT0/dp0, dT0/dp1;
309    *      dT1/dp0, dT1/dp1 ];
310    *
311    *    = [1, 0;
312    *       0, 1];
313    *
314    * TODO: format the above for doxygen formula.
315    */
ComputeJacobianWithRespectToParameters(const InputPointType &,JacobianType & j)316   void ComputeJacobianWithRespectToParameters(const InputPointType &,
317                                                       JacobianType & j) const override
318   {
319     j = this->m_IdentityJacobian;
320   }
321 
322   /**
323    * Compute the jacobian with respect to the parameters at an index.
324    * Simply returns identity matrix, sized [NDimensions, NDimensions].
325    * See \c ComputeJacobianWithRespectToParameters( InputPointType, ... )
326    * for rationale.
327    */
ComputeJacobianWithRespectToParameters(const IndexType &,JacobianType & j)328   virtual void ComputeJacobianWithRespectToParameters(const IndexType &,
329                                                       JacobianType & j) const
330   {
331     j = this->m_IdentityJacobian;
332   }
333 
334   /**
335    * Compute the jacobian with respect to the position, by point.
336    * \c j will be resized as needed.
337    */
338   void ComputeJacobianWithRespectToPosition(const InputPointType  & x, JacobianPositionType & j ) const override;
339   using Superclass::ComputeJacobianWithRespectToPosition;
340 
341   /**
342    * Compute the jacobian with respect to the position, by point.
343    * \c j will be resized as needed.
344    */
345   void ComputeInverseJacobianWithRespectToPosition(const InputPointType  & x,
346                                                    InverseJacobianPositionType & j ) const override;
347   using Superclass::ComputeInverseJacobianWithRespectToPosition;
348 
349   /**
350    * Compute the jacobian with respect to the position, by index.
351    * \c j will be resized as needed.
352    */
353   virtual void ComputeJacobianWithRespectToPosition(const IndexType  & x, JacobianPositionType & j ) const;
354 
355   /**
356    * Compute the inverse jacobian of the forward displacement field with
357    * respect to the position, by point. Note that this is different than
358    * the jacobian of the inverse displacement field. This takes advantage
359    * of the ability to compute the inverse jacobian of a displacement field
360    * by simply reversing the sign of the forward jacobian.
361    * However, a more accurate method for computing the inverse
362    * jacobian is to take the inverse of the jacobian matrix. This
363    * method is more computationally expensive and may be used by
364    * setting \c useSVD to true
365    */
366   virtual void GetInverseJacobianOfForwardFieldWithRespectToPosition(const InputPointType & point,
367                                                                      JacobianPositionType & jacobian,
368                                                                      bool useSVD = false ) const;
369 
370   /**
371    * Compute the inverse jacobian of the forward displacement field with
372    * respect to the position, by index.Note that this is different than
373    * the jacobian of the inverse displacement field. This takes advantage
374    * of the ability to compute the inverse jacobian of a displacement field
375    * by simply reversing the sign of the forward jacobian.
376    * However, a more accurate method for computing the inverse
377    * jacobian is to take the inverse of the jacobian matrix. This
378    * method is more computationally expensive and may be used by
379    * setting \c useSVD to true
380    */
381   virtual void GetInverseJacobianOfForwardFieldWithRespectToPosition(const IndexType & index,
382                                                                      JacobianPositionType & jacobian,
383                                                                      bool useSVD = false ) const;
384 
385   void UpdateTransformParameters( const DerivativeType & update, ScalarType factor = 1.0 ) override;
386 
387   /** Return an inverse of this transform.
388    * Note that the inverse displacement field must be set by the user. */
389   bool GetInverse( Self *inverse ) const;
390 
391   /** Return an inverse of this transform.
392    * Note that the inverse displacement field must be set by the user. */
393   InverseTransformBasePointer GetInverseTransform() const override;
394 
395   virtual void SetIdentity();
396 
397   /** This transform is not linear. */
GetTransformCategory()398   TransformCategoryType GetTransformCategory() const override
399   {
400     return Self::DisplacementField;
401   }
402 
GetNumberOfLocalParameters()403   NumberOfParametersType GetNumberOfLocalParameters() const override
404   {
405     return Dimension;
406   }
407 
408   /** Set/Get the coordinate tolerance.
409    *  This tolerance is used when comparing the space defined
410    *  by deformation fields and it's inverse to ensure they occupy the
411    *  same physical space.
412    *
413    * \sa ImageToImageFilterCommon::SetGlobalDefaultCoordinateTolerance
414    */
415   itkSetMacro(CoordinateTolerance, double);
416   itkGetConstMacro(CoordinateTolerance, double);
417 
418   /** Set/Get the direction tolerance.
419    *  This tolerance is used to when comparing the orientation of the
420    *  deformation fields and it's inverse to ensure they occupy the
421    *  same physical space.
422    *
423    * \sa ImageToImageFilterCommon::SetGlobalDefaultDirectionTolerance
424    */
425   itkSetMacro(DirectionTolerance, double);
426   itkGetConstMacro(DirectionTolerance, double);
427 
428 protected:
429 
430   DisplacementFieldTransform();
431   ~DisplacementFieldTransform() override = default;
432   void PrintSelf( std::ostream& os, Indent indent ) const override;
433 
434   /** The displacement field and its inverse (if it exists). */
435   typename DisplacementFieldType::Pointer      m_DisplacementField;
436   typename DisplacementFieldType::Pointer      m_InverseDisplacementField;
437 
438   /** The interpolator. */
439   typename InterpolatorType::Pointer          m_Interpolator;
440   typename InterpolatorType::Pointer          m_InverseInterpolator;
441 
442   /** Track when the displacement field was last set/assigned, as
443    * distinct from when it may have had its contents modified. */
444   ModifiedTimeType m_DisplacementFieldSetTime{ 0 };
445 
446   /** Create an identity jacobian for use in
447    * ComputeJacobianWithRespectToParameters. */
448   JacobianType m_IdentityJacobian;
449 
450 private:
451   /** Internal method for calculating either forward or inverse jacobian,
452    * depending on state of \c doInverseJacobian. Used by
453    * public methods \c ComputeJacobianWithRespectToPosition and
454    * \c GetInverseJacobianOfForwardFieldWithRespectToPosition to
455    * perform actual work.
456    * \c doInverseJacobian indicates that the inverse jacobian
457    * should be returned
458    */
459   virtual void ComputeJacobianWithRespectToPositionInternal(const IndexType & index,
460                                                             JacobianPositionType & jacobian,
461                                                             bool doInverseJacobian) const;
462 
463   /**
464    * Internal method to check that the inverse and forward displacement fields have the
465    * same fixed parameters.
466    */
467   virtual void VerifyFixedParametersInformation();
468 
469   /**
470    * Convenience method which reads the information from the current
471    * displacement field into m_FixedParameters.
472    */
473   virtual void SetFixedParametersFromDisplacementField() const;
474 
475   double m_CoordinateTolerance;
476   double m_DirectionTolerance;
477 
478 };
479 
480 } // end namespace itk
481 
482 #ifndef ITK_MANUAL_INSTANTIATION
483 #include "itkDisplacementFieldTransform.hxx"
484 #endif
485 
486 #endif // itkDisplacementFieldTransform_h
487