1 /*=========================================================================
2  *
3  *  Copyright Insight Software Consortium
4  *
5  *  Licensed under the Apache License, Version 2.0 (the "License");
6  *  you may not use this file except in compliance with the License.
7  *  You may obtain a copy of the License at
8  *
9  *         http://www.apache.org/licenses/LICENSE-2.0.txt
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  *
17  *=========================================================================*/
18 #ifndef _itkLBFGSOptimizer_hxx
19 #define _itkLBFGSOptimizer_hxx
20 
21 #include "itkLBFGSOptimizer.h"
22 #include "itkMath.h"
23 
24 namespace itk
25 {
26 /**
27  * Constructor
28  */
29 LBFGSOptimizer
LBFGSOptimizer()30 ::LBFGSOptimizer()
31 {
32   m_OptimizerInitialized    = false;
33   m_VnlOptimizer            = nullptr;
34   m_Trace                              = false;
35   m_MaximumNumberOfFunctionEvaluations = 2000;
36   m_GradientConvergenceTolerance       = 1e-5;
37   m_LineSearchAccuracy                 = 0.9;
38   m_DefaultStepLength                  = 1.0;
39 }
40 
41 /**
42  * Destructor
43  */
44 LBFGSOptimizer
~LBFGSOptimizer()45 ::~LBFGSOptimizer()
46 {
47   delete m_VnlOptimizer;
48 }
49 
50 /**
51  * PrintSelf
52  */
53 void
54 LBFGSOptimizer
PrintSelf(std::ostream & os,Indent indent) const55 ::PrintSelf(std::ostream & os, Indent indent) const
56 {
57   Superclass::PrintSelf(os, indent);
58   os << indent << "Trace: ";
59   if ( m_Trace )
60     {
61     os << "On";
62     }
63   else
64     {
65     os << "Off";
66     }
67   os << std::endl;
68   os << indent << "MaximumNumberOfFunctionEvaluations: "
69      << m_MaximumNumberOfFunctionEvaluations << std::endl;
70   os << indent << "GradientConvergenceTolerance: "
71      << m_GradientConvergenceTolerance << std::endl;
72   os << indent << "LineSearchAccuracy: "
73      << m_LineSearchAccuracy << std::endl;
74   os << indent << "DefaultStepLength: "
75      << m_DefaultStepLength << std::endl;
76 }
77 
78 /**
79  * Set the optimizer trace flag
80  */
81 void
82 LBFGSOptimizer
SetTrace(bool flag)83 ::SetTrace(bool flag)
84 {
85   if ( flag == m_Trace )
86     {
87     return;
88     }
89 
90   m_Trace = flag;
91   if ( m_OptimizerInitialized )
92     {
93     m_VnlOptimizer->set_trace(m_Trace);
94     }
95 
96   this->Modified();
97 }
98 
99 /**
100  * Set the maximum number of function evalutions
101  */
102 void
103 LBFGSOptimizer
SetMaximumNumberOfFunctionEvaluations(unsigned int n)104 ::SetMaximumNumberOfFunctionEvaluations(unsigned int n)
105 {
106   if ( n == m_MaximumNumberOfFunctionEvaluations )
107     {
108     return;
109     }
110 
111   m_MaximumNumberOfFunctionEvaluations = n;
112   if ( m_OptimizerInitialized )
113     {
114     m_VnlOptimizer->set_max_function_evals(
115       static_cast< int >( m_MaximumNumberOfFunctionEvaluations ) );
116     }
117 
118   this->Modified();
119 }
120 
121 /**
122  * Set the gradient convergence tolerance
123  */
124 void
125 LBFGSOptimizer
SetGradientConvergenceTolerance(double f)126 ::SetGradientConvergenceTolerance(double f)
127 {
128   if ( Math::ExactlyEquals(f, m_GradientConvergenceTolerance) )
129     {
130     return;
131     }
132 
133   m_GradientConvergenceTolerance = f;
134   if ( m_OptimizerInitialized )
135     {
136     m_VnlOptimizer->set_g_tolerance(m_GradientConvergenceTolerance);
137     }
138 
139   this->Modified();
140 }
141 
142 /**
143  * Set the line search accuracy
144  */
145 void
146 LBFGSOptimizer
SetLineSearchAccuracy(double f)147 ::SetLineSearchAccuracy(double f)
148 {
149   if ( Math::ExactlyEquals(f, m_LineSearchAccuracy) )
150     {
151     return;
152     }
153 
154   m_LineSearchAccuracy = f;
155   if ( m_OptimizerInitialized )
156     {
157     m_VnlOptimizer->line_search_accuracy = m_LineSearchAccuracy;
158     }
159 
160   this->Modified();
161 }
162 
163 /**
164  * Set the default step length
165  */
166 void
167 LBFGSOptimizer
SetDefaultStepLength(double f)168 ::SetDefaultStepLength(double f)
169 {
170   if ( Math::ExactlyEquals(f, m_DefaultStepLength) )
171     {
172     return;
173     }
174 
175   m_DefaultStepLength = f;
176   if ( m_OptimizerInitialized )
177     {
178     m_VnlOptimizer->default_step_length = m_DefaultStepLength;
179     }
180 
181   this->Modified();
182 }
183 
184 /** Return Current Value */
185 LBFGSOptimizer::MeasureType
186 LBFGSOptimizer
GetValue() const187 ::GetValue() const
188 {
189   return this->GetCachedValue();
190 }
191 
192 /**
193  * Connect a Cost Function
194  */
195 void
196 LBFGSOptimizer
SetCostFunction(SingleValuedCostFunction * costFunction)197 ::SetCostFunction(SingleValuedCostFunction *costFunction)
198 {
199   const unsigned int numberOfParameters =
200     costFunction->GetNumberOfParameters();
201 
202   auto * adaptor = new CostFunctionAdaptorType(numberOfParameters);
203 
204   adaptor->SetCostFunction(costFunction);
205 
206   if ( m_OptimizerInitialized )
207     {
208     delete m_VnlOptimizer;
209     }
210 
211   this->SetCostFunctionAdaptor(adaptor);
212 
213   m_VnlOptimizer = new vnl_lbfgs(*adaptor);
214 
215   // set the optimizer parameters
216   m_VnlOptimizer->set_trace(m_Trace);
217   m_VnlOptimizer->set_max_function_evals(
218     static_cast< int >( m_MaximumNumberOfFunctionEvaluations ) );
219   m_VnlOptimizer->set_g_tolerance(m_GradientConvergenceTolerance);
220   m_VnlOptimizer->line_search_accuracy = m_LineSearchAccuracy;
221   m_VnlOptimizer->default_step_length  = m_DefaultStepLength;
222 
223   m_OptimizerInitialized = true;
224 
225   this->Modified();
226 }
227 
228 /**
229  * Start the optimization
230  */
231 void
232 LBFGSOptimizer
StartOptimization()233 ::StartOptimization()
234 {
235   this->InvokeEvent( StartEvent() );
236 
237   if ( this->GetMaximize() )
238     {
239     this->GetNonConstCostFunctionAdaptor()->NegateCostFunctionOn();
240     }
241 
242   ParametersType currentPositionInternalValue = this->GetInitialPosition();
243 
244   // We also scale the initial vnlCompatibleParameters up if scales are defined.
245   // This compensates for later scaling them down in the cost function adaptor
246   // and at the end of this function.
247   InternalParametersType vnlCompatibleParameters(currentPositionInternalValue.size());
248   const ScalesType & scales = this->GetScales();
249   if ( m_ScalesInitialized )
250     {
251     this->GetNonConstCostFunctionAdaptor()->SetScales(scales);
252     }
253   for ( unsigned int i = 0; i < vnlCompatibleParameters.size(); ++i )
254     {
255     vnlCompatibleParameters[i] = (m_ScalesInitialized)
256       ? currentPositionInternalValue[i] * scales[i]
257       : currentPositionInternalValue[i];
258     }
259 
260   // vnl optimizers return the solution by reference
261   // in the variable provided as initial position
262   m_VnlOptimizer->minimize(vnlCompatibleParameters);
263 
264   if ( vnlCompatibleParameters.size() != currentPositionInternalValue.size() )
265     {
266     // set current position to initial position and throw an exception
267     this->SetCurrentPosition(currentPositionInternalValue);
268     itkExceptionMacro(<< "Error occurred in optimization");
269     }
270 
271   // we scale the vnlCompatibleParameters down if scales are defined
272   const ScalesType & invScales = this->GetInverseScales();
273   for ( unsigned int i = 0; i < vnlCompatibleParameters.size(); ++i )
274     {
275     currentPositionInternalValue[i] = (m_ScalesInitialized)
276       ? vnlCompatibleParameters[i] * invScales[i]
277       : vnlCompatibleParameters[i];
278     }
279 
280   this->SetCurrentPosition(currentPositionInternalValue);
281   this->InvokeEvent( EndEvent() );
282 }
283 
284 /**
285  * Get the Optimizer
286  */
287 vnl_lbfgs *
288 LBFGSOptimizer
GetOptimizer()289 ::GetOptimizer()
290 {
291   return m_VnlOptimizer;
292 }
293 
294 const std::string
GetStopConditionDescription() const295 LBFGSOptimizer::GetStopConditionDescription() const
296 {
297   m_StopConditionDescription.str("");
298   m_StopConditionDescription << this->GetNameOfClass() << ": ";
299   if ( m_VnlOptimizer )
300     {
301     switch ( m_VnlOptimizer->get_failure_code() )
302       {
303       case vnl_nonlinear_minimizer::ERROR_FAILURE:
304         m_StopConditionDescription << "Failure";
305         break;
306       case vnl_nonlinear_minimizer::ERROR_DODGY_INPUT:
307         m_StopConditionDescription << "Dodgy input";
308         break;
309       case vnl_nonlinear_minimizer::CONVERGED_FTOL:
310         m_StopConditionDescription << "Function tolerance reached";
311         break;
312       case vnl_nonlinear_minimizer::CONVERGED_XTOL:
313         m_StopConditionDescription << "Solution tolerance reached";
314         break;
315       case vnl_nonlinear_minimizer::CONVERGED_XFTOL:
316         m_StopConditionDescription << "Solution and Function tolerance both reached";
317         break;
318       case vnl_nonlinear_minimizer::CONVERGED_GTOL:
319         m_StopConditionDescription << "Gradient tolerance reached";
320         break;
321       case vnl_nonlinear_minimizer::FAILED_TOO_MANY_ITERATIONS:
322         m_StopConditionDescription << "Too many function evaluations. Function evaluations  = "
323                                    << m_MaximumNumberOfFunctionEvaluations;
324         break;
325       case vnl_nonlinear_minimizer::FAILED_FTOL_TOO_SMALL:
326         m_StopConditionDescription << "Function tolerance too small";
327         break;
328       case vnl_nonlinear_minimizer::FAILED_XTOL_TOO_SMALL:
329         m_StopConditionDescription << "Solution tolerance too small";
330         break;
331       case vnl_nonlinear_minimizer::FAILED_GTOL_TOO_SMALL:
332         m_StopConditionDescription << "Gradient tolerance too small";
333         break;
334       case vnl_nonlinear_minimizer::FAILED_USER_REQUEST:
335         m_StopConditionDescription << "User requested";
336         break;
337       }
338     return m_StopConditionDescription.str();
339     }
340   else
341     {
342     return std::string("");
343     }
344 }
345 } // end namespace itk
346 
347 #endif
348