1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file optimizer_op-inl.h
22  * \brief Optimizer operators
23  * \author Junyuan Xie
24  */
25 #ifndef MXNET_OPERATOR_OPTIMIZER_OP_INL_H_
26 #define MXNET_OPERATOR_OPTIMIZER_OP_INL_H_
27 #include <dmlc/parameter.h>
28 #include <mxnet/operator.h>
29 #include <mxnet/operator_util.h>
30 #include <mxnet/op_attr_types.h>
31 #include <mshadow/base.h>
32 #include <nnvm/op.h>
33 #include <nnvm/op_attr_types.h>
34 #include <vector>
35 #include "./operator_common.h"
36 #include "./mshadow_op.h"
37 #include "./elemwise_op_common.h"
38 #include "mxnet_op.h"
39 #include "./tensor/init_op.h"
40 #include "./tensor/util/tensor_util-inl.h"
41 
42 namespace mxnet {
43 namespace op {
44 
45 /*
46  * \brief log message for optimizers with lazy update.
47  */
LogLazyUpdate()48 inline void LogLazyUpdate() {
49   common::LogOnce("Optimizer with lazy_update = True detected. "
50                   "Be aware that lazy update with row_sparse gradient is different from "
51                   "standard update, and may lead to different empirical results. See "
52                   "https://mxnet.apache.org/api/python/optimization/optimization.html "
53                   "for more details.");
54 }
55 
56 struct SGDParam : public dmlc::Parameter<SGDParam> {
57   float lr;
58   float wd;
59   float rescale_grad;
60   float clip_gradient;
61   bool lazy_update;
DMLC_DECLARE_PARAMETERSGDParam62   DMLC_DECLARE_PARAMETER(SGDParam) {
63     DMLC_DECLARE_FIELD(lr)
64     .describe("Learning rate");
65     DMLC_DECLARE_FIELD(wd)
66     .set_default(0.0f)
67     .describe("Weight decay augments the objective function with a "
68               "regularization term that penalizes large weights. "
69               "The penalty scales with the square of the magnitude of each weight.");
70     DMLC_DECLARE_FIELD(rescale_grad)
71     .set_default(1.0f)
72     .describe("Rescale gradient to grad = rescale_grad*grad.");
73     DMLC_DECLARE_FIELD(clip_gradient)
74     .set_default(-1.0f)
75     .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
76               "If clip_gradient <= 0, gradient clipping is turned off. "
77               "grad = max(min(grad, clip_gradient), -clip_gradient).");
78     DMLC_DECLARE_FIELD(lazy_update)
79     .set_default(true)
80     .describe("If true, lazy updates are applied if gradient's stype is row_sparse.");
81   }
82 };
83 
84 struct MultiSGDParam : public dmlc::Parameter<MultiSGDParam> {
85   mxnet::Tuple<float> lrs;
86   mxnet::Tuple<float> wds;
87   float rescale_grad;
88   float clip_gradient;
89   int num_weights;
DMLC_DECLARE_PARAMETERMultiSGDParam90   DMLC_DECLARE_PARAMETER(MultiSGDParam) {
91     DMLC_DECLARE_FIELD(lrs)
92     .describe("Learning rates.");
93     DMLC_DECLARE_FIELD(wds)
94     .describe("Weight decay augments the objective function with a "
95               "regularization term that penalizes large weights. "
96               "The penalty scales with the square of the magnitude of each weight.");
97     DMLC_DECLARE_FIELD(rescale_grad)
98     .set_default(1.0f)
99     .describe("Rescale gradient to grad = rescale_grad*grad.");
100     DMLC_DECLARE_FIELD(clip_gradient)
101     .set_default(-1.0f)
102     .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
103               "If clip_gradient <= 0, gradient clipping is turned off. "
104               "grad = max(min(grad, clip_gradient), -clip_gradient).");
105     DMLC_DECLARE_FIELD(num_weights)
106     .set_default(1)
107     .describe("Number of updated weights.");
108   }
109 };
110 
111 struct MultiSGDMomParam : public dmlc::Parameter<MultiSGDMomParam> {
112   mxnet::Tuple<float> lrs;
113   mxnet::Tuple<float> wds;
114   float momentum;
115   float rescale_grad;
116   float clip_gradient;
117   int num_weights;
DMLC_DECLARE_PARAMETERMultiSGDMomParam118   DMLC_DECLARE_PARAMETER(MultiSGDMomParam) {
119     DMLC_DECLARE_FIELD(lrs)
120     .describe("Learning rates.");
121     DMLC_DECLARE_FIELD(wds)
122     .describe("Weight decay augments the objective function with a "
123               "regularization term that penalizes large weights. "
124               "The penalty scales with the square of the magnitude of each weight.");
125     DMLC_DECLARE_FIELD(momentum)
126     .set_default(0.0f)
127     .describe("The decay rate of momentum estimates at each epoch.");
128     DMLC_DECLARE_FIELD(rescale_grad)
129     .set_default(1.0f)
130     .describe("Rescale gradient to grad = rescale_grad*grad.");
131     DMLC_DECLARE_FIELD(clip_gradient)
132     .set_default(-1.0f)
133     .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
134               "If clip_gradient <= 0, gradient clipping is turned off. "
135               "grad = max(min(grad, clip_gradient), -clip_gradient).");
136     DMLC_DECLARE_FIELD(num_weights)
137     .set_default(1)
138     .describe("Number of updated weights.");
139   }
140 };
141 
142 
143 template<typename ParamType, int input_stride>
MultiSGDShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)144 inline bool MultiSGDShape(const nnvm::NodeAttrs& attrs,
145                           mxnet::ShapeVector *in_attrs,
146                           mxnet::ShapeVector *out_attrs) {
147   const ParamType& param = dmlc::get<ParamType>(attrs.parsed);
148   CHECK_EQ(in_attrs->size(), input_stride * param.num_weights);
149   CHECK_EQ(out_attrs->size(), param.num_weights);
150 
151   bool all_inferred = true;
152   auto& input_shapes = *in_attrs;
153   auto& output_shapes = *out_attrs;
154   // Learning rates
155   CHECK_EQ(param.lrs.ndim(), param.num_weights)
156     << "Number of learning rates is inconsistent with num_weights "
157     << "parameter passed. Expected number of learning rates: "
158     << param.num_weights << ", and got " << param.lrs.ndim();
159   // Weight decays
160   CHECK_EQ(param.wds.ndim(), param.num_weights)
161     << "Number of weight decays is inconsistent with num_weights "
162     << "parameter passed. Expected number of weight decays: "
163     << param.num_weights << ", and got " << param.wds.ndim();
164   // Weights and gradients
165   for (int i = 0; i < param.num_weights; ++i) {
166     mxnet::ShapeVector input_vec;
167     mxnet::ShapeVector output_vec({output_shapes[i]});
168     for (int j = 0; j < input_stride; ++j) {
169       input_vec.push_back(input_shapes[i * input_stride + j]);
170     }
171     all_inferred = all_inferred && ElemwiseShape<input_stride, 1>(attrs, &input_vec, &output_vec);
172   }
173   return all_inferred;
174 }
175 
176 template <typename ParamType, int input_stride, int num_fp32_inputs>
MP_MultiSGD_InferType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_attrs,std::vector<int> * out_attrs)177 inline bool MP_MultiSGD_InferType(const nnvm::NodeAttrs& attrs,
178                                   std::vector<int> *in_attrs,
179                                   std::vector<int> *out_attrs) {
180   const ParamType& param = dmlc::get<ParamType>(attrs.parsed);
181   CHECK_EQ(in_attrs->size(), input_stride * param.num_weights);
182   CHECK_EQ(out_attrs->size(), param.num_weights);
183 
184   bool all_inferred = true;
185   auto& input_types = *in_attrs;
186   auto& output_types = *out_attrs;
187   // Weights and gradients
188   for (int i = 0; i < param.num_weights; ++i) {
189     std::vector<int> input_vec;
190     std::vector<int> output_vec({output_types[i]});
191     for (int j = 0; j < input_stride - num_fp32_inputs; ++j) {
192       input_vec.push_back(input_types[i * input_stride + j]);
193     }
194     all_inferred = all_inferred &&
195                    ElemwiseType<input_stride - num_fp32_inputs, 1>(attrs, &input_vec, &output_vec);
196   }
197   // master copies of weights
198   for (int i = 0; i < param.num_weights; ++i) {
199     for (int j = 0; j < num_fp32_inputs; ++j) {
200       TYPE_ASSIGN_CHECK(input_types, input_stride * i + input_stride - 1 - j, mshadow::kFloat32);
201     }
202   }
203   return all_inferred;
204 }
205 
206 template<typename DType, typename MPDType>
207 struct MultiSGDKernelParam {
208   static const int N = 60;
209   int count;
210   size_t max_size;
211   size_t sizes[N];
212   DType * weights[N];
213   DType * grads[N];
214   MPDType * mom[N];
215   MPDType * weights32[N];
216   DType * out_data[N];
217   MPDType lrs[N];
218   MPDType wds[N];
219   MPDType clip_gradient;
220   MPDType rescale_grad;
221   MPDType momentum;
222 };
223 
224 template <typename MPDType, bool has_momentum, bool has_mixed_precision>
225 struct MultiSGDKernel {
226   template<typename DType>
MapMultiSGDKernel227   MSHADOW_XINLINE static void Map(index_t i, const MultiSGDKernelParam<DType, MPDType>& param,
228     const OpReqType req) {
229     for (int index = 0; index < param.count; ++index) {
230       if (i < static_cast<index_t>(param.sizes[index])) {
231         MPDType w = has_mixed_precision ? param.weights32[index][i] :
232                                           MPDType(param.weights[index][i]);
233         MPDType mom = has_momentum ? param.mom[index][i] : MPDType(0);
234         if (param.clip_gradient >= 0.0f) {
235           mom = param.momentum*mom
236                 - param.lrs[index]*param.wds[index]*w
237                 - param.lrs[index]
238                 *mshadow_op::clip::Map(param.rescale_grad *
239                                        static_cast<MPDType>(param.grads[index][i]),
240                                      param.clip_gradient);
241         } else {
242           mom = param.momentum*mom
243                 - param.lrs[index]*param.wds[index]*w
244                 - param.lrs[index]*param.rescale_grad*static_cast<MPDType>(param.grads[index][i]);
245         }
246         if (has_momentum) {
247           param.mom[index][i] = mom;
248         }
249         w = w + mom;
250         if (has_mixed_precision) {
251           param.weights32[index][i] = w;
252         }
253         KERNEL_ASSIGN(param.out_data[index][i], req, w);
254       }
255     }
256   }
257 };
258 
259 template<typename xpu,
260          typename DType,
261          typename MPDType,
262          typename ParamType = MultiSGDParam,
263          int input_stride = 2>
FillMultiSGDKernelParam(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<TBlob> & outputs)264 MultiSGDKernelParam<DType, MPDType> FillMultiSGDKernelParam(const nnvm::NodeAttrs& attrs,
265                                                             const OpContext &ctx,
266                                                             const std::vector<TBlob> &inputs,
267                                                             const std::vector<TBlob> &outputs) {
268   using namespace mxnet_op;
269   const ParamType& p = nnvm::get<ParamType>(attrs.parsed);
270   Stream<xpu>* s = ctx.get_stream<xpu>();
271   MultiSGDKernelParam<DType, MPDType> param;
272   param.clip_gradient = p.clip_gradient;
273   param.rescale_grad = p.rescale_grad;
274   param.momentum = 0;
275   param.count = p.num_weights;
276   param.max_size = 0;
277   for (int i = 0; i < param.count; ++i) {
278     param.sizes[i] = inputs[i * input_stride].shape_.Size();
279     if (param.max_size < param.sizes[i]) {
280       param.max_size = param.sizes[i];
281     }
282     param.weights[i] = inputs[i * input_stride].FlatTo2D<xpu, DType>(s).dptr_;
283     param.grads[i] = inputs[i * input_stride + 1].FlatTo2D<xpu, DType>(s).dptr_;
284     // if mixed precision, then the last input in a set
285     // is 32-bit master copy of the weights
286     if (!std::is_same<DType, MPDType>::value) {
287       param.weights32[i] = inputs[i * input_stride + input_stride - 1]
288                            .FlatTo2D<xpu, MPDType>(s).dptr_;
289     }
290     param.out_data[i] = outputs[i].FlatTo2D<xpu, DType>(s).dptr_;
291     param.lrs[i] = p.lrs[i];
292     param.wds[i] = p.wds[i];
293   }
294 
295   return param;
296 }
297 
298 
299 template<typename xpu,
300          typename DType,
301          typename MPDType,
302          int input_stride = 3>
FillMultiSGDMomKernelParam(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<TBlob> & outputs)303 MultiSGDKernelParam<DType, MPDType> FillMultiSGDMomKernelParam(const nnvm::NodeAttrs& attrs,
304                                                             const OpContext &ctx,
305                                                             const std::vector<TBlob> &inputs,
306                                                             const std::vector<TBlob> &outputs) {
307   using namespace mxnet_op;
308   const MultiSGDMomParam& p = nnvm::get<MultiSGDMomParam>(attrs.parsed);
309   Stream<xpu>* s = ctx.get_stream<xpu>();
310   MultiSGDKernelParam<DType, MPDType> param =
311     FillMultiSGDKernelParam<xpu,
312                             DType,
313                             MPDType,
314                             MultiSGDMomParam,
315                             input_stride>(attrs, ctx, inputs, outputs);
316   param.momentum = p.momentum;
317   for (int i = 0; i < param.count; ++i) {
318     param.mom[i] = inputs[i * input_stride + 2].FlatTo2D<xpu, MPDType>(s).dptr_;
319   }
320 
321   return param;
322 }
323 
324 template<typename T>
325 class type_identity {
326  public:
327   using type = T;
328 };
329 
330 template<typename T>
331 class single_precision {
332  public:
333   using type = float;
334 };
335 
336 template<typename xpu, template<typename> class MPTypeChooser, int input_stride>
MultiSGDUpdate(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)337 inline void MultiSGDUpdate(const nnvm::NodeAttrs& attrs,
338                            const OpContext &ctx,
339                            const std::vector<TBlob> &inputs,
340                            const std::vector<OpReqType> &req,
341                            const std::vector<TBlob> &outputs) {
342   using namespace mxnet_op;
343   Stream<xpu>* s = ctx.get_stream<xpu>();
344   MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
345     using MPDType = typename MPTypeChooser<DType>::type;
346     MultiSGDKernelParam<DType, MPDType> param =
347       FillMultiSGDKernelParam<xpu,
348                               DType,
349                               MPDType,
350                               MultiSGDParam,
351                               input_stride>(attrs, ctx, inputs, outputs);
352     Kernel<MultiSGDKernel<MPDType,
353                           false,
354                           !std::is_same<DType, MPDType>::value>,
355                           xpu>::Launch(s, param.max_size, param, req[0]);
356   });
357 }
358 
359 template<typename xpu, template<typename> class MPTypeChooser, int input_stride>
MultiSGDMomUpdate(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)360 inline void MultiSGDMomUpdate(const nnvm::NodeAttrs& attrs,
361                               const OpContext &ctx,
362                               const std::vector<TBlob> &inputs,
363                               const std::vector<OpReqType> &req,
364                               const std::vector<TBlob> &outputs) {
365   using namespace mxnet_op;
366   Stream<xpu>* s = ctx.get_stream<xpu>();
367   MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
368     using MPDType = typename MPTypeChooser<DType>::type;
369     MultiSGDKernelParam<DType, MPDType> param =
370       FillMultiSGDMomKernelParam<xpu,
371                                  DType,
372                                  MPDType,
373                                  input_stride>(attrs, ctx, inputs, outputs);
374     Kernel<MultiSGDKernel<MPDType,
375                           true,
376                           !std::is_same<DType, MPDType>::value>,
377                           xpu>::Launch(s, param.max_size, param, req[0]);
378   });
379 }
380 
381 struct SGDKernel {
382   template<typename DType>
MapSGDKernel383   MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const DType* weight_data,
384     const DType* grad_data, const DType param_clip_gradient,
385     const DType param_lr, const DType param_wd, const DType param_rescale_grad,
386     const OpReqType req) {
387     if (param_clip_gradient >= 0.0f) {
388       KERNEL_ASSIGN(out_data[i], req,
389              (1.f-param_lr*param_wd)*weight_data[i]
390                - (param_lr)
391                  * mshadow_op::clip::Map(param_rescale_grad*grad_data[i], param_clip_gradient));
392     } else {
393       KERNEL_ASSIGN(out_data[i], req,
394              (1.f-param_lr*param_wd)*weight_data[i]
395                - (param_lr*param_rescale_grad)*grad_data[i]);
396     }
397   }
398 };
399 
400 template<typename xpu>
SGDUpdate(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)401 inline void SGDUpdate(const nnvm::NodeAttrs& attrs,
402                       const OpContext &ctx,
403                       const std::vector<TBlob> &inputs,
404                       const std::vector<OpReqType> &req,
405                       const std::vector<TBlob> &outputs) {
406   using namespace mxnet_op;
407   const SGDParam& param = nnvm::get<SGDParam>(attrs.parsed);
408   Stream<xpu>* s = ctx.get_stream<xpu>();
409   MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
410     Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
411     Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
412     Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
413     Kernel<SGDKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, weight.dptr_,
414       grad.dptr_, static_cast<DType>(param.clip_gradient),
415       static_cast<DType>(param.lr), static_cast<DType>(param.wd),
416       static_cast<DType>(param.rescale_grad), req[0]);
417   });
418 }
419 
420 /*! \brief kernel for sparse sgd
421  */
422 template<int req, typename xpu>
423 struct SGDDnsRspKernel;
424 
425 template<int req>
426 struct SGDDnsRspKernel<req, gpu> {
427   // DType is the output data type
428   // IType is row sparse idx type
429   // i is the ith element in row sparse gradient
430   template<typename DType, typename IType>
431   MSHADOW_XINLINE static void Map(index_t i, const index_t row_length, DType* out,
432                                   const DType* weight, const IType* grad_idx,
433                                   const DType *grad_val, const DType clip_gradient, const DType lr,
434                                   const DType wd, const DType rescale_grad) {
435     using nnvm::dim_t;
436     using namespace mshadow_op;
437     const dim_t row_id = i / row_length;
438     const dim_t col_id = i % row_length;
439     const dim_t row_offset = grad_idx[row_id] * row_length;
440     const dim_t data_i = row_offset + col_id;
441     if (clip_gradient >= 0.0f) {
442       KERNEL_ASSIGN(out[data_i], req, (1.f - lr * wd) * weight[data_i] -
443                    (lr) * mshadow_op::clip::Map(rescale_grad * grad_val[i], clip_gradient));
444     } else {
445       KERNEL_ASSIGN(out[data_i], req, (1.f - lr * wd) * weight[data_i] -
446                     (lr * rescale_grad) * grad_val[i]);
447     }
448   }
449 };
450 
451 /*! \brief kernel for sparse sgd
452  */
453 template<int req>
454 struct SGDDnsRspKernel<req, cpu> {
455   // DType is the output data type
456   // IType is row sparse idx type
457   // i is the ith row in row sparse gradient
458   template<typename DType, typename IType>
459   MSHADOW_XINLINE static void Map(index_t i, const index_t row_length, DType* out,
460                                   const DType* weight, const IType* grad_idx,
461                                   const DType *grad_val, const DType clip_gradient, const DType lr,
462                                   const DType wd, const DType rescale_grad) {
463     for (index_t j = 0; j < row_length; j++) {
464       index_t data_i = grad_idx[i] * row_length + j;
465       index_t grad_i = i * row_length + j;
466       if (clip_gradient >= 0.0f) {
467         KERNEL_ASSIGN(out[data_i], req, (1.f - lr * wd) * weight[data_i] -
468                      (lr) * mshadow_op::clip::Map(rescale_grad * grad_val[grad_i], clip_gradient));
469       } else {
470         KERNEL_ASSIGN(out[data_i], req, (1.f - lr * wd) * weight[data_i] -
471                       (lr * rescale_grad) * grad_val[grad_i]);
472       }
473     }
474   }
475 };
476 
477 /*
478  * \brief SGD update implementation for dense weight and row_sparse grad.
479  *        Both standard update and lazy update are supported.
480  */
481 template<typename xpu>
482 inline void SGDUpdateDnsRspImpl(const SGDParam& param,
483                                 const OpContext &ctx,
484                                 const TBlob& weight,
485                                 const NDArray& grad,
486                                 const OpReqType& req,
487                                 TBlob *out) {
488   using namespace mshadow;
489   using namespace mshadow::expr;
490   using namespace mshadow_op;
491   using namespace mxnet_op;
492   Stream<xpu>* s = ctx.get_stream<xpu>();
493   CHECK_EQ(grad.storage_type(), kRowSparseStorage);
494   // if gradients are zeros, no weights are updated
495   if (req == kNullOp) return;
496   CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_mom_update";
497   CHECK_GT(weight.shape_.Size(), 0);
498 
499   MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
500     MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, {
501       MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
502         DType* weight_data = weight.dptr<DType>();
503         float wd = param.wd;
504         // apply standard weight decay if not lazy update
505         if (!param.lazy_update) {
506           Kernel<op_with_req<mshadow_op::mul, req_type>, xpu>::Launch(s, weight.Size(),
507             weight_data, weight_data, static_cast<DType>(1 - param.lr * param.wd));
508           wd = 0;
509         }
510         if (!grad.storage_initialized()) return;
511         const IType* grad_idx = grad.aux_data(rowsparse::kIdx).dptr<IType>();
512         const DType* grad_val = grad.data().dptr<DType>();
513         const nnvm::dim_t num_rows = grad.aux_shape(rowsparse::kIdx)[0];
514         const auto row_length = weight.shape_.ProdShape(1, weight.ndim());
515         size_t num_threads = num_rows;
516         if (std::is_same<xpu, gpu>::value) {
517           num_threads = num_rows * row_length;
518         }
519         Kernel<SGDDnsRspKernel<req_type, xpu>, xpu>::Launch(s, num_threads, row_length,
520           out->dptr<DType>(), weight_data, grad_idx, grad_val,
521           static_cast<DType>(param.clip_gradient),
522           static_cast<DType>(param.lr), static_cast<DType>(wd),
523           static_cast<DType>(param.rescale_grad));
524       });
525     });
526   });
527 }
528 
529 /*
530  * \brief SGD update implementation for row_sparse grad.
531  *        Both standard update and lazy update are supported.
532  */
533 template<typename xpu>
534 inline void SGDUpdateRspImpl(const SGDParam& param,
535                              const OpContext& ctx,
536                              const NDArray& weight,
537                              const NDArray& grad,
538                              const OpReqType& req,
539                              NDArray *out) {
540   CheckAllRowsPresent(weight, "SGDUpdate", "weights");
541   // reuse dns rsp implementation when storage_shape == shape
542   TBlob out_blob = out->data();
543   SGDUpdateDnsRspImpl<xpu>(param, ctx, weight.data(), grad, req, &out_blob);
544 }
545 
546 template<typename xpu>
547 inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs,
548                         const OpContext &ctx,
549                         const std::vector<NDArray> &inputs,
550                         const std::vector<OpReqType> &req,
551                         const std::vector<NDArray> &outputs) {
552   const SGDParam& param = nnvm::get<SGDParam>(attrs.parsed);
553   const auto w_stype = inputs[0].storage_type();
554   const auto g_stype = inputs[1].storage_type();
555   const auto o_stype = outputs[0].storage_type();
556   if (o_stype == w_stype && g_stype == kRowSparseStorage &&
557       (w_stype == kDefaultStorage || w_stype == kRowSparseStorage)) {
558     NDArray out = outputs[0];
559     // std update and lazy update with rsp grad
560     SGDUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], req[0], &out);
561   } else {
562     LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
563   }
564 }
565 
566 struct SGDMomParam : public dmlc::Parameter<SGDMomParam> {
567   float lr;
568   float momentum;
569   float wd;
570   float rescale_grad;
571   float clip_gradient;
572   bool lazy_update;
573   DMLC_DECLARE_PARAMETER(SGDMomParam) {
574     DMLC_DECLARE_FIELD(lr)
575     .describe("Learning rate");
576     DMLC_DECLARE_FIELD(momentum)
577     .set_default(0.0f)
578     .describe("The decay rate of momentum estimates at each epoch.");
579     DMLC_DECLARE_FIELD(wd)
580     .set_default(0.0f)
581     .describe("Weight decay augments the objective function with a "
582               "regularization term that penalizes large weights. "
583               "The penalty scales with the square of the magnitude of each weight.");
584     DMLC_DECLARE_FIELD(rescale_grad)
585     .set_default(1.0f)
586     .describe("Rescale gradient to grad = rescale_grad*grad.");
587     DMLC_DECLARE_FIELD(clip_gradient)
588     .set_default(-1.0f)
589     .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
590               "If clip_gradient <= 0, gradient clipping is turned off. "
591               "grad = max(min(grad, clip_gradient), -clip_gradient).");
592     DMLC_DECLARE_FIELD(lazy_update)
593     .set_default(true)
594     .describe("If true, lazy updates are applied if gradient's stype is row_sparse "
595               "and both weight and momentum have the same stype");
596   }
597 };
598 
599 
600 struct SGDMomKernel {
601   template<typename DType>
602   MSHADOW_XINLINE static void Map(index_t i, DType* out_data, DType* mom_data,
603                                   const DType* weight_data, const DType* grad_data,
604                                   const DType param_clip_gradient, const DType param_momentum,
605                                   const DType param_lr, const DType param_wd,
606                                   const DType param_rescale_grad, const OpReqType req) {
607     if (param_clip_gradient >= 0.0f) {
608       mom_data[i] = param_momentum*mom_data[i]
609               - param_lr*param_wd*weight_data[i]
610               - param_lr
611               *mshadow_op::clip::Map(param_rescale_grad*grad_data[i], param_clip_gradient);
612     } else {
613       mom_data[i] = param_momentum*mom_data[i]
614                 - param_lr*param_wd*weight_data[i]
615                 - param_lr*param_rescale_grad*grad_data[i];
616     }
617     KERNEL_ASSIGN(out_data[i], req, weight_data[i] + mom_data[i]);
618   }
619 };
620 
621 template<typename xpu>
622 inline void SGDMomUpdate(const nnvm::NodeAttrs& attrs,
623                          const OpContext &ctx,
624                          const std::vector<TBlob> &inputs,
625                          const std::vector<OpReqType> &req,
626                          const std::vector<TBlob> &outputs) {
627   using namespace mxnet_op;
628   SGDMomParam param = nnvm::get<SGDMomParam>(attrs.parsed);
629   Stream<xpu>* s = ctx.get_stream<xpu>();
630   MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
631     Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
632     Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
633     Tensor<xpu, 2, DType> mom = inputs[2].FlatTo2D<xpu, DType>(s);
634     Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
635     Kernel<SGDMomKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, mom.dptr_, weight.dptr_,
636       grad.dptr_, static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
637       static_cast<DType>(param.lr), static_cast<DType>(param.wd),
638       static_cast<DType>(param.rescale_grad), req[0]);
639     });
640 }
641 
642 template<int n_in, int n_out, int total_in>
643 inline bool MP_InferType(const nnvm::NodeAttrs& attrs,
644                              std::vector<int> *in_attrs,
645                              std::vector<int> *out_attrs) {
646   CHECK_EQ(in_attrs->size(), static_cast<size_t>(total_in)) << " in operator " << attrs.name;
647   CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator " << attrs.name;
648   for (int i = n_in; i < total_in; ++i) {
649     TYPE_ASSIGN_CHECK(*in_attrs, i, mshadow::kFloat32);
650   }
651   return ElemwiseAttr<int, type_is_none, type_assign, true, type_string, n_in, n_out>(
652       attrs, in_attrs, out_attrs, -1);
653 }
654 
655 struct MP_SGDKernel {
656   template<typename DType>
657   MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const DType* weight_data,
658     const DType* grad_data, float* weight32, const float param_clip_gradient,
659     const float param_lr, const float param_wd, const float param_rescale_grad,
660     const OpReqType req) {
661     if (param_clip_gradient >= 0.0f) {
662       float w = weight32[i];
663       w = (1.f - param_lr*param_wd)*w -
664           (param_lr) * mshadow_op::clip::Map(param_rescale_grad*static_cast<float>(grad_data[i]),
665                                              param_clip_gradient);
666       weight32[i] = w;
667       KERNEL_ASSIGN(out_data[i], req, (DType)w);
668     } else {
669       float w = weight32[i];
670       w = (1.f-param_lr*param_wd)*w
671                - (param_lr*param_rescale_grad)*static_cast<float>(grad_data[i]);
672       weight32[i] = w;
673       KERNEL_ASSIGN(out_data[i], req, (DType)w);
674     }
675   }
676 };
677 
678 template<typename xpu>
679 inline void MP_SGDUpdate(const nnvm::NodeAttrs& attrs,
680                       const OpContext &ctx,
681                       const std::vector<TBlob> &inputs,
682                       const std::vector<OpReqType> &req,
683                       const std::vector<TBlob> &outputs) {
684   using namespace mxnet_op;
685   const SGDParam& param = nnvm::get<SGDParam>(attrs.parsed);
686   Stream<xpu>* s = ctx.get_stream<xpu>();
687   MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
688     Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
689     Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
690     Tensor<xpu, 2, float> weight32 = inputs[2].FlatTo2D<xpu, float>(s);
691     Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
692     Kernel<MP_SGDKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, weight.dptr_,
693       grad.dptr_, weight32.dptr_, param.clip_gradient,
694       param.lr, param.wd,
695       param.rescale_grad, req[0]);
696   });
697 }
698 
699 struct MP_SGDMomKernel {
700   template<typename DType>
701   MSHADOW_XINLINE static void Map(index_t i, DType* out_data, float* mom_data,
702     const DType* weight_data, const DType* grad_data, float* weight32,
703     const float param_clip_gradient, const float param_momentum, const float param_lr,
704     const float param_wd, const float param_rescale_grad, const OpReqType req) {
705     float w = weight32[i];
706     float mom = mom_data[i];
707     if (param_clip_gradient >= 0.0f) {
708       mom = param_momentum*mom
709               - param_lr*param_wd*w
710               - param_lr
711               *mshadow_op::clip::Map(param_rescale_grad*static_cast<float>(grad_data[i]),
712                                      param_clip_gradient);
713     } else {
714       mom = param_momentum*mom
715                 - param_lr*param_wd*w
716                 - param_lr*param_rescale_grad*static_cast<float>(grad_data[i]);
717     }
718     mom_data[i] = mom;
719     w = w + mom;
720     weight32[i] = w;
721     KERNEL_ASSIGN(out_data[i], req, w);
722   }
723 };
724 
725 template<typename xpu>
726 inline void MP_SGDMomUpdate(const nnvm::NodeAttrs& attrs,
727                          const OpContext &ctx,
728                          const std::vector<TBlob> &inputs,
729                          const std::vector<OpReqType> &req,
730                          const std::vector<TBlob> &outputs) {
731   using namespace mxnet_op;
732   SGDMomParam param = nnvm::get<SGDMomParam>(attrs.parsed);
733   Stream<xpu>* s = ctx.get_stream<xpu>();
734   MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
735     Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
736     Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
737     Tensor<xpu, 2, float> mom = inputs[2].FlatTo2D<xpu, float>(s);
738     Tensor<xpu, 2, float> weight32 = inputs[3].FlatTo2D<xpu, float>(s);
739     Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
740     Kernel<MP_SGDMomKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, mom.dptr_,
741       weight.dptr_, grad.dptr_, weight32.dptr_, param.clip_gradient, param.momentum,
742       param.lr, param.wd, param.rescale_grad, req[0]);
743   });
744 }
745 
746 template<int req, typename xpu>
747 struct SGDMomDnsRspDnsKernel;
748 
749 template<int req>
750 struct SGDMomDnsRspDnsKernel<req, cpu> {
751   template<typename DType, typename IType>
752   MSHADOW_XINLINE static void Map(index_t i, index_t row_length, DType* out_data,
753     DType* mom_data, const DType* weight_data, const IType* grad_idx,
754     const DType* grad_data, const DType clip_gradient, const DType momentum,
755     const DType lr, const DType wd, const DType rescale_grad) {
756     const DType rate = lr * wd;
757     for (index_t j = 0; j < row_length; j++) {
758       index_t data_i = grad_idx[i] * row_length + j;
759       index_t grad_i = i * row_length + j;
760       if (clip_gradient >= 0.0f) {
761         mom_data[data_i] = momentum * mom_data[data_i]
762                 - rate * weight_data[data_i]
763                 - lr *
764                 mshadow_op::clip::Map(rescale_grad * grad_data[grad_i],
765                                       clip_gradient);
766       } else {
767         mom_data[data_i] = momentum * mom_data[data_i]
768                   - rate * weight_data[data_i]
769                   - lr * rescale_grad * grad_data[grad_i];
770       }
771       KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] + mom_data[data_i]);
772     }
773   }
774 };
775 
776 template<int req>
777 struct SGDMomDnsRspDnsKernel<req, gpu> {
778   template<typename DType, typename IType>
779   MSHADOW_XINLINE static void Map(index_t i, index_t row_length, DType* out_data,
780     DType* mom_data, const DType* weight_data, const IType* grad_idx,
781     const DType* grad_data, const DType clip_gradient, const DType momentum,
782     const DType lr, const DType wd, const DType rescale_grad) {
783     using nnvm::dim_t;
784     const DType rate = lr * wd;
785     const dim_t row_id = i / row_length;
786     const dim_t col_id = i % row_length;
787     const dim_t data_i = grad_idx[row_id] * row_length + col_id;
788     if (clip_gradient >= 0.0f) {
789       mom_data[data_i] = momentum * mom_data[data_i]
790               - rate * weight_data[data_i]
791               - lr *
792               mshadow_op::clip::Map(rescale_grad * grad_data[i],
793                                     clip_gradient);
794     } else {
795       mom_data[data_i] = momentum * mom_data[data_i]
796                 - rate * weight_data[data_i]
797                 - lr * rescale_grad * grad_data[i];
798     }
799     KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] + mom_data[data_i]);
800   }
801 };
802 
803 /*
804  * \brief sgd mom lazy update for dense weight, row_sparse grad, dense state.
805  */
806 template<typename xpu>
807 inline void SGDMomLazyUpdateDnsRspDnsImpl(const SGDMomParam& param,
808                                           const OpContext& ctx,
809                                           const TBlob& weight,
810                                           const NDArray& grad,
811                                           const TBlob& mom,
812                                           const OpReqType& req,
813                                           TBlob *out) {
814   using namespace mxnet_op;
815   using namespace rowsparse;
816   Stream<xpu>* s = ctx.get_stream<xpu>();
817   if (!grad.storage_initialized() || req == kNullOp) return;
818   CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_mom_update";
819   CHECK_GT(weight.shape_.Size(), 0);
820   CHECK_GT(mom.shape_.Size(), 0);
821 
822   MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
823     MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
824       MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
825         DType* weight_data = weight.dptr<DType>();
826         IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
827         DType* grad_val = grad.data().dptr<DType>();
828         DType* mom_data = mom.dptr<DType>();
829         DType* out_data = out->dptr<DType>();
830         index_t num_rows = grad.aux_shape(kIdx)[0];
831         auto row_length = weight.shape_.ProdShape(1, weight.ndim());
832         size_t num_threads = num_rows;
833         if (std::is_same<xpu, gpu>::value) {
834           num_threads = num_rows * row_length;
835         }
836         Kernel<SGDMomDnsRspDnsKernel<req_type, xpu>, xpu>::Launch(s, num_threads, row_length,
837           out_data, mom_data, weight_data, grad_idx, grad_val,
838           static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
839           static_cast<DType>(param.lr), static_cast<DType>(param.wd),
840           static_cast<DType>(param.rescale_grad));
841       });
842     });
843   });
844 }
845 
846 /*
847  * \brief sgd momentum lazy update for row_sparse grad.
848  */
849 template<typename xpu>
850 inline void SGDMomLazyUpdateRspImpl(const SGDMomParam& param,
851                                     const OpContext& ctx,
852                                     const NDArray& weight,
853                                     const NDArray& grad,
854                                     const NDArray& mom,
855                                     const OpReqType& req,
856                                     NDArray *out) {
857   using namespace mxnet_op;
858   using namespace rowsparse;
859   CheckAllRowsPresent(weight, "SGDMomUpdate", "weights");
860   Stream<xpu>* s = ctx.get_stream<xpu>();
861   // fill mom with zero values (if it's in rsp storage)
862   // in order to reuse the sgd mom dns impl
863   if (mom.storage_type() == kRowSparseStorage && !mom.storage_initialized()) {
864     NDArray mom_zeros = mom;
865     FillDnsZerosRspImpl(s, &mom_zeros);
866   }
867   TBlob out_blob = out->data();
868   // reuse dns rsp implementation when storage_shape == shape
869   SGDMomLazyUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
870                                      mom.data(), req, &out_blob);
871 }
872 
873 /*!
874  * \brief Storge type inference function for optimizers which support both
875  *        lazy update and standard update, with states (e.g. 2nd order moment)
876  * \param num_states The number of states that could be row_sparse or dense
877  */
878 template<size_t num_states, typename ParamType>
879 inline bool StdOptStorageType(const nnvm::NodeAttrs& attrs,
880                               const int dev_mask,
881                               DispatchMode* dispatch_mode,
882                               std::vector<int>* in_attrs,
883                               std::vector<int>* out_attrs) {
884   using namespace common;
885   const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
886   // weight, grad, state 0, state 1, ... -> weight
887   CHECK_EQ(in_attrs->size(), 2 + num_states);
888   CHECK_EQ(out_attrs->size(), 1U);
889   const int weight_stype = in_attrs->at(0);
890   const int grad_stype = in_attrs->at(1);
891   const int state_stype = in_attrs->at(2);
892   // the storage type of all states should be the same
893   for (size_t i = 3; i <  2 + num_states; i++) {
894     CHECK_EQ(state_stype, in_attrs->at(i))
895       << "Inconsistent storage types detected in state " << i;
896   }
897   bool dispatched = false;
898   if (!dispatched && ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
899     // dns, ... -> dns
900     dispatched = storage_type_assign(out_attrs, kDefaultStorage,
901                                      dispatch_mode, DispatchMode::kFCompute);
902   }
903   if (!dispatched && grad_stype == kRowSparseStorage &&
904       (weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage) &&
905       state_stype == weight_stype) {
906     // weight and state share stype, grad's stype = rsp
907     dispatched = storage_type_assign(out_attrs, static_cast<NDArrayStorageType>(weight_stype),
908                                      dispatch_mode, DispatchMode::kFComputeEx);
909     // warn users if lazy_update is turned on
910     if (dispatched && param.lazy_update) LogLazyUpdate();
911   }
912   if (!dispatched && grad_stype == kRowSparseStorage &&
913       weight_stype == kRowSparseStorage && state_stype == kDefaultStorage) {
914     // weight,  grad, state, ...  -> weight
915     // rsp,     rsp,  dns,   ...  -> rsp, standard update
916     dispatched = storage_type_assign(out_attrs, static_cast<NDArrayStorageType>(weight_stype),
917                                      dispatch_mode, DispatchMode::kFComputeEx);
918   }
919   if (!dispatched) {
920     dispatched = dispatch_fallback(out_attrs, dispatch_mode);
921   }
922   return dispatched;
923 }
924 
925 /*
926  * \brief kernel for standard momentum update for dense weight, sparse grad and dense state.
927  */
928 template<int req, typename xpu>
929 struct SGDMomStdDnsRspDnsKernel;
930 
931 
932 /*
933  * \brief standard momentum update for dense weight, row_sparse grad and dense states.
934  */
935 template<typename xpu>
936 void SGDMomStdUpdateDnsRspDnsImpl(const SGDMomParam& param,
937                                   const OpContext& ctx,
938                                   const TBlob& weight,
939                                   const NDArray& grad,
940                                   const TBlob& mom,
941                                   const OpReqType& req,
942                                   TBlob *out);
943 
944 /*
945  * \brief standard momentum update for row_sparse grad.
946  *        both row_sparse and dense weight are supported.
947  */
948 template<typename xpu>
949 inline void SGDMomStdUpdateRspImpl(const SGDMomParam& param,
950                                    const OpContext& ctx,
951                                    const NDArray& weight,
952                                    const NDArray& grad,
953                                    const NDArray& mom,
954                                    const OpReqType& req,
955                                    NDArray *out) {
956   using namespace mxnet_op;
957   using namespace rowsparse;
958   CheckAllRowsPresent(weight, "SGDMomUpdate", "weights");
959   Stream<xpu>* s = ctx.get_stream<xpu>();
960   // fill mom with zero values (if it's in rsp storage)
961   // in order to reuse the sgd mom dns impl
962   if (mom.storage_type() == kRowSparseStorage && !mom.storage_initialized()) {
963     NDArray mom_zeros = mom;
964     FillDnsZerosRspImpl(s, &mom_zeros);
965   }
966   TBlob out_blob = out->data();
967   SGDMomStdUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
968                                     mom.data(), req, &out_blob);
969 }
970 
971 template<typename xpu>
972 inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
973                            const OpContext &ctx,
974                            const std::vector<NDArray> &inputs,
975                            const std::vector<OpReqType> &req,
976                            const std::vector<NDArray> &outputs) {
977   using namespace mxnet_op;
978   const SGDMomParam& param = nnvm::get<SGDMomParam>(attrs.parsed);
979   auto &weight = inputs[0];
980   auto &grad = inputs[1];
981   auto &mom = inputs[2];
982   const auto w_stype = weight.storage_type();
983   const auto m_stype = mom.storage_type();
984   const auto out_stype = outputs[0].storage_type();
985   NDArray out = outputs[0];
986   const bool valid_weight = w_stype == kDefaultStorage || w_stype == kRowSparseStorage;
987   const bool valid_grad = grad.storage_type() == kRowSparseStorage;
988   const bool lazy_update = param.lazy_update;
989   CHECK(w_stype == out_stype) << "Inconsistent weight stype and output stype";
990   if (valid_weight && valid_grad && m_stype == w_stype) {
991     if (lazy_update) {
992       // rsp grad && m.stype = w.stype && lazy_update = true -> lazy update
993       SGDMomLazyUpdateRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
994     } else {
995       // rsp grad && m.stype = w.stype && lazy_update = false -> std update
996       SGDMomStdUpdateRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
997     }
998   } else if (w_stype == kRowSparseStorage && valid_grad && m_stype == kDefaultStorage) {
999     // rsp weight, rsp grad, dns state -> std update
1000     SGDMomStdUpdateRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
1001   } else {
1002     LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
1003   }
1004 }
1005 
1006 
1007 struct NAGParam : public dmlc::Parameter<NAGParam> {
1008   float lr;
1009   float wd;
1010   float rescale_grad;
1011   float clip_gradient;
1012   DMLC_DECLARE_PARAMETER(NAGParam) {
1013     DMLC_DECLARE_FIELD(lr)
1014     .describe("Learning rate");
1015     DMLC_DECLARE_FIELD(wd)
1016     .set_default(0.0f)
1017     .describe("Weight decay augments the objective function with a "
1018               "regularization term that penalizes large weights. "
1019               "The penalty scales with the square of the magnitude "
1020               "of each weight.");
1021     DMLC_DECLARE_FIELD(rescale_grad)
1022     .set_default(1.0f)
1023     .describe("Rescale gradient to grad = rescale_grad*grad.");
1024     DMLC_DECLARE_FIELD(clip_gradient)
1025     .set_default(-1.0f)
1026     .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
1027               "If clip_gradient <= 0, gradient clipping is turned off. "
1028               "grad = max(min(grad, clip_gradient), -clip_gradient).");
1029   }
1030 };
1031 
1032 struct NAGMomParam : public dmlc::Parameter<NAGMomParam> {
1033   float lr;
1034   float momentum;
1035   float wd;
1036   float rescale_grad;
1037   float clip_gradient;
1038   DMLC_DECLARE_PARAMETER(NAGMomParam) {
1039     DMLC_DECLARE_FIELD(lr)
1040     .describe("Learning rate");
1041     DMLC_DECLARE_FIELD(momentum)
1042     .set_default(0.0f)
1043     .describe("The decay rate of momentum estimates at each epoch.");
1044     DMLC_DECLARE_FIELD(wd)
1045     .set_default(0.0f)
1046     .describe("Weight decay augments the objective function with a "
1047               "regularization term that penalizes large weights. "
1048               "The penalty scales with the square of the magnitude "
1049               "of each weight.");
1050     DMLC_DECLARE_FIELD(rescale_grad)
1051     .set_default(1.0f)
1052     .describe("Rescale gradient to grad = rescale_grad*grad.");
1053     DMLC_DECLARE_FIELD(clip_gradient)
1054     .set_default(-1.0f)
1055     .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
1056               "If clip_gradient <= 0, gradient clipping is turned off. "
1057               "grad = max(min(grad, clip_gradient), -clip_gradient).");
1058   }
1059 };
1060 
1061 struct NAGMomKernel {
1062   template<typename DType>
1063   MSHADOW_XINLINE static void Map(index_t i, DType* out_data, DType* mom_data,
1064     const DType* weight_data, const DType* grad_data,
1065     const DType param_clip_gradient, const DType param_momentum,
1066     const DType param_lr, const DType param_wd,
1067     const DType param_rescale_grad, const OpReqType req) {
1068     if (param_clip_gradient >= 0.0f) {
1069       mom_data[i] = param_momentum*mom_data[i];
1070       KERNEL_ASSIGN(out_data[i], req, weight_data[i]-mom_data[i]+(param_momentum+1)
1071               *(mom_data[i]-(param_lr*(mshadow_op::clip::Map(param_rescale_grad
1072                               *grad_data[i], param_clip_gradient)+(param_wd*weight_data[i])))));
1073       mom_data[i] = mom_data[i] - (param_lr*((mshadow_op::clip::Map(param_rescale_grad*grad_data[i],
1074                           param_clip_gradient))+(param_wd*weight_data[i])));
1075     } else {
1076       mom_data[i] = param_momentum*mom_data[i];
1077       KERNEL_ASSIGN(out_data[i], req, weight_data[i]-mom_data[i]+(param_momentum+1)
1078               *(mom_data[i]-(param_lr*(param_rescale_grad*grad_data[i]+param_wd*weight_data[i]))));
1079       mom_data[i] = mom_data[i] - param_lr*((param_rescale_grad*grad_data[i])
1080               +(param_wd*weight_data[i]));
1081     }
1082   }
1083 };
1084 
1085 template<typename xpu>
1086 inline void NAGMomUpdate(const nnvm::NodeAttrs& attrs,
1087                          const OpContext &ctx,
1088                          const std::vector<TBlob> &inputs,
1089                          const std::vector<OpReqType> &req,
1090                          const std::vector<TBlob> &outputs) {
1091   using namespace mxnet_op;
1092   NAGMomParam param = nnvm::get<NAGMomParam>(attrs.parsed);
1093   Stream<xpu>* s = ctx.get_stream<xpu>();
1094   MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1095     Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
1096     Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
1097     Tensor<xpu, 2, DType> mom = inputs[2].FlatTo2D<xpu, DType>(s);
1098     Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
1099     Kernel<NAGMomKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_,
1100       mom.dptr_, weight.dptr_, grad.dptr_,
1101       static_cast<DType>(param.clip_gradient),
1102       static_cast<DType>(param.momentum), static_cast<DType>(param.lr),
1103       static_cast<DType>(param.wd), static_cast<DType>(param.rescale_grad),
1104       req[0]);
1105   });
1106 }
1107 
1108 struct MP_NAGMomKernel {
1109   template<typename DType>
1110   MSHADOW_XINLINE static void Map(index_t i, DType* out_data,
1111     float* mom_data, const DType* weight_data,
1112     const DType* grad_data, float* weight32,
1113     const float param_clip_gradient,
1114     const float param_momentum, const float param_lr,
1115     const float param_wd, const float param_rescale_grad,
1116     const OpReqType req) {
1117     float w = weight32[i];
1118     if (param_clip_gradient >= 0.0f) {
1119       mom_data[i] = param_momentum*mom_data[i];
1120       w = w-mom_data[i]+(param_momentum+1)*(mom_data[i]-param_lr
1121               *(mshadow_op::clip::Map(param_rescale_grad*static_cast<float>(grad_data[i]),
1122                           param_clip_gradient)+(param_wd*w)));
1123       mom_data[i] = mom_data[i] - param_lr
1124           *((mshadow_op::clip::Map(param_rescale_grad*static_cast<float>(grad_data[i]),
1125                           param_clip_gradient))+(param_wd*w));
1126       weight32[i] = w;
1127       KERNEL_ASSIGN(out_data[i], req, w);
1128     } else {
1129       mom_data[i] = param_momentum*mom_data[i];
1130       w = w-mom_data[i]+(param_momentum+1)*(mom_data[i]-param_lr
1131               *(param_rescale_grad*static_cast<float>(grad_data[i])+(param_wd*w)));
1132       mom_data[i] = mom_data[i] - param_lr
1133           *((param_rescale_grad*static_cast<float>(grad_data[i]))+(param_wd*w));
1134       weight32[i] = w;
1135       KERNEL_ASSIGN(out_data[i], req, w);
1136     }
1137   }
1138 };
1139 
1140 template<typename xpu>
1141 inline void MP_NAGMomUpdate(const nnvm::NodeAttrs& attrs,
1142                          const OpContext &ctx,
1143                          const std::vector<TBlob> &inputs,
1144                          const std::vector<OpReqType> &req,
1145                          const std::vector<TBlob> &outputs) {
1146   using namespace mxnet_op;
1147   NAGMomParam param = nnvm::get<NAGMomParam>(attrs.parsed);
1148   Stream<xpu>* s = ctx.get_stream<xpu>();
1149   MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1150     Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
1151     Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
1152     Tensor<xpu, 2, float> mom = inputs[2].FlatTo2D<xpu, float>(s);
1153     Tensor<xpu, 2, float> weight32 = inputs[3].FlatTo2D<xpu, float>(s);
1154     Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
1155     Kernel<MP_NAGMomKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_,
1156       mom.dptr_, weight.dptr_, grad.dptr_, weight32.dptr_,
1157       param.clip_gradient, param.momentum, param.lr, param.wd,
1158       param.rescale_grad, req[0]);
1159   });
1160 }
1161 
1162 
1163 struct FTMLParam : public dmlc::Parameter<FTMLParam> {
1164   float lr;
1165   float beta1;
1166   float beta2;
1167   double epsilon;
1168   int t;
1169   float wd;
1170   float rescale_grad;
1171   float clip_grad;
1172   DMLC_DECLARE_PARAMETER(FTMLParam) {
1173     DMLC_DECLARE_FIELD(lr)
1174     .describe("Learning rate.");
1175     DMLC_DECLARE_FIELD(beta1)
1176     .set_default(0.6f)
1177     .set_range(0.0f, 1.0f)
1178     .describe("Generally close to 0.5.");
1179     DMLC_DECLARE_FIELD(beta2)
1180     .set_default(0.999f)
1181     .set_range(0.0f, 1.0f)
1182     .describe("Generally close to 1.");
1183     DMLC_DECLARE_FIELD(epsilon)
1184     .set_default(1e-8f)
1185     .describe("Epsilon to prevent div 0.");
1186     DMLC_DECLARE_FIELD(t)
1187     .describe("Number of update.");
1188     DMLC_DECLARE_FIELD(wd)
1189     .set_default(0.0f)
1190     .describe("Weight decay augments the objective function with a "
1191               "regularization term that penalizes large weights. "
1192               "The penalty scales with the square of the magnitude of each weight.");
1193     DMLC_DECLARE_FIELD(rescale_grad)
1194     .set_default(1.0f)
1195     .describe("Rescale gradient to grad = rescale_grad*grad.");
1196     DMLC_DECLARE_FIELD(clip_grad)
1197     .set_default(-1.0f)
1198     .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
1199               "If clip_gradient <= 0, gradient clipping is turned off. "
1200               "grad = max(min(grad, clip_gradient), -clip_gradient).");
1201   }
1202 };
1203 
1204 
1205 struct FTMLKernel {
1206   template<typename DType>
1207   MSHADOW_XINLINE static void Map(index_t i, DType* out, DType* weight, DType* grad,
1208     DType* d, DType* v, DType* z, const DType lr, const DType beta1,
1209     const DType beta2, const DType epsilon, const DType t,
1210     const DType wd, const DType rescale_grad, const DType clip_grad,
1211     const OpReqType req) {
1212     using namespace mshadow_op;
1213     const DType grad_i = clip_grad >= 0.0f
1214         ? clip::Map(rescale_grad * grad[i] + wd * weight[i], clip_grad)
1215         : (rescale_grad * grad[i] + wd * weight[i]);
1216     v[i] = beta2 * v[i] + (1 - beta2) * square::Map(grad_i);
1217     const DType d_t = (1 - power::Map(beta1, t)) / lr *
1218         (square_root::Map(v[i] / (1 - power::Map(beta2, t))) + epsilon);
1219     z[i] = beta1 * z[i] + (1 - beta1) * grad_i - (d_t - beta1 * d[i]) * weight[i];
1220     d[i] = d_t;
1221     KERNEL_ASSIGN(out[i], req, - z[i] / d_t);
1222   }
1223 };
1224 
1225 
1226 template<typename xpu>
1227 inline void FTMLUpdate(const nnvm::NodeAttrs& attrs,
1228                        const OpContext &ctx,
1229                        const std::vector<TBlob> &inputs,
1230                        const std::vector<OpReqType> &req,
1231                        const std::vector<TBlob> &outputs) {
1232   using namespace mxnet_op;
1233   FTMLParam param = nnvm::get<FTMLParam>(attrs.parsed);
1234   Stream<xpu>* s = ctx.get_stream<xpu>();
1235   MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1236     DType* weight_data = inputs[0].dptr<DType>();
1237     DType* grad_data = inputs[1].dptr<DType>();
1238     DType* d_data = inputs[2].dptr<DType>();
1239     DType* v_data = inputs[3].dptr<DType>();
1240     DType* z_data = inputs[4].dptr<DType>();
1241     DType* out_data = outputs[0].dptr<DType>();
1242     Kernel<FTMLKernel, xpu>::Launch(s, inputs[0].shape_.Size(), out_data,
1243       weight_data, grad_data, d_data, v_data, z_data, static_cast<DType>(param.lr),
1244       static_cast<DType>(param.beta1), static_cast<DType>(param.beta2),
1245       static_cast<DType>(param.epsilon), static_cast<DType>(param.t), static_cast<DType>(param.wd),
1246       static_cast<DType>(param.rescale_grad), static_cast<DType>(param.clip_grad),
1247       req[0]);
1248   });
1249 }
1250 
1251 struct AdamParam : public dmlc::Parameter<AdamParam> {
1252   float lr;
1253   float beta1;
1254   float beta2;
1255   float epsilon;
1256   float wd;
1257   float rescale_grad;
1258   float clip_gradient;
1259   bool lazy_update;
1260   DMLC_DECLARE_PARAMETER(AdamParam) {
1261     DMLC_DECLARE_FIELD(lr)
1262     .describe("Learning rate");
1263     DMLC_DECLARE_FIELD(beta1)
1264     .set_default(0.9f)
1265     .describe("The decay rate for the 1st moment estimates.");
1266     DMLC_DECLARE_FIELD(beta2)
1267     .set_default(0.999f)
1268     .describe("The decay rate for the 2nd moment estimates.");
1269     DMLC_DECLARE_FIELD(epsilon)
1270     .set_default(1e-8f)
1271     .describe("A small constant for numerical stability.");
1272     DMLC_DECLARE_FIELD(wd)
1273     .set_default(0.0f)
1274     .describe("Weight decay augments the objective function with a "
1275               "regularization term that penalizes large weights. "
1276               "The penalty scales with the square of the magnitude of each weight.");
1277     DMLC_DECLARE_FIELD(rescale_grad)
1278     .set_default(1.0f)
1279     .describe("Rescale gradient to grad = rescale_grad*grad.");
1280     DMLC_DECLARE_FIELD(clip_gradient)
1281     .set_default(-1.0f)
1282     .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
1283               "If clip_gradient <= 0, gradient clipping is turned off. "
1284               "grad = max(min(grad, clip_gradient), -clip_gradient).");
1285     DMLC_DECLARE_FIELD(lazy_update)
1286     .set_default(true)
1287     .describe("If true, lazy updates are applied if gradient's stype is row_sparse "
1288               "and all of w, m and v have the same stype");
1289   }
1290 };
1291 
1292 struct AdamUpdateKernel {
1293   template<typename DType>
1294   MSHADOW_XINLINE static void Map(index_t i, DType* out_data,
1295     DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data,
1296     const DType clip_gradient, const DType rescale_grad,
1297     const DType beta1, const DType beta2,
1298     const DType lr, const DType wd,
1299     const DType epsilon, const OpReqType req) {
1300     using namespace mshadow_op;
1301 
1302     DType grad_rescaled = grad_data[i] * rescale_grad + weight_data[i] * wd;
1303     if (clip_gradient >= 0.f) {
1304       grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
1305     }
1306 
1307     mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled;
1308     var_data[i] = beta2 * var_data[i] +
1309                         (1.f - beta2) * grad_rescaled * grad_rescaled;
1310 
1311     KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * mean_data[i] /
1312                   (square_root::Map(var_data[i]) + epsilon));
1313   }
1314 };
1315 
1316 template<typename xpu>
1317 inline void AdamUpdate(const nnvm::NodeAttrs& attrs,
1318                        const OpContext &ctx,
1319                        const std::vector<TBlob> &inputs,
1320                        const std::vector<OpReqType> &req,
1321                        const std::vector<TBlob> &outputs) {
1322   using namespace mxnet_op;
1323   const AdamParam& param = nnvm::get<AdamParam>(attrs.parsed);
1324   Stream<xpu>* s = ctx.get_stream<xpu>();
1325   MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1326     Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
1327     Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
1328     Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s);
1329     Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
1330     Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
1331 
1332     Kernel<AdamUpdateKernel, xpu>::Launch(s, weight.shape_.Size(),
1333           out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_,
1334           static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
1335           static_cast<DType>(param.beta1), static_cast<DType>(param.beta2),
1336           static_cast<DType>(param.lr), static_cast<DType>(param.wd),
1337           static_cast<DType>(param.epsilon), req[0]);
1338   });
1339 }
1340 
1341 template<int req, typename xpu>
1342 struct AdamDnsRspDnsKernel;
1343 
1344 /*!
1345  * Note: this kernel performs sparse adam update. For each row-slice in row_sparse
1346  * gradient, it finds the corresponding elements in weight, mean and var and performs
1347  * the update.
1348  * The kernel assumes dense weight/mean/var, and row_sparse gradient
1349  */
1350 template<int req>
1351 struct AdamDnsRspDnsKernel<req, cpu> {
1352   template<typename DType, typename IType>
1353   MSHADOW_XINLINE static void Map(index_t i, const nnvm::dim_t row_length, DType* out_data,
1354     DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx,
1355     const DType* grad_data, const DType clip_gradient, const DType beta1, const DType beta2,
1356     const DType lr, const DType wd, const DType epsilon, const DType rescale_grad) {
1357     using nnvm::dim_t;
1358     using namespace mshadow_op;
1359     const dim_t row_offset = grad_idx[i] * row_length;
1360     for (dim_t j = 0; j < row_length; j++) {
1361       // index in data/mean/var
1362       const dim_t data_i = row_offset + j;
1363       // index in grad
1364       const dim_t grad_i = i * row_length + j;
1365       const DType grad_rescaled = grad_data[grad_i] * rescale_grad + weight_data[data_i] * wd;
1366       if (clip_gradient >= 0.0f) {
1367         mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) *
1368                             clip::Map(grad_rescaled, clip_gradient);
1369         var_data[data_i] =  beta2 * var_data[data_i] + (1.f - beta2) * square::Map(
1370                             clip::Map(grad_rescaled, clip_gradient));
1371       } else {
1372         mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * grad_rescaled;
1373         var_data[data_i] = beta2 * var_data[data_i] +
1374                            (1.f - beta2) * grad_rescaled * grad_rescaled;
1375       }
1376       KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] - lr * mean_data[data_i] /
1377                     (square_root::Map(var_data[data_i]) + epsilon));
1378     }
1379   }
1380 };
1381 
1382 
1383 template<int req>
1384 struct AdamDnsRspDnsKernel<req, gpu> {
1385   template<typename DType, typename IType>
1386   MSHADOW_XINLINE static void Map(index_t i, const nnvm::dim_t row_length, DType* out_data,
1387     DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx,
1388     const DType* grad_data, const DType clip_gradient, const DType beta1, const DType beta2,
1389     const DType lr, const DType wd, const DType epsilon, const DType rescale_grad) {
1390     using nnvm::dim_t;
1391     using namespace mshadow_op;
1392     const dim_t row_id = i / row_length;
1393     const dim_t col_id = i % row_length;
1394     const dim_t row_offset = grad_idx[row_id] * row_length;
1395     // index in data/mean/var
1396     const dim_t data_i = row_offset + col_id;
1397     // index in grad
1398     DType grad_rescaled = grad_data[i] * rescale_grad + weight_data[data_i] * wd;
1399     if (clip_gradient >= 0.0f) {
1400       grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
1401     }
1402     mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * grad_rescaled;
1403     var_data[data_i] = beta2 * var_data[data_i] +
1404                        (1.f - beta2) * grad_rescaled * grad_rescaled;
1405     KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] - lr * mean_data[data_i] /
1406                   (square_root::Map(var_data[data_i]) + epsilon));
1407   }
1408 };
1409 
1410 /*
1411  * \brief lazy adam update for dense weight, dense states and rsp grad.
1412  */
1413 template<typename xpu>
1414 inline void AdamLazyUpdateDnsRspDnsImpl(const AdamParam& param,
1415                                         const OpContext& ctx,
1416                                         const TBlob& weight,
1417                                         const NDArray& grad,
1418                                         const TBlob& mean,
1419                                         const TBlob& var,
1420                                         const OpReqType& req,
1421                                         TBlob *out) {
1422   using namespace mxnet_op;
1423   using namespace rowsparse;
1424   Stream<xpu>* s = ctx.get_stream<xpu>();
1425   if (!grad.storage_initialized() || req == kNullOp) return;
1426   CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse adam_update";
1427   CHECK_GT(weight.shape_.Size(), 0);
1428   CHECK_GT(mean.shape_.Size(), 0);
1429   CHECK_GT(var.shape_.Size(), 0);
1430 
1431   MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
1432     MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
1433       MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
1434         const DType* weight_data = weight.dptr<DType>();
1435         const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
1436         const DType* grad_val = grad.data().dptr<DType>();
1437         DType* mean_data = mean.dptr<DType>();
1438         DType* var_data = var.dptr<DType>();
1439         DType* out_data = out->dptr<DType>();
1440         nnvm::dim_t num_rows = grad.aux_shape(kIdx)[0];
1441         const auto row_length = weight.shape_.ProdShape(1, weight.ndim());
1442         size_t num_threads = num_rows;
1443         if (std::is_same<xpu, gpu>::value) {
1444           num_threads = num_rows * row_length;
1445         }
1446         Kernel<AdamDnsRspDnsKernel<req_type, xpu>, xpu>::Launch(s, num_threads,
1447           row_length, out_data, mean_data, var_data, weight_data, grad_idx, grad_val,
1448           static_cast<DType>(param.clip_gradient), static_cast<DType>(param.beta1),
1449           static_cast<DType>(param.beta2), static_cast<DType>(param.lr),
1450           static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
1451           static_cast<DType>(param.rescale_grad));
1452       });
1453     });
1454   });
1455 }
1456 
1457 /*
1458  * \brief lazy adam update for both row_sparse and dense weight.
1459  *        grad is expected to be row_sparse.
1460  */
1461 template<typename xpu>
1462 inline void AdamLazyUpdateRspImpl(const AdamParam& param,
1463                                   const OpContext& ctx,
1464                                   const NDArray& weight,
1465                                   const NDArray& grad,
1466                                   const NDArray& mean,
1467                                   const NDArray& var,
1468                                   const OpReqType& req,
1469                                   NDArray *out) {
1470   using namespace mxnet_op;
1471   using namespace rowsparse;
1472   CheckAllRowsPresent(weight, "AdamUpdate", "weights");
1473   Stream<xpu>* s = ctx.get_stream<xpu>();
1474   // fill mean and variance with zero values in order to reuse the sgd mom dns impl
1475   if (mean.storage_type() == kRowSparseStorage && !mean.storage_initialized()) {
1476     NDArray mean_zeros = mean;
1477     FillDnsZerosRspImpl(s, &mean_zeros);
1478   }
1479   if (var.storage_type() == kRowSparseStorage && !var.storage_initialized()) {
1480     NDArray var_zeros = var;
1481     FillDnsZerosRspImpl(s, &var_zeros);
1482   }
1483   TBlob out_blob = out->data();
1484   // reuse dns rsp implementation when storage_shape == shape
1485   AdamLazyUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, mean.data(),
1486                                    var.data(), req, &out_blob);
1487 }
1488 
1489 /*
1490  * \brief kernel for standard adam update for dense weight, row_sparse grad and dense states.
1491  */
1492 template<int req, typename xpu>
1493 struct AdamStdDnsRspDnsKernel;
1494 
1495 /*
1496  * \brief standard adam update for dense weight, row_sparse grad and dense states.
1497  */
1498 template<typename xpu>
1499 void AdamStdUpdateDnsRspDnsImpl(const AdamParam& param,
1500                                 const OpContext& ctx,
1501                                 const TBlob& weight,
1502                                 const NDArray& grad,
1503                                 const TBlob& mean,
1504                                 const TBlob& var,
1505                                 const OpReqType& req,
1506                                 TBlob *out);
1507 
1508 /*
1509  * \brief standard adam update for both row_sparse and dense weight.
1510  *        states are expected to be dense, while grad is expected to be row_sparse.
1511  */
1512 template<typename xpu>
1513 inline void AdamStdUpdateRspImpl(const AdamParam& param,
1514                                  const OpContext& ctx,
1515                                  const NDArray& weight,
1516                                  const NDArray& grad,
1517                                  const NDArray& mean,
1518                                  const NDArray& var,
1519                                  const OpReqType& req,
1520                                  NDArray *out) {
1521   using namespace mxnet_op;
1522   using namespace rowsparse;
1523   CheckAllRowsPresent(weight, "AdamStdUpdate", "weights");
1524   TBlob out_blob = out->data();
1525   // reuse dns rsp implementation when storage_shape == shape
1526   AdamStdUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, mean.data(),
1527                                   var.data(), req, &out_blob);
1528 }
1529 
1530 template<typename xpu>
1531 inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
1532                          const OpContext &ctx,
1533                          const std::vector<NDArray> &inputs,
1534                          const std::vector<OpReqType> &req,
1535                          const std::vector<NDArray> &outputs) {
1536   const AdamParam& param = nnvm::get<AdamParam>(attrs.parsed);
1537   const auto w_stype = inputs[0].storage_type();
1538   const auto g_stype = inputs[1].storage_type();
1539   const auto m_stype = inputs[2].storage_type();
1540   const auto v_stype = inputs[3].storage_type();
1541   const auto out_stype = outputs[0].storage_type();
1542   NDArray out = outputs[0];
1543   const bool valid_weight = w_stype == kDefaultStorage || w_stype == kRowSparseStorage;
1544   CHECK(w_stype == out_stype) << "Inconsistent weight stype and output stype";
1545   CHECK(m_stype == v_stype) << "Inconsistent mean stype and var stype";
1546   if (valid_weight && g_stype == kRowSparseStorage && m_stype == w_stype) {
1547      if (param.lazy_update) {
1548        // rsp grad && m.stype = w.stype && lazy_update = true -> lazy update
1549        AdamLazyUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
1550                                   inputs[3], req[0], &out);
1551      } else {
1552        // rsp grad && m.stype = w.stype && lazy_update = false -> std update
1553        AdamStdUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
1554                                  inputs[3], req[0], &out);
1555      }
1556   } else if (w_stype == kRowSparseStorage && g_stype == kRowSparseStorage &&
1557              m_stype == kDefaultStorage) {
1558      // rsp, rsp, dns, dns -> rsp, standard update
1559      AdamStdUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
1560                                inputs[3], req[0], &out);
1561   } else {
1562     LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
1563   }
1564 }
1565 
1566 struct LambUpdatePhaseOneParam : public dmlc::Parameter<LambUpdatePhaseOneParam> {
1567     float beta1;
1568     float beta2;
1569     float epsilon;
1570     int t;
1571     bool bias_correction;
1572     float wd;
1573     float rescale_grad;
1574     float clip_gradient;
1575     DMLC_DECLARE_PARAMETER(LambUpdatePhaseOneParam) {
1576       DMLC_DECLARE_FIELD(beta1)
1577       .set_default(0.9f)
1578       .describe("The decay rate for the 1st moment estimates.");
1579       DMLC_DECLARE_FIELD(beta2)
1580       .set_default(0.999f)
1581       .describe("The decay rate for the 2nd moment estimates.");
1582       DMLC_DECLARE_FIELD(epsilon)
1583       .set_default(1e-6f)
1584       .describe("A small constant for numerical stability.");
1585       DMLC_DECLARE_FIELD(t)
1586       .describe("Index update count.");
1587       DMLC_DECLARE_FIELD(bias_correction)
1588       .set_default(true)
1589       .describe("Whether to use bias correction.");
1590       DMLC_DECLARE_FIELD(wd)
1591       .describe("Weight decay augments the objective function with a "
1592                 "regularization term that penalizes large weights. "
1593                 "The penalty scales with the square of the magnitude of each weight.");
1594       DMLC_DECLARE_FIELD(rescale_grad)
1595       .set_default(1.0f)
1596       .describe("Rescale gradient to grad = rescale_grad*grad.");
1597       DMLC_DECLARE_FIELD(clip_gradient)
1598       .set_default(-1.0f)
1599       .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
1600                 "If clip_gradient <= 0, gradient clipping is turned off. "
1601                 "grad = max(min(grad, clip_gradient), -clip_gradient).");
1602     }
1603 };
1604 
1605 struct LambUpdatePhaseTwoParam : public dmlc::Parameter<LambUpdatePhaseTwoParam> {
1606     float lr;
1607     float lower_bound;
1608     float upper_bound;
1609     DMLC_DECLARE_PARAMETER(LambUpdatePhaseTwoParam) {
1610       DMLC_DECLARE_FIELD(lr)
1611       .describe("Learning rate");
1612       DMLC_DECLARE_FIELD(lower_bound)
1613       .set_default(-1.0f)
1614       .describe("Lower limit of norm of weight. If lower_bound <= 0, Lower limit is not set");
1615       DMLC_DECLARE_FIELD(upper_bound)
1616       .set_default(-1.0f)
1617       .describe("Upper limit of norm of weight. If upper_bound <= 0, Upper limit is not set");
1618     }
1619 };
1620 
1621 struct LambUpdatePhaseOneKernel {
1622   template<typename DType>
1623   MSHADOW_XINLINE static void Map(index_t i, DType* out_data,
1624     DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data,
1625     const DType clip_gradient, const DType rescale_grad,
1626     const DType beta1, const DType beta1_t, const DType beta2, const DType beta2_t,
1627     const DType wd, const DType epsilon, const int t,
1628     bool bias_correction, const OpReqType req) {
1629     using namespace mshadow_op;
1630 
1631     DType grad_rescaled = grad_data[i] * rescale_grad;
1632     if (clip_gradient >= 0.f) {
1633       grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
1634     }
1635 
1636     mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled;
1637     var_data[i] = beta2 * var_data[i] + (1.f - beta2) * grad_rescaled * grad_rescaled;
1638 
1639     DType g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd * weight_data[i];
1640 
1641     if (bias_correction) {
1642       DType mean_hat = mean_data[i] / (1. - beta1_t);
1643       DType var_hat = var_data[i] / (1 - beta2_t);
1644       g = mean_hat / (square_root::Map(var_hat) + epsilon) + wd * weight_data[i];
1645     }
1646     KERNEL_ASSIGN(out_data[i], req, g);
1647   }
1648 };
1649 
1650 template<typename xpu>
1651 inline void LambUpdatePhaseOne(const nnvm::NodeAttrs& attrs,
1652                        const OpContext &ctx,
1653                        const std::vector<TBlob> &inputs,
1654                        const std::vector<OpReqType> &req,
1655                        const std::vector<TBlob> &outputs) {
1656   using namespace mxnet_op;
1657   const LambUpdatePhaseOneParam& param = nnvm::get<LambUpdatePhaseOneParam>(attrs.parsed);
1658   Stream<xpu>* s = ctx.get_stream<xpu>();
1659   MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1660     DType beta1_t = std::pow(param.beta1, param.t);
1661     DType beta2_t = std::pow(param.beta2, param.t);
1662     Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
1663     Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
1664     Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s);
1665     Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
1666     Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
1667 
1668   Kernel<LambUpdatePhaseOneKernel, xpu>::Launch(s, weight.shape_.Size(),
1669     out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_,
1670     static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
1671     static_cast<DType>(param.beta1), beta1_t, static_cast<DType>(param.beta2), beta2_t,
1672     static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
1673     static_cast<int>(param.t), static_cast<bool>(param.bias_correction), req[0]);
1674   });
1675 }
1676 
1677 inline bool LambUpdatePhaseTwoShape(const nnvm::NodeAttrs& attrs,
1678                             mxnet::ShapeVector* in_attrs,
1679                             mxnet::ShapeVector* out_attrs) {
1680   CHECK_EQ(in_attrs->size(), 4U);
1681   CHECK_EQ(out_attrs->size(), 1U);
1682 
1683   mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1);
1684 
1685   mxnet::TShape& weight_shape = in_attrs->at(0);
1686   mxnet::TShape& g_shape = in_attrs->at(1);
1687   CHECK_EQ(weight_shape.ndim(), g_shape.ndim())
1688            << "total no. of dimensions for weights and g must match";
1689   for (int i=0; i < weight_shape.ndim(); ++i) {
1690     CHECK_EQ(weight_shape[i], g_shape[i])
1691            << "weight and g dimension size mismatch at " << i << "-th index";
1692   }
1693   mxnet::TShape& r1_shape = in_attrs->at(2);
1694   mxnet::TShape& r2_shape = in_attrs->at(3);
1695   CHECK_EQ(r1_shape[0], 1U) << "r1 shape incorrect";
1696   CHECK_EQ(r2_shape[0], 1U) << "r2 shape incorrect";
1697   for (int i=0; i < expected_out.ndim(); ++i) {
1698     expected_out[i] = weight_shape[i];
1699   }
1700 
1701   SHAPE_ASSIGN_CHECK(*out_attrs, 0, expected_out);
1702   return shape_is_known(expected_out);
1703 }
1704 
1705 struct LambUpdatePhaseTwoKernel {
1706   template<typename DType>
1707   MSHADOW_XINLINE static void Map(index_t i, DType* out_data,
1708     const DType* weight_data, const DType* g,
1709     const DType* r1, const DType* r2,
1710     DType lr, const DType lower_bound,
1711     const DType upper_bound, const OpReqType req) {
1712     using namespace mshadow_op;
1713 
1714     DType new_r1 = r1[0];
1715     if (lower_bound >= 0) {
1716       new_r1 = maximum::Map(new_r1, lower_bound);
1717     }
1718     if (upper_bound >= 0) {
1719       new_r1 = minimum::Map(new_r1, upper_bound);
1720     }
1721     if (new_r1 == 0.0f || r2[0] == 0.0f) {
1722       lr = lr * 1.0f;
1723     } else {
1724       lr = lr * new_r1 / r2[0];
1725     }
1726 
1727     KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * g[i]);
1728   }
1729 };
1730 
1731 template<typename xpu>
1732 inline void LambUpdatePhaseTwo(const nnvm::NodeAttrs& attrs,
1733                        const OpContext &ctx,
1734                        const std::vector<TBlob> &inputs,
1735                        const std::vector<OpReqType> &req,
1736                        const std::vector<TBlob> &outputs) {
1737   using namespace mxnet_op;
1738   const LambUpdatePhaseTwoParam& param = nnvm::get<LambUpdatePhaseTwoParam>(attrs.parsed);
1739   Stream<xpu>* s = ctx.get_stream<xpu>();
1740   MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1741     Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
1742     Tensor<xpu, 2, DType> g = inputs[1].FlatTo2D<xpu, DType>(s);
1743     Tensor<xpu, 2, DType> r1 = inputs[2].FlatTo2D<xpu, DType>(s);
1744     Tensor<xpu, 2, DType> r2 = inputs[3].FlatTo2D<xpu, DType>(s);
1745     Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
1746 
1747   Kernel<LambUpdatePhaseTwoKernel, xpu>::Launch(s, weight.shape_.Size(),
1748     out.dptr_, weight.dptr_, g.dptr_, r1.dptr_, r2.dptr_,
1749     static_cast<DType>(param.lr), static_cast<DType>(param.lower_bound),
1750     static_cast<DType>(param.upper_bound), req[0]);
1751   });
1752 }
1753 
1754 template<int n_in, int n_out, int total_in>
1755 inline bool MPLambPhaseOneType(const nnvm::NodeAttrs& attrs,
1756                              std::vector<int> *in_attrs,
1757                              std::vector<int> *out_attrs) {
1758   CHECK_EQ(in_attrs->size(), static_cast<size_t>(total_in)) << " in operator " << attrs.name;
1759   CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator " << attrs.name;
1760   for (int i = 0; i < n_in; ++i) {
1761     TYPE_ASSIGN_CHECK(*in_attrs, i, mshadow::kFloat16);
1762   }
1763   for (int i = n_in; i < total_in; ++i) {
1764     TYPE_ASSIGN_CHECK(*in_attrs, i, mshadow::kFloat32);
1765   }
1766   for (int i = 0; i < n_out; ++i) {
1767     TYPE_ASSIGN_CHECK(*out_attrs, i, mshadow::kFloat32);
1768   }
1769   return true;
1770 }
1771 
1772 struct MPLambUpdatePhaseOneKernel {
1773   template<typename DType>
1774   MSHADOW_XINLINE static void Map(index_t i, float* out_data,
1775     float* mean_data, float* var_data, const DType* weight_data,
1776     const DType* grad_data, const float* weight32_data,
1777     const float clip_gradient, const float rescale_grad,
1778     const float beta1_t, const float beta1,
1779     const float beta2_t, const float beta2,
1780     const float wd, const float epsilon, const int t,
1781     bool bias_correction, const OpReqType req) {
1782     using namespace mshadow_op;
1783 
1784     float grad_rescaled = grad_data[i] * rescale_grad;
1785     if (clip_gradient >= 0.f) {
1786       grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
1787     }
1788 
1789     mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled;
1790     var_data[i] = beta2 * var_data[i] + (1.f - beta2) * grad_rescaled * grad_rescaled;
1791 
1792     float g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd * weight32_data[i];
1793 
1794     if (bias_correction) {
1795       float mean_hat = mean_data[i] / (1. - beta1_t);
1796       float var_hat = var_data[i] / (1 - beta2_t);
1797       g = mean_hat / (square_root::Map(var_hat) + epsilon) + wd * weight32_data[i];
1798     }
1799     KERNEL_ASSIGN(out_data[i], req, g);
1800   }
1801 };
1802 
1803 template<typename xpu>
1804 inline void MPLambUpdatePhaseOne(const nnvm::NodeAttrs& attrs,
1805                        const OpContext &ctx,
1806                        const std::vector<TBlob> &inputs,
1807                        const std::vector<OpReqType> &req,
1808                        const std::vector<TBlob> &outputs) {
1809   using namespace mxnet_op;
1810   const LambUpdatePhaseOneParam& param = nnvm::get<LambUpdatePhaseOneParam>(attrs.parsed);
1811   Stream<xpu>* s = ctx.get_stream<xpu>();
1812   MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1813     float beta1_t = std::pow(param.beta1, param.t);
1814     float beta2_t = std::pow(param.beta2, param.t);
1815     Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
1816     Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
1817     Tensor<xpu, 2, float> mean = inputs[2].FlatTo2D<xpu, float>(s);
1818     Tensor<xpu, 2, float> var = inputs[3].FlatTo2D<xpu, float>(s);
1819     Tensor<xpu, 2, float> weight32 = inputs[4].FlatTo2D<xpu, float>(s);
1820     Tensor<xpu, 2, float> out = outputs[0].FlatTo2D<xpu, float>(s);
1821 
1822   Kernel<MPLambUpdatePhaseOneKernel, xpu>::Launch(s, weight.shape_.Size(),
1823     out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_, weight32.dptr_,
1824     param.clip_gradient, param.rescale_grad, beta1_t, param.beta1, beta2_t, param.beta2,
1825     param.wd, param.epsilon, param.t, param.bias_correction, req[0]);
1826   });
1827 }
1828 
1829 inline bool MPLambUpdatePhaseTwoShape(const nnvm::NodeAttrs& attrs,
1830                             mxnet::ShapeVector* in_attrs,
1831                             mxnet::ShapeVector* out_attrs) {
1832   CHECK_EQ(in_attrs->size(), 5U);
1833   CHECK_EQ(out_attrs->size(), 1U);
1834 
1835   mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1);
1836 
1837   mxnet::TShape& weight_shape = in_attrs->at(0);
1838   mxnet::TShape& g_shape = in_attrs->at(1);
1839   mxnet::TShape& weight32_shape = in_attrs->at(4);
1840   CHECK_EQ(weight_shape.ndim(), g_shape.ndim())
1841            << "total no. of dimensions for weights and g must match";
1842   CHECK_EQ(weight_shape.ndim(), weight32_shape.ndim())
1843            << "total no. of dimensions for weights and g must match";
1844   for (int i=0; i < weight_shape.ndim(); ++i) {
1845     CHECK_EQ(weight_shape[i], g_shape[i])
1846            << "weight and g dimension size mismatch at " << i << "-th index";
1847     CHECK_EQ(weight_shape[i], weight32_shape[i])
1848            << "weight and g dimension size mismatch at " << i << "-th index";
1849   }
1850   mxnet::TShape& r1_shape = in_attrs->at(2);
1851   mxnet::TShape& r2_shape = in_attrs->at(3);
1852   CHECK_EQ(r1_shape[0], 1U) << "r1 shape incorrect";
1853   CHECK_EQ(r2_shape[0], 1U) << "r2 shape incorrect";
1854   for (int i=0; i < expected_out.ndim(); ++i) {
1855     expected_out[i] = weight_shape[i];
1856   }
1857 
1858   SHAPE_ASSIGN_CHECK(*out_attrs, 0, expected_out);
1859   return shape_is_known(expected_out);
1860 }
1861 
1862 struct MPLambUpdatePhaseTwoKernel {
1863   template<typename DType>
1864   MSHADOW_XINLINE static void Map(index_t i, DType* out_data,
1865     const DType* weight_data, const float* g,
1866     const float* r1, const float* r2, const float* weight32_data,
1867     float lr, const float lower_bound,
1868     const float upper_bound, const OpReqType req) {
1869     using namespace mshadow_op;
1870 
1871     float new_r1 = r1[0];
1872     if (lower_bound >= 0) {
1873       new_r1 = maximum::Map(new_r1, lower_bound);
1874     }
1875     if (upper_bound >= 0) {
1876       new_r1 = minimum::Map(new_r1, upper_bound);
1877     }
1878     if (new_r1 == 0.0f || r2[0] == 0.0f) {
1879       lr = lr * 1.0f;
1880     } else {
1881       lr = lr * new_r1 / r2[0];
1882     }
1883 
1884     KERNEL_ASSIGN(out_data[i], req, weight32_data[i] - lr * g[i]);
1885   }
1886 };
1887 
1888 template<typename xpu>
1889 inline void MPLambUpdatePhaseTwo(const nnvm::NodeAttrs& attrs,
1890                        const OpContext &ctx,
1891                        const std::vector<TBlob> &inputs,
1892                        const std::vector<OpReqType> &req,
1893                        const std::vector<TBlob> &outputs) {
1894   using namespace mxnet_op;
1895   const LambUpdatePhaseTwoParam& param = nnvm::get<LambUpdatePhaseTwoParam>(attrs.parsed);
1896   Stream<xpu>* s = ctx.get_stream<xpu>();
1897   MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1898     Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
1899     Tensor<xpu, 2, float> g = inputs[1].FlatTo2D<xpu, float>(s);
1900     Tensor<xpu, 2, float> r1 = inputs[2].FlatTo2D<xpu, float>(s);
1901     Tensor<xpu, 2, float> r2 = inputs[3].FlatTo2D<xpu, float>(s);
1902     Tensor<xpu, 2, float> weight32 = inputs[4].FlatTo2D<xpu, float>(s);
1903     Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
1904 
1905   Kernel<MPLambUpdatePhaseTwoKernel, xpu>::Launch(s, weight.shape_.Size(),
1906     out.dptr_, weight.dptr_, g.dptr_, r1.dptr_, r2.dptr_, weight32.dptr_,
1907     param.lr, param.lower_bound,
1908     param.upper_bound, req[0]);
1909   });
1910 }
1911 
1912 // This RMSProp code follows the version in
1913 // http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45)
1914 // by Alex Graves, 2013.
1915 struct RMSPropAlexParam : public dmlc::Parameter<RMSPropAlexParam> {
1916   float lr;
1917   float gamma1;
1918   float gamma2;
1919   float epsilon;
1920   float wd;
1921   float rescale_grad;
1922   float clip_gradient;
1923   float clip_weights;
1924   DMLC_DECLARE_PARAMETER(RMSPropAlexParam) {
1925     DMLC_DECLARE_FIELD(lr)
1926     .describe("Learning rate");
1927     DMLC_DECLARE_FIELD(gamma1).set_default(0.95f)
1928     .describe("Decay rate.");
1929     DMLC_DECLARE_FIELD(gamma2).set_default(0.9f)
1930     .describe("Decay rate.");
1931     DMLC_DECLARE_FIELD(epsilon).set_default(1e-8f)
1932     .describe("A small constant for numerical stability.");
1933     DMLC_DECLARE_FIELD(wd).set_default(0.0f)
1934     .describe("Weight decay augments the objective function with a "
1935               "regularization term that penalizes large weights. "
1936               "The penalty scales with the square of the magnitude of each weight.");
1937     DMLC_DECLARE_FIELD(rescale_grad)
1938     .set_default(1.0f)
1939     .describe("Rescale gradient to grad = rescale_grad*grad.");
1940     DMLC_DECLARE_FIELD(clip_gradient)
1941     .set_default(-1.0f)
1942     .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
1943               "If clip_gradient <= 0, gradient clipping is turned off. "
1944               "grad = max(min(grad, clip_gradient), -clip_gradient).");
1945     DMLC_DECLARE_FIELD(clip_weights)
1946     .set_default(-1.0f)
1947     .describe("Clip weights to the range of [-clip_weights, clip_weights] "
1948               "If clip_weights <= 0, weight clipping is turned off. "
1949               "weights = max(min(weights, clip_weights), -clip_weights).");
1950   }
1951 };
1952 
1953 struct RMSPropAlexUpdateKernel {
1954   template<typename DType>
1955   MSHADOW_XINLINE static void Map(index_t i, DType* out_data,
1956     DType* state_n_data, DType* state_g_data, DType* delta_data,
1957     const DType* weight_data, const DType* grad_data,
1958     const DType clip_gradient, const DType rescale_grad,
1959     const DType gamma1, const DType gamma2,
1960     const DType lr, const DType wd,
1961     const DType clip_weights, const DType epsilon,
1962     const OpReqType req) {
1963     using namespace mshadow_op;
1964 
1965     DType grad_rescaled = rescale_grad * grad_data[i] + wd * weight_data[i];
1966     if (clip_gradient >= 0.0f) {
1967       grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
1968     }
1969 
1970     state_n_data[i] = (1.f - gamma1) * grad_rescaled * grad_rescaled +
1971                       gamma1 * state_n_data[i];
1972     state_g_data[i] = (1.f - gamma1) * grad_rescaled +
1973                       gamma1 * state_g_data[i];
1974     delta_data[i] = gamma2 * delta_data[i] -
1975                     (lr * (grad_rescaled) /
1976                       (square_root::Map(state_n_data[i] -
1977                                         state_g_data[i] * state_g_data[i] + epsilon)));
1978 
1979     if (clip_weights >= 0.0f) {
1980       const DType clipped_weight = clip::Map(weight_data[i] + delta_data[i], clip_weights);
1981       KERNEL_ASSIGN(out_data[i], req, clipped_weight);
1982     } else {
1983       KERNEL_ASSIGN(out_data[i], req, weight_data[i] + delta_data[i]);
1984     }
1985   }
1986 };
1987 
1988 template <typename xpu>
1989 inline void RMSPropAlexUpdate(const nnvm::NodeAttrs &attrs,
1990                               const OpContext &ctx,
1991                               const std::vector<TBlob> &inputs,
1992                               const std::vector<OpReqType> &req,
1993                               const std::vector<TBlob> &outputs) {
1994   using namespace mxnet_op;
1995   const RMSPropAlexParam &param = nnvm::get<RMSPropAlexParam>(attrs.parsed);
1996   Stream<xpu> *s = ctx.get_stream<xpu>();
1997   MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1998     DType* weight_data = inputs[0].dptr<DType>();
1999     DType* grad_data = inputs[1].dptr<DType>();
2000     DType* state_n_data = inputs[2].dptr<DType>();
2001     DType* state_g_data = inputs[3].dptr<DType>();
2002     DType* delta_data = inputs[4].dptr<DType>();
2003     DType* out_data = outputs[0].dptr<DType>();
2004 
2005     Kernel<RMSPropAlexUpdateKernel, xpu>::Launch(s, inputs[0].shape_.Size(),
2006       out_data, state_n_data, state_g_data, delta_data, weight_data, grad_data,
2007       static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
2008       static_cast<DType>(param.gamma1), static_cast<DType>(param.gamma2),
2009       static_cast<DType>(param.lr), static_cast<DType>(param.wd),
2010       static_cast<DType>(param.clip_weights), static_cast<DType>(param.epsilon), req[0]);
2011   });
2012 }
2013 
2014 // This RMSProp code follows the version in
2015 // http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf
2016 // by Tieleman & Hinton, 2012
2017 struct RMSPropParam : public dmlc::Parameter<RMSPropParam> {
2018   float lr;
2019   float gamma1;
2020   float epsilon;
2021   float wd;
2022   float rescale_grad;
2023   float clip_gradient;
2024   float clip_weights;
2025   DMLC_DECLARE_PARAMETER(RMSPropParam) {
2026     DMLC_DECLARE_FIELD(lr)
2027     .describe("Learning rate");
2028     DMLC_DECLARE_FIELD(gamma1).set_default(0.95f)
2029     .describe("The decay rate of momentum estimates.");
2030     DMLC_DECLARE_FIELD(epsilon).set_default(1e-8f)
2031     .describe("A small constant for numerical stability.");
2032     DMLC_DECLARE_FIELD(wd).set_default(0.0f)
2033     .describe("Weight decay augments the objective function with a "
2034               "regularization term that penalizes large weights. "
2035               "The penalty scales with the square of the magnitude of each weight.");
2036     DMLC_DECLARE_FIELD(rescale_grad)
2037     .set_default(1.0f)
2038     .describe("Rescale gradient to grad = rescale_grad*grad.");
2039     DMLC_DECLARE_FIELD(clip_gradient)
2040     .set_default(-1.0f)
2041     .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
2042               "If clip_gradient <= 0, gradient clipping is turned off. "
2043               "grad = max(min(grad, clip_gradient), -clip_gradient).");
2044     DMLC_DECLARE_FIELD(clip_weights)
2045     .set_default(-1.0f)
2046     .describe("Clip weights to the range of [-clip_weights, clip_weights] "
2047               "If clip_weights <= 0, weight clipping is turned off. "
2048               "weights = max(min(weights, clip_weights), -clip_weights).");
2049   }
2050 };
2051 
2052 struct RMSPropUpdateKernel {
2053   template<typename DType>
2054   MSHADOW_XINLINE static void Map(index_t i,
2055     DType* out_data, DType* state_n_data,
2056     const DType* weight_data, const DType* grad_data,
2057     const DType clip_gradient, const DType rescale_grad,
2058     const DType gamma1, const DType lr, const DType wd,
2059     const DType clip_weights, const DType epsilon,
2060     const OpReqType req) {
2061     using namespace mshadow_op;
2062 
2063     DType grad_rescaled = rescale_grad * grad_data[i] + wd * weight_data[i];
2064     if (clip_gradient >= 0.0f) {
2065       grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
2066     }
2067 
2068     state_n_data[i] = (1.f - gamma1) * (grad_rescaled * grad_rescaled) + gamma1 * state_n_data[i];
2069 
2070     DType weight = weight_data[i] -
2071                    lr * (grad_rescaled / square_root::Map(state_n_data[i] + epsilon));
2072     if (clip_weights >= 0.0f) {
2073       weight = clip::Map(weight, clip_weights);
2074     }
2075     KERNEL_ASSIGN(out_data[i], req, weight);
2076   }
2077 };
2078 
2079 template <typename xpu>
2080 inline void RMSPropUpdate(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
2081                           const std::vector<TBlob> &inputs,
2082                           const std::vector<OpReqType> &req,
2083                           const std::vector<TBlob> &outputs) {
2084   using namespace mxnet_op;
2085   const RMSPropParam &param = nnvm::get<RMSPropParam>(attrs.parsed);
2086   Stream<xpu> *s = ctx.get_stream<xpu>();
2087   MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
2088     DType* weight_data = inputs[0].dptr<DType>();
2089     DType* grad_data = inputs[1].dptr<DType>();
2090     DType* state_n_data = inputs[2].dptr<DType>();
2091     DType* out_data = outputs[0].dptr<DType>();
2092 
2093     Kernel<RMSPropUpdateKernel, xpu>::Launch(s, inputs[0].shape_.Size(),
2094       out_data, state_n_data, weight_data, grad_data,
2095       static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
2096       static_cast<DType>(param.gamma1), static_cast<DType>(param.lr), static_cast<DType>(param.wd),
2097       static_cast<DType>(param.clip_weights), static_cast<DType>(param.epsilon), req[0]);
2098   });
2099 }
2100 
2101 struct FtrlParam : public dmlc::Parameter<FtrlParam> {
2102   float lr;
2103   float lamda1;
2104   float beta;
2105   float wd;
2106   float rescale_grad;
2107   float clip_gradient;
2108   DMLC_DECLARE_PARAMETER(FtrlParam) {
2109     DMLC_DECLARE_FIELD(lr)
2110     .describe("Learning rate");
2111     DMLC_DECLARE_FIELD(lamda1)
2112     .set_default(0.01f)
2113     .describe("The L1 regularization coefficient.");
2114     DMLC_DECLARE_FIELD(beta)
2115     .set_default(1.0f)
2116     .describe("Per-Coordinate Learning Rate beta.");
2117     DMLC_DECLARE_FIELD(wd)
2118     .set_default(0.0f)
2119     .describe("Weight decay augments the objective function with a "
2120               "regularization term that penalizes large weights. "
2121               "The penalty scales with the square of the magnitude of each weight.");
2122     DMLC_DECLARE_FIELD(rescale_grad)
2123     .set_default(1.0f)
2124     .describe("Rescale gradient to grad = rescale_grad*grad.");
2125     DMLC_DECLARE_FIELD(clip_gradient)
2126     .set_default(-1.0f)
2127     .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
2128               "If clip_gradient <= 0, gradient clipping is turned off. "
2129               "grad = max(min(grad, clip_gradient), -clip_gradient).");
2130   }
2131 };
2132 
2133 struct FtrlUpdateKernel {
2134   template<typename DType>
2135   MSHADOW_XINLINE static void Map(index_t i, DType* out_data,
2136     DType* n_data, DType* z_data, const DType* weight_data, const DType* grad_data,
2137     const DType clip_gradient, const DType rescale_grad,
2138     const DType beta, const DType lamda1,
2139     const DType lr, const DType wd,
2140     const OpReqType req) {
2141     using namespace mshadow_op;
2142 
2143     DType grad_rescaled = grad_data[i] * rescale_grad;
2144     if (clip_gradient >= 0.0f) {
2145       grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
2146     }
2147 
2148     z_data[i] += grad_rescaled - (square_root::Map(n_data[i] +
2149                       square::Map(grad_rescaled)) - square_root::Map(n_data[i])) *
2150                       weight_data[i] / lr;
2151     n_data[i] += square::Map(grad_rescaled);
2152 
2153     KERNEL_ASSIGN(out_data[i], req,
2154                   (sign::Map(z_data[i]) * lamda1 - z_data[i]) /
2155                   ((beta + square_root::Map(n_data[i])) / lr + wd) *
2156                   gt::Map(abs::Map(z_data[i]), lamda1));
2157   }
2158 };
2159 
2160 template<typename xpu>
2161 inline void FtrlUpdate(const nnvm::NodeAttrs& attrs,
2162                        const OpContext &ctx,
2163                        const std::vector<TBlob> &inputs,
2164                        const std::vector<OpReqType> &req,
2165                        const std::vector<TBlob> &outputs) {
2166   using namespace mxnet_op;
2167 
2168   const FtrlParam& param = nnvm::get<FtrlParam>(attrs.parsed);
2169   Stream<xpu>* s = ctx.get_stream<xpu>();
2170   MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
2171     Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
2172     Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
2173     Tensor<xpu, 2, DType> z = inputs[2].FlatTo2D<xpu, DType>(s);
2174     Tensor<xpu, 2, DType> n = inputs[3].FlatTo2D<xpu, DType>(s);
2175     Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
2176 
2177     Kernel<FtrlUpdateKernel, xpu>::Launch(s, weight.shape_.Size(),
2178       out.dptr_, n.dptr_, z.dptr_, weight.dptr_, grad.dptr_,
2179       static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
2180       static_cast<DType>(param.beta), static_cast<DType>(param.lamda1),
2181       static_cast<DType>(param.lr), static_cast<DType>(param.wd), req[0]);
2182   });
2183 }
2184 
2185 template<int req>
2186 struct FtrlDnsRspDnsKernel {
2187   template<typename DType, typename IType>
2188   MSHADOW_XINLINE static void Map(index_t i, const nnvm::dim_t row_length, DType* out_data,
2189     DType* z_data, DType* n_data, const DType* weight_data, const IType* grad_idx,
2190     const DType* grad_data, const DType clip_gradient, const DType lamda1, const DType beta,
2191     const DType lr, const DType wd, const DType rescale_grad) {
2192     using nnvm::dim_t;
2193     using namespace mshadow_op;
2194     const dim_t row_offset = grad_idx[i] * row_length;
2195     for (dim_t j = 0; j < row_length; j++) {
2196       // index in data/z/n
2197       const dim_t data_i = row_offset + j;
2198       // index in grad
2199       const dim_t grad_i = i * row_length + j;
2200       const DType grad_rescaled = grad_data[grad_i] * rescale_grad;
2201       if (clip_gradient >= 0.0f) {
2202         z_data[data_i] += clip::Map(grad_rescaled, clip_gradient) -
2203                           (square_root::Map(n_data[data_i] +
2204                           square::Map(clip::Map(grad_rescaled, clip_gradient))) -
2205                           square_root::Map(n_data[data_i])) * weight_data[data_i] / lr;
2206         n_data[data_i] += square::Map(clip::Map(grad_rescaled, clip_gradient));
2207       } else {
2208         z_data[data_i] += grad_rescaled - (square_root::Map(n_data[data_i] +
2209                           square::Map(grad_rescaled)) - square_root::Map(n_data[data_i])) *
2210                           weight_data[data_i] / lr;
2211         n_data[data_i] += square::Map(grad_rescaled);
2212       }
2213       KERNEL_ASSIGN(out_data[data_i], req,
2214                     (sign::Map(z_data[data_i]) * lamda1 - z_data[data_i]) /
2215                     ((beta + square_root::Map(n_data[data_i])) / lr + wd) *
2216                     gt::Map(abs::Map(z_data[data_i]), lamda1));
2217     }
2218   }
2219 };
2220 
2221 
2222 template<typename xpu>
2223 inline void FtrlUpdateDnsRspDnsImpl(const FtrlParam& param,
2224                                     const OpContext& ctx,
2225                                     const TBlob& weight,
2226                                     const NDArray& grad,
2227                                     const TBlob& z,
2228                                     const TBlob& n,
2229                                     const OpReqType& req,
2230                                     TBlob *out) {
2231   using namespace mxnet_op;
2232   using namespace rowsparse;
2233   Stream<xpu>* s = ctx.get_stream<xpu>();
2234   if (!grad.storage_initialized() || req == kNullOp) return;
2235   CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse ftrl_update";
2236   CHECK_GT(weight.shape_.Size(), 0);
2237   CHECK_GT(z.shape_.Size(), 0);
2238   CHECK_GT(n.shape_.Size(), 0);
2239 
2240   MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
2241     MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
2242       MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
2243         const DType* weight_data = weight.dptr<DType>();
2244         const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
2245         const DType* grad_val = grad.data().dptr<DType>();
2246         DType* z_data = z.dptr<DType>();
2247         DType* n_data = n.dptr<DType>();
2248         DType* out_data = out->dptr<DType>();
2249         nnvm::dim_t num_rows = grad.aux_shape(kIdx)[0];
2250         const auto row_length = weight.shape_.ProdShape(1, weight.ndim());
2251         Kernel<FtrlDnsRspDnsKernel<req_type>, xpu>::Launch(s, num_rows, row_length,
2252           out_data, z_data, n_data, weight_data, grad_idx, grad_val,
2253           static_cast<DType>(param.clip_gradient), static_cast<DType>(param.lamda1),
2254           static_cast<DType>(param.beta), static_cast<DType>(param.lr),
2255           static_cast<DType>(param.wd), static_cast<DType>(param.rescale_grad));
2256       });
2257     });
2258   });
2259 }
2260 
2261 template<typename xpu>
2262 inline void FtrlUpdateRspRspRspImpl(const FtrlParam& param,
2263                                     const OpContext& ctx,
2264                                     const NDArray& weight,
2265                                     const NDArray& grad,
2266                                     const NDArray& z,
2267                                     const NDArray& n,
2268                                     const OpReqType& req,
2269                                     NDArray *out) {
2270   using namespace mshadow;
2271   using namespace mshadow::expr;
2272   using namespace mxnet_op;
2273   using namespace rowsparse;
2274   CheckAllRowsPresent(weight, "FtrlUpdate", "weights");
2275   Stream<xpu>* s = ctx.get_stream<xpu>();
2276   // fill z and n with zero values in order to reuse the sgd mom dns impl
2277   if (!z.storage_initialized()) {
2278     NDArray z_zeros = z;
2279     FillDnsZerosRspImpl(s, &z_zeros);
2280   }
2281   if (!n.storage_initialized()) {
2282     NDArray n_zeros = n;
2283     FillDnsZerosRspImpl(s, &n_zeros);
2284   }
2285   TBlob out_blob = out->data();
2286   // reuse dns rsp implementation when storage_shape == shape
2287   FtrlUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, z.data(),
2288                                n.data(), req, &out_blob);
2289 }
2290 
2291 template<typename xpu>
2292 inline void FtrlUpdateEx(const nnvm::NodeAttrs& attrs,
2293                          const OpContext &ctx,
2294                          const std::vector<NDArray> &inputs,
2295                          const std::vector<OpReqType> &req,
2296                          const std::vector<NDArray> &outputs) {
2297   const FtrlParam& param = nnvm::get<FtrlParam>(attrs.parsed);
2298   const auto weight_stype = inputs[0].storage_type();
2299   const auto z_stype = inputs[2].storage_type();
2300   const auto n_stype = inputs[3].storage_type();
2301 
2302   const auto out_stype = outputs[0].storage_type();
2303   CHECK_EQ(z_stype, weight_stype) << "Inconsistent storage type detected between "
2304            << " z.stype = " << z_stype << " and weight.stype = " << weight_stype;
2305   CHECK_EQ(n_stype, weight_stype) << "Inconsistent storage type detected between "
2306            << " n.stype = " << n_stype << " and weight.stype = " << weight_stype;
2307   if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) && out_stype == kRowSparseStorage) {
2308      NDArray out = outputs[0];
2309      FtrlUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
2310                                   inputs[3], req[0], &out);
2311   } else {
2312     LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
2313   }
2314 }
2315 
2316 
2317 // Implementation for signSGD and Signum
2318 
2319 struct SignSGDParam : public dmlc::Parameter<SignSGDParam> {
2320   float lr;
2321   float wd;
2322   float rescale_grad;
2323   float clip_gradient;
2324   DMLC_DECLARE_PARAMETER(SignSGDParam) {
2325     DMLC_DECLARE_FIELD(lr)
2326     .describe("Learning rate");
2327     DMLC_DECLARE_FIELD(wd)
2328     .set_default(0.0f)
2329     .describe("Weight decay augments the objective function with a "
2330               "regularization term that penalizes large weights. "
2331               "The penalty scales with the square of the magnitude of each weight.");
2332     DMLC_DECLARE_FIELD(rescale_grad)
2333     .set_default(1.0f)
2334     .describe("Rescale gradient to grad = rescale_grad*grad.");
2335     DMLC_DECLARE_FIELD(clip_gradient)
2336     .set_default(-1.0f)
2337     .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
2338               "If clip_gradient <= 0, gradient clipping is turned off. "
2339               "grad = max(min(grad, clip_gradient), -clip_gradient).");
2340   }
2341 };
2342 
2343 
2344 struct SignSGDKernel {
2345   template<typename DType>
2346   MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const DType* weight_data,
2347     const DType* grad_data, const DType param_clip_gradient,
2348     const DType param_lr, const DType param_wd, const DType param_rescale_grad,
2349     const OpReqType req) {
2350 
2351     // param_clip_gradient has no effect for SignSGD
2352     KERNEL_ASSIGN(out_data[i], req,
2353              (1.f-param_lr*param_wd)*weight_data[i]
2354                - (param_lr)*((grad_data[i] > 0) - (grad_data[i] < 0)));
2355   }
2356 };
2357 
2358 template<typename xpu>
2359 inline void SignSGDUpdate(const nnvm::NodeAttrs& attrs,
2360                       const OpContext &ctx,
2361                       const std::vector<TBlob> &inputs,
2362                       const std::vector<OpReqType> &req,
2363                       const std::vector<TBlob> &outputs) {
2364   using namespace mxnet_op;
2365   const SignSGDParam& param = nnvm::get<SignSGDParam>(attrs.parsed);
2366   Stream<xpu>* s = ctx.get_stream<xpu>();
2367   MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
2368     Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
2369     Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
2370     Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
2371     Kernel<SignSGDKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, weight.dptr_,
2372       grad.dptr_, static_cast<DType>(param.clip_gradient),
2373       static_cast<DType>(param.lr), static_cast<DType>(param.wd),
2374       static_cast<DType>(param.rescale_grad), req[0]);
2375   });
2376 }
2377 
2378 
2379 struct SignumParam : public dmlc::Parameter<SignumParam> {
2380   float lr;
2381   float momentum;
2382   float wd;
2383   float rescale_grad;
2384   float clip_gradient;
2385   float wd_lh;  // the amount of algorithmic weight decay by Loshchilov and Frank Hutter
2386   DMLC_DECLARE_PARAMETER(SignumParam) {
2387     DMLC_DECLARE_FIELD(lr)
2388     .describe("Learning rate");
2389     DMLC_DECLARE_FIELD(momentum)
2390     .set_default(0.0f)
2391     .describe("The decay rate of momentum estimates at each epoch.");
2392     DMLC_DECLARE_FIELD(wd)
2393     .set_default(0.0f)
2394     .describe("Weight decay augments the objective function with a "
2395               "regularization term that penalizes large weights. "
2396               "The penalty scales with the square of the magnitude of each weight.");
2397     DMLC_DECLARE_FIELD(rescale_grad)
2398     .set_default(1.0f)
2399     .describe("Rescale gradient to grad = rescale_grad*grad.");
2400     DMLC_DECLARE_FIELD(clip_gradient)
2401     .set_default(-1.0f)
2402     .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
2403               "If clip_gradient <= 0, gradient clipping is turned off. "
2404               "grad = max(min(grad, clip_gradient), -clip_gradient).");
2405     DMLC_DECLARE_FIELD(wd_lh)
2406     .set_default(0.0f)
2407     .describe("The amount of weight decay that does not go into gradient/momentum calculations"
2408               "otherwise do weight decay algorithmically only.");
2409   }
2410 };
2411 
2412 struct SignumKernel {
2413   template<typename DType>
2414   MSHADOW_XINLINE static void Map(index_t i, DType* out_data, DType* mom_data,
2415                                   const DType* weight_data, const DType* grad_data,
2416                                   const DType param_clip_gradient, const DType param_momentum,
2417                                   const DType param_lr, const DType param_wd,
2418                                   const DType param_rescale_grad, const DType param_wd_lh,
2419                                   const OpReqType req) {
2420     if (param_clip_gradient >= 0.0f) {
2421       mom_data[i] = param_momentum*mom_data[i]
2422               - (1-param_momentum)*param_wd*weight_data[i]
2423               - (1-param_momentum)
2424               *mshadow_op::clip::Map(param_rescale_grad*grad_data[i], param_clip_gradient);
2425     } else {
2426       mom_data[i] = param_momentum*mom_data[i]
2427                 - (1-param_momentum)*param_wd*weight_data[i]
2428                 - (1-param_momentum)*param_rescale_grad*grad_data[i];
2429     }
2430     KERNEL_ASSIGN(out_data[i], req, (1.f-param_lr*param_wd_lh)*weight_data[i]
2431       + (param_lr)*((mom_data[i] > 0) - (mom_data[i] < 0)));
2432   }
2433 };
2434 
2435 template<typename xpu>
2436 inline void SignumUpdate(const nnvm::NodeAttrs& attrs,
2437                          const OpContext &ctx,
2438                          const std::vector<TBlob> &inputs,
2439                          const std::vector<OpReqType> &req,
2440                          const std::vector<TBlob> &outputs) {
2441   using namespace mxnet_op;
2442   SignumParam param = nnvm::get<SignumParam>(attrs.parsed);
2443   Stream<xpu>* s = ctx.get_stream<xpu>();
2444   MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
2445     Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
2446     Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
2447     Tensor<xpu, 2, DType> mom = inputs[2].FlatTo2D<xpu, DType>(s);
2448     Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
2449     Kernel<SignumKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, mom.dptr_, weight.dptr_,
2450       grad.dptr_, static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
2451       static_cast<DType>(param.lr), static_cast<DType>(param.wd),
2452       static_cast<DType>(param.rescale_grad), static_cast<DType>(param.wd_lh), req[0]);
2453     });
2454 }
2455 
2456 struct AdagradParam : public dmlc::Parameter<AdagradParam> {
2457   float lr;
2458   float epsilon;
2459   float rescale_grad;
2460   float clip_gradient;
2461   float wd;
2462   DMLC_DECLARE_PARAMETER(AdagradParam) {
2463     DMLC_DECLARE_FIELD(lr)
2464     .describe("Learning rate");
2465     DMLC_DECLARE_FIELD(epsilon)
2466     .set_default(1.0e-7)
2467     .describe("epsilon");
2468     DMLC_DECLARE_FIELD(wd)
2469     .set_default(0.0f)
2470     .describe("weight decay");
2471     DMLC_DECLARE_FIELD(rescale_grad)
2472     .set_default(1.0f)
2473     .describe("Rescale gradient to grad = rescale_grad*grad.");
2474     DMLC_DECLARE_FIELD(clip_gradient)
2475     .set_default(-1.0f)
2476     .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
2477               "If clip_gradient <= 0, gradient clipping is turned off. "
2478               "grad = max(min(grad, clip_gradient), -clip_gradient).");
2479   }
2480 };
2481 
2482 inline bool AdagradStorageType(const nnvm::NodeAttrs& attrs,
2483                                const int dev_mask,
2484                                DispatchMode* dispatch_mode,
2485                                std::vector<int>* in_attrs,
2486                                std::vector<int>* out_attrs) {
2487   const AdagradParam& param = nnvm::get<AdagradParam>(attrs.parsed);
2488   CHECK_EQ(in_attrs->size(), 3U);
2489   CHECK_EQ(out_attrs->size(), 1U);
2490   const int weight_stype = in_attrs->at(0);
2491   const int grad_stype = in_attrs->at(1);
2492   const int state_stype = in_attrs->at(2);
2493   bool dispatched = false;
2494   if (!dispatched && grad_stype == kRowSparseStorage &&
2495       (weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage) &&
2496       state_stype == weight_stype && param.wd == 0.0f) {
2497     // weight and state share stype, grad's stype = rsp
2498     dispatched = storage_type_assign(
2499         out_attrs, static_cast<NDArrayStorageType>(weight_stype), dispatch_mode,
2500         DispatchMode::kFComputeEx);
2501   }
2502   return dispatched;
2503 }
2504 
2505 template<typename xpu>
2506 struct AdagradDnsRspDnsKernel;
2507 
2508 template<>
2509 struct AdagradDnsRspDnsKernel<cpu> {
2510   template<typename DType, typename IType>
2511   MSHADOW_XINLINE static void Map(index_t i, index_t row_length, DType* out_data,
2512     DType* state_data, const DType* weight_data, const IType* grad_idx,
2513     const DType* grad_data, const DType clip_gradient, const DType epsilon,
2514     const DType lr, const DType rescale_grad) {
2515     using nnvm::dim_t;
2516     using namespace mshadow_op;
2517     const dim_t data_i = grad_idx[i] * row_length;
2518     const dim_t grad_i = i * row_length;
2519     for (dim_t j = 0; j < row_length; j++) {
2520       const dim_t data_j = data_i + j;
2521       const dim_t grad_j = grad_i + j;
2522       DType grad_rescaled = grad_data[grad_j] * rescale_grad;
2523       if (clip_gradient >= 0.0f) {
2524         grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
2525       }
2526       const DType grad_squared = grad_rescaled * grad_rescaled;
2527       state_data[data_j] += grad_squared;
2528       const DType div = grad_rescaled / square_root::Map(state_data[data_j] + epsilon);
2529       // No need to use KERNEL_ASSIGN, as we already checked req is kWriteInplace
2530       out_data[data_j] = weight_data[data_j] - div * lr;
2531     }
2532   }
2533 };
2534 
2535 template<>
2536 struct AdagradDnsRspDnsKernel<gpu> {
2537   template<typename DType, typename IType>
2538   MSHADOW_XINLINE static void Map(index_t i, index_t row_length, DType* out_data,
2539     DType* state_data, const DType* weight_data, const IType* grad_idx,
2540     const DType* grad_data, const DType clip_gradient, const DType epsilon,
2541     const DType lr, const DType rescale_grad) {
2542     using nnvm::dim_t;
2543     using namespace mshadow_op;
2544     const dim_t row_id = i / row_length;
2545     const dim_t col_id = i % row_length;
2546     const dim_t data_i = grad_idx[row_id] * row_length + col_id;
2547     DType grad_rescaled = grad_data[i] * rescale_grad;
2548     if (clip_gradient >= 0.0f) {
2549       grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
2550     }
2551     const DType grad_squared = grad_rescaled * grad_rescaled;
2552     state_data[data_i] += grad_squared;
2553     const DType div = grad_rescaled / square_root::Map(state_data[data_i] + epsilon);
2554     // No need to use KERNEL_ASSIGN, as we already checked req is kWriteInplace
2555     out_data[data_i] = weight_data[data_i] - div * lr;
2556   }
2557 };
2558 
2559 template<typename xpu>
2560 void AdagradUpdateDnsRspDnsImpl(const AdagradParam& param,
2561                                 const OpContext& ctx,
2562                                 const TBlob& weight,
2563                                 const NDArray& grad,
2564                                 const TBlob& state,
2565                                 const OpReqType& req,
2566                                 TBlob *out) {
2567   using namespace mxnet_op;
2568   using namespace rowsparse;
2569   using namespace mshadow;
2570   Stream<xpu>* s = ctx.get_stream<xpu>();
2571   CHECK_EQ(param.wd, 0.0f)
2572     << "sparse adagrad_update does not support wd.";
2573   if (req == kNullOp || !grad.storage_initialized()) return;
2574   CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse adagrad_update";
2575   CHECK_GT(weight.shape_.Size(), 0);
2576   CHECK_GT(state.shape_.Size(), 0);
2577   MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
2578     MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
2579       const DType* weight_data = weight.dptr<DType>();
2580       const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
2581       const DType* grad_val = grad.data().dptr<DType>();
2582       DType* state_data = state.dptr<DType>();
2583       DType* out_data = out->dptr<DType>();
2584       const nnvm::dim_t nnr = grad.storage_shape()[0];
2585       const auto row_length = weight.shape_.ProdShape(1, weight.ndim());
2586       size_t num_threads = nnr;
2587       if (std::is_same<xpu, gpu>::value) {
2588         num_threads = nnr * row_length;
2589       }
2590       Kernel<AdagradDnsRspDnsKernel<xpu>, xpu>::Launch(s, num_threads, row_length,
2591         out_data, state_data, weight_data, grad_idx, grad_val,
2592         static_cast<DType>(param.clip_gradient), static_cast<DType>(param.epsilon),
2593         static_cast<DType>(param.lr), static_cast<DType>(param.rescale_grad));
2594     });
2595   });
2596 }
2597 
2598 template<typename xpu>
2599 inline void AdagradUpdateRspRspRspImpl(const AdagradParam& param,
2600                                        const OpContext& ctx,
2601                                        const NDArray& weight,
2602                                        const NDArray& grad,
2603                                        const NDArray& state,
2604                                        const OpReqType& req,
2605                                        NDArray *out) {
2606   using namespace mshadow;
2607   using namespace mxnet_op;
2608   using namespace rowsparse;
2609   CheckAllRowsPresent(weight, "AdagradUpdate", "weights");
2610   Stream<xpu>* s = ctx.get_stream<xpu>();
2611   // fill history with zero values
2612   if (!state.storage_initialized()) {
2613     NDArray state_zeros = state;
2614     FillDnsZerosRspImpl(s, &state_zeros);
2615   }
2616   TBlob out_blob = out->data();
2617   // reuse dns rsp implementation when storage_shape == shape
2618   AdagradUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
2619                                   state.data(), req, &out_blob);
2620 }
2621 
2622 template<typename xpu>
2623 inline void AdagradUpdateEx(const nnvm::NodeAttrs& attrs,
2624                             const OpContext &ctx,
2625                             const std::vector<NDArray> &inputs,
2626                             const std::vector<OpReqType> &req,
2627                             const std::vector<NDArray> &outputs) {
2628   using namespace mxnet_op;
2629   const AdagradParam& param = nnvm::get<AdagradParam>(attrs.parsed);
2630 
2631   const auto weight_stype = inputs[0].storage_type();
2632   const auto grad_stype = inputs[1].storage_type();
2633   const auto state_stype = inputs[2].storage_type();
2634   const auto output_stype = outputs[0].storage_type();
2635 
2636   if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) &&
2637       common::ContainsOnlyStorage(outputs, kRowSparseStorage)) {
2638     NDArray out = outputs[0];
2639     AdagradUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
2640                                     req[0], &out);
2641   } else if (state_stype == weight_stype && output_stype == weight_stype &&
2642              weight_stype == kDefaultStorage &&
2643              grad_stype == kRowSparseStorage) {
2644     TBlob out_blob = outputs[0].data();
2645     AdagradUpdateDnsRspDnsImpl<xpu>(param, ctx, inputs[0].data(), inputs[1],
2646                                     inputs[2].data(), req[0],
2647                                     &out_blob);
2648   } else {
2649     LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
2650   }
2651 }
2652 
2653 }  // namespace op
2654 }  // namespace mxnet
2655 
2656 #endif  // MXNET_OPERATOR_OPTIMIZER_OP_INL_H_
2657