1 /**
2  * @file katyusha.hpp
3  * @author Marcus Edel
4  *
5  * Katyusha a direct, primal-only stochastic gradient method.
6  *
7  * ensmallen is free software; you may redistribute it and/or modify it under
8  * the terms of the 3-clause BSD license.  You should have received a copy of
9  * the 3-clause BSD license along with ensmallen.  If not, see
10  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
11  */
12 #ifndef ENSMALLEN_KATYUSHA_KATYUSHA_HPP
13 #define ENSMALLEN_KATYUSHA_KATYUSHA_HPP
14 
15 namespace ens {
16 
17 /**
18  * Katyusha is a direct, primal-only stochastic gradient method which uses a
19  * "negative momentum" on top of Nesterov’s momentum.
20  *
21  * For more information, see the following.
22  *
23  * @code
24  * @inproceedings{Allen-Zhu2016,
25  *   author    = {{Allen-Zhu}, Z.},
26  *   title     = {Katyusha: The First Direct Acceleration of Stochastic Gradient
27  *                Methods},
28  *   booktitle = {Proceedings of the 49th Annual ACM SIGACT Symposium on Theory
29  *                of Computing},
30  *   pages     = {1200--1205},
31  *   publisher = {ACM},
32  *   year      = {2017},
33  *   series    = {STOC 2017},
34  * }
35  * @endcode
36  *
37  * Katyusha can optimize differentiable separable functions.  For more details,
38  * see the documentation on function types included with this distribution or on
39  * the ensmallen website.
40  *
41  * @tparam proximal Whether the proximal update should be used or not.
42  */
43 template<bool Proximal = false>
44 class KatyushaType
45 {
46  public:
47   /**
48    * Construct the Katyusha optimizer with the given function and parameters.
49    * The defaults here are not necessarily good for the given problem, so it is
50    * suggested that the values used be tailored to the task at hand.  The
51    * maximum number of iterations refers to the maximum number of points that
52    * are processed (i.e., one iteration equals one point; one iteration does not
53    * equal one pass over the dataset).
54    *
55    * @param convexity The regularization parameter.
56    * @param lipschitz The Lipschitz constant.
57    * @param batchSize Batch size to use for each step.
58    * @param maxIterations Maximum number of iterations allowed (0 means no
59    *    limit).
60    * @param innerIterations The number of inner iterations allowed (0 means
61    *    n / batchSize). Note that the full gradient is only calculated in
62    *    the outer iteration.
63    * @param tolerance Maximum absolute tolerance to terminate algorithm.
64    * @param shuffle If true, the function order is shuffled; otherwise, each
65    *    function is visited in linear order.
66    * @param exactObjective Calculate the exact objective (Default: estimate the
67    *        final objective obtained on the last pass over the data).
68    */
69   KatyushaType(const double convexity = 1.0,
70                const double lipschitz = 10.0,
71                const size_t batchSize = 32,
72                const size_t maxIterations = 1000,
73                const size_t innerIterations = 0,
74                const double tolerance = 1e-5,
75                const bool shuffle = true,
76                const bool exactObjective = false);
77 
78   /**
79    * Optimize the given function using Katyusha. The given starting point will
80    * be modified to store the finishing point of the algorithm, and the final
81    * objective value is returned.
82    *
83    * @tparam SeparableFunctionType Type of the function to be optimized.
84    * @tparam MatType Type of matrix to optimize with.
85    * @tparam GradType Type of matrix to use to represent function gradients.
86    * @tparam CallbackTypes Types of callback functions.
87    * @param function Function to optimize.
88    * @param iterate Starting point (will be modified).
89    * @param callbacks Callback functions.
90    * @return Objective value of the final point.
91    */
92   template<typename SeparableFunctionType,
93            typename MatType,
94            typename GradType,
95            typename... CallbackTypes>
96   typename std::enable_if<IsArmaType<GradType>::value,
97       typename MatType::elem_type>::type
98   Optimize(SeparableFunctionType& function,
99            MatType& iterate,
100            CallbackTypes&&... callbacks);
101 
102   //! Forward the MatType as GradType.
103   template<typename SeparableFunctionType,
104            typename MatType,
105            typename... CallbackTypes>
Optimize(SeparableFunctionType & function,MatType & iterate,CallbackTypes &&...callbacks)106   typename MatType::elem_type Optimize(SeparableFunctionType& function,
107                                        MatType& iterate,
108                                        CallbackTypes&&... callbacks)
109   {
110     return Optimize<SeparableFunctionType, MatType, MatType,
111         CallbackTypes...>(function, iterate,
112         std::forward<CallbackTypes>(callbacks)...);
113   }
114 
115   //! Get the convexity parameter.
Convexity() const116   double Convexity() const { return convexity; }
117   //! Modify the convexity parameter.
Convexity()118   double& Convexity() { return convexity; }
119 
120   //! Get the lipschitz parameter.
Lipschitz() const121   double Lipschitz() const { return lipschitz; }
122   //! Modify the lipschitz parameter.
Lipschitz()123   double& Lipschitz() { return lipschitz; }
124 
125   //! Get the batch size.
BatchSize() const126   size_t BatchSize() const { return batchSize; }
127   //! Modify the batch size.
BatchSize()128   size_t& BatchSize() { return batchSize; }
129 
130   //! Get the maximum number of iterations (0 indicates no limit).
MaxIterations() const131   size_t MaxIterations() const { return maxIterations; }
132   //! Modify the maximum number of iterations (0 indicates no limit).
MaxIterations()133   size_t& MaxIterations() { return maxIterations; }
134 
135   //! Get the maximum number of iterations (0 indicates default n / b).
InnerIterations() const136   size_t InnerIterations() const { return innerIterations; }
137   //! Modify the maximum number of iterations (0 indicates default n / b).
InnerIterations()138   size_t& InnerIterations() { return innerIterations; }
139 
140   //! Get the tolerance for termination.
Tolerance() const141   double Tolerance() const { return tolerance; }
142   //! Modify the tolerance for termination.
Tolerance()143   double& Tolerance() { return tolerance; }
144 
145   //! Get whether or not the individual functions are shuffled.
Shuffle() const146   bool Shuffle() const { return shuffle; }
147   //! Modify whether or not the individual functions are shuffled.
Shuffle()148   bool& Shuffle() { return shuffle; }
149 
150   //! Get whether or not the actual objective is calculated.
ExactObjective() const151   bool ExactObjective() const { return exactObjective; }
152   //! Modify whether or not the actual objective is calculated.
ExactObjective()153   bool& ExactObjective() { return exactObjective; }
154 
155  private:
156   //! The convexity regularization term.
157   double convexity;
158 
159   //! The lipschitz constant.
160   double lipschitz;
161 
162   //! The batch size for processing.
163   size_t batchSize;
164 
165   //! The maximum number of allowed iterations.
166   size_t maxIterations;
167 
168   //! The maximum number of allowed inner iterations per epoch.
169   size_t innerIterations;
170 
171   //! The tolerance for termination.
172   double tolerance;
173 
174   //! Controls whether or not the individual functions are shuffled when
175   //! iterating.
176   bool shuffle;
177 
178   //! Controls whether or not the actual Objective value is calculated.
179   bool exactObjective;
180 };
181 
182 // Convenience typedefs.
183 
184 /**
185  * Katyusha using the standard update step.
186  */
187 using Katyusha = KatyushaType<false>;
188 
189 /**
190  * Katyusha using the proximal update step.
191  */
192 using KatyushaProximal = KatyushaType<true>;
193 
194 } // namespace ens
195 
196 // Include implementation.
197 #include "katyusha_impl.hpp"
198 
199 #endif
200