1 /*=========================================================================
2 *
3 * Copyright Insight Software Consortium
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0.txt
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 *=========================================================================*/
18
19 #include "itkObjectToObjectMultiMetricv4.h"
20 #include "itkMeanSquaresImageToImageMetricv4.h"
21 #include "itkMattesMutualInformationImageToImageMetricv4.h"
22 #include "itkJointHistogramMutualInformationImageToImageMetricv4.h"
23 #include "itkANTSNeighborhoodCorrelationImageToImageMetricv4.h"
24 #include "itkTranslationTransform.h"
25 #include "itkLinearInterpolateImageFunction.h"
26 #include "itkImage.h"
27 #include "itkGaussianImageSource.h"
28 #include "itkShiftScaleImageFilter.h"
29 #include "itkTestingMacros.h"
30 #include "itkCompositeTransform.h"
31 #include "itkEuclideanDistancePointSetToPointSetMetricv4.h"
32 #include "itkExpectationBasedPointSetToPointSetMetricv4.h"
33 #include "itkRegistrationParameterScalesFromPhysicalShift.h"
34
35
36 /** This test illustrates the use of the MultivariateImageToImageMetric class, which
37 takes N metrics and assigns a weight to each metric's result.
38 */
39
40 constexpr unsigned int ObjectToObjectMultiMetricv4TestDimension = 2;
41 using ObjectToObjectMultiMetricv4TestMultiMetricType = itk::ObjectToObjectMultiMetricv4<ObjectToObjectMultiMetricv4TestDimension,ObjectToObjectMultiMetricv4TestDimension>;
42
43 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
44
itkObjectToObjectMultiMetricv4TestEvaluate(ObjectToObjectMultiMetricv4TestMultiMetricType::Pointer & multiVariateMetric,bool useDisplacementTransform)45 int itkObjectToObjectMultiMetricv4TestEvaluate( ObjectToObjectMultiMetricv4TestMultiMetricType::Pointer & multiVariateMetric, bool useDisplacementTransform )
46 {
47 int testStatus = EXIT_SUCCESS;
48 using MultiMetricType = ObjectToObjectMultiMetricv4TestMultiMetricType;
49
50 // Setup weights
51 MultiMetricType::WeightsArrayType origMetricWeights( multiVariateMetric->GetNumberOfMetrics() );
52 MultiMetricType::WeightValueType weightSum = 0;
53 for( itk::SizeValueType n = 0; n < multiVariateMetric->GetNumberOfMetrics(); n++ )
54 {
55 origMetricWeights[n] = static_cast<MultiMetricType::WeightValueType>( n + 1 );
56 weightSum += origMetricWeights[n];
57 }
58 multiVariateMetric->SetMetricWeights( origMetricWeights );
59
60 // Initialize. This initializes all the component metrics.
61 std::cout << "Initialize" << std::endl;
62 multiVariateMetric->Initialize();
63
64 // Print out metric value and derivative.
65 using MeasureType = MultiMetricType::MeasureType;
66 MeasureType measure = 0;
67 MultiMetricType::DerivativeType DerivResultOfGetValueAndDerivative;
68 std::cout << "GetValueAndDerivative" << std::endl;
69 try
70 {
71 multiVariateMetric->GetValueAndDerivative( measure, DerivResultOfGetValueAndDerivative );
72 }
73 catch (itk::ExceptionObject& exp)
74 {
75 std::cerr << "Exception caught during call to GetValueAndDerivative:" << std::endl;
76 std::cerr << exp << std::endl;
77 testStatus = EXIT_FAILURE;
78 }
79 std::cout << "Multivariate measure: " << measure << std::endl;
80 if( ! useDisplacementTransform )
81 {
82 std::cout << " Derivative : " << DerivResultOfGetValueAndDerivative << std::endl << std::endl;
83 }
84
85 // Test GetDerivative
86 MultiMetricType::DerivativeType ResultOfGetDerivative;
87 multiVariateMetric->GetDerivative( ResultOfGetDerivative );
88 for( MultiMetricType::NumberOfParametersType p = 0; p < multiVariateMetric->GetNumberOfParameters(); p++ )
89 {
90 //When accumulation is done accross multiple threads, the accumulations can be done
91 //in different orders resulting in slightly different numerical results.
92 //The FloatAlmostEqual is used to address the multi-threaded accumulation differences
93 if( !itk::Math::FloatAlmostEqual( ResultOfGetDerivative[p], DerivResultOfGetValueAndDerivative[p], 8, 1e-15 ) )
94 {
95 std::cerr << "Results do not match between GetValueAndDerivative and GetDerivative." << std::endl;
96 std::cout << ResultOfGetDerivative << " != " << DerivResultOfGetValueAndDerivative << std::endl;
97 std::cout << "DIFF: " << ResultOfGetDerivative - DerivResultOfGetValueAndDerivative << std::endl;
98 testStatus = EXIT_FAILURE;
99 }
100 }
101
102 // Test GetValue method
103 MeasureType measure2 = 0;
104 std::cout << "GetValue" << std::endl;
105 try
106 {
107 measure2 = multiVariateMetric->GetValue();
108 }
109 catch (itk::ExceptionObject& exp)
110 {
111 std::cerr << "Exception caught during call to GetValue:" << std::endl;
112 std::cerr << exp << std::endl;
113 testStatus = EXIT_FAILURE;
114 }
115 if( ! itk::Math::FloatAlmostEqual( measure2, measure ) )
116 {
117 std::cerr << "measure does not match between calls to GetValue and GetValueAndDerivative: "
118 << "measure: " << measure << " measure2: " << measure2 << std::endl;
119 testStatus = EXIT_FAILURE;
120 }
121
122 // Evaluate individually
123 MeasureType metricValue = itk::NumericTraits<MeasureType>::ZeroValue();
124 MeasureType weightedMetricValue = itk::NumericTraits<MeasureType>::ZeroValue();
125 MultiMetricType::DerivativeType metricDerivative;
126 MultiMetricType::DerivativeType DerivResultOfGetValueAndDerivativeTruth( multiVariateMetric->GetNumberOfParameters() );
127 DerivResultOfGetValueAndDerivativeTruth.Fill( itk::NumericTraits<MultiMetricType::DerivativeValueType>::ZeroValue() );
128 MultiMetricType::DerivativeValueType totalMagnitude = itk::NumericTraits<MultiMetricType::DerivativeValueType>::ZeroValue();
129
130 for (itk::SizeValueType i = 0; i < multiVariateMetric->GetNumberOfMetrics(); i++)
131 {
132 std::cout << "GetValueAndDerivative on component metrics" << std::endl;
133 multiVariateMetric->GetMetricQueue()[i]->GetValueAndDerivative( metricValue, metricDerivative );
134 std::cout << " Metric " << i << " value : " << metricValue << std::endl;
135 if( ! useDisplacementTransform )
136 {
137 std::cout << " Metric " << i << " derivative : " << metricDerivative << std::endl << std::endl;
138 }
139 if( ! itk::Math::FloatAlmostEqual( metricValue, multiVariateMetric->GetValueArray()[i] ) )
140 {
141 std::cerr << "Individual metric value " << metricValue
142 << " does not match that returned from multi-variate metric: " << multiVariateMetric->GetValueArray()[i]
143 << std::endl;
144 testStatus = EXIT_FAILURE;
145 }
146 weightedMetricValue += metricValue * origMetricWeights[i] / weightSum;
147 for( MultiMetricType::NumberOfParametersType p = 0; p < multiVariateMetric->GetNumberOfParameters(); p++ )
148 {
149 DerivResultOfGetValueAndDerivativeTruth[p] += metricDerivative[p] * ( origMetricWeights[i] / weightSum ) / metricDerivative.magnitude();
150 }
151 totalMagnitude += metricDerivative.magnitude();
152 }
153 totalMagnitude /= multiVariateMetric->GetNumberOfMetrics();
154 for( MultiMetricType::NumberOfParametersType p = 0; p < multiVariateMetric->GetNumberOfParameters(); p++ )
155 {
156 DerivResultOfGetValueAndDerivativeTruth[p] *= totalMagnitude;
157 }
158
159 if( std::fabs( weightedMetricValue - multiVariateMetric->GetWeightedValue() ) > 1e-6 )
160 {
161 std::cerr << "Computed weighted metric value " << weightedMetricValue << " does match returned value "
162 << multiVariateMetric->GetWeightedValue() << std::endl;
163 testStatus = EXIT_FAILURE;
164 }
165
166 for( MultiMetricType::NumberOfParametersType p = 0; p < multiVariateMetric->GetNumberOfParameters(); p++ )
167 {
168 auto tolerance = static_cast<MultiMetricType::DerivativeValueType> (1e-6);
169 if( std::fabs(DerivResultOfGetValueAndDerivativeTruth[p] - DerivResultOfGetValueAndDerivative[p]) > tolerance )
170 {
171 std::cerr << "Error: DerivResultOfGetValueAndDerivative does not match expected result." << std::endl;
172 if( useDisplacementTransform )
173 {
174 std::cerr << " DerivResultOfGetValueAndDerivative[" << p << "]: " << DerivResultOfGetValueAndDerivative[p] << std::endl
175 << " DerivResultOfGetValueAndDerivativeTruth[" << p << "]: " << DerivResultOfGetValueAndDerivativeTruth[p] << std::endl;
176 }
177 else
178 {
179 std::cerr << " DerivResultOfGetValueAndDerivative: " << DerivResultOfGetValueAndDerivative << std::endl
180 << " DerivResultOfGetValueAndDerivativeTruth: " << DerivResultOfGetValueAndDerivativeTruth << std::endl;
181 }
182 testStatus = EXIT_FAILURE;
183 }
184 }
185
186 return testStatus;
187 }
188
189 ////////////////////////////////////////////////////////////
190
itkObjectToObjectMultiMetricv4TestRun(bool useDisplacementTransform)191 int itkObjectToObjectMultiMetricv4TestRun(bool useDisplacementTransform )
192 {
193 // Create two simple images
194 const unsigned int Dimension = ObjectToObjectMultiMetricv4TestDimension;
195 using PixelType = double;
196 using CoordinateRepresentationType = double;
197
198 // Allocate Images
199 using FixedImageType = itk::Image<PixelType,Dimension>;
200 using MovingImageType = itk::Image<PixelType,Dimension>;
201
202 // Declare Gaussian Sources
203 using FixedImageSourceType = itk::GaussianImageSource< FixedImageType >;
204
205 // Note: the following declarations are classical arrays
206 FixedImageType::SizeValueType fixedImageSize[] = { 100, 100 };
207 FixedImageType::SpacingValueType fixedImageSpacing[] = { 1.0f, 1.0f };
208 FixedImageType::PointValueType fixedImageOrigin[] = { 0.0f, 0.0f };
209 FixedImageSourceType::Pointer fixedImageSource = FixedImageSourceType::New();
210
211 fixedImageSource->SetSize( fixedImageSize );
212 fixedImageSource->SetOrigin( fixedImageOrigin );
213 fixedImageSource->SetSpacing( fixedImageSpacing );
214 fixedImageSource->SetNormalized( false );
215 fixedImageSource->SetScale( 1.0f );
216 fixedImageSource->Update(); // Force the filter to run
217 FixedImageType::Pointer fixedImage = fixedImageSource->GetOutput();
218
219 using ShiftScaleFilterType = itk::ShiftScaleImageFilter<FixedImageType, MovingImageType>;
220 ShiftScaleFilterType::Pointer shiftFilter = ShiftScaleFilterType::New();
221 shiftFilter->SetInput( fixedImage );
222 shiftFilter->SetShift( 2.0 );
223 shiftFilter->Update();
224 MovingImageType::Pointer movingImage = shiftFilter->GetOutput();
225
226 // Set up the metric.
227 using MultiMetricType = ObjectToObjectMultiMetricv4TestMultiMetricType;
228 MultiMetricType::Pointer multiVariateMetric = MultiMetricType::New();
229
230 // Instantiate and Add metrics to the queue
231 using JointHistorgramMetrictype = itk::JointHistogramMutualInformationImageToImageMetricv4<FixedImageType,MovingImageType>;
232 using MeanSquaresMetricType = itk::MeanSquaresImageToImageMetricv4<FixedImageType,MovingImageType>;
233 using MattesMutualInformationMetricType = itk::MattesMutualInformationImageToImageMetricv4 <FixedImageType,MovingImageType>;
234 using ANTSNCMetricType = itk::ANTSNeighborhoodCorrelationImageToImageMetricv4<FixedImageType,MovingImageType>;
235
236 MeanSquaresMetricType::Pointer m1 = MeanSquaresMetricType::New();
237 MattesMutualInformationMetricType::Pointer m2 = MattesMutualInformationMetricType::New();
238 JointHistorgramMetrictype::Pointer m3 = JointHistorgramMetrictype::New();
239 ANTSNCMetricType::Pointer m4 = ANTSNCMetricType::New();
240
241 // Set up a transform
242 using TransformType = itk::Transform<CoordinateRepresentationType, Dimension, Dimension>;
243 using DisplacementTransformType = itk::DisplacementFieldTransform<double, Dimension>;
244 using TranslationTransformType = itk::TranslationTransform<CoordinateRepresentationType,Dimension>;
245 TransformType::Pointer transform;
246
247 if( useDisplacementTransform )
248 {
249 using FieldType = DisplacementTransformType::DisplacementFieldType;
250 using VectorType = itk::Vector<double, Dimension>;
251
252 VectorType zero;
253 zero.Fill(0.0);
254
255 FieldType::Pointer field = FieldType::New();
256 field->SetRegions( fixedImage->GetBufferedRegion() );
257 field->SetSpacing( fixedImage->GetSpacing() );
258 field->SetOrigin( fixedImage->GetOrigin() );
259 field->Allocate();
260 field->FillBuffer(zero);
261
262 DisplacementTransformType::Pointer displacementTransform = DisplacementTransformType::New();
263 displacementTransform->SetDisplacementField(field);
264 transform = displacementTransform;
265 }
266 else
267 {
268 TranslationTransformType::Pointer translationTransform = TranslationTransformType::New();
269 translationTransform->SetIdentity();
270 transform = translationTransform;
271 }
272
273 // Plug the images and transform into the metrics
274 std::cout << "Setup metrics" << std::endl;
275 m1->SetFixedImage(fixedImage);
276 m1->SetMovingImage(movingImage);
277 m1->SetMovingTransform( transform );
278 m2->SetFixedImage(fixedImage);
279 m2->SetMovingImage(movingImage);
280 m2->SetMovingTransform( transform );
281 m3->SetFixedImage(fixedImage);
282 m3->SetMovingImage(movingImage);
283 m3->SetMovingTransform( transform );
284 m4->SetFixedImage(fixedImage);
285 m4->SetMovingImage(movingImage);
286 m4->SetMovingTransform( transform );
287
288 // Add the component metrics
289 std::cout << "Add component metrics" << std::endl;
290 multiVariateMetric->AddMetric(m1);
291 multiVariateMetric->AddMetric(m2);
292 multiVariateMetric->AddMetric(m3);
293 multiVariateMetric->AddMetric(m4);
294
295 if( multiVariateMetric->GetMetricQueue()[0] != m1 || multiVariateMetric->GetMetricQueue()[3] != m4 )
296 {
297 std::cerr << "AddMetric or GetMetricQueue failed." << std::endl;
298 return EXIT_FAILURE;
299 }
300
301 // Expect return true because all image metrics
302 if( multiVariateMetric->SupportsArbitraryVirtualDomainSamples() == false )
303 {
304 std::cerr << "Expected SupportsArbitraryVirtualDomainSamples() to return false, but got true. " << std::endl;
305 return EXIT_FAILURE;
306 }
307
308 // Test Set/Get Transform mechanics
309 multiVariateMetric->Initialize();
310 if( multiVariateMetric->GetMovingTransform() != transform.GetPointer() )
311 {
312 std::cerr << "Automatic transform assignment failed. transform: " << transform.GetPointer() << " GetMovingTranform: " << multiVariateMetric->GetMovingTransform() << std::endl;
313 return EXIT_FAILURE;
314 }
315 multiVariateMetric->SetMovingTransform( nullptr );
316 for( itk::SizeValueType n = 0; n < multiVariateMetric->GetNumberOfMetrics(); n++ )
317 {
318 if( multiVariateMetric->GetMovingTransform() != nullptr || multiVariateMetric->GetMetricQueue()[n]->GetMovingTransform() != nullptr )
319 {
320 std::cerr << "Assignment of null transform failed. multiVariateMetric->GetMovingTransform(): " << multiVariateMetric->GetMovingTransform()
321 << " multiVariateMetric->GetMetricQueue()[" << n << "]->GetMovingTransform(): "
322 << multiVariateMetric->GetMetricQueue()[n]->GetMovingTransform() << std::endl;
323 return EXIT_FAILURE;
324 }
325 }
326 multiVariateMetric->SetMovingTransform( transform );
327 for( itk::SizeValueType n = 0; n < multiVariateMetric->GetNumberOfMetrics(); n++ )
328 {
329 if( multiVariateMetric->GetMovingTransform() != transform.GetPointer() ||
330 multiVariateMetric->GetMetricQueue()[0]->GetMovingTransform() != transform.GetPointer() )
331 {
332 std::cerr << "Assignment of transform failed." << std::endl;
333 return EXIT_FAILURE;
334 }
335 }
336 if( multiVariateMetric->GetMovingTransform() != transform.GetPointer() )
337 {
338 std::cerr << "Retrieval of transform failed." << std::endl;
339 }
340
341 // Test with images
342 std::cout << "*** Test image metrics *** " << std::endl;
343 if( itkObjectToObjectMultiMetricv4TestEvaluate( multiVariateMetric, useDisplacementTransform ) != EXIT_SUCCESS )
344 {
345 return EXIT_FAILURE;
346 }
347
348 std::cout << "*** Test with mismatched transforms *** " << std::endl;
349 TranslationTransformType::Pointer transform2 = TranslationTransformType::New();
350 m4->SetMovingTransform( transform2 );
351 TRY_EXPECT_EXCEPTION( multiVariateMetric->Initialize() );
352 m4->SetMovingTransform( transform );
353
354 std::cout << "*** Test with proper CompositeTransform ***" << std::endl;
355 using CompositeTransformType = itk::CompositeTransform<CoordinateRepresentationType,Dimension>;
356 CompositeTransformType::Pointer compositeTransform = CompositeTransformType::New();
357 compositeTransform->AddTransform( transform2 );
358 compositeTransform->AddTransform( transform );
359 compositeTransform->SetOnlyMostRecentTransformToOptimizeOn();
360 m4->SetMovingTransform( compositeTransform );
361 if( itkObjectToObjectMultiMetricv4TestEvaluate( multiVariateMetric, useDisplacementTransform ) != EXIT_SUCCESS )
362 {
363 std::cerr << "Failed with proper CompositeTransform." << std::endl;
364 return EXIT_FAILURE;
365 }
366
367 std::cout << "*** Test with CompositeTransform - too many active transforms ***" << std::endl;
368 compositeTransform->SetAllTransformsToOptimizeOn();
369 TRY_EXPECT_EXCEPTION( multiVariateMetric->Initialize() );
370
371 std::cout << "*** Test with CompositeTransform - one active transform, but wrong one ***" << std::endl;
372 compositeTransform->SetAllTransformsToOptimizeOff();
373 compositeTransform->SetNthTransformToOptimizeOn( 0 );
374 TRY_EXPECT_EXCEPTION( multiVariateMetric->Initialize() );
375
376 // Reset transform
377 m4->SetMovingTransform( transform );
378
379 //
380 // Test with adding point set metrics
381 //
382 using PointSetType = itk::PointSet<float, Dimension>;
383 PointSetType::Pointer fixedPoints = PointSetType::New();
384 PointSetType::Pointer movingPoints = PointSetType::New();
385 fixedPoints->Initialize();
386 movingPoints->Initialize();
387
388 PointSetType::PointType point;
389 for( itk::SizeValueType n = 0; n < 100; n++ )
390 {
391 point[0] = n * 1.0;
392 point[1] = n * 2.0;
393 fixedPoints->SetPoint( n, point );
394 point[0] += 0.5;
395 point[1] += 0.5;
396 movingPoints->SetPoint( n, point );
397 }
398
399 using ExpectationPointSetMetricType = itk::ExpectationBasedPointSetToPointSetMetricv4<PointSetType>;
400 using EuclideanPointSetMetricType = itk::EuclideanDistancePointSetToPointSetMetricv4<PointSetType>;
401 ExpectationPointSetMetricType::Pointer expectationPointSetMetric = ExpectationPointSetMetricType::New();
402 EuclideanPointSetMetricType::Pointer euclideanPointSetMetric = EuclideanPointSetMetricType::New();
403
404 expectationPointSetMetric->SetFixedPointSet( fixedPoints );
405 expectationPointSetMetric->SetMovingPointSet( movingPoints );
406 expectationPointSetMetric->SetMovingTransform( transform );
407 euclideanPointSetMetric->SetFixedPointSet( fixedPoints );
408 euclideanPointSetMetric->SetMovingPointSet( movingPoints );
409 euclideanPointSetMetric->SetMovingTransform( transform );
410
411 multiVariateMetric->AddMetric( expectationPointSetMetric );
412 multiVariateMetric->AddMetric( euclideanPointSetMetric );
413
414
415 // Expect return false because of point set metrics
416 if( multiVariateMetric->SupportsArbitraryVirtualDomainSamples() == true )
417 {
418 std::cerr << "Expected SupportsArbitraryVirtualDomainSamples() to return true, but got false. " << std::endl;
419 return EXIT_FAILURE;
420 }
421
422 // Test
423 std::cout << "*** Test with PointSet metrics and Image metrics *** " << std::endl;
424 if( itkObjectToObjectMultiMetricv4TestEvaluate( multiVariateMetric, useDisplacementTransform ) != EXIT_SUCCESS )
425 {
426 return EXIT_FAILURE;
427 }
428
429 //
430 // Exercise basic operation with a scales estimator
431 //
432 using ScalesEstimatorMultiType = itk::RegistrationParameterScalesFromPhysicalShift< MultiMetricType >;
433 ScalesEstimatorMultiType::Pointer shiftScaleEstimator = ScalesEstimatorMultiType::New();
434 shiftScaleEstimator->SetMetric(multiVariateMetric);
435 // Have to assign virtual domain sampling points when using a point set with scales estimator
436 shiftScaleEstimator->SetVirtualDomainPointSet( expectationPointSetMetric->GetVirtualTransformedPointSet() );
437
438 ScalesEstimatorMultiType::ScalesType scales;
439 shiftScaleEstimator->EstimateScales( scales );
440 std::cout << "Estimated scales: " << scales << std::endl;
441
442 ScalesEstimatorMultiType::FloatType stepScale;
443 ScalesEstimatorMultiType::ParametersType step;
444 step.SetSize( multiVariateMetric->GetNumberOfParameters() );
445 step.Fill( itk::NumericTraits<ScalesEstimatorMultiType::ParametersType::ValueType>::OneValue() );
446 stepScale = shiftScaleEstimator->EstimateStepScale( step );
447 std::cout << "Estimated stepScale: " << stepScale << std::endl;
448
449 //
450 // Test that we get the same scales/step estimation
451 // with a single metric and the same metric twice in a multimetric
452 //
453 ScalesEstimatorMultiType::ScalesType singleScales, multiSingleScales, multiDoubleScales;
454 ScalesEstimatorMultiType::FloatType singleStep, multiSingleStep, multiDoubleStep;
455 step.SetSize( m1->GetNumberOfParameters() );
456 step.Fill( itk::NumericTraits<ScalesEstimatorMultiType::ParametersType::ValueType>::OneValue() );
457
458 using ScalesEstimatorMeanSquaresType = itk::RegistrationParameterScalesFromPhysicalShift<MeanSquaresMetricType>;
459 ScalesEstimatorMeanSquaresType::Pointer singleShiftScaleEstimator = ScalesEstimatorMeanSquaresType::New();
460 singleShiftScaleEstimator->SetMetric(m1);
461 m1->Initialize();
462 singleShiftScaleEstimator->EstimateScales( singleScales );
463 std::cout << "Single metric estimated scales: " << singleScales << std::endl;
464 singleStep = singleShiftScaleEstimator->EstimateStepScale( step );
465 std::cout << "Single metric estimated stepScale: " << singleStep << std::endl;
466
467 MultiMetricType::Pointer multiSingleMetric = MultiMetricType::New();
468 multiSingleMetric->AddMetric( m1 );
469 multiSingleMetric->Initialize();
470 shiftScaleEstimator->SetMetric( multiSingleMetric );
471 shiftScaleEstimator->EstimateScales( multiSingleScales );
472 std::cout << "multi-single estimated scales: " << multiSingleScales << std::endl;
473 multiSingleStep = shiftScaleEstimator->EstimateStepScale( step );
474 std::cout << "multi-single estimated stepScale: " << multiSingleStep << std::endl;
475
476 MultiMetricType::Pointer multiDoubleMetric = MultiMetricType::New();
477 multiDoubleMetric->AddMetric( m1 );
478 multiDoubleMetric->AddMetric( m1 );
479 multiDoubleMetric->Initialize();
480 shiftScaleEstimator->SetMetric( multiDoubleMetric );
481 shiftScaleEstimator->EstimateScales( multiDoubleScales );
482 std::cout << "multi-double estimated scales: " << multiDoubleScales << std::endl;
483 multiDoubleStep = shiftScaleEstimator->EstimateStepScale( step );
484 std::cout << "multi-double estimated stepScale: " << multiDoubleStep << std::endl;
485
486 // Check that results are the same for all three estimations
487 bool passedEstimation = true;
488 auto tolerance = static_cast<ScalesEstimatorMultiType::FloatType>(1e-6);
489 if( std::fabs(singleStep - multiSingleStep) > tolerance || std::fabs(singleStep - multiDoubleStep) > tolerance )
490 {
491 std::cerr << "Steps do not match as expected between estimation on same metric." << std::endl;
492 passedEstimation = false;
493 }
494 if( std::fabs(singleScales[0] - multiSingleScales[0] ) > tolerance ||
495 std::fabs(singleScales[1] - multiSingleScales[1] ) > tolerance ||
496 std::fabs(singleScales[0] - multiDoubleScales[0] ) > tolerance ||
497 std::fabs(singleScales[1] - multiDoubleScales[1] ) > tolerance )
498 {
499 std::cerr << "Scales do not match as expected between estimation on same metric." << std::endl;
500 passedEstimation = false;
501 }
502 if( ! passedEstimation )
503 {
504 return EXIT_FAILURE;
505 }
506
507 if( ! useDisplacementTransform )
508 {
509 // Exercising the Print function
510 std::cout << "Print: " << std::endl;
511 multiVariateMetric->Print(std::cout);
512
513 // Test ClearMetricQueue
514 multiVariateMetric->ClearMetricQueue();
515 if( multiVariateMetric->GetNumberOfMetrics() != 0 )
516 {
517 std::cerr << "ClearMetricQueue() failed. Number of metrics is not zero." << std::endl;
518 return EXIT_FAILURE;
519 }
520 }
521
522 return EXIT_SUCCESS;
523 }
524
itkObjectToObjectMultiMetricv4Test(int,char * [])525 int itkObjectToObjectMultiMetricv4Test (int , char *[])
526 {
527 std::cout << "XXX Test with TranslationTransform XXX" << std::endl << std::endl;
528 int result = itkObjectToObjectMultiMetricv4TestRun( false );
529 if( result == EXIT_FAILURE )
530 {
531 std::cerr << "Failed test with translation transform. See message above." << std::endl;
532 return EXIT_FAILURE;
533 }
534
535 std::cout << std::endl << std::endl << "XXX Test with DisplacementFieldTransform XXX" << std::endl << std::endl;
536 result = itkObjectToObjectMultiMetricv4TestRun( true );
537 if( result == EXIT_FAILURE )
538 {
539 std::cerr << "Failed test with displacement field transform. See message above." << std::endl;
540 return EXIT_FAILURE;
541 }
542
543 return EXIT_SUCCESS;
544 }
545