1 // -*- C++ -*-
2 /**
3 * @brief Ceres solver
4 *
5 * Copyright 2005-2021 Airbus-EDF-IMACS-ONERA-Phimeca
6 *
7 * This library is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU Lesser General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * This library is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 * GNU Lesser General Public License for more details.
16 *
17 * You should have received a copy of the GNU Lesser General Public License
18 * along with this library. If not, see <http://www.gnu.org/licenses/>.
19 *
20 */
21 #include "openturns/Ceres.hxx"
22 #include "openturns/Point.hxx"
23 #include "openturns/PersistentObjectFactory.hxx"
24 #include "openturns/Log.hxx"
25 #include "openturns/SpecFunc.hxx"
26 #ifdef OPENTURNS_HAVE_CERES
27 #include <ceres/ceres.h>
28 #endif
29
30 BEGIN_NAMESPACE_OPENTURNS
31
32 CLASSNAMEINIT(Ceres)
33
34 static const Factory<Ceres> Factory_Ceres;
35
GetAlgorithmNames()36 Description Ceres::GetAlgorithmNames()
37 {
38 static Description AlgorithmNames;
39 if (!AlgorithmNames.getSize())
40 {
41 // trust-region methods, not for general optimization
42 AlgorithmNames.add("LEVENBERG_MARQUARDT");// default nlls method, list it first
43 AlgorithmNames.add("DOGLEG");
44
45 // line search methods, available for both least-squares and general optimization
46 AlgorithmNames.add("STEEPEST_DESCENT");
47 AlgorithmNames.add("NONLINEAR_CONJUGATE_GRADIENT");
48 AlgorithmNames.add("LBFGS");
49 AlgorithmNames.add("BFGS");
50 }
51 return AlgorithmNames;
52 }
53
54
55 /* Default constructor */
Ceres(const String & algoName)56 Ceres::Ceres(const String & algoName)
57 : OptimizationAlgorithmImplementation()
58 , algoName_(algoName)
59 {
60 if (!GetAlgorithmNames().contains(algoName))
61 throw InvalidArgumentException(HERE) << "Unknown algorithm name, should be one of " << GetAlgorithmNames();
62 }
63
Ceres(const OptimizationProblem & problem,const String & algoName)64 Ceres::Ceres(const OptimizationProblem & problem,
65 const String & algoName)
66 : OptimizationAlgorithmImplementation(problem)
67 , algoName_(algoName)
68 {
69 if (!GetAlgorithmNames().contains(algoName))
70 throw InvalidArgumentException(HERE) << "Unknown algorithm name, should be one of " << GetAlgorithmNames();
71 checkProblem(problem);
72 }
73
74 /* Virtual constructor */
clone() const75 Ceres * Ceres::clone() const
76 {
77 return new Ceres(*this);
78 }
79
80 /* Check whether this problem can be solved by this solver. Must be overloaded by the actual optimisation algorithm */
checkProblem(const OptimizationProblem & problem) const81 void Ceres::checkProblem(const OptimizationProblem & problem) const
82 {
83 if (problem.hasMultipleObjective())
84 throw InvalidArgumentException(HERE) << "Error: " << getClassName() << " does not support multi-objective optimization";
85
86 if (problem.hasLevelFunction())
87 throw InvalidArgumentException(HERE) << "Error: " << getClassName() << " does not support nearest-point problems";
88
89 if (problem.hasBounds() && (algoName_ != "LEVENBERG_MARQUARDT" && algoName_ != "DOGLEG"))
90 throw InvalidArgumentException(HERE) << "Error: " << getClassName() << " line search algorithms do not support bound constraints";
91
92 if (!problem.hasResidualFunction() && (algoName_ == "LEVENBERG_MARQUARDT" || algoName_ == "DOGLEG"))
93 throw InvalidArgumentException(HERE) << "Error: " << getClassName() << " trust-region algorithms do not support general optimization";
94
95 if (problem.hasInequalityConstraint())
96 throw InvalidArgumentException(HERE) << "Error: " << getClassName() << " does not support inequality constraints";
97
98 if (problem.hasEqualityConstraint())
99 throw InvalidArgumentException(HERE) << "Error: " << getClassName() << " does not support equality constraints";
100
101 if (!problem.isContinuous())
102 throw InvalidArgumentException(HERE) << "Error: " << getClassName() << " does not support non continuous problems";
103 }
104
105 #ifdef OPENTURNS_HAVE_CERES
106 class CostFunctionInterface : public ceres::CostFunction
107 {
108 public:
CostFunctionInterface(Ceres & algorithm)109 explicit CostFunctionInterface(Ceres & algorithm)
110 : ceres::CostFunction()
111 , algorithm_(algorithm)
112 {
113 const OptimizationProblem problem(algorithm_.getProblem());
114 *mutable_parameter_block_sizes() = std::vector<int32_t>(1, problem.getDimension());
115 set_num_residuals(problem.getResidualFunction().getOutputDimension());
116 }
117
Evaluate(double const * const * parameters,double * residuals,double ** jacobians) const118 bool Evaluate(double const* const* parameters,
119 double* residuals,
120 double** jacobians) const override
121 {
122 const OptimizationProblem problem(algorithm_.getProblem());
123 const UnsignedInteger n = problem.getDimension();
124 const UnsignedInteger m = problem.getResidualFunction().getOutputDimension();
125 Point inP(n);
126 const double * x = parameters[0];
127 std::copy(x, x + n, inP.begin());
128
129 // evaluation
130 const Point outP(problem.getResidualFunction()(inP));
131 std::copy(outP.begin(), outP.end(), residuals);
132 algorithm_.evaluationInputHistory_.add(inP);
133 algorithm_.evaluationOutputHistory_.add(Point(1, 0.5 * outP.normSquare()));
134
135 // gradient
136 if (jacobians)
137 {
138 const Matrix gradient(problem.getResidualFunction().gradient(inP));
139 std::copy(&gradient(0, 0), &gradient(n - 1, m - 1) + 1, jacobians[0]);
140 }
141 return true;
142 }
143
144 protected:
145 Ceres & algorithm_;
146 };
147
148
149 class FirstOrderFunctionInterface : public ceres::FirstOrderFunction
150 {
151 public:
FirstOrderFunctionInterface(Ceres & algorithm)152 explicit FirstOrderFunctionInterface(Ceres & algorithm)
153 : ceres::FirstOrderFunction()
154 , algorithm_(algorithm) {}
155
NumParameters() const156 int NumParameters() const override
157 {
158 return algorithm_.getProblem().getDimension();
159 }
160
Evaluate(const double * const x,double * cost,double * jacobian) const161 bool Evaluate(const double * const x,
162 double * cost,
163 double * jacobian) const override
164 {
165 const OptimizationProblem problem(algorithm_.getProblem());
166 const UnsignedInteger n = problem.getDimension();
167 Point inP(n);
168 std::copy(x, x + n, inP.begin());
169
170 // evaluation
171 const Point outP(problem.getObjective()(inP));
172 *cost = problem.isMinimization() ? outP[0] : -outP[0];
173 algorithm_.evaluationInputHistory_.add(inP);
174 algorithm_.evaluationOutputHistory_.add(outP);
175
176 // update result
177 algorithm_.result_.setEvaluationNumber(algorithm_.evaluationInputHistory_.getSize());
178 algorithm_.result_.store(inP, outP, 0.0, 0.0, 0.0, 0.0);
179
180 // gradient
181 if (jacobian)
182 {
183 const Matrix gradient(problem.isMinimization() ? problem.getObjective().gradient(inP) : -1.0 * problem.getObjective().gradient(inP));
184 std::copy(&gradient(0, 0), &gradient(n - 1, 0) + 1, jacobian);
185 }
186 return true;
187 }
188
189 protected:
190 Ceres & algorithm_;
191 };
192
193
194 class IterationCallbackInterface : public ceres::IterationCallback
195 {
196 public:
IterationCallbackInterface(Ceres & algorithm)197 explicit IterationCallbackInterface(Ceres & algorithm)
198 : ceres::IterationCallback()
199 , algorithm_(algorithm) {}
200
operator ()(const ceres::IterationSummary & summary)201 virtual ceres::CallbackReturnType operator()(const ceres::IterationSummary & summary)
202 {
203 if (algorithm_.progressCallback_.first)
204 algorithm_.progressCallback_.first(100.0 * summary.iteration / algorithm_.getMaximumIterationNumber(), algorithm_.progressCallback_.second);
205 if (algorithm_.stopCallback_.first && algorithm_.stopCallback_.first(algorithm_.stopCallback_.second))
206 return ceres::SOLVER_ABORT;
207 else
208 return ceres::SOLVER_CONTINUE;
209 }
210
211 protected:
212 const Ceres & algorithm_;
213 };
214
215 #endif
216
217 /* Performs the actual computation by calling the Ceres library
218 */
run()219 void Ceres::run()
220 {
221 #ifdef OPENTURNS_HAVE_CERES
222 const UnsignedInteger dimension = getProblem().getDimension();
223 Point x(getStartingPoint());
224 if (x.getDimension() != dimension)
225 throw InvalidArgumentException(HERE) << "Invalid starting point dimension (" << x.getDimension() << "), expected " << dimension;
226
227 // initialize history
228 evaluationInputHistory_ = Sample(0, dimension);
229 evaluationOutputHistory_ = Sample(0, 1);
230 result_ = OptimizationResult(getProblem());
231
232 double optimalValue = 0.0;
233 UnsignedInteger iterationNumber = 0;
234
235 if (getProblem().hasResidualFunction())
236 {
237 // Build the problem.
238 ceres::Problem problem;
239 ceres::CostFunction* cost_function = new CostFunctionInterface(*this);
240 problem.AddResidualBlock(cost_function, NULL, &x[0]);
241
242 if (getProblem().hasBounds())
243 {
244 Interval bounds(getProblem().getBounds());
245 if (!bounds.contains(x))
246 LOGWARN(OSS() << "Starting point is not inside bounds x=" << x.__str__() << " bounds=" << bounds);
247 Interval::BoolCollection finiteLowerBound(bounds.getFiniteLowerBound());
248 Interval::BoolCollection finiteUpperBound(bounds.getFiniteUpperBound());
249 Point lowerBound(bounds.getLowerBound());
250 Point upperBound(bounds.getUpperBound());
251 for (UnsignedInteger i = 0; i < dimension; ++ i)
252 {
253 if (finiteLowerBound[i]) problem.SetParameterLowerBound(&(*x.begin()), i, lowerBound[i]);
254 if (finiteUpperBound[i]) problem.SetParameterUpperBound(&(*x.begin()), i, upperBound[i]);
255 }
256 }
257
258 // Run the solver!
259 ceres::Solver::Options options;
260
261 // Switch trust region / line search depending on algoName as it's the union
262 if (ceres::StringToTrustRegionStrategyType(algoName_, &options.trust_region_strategy_type))
263 options.minimizer_type = ceres::TRUST_REGION;
264 else if (ceres::StringToLineSearchDirectionType(algoName_, &options.line_search_direction_type))
265 options.minimizer_type = ceres::LINE_SEARCH;
266 else
267 throw InvalidArgumentException(HERE) << "Could not set minimizer_type";
268
269 options.max_num_iterations = getMaximumIterationNumber();
270 options.function_tolerance = getMaximumResidualError();
271 options.parameter_tolerance = getMaximumRelativeError();
272
273 // Set remaining options from ResourceMap
274 if (ResourceMap::HasKey("Ceres-line_search_type") && !ceres::StringToLineSearchType(ResourceMap::Get("Ceres-line_search_type"), &options.line_search_type))
275 throw InvalidArgumentException(HERE) << "Invalid value for line_search_type";
276 if (ResourceMap::HasKey("Ceres-nonlinear_conjugate_gradient_type") && !ceres::StringToNonlinearConjugateGradientType(ResourceMap::Get("Ceres-nonlinear_conjugate_gradient_type"), &options.nonlinear_conjugate_gradient_type))
277 throw InvalidArgumentException(HERE) << "Invalid value for nonlinear_conjugate_gradient_type";
278 if (ResourceMap::HasKey("Ceres-max_lbfgs_rank"))
279 options.max_lbfgs_rank = ResourceMap::GetAsUnsignedInteger("Ceres-max_lbfgs_rank");
280 if (ResourceMap::HasKey("Ceres-use_approximate_eigenvalue_bfgs_scaling"))
281 options.use_approximate_eigenvalue_bfgs_scaling = ResourceMap::GetAsBool("Ceres-use_approximate_eigenvalue_bfgs_scaling");
282 if (ResourceMap::HasKey("Ceres-line_search_interpolation_type") && !ceres::StringToLineSearchInterpolationType(ResourceMap::Get("Ceres-line_search_interpolation_type"), &options.line_search_interpolation_type))
283 throw InvalidArgumentException(HERE) << "Invalid value for line_search_interpolation_type";
284 if (ResourceMap::HasKey("Ceres-min_line_search_step_size"))
285 options.min_line_search_step_size = ResourceMap::GetAsScalar("Ceres-min_line_search_step_size");
286 if (ResourceMap::HasKey("Ceres-line_search_sufficient_function_decrease"))
287 options.line_search_sufficient_function_decrease = ResourceMap::GetAsScalar("Ceres-line_search_sufficient_function_decrease");
288 if (ResourceMap::HasKey("Ceres-max_line_search_step_contraction"))
289 options.max_line_search_step_contraction = ResourceMap::GetAsScalar("Ceres-max_line_search_step_contraction");
290 if (ResourceMap::HasKey("Ceres-min_line_search_step_contraction"))
291 options.min_line_search_step_contraction = ResourceMap::GetAsScalar("Ceres-min_line_search_step_contraction");
292 if (ResourceMap::HasKey("Ceres-max_num_line_search_step_size_iterations"))
293 options.max_num_line_search_step_size_iterations = ResourceMap::GetAsUnsignedInteger("Ceres-max_num_line_search_step_size_iterations");
294 if (ResourceMap::HasKey("Ceres-max_num_line_search_direction_restarts"))
295 options.max_num_line_search_direction_restarts = ResourceMap::GetAsUnsignedInteger("Ceres-max_num_line_search_direction_restarts");
296 if (ResourceMap::HasKey("Ceres-line_search_sufficient_curvature_decrease"))
297 options.line_search_sufficient_curvature_decrease = ResourceMap::GetAsScalar("Ceres-line_search_sufficient_curvature_decrease");
298 if (ResourceMap::HasKey("Ceres-max_line_search_step_expansion"))
299 options.max_line_search_step_expansion = ResourceMap::GetAsScalar("Ceres-max_line_search_step_expansion");
300 if (ResourceMap::HasKey("Ceres-dogleg_type") && !ceres::StringToDoglegType(ResourceMap::Get("Ceres-dogleg_type"), &options.dogleg_type))
301 throw InvalidArgumentException(HERE) << "Invalid value for dogleg_type";
302 if (ResourceMap::HasKey("Ceres-use_nonmonotonic_steps"))
303 options.use_nonmonotonic_steps = ResourceMap::GetAsBool("Ceres-use_nonmonotonic_steps");
304 if (ResourceMap::HasKey("Ceres-max_consecutive_nonmonotonic_steps"))
305 options.max_consecutive_nonmonotonic_steps = ResourceMap::GetAsUnsignedInteger("Ceres-max_consecutive_nonmonotonic_steps");
306 if (ResourceMap::HasKey("Ceres-max_num_iterations"))
307 options.max_num_iterations = ResourceMap::GetAsUnsignedInteger("Ceres-max_num_iterations");
308 if (ResourceMap::HasKey("Ceres-max_solver_time_in_seconds"))
309 options.max_solver_time_in_seconds = ResourceMap::GetAsScalar("Ceres-max_solver_time_in_seconds");
310 if (ResourceMap::HasKey("Ceres-num_threads"))
311 options.num_threads = ResourceMap::GetAsUnsignedInteger("Ceres-num_threads");
312 if (ResourceMap::HasKey("Ceres-initial_trust_region_radius"))
313 options.initial_trust_region_radius = ResourceMap::GetAsScalar("Ceres-initial_trust_region_radius");
314 if (ResourceMap::HasKey("Ceres-max_trust_region_radius"))
315 options.max_trust_region_radius = ResourceMap::GetAsScalar("Ceres-max_trust_region_radius");
316 if (ResourceMap::HasKey("Ceres-min_trust_region_radius"))
317 options.min_trust_region_radius = ResourceMap::GetAsScalar("Ceres-min_trust_region_radius");
318 if (ResourceMap::HasKey("Ceres-min_relative_decrease"))
319 options.min_relative_decrease = ResourceMap::GetAsScalar("Ceres-min_relative_decrease");
320 if (ResourceMap::HasKey("Ceres-min_lm_diagonal"))
321 options.min_lm_diagonal = ResourceMap::GetAsScalar("Ceres-min_lm_diagonal");
322 if (ResourceMap::HasKey("Ceres-max_lm_diagonal"))
323 options.max_lm_diagonal = ResourceMap::GetAsScalar("Ceres-max_lm_diagonal");
324 if (ResourceMap::HasKey("Ceres-max_num_consecutive_invalid_steps"))
325 options.max_num_consecutive_invalid_steps = ResourceMap::GetAsUnsignedInteger("Ceres-max_num_consecutive_invalid_steps");
326 if (ResourceMap::HasKey("Ceres-function_tolerance"))
327 options.function_tolerance = ResourceMap::GetAsScalar("Ceres-function_tolerance");
328 if (ResourceMap::HasKey("Ceres-gradient_tolerance"))
329 options.gradient_tolerance = ResourceMap::GetAsScalar("Ceres-gradient_tolerance");
330 if (ResourceMap::HasKey("Ceres-parameter_tolerance"))
331 options.parameter_tolerance = ResourceMap::GetAsScalar("Ceres-parameter_tolerance");
332 if (ResourceMap::HasKey("Ceres-linear_solver_type") && !ceres::StringToLinearSolverType(ResourceMap::Get("Ceres-linear_solver_type"), &options.linear_solver_type))
333 throw InvalidArgumentException(HERE) << "Invalid value for linear_solver_type";
334 if (ResourceMap::HasKey("Ceres-preconditioner_type") && !ceres::StringToPreconditionerType(ResourceMap::Get("Ceres-preconditioner_type"), &options.preconditioner_type))
335 throw InvalidArgumentException(HERE) << "Invalid value for preconditioner_type";
336 if (ResourceMap::HasKey("Ceres-visibility_clustering_type") && !ceres::StringToVisibilityClusteringType(ResourceMap::Get("Ceres-visibility_clustering_type"), &options.visibility_clustering_type))
337 throw InvalidArgumentException(HERE) << "Invalid value for visibility_clustering_type";
338 if (ResourceMap::HasKey("Ceres-dense_linear_algebra_library_type") && !ceres::StringToDenseLinearAlgebraLibraryType(ResourceMap::Get("Ceres-dense_linear_algebra_library_type"), &options.dense_linear_algebra_library_type))
339 throw InvalidArgumentException(HERE) << "Invalid value for dense_linear_algebra_library_type";
340 if (ResourceMap::HasKey("Ceres-sparse_linear_algebra_library_type") && !ceres::StringToSparseLinearAlgebraLibraryType(ResourceMap::Get("Ceres-sparse_linear_algebra_library_type"), &options.sparse_linear_algebra_library_type))
341 throw InvalidArgumentException(HERE) << "Invalid value for sparse_linear_algebra_library_type";
342 if (ResourceMap::HasKey("Ceres-use_explicit_schur_complement"))
343 options.use_explicit_schur_complement = ResourceMap::GetAsBool("Ceres-use_explicit_schur_complement");
344 if (ResourceMap::HasKey("Ceres-use_postordering"))
345 options.use_postordering = ResourceMap::GetAsBool("Ceres-use_postordering");
346 if (ResourceMap::HasKey("Ceres-dynamic_sparsity"))
347 options.dynamic_sparsity = ResourceMap::GetAsBool("Ceres-dynamic_sparsity");
348 if (ResourceMap::HasKey("Ceres-min_linear_solver_iterations"))
349 options.min_linear_solver_iterations = ResourceMap::GetAsUnsignedInteger("Ceres-min_linear_solver_iterations");
350 if (ResourceMap::HasKey("Ceres-max_linear_solver_iterations"))
351 options.max_linear_solver_iterations = ResourceMap::GetAsUnsignedInteger("Ceres-max_linear_solver_iterations");
352 if (ResourceMap::HasKey("Ceres-eta"))
353 options.eta = ResourceMap::GetAsScalar("Ceres-eta");
354 if (ResourceMap::HasKey("Ceres-jacobi_scaling"))
355 options.jacobi_scaling = ResourceMap::GetAsBool("Ceres-jacobi_scaling");
356 if (ResourceMap::HasKey("Ceres-use_inner_iterations"))
357 options.use_inner_iterations = ResourceMap::GetAsBool("Ceres-use_inner_iterations");
358 if (ResourceMap::HasKey("Ceres-inner_iteration_tolerance"))
359 options.inner_iteration_tolerance = ResourceMap::GetAsScalar("Ceres-inner_iteration_tolerance");
360 // logging_type: https://github.com/ceres-solver/ceres-solver/issues/470
361 options.logging_type = ceres::SILENT;
362 if (ResourceMap::HasKey("Ceres-minimizer_progress_to_stdout"))
363 options.minimizer_progress_to_stdout = ResourceMap::GetAsBool("Ceres-minimizer_progress_to_stdout");
364 // trust_region_problem_dump_directory/trust_region_problem_dump_format_type: https://github.com/ceres-solver/ceres-solver/issues/470
365 if (ResourceMap::HasKey("Ceres-check_gradients"))
366 options.check_gradients = ResourceMap::GetAsBool("Ceres-check_gradients");
367 if (ResourceMap::HasKey("Ceres-gradient_check_relative_precision"))
368 options.gradient_check_relative_precision = ResourceMap::GetAsScalar("Ceres-gradient_check_relative_precision");
369 if (ResourceMap::HasKey("Ceres-gradient_check_numeric_derivative_relative_step_size"))
370 options.gradient_check_numeric_derivative_relative_step_size = ResourceMap::GetAsScalar("Ceres-gradient_check_numeric_derivative_relative_step_size");
371 if (ResourceMap::HasKey("Ceres-update_state_every_iteration"))
372 options.update_state_every_iteration = ResourceMap::GetAsBool("Ceres-update_state_every_iteration");
373
374 Pointer<IterationCallbackInterface> p_iterationCallbackInterface = new IterationCallbackInterface(*this);
375 options.callbacks.push_back(p_iterationCallbackInterface.get());
376 ceres::Solver::Summary summary;
377 ceres::Solve(options, &problem, &summary);
378 LOGINFO(OSS() << summary.BriefReport());
379 if (summary.termination_type == ceres::FAILURE)
380 throw InternalException(HERE) << "Ceres terminated with failure.";
381 else if (summary.termination_type != ceres::CONVERGENCE)
382 LOGWARN(OSS() << "Ceres terminated with " << ceres::TerminationTypeToString(summary.termination_type));
383
384 optimalValue = summary.final_cost;
385 iterationNumber = summary.iterations.size();
386
387 }
388 else
389 {
390 // general optimization
391
392 ceres::GradientProblemSolver::Options options;
393 // check that algoName is a line search method
394 if (!ceres::StringToLineSearchDirectionType(algoName_, &options.line_search_direction_type))
395 throw InvalidArgumentException(HERE) << "Unconstrained optimization only allows line search methods";
396
397 options.max_num_iterations = getMaximumIterationNumber();
398 options.function_tolerance = getMaximumResidualError();
399 options.parameter_tolerance = getMaximumRelativeError();
400
401 // Set remaining options from ResourceMap
402 if (ResourceMap::HasKey("Ceres-line_search_type") && !ceres::StringToLineSearchType(ResourceMap::Get("Ceres-line_search_type"), &options.line_search_type))
403 throw InvalidArgumentException(HERE) << "Invalid value for line_search_type";
404 if (ResourceMap::HasKey("Ceres-nonlinear_conjugate_gradient_type") && !ceres::StringToNonlinearConjugateGradientType(ResourceMap::Get("Ceres-nonlinear_conjugate_gradient_type"), &options.nonlinear_conjugate_gradient_type))
405 throw InvalidArgumentException(HERE) << "Invalid value for nonlinear_conjugate_gradient_type";
406 if (ResourceMap::HasKey("Ceres-max_lbfgs_rank"))
407 options.max_lbfgs_rank = ResourceMap::GetAsUnsignedInteger("Ceres-max_lbfgs_rank");
408 if (ResourceMap::HasKey("Ceres-use_approximate_eigenvalue_bfgs_scaling"))
409 options.use_approximate_eigenvalue_bfgs_scaling = ResourceMap::GetAsBool("Ceres-use_approximate_eigenvalue_bfgs_scaling");
410 if (ResourceMap::HasKey("Ceres-line_search_interpolation_type") && !ceres::StringToLineSearchInterpolationType(ResourceMap::Get("Ceres-line_search_interpolation_type"), &options.line_search_interpolation_type))
411 throw InvalidArgumentException(HERE) << "Invalid value for line_search_interpolation_type";
412
413 if (ResourceMap::HasKey("Ceres-min_line_search_step_size"))
414 options.min_line_search_step_size = ResourceMap::GetAsScalar("Ceres-min_line_search_step_size");
415 if (ResourceMap::HasKey("Ceres-line_search_sufficient_function_decrease"))
416 options.line_search_sufficient_function_decrease = ResourceMap::GetAsScalar("Ceres-line_search_sufficient_function_decrease");
417 if (ResourceMap::HasKey("Ceres-max_line_search_step_contraction"))
418 options.max_line_search_step_contraction = ResourceMap::GetAsScalar("Ceres-max_line_search_step_contraction");
419 if (ResourceMap::HasKey("Ceres-min_line_search_step_contraction"))
420 options.min_line_search_step_contraction = ResourceMap::GetAsScalar("Ceres-min_line_search_step_contraction");
421 if (ResourceMap::HasKey("Ceres-max_num_line_search_step_size_iterations"))
422 options.max_num_line_search_step_size_iterations = ResourceMap::GetAsUnsignedInteger("Ceres-max_num_line_search_step_size_iterations");
423 if (ResourceMap::HasKey("Ceres-max_num_line_search_direction_restarts"))
424 options.max_num_line_search_direction_restarts = ResourceMap::GetAsUnsignedInteger("Ceres-max_num_line_search_direction_restarts");
425 if (ResourceMap::HasKey("Ceres-line_search_sufficient_curvature_decrease"))
426 options.line_search_sufficient_curvature_decrease = ResourceMap::GetAsScalar("Ceres-line_search_sufficient_curvature_decrease");
427 if (ResourceMap::HasKey("Ceres-max_line_search_step_expansion"))
428 options.max_line_search_step_expansion = ResourceMap::GetAsScalar("Ceres-max_line_search_step_expansion");
429 if (ResourceMap::HasKey("Ceres-max_num_iterations"))
430 options.max_num_iterations = ResourceMap::GetAsUnsignedInteger("Ceres-max_num_iterations");
431 if (ResourceMap::HasKey("Ceres-max_solver_time_in_seconds"))
432 options.max_solver_time_in_seconds = ResourceMap::GetAsScalar("Ceres-max_solver_time_in_seconds");
433 if (ResourceMap::HasKey("Ceres-function_tolerance"))
434 options.function_tolerance = ResourceMap::GetAsScalar("Ceres-function_tolerance");
435 if (ResourceMap::HasKey("Ceres-gradient_tolerance"))
436 options.gradient_tolerance = ResourceMap::GetAsScalar("Ceres-gradient_tolerance");
437 if (ResourceMap::HasKey("Ceres-parameter_tolerance"))
438 options.parameter_tolerance = ResourceMap::GetAsScalar("Ceres-parameter_tolerance");
439 // logging_type: https://github.com/ceres-solver/ceres-solver/issues/470
440 options.logging_type = ceres::SILENT;
441 if (ResourceMap::HasKey("Ceres-minimizer_progress_to_stdout"))
442 options.minimizer_progress_to_stdout = ResourceMap::GetAsBool("Ceres-minimizer_progress_to_stdout");
443
444 Pointer<IterationCallbackInterface> p_iterationCallbackInterface = new IterationCallbackInterface(*this);
445 options.callbacks.push_back(p_iterationCallbackInterface.get());
446 ceres::GradientProblemSolver::Summary summary;
447 ceres::GradientProblem problem(new FirstOrderFunctionInterface(*this));
448 ceres::Solve(options, problem, &x[0], &summary);
449
450 LOGINFO(OSS() << summary.BriefReport());
451 if (summary.termination_type != ceres::CONVERGENCE)
452 LOGWARN(OSS() << "Ceres terminated with " << ceres::TerminationTypeToString(summary.termination_type));
453
454 optimalValue = getProblem().isMinimization() ? summary.final_cost : -summary.final_cost;
455 iterationNumber = summary.iterations.size();
456 }
457
458 OptimizationResult result(getProblem());
459
460 const UnsignedInteger size = evaluationInputHistory_.getSize();
461
462 Scalar absoluteError = -1.0;
463 Scalar relativeError = -1.0;
464 Scalar residualError = -1.0;
465 Scalar constraintError = -1.0;
466
467 for (UnsignedInteger i = 0; i < size; ++ i)
468 {
469 const Point inP(evaluationInputHistory_[i]);
470 const Point outP(evaluationOutputHistory_[i]);
471 constraintError = 0.0;
472 if (getProblem().hasBounds())
473 {
474 const Interval bounds(getProblem().getBounds());
475 for (UnsignedInteger j = 0; j < dimension; ++ j)
476 {
477 if (bounds.getFiniteLowerBound()[j])
478 constraintError = std::max(constraintError, bounds.getLowerBound()[j] - inP[j]);
479 if (bounds.getFiniteUpperBound()[j])
480 constraintError = std::max(constraintError, inP[j] - bounds.getUpperBound()[j]);
481 }
482 }
483 if (i > 0)
484 {
485 const Point inPM(evaluationInputHistory_[i - 1]);
486 const Point outPM(evaluationOutputHistory_[i - 1]);
487 absoluteError = (inP - inPM).normInf();
488 relativeError = (inP.normInf() > 0.0) ? (absoluteError / inP.normInf()) : -1.0;
489 residualError = (std::abs(outP[0]) > 0.0) ? (std::abs(outP[0] - outPM[0]) / std::abs(outP[0])) : -1.0;
490 }
491 result.store(inP, outP, absoluteError, relativeError, residualError, constraintError);
492 }
493
494 result.setEvaluationNumber(size);
495 result.setIterationNumber(iterationNumber);
496 result.setOptimalPoint(x);
497 result.setOptimalValue(Point(1, optimalValue));
498 setResult(result);
499 #else
500 throw NotYetImplementedException(HERE) << "No Ceres support";
501 #endif
502 }
503
504 /* String converter */
__repr__() const505 String Ceres::__repr__() const
506 {
507 OSS oss;
508 oss << "class=" << getClassName()
509 << " " << OptimizationAlgorithmImplementation::__repr__();
510 return oss;
511 }
512
513 /* String converter */
__str__(const String &) const514 String Ceres::__str__(const String & ) const
515 {
516 OSS oss(false);
517 oss << "class=" << getClassName();
518 return oss;
519 }
520
setAlgorithmName(const String algoName)521 void Ceres::setAlgorithmName(const String algoName)
522 {
523 algoName_ = algoName;
524 }
525
getAlgorithmName() const526 String Ceres::getAlgorithmName() const
527 {
528 return algoName_;
529 }
530
531 /* Method save() stores the object through the StorageManager */
save(Advocate & adv) const532 void Ceres::save(Advocate & adv) const
533 {
534 OptimizationAlgorithmImplementation::save(adv);
535 adv.saveAttribute("algoName_", algoName_);
536 }
537
538 /* Method load() reloads the object from the StorageManager */
load(Advocate & adv)539 void Ceres::load(Advocate & adv)
540 {
541 OptimizationAlgorithmImplementation::load(adv);
542 adv.loadAttribute("algoName_", algoName_);
543 }
544
IsAvailable()545 Bool Ceres::IsAvailable()
546 {
547 #ifdef OPENTURNS_HAVE_CERES
548 return true;
549 #else
550 return false;
551 #endif
552 }
553
Initialize()554 void Ceres::Initialize()
555 {
556 #ifdef OPENTURNS_HAVE_CERES
557 google::InitGoogleLogging("openturns");
558 #endif
559 }
560
561 END_NAMESPACE_OPENTURNS
562
563