1 /** 2 * @file katyusha.hpp 3 * @author Marcus Edel 4 * 5 * Katyusha a direct, primal-only stochastic gradient method. 6 * 7 * ensmallen is free software; you may redistribute it and/or modify it under 8 * the terms of the 3-clause BSD license. You should have received a copy of 9 * the 3-clause BSD license along with ensmallen. If not, see 10 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 11 */ 12 #ifndef ENSMALLEN_KATYUSHA_KATYUSHA_HPP 13 #define ENSMALLEN_KATYUSHA_KATYUSHA_HPP 14 15 namespace ens { 16 17 /** 18 * Katyusha is a direct, primal-only stochastic gradient method which uses a 19 * "negative momentum" on top of Nesterov’s momentum. 20 * 21 * For more information, see the following. 22 * 23 * @code 24 * @inproceedings{Allen-Zhu2016, 25 * author = {{Allen-Zhu}, Z.}, 26 * title = {Katyusha: The First Direct Acceleration of Stochastic Gradient 27 * Methods}, 28 * booktitle = {Proceedings of the 49th Annual ACM SIGACT Symposium on Theory 29 * of Computing}, 30 * pages = {1200--1205}, 31 * publisher = {ACM}, 32 * year = {2017}, 33 * series = {STOC 2017}, 34 * } 35 * @endcode 36 * 37 * Katyusha can optimize differentiable separable functions. For more details, 38 * see the documentation on function types included with this distribution or on 39 * the ensmallen website. 40 * 41 * @tparam proximal Whether the proximal update should be used or not. 42 */ 43 template<bool Proximal = false> 44 class KatyushaType 45 { 46 public: 47 /** 48 * Construct the Katyusha optimizer with the given function and parameters. 49 * The defaults here are not necessarily good for the given problem, so it is 50 * suggested that the values used be tailored to the task at hand. The 51 * maximum number of iterations refers to the maximum number of points that 52 * are processed (i.e., one iteration equals one point; one iteration does not 53 * equal one pass over the dataset). 54 * 55 * @param convexity The regularization parameter. 56 * @param lipschitz The Lipschitz constant. 57 * @param batchSize Batch size to use for each step. 58 * @param maxIterations Maximum number of iterations allowed (0 means no 59 * limit). 60 * @param innerIterations The number of inner iterations allowed (0 means 61 * n / batchSize). Note that the full gradient is only calculated in 62 * the outer iteration. 63 * @param tolerance Maximum absolute tolerance to terminate algorithm. 64 * @param shuffle If true, the function order is shuffled; otherwise, each 65 * function is visited in linear order. 66 * @param exactObjective Calculate the exact objective (Default: estimate the 67 * final objective obtained on the last pass over the data). 68 */ 69 KatyushaType(const double convexity = 1.0, 70 const double lipschitz = 10.0, 71 const size_t batchSize = 32, 72 const size_t maxIterations = 1000, 73 const size_t innerIterations = 0, 74 const double tolerance = 1e-5, 75 const bool shuffle = true, 76 const bool exactObjective = false); 77 78 /** 79 * Optimize the given function using Katyusha. The given starting point will 80 * be modified to store the finishing point of the algorithm, and the final 81 * objective value is returned. 82 * 83 * @tparam SeparableFunctionType Type of the function to be optimized. 84 * @tparam MatType Type of matrix to optimize with. 85 * @tparam GradType Type of matrix to use to represent function gradients. 86 * @tparam CallbackTypes Types of callback functions. 87 * @param function Function to optimize. 88 * @param iterate Starting point (will be modified). 89 * @param callbacks Callback functions. 90 * @return Objective value of the final point. 91 */ 92 template<typename SeparableFunctionType, 93 typename MatType, 94 typename GradType, 95 typename... CallbackTypes> 96 typename std::enable_if<IsArmaType<GradType>::value, 97 typename MatType::elem_type>::type 98 Optimize(SeparableFunctionType& function, 99 MatType& iterate, 100 CallbackTypes&&... callbacks); 101 102 //! Forward the MatType as GradType. 103 template<typename SeparableFunctionType, 104 typename MatType, 105 typename... CallbackTypes> Optimize(SeparableFunctionType & function,MatType & iterate,CallbackTypes &&...callbacks)106 typename MatType::elem_type Optimize(SeparableFunctionType& function, 107 MatType& iterate, 108 CallbackTypes&&... callbacks) 109 { 110 return Optimize<SeparableFunctionType, MatType, MatType, 111 CallbackTypes...>(function, iterate, 112 std::forward<CallbackTypes>(callbacks)...); 113 } 114 115 //! Get the convexity parameter. Convexity() const116 double Convexity() const { return convexity; } 117 //! Modify the convexity parameter. Convexity()118 double& Convexity() { return convexity; } 119 120 //! Get the lipschitz parameter. Lipschitz() const121 double Lipschitz() const { return lipschitz; } 122 //! Modify the lipschitz parameter. Lipschitz()123 double& Lipschitz() { return lipschitz; } 124 125 //! Get the batch size. BatchSize() const126 size_t BatchSize() const { return batchSize; } 127 //! Modify the batch size. BatchSize()128 size_t& BatchSize() { return batchSize; } 129 130 //! Get the maximum number of iterations (0 indicates no limit). MaxIterations() const131 size_t MaxIterations() const { return maxIterations; } 132 //! Modify the maximum number of iterations (0 indicates no limit). MaxIterations()133 size_t& MaxIterations() { return maxIterations; } 134 135 //! Get the maximum number of iterations (0 indicates default n / b). InnerIterations() const136 size_t InnerIterations() const { return innerIterations; } 137 //! Modify the maximum number of iterations (0 indicates default n / b). InnerIterations()138 size_t& InnerIterations() { return innerIterations; } 139 140 //! Get the tolerance for termination. Tolerance() const141 double Tolerance() const { return tolerance; } 142 //! Modify the tolerance for termination. Tolerance()143 double& Tolerance() { return tolerance; } 144 145 //! Get whether or not the individual functions are shuffled. Shuffle() const146 bool Shuffle() const { return shuffle; } 147 //! Modify whether or not the individual functions are shuffled. Shuffle()148 bool& Shuffle() { return shuffle; } 149 150 //! Get whether or not the actual objective is calculated. ExactObjective() const151 bool ExactObjective() const { return exactObjective; } 152 //! Modify whether or not the actual objective is calculated. ExactObjective()153 bool& ExactObjective() { return exactObjective; } 154 155 private: 156 //! The convexity regularization term. 157 double convexity; 158 159 //! The lipschitz constant. 160 double lipschitz; 161 162 //! The batch size for processing. 163 size_t batchSize; 164 165 //! The maximum number of allowed iterations. 166 size_t maxIterations; 167 168 //! The maximum number of allowed inner iterations per epoch. 169 size_t innerIterations; 170 171 //! The tolerance for termination. 172 double tolerance; 173 174 //! Controls whether or not the individual functions are shuffled when 175 //! iterating. 176 bool shuffle; 177 178 //! Controls whether or not the actual Objective value is calculated. 179 bool exactObjective; 180 }; 181 182 // Convenience typedefs. 183 184 /** 185 * Katyusha using the standard update step. 186 */ 187 using Katyusha = KatyushaType<false>; 188 189 /** 190 * Katyusha using the proximal update step. 191 */ 192 using KatyushaProximal = KatyushaType<true>; 193 194 } // namespace ens 195 196 // Include implementation. 197 #include "katyusha_impl.hpp" 198 199 #endif 200