1 /**
2 * @file spsa.hpp
3 * @author N Rajiv Vaidyanathan
4 * @author Marcus Edel
5 *
6 * SPSA (Simultaneous perturbation stochastic approximation)
7 * update for faster convergence.
8 *
9 * ensmallen is free software; you may redistribute it and/or modify it under
10 * the terms of the 3-clause BSD license. You should have received a copy of
11 * the 3-clause BSD license along with ensmallen. If not, see
12 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
13 */
14 #ifndef ENSMALLEN_SPSA_SPSA_IMPL_HPP
15 #define ENSMALLEN_SPSA_SPSA_IMPL_HPP
16
17 // In case it hasn't been included yet.
18 #include "spsa.hpp"
19
20 #include <ensmallen_bits/function.hpp>
21
22 namespace ens {
23
SPSA(const double alpha,const double gamma,const double stepSize,const double evaluationStepSize,const size_t maxIterations,const double tolerance)24 inline SPSA::SPSA(const double alpha,
25 const double gamma,
26 const double stepSize,
27 const double evaluationStepSize,
28 const size_t maxIterations,
29 const double tolerance) :
30 alpha(alpha),
31 gamma(gamma),
32 stepSize(stepSize),
33 evaluationStepSize(evaluationStepSize),
34 ak(0.001 * maxIterations),
35 maxIterations(maxIterations),
36 tolerance(tolerance)
37 { /* Nothing to do. */ }
38
39 template<typename ArbitraryFunctionType,
40 typename MatType,
41 typename... CallbackTypes>
Optimize(ArbitraryFunctionType & function,MatType & iterate,CallbackTypes &&...callbacks)42 typename MatType::elem_type SPSA::Optimize(ArbitraryFunctionType& function,
43 MatType& iterate,
44 CallbackTypes&&... callbacks)
45 {
46 // Convenience typedefs.
47 typedef typename MatType::elem_type ElemType;
48 typedef typename MatTypeTraits<MatType>::BaseMatType BaseMatType;
49
50 // Make sure that we have the methods that we need.
51 traits::CheckArbitraryFunctionTypeAPI<ArbitraryFunctionType,
52 MatType>();
53 RequireFloatingPointType<MatType>();
54
55 BaseMatType gradient(iterate.n_rows, iterate.n_cols);
56 arma::Mat<ElemType> spVector(iterate.n_rows, iterate.n_cols);
57
58 // To keep track of where we are and how things are going.
59 ElemType overallObjective = 0;
60 ElemType lastObjective = DBL_MAX;
61
62 // Controls early termination of the optimization process.
63 bool terminate = false;
64
65 terminate |= Callback::BeginOptimization(*this, function, iterate,
66 callbacks...);
67 for (size_t k = 0; k < maxIterations && !terminate; ++k)
68 {
69 // Output current objective function.
70 Info << "SPSA: iteration " << k << ", objective " << overallObjective
71 << "." << std::endl;
72
73 if (std::isnan(overallObjective) || std::isinf(overallObjective))
74 {
75 Warn << "SPSA: converged to " << overallObjective << "; terminating"
76 << " with failure. Try a smaller step size?" << std::endl;
77
78 Callback::EndOptimization(*this, function, iterate, callbacks...);
79 return overallObjective;
80 }
81
82 if (std::abs(lastObjective - overallObjective) < tolerance)
83 {
84 Warn << "SPSA: minimized within tolerance " << tolerance << "; "
85 << "terminating optimization." << std::endl;
86 Callback::EndOptimization(*this, function, iterate, callbacks...);
87 return overallObjective;
88 }
89
90 // Reset the counter variables.
91 lastObjective = overallObjective;
92
93 // Gain sequences.
94 const double akLocal = stepSize / std::pow(k + 1 + ak, alpha);
95 const double ck = evaluationStepSize / std::pow(k + 1, gamma);
96
97 // Choose stochastic directions.
98 spVector = arma::conv_to<arma::Mat<ElemType>>::from(
99 arma::randi(iterate.n_rows, iterate.n_cols,
100 arma::distr_param(0, 1))) * 2 - 1;
101
102 iterate += ck * spVector;
103 const double fPlus = function.Evaluate(iterate);
104 Callback::Evaluate(*this, function, iterate, fPlus, callbacks...);
105
106 iterate -= 2 * ck * spVector;
107 const double fMinus = function.Evaluate(iterate);
108 Callback::Evaluate(*this, function, iterate, fMinus, callbacks...);
109
110 iterate += ck * spVector;
111
112 gradient = (fPlus - fMinus) * (1 / (2 * ck * spVector));
113 iterate -= akLocal * gradient;
114
115 terminate |= Callback::StepTaken(*this, function, iterate, callbacks...);
116
117 overallObjective = function.Evaluate(iterate);
118 Callback::Evaluate(*this, function, iterate, overallObjective,
119 callbacks...);
120 }
121
122 // Calculate final objective.
123 const ElemType objective = function.Evaluate(iterate);
124 Callback::Evaluate(*this, function, iterate, objective, callbacks...);
125
126 Callback::EndOptimization(*this, function, iterate, callbacks...);
127 return objective;
128 }
129
130 } // namespace ens
131
132 #endif
133