1 /** 2 * @file swats.hpp 3 * @author Marcus Edel 4 * 5 * SWATS is a simple strategy which switches from Adam to SGD when a triggering 6 * condition is satisfied. 7 * 8 * ensmallen is free software; you may redistribute it and/or modify it under 9 * the terms of the 3-clause BSD license. You should have received a copy of 10 * the 3-clause BSD license along with ensmallen. If not, see 11 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 12 */ 13 #ifndef ENSMALLEN_SWATS_SWATS_HPP 14 #define ENSMALLEN_SWATS_SWATS_HPP 15 16 #include <ensmallen_bits/sgd/sgd.hpp> 17 #include "swats_update.hpp" 18 19 namespace ens { 20 21 /** 22 * SWATS is a simple strategy which switches from Adam to SGD when a triggering 23 * condition is satisfied. The condition relates to the projection of Adam steps 24 * on the gradient subspace. 25 * 26 * For more information, see the following. 27 * 28 * @code 29 * @article{Keskar2017, 30 * author = {Nitish Shirish Keskar and Richard Socher}, 31 * title = {Improving Generalization Performance by Switching from Adam to 32 * {SGD}}, 33 * journal = {CoRR}, 34 * volume = {abs/1712.07628}, 35 * year = {2017} 36 * url = {http://arxiv.org/abs/1712.07628}, 37 * } 38 * @endcode 39 * 40 * SWATS can optimize differentiable separable functions. For more details, see 41 * the documentation on function types include with this distribution or on the 42 * ensmallen website. 43 */ 44 class SWATS 45 { 46 public: 47 /** 48 * Construct the SWATS optimizer with the given function and parameters. The 49 * defaults here are not necessarily good for the given problem, so it is 50 * suggested that the values used be tailored to the task at hand. The 51 * maximum number of iterations refers to the maximum number of points that 52 * are processed (i.e., one iteration equals one point; one iteration does not 53 * equal one pass over the dataset). 54 * 55 * @param stepSize Step size for each iteration. 56 * @param batchSize Number of points to process in a single step. 57 * @param beta1 Exponential decay rate for the first moment estimates. 58 * @param beta2 Exponential decay rate for the weighted infinity norm 59 estimates. 60 * @param epsilon Value used to initialise the mean squared gradient 61 * parameter. 62 * @param maxIterations Maximum number of iterations allowed (0 means no 63 * limit). 64 * @param tolerance Maximum absolute tolerance to terminate algorithm. 65 * @param shuffle If true, the function order is shuffled; otherwise, each 66 * function is visited in linear order. 67 * @param resetPolicy If true, parameters are reset before every Optimize 68 * call; otherwise, their values are retained. 69 * @param exactObjective Calculate the exact objective (Default: estimate the 70 * final objective obtained on the last pass over the data). 71 */ 72 SWATS(const double stepSize = 0.001, 73 const size_t batchSize = 32, 74 const double beta1 = 0.9, 75 const double beta2 = 0.999, 76 const double epsilon = 1e-8, 77 const size_t maxIterations = 100000, 78 const double tolerance = 1e-5, 79 const bool shuffle = true, 80 const bool resetPolicy = true, 81 const bool exactObjective = false); 82 83 /** 84 * Optimize the given function using SWATS. The given starting point will 85 * be modified to store the finishing point of the algorithm, and the final 86 * objective value is returned. 87 * 88 * @tparam SeparableFunctionType Type of the function to be optimized. 89 * @tparam MatType Type of matrix to optimize with. 90 * @tparam GradType Type of matrix to use to represent function gradients. 91 * @tparam CallbackTypes Types of callback functions. 92 * @param function Function to optimize. 93 * @param iterate Starting point (will be modified). 94 * @param callbacks Callback functions. 95 * @return Objective value of the final point. 96 */ 97 template<typename SeparableFunctionType, 98 typename MatType, 99 typename GradType, 100 typename... CallbackTypes> 101 typename std::enable_if<IsArmaType<GradType>::value, 102 typename MatType::elem_type>::type Optimize(SeparableFunctionType & function,MatType & iterate,CallbackTypes &&...callbacks)103 Optimize(SeparableFunctionType& function, 104 MatType& iterate, 105 CallbackTypes&&... callbacks) 106 { 107 return optimizer.Optimize<SeparableFunctionType, MatType, GradType, 108 CallbackTypes...>(function, iterate, 109 std::forward<CallbackTypes>(callbacks)...); 110 } 111 112 //! Forward the MatType as GradType. 113 template<typename SeparableFunctionType, 114 typename MatType, 115 typename... CallbackTypes> Optimize(SeparableFunctionType & function,MatType & iterate,CallbackTypes &&...callbacks)116 typename MatType::elem_type Optimize(SeparableFunctionType& function, 117 MatType& iterate, 118 CallbackTypes&&... callbacks) 119 { 120 return Optimize<SeparableFunctionType, MatType, MatType, 121 CallbackTypes...>(function, iterate, 122 std::forward<CallbackTypes>(callbacks)...); 123 } 124 125 //! Get the step size. StepSize() const126 double StepSize() const { return optimizer.StepSize(); } 127 //! Modify the step size. StepSize()128 double& StepSize() { return optimizer.StepSize(); } 129 130 //! Get the batch size. BatchSize() const131 size_t BatchSize() const { return optimizer.BatchSize(); } 132 //! Modify the batch size. BatchSize()133 size_t& BatchSize() { return optimizer.BatchSize(); } 134 135 //! Get the smoothing parameter. Beta1() const136 double Beta1() const { return optimizer.UpdatePolicy().Beta1(); } 137 //! Modify the smoothing parameter. Beta1()138 double& Beta1() { return optimizer.UpdatePolicy().Beta1(); } 139 140 //! Get the second moment coefficient. Beta2() const141 double Beta2() const { return optimizer.UpdatePolicy().Beta2(); } 142 //! Modify the second moment coefficient. Beta2()143 double& Beta2() { return optimizer.UpdatePolicy().Beta2(); } 144 145 //! Get the value used to initialise the mean squared gradient parameter. Epsilon() const146 double Epsilon() const { return optimizer.UpdatePolicy().Epsilon(); } 147 //! Modify the value used to initialise the mean squared gradient parameter. Epsilon()148 double& Epsilon() { return optimizer.UpdatePolicy().Epsilon(); } 149 150 //! Get the maximum number of iterations (0 indicates no limit). MaxIterations() const151 size_t MaxIterations() const { return optimizer.MaxIterations(); } 152 //! Modify the maximum number of iterations (0 indicates no limit). MaxIterations()153 size_t& MaxIterations() { return optimizer.MaxIterations(); } 154 155 //! Get the tolerance for termination. Tolerance() const156 double Tolerance() const { return optimizer.Tolerance(); } 157 //! Modify the tolerance for termination. Tolerance()158 double& Tolerance() { return optimizer.Tolerance(); } 159 160 //! Get whether or not the individual functions are shuffled. Shuffle() const161 bool Shuffle() const { return optimizer.Shuffle(); } 162 //! Modify whether or not the individual functions are shuffled. Shuffle()163 bool& Shuffle() { return optimizer.Shuffle(); } 164 165 //! Get whether or not the actual objective is calculated. ExactObjective() const166 bool ExactObjective() const { return optimizer.ExactObjective(); } 167 //! Modify whether or not the actual objective is calculated. ExactObjective()168 bool& ExactObjective() { return optimizer.ExactObjective(); } 169 170 //! Get whether or not the update policy parameters 171 //! are reset before Optimize call. ResetPolicy() const172 bool ResetPolicy() const { return optimizer.ResetPolicy(); } 173 //! Modify whether or not the update policy parameters 174 //! are reset before Optimize call. ResetPolicy()175 bool& ResetPolicy() { return optimizer.ResetPolicy(); } 176 177 private: 178 //! The SWATS update policy. 179 SGD<SWATSUpdate> optimizer; 180 }; 181 182 } // namespace ens 183 184 // Include implementation. 185 #include "swats_impl.hpp" 186 187 #endif 188