1 /*=========================================================================
2  *
3  *  Copyright Insight Software Consortium
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at
8  *
9  *         http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  *=========================================================================*/
18 #ifndef itkMultiGradientOptimizerv4_hxx
19 #define itkMultiGradientOptimizerv4_hxx
20 
21 #include "itkMultiGradientOptimizerv4.h"
22 
23 namespace itk
24 {
25 
26 //-------------------------------------------------------------------
27 template<typename TInternalComputationValueType>
28 MultiGradientOptimizerv4Template<TInternalComputationValueType>
MultiGradientOptimizerv4Template()29 ::MultiGradientOptimizerv4Template()
30 
31 {
32   this->m_NumberOfIterations = static_cast<SizeValueType>(0);
33   this->m_StopCondition      = Superclass::MAXIMUM_NUMBER_OF_ITERATIONS;
34   this->m_StopConditionDescription << this->GetNameOfClass() << ": ";
35 
36   this->m_MaximumMetricValue=NumericTraits<MeasureType>::max();
37   this->m_MinimumMetricValue = this->m_MaximumMetricValue;
38   }
39 
40 //-------------------------------------------------------------------
41 template<typename TInternalComputationValueType>
42 void
43 MultiGradientOptimizerv4Template<TInternalComputationValueType>
PrintSelf(std::ostream & os,Indent indent) const44 ::PrintSelf(std::ostream & os, Indent indent) const
45 {
46   Superclass::PrintSelf(os, indent);
47   os << indent << "Stop condition:"<< this->m_StopCondition << std::endl;
48   os << indent << "Stop condition description: " << this->m_StopConditionDescription.str()  << std::endl;
49   }
50 
51 //-------------------------------------------------------------------
52 template<typename TInternalComputationValueType>
53 typename MultiGradientOptimizerv4Template<TInternalComputationValueType>::OptimizersListType &
54 MultiGradientOptimizerv4Template<TInternalComputationValueType>
GetOptimizersList()55 ::GetOptimizersList()
56 {
57   return this->m_OptimizersList;
58   }
59 
60 /** Set the list of optimizers to use in the multiple gradient descent */
61 template<typename TInternalComputationValueType>
62 void
63 MultiGradientOptimizerv4Template<TInternalComputationValueType>
SetOptimizersList(typename MultiGradientOptimizerv4Template::OptimizersListType & p)64 ::SetOptimizersList(typename MultiGradientOptimizerv4Template::OptimizersListType & p)
65 {
66   if( p != this->m_OptimizersList )
67     {
68     this->m_OptimizersList = p;
69     this->Modified();
70     }
71   }
72 
73 /** Get the list of metric values that we produced after the multi-gradient optimization.  */
74 template<typename TInternalComputationValueType>
75 const typename MultiGradientOptimizerv4Template<TInternalComputationValueType>::MetricValuesListType &
76 MultiGradientOptimizerv4Template<TInternalComputationValueType>
GetMetricValuesList() const77 ::GetMetricValuesList() const
78 {
79   return this->m_MetricValuesList;
80   }
81 
82 //-------------------------------------------------------------------
83 template<typename TInternalComputationValueType>
84 const typename MultiGradientOptimizerv4Template<TInternalComputationValueType>::StopConditionReturnStringType
85 MultiGradientOptimizerv4Template<TInternalComputationValueType>
GetStopConditionDescription() const86 ::GetStopConditionDescription() const
87 {
88   return this->m_StopConditionDescription.str();
89   }
90 
91 //-------------------------------------------------------------------
92 template<typename TInternalComputationValueType>
93 void
94 MultiGradientOptimizerv4Template<TInternalComputationValueType>
StopOptimization()95 ::StopOptimization()
96 {
97   itkDebugMacro( "StopOptimization called with a description - "
98                 << this->GetStopConditionDescription() );
99   this->m_Stop = true;
100 
101     // FIXME
102     // this->m_Metric->SetParameters( this->m_OptimizersList[ this->m_BestParametersIndex ] );
103   this->InvokeEvent( EndEvent() );
104   }
105 
106 /**
107 * Start and run the optimization
108 */
109 template<typename TInternalComputationValueType>
110 void
111 MultiGradientOptimizerv4Template<TInternalComputationValueType>
StartOptimization(bool doOnlyInitialization)112 ::StartOptimization( bool doOnlyInitialization )
113 {
114   itkDebugMacro("StartOptimization");
115   auto maxOpt=static_cast<SizeValueType>( this->m_OptimizersList.size() );
116   if ( maxOpt == NumericTraits<SizeValueType>::ZeroValue() )
117     {
118     itkExceptionMacro(" No optimizers are set.");
119     }
120   if ( ! this->m_Metric )
121     {
122     this->m_Metric = this->m_OptimizersList[0]->GetModifiableMetric();
123     }
124   this->m_MetricValuesList.clear();
125   this->m_MinimumMetricValue = this->m_MaximumMetricValue;
126   const ParametersType & testParamsAreTheSameObject = this->m_OptimizersList[0]->GetCurrentPosition();
127   this->m_MetricValuesList.push_back( this->m_MaximumMetricValue );
128   /* Initialize the optimizer, but don't run it. */
129   this->m_OptimizersList[0]->StartOptimization( true /* doOnlyInitialization */ );
130 
131   for ( SizeValueType whichOptimizer = 1; whichOptimizer < maxOpt; whichOptimizer++ )
132     {
133     this->m_MetricValuesList.push_back(this->m_MaximumMetricValue);
134     const ParametersType & compareParams = this->m_OptimizersList[whichOptimizer]->GetCurrentPosition();
135     if ( &compareParams != &testParamsAreTheSameObject )
136       {
137       itkExceptionMacro(" Parameter objects are not identical across all optimizers/metrics.");
138       }
139     /* Initialize the optimizer, but don't run it. */
140     this->m_OptimizersList[whichOptimizer]->StartOptimization( true /* doOnlyInitialization */ );
141     }
142 
143   this->m_CurrentIteration = static_cast<SizeValueType>(0);
144 
145   /* Must call the superclass version for basic validation and setup,
146    * and to start the optimization loop. */
147   if ( this->m_NumberOfIterations > static_cast<SizeValueType>(0) )
148     {
149     Superclass::StartOptimization( doOnlyInitialization );
150     }
151   }
152 
153 /**
154 * Resume optimization.
155 */
156 template<typename TInternalComputationValueType>
157 void
158 MultiGradientOptimizerv4Template<TInternalComputationValueType>
ResumeOptimization()159 ::ResumeOptimization()
160 {
161   this->m_StopConditionDescription.str("");
162   this->m_StopConditionDescription << this->GetNameOfClass() << ": ";
163   this->InvokeEvent( StartEvent() );
164   itkDebugMacro(" start ");
165   this->m_Stop = false;
166   while( ! this->m_Stop )
167     {
168     /* Compute metric value/derivative. */
169 
170     auto maxOpt = static_cast<SizeValueType>( this->m_OptimizersList.size() );
171     /** we rely on learning rate or parameter scale estimator to do the weighting */
172     TInternalComputationValueType combinefunction = NumericTraits<TInternalComputationValueType>::OneValue() / static_cast<TInternalComputationValueType>(maxOpt);
173     itkDebugMacro(" nopt " << maxOpt);
174 
175     for (SizeValueType whichOptimizer = 0; whichOptimizer < maxOpt; whichOptimizer++ )
176       {
177       this->m_OptimizersList[whichOptimizer]->GetMetric()->GetValueAndDerivative(
178                                                                                  const_cast<MeasureType&>( this->m_OptimizersList[whichOptimizer]->GetCurrentMetricValue() ),
179                                                                                  const_cast<DerivativeType&>( this->m_OptimizersList[whichOptimizer]->GetGradient() ) );
180       itkDebugMacro(" got-deriv " << whichOptimizer);
181       if ( this->m_Gradient.Size() != this->m_OptimizersList[whichOptimizer]->GetGradient().Size() )
182         {
183         this->m_Gradient.SetSize( this->m_OptimizersList[whichOptimizer]->GetGradient().Size() );
184         itkDebugMacro(" resized ");
185         }
186 
187       /* Modify the gradient by scales, weights and learning rate */
188       this->m_OptimizersList[whichOptimizer]->ModifyGradientByScales();
189       this->m_OptimizersList[whichOptimizer]->EstimateLearningRate();
190       this->m_OptimizersList[whichOptimizer]->ModifyGradientByLearningRate();
191 
192       itkDebugMacro(" mod-grad ");
193       /** combine the gradients */
194       if ( whichOptimizer == 0 )
195         {
196         this->m_Gradient.Fill(0);
197         }
198       this->m_Gradient = this->m_Gradient + this->m_OptimizersList[whichOptimizer]->GetGradient() * combinefunction;
199       itkDebugMacro(" add-grad ");
200       this->m_MetricValuesList[whichOptimizer] = this->m_OptimizersList[whichOptimizer]->GetCurrentMetricValue();
201       }//endfor
202 
203     /* Check if optimization has been stopped externally.
204      * (Presumably this could happen from a multi-threaded client app?) */
205     if ( this->m_Stop )
206       {
207       this->m_StopConditionDescription << "StopOptimization() called";
208       break;
209       }
210     try
211       {
212       /* Pass combined gradient to transforms and let them update */
213       itkDebugMacro(" combine-grad ");
214       this->m_OptimizersList[0]->GetModifiableMetric()->UpdateTransformParameters( this->m_Gradient );
215       }
216     catch ( ExceptionObject & err )
217       {
218       this->m_StopCondition = Superclass::UPDATE_PARAMETERS_ERROR;
219       this->m_StopConditionDescription << "UpdateTransformParameters error";
220       this->StopOptimization();
221         // Pass exception to caller
222       throw err;
223       }
224     this->InvokeEvent( IterationEvent() );
225     /* Update and check iteration count */
226     this->m_CurrentIteration++;
227     if ( this->m_CurrentIteration >= this->m_NumberOfIterations )
228       {
229       this->m_StopConditionDescription << "Maximum number of iterations (" << this->m_NumberOfIterations << ") exceeded.";
230       this->m_StopCondition = Superclass::MAXIMUM_NUMBER_OF_ITERATIONS;
231       this->StopOptimization();
232       break;
233       }
234     }  //while (!m_Stop)
235   }
236 
237 } //namespace itk
238 
239 #endif
240