1 /**
2  * @file swats.hpp
3  * @author Marcus Edel
4  *
5  * SWATS is a simple strategy which switches from Adam to SGD when a triggering
6  * condition is satisfied.
7  *
8  * ensmallen is free software; you may redistribute it and/or modify it under
9  * the terms of the 3-clause BSD license.  You should have received a copy of
10  * the 3-clause BSD license along with ensmallen.  If not, see
11  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12  */
13 #ifndef ENSMALLEN_SWATS_SWATS_HPP
14 #define ENSMALLEN_SWATS_SWATS_HPP
15 
16 #include <ensmallen_bits/sgd/sgd.hpp>
17 #include "swats_update.hpp"
18 
19 namespace ens {
20 
21 /**
22  * SWATS is a simple strategy which switches from Adam to SGD when a triggering
23  * condition is satisfied. The condition relates to the projection of Adam steps
24  * on the gradient subspace.
25  *
26  * For more information, see the following.
27  *
28  * @code
29  * @article{Keskar2017,
30  *   author  = {Nitish Shirish Keskar and Richard Socher},
31  *   title   = {Improving Generalization Performance by Switching from Adam to
32  *              {SGD}},
33  *   journal = {CoRR},
34  *   volume  = {abs/1712.07628},
35  *   year    = {2017}
36  *   url     = {http://arxiv.org/abs/1712.07628},
37  * }
38  * @endcode
39  *
40  * SWATS can optimize differentiable separable functions.  For more details, see
41  * the documentation on function types include with this distribution or on the
42  * ensmallen website.
43  */
44 class SWATS
45 {
46  public:
47   /**
48    * Construct the SWATS optimizer with the given function and parameters. The
49    * defaults here are not necessarily good for the given problem, so it is
50    * suggested that the values used be tailored to the task at hand.  The
51    * maximum number of iterations refers to the maximum number of points that
52    * are processed (i.e., one iteration equals one point; one iteration does not
53    * equal one pass over the dataset).
54    *
55    * @param stepSize Step size for each iteration.
56    * @param batchSize Number of points to process in a single step.
57    * @param beta1 Exponential decay rate for the first moment estimates.
58    * @param beta2 Exponential decay rate for the weighted infinity norm
59             estimates.
60    * @param epsilon Value used to initialise the mean squared gradient
61    *        parameter.
62    * @param maxIterations Maximum number of iterations allowed (0 means no
63    *        limit).
64    * @param tolerance Maximum absolute tolerance to terminate algorithm.
65    * @param shuffle If true, the function order is shuffled; otherwise, each
66    *        function is visited in linear order.
67    * @param resetPolicy If true, parameters are reset before every Optimize
68    *        call; otherwise, their values are retained.
69    * @param exactObjective Calculate the exact objective (Default: estimate the
70    *        final objective obtained on the last pass over the data).
71    */
72   SWATS(const double stepSize = 0.001,
73         const size_t batchSize = 32,
74         const double beta1 = 0.9,
75         const double beta2 = 0.999,
76         const double epsilon = 1e-8,
77         const size_t maxIterations = 100000,
78         const double tolerance = 1e-5,
79         const bool shuffle = true,
80         const bool resetPolicy = true,
81         const bool exactObjective = false);
82 
83   /**
84    * Optimize the given function using SWATS. The given starting point will
85    * be modified to store the finishing point of the algorithm, and the final
86    * objective value is returned.
87    *
88    * @tparam SeparableFunctionType Type of the function to be optimized.
89    * @tparam MatType Type of matrix to optimize with.
90    * @tparam GradType Type of matrix to use to represent function gradients.
91    * @tparam CallbackTypes Types of callback functions.
92    * @param function Function to optimize.
93    * @param iterate Starting point (will be modified).
94    * @param callbacks Callback functions.
95    * @return Objective value of the final point.
96    */
97   template<typename SeparableFunctionType,
98            typename MatType,
99            typename GradType,
100            typename... CallbackTypes>
101   typename std::enable_if<IsArmaType<GradType>::value,
102       typename MatType::elem_type>::type
Optimize(SeparableFunctionType & function,MatType & iterate,CallbackTypes &&...callbacks)103   Optimize(SeparableFunctionType& function,
104            MatType& iterate,
105            CallbackTypes&&... callbacks)
106   {
107     return optimizer.Optimize<SeparableFunctionType, MatType, GradType,
108         CallbackTypes...>(function, iterate,
109         std::forward<CallbackTypes>(callbacks)...);
110   }
111 
112   //! Forward the MatType as GradType.
113   template<typename SeparableFunctionType,
114            typename MatType,
115            typename... CallbackTypes>
Optimize(SeparableFunctionType & function,MatType & iterate,CallbackTypes &&...callbacks)116   typename MatType::elem_type Optimize(SeparableFunctionType& function,
117                                        MatType& iterate,
118                                        CallbackTypes&&... callbacks)
119   {
120     return Optimize<SeparableFunctionType, MatType, MatType,
121         CallbackTypes...>(function, iterate,
122         std::forward<CallbackTypes>(callbacks)...);
123   }
124 
125   //! Get the step size.
StepSize() const126   double StepSize() const { return optimizer.StepSize(); }
127   //! Modify the step size.
StepSize()128   double& StepSize() { return optimizer.StepSize(); }
129 
130   //! Get the batch size.
BatchSize() const131   size_t BatchSize() const { return optimizer.BatchSize(); }
132   //! Modify the batch size.
BatchSize()133   size_t& BatchSize() { return optimizer.BatchSize(); }
134 
135   //! Get the smoothing parameter.
Beta1() const136   double Beta1() const { return optimizer.UpdatePolicy().Beta1(); }
137   //! Modify the smoothing parameter.
Beta1()138   double& Beta1() { return optimizer.UpdatePolicy().Beta1(); }
139 
140   //! Get the second moment coefficient.
Beta2() const141   double Beta2() const { return optimizer.UpdatePolicy().Beta2(); }
142   //! Modify the second moment coefficient.
Beta2()143   double& Beta2() { return optimizer.UpdatePolicy().Beta2(); }
144 
145   //! Get the value used to initialise the mean squared gradient parameter.
Epsilon() const146   double Epsilon() const { return optimizer.UpdatePolicy().Epsilon(); }
147   //! Modify the value used to initialise the mean squared gradient parameter.
Epsilon()148   double& Epsilon() { return optimizer.UpdatePolicy().Epsilon(); }
149 
150   //! Get the maximum number of iterations (0 indicates no limit).
MaxIterations() const151   size_t MaxIterations() const { return optimizer.MaxIterations(); }
152   //! Modify the maximum number of iterations (0 indicates no limit).
MaxIterations()153   size_t& MaxIterations() { return optimizer.MaxIterations(); }
154 
155   //! Get the tolerance for termination.
Tolerance() const156   double Tolerance() const { return optimizer.Tolerance(); }
157   //! Modify the tolerance for termination.
Tolerance()158   double& Tolerance() { return optimizer.Tolerance(); }
159 
160   //! Get whether or not the individual functions are shuffled.
Shuffle() const161   bool Shuffle() const { return optimizer.Shuffle(); }
162   //! Modify whether or not the individual functions are shuffled.
Shuffle()163   bool& Shuffle() { return optimizer.Shuffle(); }
164 
165   //! Get whether or not the actual objective is calculated.
ExactObjective() const166   bool ExactObjective() const { return optimizer.ExactObjective(); }
167   //! Modify whether or not the actual objective is calculated.
ExactObjective()168   bool& ExactObjective() { return optimizer.ExactObjective(); }
169 
170   //! Get whether or not the update policy parameters
171   //! are reset before Optimize call.
ResetPolicy() const172   bool ResetPolicy() const { return optimizer.ResetPolicy(); }
173   //! Modify whether or not the update policy parameters
174   //! are reset before Optimize call.
ResetPolicy()175   bool& ResetPolicy() { return optimizer.ResetPolicy(); }
176 
177  private:
178   //! The SWATS update policy.
179   SGD<SWATSUpdate> optimizer;
180 };
181 
182 } // namespace ens
183 
184 // Include implementation.
185 #include "swats_impl.hpp"
186 
187 #endif
188