1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file optimizer_op.cc
22  * \brief Optimizer operators
23  * \author Junyuan Xie
24  */
25 #include "./optimizer_op-inl.h"
26 #include "./elemwise_op_common.h"
27 
28 namespace mxnet {
29 namespace op {
30 
31 DMLC_REGISTER_PARAMETER(SGDParam);
32 DMLC_REGISTER_PARAMETER(SGDMomParam);
33 DMLC_REGISTER_PARAMETER(MultiSGDParam);
34 DMLC_REGISTER_PARAMETER(MultiSGDMomParam);
35 DMLC_REGISTER_PARAMETER(FTMLParam);
36 DMLC_REGISTER_PARAMETER(AdamParam);
37 DMLC_REGISTER_PARAMETER(NAGParam);
38 DMLC_REGISTER_PARAMETER(NAGMomParam);
39 DMLC_REGISTER_PARAMETER(RMSPropParam);
40 DMLC_REGISTER_PARAMETER(RMSPropAlexParam);
41 DMLC_REGISTER_PARAMETER(FtrlParam);
42 DMLC_REGISTER_PARAMETER(SignSGDParam);
43 DMLC_REGISTER_PARAMETER(SignumParam);
44 DMLC_REGISTER_PARAMETER(AdagradParam);
45 DMLC_REGISTER_PARAMETER(LambUpdatePhaseOneParam);
46 DMLC_REGISTER_PARAMETER(LambUpdatePhaseTwoParam);
47 
48 NNVM_REGISTER_OP(signsgd_update)
49 .describe(R"code(Update function for SignSGD optimizer.
50 
51 .. math::
52 
53  g_t = \nabla J(W_{t-1})\\
54  W_t = W_{t-1} - \eta_t \text{sign}(g_t)
55 
56 It updates the weights using::
57 
58  weight = weight - learning_rate * sign(gradient)
59 
60 .. note::
61    - sparse ndarray not supported for this optimizer yet.
62 )code" ADD_FILELINE)
63 .set_num_inputs(2)
64 .set_num_outputs(1)
65 .set_attr_parser(ParamParser<SignSGDParam>)
66 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
67 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
68 .set_attr<FCompute>("FCompute<cpu>", SignSGDUpdate<cpu>)
69 .add_argument("weight", "NDArray-or-Symbol", "Weight")
70 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
71 .add_arguments(SignSGDParam::__FIELDS__());
72 
73 
74 NNVM_REGISTER_OP(signum_update)
75 .describe(R"code(SIGN momentUM (Signum) optimizer.
76 
77 .. math::
78 
79  g_t = \nabla J(W_{t-1})\\
80  m_t = \beta m_{t-1} + (1 - \beta) g_t\\
81  W_t = W_{t-1} - \eta_t \text{sign}(m_t)
82 
83 It updates the weights using::
84  state = momentum * state + (1-momentum) * gradient
85  weight = weight - learning_rate * sign(state)
86 
87 Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch.
88 
89 .. note::
90    - sparse ndarray not supported for this optimizer yet.
91 )code" ADD_FILELINE)
92 .set_num_inputs(3)
93 .set_num_outputs(1)
94 .set_attr_parser(ParamParser<SignumParam>)
95 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
96 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
97 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde190102(const nnvm::NodeAttrs& attrs) 98   [](const nnvm::NodeAttrs& attrs) {
99     return std::vector<uint32_t>{2};
100   })
101 .set_attr<FCompute>("FCompute<cpu>", SignumUpdate<cpu>)
102 .add_argument("weight", "NDArray-or-Symbol", "Weight")
103 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
104 .add_argument("mom", "NDArray-or-Symbol", "Momentum")
105 .add_arguments(SignumParam::__FIELDS__());
106 
107 template<int req>
108 struct SGDMomStdDnsRspDnsKernel<req, cpu> {
109   template<typename DType, typename IType, typename RType>
Mapmxnet::op::SGDMomStdDnsRspDnsKernel110   MSHADOW_XINLINE static void Map(int i, index_t row_length, DType* out_data,
111     DType* mom_data, const DType* weight_data, const IType* grad_idx,
112     const DType* grad_data, const RType* prefix_sum, const DType clip_gradient,
113     const DType momentum, const DType lr, const DType wd, const DType rescale_grad) {
114     const DType rate = lr * wd;
115     const bool non_zero = (i == 0) ? prefix_sum[0] > 0
116                                    : prefix_sum[i] > prefix_sum[i-1];
117 
118     const index_t row_i = i * row_length;
119     const RType grad_i = (prefix_sum[i]-1) * row_length;
120     for (index_t j = 0; j < row_length; j++) {
121       const index_t data_i = row_i + j;
122       const DType grad = non_zero ? grad_data[grad_i + j]
123                                   : static_cast<DType>(0);
124       if (clip_gradient >= 0.0f) {
125         mom_data[data_i] = momentum * mom_data[data_i]
126                 - rate * weight_data[data_i]
127                 - lr *
128                 mshadow_op::clip::Map(rescale_grad * grad,
129                                       clip_gradient);
130       } else {
131         mom_data[data_i] = momentum * mom_data[data_i]
132                   - rate * weight_data[data_i]
133                   - lr * rescale_grad * grad;
134       }
135       KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] + mom_data[data_i]);
136     }
137   }
138 };
139 
140 /*
141  * \brief standard momentum update for dense weight on cpu.
142  *        state is expected to be dense, while grad is expected to be row_sparse.
143  */
144 template<>
SGDMomStdUpdateDnsRspDnsImpl(const SGDMomParam & param,const OpContext & ctx,const TBlob & weight,const NDArray & grad,const TBlob & mom,const OpReqType & req,TBlob * out)145 void SGDMomStdUpdateDnsRspDnsImpl<cpu>(const SGDMomParam& param,
146                                        const OpContext& ctx,
147                                        const TBlob& weight,
148                                        const NDArray& grad,
149                                        const TBlob& mom,
150                                        const OpReqType& req,
151                                        TBlob *out) {
152   using namespace mxnet_op;
153   using namespace rowsparse;
154   using namespace mshadow;
155   Stream<cpu>* s = ctx.get_stream<cpu>();
156   if (req == kNullOp) return;
157   CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_mom_update";
158   CHECK_GT(weight.shape_.Size(), 0);
159   CHECK_GT(mom.shape_.Size(), 0);
160   MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
161     MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
162       MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
163         DType* weight_data = weight.dptr<DType>();
164         const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
165         const DType* grad_val = grad.data().dptr<DType>();
166         DType* mom_data = mom.dptr<DType>();
167         DType* out_data = out->dptr<DType>();
168         const nnvm::dim_t num_rows = weight.shape_[0];
169         const auto row_length = weight.shape_.ProdShape(1, weight.ndim());
170         Tensor<cpu, 1, char> workspace = ctx.requested[0]
171           .get_space_typed<cpu, 1, char>(Shape1(num_rows * sizeof(nnvm::dim_t)), s);
172 
173         nnvm::dim_t* prefix_sum = reinterpret_cast<nnvm::dim_t*>(workspace.dptr_);
174         // mark row flags
175         Kernel<set_zero, cpu>::Launch(s, num_rows, prefix_sum);
176         if (grad.storage_initialized()) {
177           Kernel<MarkRowFlgKernel, cpu>::Launch(s, grad.aux_shape(kIdx)[0],
178             prefix_sum, grad_idx);
179           // calculate inclusive prefix sum
180           for (nnvm::dim_t i = 1; i < num_rows; i++) {
181             prefix_sum[i] += prefix_sum[i - 1];
182           }
183         }
184         Kernel<SGDMomStdDnsRspDnsKernel<req_type, cpu>, cpu>::Launch(s, num_rows, row_length,
185           out_data, mom_data, weight_data, grad_idx, grad_val, prefix_sum,
186           static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
187           static_cast<DType>(param.lr), static_cast<DType>(param.wd),
188           static_cast<DType>(param.rescale_grad));
189       });
190     });
191   });
192 }
193 
194 template<int req>
195 struct AdamStdDnsRspDnsKernel<req, cpu> {
196   template<typename DType, typename IType, typename RType>
Mapmxnet::op::AdamStdDnsRspDnsKernel197   MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data,
198     DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx,
199     const DType* grad_data, const RType* prefix_sum, const DType clip_gradient,
200     const DType beta1, const DType beta2, const DType lr, const DType wd,
201     const DType epsilon, const DType rescale_grad) {
202     using namespace mshadow_op;
203     const bool non_zero = (i == 0) ? prefix_sum[0] > 0
204                                    : prefix_sum[i] > prefix_sum[i-1];
205 
206     const index_t row_i = i * row_length;
207     const RType grad_i = (prefix_sum[i]-1) * row_length;
208     for (index_t j = 0; j < row_length; j++) {
209       const index_t data_i = row_i + j;
210       const DType grad_rescaled = non_zero ? static_cast<DType>(
211                                                grad_data[grad_i + j] * rescale_grad +
212                                                weight_data[data_i] * wd)
213                                            : static_cast<DType>(weight_data[data_i] * wd);
214       if (clip_gradient >= 0.0f) {
215         mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) *
216                             clip::Map(grad_rescaled, clip_gradient);
217         var_data[data_i] =  beta2 * var_data[data_i] + (1.f - beta2) * square::Map(
218                             clip::Map(grad_rescaled, clip_gradient));
219       } else {
220         mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * grad_rescaled;
221         var_data[data_i] = beta2 * var_data[data_i] +
222                            (1.f - beta2) * square::Map(grad_rescaled);
223       }
224       KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] - lr * mean_data[data_i] /
225                     (square_root::Map(var_data[data_i]) + epsilon));
226     }
227   }
228 };
229 
230 
231 template<>
AdamStdUpdateDnsRspDnsImpl(const AdamParam & param,const OpContext & ctx,const TBlob & weight,const NDArray & grad,const TBlob & mean,const TBlob & var,const OpReqType & req,TBlob * out)232 void AdamStdUpdateDnsRspDnsImpl<cpu>(const AdamParam& param,
233                                      const OpContext& ctx,
234                                      const TBlob& weight,
235                                      const NDArray& grad,
236                                      const TBlob& mean,
237                                      const TBlob& var,
238                                      const OpReqType& req,
239                                      TBlob *out) {
240   using namespace mxnet_op;
241   using namespace rowsparse;
242   using namespace mshadow;
243   Stream<cpu>* s = ctx.get_stream<cpu>();
244   if (req == kNullOp) return;
245   CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse adam_update";
246   CHECK_GT(weight.shape_.Size(), 0);
247   CHECK_GT(mean.shape_.Size(), 0);
248   CHECK_GT(var.shape_.Size(), 0);
249 
250   MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
251     MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
252       MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
253         const DType* weight_data = weight.dptr<DType>();
254         const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
255         const DType* grad_val = grad.data().dptr<DType>();
256         DType* mean_data = mean.dptr<DType>();
257         DType* var_data = var.dptr<DType>();
258         DType* out_data = out->dptr<DType>();
259         nnvm::dim_t num_rows = weight.shape_[0];
260         nnvm::dim_t row_length = weight.shape_.ProdShape(1, weight.ndim());
261         Tensor<cpu, 1, char> workspace = ctx.requested[0]
262           .get_space_typed<cpu, 1, char>(Shape1(num_rows * sizeof(nnvm::dim_t)), s);
263 
264         nnvm::dim_t* prefix_sum = reinterpret_cast<nnvm::dim_t*>(workspace.dptr_);
265         // mark row flags
266         Kernel<set_zero, cpu>::Launch(s, num_rows, prefix_sum);
267         if (grad.storage_initialized()) {
268           Kernel<MarkRowFlgKernel, cpu>::Launch(s, grad.aux_shape(kIdx)[0],
269             prefix_sum, grad_idx);
270           // calculate inclusive prefix sum
271           for (nnvm::dim_t i = 1; i < num_rows; i++) {
272             prefix_sum[i] += prefix_sum[i - 1];
273           }
274         }
275 
276         Kernel<AdamStdDnsRspDnsKernel<req_type, cpu>, cpu>::Launch(s, num_rows, row_length,
277           out_data, mean_data, var_data, weight_data, grad_idx, grad_val, prefix_sum,
278           static_cast<DType>(param.clip_gradient), static_cast<DType>(param.beta1),
279           static_cast<DType>(param.beta2), static_cast<DType>(param.lr),
280           static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
281           static_cast<DType>(param.rescale_grad));
282       });
283     });
284   });
285 }
286 
287 /*!
288  * \brief Storge type inference function for SGD.
289  */
SGDStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)290 inline bool SGDStorageType(const nnvm::NodeAttrs& attrs,
291                            const int dev_mask,
292                            DispatchMode* dispatch_mode,
293                            std::vector<int>* in_attrs,
294                            std::vector<int>* out_attrs) {
295   using namespace common;
296   const SGDParam& param = nnvm::get<SGDParam>(attrs.parsed);
297   CHECK_EQ(in_attrs->size(), 2U);
298   CHECK_EQ(out_attrs->size(), 1U);
299   const int weight_stype = in_attrs->at(0);
300   const int grad_stype = in_attrs->at(1);
301   bool dispatched = false;
302   if (!dispatched && ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
303     // dns, ... -> dns
304     dispatched = storage_type_assign(out_attrs, kDefaultStorage,
305                                      dispatch_mode, DispatchMode::kFCompute);
306   }
307   if (!dispatched && grad_stype == kRowSparseStorage &&
308       (weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage)) {
309     // grad's stype = rsp
310     dispatched = storage_type_assign(out_attrs, static_cast<NDArrayStorageType>(weight_stype),
311                                      dispatch_mode, DispatchMode::kFComputeEx);
312     // warn users if lazy_update is turned on
313     if (dispatched && param.wd != 0 && param.lazy_update) LogLazyUpdate();
314   }
315   if (!dispatched) {
316     dispatched = dispatch_fallback(out_attrs, dispatch_mode);
317   }
318   return dispatched;
319 }
320 
321 NNVM_REGISTER_OP(multi_sgd_update)
322 .describe(R"code(Update function for Stochastic Gradient Descent (SDG) optimizer.
323 
324 It updates the weights using::
325 
326  weight = weight - learning_rate * (gradient + wd * weight)
327 
328 )code" ADD_FILELINE)
__anonfeecde190202(const nnvm::NodeAttrs& attrs) 329 .set_num_inputs([](const nnvm::NodeAttrs& attrs) {
330     const MultiSGDParam& param = dmlc::get<MultiSGDParam>(attrs.parsed);
331     return static_cast<uint32_t>(param.num_weights * 2);
332   })
__anonfeecde190302(const nnvm::NodeAttrs& attrs) 333 .set_num_outputs([](const nnvm::NodeAttrs& attrs) {
334     const MultiSGDParam& param = dmlc::get<MultiSGDParam>(attrs.parsed);
335     return static_cast<uint32_t>(param.num_weights);
336   })
337 .set_attr_parser(ParamParser<MultiSGDParam>)
338 .set_attr<mxnet::FInferShape>("FInferShape", MultiSGDShape<MultiSGDParam, 2>)
339 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, -1>)
340 .set_attr<nnvm::FListInputNames>("FListInputNames",
__anonfeecde190402(const NodeAttrs& attrs) 341   [](const NodeAttrs& attrs) {
342     uint32_t num_args = dmlc::get<MultiSGDParam>(attrs.parsed).num_weights;
343     std::vector<std::string> ret;
344     for (uint32_t i = 0; i < num_args; ++i) {
345       ret.push_back(std::string("weight_") + std::to_string(i));
346       ret.push_back(std::string("grad_") + std::to_string(i));
347     }
348     return ret;
349   })
350 .set_attr<FCompute>("FCompute<cpu>", MultiSGDUpdate<cpu, type_identity, 2>)
351 .add_argument("data", "NDArray-or-Symbol[]", "Weights")
352 .add_arguments(MultiSGDParam::__FIELDS__());
353 
354 NNVM_REGISTER_OP(multi_sgd_mom_update)
355 .describe(R"code(Momentum update function for Stochastic Gradient Descent (SGD) optimizer.
356 
357 Momentum update has better convergence rates on neural networks. Mathematically it looks
358 like below:
359 
360 .. math::
361 
362   v_1 = \alpha * \nabla J(W_0)\\
363   v_t = \gamma v_{t-1} - \alpha * \nabla J(W_{t-1})\\
364   W_t = W_{t-1} + v_t
365 
366 It updates the weights using::
367 
368   v = momentum * v - learning_rate * gradient
369   weight += v
370 
371 Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch.
372 
373 )code" ADD_FILELINE)
__anonfeecde190502(const nnvm::NodeAttrs& attrs) 374 .set_num_inputs([](const nnvm::NodeAttrs& attrs) {
375     const MultiSGDMomParam& param = dmlc::get<MultiSGDMomParam>(attrs.parsed);
376     return static_cast<uint32_t>(param.num_weights * 3);
377   })
__anonfeecde190602(const nnvm::NodeAttrs& attrs) 378 .set_num_outputs([](const nnvm::NodeAttrs& attrs) {
379     const MultiSGDMomParam& param = dmlc::get<MultiSGDMomParam>(attrs.parsed);
380     return static_cast<uint32_t>(param.num_weights);
381   })
382 .set_attr_parser(ParamParser<MultiSGDMomParam>)
383 .set_attr<mxnet::FInferShape>("FInferShape", MultiSGDShape<MultiSGDMomParam, 3>)
384 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, -1>)
385 .set_attr<nnvm::FListInputNames>("FListInputNames",
__anonfeecde190702(const NodeAttrs& attrs) 386   [](const NodeAttrs& attrs) {
387     uint32_t num_args = dmlc::get<MultiSGDParam>(attrs.parsed).num_weights;
388     std::vector<std::string> ret;
389     for (uint32_t i = 0; i < num_args; ++i) {
390       ret.push_back(std::string("weight_") + std::to_string(i));
391       ret.push_back(std::string("grad_") + std::to_string(i));
392       ret.push_back(std::string("mom_") + std::to_string(i));
393     }
394     return ret;
395   })
396 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde190802(const nnvm::NodeAttrs& attrs) 397   [](const nnvm::NodeAttrs& attrs) {
398     std::vector<uint32_t> ret;
399     const MultiSGDMomParam& param = dmlc::get<MultiSGDMomParam>(attrs.parsed);
400     for (int i = 0; i < param.num_weights; ++i) {
401       ret.push_back(i * 3 + 2);
402     }
403     return ret;
404   })
405 .set_attr<FCompute>("FCompute<cpu>", MultiSGDMomUpdate<cpu, type_identity, 3>)
406 .add_argument("data", "NDArray-or-Symbol[]", "Weights, gradients and momentum")
407 .add_arguments(MultiSGDMomParam::__FIELDS__());
408 
409 NNVM_REGISTER_OP(multi_mp_sgd_update)
410 .describe(R"code(Update function for multi-precision Stochastic Gradient Descent (SDG) optimizer.
411 
412 It updates the weights using::
413 
414  weight = weight - learning_rate * (gradient + wd * weight)
415 
416 )code" ADD_FILELINE)
__anonfeecde190902(const nnvm::NodeAttrs& attrs) 417 .set_num_inputs([](const nnvm::NodeAttrs& attrs) {
418     const MultiSGDParam& param = dmlc::get<MultiSGDParam>(attrs.parsed);
419     return static_cast<uint32_t>(param.num_weights * 3);
420   })
__anonfeecde190a02(const nnvm::NodeAttrs& attrs) 421 .set_num_outputs([](const nnvm::NodeAttrs& attrs) {
422     const MultiSGDParam& param = dmlc::get<MultiSGDParam>(attrs.parsed);
423     return static_cast<uint32_t>(param.num_weights);
424   })
425 .set_attr_parser(ParamParser<MultiSGDParam>)
426 .set_attr<mxnet::FInferShape>("FInferShape", MultiSGDShape<MultiSGDParam, 3>)
427 .set_attr<nnvm::FInferType>("FInferType", MP_MultiSGD_InferType<MultiSGDParam, 3, 1>)
428 .set_attr<nnvm::FListInputNames>("FListInputNames",
__anonfeecde190b02(const NodeAttrs& attrs) 429   [](const NodeAttrs& attrs) {
430     uint32_t num_args = dmlc::get<MultiSGDParam>(attrs.parsed).num_weights;
431     std::vector<std::string> ret;
432     for (uint32_t i = 0; i < num_args; ++i) {
433       ret.push_back(std::string("weight_") + std::to_string(i));
434       ret.push_back(std::string("grad_") + std::to_string(i));
435       ret.push_back(std::string("weight32_") + std::to_string(i));
436     }
437     return ret;
438   })
439 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde190c02(const nnvm::NodeAttrs& attrs) 440   [](const nnvm::NodeAttrs& attrs) {
441     std::vector<uint32_t> ret;
442     const MultiSGDParam& param = dmlc::get<MultiSGDParam>(attrs.parsed);
443     for (int i = 0; i < param.num_weights; ++i) {
444       ret.push_back(i * 3 + 2);
445     }
446     return ret;
447   })
448 .set_attr<FCompute>("FCompute<cpu>", MultiSGDUpdate<cpu, single_precision, 3>)
449 .add_argument("data", "NDArray-or-Symbol[]", "Weights")
450 .add_arguments(MultiSGDParam::__FIELDS__());
451 
452 NNVM_REGISTER_OP(multi_mp_sgd_mom_update)
453 .describe(R"code(Momentum update function for multi-precision Stochastic Gradient Descent (SGD) optimizer.
454 
455 Momentum update has better convergence rates on neural networks. Mathematically it looks
456 like below:
457 
458 .. math::
459 
460   v_1 = \alpha * \nabla J(W_0)\\
461   v_t = \gamma v_{t-1} - \alpha * \nabla J(W_{t-1})\\
462   W_t = W_{t-1} + v_t
463 
464 It updates the weights using::
465 
466   v = momentum * v - learning_rate * gradient
467   weight += v
468 
469 Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch.
470 
471 )code" ADD_FILELINE)
__anonfeecde190d02(const nnvm::NodeAttrs& attrs) 472 .set_num_inputs([](const nnvm::NodeAttrs& attrs) {
473     const MultiSGDMomParam& param = dmlc::get<MultiSGDMomParam>(attrs.parsed);
474     return static_cast<uint32_t>(param.num_weights * 4);
475   })
__anonfeecde190e02(const nnvm::NodeAttrs& attrs) 476 .set_num_outputs([](const nnvm::NodeAttrs& attrs) {
477     const MultiSGDMomParam& param = dmlc::get<MultiSGDMomParam>(attrs.parsed);
478     return static_cast<uint32_t>(param.num_weights);
479   })
480 .set_attr_parser(ParamParser<MultiSGDMomParam>)
481 .set_attr<mxnet::FInferShape>("FInferShape", MultiSGDShape<MultiSGDMomParam, 4>)
482 .set_attr<nnvm::FInferType>("FInferType", MP_MultiSGD_InferType<MultiSGDMomParam, 4, 2>)
483 .set_attr<nnvm::FListInputNames>("FListInputNames",
__anonfeecde190f02(const NodeAttrs& attrs) 484   [](const NodeAttrs& attrs) {
485     uint32_t num_args = dmlc::get<MultiSGDMomParam>(attrs.parsed).num_weights;
486     std::vector<std::string> ret;
487     for (uint32_t i = 0; i < num_args; ++i) {
488       ret.push_back(std::string("weight_") + std::to_string(i));
489       ret.push_back(std::string("grad_") + std::to_string(i));
490       ret.push_back(std::string("mom_") + std::to_string(i));
491       ret.push_back(std::string("weight32_") + std::to_string(i));
492     }
493     return ret;
494   })
495 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191002(const nnvm::NodeAttrs& attrs) 496   [](const nnvm::NodeAttrs& attrs) {
497     std::vector<uint32_t> ret;
498     const MultiSGDMomParam& param = dmlc::get<MultiSGDMomParam>(attrs.parsed);
499     for (int i = 0; i < param.num_weights; ++i) {
500       ret.push_back(i * 4 + 2);
501       ret.push_back(i * 4 + 3);
502     }
503     return ret;
504   })
505 .set_attr<FCompute>("FCompute<cpu>", MultiSGDMomUpdate<cpu, single_precision, 4>)
506 .add_argument("data", "NDArray-or-Symbol[]", "Weights")
507 .add_arguments(MultiSGDMomParam::__FIELDS__());
508 
509 NNVM_REGISTER_OP(sgd_update)
510 MXNET_ADD_SPARSE_OP_ALIAS(sgd_update)
511 .describe(R"code(Update function for Stochastic Gradient Descent (SGD) optimizer.
512 
513 It updates the weights using::
514 
515  weight = weight - learning_rate * (gradient + wd * weight)
516 
517 However, if gradient is of ``row_sparse`` storage type and ``lazy_update`` is True,
518 only the row slices whose indices appear in grad.indices are updated::
519 
520  for row in gradient.indices:
521      weight[row] = weight[row] - learning_rate * (gradient[row] + wd * weight[row])
522 
523 )code" ADD_FILELINE)
524 .set_num_inputs(2)
525 .set_num_outputs(1)
526 .set_attr_parser(ParamParser<SGDParam>)
527 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
528 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
529 .set_attr<FInferStorageType>("FInferStorageType", SGDStorageType)
530 .set_attr<FCompute>("FCompute<cpu>", SGDUpdate<cpu>)
531 .set_attr<FComputeEx>("FComputeEx<cpu>", SGDUpdateEx<cpu>)
532 .add_argument("weight", "NDArray-or-Symbol", "Weight")
533 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
534 .add_arguments(SGDParam::__FIELDS__());
535 
536 NNVM_REGISTER_OP(sgd_mom_update)
537 MXNET_ADD_SPARSE_OP_ALIAS(sgd_mom_update)
538 .describe(R"code(Momentum update function for Stochastic Gradient Descent (SGD) optimizer.
539 
540 Momentum update has better convergence rates on neural networks. Mathematically it looks
541 like below:
542 
543 .. math::
544 
545   v_1 = \alpha * \nabla J(W_0)\\
546   v_t = \gamma v_{t-1} - \alpha * \nabla J(W_{t-1})\\
547   W_t = W_{t-1} + v_t
548 
549 It updates the weights using::
550 
551   v = momentum * v - learning_rate * gradient
552   weight += v
553 
554 Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch.
555 
556 However, if grad's storage type is ``row_sparse``, ``lazy_update`` is True and weight's storage
557 type is the same as momentum's storage type,
558 only the row slices whose indices appear in grad.indices are updated (for both weight and momentum)::
559 
560   for row in gradient.indices:
561       v[row] = momentum[row] * v[row] - learning_rate * gradient[row]
562       weight[row] += v[row]
563 
564 )code" ADD_FILELINE)
565 .set_num_inputs(3)
566 .set_num_outputs(1)
567 .set_attr_parser(ParamParser<SGDMomParam>)
568 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
569 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
570 .set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<1, SGDMomParam>)
571 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191102(const nnvm::NodeAttrs& attrs) 572   [](const nnvm::NodeAttrs& attrs) {
573     return std::vector<uint32_t>{2};
574   })
575 .set_attr<FResourceRequestEx>("FResourceRequestEx",
__anonfeecde191202(const NodeAttrs& attrs, const int dev_mask, const DispatchMode dispatch_mode) 576   [](const NodeAttrs& attrs, const int dev_mask, const DispatchMode dispatch_mode) {
577     std::vector<ResourceRequest> request;
578     if (dispatch_mode == DispatchMode::kFComputeEx) {
579       request.emplace_back(ResourceRequest::kTempSpace);
580     }
581     return request;
582   })
583 .set_attr<FCompute>("FCompute<cpu>", SGDMomUpdate<cpu>)
584 .set_attr<FComputeEx>("FComputeEx<cpu>", SGDMomUpdateEx<cpu>)
585 .add_argument("weight", "NDArray-or-Symbol", "Weight")
586 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
587 .add_argument("mom", "NDArray-or-Symbol", "Momentum")
588 .add_arguments(SGDMomParam::__FIELDS__());
589 
590 NNVM_REGISTER_OP(mp_sgd_update)
591 .describe("Updater function for multi-precision sgd optimizer")
592 .set_num_inputs(3)
593 .set_num_outputs(1)
594 .set_attr_parser(ParamParser<SGDParam>)
595 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
596 .set_attr<nnvm::FInferType>("FInferType", MP_InferType<2, 1, 3>)
597 .set_attr<FCompute>("FCompute<cpu>", MP_SGDUpdate<cpu>)
598 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191302(const nnvm::NodeAttrs& attrs) 599   [](const nnvm::NodeAttrs& attrs) {
600     return std::vector<uint32_t>{2};
601   })
602 .add_argument("weight", "NDArray-or-Symbol", "Weight")
603 .add_argument("grad", "NDArray-or-Symbol", "gradient")
604 .add_argument("weight32", "NDArray-or-Symbol", "Weight32")
605 .add_arguments(SGDParam::__FIELDS__());
606 
607 NNVM_REGISTER_OP(mp_sgd_mom_update)
608 .describe("Updater function for multi-precision sgd optimizer")
609 .set_num_inputs(4)
610 .set_num_outputs(1)
611 .set_attr_parser(ParamParser<SGDMomParam>)
612 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
613 .set_attr<nnvm::FInferType>("FInferType", MP_InferType<2, 1, 4>)
614 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191402(const nnvm::NodeAttrs& attrs) 615   [](const nnvm::NodeAttrs& attrs) {
616     return std::vector<uint32_t>{2, 3};
617   })
618 .set_attr<FCompute>("FCompute<cpu>", MP_SGDMomUpdate<cpu>)
619 .add_argument("weight", "NDArray-or-Symbol", "Weight")
620 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
621 .add_argument("mom", "NDArray-or-Symbol", "Momentum")
622 .add_argument("weight32", "NDArray-or-Symbol", "Weight32")
623 .add_arguments(SGDMomParam::__FIELDS__());
624 
625 NNVM_REGISTER_OP(ftml_update)
626 .describe(R"code(The FTML optimizer described in
627 *FTML - Follow the Moving Leader in Deep Learning*,
628 available at http://proceedings.mlr.press/v70/zheng17a/zheng17a.pdf.
629 
630 .. math::
631 
632  g_t = \nabla J(W_{t-1})\\
633  v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\
634  d_t = \frac{ 1 - \beta_1^t }{ \eta_t } (\sqrt{ \frac{ v_t }{ 1 - \beta_2^t } } + \epsilon)
635  \sigma_t = d_t - \beta_1 d_{t-1}
636  z_t = \beta_1 z_{ t-1 } + (1 - \beta_1^t) g_t - \sigma_t W_{t-1}
637  W_t = - \frac{ z_t }{ d_t }
638 
639 )code" ADD_FILELINE)
640 .set_num_inputs(5)
641 .set_num_outputs(1)
642 .set_attr_parser(ParamParser<FTMLParam>)
643 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<5, 1>)
644 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<5, 1>)
645 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191502(const nnvm::NodeAttrs& attrs) 646   [](const nnvm::NodeAttrs& attrs) {
647     return std::vector<uint32_t>{2, 3, 4};
648   })
649 .set_attr<FCompute>("FCompute<cpu>", FTMLUpdate<cpu>)
650 .add_argument("weight", "NDArray-or-Symbol", "Weight")
651 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
652 .add_argument("d", "NDArray-or-Symbol", "Internal state ``d_t``")
653 .add_argument("v", "NDArray-or-Symbol", "Internal state ``v_t``")
654 .add_argument("z", "NDArray-or-Symbol", "Internal state ``z_t``")
655 .add_arguments(FTMLParam::__FIELDS__());
656 
657 NNVM_REGISTER_OP(adam_update)
658 MXNET_ADD_SPARSE_OP_ALIAS(adam_update)
659 .describe(R"code(Update function for Adam optimizer. Adam is seen as a generalization
660 of AdaGrad.
661 
662 Adam update consists of the following steps, where g represents gradient and m, v
663 are 1st and 2nd order moment estimates (mean and variance).
664 
665 .. math::
666 
667  g_t = \nabla J(W_{t-1})\\
668  m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\
669  v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\
670  W_t = W_{t-1} - \alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon }
671 
672 It updates the weights using::
673 
674  m = beta1*m + (1-beta1)*grad
675  v = beta2*v + (1-beta2)*(grad**2)
676  w += - learning_rate * m / (sqrt(v) + epsilon)
677 
678 However, if grad's storage type is ``row_sparse``, ``lazy_update`` is True and the storage
679 type of weight is the same as those of m and v,
680 only the row slices whose indices appear in grad.indices are updated (for w, m and v)::
681 
682  for row in grad.indices:
683      m[row] = beta1*m[row] + (1-beta1)*grad[row]
684      v[row] = beta2*v[row] + (1-beta2)*(grad[row]**2)
685      w[row] += - learning_rate * m[row] / (sqrt(v[row]) + epsilon)
686 
687 )code" ADD_FILELINE)
688 .set_num_inputs(4)
689 .set_num_outputs(1)
690 .set_attr_parser(ParamParser<AdamParam>)
691 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
692 .set_attr<FResourceRequest>("FResourceRequest",
__anonfeecde191602(const NodeAttrs& attrs) 693   [](const NodeAttrs& attrs) {
694     return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
695   })
696 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
697 .set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<2, AdamParam>)
698 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191702(const nnvm::NodeAttrs& attrs) 699   [](const nnvm::NodeAttrs& attrs) {
700     return std::vector<uint32_t>{2, 3};
701   })
702 .set_attr<FCompute>("FCompute<cpu>", AdamUpdate<cpu>)
703 .set_attr<FComputeEx>("FComputeEx<cpu>", AdamUpdateEx<cpu>)
704 .add_argument("weight", "NDArray-or-Symbol", "Weight")
705 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
706 .add_argument("mean", "NDArray-or-Symbol", "Moving mean")
707 .add_argument("var", "NDArray-or-Symbol", "Moving variance")
708 .add_arguments(AdamParam::__FIELDS__());
709 
710 
711 NNVM_REGISTER_OP(nag_mom_update)
712 .describe(R"code(Update function for Nesterov Accelerated Gradient( NAG) optimizer.
713 It updates the weights using the following formula,
714 
715 .. math::
716   v_t = \gamma v_{t-1} + \eta * \nabla J(W_{t-1} - \gamma v_{t-1})\\
717   W_t = W_{t-1} - v_t
718 
719 Where
720 :math:`\eta` is the learning rate of the optimizer
721 :math:`\gamma` is the decay rate of the momentum estimate
722 :math:`\v_t` is the update vector at time step `t`
723 :math:`\W_t` is the weight vector at time step `t`
724 
725 )code" ADD_FILELINE)
726 .set_num_inputs(3)
727 .set_num_outputs(1)
728 .set_attr_parser(ParamParser<NAGMomParam>)
729 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
730 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
731 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191802(const nnvm::NodeAttrs& attrs) 732   [](const nnvm::NodeAttrs& attrs) {
733     return std::vector<uint32_t>{2};
734   })
735 .set_attr<FCompute>("FCompute<cpu>", NAGMomUpdate<cpu>)
736 .add_argument("weight", "NDArray-or-Symbol", "Weight")
737 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
738 .add_argument("mom", "NDArray-or-Symbol", "Momentum")
739 .add_arguments(NAGMomParam::__FIELDS__());
740 
741 
742 NNVM_REGISTER_OP(mp_nag_mom_update)
743 .describe(R"code(Update function for multi-precision Nesterov Accelerated Gradient( NAG) optimizer.
744 )code" ADD_FILELINE)
745 .set_num_inputs(4)
746 .set_num_outputs(1)
747 .set_attr_parser(ParamParser<NAGMomParam>)
748 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
749 .set_attr<nnvm::FInferType>("FInferType", MP_InferType<2, 1, 4>)
750 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191902(const nnvm::NodeAttrs& attrs) 751   [](const nnvm::NodeAttrs& attrs) {
752     return std::vector<uint32_t>{2, 3};
753   })
754 .set_attr<FCompute>("FCompute<cpu>", MP_NAGMomUpdate<cpu>)
755 .add_argument("weight", "NDArray-or-Symbol", "Weight")
756 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
757 .add_argument("mom", "NDArray-or-Symbol", "Momentum")
758 .add_argument("weight32", "NDArray-or-Symbol", "Weight32")
759 .add_arguments(NAGMomParam::__FIELDS__());
760 
761 
762 NNVM_REGISTER_OP(rmsprop_update)
763 .describe(R"code(Update function for `RMSProp` optimizer.
764 
765 `RMSprop` is a variant of stochastic gradient descent where the gradients are
766 divided by a cache which grows with the sum of squares of recent gradients?
767 
768 `RMSProp` is similar to `AdaGrad`, a popular variant of `SGD` which adaptively
769 tunes the learning rate of each parameter. `AdaGrad` lowers the learning rate for
770 each parameter monotonically over the course of training.
771 While this is analytically motivated for convex optimizations, it may not be ideal
772 for non-convex problems. `RMSProp` deals with this heuristically by allowing the
773 learning rates to rebound as the denominator decays over time.
774 
775 Define the Root Mean Square (RMS) error criterion of the gradient as
776 :math:`RMS[g]_t = \sqrt{E[g^2]_t + \epsilon}`, where :math:`g` represents
777 gradient and :math:`E[g^2]_t` is the decaying average over past squared gradient.
778 
779 The :math:`E[g^2]_t` is given by:
780 
781 .. math::
782   E[g^2]_t = \gamma * E[g^2]_{t-1} + (1-\gamma) * g_t^2
783 
784 The update step is
785 
786 .. math::
787   \theta_{t+1} = \theta_t - \frac{\eta}{RMS[g]_t} g_t
788 
789 The RMSProp code follows the version in
790 http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf
791 Tieleman & Hinton, 2012.
792 
793 Hinton suggests the momentum term :math:`\gamma` to be 0.9 and the learning rate
794 :math:`\eta` to be 0.001.
795 
796 )code" ADD_FILELINE)
797 .set_num_inputs(3)
798 .set_num_outputs(1)
799 .set_attr_parser(ParamParser<RMSPropParam>)
800 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
801 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
802 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191a02(const nnvm::NodeAttrs &attrs) 803   [](const nnvm::NodeAttrs &attrs) {
804     return std::vector<uint32_t>{2};
805   })
806 .set_attr<FCompute>("FCompute<cpu>", RMSPropUpdate<cpu>)
807 .add_argument("weight", "NDArray-or-Symbol", "Weight")
808 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
809 .add_argument("n", "NDArray-or-Symbol", "n")
810 .add_arguments(RMSPropParam::__FIELDS__());
811 
812 NNVM_REGISTER_OP(rmspropalex_update)
813 .describe(R"code(Update function for RMSPropAlex optimizer.
814 
815 `RMSPropAlex` is non-centered version of `RMSProp`.
816 
817 Define :math:`E[g^2]_t` is the decaying average over past squared gradient and
818 :math:`E[g]_t` is the decaying average over past gradient.
819 
820 .. math::
821   E[g^2]_t = \gamma_1 * E[g^2]_{t-1} + (1 - \gamma_1) * g_t^2\\
822   E[g]_t = \gamma_1 * E[g]_{t-1} + (1 - \gamma_1) * g_t\\
823   \Delta_t = \gamma_2 * \Delta_{t-1} - \frac{\eta}{\sqrt{E[g^2]_t - E[g]_t^2 + \epsilon}} g_t\\
824 
825 The update step is
826 
827 .. math::
828   \theta_{t+1} = \theta_t + \Delta_t
829 
830 The RMSPropAlex code follows the version in
831 http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45) by Alex Graves, 2013.
832 
833 Graves suggests the momentum term :math:`\gamma_1` to be 0.95, :math:`\gamma_2`
834 to be 0.9 and the learning rate :math:`\eta` to be 0.0001.
835 )code" ADD_FILELINE)
836 .set_num_inputs(5)
837 .set_num_outputs(1)
838 .set_attr_parser(ParamParser<RMSPropAlexParam>)
839 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<5, 1>)
840 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<5, 1>)
841 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191b02(const nnvm::NodeAttrs& attrs) 842   [](const nnvm::NodeAttrs& attrs) {
843     return std::vector<uint32_t>{2, 3, 4};
844   })
845 .set_attr<FCompute>("FCompute<cpu>", RMSPropAlexUpdate<cpu>)
846 .add_argument("weight", "NDArray-or-Symbol", "Weight")
847 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
848 .add_argument("n", "NDArray-or-Symbol", "n")
849 .add_argument("g", "NDArray-or-Symbol", "g")
850 .add_argument("delta", "NDArray-or-Symbol", "delta")
851 .add_arguments(RMSPropAlexParam::__FIELDS__());
852 
853 NNVM_REGISTER_OP(ftrl_update)
854 MXNET_ADD_SPARSE_OP_ALIAS(ftrl_update)
855 .describe(R"code(Update function for Ftrl optimizer.
856 Referenced from *Ad Click Prediction: a View from the Trenches*, available at
857 http://dl.acm.org/citation.cfm?id=2488200.
858 
859 It updates the weights using::
860 
861  rescaled_grad = clip(grad * rescale_grad, clip_gradient)
862  z += rescaled_grad - (sqrt(n + rescaled_grad**2) - sqrt(n)) * weight / learning_rate
863  n += rescaled_grad**2
864  w = (sign(z) * lamda1 - z) / ((beta + sqrt(n)) / learning_rate + wd) * (abs(z) > lamda1)
865 
866 If w, z and n are all of ``row_sparse`` storage type,
867 only the row slices whose indices appear in grad.indices are updated (for w, z and n)::
868 
869  for row in grad.indices:
870      rescaled_grad[row] = clip(grad[row] * rescale_grad, clip_gradient)
871      z[row] += rescaled_grad[row] - (sqrt(n[row] + rescaled_grad[row]**2) - sqrt(n[row])) * weight[row] / learning_rate
872      n[row] += rescaled_grad[row]**2
873      w[row] = (sign(z[row]) * lamda1 - z[row]) / ((beta + sqrt(n[row])) / learning_rate + wd) * (abs(z[row]) > lamda1)
874 
875 )code" ADD_FILELINE)
876 .set_num_inputs(4)
877 .set_num_outputs(1)
878 .set_attr_parser(ParamParser<FtrlParam>)
879 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
880 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
881 .set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<4, 1, false, true, false>)
882 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191c02(const nnvm::NodeAttrs& attrs) 883   [](const nnvm::NodeAttrs& attrs) {
884     return std::vector<uint32_t>{2, 3};
885   })
886 .set_attr<FCompute>("FCompute<cpu>", FtrlUpdate<cpu>)
887 .set_attr<FComputeEx>("FComputeEx<cpu>", FtrlUpdateEx<cpu>)
888 .add_argument("weight", "NDArray-or-Symbol", "Weight")
889 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
890 .add_argument("z", "NDArray-or-Symbol", "z")
891 .add_argument("n", "NDArray-or-Symbol", "Square of grad")
892 .add_arguments(FtrlParam::__FIELDS__());
893 
894 NNVM_REGISTER_OP(_sparse_adagrad_update)
895 .describe(R"code(Update function for AdaGrad optimizer.
896 
897 Referenced from *Adaptive Subgradient Methods for Online Learning and Stochastic Optimization*,
898 and available at http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf.
899 
900 Updates are applied by::
901 
902     rescaled_grad = clip(grad * rescale_grad, clip_gradient)
903     history = history + square(rescaled_grad)
904     w = w - learning_rate * rescaled_grad / sqrt(history + epsilon)
905 
906 Note that non-zero values for the weight decay option are not supported.
907 
908 )code" ADD_FILELINE)
909 .set_num_inputs(3)
910 .set_num_outputs(1)
911 .set_attr_parser(ParamParser<AdagradParam>)
912 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
913 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
914 .set_attr<FInferStorageType>("FInferStorageType", AdagradStorageType)
915 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191d02(const nnvm::NodeAttrs& attrs) 916   [](const nnvm::NodeAttrs& attrs) {
917     return std::vector<uint32_t>{2};
918   })
919 .set_attr<FComputeEx>("FComputeEx<cpu>", AdagradUpdateEx<cpu>)
920 .add_argument("weight", "NDArray-or-Symbol", "Weight")
921 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
922 .add_argument("history", "NDArray-or-Symbol", "History")
923 .add_arguments(AdagradParam::__FIELDS__());
924 
925 NNVM_REGISTER_OP(lamb_update_phase1)
926 .describe(R"code(Phase I of lamb update it performs the following operations and returns g:.
927 
928 Link to paper: https://arxiv.org/pdf/1904.00962.pdf
929 
930 .. math::
931     \begin{gather*}
932     grad = grad * rescale_grad
933     if (grad < -clip_gradient)
934     then
935          grad = -clip_gradient
936     if (grad > clip_gradient)
937     then
938          grad = clip_gradient
939 
940     mean = beta1 * mean + (1 - beta1) * grad;
941     variance = beta2 * variance + (1. - beta2) * grad ^ 2;
942 
943     if (bias_correction)
944     then
945          mean_hat = mean / (1. - beta1^t);
946          var_hat = var / (1 - beta2^t);
947          g = mean_hat / (var_hat^(1/2) + epsilon) + wd * weight;
948     else
949          g = mean / (var_data^(1/2) + epsilon) + wd * weight;
950     \end{gather*}
951 
952 )code" ADD_FILELINE)
953 .set_num_inputs(4)
954 .set_num_outputs(1)
955 .set_attr_parser(ParamParser<LambUpdatePhaseOneParam>)
956 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
957 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
958 .set_attr<FCompute>("FCompute<cpu>", LambUpdatePhaseOne<cpu>)
959 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191e02(const nnvm::NodeAttrs& attrs) 960   [](const nnvm::NodeAttrs& attrs) {
961     return std::vector<uint32_t>{2, 3};
962   })
963 .add_argument("weight", "NDArray-or-Symbol", "Weight")
964 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
965 .add_argument("mean", "NDArray-or-Symbol", "Moving mean")
966 .add_argument("var", "NDArray-or-Symbol", "Moving variance")
967 .add_arguments(LambUpdatePhaseOneParam::__FIELDS__());
968 
969 NNVM_REGISTER_OP(lamb_update_phase2)
970 .describe(R"code(Phase II of lamb update it performs the following operations and updates grad.
971 
972 Link to paper: https://arxiv.org/pdf/1904.00962.pdf
973 
974 .. math::
975     \begin{gather*}
976     if (lower_bound >= 0)
977     then
978          r1 = max(r1, lower_bound)
979     if (upper_bound >= 0)
980     then
981          r1 = max(r1, upper_bound)
982 
983     if (r1 == 0 or r2 == 0)
984     then
985          lr = lr
986     else
987          lr = lr * (r1/r2)
988     weight = weight - lr * g
989     \end{gather*}
990 
991 )code" ADD_FILELINE)
992 .set_num_inputs(4)
993 .set_num_outputs(1)
994 .set_attr_parser(ParamParser<LambUpdatePhaseTwoParam>)
995 .set_attr<mxnet::FInferShape>("FInferShape", LambUpdatePhaseTwoShape)
996 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
997 .set_attr<FCompute>("FCompute<cpu>", LambUpdatePhaseTwo<cpu>)
998 .add_argument("weight", "NDArray-or-Symbol", "Weight")
999 .add_argument("g", "NDArray-or-Symbol", "Output of lamb_update_phase 1")
1000 .add_argument("r1", "NDArray-or-Symbol", "r1")
1001 .add_argument("r2", "NDArray-or-Symbol", "r2")
1002 .add_arguments(LambUpdatePhaseTwoParam::__FIELDS__());
1003 
1004 NNVM_REGISTER_OP(mp_lamb_update_phase1)
1005 .describe(R"code(Mixed Precision version of Phase I of lamb update
1006 it performs the following operations and returns g:.
1007 
1008           Link to paper: https://arxiv.org/pdf/1904.00962.pdf
1009 
1010           .. math::
1011               \begin{gather*}
1012               grad32 = grad(float16) * rescale_grad
1013               if (grad < -clip_gradient)
1014               then
1015                    grad = -clip_gradient
1016               if (grad > clip_gradient)
1017               then
1018                    grad = clip_gradient
1019 
1020               mean = beta1 * mean + (1 - beta1) * grad;
1021               variance = beta2 * variance + (1. - beta2) * grad ^ 2;
1022 
1023               if (bias_correction)
1024               then
1025                    mean_hat = mean / (1. - beta1^t);
1026                    var_hat = var / (1 - beta2^t);
1027                    g = mean_hat / (var_hat^(1/2) + epsilon) + wd * weight32;
1028               else
1029                    g = mean / (var_data^(1/2) + epsilon) + wd * weight32;
1030               \end{gather*}
1031 
1032           )code" ADD_FILELINE)
1033 .set_num_inputs(5)
1034 .set_num_outputs(1)
1035 .set_attr_parser(ParamParser<LambUpdatePhaseOneParam>)
1036 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<5, 1>)
1037 .set_attr<nnvm::FInferType>("FInferType", MPLambPhaseOneType<2, 1, 5>)
1038 .set_attr<FCompute>("FCompute<cpu>", MPLambUpdatePhaseOne<cpu>)
1039 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191f02(const nnvm::NodeAttrs& attrs) 1040   [](const nnvm::NodeAttrs& attrs) {
1041     return std::vector<uint32_t>{2, 3};
1042   })
1043 .add_argument("weight", "NDArray-or-Symbol", "Weight")
1044 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
1045 .add_argument("mean", "NDArray-or-Symbol", "Moving mean")
1046 .add_argument("var", "NDArray-or-Symbol", "Moving variance")
1047 .add_argument("weight32", "NDArray-or-Symbol", "Weight32")
1048 .add_arguments(LambUpdatePhaseOneParam::__FIELDS__());
1049 
1050 NNVM_REGISTER_OP(mp_lamb_update_phase2)
1051 .describe(R"code(Mixed Precision version Phase II of lamb update
1052 it performs the following operations and updates grad.
1053 
1054           Link to paper: https://arxiv.org/pdf/1904.00962.pdf
1055 
1056           .. math::
1057               \begin{gather*}
1058               if (lower_bound >= 0)
1059               then
1060                    r1 = max(r1, lower_bound)
1061               if (upper_bound >= 0)
1062               then
1063                    r1 = max(r1, upper_bound)
1064 
1065               if (r1 == 0 or r2 == 0)
1066               then
1067                    lr = lr
1068               else
1069                    lr = lr * (r1/r2)
1070               weight32 = weight32 - lr * g
1071               weight(float16) = weight32
1072               \end{gather*}
1073 
1074           )code" ADD_FILELINE)
1075 .set_num_inputs(5)
1076 .set_num_outputs(1)
1077 .set_attr_parser(ParamParser<LambUpdatePhaseTwoParam>)
1078 .set_attr<mxnet::FInferShape>("FInferShape", MPLambUpdatePhaseTwoShape)
1079 .set_attr<nnvm::FInferType>("FInferType", MP_InferType<1, 1, 5>)
1080 .set_attr<FCompute>("FCompute<cpu>", MPLambUpdatePhaseTwo<cpu>)
1081 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde192002(const nnvm::NodeAttrs& attrs) 1082   [](const nnvm::NodeAttrs& attrs) {
1083     return std::vector<uint32_t>{4};
1084   })
1085 .add_argument("weight", "NDArray-or-Symbol", "Weight")
1086 .add_argument("g", "NDArray-or-Symbol", "Output of mp_lamb_update_phase 1")
1087 .add_argument("r1", "NDArray-or-Symbol", "r1")
1088 .add_argument("r2", "NDArray-or-Symbol", "r2")
1089 .add_argument("weight32", "NDArray-or-Symbol", "Weight32")
1090 .add_arguments(LambUpdatePhaseTwoParam::__FIELDS__());
1091 
1092 }  // namespace op
1093 }  // namespace mxnet
1094