1 /**
2  * @file spsa.hpp
3  * @author N Rajiv Vaidyanathan
4  * @author Marcus Edel
5  *
6  * SPSA (Simultaneous perturbation stochastic approximation)
7  * update for faster convergence.
8  *
9  * ensmallen is free software; you may redistribute it and/or modify it under
10  * the terms of the 3-clause BSD license.  You should have received a copy of
11  * the 3-clause BSD license along with ensmallen.  If not, see
12  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
13  */
14 #ifndef ENSMALLEN_SPSA_SPSA_IMPL_HPP
15 #define ENSMALLEN_SPSA_SPSA_IMPL_HPP
16 
17 // In case it hasn't been included yet.
18 #include "spsa.hpp"
19 
20 #include <ensmallen_bits/function.hpp>
21 
22 namespace ens {
23 
SPSA(const double alpha,const double gamma,const double stepSize,const double evaluationStepSize,const size_t maxIterations,const double tolerance)24 inline SPSA::SPSA(const double alpha,
25                   const double gamma,
26                   const double stepSize,
27                   const double evaluationStepSize,
28                   const size_t maxIterations,
29                   const double tolerance) :
30     alpha(alpha),
31     gamma(gamma),
32     stepSize(stepSize),
33     evaluationStepSize(evaluationStepSize),
34     ak(0.001 * maxIterations),
35     maxIterations(maxIterations),
36     tolerance(tolerance)
37 { /* Nothing to do. */ }
38 
39 template<typename ArbitraryFunctionType,
40          typename MatType,
41          typename... CallbackTypes>
Optimize(ArbitraryFunctionType & function,MatType & iterate,CallbackTypes &&...callbacks)42 typename MatType::elem_type SPSA::Optimize(ArbitraryFunctionType& function,
43                                            MatType& iterate,
44                                            CallbackTypes&&... callbacks)
45 {
46   // Convenience typedefs.
47   typedef typename MatType::elem_type ElemType;
48   typedef typename MatTypeTraits<MatType>::BaseMatType BaseMatType;
49 
50   // Make sure that we have the methods that we need.
51   traits::CheckArbitraryFunctionTypeAPI<ArbitraryFunctionType,
52       MatType>();
53   RequireFloatingPointType<MatType>();
54 
55   BaseMatType gradient(iterate.n_rows, iterate.n_cols);
56   arma::Mat<ElemType> spVector(iterate.n_rows, iterate.n_cols);
57 
58   // To keep track of where we are and how things are going.
59   ElemType overallObjective = 0;
60   ElemType lastObjective = DBL_MAX;
61 
62   // Controls early termination of the optimization process.
63   bool terminate = false;
64 
65   terminate |= Callback::BeginOptimization(*this, function, iterate,
66       callbacks...);
67   for (size_t k = 0; k < maxIterations && !terminate; ++k)
68   {
69     // Output current objective function.
70     Info << "SPSA: iteration " << k << ", objective " << overallObjective
71         << "." << std::endl;
72 
73     if (std::isnan(overallObjective) || std::isinf(overallObjective))
74     {
75       Warn << "SPSA: converged to " << overallObjective << "; terminating"
76           << " with failure.  Try a smaller step size?" << std::endl;
77 
78       Callback::EndOptimization(*this, function, iterate, callbacks...);
79       return overallObjective;
80     }
81 
82     if (std::abs(lastObjective - overallObjective) < tolerance)
83     {
84       Warn << "SPSA: minimized within tolerance " << tolerance << "; "
85           << "terminating optimization." << std::endl;
86       Callback::EndOptimization(*this, function, iterate, callbacks...);
87       return overallObjective;
88     }
89 
90     // Reset the counter variables.
91     lastObjective = overallObjective;
92 
93     // Gain sequences.
94     const double akLocal = stepSize / std::pow(k + 1 + ak, alpha);
95     const double ck = evaluationStepSize / std::pow(k + 1, gamma);
96 
97     // Choose stochastic directions.
98     spVector = arma::conv_to<arma::Mat<ElemType>>::from(
99         arma::randi(iterate.n_rows, iterate.n_cols,
100         arma::distr_param(0, 1))) * 2 - 1;
101 
102     iterate += ck * spVector;
103     const double fPlus = function.Evaluate(iterate);
104     Callback::Evaluate(*this, function, iterate, fPlus, callbacks...);
105 
106     iterate -= 2 * ck * spVector;
107     const double fMinus = function.Evaluate(iterate);
108     Callback::Evaluate(*this, function, iterate, fMinus, callbacks...);
109 
110     iterate += ck * spVector;
111 
112     gradient = (fPlus - fMinus) * (1 / (2 * ck * spVector));
113     iterate -= akLocal * gradient;
114 
115     terminate |= Callback::StepTaken(*this, function, iterate, callbacks...);
116 
117     overallObjective = function.Evaluate(iterate);
118     Callback::Evaluate(*this, function, iterate, overallObjective,
119         callbacks...);
120   }
121 
122   // Calculate final objective.
123   const ElemType objective = function.Evaluate(iterate);
124   Callback::Evaluate(*this, function, iterate, objective, callbacks...);
125 
126   Callback::EndOptimization(*this, function, iterate, callbacks...);
127   return objective;
128 }
129 
130 } // namespace ens
131 
132 #endif
133