1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file optimizer_op-inl.h
22 * \brief Optimizer operators
23 * \author Junyuan Xie
24 */
25 #ifndef MXNET_OPERATOR_OPTIMIZER_OP_INL_H_
26 #define MXNET_OPERATOR_OPTIMIZER_OP_INL_H_
27 #include <dmlc/parameter.h>
28 #include <mxnet/operator.h>
29 #include <mxnet/operator_util.h>
30 #include <mxnet/op_attr_types.h>
31 #include <mshadow/base.h>
32 #include <nnvm/op.h>
33 #include <nnvm/op_attr_types.h>
34 #include <vector>
35 #include "./operator_common.h"
36 #include "./mshadow_op.h"
37 #include "./elemwise_op_common.h"
38 #include "mxnet_op.h"
39 #include "./tensor/init_op.h"
40 #include "./tensor/util/tensor_util-inl.h"
41
42 namespace mxnet {
43 namespace op {
44
45 /*
46 * \brief log message for optimizers with lazy update.
47 */
LogLazyUpdate()48 inline void LogLazyUpdate() {
49 common::LogOnce("Optimizer with lazy_update = True detected. "
50 "Be aware that lazy update with row_sparse gradient is different from "
51 "standard update, and may lead to different empirical results. See "
52 "https://mxnet.apache.org/api/python/optimization/optimization.html "
53 "for more details.");
54 }
55
56 struct SGDParam : public dmlc::Parameter<SGDParam> {
57 float lr;
58 float wd;
59 float rescale_grad;
60 float clip_gradient;
61 bool lazy_update;
DMLC_DECLARE_PARAMETERSGDParam62 DMLC_DECLARE_PARAMETER(SGDParam) {
63 DMLC_DECLARE_FIELD(lr)
64 .describe("Learning rate");
65 DMLC_DECLARE_FIELD(wd)
66 .set_default(0.0f)
67 .describe("Weight decay augments the objective function with a "
68 "regularization term that penalizes large weights. "
69 "The penalty scales with the square of the magnitude of each weight.");
70 DMLC_DECLARE_FIELD(rescale_grad)
71 .set_default(1.0f)
72 .describe("Rescale gradient to grad = rescale_grad*grad.");
73 DMLC_DECLARE_FIELD(clip_gradient)
74 .set_default(-1.0f)
75 .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
76 "If clip_gradient <= 0, gradient clipping is turned off. "
77 "grad = max(min(grad, clip_gradient), -clip_gradient).");
78 DMLC_DECLARE_FIELD(lazy_update)
79 .set_default(true)
80 .describe("If true, lazy updates are applied if gradient's stype is row_sparse.");
81 }
82 };
83
84 struct MultiSGDParam : public dmlc::Parameter<MultiSGDParam> {
85 mxnet::Tuple<float> lrs;
86 mxnet::Tuple<float> wds;
87 float rescale_grad;
88 float clip_gradient;
89 int num_weights;
DMLC_DECLARE_PARAMETERMultiSGDParam90 DMLC_DECLARE_PARAMETER(MultiSGDParam) {
91 DMLC_DECLARE_FIELD(lrs)
92 .describe("Learning rates.");
93 DMLC_DECLARE_FIELD(wds)
94 .describe("Weight decay augments the objective function with a "
95 "regularization term that penalizes large weights. "
96 "The penalty scales with the square of the magnitude of each weight.");
97 DMLC_DECLARE_FIELD(rescale_grad)
98 .set_default(1.0f)
99 .describe("Rescale gradient to grad = rescale_grad*grad.");
100 DMLC_DECLARE_FIELD(clip_gradient)
101 .set_default(-1.0f)
102 .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
103 "If clip_gradient <= 0, gradient clipping is turned off. "
104 "grad = max(min(grad, clip_gradient), -clip_gradient).");
105 DMLC_DECLARE_FIELD(num_weights)
106 .set_default(1)
107 .describe("Number of updated weights.");
108 }
109 };
110
111 struct MultiSGDMomParam : public dmlc::Parameter<MultiSGDMomParam> {
112 mxnet::Tuple<float> lrs;
113 mxnet::Tuple<float> wds;
114 float momentum;
115 float rescale_grad;
116 float clip_gradient;
117 int num_weights;
DMLC_DECLARE_PARAMETERMultiSGDMomParam118 DMLC_DECLARE_PARAMETER(MultiSGDMomParam) {
119 DMLC_DECLARE_FIELD(lrs)
120 .describe("Learning rates.");
121 DMLC_DECLARE_FIELD(wds)
122 .describe("Weight decay augments the objective function with a "
123 "regularization term that penalizes large weights. "
124 "The penalty scales with the square of the magnitude of each weight.");
125 DMLC_DECLARE_FIELD(momentum)
126 .set_default(0.0f)
127 .describe("The decay rate of momentum estimates at each epoch.");
128 DMLC_DECLARE_FIELD(rescale_grad)
129 .set_default(1.0f)
130 .describe("Rescale gradient to grad = rescale_grad*grad.");
131 DMLC_DECLARE_FIELD(clip_gradient)
132 .set_default(-1.0f)
133 .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
134 "If clip_gradient <= 0, gradient clipping is turned off. "
135 "grad = max(min(grad, clip_gradient), -clip_gradient).");
136 DMLC_DECLARE_FIELD(num_weights)
137 .set_default(1)
138 .describe("Number of updated weights.");
139 }
140 };
141
142
143 template<typename ParamType, int input_stride>
MultiSGDShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)144 inline bool MultiSGDShape(const nnvm::NodeAttrs& attrs,
145 mxnet::ShapeVector *in_attrs,
146 mxnet::ShapeVector *out_attrs) {
147 const ParamType& param = dmlc::get<ParamType>(attrs.parsed);
148 CHECK_EQ(in_attrs->size(), input_stride * param.num_weights);
149 CHECK_EQ(out_attrs->size(), param.num_weights);
150
151 bool all_inferred = true;
152 auto& input_shapes = *in_attrs;
153 auto& output_shapes = *out_attrs;
154 // Learning rates
155 CHECK_EQ(param.lrs.ndim(), param.num_weights)
156 << "Number of learning rates is inconsistent with num_weights "
157 << "parameter passed. Expected number of learning rates: "
158 << param.num_weights << ", and got " << param.lrs.ndim();
159 // Weight decays
160 CHECK_EQ(param.wds.ndim(), param.num_weights)
161 << "Number of weight decays is inconsistent with num_weights "
162 << "parameter passed. Expected number of weight decays: "
163 << param.num_weights << ", and got " << param.wds.ndim();
164 // Weights and gradients
165 for (int i = 0; i < param.num_weights; ++i) {
166 mxnet::ShapeVector input_vec;
167 mxnet::ShapeVector output_vec({output_shapes[i]});
168 for (int j = 0; j < input_stride; ++j) {
169 input_vec.push_back(input_shapes[i * input_stride + j]);
170 }
171 all_inferred = all_inferred && ElemwiseShape<input_stride, 1>(attrs, &input_vec, &output_vec);
172 }
173 return all_inferred;
174 }
175
176 template <typename ParamType, int input_stride, int num_fp32_inputs>
MP_MultiSGD_InferType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_attrs,std::vector<int> * out_attrs)177 inline bool MP_MultiSGD_InferType(const nnvm::NodeAttrs& attrs,
178 std::vector<int> *in_attrs,
179 std::vector<int> *out_attrs) {
180 const ParamType& param = dmlc::get<ParamType>(attrs.parsed);
181 CHECK_EQ(in_attrs->size(), input_stride * param.num_weights);
182 CHECK_EQ(out_attrs->size(), param.num_weights);
183
184 bool all_inferred = true;
185 auto& input_types = *in_attrs;
186 auto& output_types = *out_attrs;
187 // Weights and gradients
188 for (int i = 0; i < param.num_weights; ++i) {
189 std::vector<int> input_vec;
190 std::vector<int> output_vec({output_types[i]});
191 for (int j = 0; j < input_stride - num_fp32_inputs; ++j) {
192 input_vec.push_back(input_types[i * input_stride + j]);
193 }
194 all_inferred = all_inferred &&
195 ElemwiseType<input_stride - num_fp32_inputs, 1>(attrs, &input_vec, &output_vec);
196 }
197 // master copies of weights
198 for (int i = 0; i < param.num_weights; ++i) {
199 for (int j = 0; j < num_fp32_inputs; ++j) {
200 TYPE_ASSIGN_CHECK(input_types, input_stride * i + input_stride - 1 - j, mshadow::kFloat32);
201 }
202 }
203 return all_inferred;
204 }
205
206 template<typename DType, typename MPDType>
207 struct MultiSGDKernelParam {
208 static const int N = 60;
209 int count;
210 size_t max_size;
211 size_t sizes[N];
212 DType * weights[N];
213 DType * grads[N];
214 MPDType * mom[N];
215 MPDType * weights32[N];
216 DType * out_data[N];
217 MPDType lrs[N];
218 MPDType wds[N];
219 MPDType clip_gradient;
220 MPDType rescale_grad;
221 MPDType momentum;
222 };
223
224 template <typename MPDType, bool has_momentum, bool has_mixed_precision>
225 struct MultiSGDKernel {
226 template<typename DType>
MapMultiSGDKernel227 MSHADOW_XINLINE static void Map(index_t i, const MultiSGDKernelParam<DType, MPDType>& param,
228 const OpReqType req) {
229 for (int index = 0; index < param.count; ++index) {
230 if (i < static_cast<index_t>(param.sizes[index])) {
231 MPDType w = has_mixed_precision ? param.weights32[index][i] :
232 MPDType(param.weights[index][i]);
233 MPDType mom = has_momentum ? param.mom[index][i] : MPDType(0);
234 if (param.clip_gradient >= 0.0f) {
235 mom = param.momentum*mom
236 - param.lrs[index]*param.wds[index]*w
237 - param.lrs[index]
238 *mshadow_op::clip::Map(param.rescale_grad *
239 static_cast<MPDType>(param.grads[index][i]),
240 param.clip_gradient);
241 } else {
242 mom = param.momentum*mom
243 - param.lrs[index]*param.wds[index]*w
244 - param.lrs[index]*param.rescale_grad*static_cast<MPDType>(param.grads[index][i]);
245 }
246 if (has_momentum) {
247 param.mom[index][i] = mom;
248 }
249 w = w + mom;
250 if (has_mixed_precision) {
251 param.weights32[index][i] = w;
252 }
253 KERNEL_ASSIGN(param.out_data[index][i], req, w);
254 }
255 }
256 }
257 };
258
259 template<typename xpu,
260 typename DType,
261 typename MPDType,
262 typename ParamType = MultiSGDParam,
263 int input_stride = 2>
FillMultiSGDKernelParam(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<TBlob> & outputs)264 MultiSGDKernelParam<DType, MPDType> FillMultiSGDKernelParam(const nnvm::NodeAttrs& attrs,
265 const OpContext &ctx,
266 const std::vector<TBlob> &inputs,
267 const std::vector<TBlob> &outputs) {
268 using namespace mxnet_op;
269 const ParamType& p = nnvm::get<ParamType>(attrs.parsed);
270 Stream<xpu>* s = ctx.get_stream<xpu>();
271 MultiSGDKernelParam<DType, MPDType> param;
272 param.clip_gradient = p.clip_gradient;
273 param.rescale_grad = p.rescale_grad;
274 param.momentum = 0;
275 param.count = p.num_weights;
276 param.max_size = 0;
277 for (int i = 0; i < param.count; ++i) {
278 param.sizes[i] = inputs[i * input_stride].shape_.Size();
279 if (param.max_size < param.sizes[i]) {
280 param.max_size = param.sizes[i];
281 }
282 param.weights[i] = inputs[i * input_stride].FlatTo2D<xpu, DType>(s).dptr_;
283 param.grads[i] = inputs[i * input_stride + 1].FlatTo2D<xpu, DType>(s).dptr_;
284 // if mixed precision, then the last input in a set
285 // is 32-bit master copy of the weights
286 if (!std::is_same<DType, MPDType>::value) {
287 param.weights32[i] = inputs[i * input_stride + input_stride - 1]
288 .FlatTo2D<xpu, MPDType>(s).dptr_;
289 }
290 param.out_data[i] = outputs[i].FlatTo2D<xpu, DType>(s).dptr_;
291 param.lrs[i] = p.lrs[i];
292 param.wds[i] = p.wds[i];
293 }
294
295 return param;
296 }
297
298
299 template<typename xpu,
300 typename DType,
301 typename MPDType,
302 int input_stride = 3>
FillMultiSGDMomKernelParam(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<TBlob> & outputs)303 MultiSGDKernelParam<DType, MPDType> FillMultiSGDMomKernelParam(const nnvm::NodeAttrs& attrs,
304 const OpContext &ctx,
305 const std::vector<TBlob> &inputs,
306 const std::vector<TBlob> &outputs) {
307 using namespace mxnet_op;
308 const MultiSGDMomParam& p = nnvm::get<MultiSGDMomParam>(attrs.parsed);
309 Stream<xpu>* s = ctx.get_stream<xpu>();
310 MultiSGDKernelParam<DType, MPDType> param =
311 FillMultiSGDKernelParam<xpu,
312 DType,
313 MPDType,
314 MultiSGDMomParam,
315 input_stride>(attrs, ctx, inputs, outputs);
316 param.momentum = p.momentum;
317 for (int i = 0; i < param.count; ++i) {
318 param.mom[i] = inputs[i * input_stride + 2].FlatTo2D<xpu, MPDType>(s).dptr_;
319 }
320
321 return param;
322 }
323
324 template<typename T>
325 class type_identity {
326 public:
327 using type = T;
328 };
329
330 template<typename T>
331 class single_precision {
332 public:
333 using type = float;
334 };
335
336 template<typename xpu, template<typename> class MPTypeChooser, int input_stride>
MultiSGDUpdate(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)337 inline void MultiSGDUpdate(const nnvm::NodeAttrs& attrs,
338 const OpContext &ctx,
339 const std::vector<TBlob> &inputs,
340 const std::vector<OpReqType> &req,
341 const std::vector<TBlob> &outputs) {
342 using namespace mxnet_op;
343 Stream<xpu>* s = ctx.get_stream<xpu>();
344 MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
345 using MPDType = typename MPTypeChooser<DType>::type;
346 MultiSGDKernelParam<DType, MPDType> param =
347 FillMultiSGDKernelParam<xpu,
348 DType,
349 MPDType,
350 MultiSGDParam,
351 input_stride>(attrs, ctx, inputs, outputs);
352 Kernel<MultiSGDKernel<MPDType,
353 false,
354 !std::is_same<DType, MPDType>::value>,
355 xpu>::Launch(s, param.max_size, param, req[0]);
356 });
357 }
358
359 template<typename xpu, template<typename> class MPTypeChooser, int input_stride>
MultiSGDMomUpdate(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)360 inline void MultiSGDMomUpdate(const nnvm::NodeAttrs& attrs,
361 const OpContext &ctx,
362 const std::vector<TBlob> &inputs,
363 const std::vector<OpReqType> &req,
364 const std::vector<TBlob> &outputs) {
365 using namespace mxnet_op;
366 Stream<xpu>* s = ctx.get_stream<xpu>();
367 MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
368 using MPDType = typename MPTypeChooser<DType>::type;
369 MultiSGDKernelParam<DType, MPDType> param =
370 FillMultiSGDMomKernelParam<xpu,
371 DType,
372 MPDType,
373 input_stride>(attrs, ctx, inputs, outputs);
374 Kernel<MultiSGDKernel<MPDType,
375 true,
376 !std::is_same<DType, MPDType>::value>,
377 xpu>::Launch(s, param.max_size, param, req[0]);
378 });
379 }
380
381 struct SGDKernel {
382 template<typename DType>
MapSGDKernel383 MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const DType* weight_data,
384 const DType* grad_data, const DType param_clip_gradient,
385 const DType param_lr, const DType param_wd, const DType param_rescale_grad,
386 const OpReqType req) {
387 if (param_clip_gradient >= 0.0f) {
388 KERNEL_ASSIGN(out_data[i], req,
389 (1.f-param_lr*param_wd)*weight_data[i]
390 - (param_lr)
391 * mshadow_op::clip::Map(param_rescale_grad*grad_data[i], param_clip_gradient));
392 } else {
393 KERNEL_ASSIGN(out_data[i], req,
394 (1.f-param_lr*param_wd)*weight_data[i]
395 - (param_lr*param_rescale_grad)*grad_data[i]);
396 }
397 }
398 };
399
400 template<typename xpu>
SGDUpdate(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)401 inline void SGDUpdate(const nnvm::NodeAttrs& attrs,
402 const OpContext &ctx,
403 const std::vector<TBlob> &inputs,
404 const std::vector<OpReqType> &req,
405 const std::vector<TBlob> &outputs) {
406 using namespace mxnet_op;
407 const SGDParam& param = nnvm::get<SGDParam>(attrs.parsed);
408 Stream<xpu>* s = ctx.get_stream<xpu>();
409 MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
410 Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
411 Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
412 Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
413 Kernel<SGDKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, weight.dptr_,
414 grad.dptr_, static_cast<DType>(param.clip_gradient),
415 static_cast<DType>(param.lr), static_cast<DType>(param.wd),
416 static_cast<DType>(param.rescale_grad), req[0]);
417 });
418 }
419
420 /*! \brief kernel for sparse sgd
421 */
422 template<int req, typename xpu>
423 struct SGDDnsRspKernel;
424
425 template<int req>
426 struct SGDDnsRspKernel<req, gpu> {
427 // DType is the output data type
428 // IType is row sparse idx type
429 // i is the ith element in row sparse gradient
430 template<typename DType, typename IType>
431 MSHADOW_XINLINE static void Map(index_t i, const index_t row_length, DType* out,
432 const DType* weight, const IType* grad_idx,
433 const DType *grad_val, const DType clip_gradient, const DType lr,
434 const DType wd, const DType rescale_grad) {
435 using nnvm::dim_t;
436 using namespace mshadow_op;
437 const dim_t row_id = i / row_length;
438 const dim_t col_id = i % row_length;
439 const dim_t row_offset = grad_idx[row_id] * row_length;
440 const dim_t data_i = row_offset + col_id;
441 if (clip_gradient >= 0.0f) {
442 KERNEL_ASSIGN(out[data_i], req, (1.f - lr * wd) * weight[data_i] -
443 (lr) * mshadow_op::clip::Map(rescale_grad * grad_val[i], clip_gradient));
444 } else {
445 KERNEL_ASSIGN(out[data_i], req, (1.f - lr * wd) * weight[data_i] -
446 (lr * rescale_grad) * grad_val[i]);
447 }
448 }
449 };
450
451 /*! \brief kernel for sparse sgd
452 */
453 template<int req>
454 struct SGDDnsRspKernel<req, cpu> {
455 // DType is the output data type
456 // IType is row sparse idx type
457 // i is the ith row in row sparse gradient
458 template<typename DType, typename IType>
459 MSHADOW_XINLINE static void Map(index_t i, const index_t row_length, DType* out,
460 const DType* weight, const IType* grad_idx,
461 const DType *grad_val, const DType clip_gradient, const DType lr,
462 const DType wd, const DType rescale_grad) {
463 for (index_t j = 0; j < row_length; j++) {
464 index_t data_i = grad_idx[i] * row_length + j;
465 index_t grad_i = i * row_length + j;
466 if (clip_gradient >= 0.0f) {
467 KERNEL_ASSIGN(out[data_i], req, (1.f - lr * wd) * weight[data_i] -
468 (lr) * mshadow_op::clip::Map(rescale_grad * grad_val[grad_i], clip_gradient));
469 } else {
470 KERNEL_ASSIGN(out[data_i], req, (1.f - lr * wd) * weight[data_i] -
471 (lr * rescale_grad) * grad_val[grad_i]);
472 }
473 }
474 }
475 };
476
477 /*
478 * \brief SGD update implementation for dense weight and row_sparse grad.
479 * Both standard update and lazy update are supported.
480 */
481 template<typename xpu>
482 inline void SGDUpdateDnsRspImpl(const SGDParam& param,
483 const OpContext &ctx,
484 const TBlob& weight,
485 const NDArray& grad,
486 const OpReqType& req,
487 TBlob *out) {
488 using namespace mshadow;
489 using namespace mshadow::expr;
490 using namespace mshadow_op;
491 using namespace mxnet_op;
492 Stream<xpu>* s = ctx.get_stream<xpu>();
493 CHECK_EQ(grad.storage_type(), kRowSparseStorage);
494 // if gradients are zeros, no weights are updated
495 if (req == kNullOp) return;
496 CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_mom_update";
497 CHECK_GT(weight.shape_.Size(), 0);
498
499 MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
500 MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, {
501 MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
502 DType* weight_data = weight.dptr<DType>();
503 float wd = param.wd;
504 // apply standard weight decay if not lazy update
505 if (!param.lazy_update) {
506 Kernel<op_with_req<mshadow_op::mul, req_type>, xpu>::Launch(s, weight.Size(),
507 weight_data, weight_data, static_cast<DType>(1 - param.lr * param.wd));
508 wd = 0;
509 }
510 if (!grad.storage_initialized()) return;
511 const IType* grad_idx = grad.aux_data(rowsparse::kIdx).dptr<IType>();
512 const DType* grad_val = grad.data().dptr<DType>();
513 const nnvm::dim_t num_rows = grad.aux_shape(rowsparse::kIdx)[0];
514 const auto row_length = weight.shape_.ProdShape(1, weight.ndim());
515 size_t num_threads = num_rows;
516 if (std::is_same<xpu, gpu>::value) {
517 num_threads = num_rows * row_length;
518 }
519 Kernel<SGDDnsRspKernel<req_type, xpu>, xpu>::Launch(s, num_threads, row_length,
520 out->dptr<DType>(), weight_data, grad_idx, grad_val,
521 static_cast<DType>(param.clip_gradient),
522 static_cast<DType>(param.lr), static_cast<DType>(wd),
523 static_cast<DType>(param.rescale_grad));
524 });
525 });
526 });
527 }
528
529 /*
530 * \brief SGD update implementation for row_sparse grad.
531 * Both standard update and lazy update are supported.
532 */
533 template<typename xpu>
534 inline void SGDUpdateRspImpl(const SGDParam& param,
535 const OpContext& ctx,
536 const NDArray& weight,
537 const NDArray& grad,
538 const OpReqType& req,
539 NDArray *out) {
540 CheckAllRowsPresent(weight, "SGDUpdate", "weights");
541 // reuse dns rsp implementation when storage_shape == shape
542 TBlob out_blob = out->data();
543 SGDUpdateDnsRspImpl<xpu>(param, ctx, weight.data(), grad, req, &out_blob);
544 }
545
546 template<typename xpu>
547 inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs,
548 const OpContext &ctx,
549 const std::vector<NDArray> &inputs,
550 const std::vector<OpReqType> &req,
551 const std::vector<NDArray> &outputs) {
552 const SGDParam& param = nnvm::get<SGDParam>(attrs.parsed);
553 const auto w_stype = inputs[0].storage_type();
554 const auto g_stype = inputs[1].storage_type();
555 const auto o_stype = outputs[0].storage_type();
556 if (o_stype == w_stype && g_stype == kRowSparseStorage &&
557 (w_stype == kDefaultStorage || w_stype == kRowSparseStorage)) {
558 NDArray out = outputs[0];
559 // std update and lazy update with rsp grad
560 SGDUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], req[0], &out);
561 } else {
562 LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
563 }
564 }
565
566 struct SGDMomParam : public dmlc::Parameter<SGDMomParam> {
567 float lr;
568 float momentum;
569 float wd;
570 float rescale_grad;
571 float clip_gradient;
572 bool lazy_update;
573 DMLC_DECLARE_PARAMETER(SGDMomParam) {
574 DMLC_DECLARE_FIELD(lr)
575 .describe("Learning rate");
576 DMLC_DECLARE_FIELD(momentum)
577 .set_default(0.0f)
578 .describe("The decay rate of momentum estimates at each epoch.");
579 DMLC_DECLARE_FIELD(wd)
580 .set_default(0.0f)
581 .describe("Weight decay augments the objective function with a "
582 "regularization term that penalizes large weights. "
583 "The penalty scales with the square of the magnitude of each weight.");
584 DMLC_DECLARE_FIELD(rescale_grad)
585 .set_default(1.0f)
586 .describe("Rescale gradient to grad = rescale_grad*grad.");
587 DMLC_DECLARE_FIELD(clip_gradient)
588 .set_default(-1.0f)
589 .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
590 "If clip_gradient <= 0, gradient clipping is turned off. "
591 "grad = max(min(grad, clip_gradient), -clip_gradient).");
592 DMLC_DECLARE_FIELD(lazy_update)
593 .set_default(true)
594 .describe("If true, lazy updates are applied if gradient's stype is row_sparse "
595 "and both weight and momentum have the same stype");
596 }
597 };
598
599
600 struct SGDMomKernel {
601 template<typename DType>
602 MSHADOW_XINLINE static void Map(index_t i, DType* out_data, DType* mom_data,
603 const DType* weight_data, const DType* grad_data,
604 const DType param_clip_gradient, const DType param_momentum,
605 const DType param_lr, const DType param_wd,
606 const DType param_rescale_grad, const OpReqType req) {
607 if (param_clip_gradient >= 0.0f) {
608 mom_data[i] = param_momentum*mom_data[i]
609 - param_lr*param_wd*weight_data[i]
610 - param_lr
611 *mshadow_op::clip::Map(param_rescale_grad*grad_data[i], param_clip_gradient);
612 } else {
613 mom_data[i] = param_momentum*mom_data[i]
614 - param_lr*param_wd*weight_data[i]
615 - param_lr*param_rescale_grad*grad_data[i];
616 }
617 KERNEL_ASSIGN(out_data[i], req, weight_data[i] + mom_data[i]);
618 }
619 };
620
621 template<typename xpu>
622 inline void SGDMomUpdate(const nnvm::NodeAttrs& attrs,
623 const OpContext &ctx,
624 const std::vector<TBlob> &inputs,
625 const std::vector<OpReqType> &req,
626 const std::vector<TBlob> &outputs) {
627 using namespace mxnet_op;
628 SGDMomParam param = nnvm::get<SGDMomParam>(attrs.parsed);
629 Stream<xpu>* s = ctx.get_stream<xpu>();
630 MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
631 Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
632 Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
633 Tensor<xpu, 2, DType> mom = inputs[2].FlatTo2D<xpu, DType>(s);
634 Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
635 Kernel<SGDMomKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, mom.dptr_, weight.dptr_,
636 grad.dptr_, static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
637 static_cast<DType>(param.lr), static_cast<DType>(param.wd),
638 static_cast<DType>(param.rescale_grad), req[0]);
639 });
640 }
641
642 template<int n_in, int n_out, int total_in>
643 inline bool MP_InferType(const nnvm::NodeAttrs& attrs,
644 std::vector<int> *in_attrs,
645 std::vector<int> *out_attrs) {
646 CHECK_EQ(in_attrs->size(), static_cast<size_t>(total_in)) << " in operator " << attrs.name;
647 CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator " << attrs.name;
648 for (int i = n_in; i < total_in; ++i) {
649 TYPE_ASSIGN_CHECK(*in_attrs, i, mshadow::kFloat32);
650 }
651 return ElemwiseAttr<int, type_is_none, type_assign, true, type_string, n_in, n_out>(
652 attrs, in_attrs, out_attrs, -1);
653 }
654
655 struct MP_SGDKernel {
656 template<typename DType>
657 MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const DType* weight_data,
658 const DType* grad_data, float* weight32, const float param_clip_gradient,
659 const float param_lr, const float param_wd, const float param_rescale_grad,
660 const OpReqType req) {
661 if (param_clip_gradient >= 0.0f) {
662 float w = weight32[i];
663 w = (1.f - param_lr*param_wd)*w -
664 (param_lr) * mshadow_op::clip::Map(param_rescale_grad*static_cast<float>(grad_data[i]),
665 param_clip_gradient);
666 weight32[i] = w;
667 KERNEL_ASSIGN(out_data[i], req, (DType)w);
668 } else {
669 float w = weight32[i];
670 w = (1.f-param_lr*param_wd)*w
671 - (param_lr*param_rescale_grad)*static_cast<float>(grad_data[i]);
672 weight32[i] = w;
673 KERNEL_ASSIGN(out_data[i], req, (DType)w);
674 }
675 }
676 };
677
678 template<typename xpu>
679 inline void MP_SGDUpdate(const nnvm::NodeAttrs& attrs,
680 const OpContext &ctx,
681 const std::vector<TBlob> &inputs,
682 const std::vector<OpReqType> &req,
683 const std::vector<TBlob> &outputs) {
684 using namespace mxnet_op;
685 const SGDParam& param = nnvm::get<SGDParam>(attrs.parsed);
686 Stream<xpu>* s = ctx.get_stream<xpu>();
687 MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
688 Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
689 Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
690 Tensor<xpu, 2, float> weight32 = inputs[2].FlatTo2D<xpu, float>(s);
691 Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
692 Kernel<MP_SGDKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, weight.dptr_,
693 grad.dptr_, weight32.dptr_, param.clip_gradient,
694 param.lr, param.wd,
695 param.rescale_grad, req[0]);
696 });
697 }
698
699 struct MP_SGDMomKernel {
700 template<typename DType>
701 MSHADOW_XINLINE static void Map(index_t i, DType* out_data, float* mom_data,
702 const DType* weight_data, const DType* grad_data, float* weight32,
703 const float param_clip_gradient, const float param_momentum, const float param_lr,
704 const float param_wd, const float param_rescale_grad, const OpReqType req) {
705 float w = weight32[i];
706 float mom = mom_data[i];
707 if (param_clip_gradient >= 0.0f) {
708 mom = param_momentum*mom
709 - param_lr*param_wd*w
710 - param_lr
711 *mshadow_op::clip::Map(param_rescale_grad*static_cast<float>(grad_data[i]),
712 param_clip_gradient);
713 } else {
714 mom = param_momentum*mom
715 - param_lr*param_wd*w
716 - param_lr*param_rescale_grad*static_cast<float>(grad_data[i]);
717 }
718 mom_data[i] = mom;
719 w = w + mom;
720 weight32[i] = w;
721 KERNEL_ASSIGN(out_data[i], req, w);
722 }
723 };
724
725 template<typename xpu>
726 inline void MP_SGDMomUpdate(const nnvm::NodeAttrs& attrs,
727 const OpContext &ctx,
728 const std::vector<TBlob> &inputs,
729 const std::vector<OpReqType> &req,
730 const std::vector<TBlob> &outputs) {
731 using namespace mxnet_op;
732 SGDMomParam param = nnvm::get<SGDMomParam>(attrs.parsed);
733 Stream<xpu>* s = ctx.get_stream<xpu>();
734 MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
735 Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
736 Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
737 Tensor<xpu, 2, float> mom = inputs[2].FlatTo2D<xpu, float>(s);
738 Tensor<xpu, 2, float> weight32 = inputs[3].FlatTo2D<xpu, float>(s);
739 Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
740 Kernel<MP_SGDMomKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, mom.dptr_,
741 weight.dptr_, grad.dptr_, weight32.dptr_, param.clip_gradient, param.momentum,
742 param.lr, param.wd, param.rescale_grad, req[0]);
743 });
744 }
745
746 template<int req, typename xpu>
747 struct SGDMomDnsRspDnsKernel;
748
749 template<int req>
750 struct SGDMomDnsRspDnsKernel<req, cpu> {
751 template<typename DType, typename IType>
752 MSHADOW_XINLINE static void Map(index_t i, index_t row_length, DType* out_data,
753 DType* mom_data, const DType* weight_data, const IType* grad_idx,
754 const DType* grad_data, const DType clip_gradient, const DType momentum,
755 const DType lr, const DType wd, const DType rescale_grad) {
756 const DType rate = lr * wd;
757 for (index_t j = 0; j < row_length; j++) {
758 index_t data_i = grad_idx[i] * row_length + j;
759 index_t grad_i = i * row_length + j;
760 if (clip_gradient >= 0.0f) {
761 mom_data[data_i] = momentum * mom_data[data_i]
762 - rate * weight_data[data_i]
763 - lr *
764 mshadow_op::clip::Map(rescale_grad * grad_data[grad_i],
765 clip_gradient);
766 } else {
767 mom_data[data_i] = momentum * mom_data[data_i]
768 - rate * weight_data[data_i]
769 - lr * rescale_grad * grad_data[grad_i];
770 }
771 KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] + mom_data[data_i]);
772 }
773 }
774 };
775
776 template<int req>
777 struct SGDMomDnsRspDnsKernel<req, gpu> {
778 template<typename DType, typename IType>
779 MSHADOW_XINLINE static void Map(index_t i, index_t row_length, DType* out_data,
780 DType* mom_data, const DType* weight_data, const IType* grad_idx,
781 const DType* grad_data, const DType clip_gradient, const DType momentum,
782 const DType lr, const DType wd, const DType rescale_grad) {
783 using nnvm::dim_t;
784 const DType rate = lr * wd;
785 const dim_t row_id = i / row_length;
786 const dim_t col_id = i % row_length;
787 const dim_t data_i = grad_idx[row_id] * row_length + col_id;
788 if (clip_gradient >= 0.0f) {
789 mom_data[data_i] = momentum * mom_data[data_i]
790 - rate * weight_data[data_i]
791 - lr *
792 mshadow_op::clip::Map(rescale_grad * grad_data[i],
793 clip_gradient);
794 } else {
795 mom_data[data_i] = momentum * mom_data[data_i]
796 - rate * weight_data[data_i]
797 - lr * rescale_grad * grad_data[i];
798 }
799 KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] + mom_data[data_i]);
800 }
801 };
802
803 /*
804 * \brief sgd mom lazy update for dense weight, row_sparse grad, dense state.
805 */
806 template<typename xpu>
807 inline void SGDMomLazyUpdateDnsRspDnsImpl(const SGDMomParam& param,
808 const OpContext& ctx,
809 const TBlob& weight,
810 const NDArray& grad,
811 const TBlob& mom,
812 const OpReqType& req,
813 TBlob *out) {
814 using namespace mxnet_op;
815 using namespace rowsparse;
816 Stream<xpu>* s = ctx.get_stream<xpu>();
817 if (!grad.storage_initialized() || req == kNullOp) return;
818 CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_mom_update";
819 CHECK_GT(weight.shape_.Size(), 0);
820 CHECK_GT(mom.shape_.Size(), 0);
821
822 MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
823 MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
824 MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
825 DType* weight_data = weight.dptr<DType>();
826 IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
827 DType* grad_val = grad.data().dptr<DType>();
828 DType* mom_data = mom.dptr<DType>();
829 DType* out_data = out->dptr<DType>();
830 index_t num_rows = grad.aux_shape(kIdx)[0];
831 auto row_length = weight.shape_.ProdShape(1, weight.ndim());
832 size_t num_threads = num_rows;
833 if (std::is_same<xpu, gpu>::value) {
834 num_threads = num_rows * row_length;
835 }
836 Kernel<SGDMomDnsRspDnsKernel<req_type, xpu>, xpu>::Launch(s, num_threads, row_length,
837 out_data, mom_data, weight_data, grad_idx, grad_val,
838 static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
839 static_cast<DType>(param.lr), static_cast<DType>(param.wd),
840 static_cast<DType>(param.rescale_grad));
841 });
842 });
843 });
844 }
845
846 /*
847 * \brief sgd momentum lazy update for row_sparse grad.
848 */
849 template<typename xpu>
850 inline void SGDMomLazyUpdateRspImpl(const SGDMomParam& param,
851 const OpContext& ctx,
852 const NDArray& weight,
853 const NDArray& grad,
854 const NDArray& mom,
855 const OpReqType& req,
856 NDArray *out) {
857 using namespace mxnet_op;
858 using namespace rowsparse;
859 CheckAllRowsPresent(weight, "SGDMomUpdate", "weights");
860 Stream<xpu>* s = ctx.get_stream<xpu>();
861 // fill mom with zero values (if it's in rsp storage)
862 // in order to reuse the sgd mom dns impl
863 if (mom.storage_type() == kRowSparseStorage && !mom.storage_initialized()) {
864 NDArray mom_zeros = mom;
865 FillDnsZerosRspImpl(s, &mom_zeros);
866 }
867 TBlob out_blob = out->data();
868 // reuse dns rsp implementation when storage_shape == shape
869 SGDMomLazyUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
870 mom.data(), req, &out_blob);
871 }
872
873 /*!
874 * \brief Storge type inference function for optimizers which support both
875 * lazy update and standard update, with states (e.g. 2nd order moment)
876 * \param num_states The number of states that could be row_sparse or dense
877 */
878 template<size_t num_states, typename ParamType>
879 inline bool StdOptStorageType(const nnvm::NodeAttrs& attrs,
880 const int dev_mask,
881 DispatchMode* dispatch_mode,
882 std::vector<int>* in_attrs,
883 std::vector<int>* out_attrs) {
884 using namespace common;
885 const ParamType& param = nnvm::get<ParamType>(attrs.parsed);
886 // weight, grad, state 0, state 1, ... -> weight
887 CHECK_EQ(in_attrs->size(), 2 + num_states);
888 CHECK_EQ(out_attrs->size(), 1U);
889 const int weight_stype = in_attrs->at(0);
890 const int grad_stype = in_attrs->at(1);
891 const int state_stype = in_attrs->at(2);
892 // the storage type of all states should be the same
893 for (size_t i = 3; i < 2 + num_states; i++) {
894 CHECK_EQ(state_stype, in_attrs->at(i))
895 << "Inconsistent storage types detected in state " << i;
896 }
897 bool dispatched = false;
898 if (!dispatched && ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
899 // dns, ... -> dns
900 dispatched = storage_type_assign(out_attrs, kDefaultStorage,
901 dispatch_mode, DispatchMode::kFCompute);
902 }
903 if (!dispatched && grad_stype == kRowSparseStorage &&
904 (weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage) &&
905 state_stype == weight_stype) {
906 // weight and state share stype, grad's stype = rsp
907 dispatched = storage_type_assign(out_attrs, static_cast<NDArrayStorageType>(weight_stype),
908 dispatch_mode, DispatchMode::kFComputeEx);
909 // warn users if lazy_update is turned on
910 if (dispatched && param.lazy_update) LogLazyUpdate();
911 }
912 if (!dispatched && grad_stype == kRowSparseStorage &&
913 weight_stype == kRowSparseStorage && state_stype == kDefaultStorage) {
914 // weight, grad, state, ... -> weight
915 // rsp, rsp, dns, ... -> rsp, standard update
916 dispatched = storage_type_assign(out_attrs, static_cast<NDArrayStorageType>(weight_stype),
917 dispatch_mode, DispatchMode::kFComputeEx);
918 }
919 if (!dispatched) {
920 dispatched = dispatch_fallback(out_attrs, dispatch_mode);
921 }
922 return dispatched;
923 }
924
925 /*
926 * \brief kernel for standard momentum update for dense weight, sparse grad and dense state.
927 */
928 template<int req, typename xpu>
929 struct SGDMomStdDnsRspDnsKernel;
930
931
932 /*
933 * \brief standard momentum update for dense weight, row_sparse grad and dense states.
934 */
935 template<typename xpu>
936 void SGDMomStdUpdateDnsRspDnsImpl(const SGDMomParam& param,
937 const OpContext& ctx,
938 const TBlob& weight,
939 const NDArray& grad,
940 const TBlob& mom,
941 const OpReqType& req,
942 TBlob *out);
943
944 /*
945 * \brief standard momentum update for row_sparse grad.
946 * both row_sparse and dense weight are supported.
947 */
948 template<typename xpu>
949 inline void SGDMomStdUpdateRspImpl(const SGDMomParam& param,
950 const OpContext& ctx,
951 const NDArray& weight,
952 const NDArray& grad,
953 const NDArray& mom,
954 const OpReqType& req,
955 NDArray *out) {
956 using namespace mxnet_op;
957 using namespace rowsparse;
958 CheckAllRowsPresent(weight, "SGDMomUpdate", "weights");
959 Stream<xpu>* s = ctx.get_stream<xpu>();
960 // fill mom with zero values (if it's in rsp storage)
961 // in order to reuse the sgd mom dns impl
962 if (mom.storage_type() == kRowSparseStorage && !mom.storage_initialized()) {
963 NDArray mom_zeros = mom;
964 FillDnsZerosRspImpl(s, &mom_zeros);
965 }
966 TBlob out_blob = out->data();
967 SGDMomStdUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
968 mom.data(), req, &out_blob);
969 }
970
971 template<typename xpu>
972 inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs,
973 const OpContext &ctx,
974 const std::vector<NDArray> &inputs,
975 const std::vector<OpReqType> &req,
976 const std::vector<NDArray> &outputs) {
977 using namespace mxnet_op;
978 const SGDMomParam& param = nnvm::get<SGDMomParam>(attrs.parsed);
979 auto &weight = inputs[0];
980 auto &grad = inputs[1];
981 auto &mom = inputs[2];
982 const auto w_stype = weight.storage_type();
983 const auto m_stype = mom.storage_type();
984 const auto out_stype = outputs[0].storage_type();
985 NDArray out = outputs[0];
986 const bool valid_weight = w_stype == kDefaultStorage || w_stype == kRowSparseStorage;
987 const bool valid_grad = grad.storage_type() == kRowSparseStorage;
988 const bool lazy_update = param.lazy_update;
989 CHECK(w_stype == out_stype) << "Inconsistent weight stype and output stype";
990 if (valid_weight && valid_grad && m_stype == w_stype) {
991 if (lazy_update) {
992 // rsp grad && m.stype = w.stype && lazy_update = true -> lazy update
993 SGDMomLazyUpdateRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
994 } else {
995 // rsp grad && m.stype = w.stype && lazy_update = false -> std update
996 SGDMomStdUpdateRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
997 }
998 } else if (w_stype == kRowSparseStorage && valid_grad && m_stype == kDefaultStorage) {
999 // rsp weight, rsp grad, dns state -> std update
1000 SGDMomStdUpdateRspImpl<xpu>(param, ctx, weight, grad, mom, req[0], &out);
1001 } else {
1002 LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
1003 }
1004 }
1005
1006
1007 struct NAGParam : public dmlc::Parameter<NAGParam> {
1008 float lr;
1009 float wd;
1010 float rescale_grad;
1011 float clip_gradient;
1012 DMLC_DECLARE_PARAMETER(NAGParam) {
1013 DMLC_DECLARE_FIELD(lr)
1014 .describe("Learning rate");
1015 DMLC_DECLARE_FIELD(wd)
1016 .set_default(0.0f)
1017 .describe("Weight decay augments the objective function with a "
1018 "regularization term that penalizes large weights. "
1019 "The penalty scales with the square of the magnitude "
1020 "of each weight.");
1021 DMLC_DECLARE_FIELD(rescale_grad)
1022 .set_default(1.0f)
1023 .describe("Rescale gradient to grad = rescale_grad*grad.");
1024 DMLC_DECLARE_FIELD(clip_gradient)
1025 .set_default(-1.0f)
1026 .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
1027 "If clip_gradient <= 0, gradient clipping is turned off. "
1028 "grad = max(min(grad, clip_gradient), -clip_gradient).");
1029 }
1030 };
1031
1032 struct NAGMomParam : public dmlc::Parameter<NAGMomParam> {
1033 float lr;
1034 float momentum;
1035 float wd;
1036 float rescale_grad;
1037 float clip_gradient;
1038 DMLC_DECLARE_PARAMETER(NAGMomParam) {
1039 DMLC_DECLARE_FIELD(lr)
1040 .describe("Learning rate");
1041 DMLC_DECLARE_FIELD(momentum)
1042 .set_default(0.0f)
1043 .describe("The decay rate of momentum estimates at each epoch.");
1044 DMLC_DECLARE_FIELD(wd)
1045 .set_default(0.0f)
1046 .describe("Weight decay augments the objective function with a "
1047 "regularization term that penalizes large weights. "
1048 "The penalty scales with the square of the magnitude "
1049 "of each weight.");
1050 DMLC_DECLARE_FIELD(rescale_grad)
1051 .set_default(1.0f)
1052 .describe("Rescale gradient to grad = rescale_grad*grad.");
1053 DMLC_DECLARE_FIELD(clip_gradient)
1054 .set_default(-1.0f)
1055 .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
1056 "If clip_gradient <= 0, gradient clipping is turned off. "
1057 "grad = max(min(grad, clip_gradient), -clip_gradient).");
1058 }
1059 };
1060
1061 struct NAGMomKernel {
1062 template<typename DType>
1063 MSHADOW_XINLINE static void Map(index_t i, DType* out_data, DType* mom_data,
1064 const DType* weight_data, const DType* grad_data,
1065 const DType param_clip_gradient, const DType param_momentum,
1066 const DType param_lr, const DType param_wd,
1067 const DType param_rescale_grad, const OpReqType req) {
1068 if (param_clip_gradient >= 0.0f) {
1069 mom_data[i] = param_momentum*mom_data[i];
1070 KERNEL_ASSIGN(out_data[i], req, weight_data[i]-mom_data[i]+(param_momentum+1)
1071 *(mom_data[i]-(param_lr*(mshadow_op::clip::Map(param_rescale_grad
1072 *grad_data[i], param_clip_gradient)+(param_wd*weight_data[i])))));
1073 mom_data[i] = mom_data[i] - (param_lr*((mshadow_op::clip::Map(param_rescale_grad*grad_data[i],
1074 param_clip_gradient))+(param_wd*weight_data[i])));
1075 } else {
1076 mom_data[i] = param_momentum*mom_data[i];
1077 KERNEL_ASSIGN(out_data[i], req, weight_data[i]-mom_data[i]+(param_momentum+1)
1078 *(mom_data[i]-(param_lr*(param_rescale_grad*grad_data[i]+param_wd*weight_data[i]))));
1079 mom_data[i] = mom_data[i] - param_lr*((param_rescale_grad*grad_data[i])
1080 +(param_wd*weight_data[i]));
1081 }
1082 }
1083 };
1084
1085 template<typename xpu>
1086 inline void NAGMomUpdate(const nnvm::NodeAttrs& attrs,
1087 const OpContext &ctx,
1088 const std::vector<TBlob> &inputs,
1089 const std::vector<OpReqType> &req,
1090 const std::vector<TBlob> &outputs) {
1091 using namespace mxnet_op;
1092 NAGMomParam param = nnvm::get<NAGMomParam>(attrs.parsed);
1093 Stream<xpu>* s = ctx.get_stream<xpu>();
1094 MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1095 Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
1096 Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
1097 Tensor<xpu, 2, DType> mom = inputs[2].FlatTo2D<xpu, DType>(s);
1098 Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
1099 Kernel<NAGMomKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_,
1100 mom.dptr_, weight.dptr_, grad.dptr_,
1101 static_cast<DType>(param.clip_gradient),
1102 static_cast<DType>(param.momentum), static_cast<DType>(param.lr),
1103 static_cast<DType>(param.wd), static_cast<DType>(param.rescale_grad),
1104 req[0]);
1105 });
1106 }
1107
1108 struct MP_NAGMomKernel {
1109 template<typename DType>
1110 MSHADOW_XINLINE static void Map(index_t i, DType* out_data,
1111 float* mom_data, const DType* weight_data,
1112 const DType* grad_data, float* weight32,
1113 const float param_clip_gradient,
1114 const float param_momentum, const float param_lr,
1115 const float param_wd, const float param_rescale_grad,
1116 const OpReqType req) {
1117 float w = weight32[i];
1118 if (param_clip_gradient >= 0.0f) {
1119 mom_data[i] = param_momentum*mom_data[i];
1120 w = w-mom_data[i]+(param_momentum+1)*(mom_data[i]-param_lr
1121 *(mshadow_op::clip::Map(param_rescale_grad*static_cast<float>(grad_data[i]),
1122 param_clip_gradient)+(param_wd*w)));
1123 mom_data[i] = mom_data[i] - param_lr
1124 *((mshadow_op::clip::Map(param_rescale_grad*static_cast<float>(grad_data[i]),
1125 param_clip_gradient))+(param_wd*w));
1126 weight32[i] = w;
1127 KERNEL_ASSIGN(out_data[i], req, w);
1128 } else {
1129 mom_data[i] = param_momentum*mom_data[i];
1130 w = w-mom_data[i]+(param_momentum+1)*(mom_data[i]-param_lr
1131 *(param_rescale_grad*static_cast<float>(grad_data[i])+(param_wd*w)));
1132 mom_data[i] = mom_data[i] - param_lr
1133 *((param_rescale_grad*static_cast<float>(grad_data[i]))+(param_wd*w));
1134 weight32[i] = w;
1135 KERNEL_ASSIGN(out_data[i], req, w);
1136 }
1137 }
1138 };
1139
1140 template<typename xpu>
1141 inline void MP_NAGMomUpdate(const nnvm::NodeAttrs& attrs,
1142 const OpContext &ctx,
1143 const std::vector<TBlob> &inputs,
1144 const std::vector<OpReqType> &req,
1145 const std::vector<TBlob> &outputs) {
1146 using namespace mxnet_op;
1147 NAGMomParam param = nnvm::get<NAGMomParam>(attrs.parsed);
1148 Stream<xpu>* s = ctx.get_stream<xpu>();
1149 MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1150 Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
1151 Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
1152 Tensor<xpu, 2, float> mom = inputs[2].FlatTo2D<xpu, float>(s);
1153 Tensor<xpu, 2, float> weight32 = inputs[3].FlatTo2D<xpu, float>(s);
1154 Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
1155 Kernel<MP_NAGMomKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_,
1156 mom.dptr_, weight.dptr_, grad.dptr_, weight32.dptr_,
1157 param.clip_gradient, param.momentum, param.lr, param.wd,
1158 param.rescale_grad, req[0]);
1159 });
1160 }
1161
1162
1163 struct FTMLParam : public dmlc::Parameter<FTMLParam> {
1164 float lr;
1165 float beta1;
1166 float beta2;
1167 double epsilon;
1168 int t;
1169 float wd;
1170 float rescale_grad;
1171 float clip_grad;
1172 DMLC_DECLARE_PARAMETER(FTMLParam) {
1173 DMLC_DECLARE_FIELD(lr)
1174 .describe("Learning rate.");
1175 DMLC_DECLARE_FIELD(beta1)
1176 .set_default(0.6f)
1177 .set_range(0.0f, 1.0f)
1178 .describe("Generally close to 0.5.");
1179 DMLC_DECLARE_FIELD(beta2)
1180 .set_default(0.999f)
1181 .set_range(0.0f, 1.0f)
1182 .describe("Generally close to 1.");
1183 DMLC_DECLARE_FIELD(epsilon)
1184 .set_default(1e-8f)
1185 .describe("Epsilon to prevent div 0.");
1186 DMLC_DECLARE_FIELD(t)
1187 .describe("Number of update.");
1188 DMLC_DECLARE_FIELD(wd)
1189 .set_default(0.0f)
1190 .describe("Weight decay augments the objective function with a "
1191 "regularization term that penalizes large weights. "
1192 "The penalty scales with the square of the magnitude of each weight.");
1193 DMLC_DECLARE_FIELD(rescale_grad)
1194 .set_default(1.0f)
1195 .describe("Rescale gradient to grad = rescale_grad*grad.");
1196 DMLC_DECLARE_FIELD(clip_grad)
1197 .set_default(-1.0f)
1198 .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
1199 "If clip_gradient <= 0, gradient clipping is turned off. "
1200 "grad = max(min(grad, clip_gradient), -clip_gradient).");
1201 }
1202 };
1203
1204
1205 struct FTMLKernel {
1206 template<typename DType>
1207 MSHADOW_XINLINE static void Map(index_t i, DType* out, DType* weight, DType* grad,
1208 DType* d, DType* v, DType* z, const DType lr, const DType beta1,
1209 const DType beta2, const DType epsilon, const DType t,
1210 const DType wd, const DType rescale_grad, const DType clip_grad,
1211 const OpReqType req) {
1212 using namespace mshadow_op;
1213 const DType grad_i = clip_grad >= 0.0f
1214 ? clip::Map(rescale_grad * grad[i] + wd * weight[i], clip_grad)
1215 : (rescale_grad * grad[i] + wd * weight[i]);
1216 v[i] = beta2 * v[i] + (1 - beta2) * square::Map(grad_i);
1217 const DType d_t = (1 - power::Map(beta1, t)) / lr *
1218 (square_root::Map(v[i] / (1 - power::Map(beta2, t))) + epsilon);
1219 z[i] = beta1 * z[i] + (1 - beta1) * grad_i - (d_t - beta1 * d[i]) * weight[i];
1220 d[i] = d_t;
1221 KERNEL_ASSIGN(out[i], req, - z[i] / d_t);
1222 }
1223 };
1224
1225
1226 template<typename xpu>
1227 inline void FTMLUpdate(const nnvm::NodeAttrs& attrs,
1228 const OpContext &ctx,
1229 const std::vector<TBlob> &inputs,
1230 const std::vector<OpReqType> &req,
1231 const std::vector<TBlob> &outputs) {
1232 using namespace mxnet_op;
1233 FTMLParam param = nnvm::get<FTMLParam>(attrs.parsed);
1234 Stream<xpu>* s = ctx.get_stream<xpu>();
1235 MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1236 DType* weight_data = inputs[0].dptr<DType>();
1237 DType* grad_data = inputs[1].dptr<DType>();
1238 DType* d_data = inputs[2].dptr<DType>();
1239 DType* v_data = inputs[3].dptr<DType>();
1240 DType* z_data = inputs[4].dptr<DType>();
1241 DType* out_data = outputs[0].dptr<DType>();
1242 Kernel<FTMLKernel, xpu>::Launch(s, inputs[0].shape_.Size(), out_data,
1243 weight_data, grad_data, d_data, v_data, z_data, static_cast<DType>(param.lr),
1244 static_cast<DType>(param.beta1), static_cast<DType>(param.beta2),
1245 static_cast<DType>(param.epsilon), static_cast<DType>(param.t), static_cast<DType>(param.wd),
1246 static_cast<DType>(param.rescale_grad), static_cast<DType>(param.clip_grad),
1247 req[0]);
1248 });
1249 }
1250
1251 struct AdamParam : public dmlc::Parameter<AdamParam> {
1252 float lr;
1253 float beta1;
1254 float beta2;
1255 float epsilon;
1256 float wd;
1257 float rescale_grad;
1258 float clip_gradient;
1259 bool lazy_update;
1260 DMLC_DECLARE_PARAMETER(AdamParam) {
1261 DMLC_DECLARE_FIELD(lr)
1262 .describe("Learning rate");
1263 DMLC_DECLARE_FIELD(beta1)
1264 .set_default(0.9f)
1265 .describe("The decay rate for the 1st moment estimates.");
1266 DMLC_DECLARE_FIELD(beta2)
1267 .set_default(0.999f)
1268 .describe("The decay rate for the 2nd moment estimates.");
1269 DMLC_DECLARE_FIELD(epsilon)
1270 .set_default(1e-8f)
1271 .describe("A small constant for numerical stability.");
1272 DMLC_DECLARE_FIELD(wd)
1273 .set_default(0.0f)
1274 .describe("Weight decay augments the objective function with a "
1275 "regularization term that penalizes large weights. "
1276 "The penalty scales with the square of the magnitude of each weight.");
1277 DMLC_DECLARE_FIELD(rescale_grad)
1278 .set_default(1.0f)
1279 .describe("Rescale gradient to grad = rescale_grad*grad.");
1280 DMLC_DECLARE_FIELD(clip_gradient)
1281 .set_default(-1.0f)
1282 .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
1283 "If clip_gradient <= 0, gradient clipping is turned off. "
1284 "grad = max(min(grad, clip_gradient), -clip_gradient).");
1285 DMLC_DECLARE_FIELD(lazy_update)
1286 .set_default(true)
1287 .describe("If true, lazy updates are applied if gradient's stype is row_sparse "
1288 "and all of w, m and v have the same stype");
1289 }
1290 };
1291
1292 struct AdamUpdateKernel {
1293 template<typename DType>
1294 MSHADOW_XINLINE static void Map(index_t i, DType* out_data,
1295 DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data,
1296 const DType clip_gradient, const DType rescale_grad,
1297 const DType beta1, const DType beta2,
1298 const DType lr, const DType wd,
1299 const DType epsilon, const OpReqType req) {
1300 using namespace mshadow_op;
1301
1302 DType grad_rescaled = grad_data[i] * rescale_grad + weight_data[i] * wd;
1303 if (clip_gradient >= 0.f) {
1304 grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
1305 }
1306
1307 mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled;
1308 var_data[i] = beta2 * var_data[i] +
1309 (1.f - beta2) * grad_rescaled * grad_rescaled;
1310
1311 KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * mean_data[i] /
1312 (square_root::Map(var_data[i]) + epsilon));
1313 }
1314 };
1315
1316 template<typename xpu>
1317 inline void AdamUpdate(const nnvm::NodeAttrs& attrs,
1318 const OpContext &ctx,
1319 const std::vector<TBlob> &inputs,
1320 const std::vector<OpReqType> &req,
1321 const std::vector<TBlob> &outputs) {
1322 using namespace mxnet_op;
1323 const AdamParam& param = nnvm::get<AdamParam>(attrs.parsed);
1324 Stream<xpu>* s = ctx.get_stream<xpu>();
1325 MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1326 Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
1327 Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
1328 Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s);
1329 Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
1330 Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
1331
1332 Kernel<AdamUpdateKernel, xpu>::Launch(s, weight.shape_.Size(),
1333 out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_,
1334 static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
1335 static_cast<DType>(param.beta1), static_cast<DType>(param.beta2),
1336 static_cast<DType>(param.lr), static_cast<DType>(param.wd),
1337 static_cast<DType>(param.epsilon), req[0]);
1338 });
1339 }
1340
1341 template<int req, typename xpu>
1342 struct AdamDnsRspDnsKernel;
1343
1344 /*!
1345 * Note: this kernel performs sparse adam update. For each row-slice in row_sparse
1346 * gradient, it finds the corresponding elements in weight, mean and var and performs
1347 * the update.
1348 * The kernel assumes dense weight/mean/var, and row_sparse gradient
1349 */
1350 template<int req>
1351 struct AdamDnsRspDnsKernel<req, cpu> {
1352 template<typename DType, typename IType>
1353 MSHADOW_XINLINE static void Map(index_t i, const nnvm::dim_t row_length, DType* out_data,
1354 DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx,
1355 const DType* grad_data, const DType clip_gradient, const DType beta1, const DType beta2,
1356 const DType lr, const DType wd, const DType epsilon, const DType rescale_grad) {
1357 using nnvm::dim_t;
1358 using namespace mshadow_op;
1359 const dim_t row_offset = grad_idx[i] * row_length;
1360 for (dim_t j = 0; j < row_length; j++) {
1361 // index in data/mean/var
1362 const dim_t data_i = row_offset + j;
1363 // index in grad
1364 const dim_t grad_i = i * row_length + j;
1365 const DType grad_rescaled = grad_data[grad_i] * rescale_grad + weight_data[data_i] * wd;
1366 if (clip_gradient >= 0.0f) {
1367 mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) *
1368 clip::Map(grad_rescaled, clip_gradient);
1369 var_data[data_i] = beta2 * var_data[data_i] + (1.f - beta2) * square::Map(
1370 clip::Map(grad_rescaled, clip_gradient));
1371 } else {
1372 mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * grad_rescaled;
1373 var_data[data_i] = beta2 * var_data[data_i] +
1374 (1.f - beta2) * grad_rescaled * grad_rescaled;
1375 }
1376 KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] - lr * mean_data[data_i] /
1377 (square_root::Map(var_data[data_i]) + epsilon));
1378 }
1379 }
1380 };
1381
1382
1383 template<int req>
1384 struct AdamDnsRspDnsKernel<req, gpu> {
1385 template<typename DType, typename IType>
1386 MSHADOW_XINLINE static void Map(index_t i, const nnvm::dim_t row_length, DType* out_data,
1387 DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx,
1388 const DType* grad_data, const DType clip_gradient, const DType beta1, const DType beta2,
1389 const DType lr, const DType wd, const DType epsilon, const DType rescale_grad) {
1390 using nnvm::dim_t;
1391 using namespace mshadow_op;
1392 const dim_t row_id = i / row_length;
1393 const dim_t col_id = i % row_length;
1394 const dim_t row_offset = grad_idx[row_id] * row_length;
1395 // index in data/mean/var
1396 const dim_t data_i = row_offset + col_id;
1397 // index in grad
1398 DType grad_rescaled = grad_data[i] * rescale_grad + weight_data[data_i] * wd;
1399 if (clip_gradient >= 0.0f) {
1400 grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
1401 }
1402 mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * grad_rescaled;
1403 var_data[data_i] = beta2 * var_data[data_i] +
1404 (1.f - beta2) * grad_rescaled * grad_rescaled;
1405 KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] - lr * mean_data[data_i] /
1406 (square_root::Map(var_data[data_i]) + epsilon));
1407 }
1408 };
1409
1410 /*
1411 * \brief lazy adam update for dense weight, dense states and rsp grad.
1412 */
1413 template<typename xpu>
1414 inline void AdamLazyUpdateDnsRspDnsImpl(const AdamParam& param,
1415 const OpContext& ctx,
1416 const TBlob& weight,
1417 const NDArray& grad,
1418 const TBlob& mean,
1419 const TBlob& var,
1420 const OpReqType& req,
1421 TBlob *out) {
1422 using namespace mxnet_op;
1423 using namespace rowsparse;
1424 Stream<xpu>* s = ctx.get_stream<xpu>();
1425 if (!grad.storage_initialized() || req == kNullOp) return;
1426 CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse adam_update";
1427 CHECK_GT(weight.shape_.Size(), 0);
1428 CHECK_GT(mean.shape_.Size(), 0);
1429 CHECK_GT(var.shape_.Size(), 0);
1430
1431 MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
1432 MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
1433 MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
1434 const DType* weight_data = weight.dptr<DType>();
1435 const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
1436 const DType* grad_val = grad.data().dptr<DType>();
1437 DType* mean_data = mean.dptr<DType>();
1438 DType* var_data = var.dptr<DType>();
1439 DType* out_data = out->dptr<DType>();
1440 nnvm::dim_t num_rows = grad.aux_shape(kIdx)[0];
1441 const auto row_length = weight.shape_.ProdShape(1, weight.ndim());
1442 size_t num_threads = num_rows;
1443 if (std::is_same<xpu, gpu>::value) {
1444 num_threads = num_rows * row_length;
1445 }
1446 Kernel<AdamDnsRspDnsKernel<req_type, xpu>, xpu>::Launch(s, num_threads,
1447 row_length, out_data, mean_data, var_data, weight_data, grad_idx, grad_val,
1448 static_cast<DType>(param.clip_gradient), static_cast<DType>(param.beta1),
1449 static_cast<DType>(param.beta2), static_cast<DType>(param.lr),
1450 static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
1451 static_cast<DType>(param.rescale_grad));
1452 });
1453 });
1454 });
1455 }
1456
1457 /*
1458 * \brief lazy adam update for both row_sparse and dense weight.
1459 * grad is expected to be row_sparse.
1460 */
1461 template<typename xpu>
1462 inline void AdamLazyUpdateRspImpl(const AdamParam& param,
1463 const OpContext& ctx,
1464 const NDArray& weight,
1465 const NDArray& grad,
1466 const NDArray& mean,
1467 const NDArray& var,
1468 const OpReqType& req,
1469 NDArray *out) {
1470 using namespace mxnet_op;
1471 using namespace rowsparse;
1472 CheckAllRowsPresent(weight, "AdamUpdate", "weights");
1473 Stream<xpu>* s = ctx.get_stream<xpu>();
1474 // fill mean and variance with zero values in order to reuse the sgd mom dns impl
1475 if (mean.storage_type() == kRowSparseStorage && !mean.storage_initialized()) {
1476 NDArray mean_zeros = mean;
1477 FillDnsZerosRspImpl(s, &mean_zeros);
1478 }
1479 if (var.storage_type() == kRowSparseStorage && !var.storage_initialized()) {
1480 NDArray var_zeros = var;
1481 FillDnsZerosRspImpl(s, &var_zeros);
1482 }
1483 TBlob out_blob = out->data();
1484 // reuse dns rsp implementation when storage_shape == shape
1485 AdamLazyUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, mean.data(),
1486 var.data(), req, &out_blob);
1487 }
1488
1489 /*
1490 * \brief kernel for standard adam update for dense weight, row_sparse grad and dense states.
1491 */
1492 template<int req, typename xpu>
1493 struct AdamStdDnsRspDnsKernel;
1494
1495 /*
1496 * \brief standard adam update for dense weight, row_sparse grad and dense states.
1497 */
1498 template<typename xpu>
1499 void AdamStdUpdateDnsRspDnsImpl(const AdamParam& param,
1500 const OpContext& ctx,
1501 const TBlob& weight,
1502 const NDArray& grad,
1503 const TBlob& mean,
1504 const TBlob& var,
1505 const OpReqType& req,
1506 TBlob *out);
1507
1508 /*
1509 * \brief standard adam update for both row_sparse and dense weight.
1510 * states are expected to be dense, while grad is expected to be row_sparse.
1511 */
1512 template<typename xpu>
1513 inline void AdamStdUpdateRspImpl(const AdamParam& param,
1514 const OpContext& ctx,
1515 const NDArray& weight,
1516 const NDArray& grad,
1517 const NDArray& mean,
1518 const NDArray& var,
1519 const OpReqType& req,
1520 NDArray *out) {
1521 using namespace mxnet_op;
1522 using namespace rowsparse;
1523 CheckAllRowsPresent(weight, "AdamStdUpdate", "weights");
1524 TBlob out_blob = out->data();
1525 // reuse dns rsp implementation when storage_shape == shape
1526 AdamStdUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, mean.data(),
1527 var.data(), req, &out_blob);
1528 }
1529
1530 template<typename xpu>
1531 inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
1532 const OpContext &ctx,
1533 const std::vector<NDArray> &inputs,
1534 const std::vector<OpReqType> &req,
1535 const std::vector<NDArray> &outputs) {
1536 const AdamParam& param = nnvm::get<AdamParam>(attrs.parsed);
1537 const auto w_stype = inputs[0].storage_type();
1538 const auto g_stype = inputs[1].storage_type();
1539 const auto m_stype = inputs[2].storage_type();
1540 const auto v_stype = inputs[3].storage_type();
1541 const auto out_stype = outputs[0].storage_type();
1542 NDArray out = outputs[0];
1543 const bool valid_weight = w_stype == kDefaultStorage || w_stype == kRowSparseStorage;
1544 CHECK(w_stype == out_stype) << "Inconsistent weight stype and output stype";
1545 CHECK(m_stype == v_stype) << "Inconsistent mean stype and var stype";
1546 if (valid_weight && g_stype == kRowSparseStorage && m_stype == w_stype) {
1547 if (param.lazy_update) {
1548 // rsp grad && m.stype = w.stype && lazy_update = true -> lazy update
1549 AdamLazyUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
1550 inputs[3], req[0], &out);
1551 } else {
1552 // rsp grad && m.stype = w.stype && lazy_update = false -> std update
1553 AdamStdUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
1554 inputs[3], req[0], &out);
1555 }
1556 } else if (w_stype == kRowSparseStorage && g_stype == kRowSparseStorage &&
1557 m_stype == kDefaultStorage) {
1558 // rsp, rsp, dns, dns -> rsp, standard update
1559 AdamStdUpdateRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
1560 inputs[3], req[0], &out);
1561 } else {
1562 LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
1563 }
1564 }
1565
1566 struct LambUpdatePhaseOneParam : public dmlc::Parameter<LambUpdatePhaseOneParam> {
1567 float beta1;
1568 float beta2;
1569 float epsilon;
1570 int t;
1571 bool bias_correction;
1572 float wd;
1573 float rescale_grad;
1574 float clip_gradient;
1575 DMLC_DECLARE_PARAMETER(LambUpdatePhaseOneParam) {
1576 DMLC_DECLARE_FIELD(beta1)
1577 .set_default(0.9f)
1578 .describe("The decay rate for the 1st moment estimates.");
1579 DMLC_DECLARE_FIELD(beta2)
1580 .set_default(0.999f)
1581 .describe("The decay rate for the 2nd moment estimates.");
1582 DMLC_DECLARE_FIELD(epsilon)
1583 .set_default(1e-6f)
1584 .describe("A small constant for numerical stability.");
1585 DMLC_DECLARE_FIELD(t)
1586 .describe("Index update count.");
1587 DMLC_DECLARE_FIELD(bias_correction)
1588 .set_default(true)
1589 .describe("Whether to use bias correction.");
1590 DMLC_DECLARE_FIELD(wd)
1591 .describe("Weight decay augments the objective function with a "
1592 "regularization term that penalizes large weights. "
1593 "The penalty scales with the square of the magnitude of each weight.");
1594 DMLC_DECLARE_FIELD(rescale_grad)
1595 .set_default(1.0f)
1596 .describe("Rescale gradient to grad = rescale_grad*grad.");
1597 DMLC_DECLARE_FIELD(clip_gradient)
1598 .set_default(-1.0f)
1599 .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
1600 "If clip_gradient <= 0, gradient clipping is turned off. "
1601 "grad = max(min(grad, clip_gradient), -clip_gradient).");
1602 }
1603 };
1604
1605 struct LambUpdatePhaseTwoParam : public dmlc::Parameter<LambUpdatePhaseTwoParam> {
1606 float lr;
1607 float lower_bound;
1608 float upper_bound;
1609 DMLC_DECLARE_PARAMETER(LambUpdatePhaseTwoParam) {
1610 DMLC_DECLARE_FIELD(lr)
1611 .describe("Learning rate");
1612 DMLC_DECLARE_FIELD(lower_bound)
1613 .set_default(-1.0f)
1614 .describe("Lower limit of norm of weight. If lower_bound <= 0, Lower limit is not set");
1615 DMLC_DECLARE_FIELD(upper_bound)
1616 .set_default(-1.0f)
1617 .describe("Upper limit of norm of weight. If upper_bound <= 0, Upper limit is not set");
1618 }
1619 };
1620
1621 struct LambUpdatePhaseOneKernel {
1622 template<typename DType>
1623 MSHADOW_XINLINE static void Map(index_t i, DType* out_data,
1624 DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data,
1625 const DType clip_gradient, const DType rescale_grad,
1626 const DType beta1, const DType beta1_t, const DType beta2, const DType beta2_t,
1627 const DType wd, const DType epsilon, const int t,
1628 bool bias_correction, const OpReqType req) {
1629 using namespace mshadow_op;
1630
1631 DType grad_rescaled = grad_data[i] * rescale_grad;
1632 if (clip_gradient >= 0.f) {
1633 grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
1634 }
1635
1636 mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled;
1637 var_data[i] = beta2 * var_data[i] + (1.f - beta2) * grad_rescaled * grad_rescaled;
1638
1639 DType g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd * weight_data[i];
1640
1641 if (bias_correction) {
1642 DType mean_hat = mean_data[i] / (1. - beta1_t);
1643 DType var_hat = var_data[i] / (1 - beta2_t);
1644 g = mean_hat / (square_root::Map(var_hat) + epsilon) + wd * weight_data[i];
1645 }
1646 KERNEL_ASSIGN(out_data[i], req, g);
1647 }
1648 };
1649
1650 template<typename xpu>
1651 inline void LambUpdatePhaseOne(const nnvm::NodeAttrs& attrs,
1652 const OpContext &ctx,
1653 const std::vector<TBlob> &inputs,
1654 const std::vector<OpReqType> &req,
1655 const std::vector<TBlob> &outputs) {
1656 using namespace mxnet_op;
1657 const LambUpdatePhaseOneParam& param = nnvm::get<LambUpdatePhaseOneParam>(attrs.parsed);
1658 Stream<xpu>* s = ctx.get_stream<xpu>();
1659 MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1660 DType beta1_t = std::pow(param.beta1, param.t);
1661 DType beta2_t = std::pow(param.beta2, param.t);
1662 Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
1663 Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
1664 Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s);
1665 Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
1666 Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
1667
1668 Kernel<LambUpdatePhaseOneKernel, xpu>::Launch(s, weight.shape_.Size(),
1669 out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_,
1670 static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
1671 static_cast<DType>(param.beta1), beta1_t, static_cast<DType>(param.beta2), beta2_t,
1672 static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
1673 static_cast<int>(param.t), static_cast<bool>(param.bias_correction), req[0]);
1674 });
1675 }
1676
1677 inline bool LambUpdatePhaseTwoShape(const nnvm::NodeAttrs& attrs,
1678 mxnet::ShapeVector* in_attrs,
1679 mxnet::ShapeVector* out_attrs) {
1680 CHECK_EQ(in_attrs->size(), 4U);
1681 CHECK_EQ(out_attrs->size(), 1U);
1682
1683 mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1);
1684
1685 mxnet::TShape& weight_shape = in_attrs->at(0);
1686 mxnet::TShape& g_shape = in_attrs->at(1);
1687 CHECK_EQ(weight_shape.ndim(), g_shape.ndim())
1688 << "total no. of dimensions for weights and g must match";
1689 for (int i=0; i < weight_shape.ndim(); ++i) {
1690 CHECK_EQ(weight_shape[i], g_shape[i])
1691 << "weight and g dimension size mismatch at " << i << "-th index";
1692 }
1693 mxnet::TShape& r1_shape = in_attrs->at(2);
1694 mxnet::TShape& r2_shape = in_attrs->at(3);
1695 CHECK_EQ(r1_shape[0], 1U) << "r1 shape incorrect";
1696 CHECK_EQ(r2_shape[0], 1U) << "r2 shape incorrect";
1697 for (int i=0; i < expected_out.ndim(); ++i) {
1698 expected_out[i] = weight_shape[i];
1699 }
1700
1701 SHAPE_ASSIGN_CHECK(*out_attrs, 0, expected_out);
1702 return shape_is_known(expected_out);
1703 }
1704
1705 struct LambUpdatePhaseTwoKernel {
1706 template<typename DType>
1707 MSHADOW_XINLINE static void Map(index_t i, DType* out_data,
1708 const DType* weight_data, const DType* g,
1709 const DType* r1, const DType* r2,
1710 DType lr, const DType lower_bound,
1711 const DType upper_bound, const OpReqType req) {
1712 using namespace mshadow_op;
1713
1714 DType new_r1 = r1[0];
1715 if (lower_bound >= 0) {
1716 new_r1 = maximum::Map(new_r1, lower_bound);
1717 }
1718 if (upper_bound >= 0) {
1719 new_r1 = minimum::Map(new_r1, upper_bound);
1720 }
1721 if (new_r1 == 0.0f || r2[0] == 0.0f) {
1722 lr = lr * 1.0f;
1723 } else {
1724 lr = lr * new_r1 / r2[0];
1725 }
1726
1727 KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * g[i]);
1728 }
1729 };
1730
1731 template<typename xpu>
1732 inline void LambUpdatePhaseTwo(const nnvm::NodeAttrs& attrs,
1733 const OpContext &ctx,
1734 const std::vector<TBlob> &inputs,
1735 const std::vector<OpReqType> &req,
1736 const std::vector<TBlob> &outputs) {
1737 using namespace mxnet_op;
1738 const LambUpdatePhaseTwoParam& param = nnvm::get<LambUpdatePhaseTwoParam>(attrs.parsed);
1739 Stream<xpu>* s = ctx.get_stream<xpu>();
1740 MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1741 Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
1742 Tensor<xpu, 2, DType> g = inputs[1].FlatTo2D<xpu, DType>(s);
1743 Tensor<xpu, 2, DType> r1 = inputs[2].FlatTo2D<xpu, DType>(s);
1744 Tensor<xpu, 2, DType> r2 = inputs[3].FlatTo2D<xpu, DType>(s);
1745 Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
1746
1747 Kernel<LambUpdatePhaseTwoKernel, xpu>::Launch(s, weight.shape_.Size(),
1748 out.dptr_, weight.dptr_, g.dptr_, r1.dptr_, r2.dptr_,
1749 static_cast<DType>(param.lr), static_cast<DType>(param.lower_bound),
1750 static_cast<DType>(param.upper_bound), req[0]);
1751 });
1752 }
1753
1754 template<int n_in, int n_out, int total_in>
1755 inline bool MPLambPhaseOneType(const nnvm::NodeAttrs& attrs,
1756 std::vector<int> *in_attrs,
1757 std::vector<int> *out_attrs) {
1758 CHECK_EQ(in_attrs->size(), static_cast<size_t>(total_in)) << " in operator " << attrs.name;
1759 CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator " << attrs.name;
1760 for (int i = 0; i < n_in; ++i) {
1761 TYPE_ASSIGN_CHECK(*in_attrs, i, mshadow::kFloat16);
1762 }
1763 for (int i = n_in; i < total_in; ++i) {
1764 TYPE_ASSIGN_CHECK(*in_attrs, i, mshadow::kFloat32);
1765 }
1766 for (int i = 0; i < n_out; ++i) {
1767 TYPE_ASSIGN_CHECK(*out_attrs, i, mshadow::kFloat32);
1768 }
1769 return true;
1770 }
1771
1772 struct MPLambUpdatePhaseOneKernel {
1773 template<typename DType>
1774 MSHADOW_XINLINE static void Map(index_t i, float* out_data,
1775 float* mean_data, float* var_data, const DType* weight_data,
1776 const DType* grad_data, const float* weight32_data,
1777 const float clip_gradient, const float rescale_grad,
1778 const float beta1_t, const float beta1,
1779 const float beta2_t, const float beta2,
1780 const float wd, const float epsilon, const int t,
1781 bool bias_correction, const OpReqType req) {
1782 using namespace mshadow_op;
1783
1784 float grad_rescaled = grad_data[i] * rescale_grad;
1785 if (clip_gradient >= 0.f) {
1786 grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
1787 }
1788
1789 mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled;
1790 var_data[i] = beta2 * var_data[i] + (1.f - beta2) * grad_rescaled * grad_rescaled;
1791
1792 float g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd * weight32_data[i];
1793
1794 if (bias_correction) {
1795 float mean_hat = mean_data[i] / (1. - beta1_t);
1796 float var_hat = var_data[i] / (1 - beta2_t);
1797 g = mean_hat / (square_root::Map(var_hat) + epsilon) + wd * weight32_data[i];
1798 }
1799 KERNEL_ASSIGN(out_data[i], req, g);
1800 }
1801 };
1802
1803 template<typename xpu>
1804 inline void MPLambUpdatePhaseOne(const nnvm::NodeAttrs& attrs,
1805 const OpContext &ctx,
1806 const std::vector<TBlob> &inputs,
1807 const std::vector<OpReqType> &req,
1808 const std::vector<TBlob> &outputs) {
1809 using namespace mxnet_op;
1810 const LambUpdatePhaseOneParam& param = nnvm::get<LambUpdatePhaseOneParam>(attrs.parsed);
1811 Stream<xpu>* s = ctx.get_stream<xpu>();
1812 MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1813 float beta1_t = std::pow(param.beta1, param.t);
1814 float beta2_t = std::pow(param.beta2, param.t);
1815 Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
1816 Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
1817 Tensor<xpu, 2, float> mean = inputs[2].FlatTo2D<xpu, float>(s);
1818 Tensor<xpu, 2, float> var = inputs[3].FlatTo2D<xpu, float>(s);
1819 Tensor<xpu, 2, float> weight32 = inputs[4].FlatTo2D<xpu, float>(s);
1820 Tensor<xpu, 2, float> out = outputs[0].FlatTo2D<xpu, float>(s);
1821
1822 Kernel<MPLambUpdatePhaseOneKernel, xpu>::Launch(s, weight.shape_.Size(),
1823 out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_, weight32.dptr_,
1824 param.clip_gradient, param.rescale_grad, beta1_t, param.beta1, beta2_t, param.beta2,
1825 param.wd, param.epsilon, param.t, param.bias_correction, req[0]);
1826 });
1827 }
1828
1829 inline bool MPLambUpdatePhaseTwoShape(const nnvm::NodeAttrs& attrs,
1830 mxnet::ShapeVector* in_attrs,
1831 mxnet::ShapeVector* out_attrs) {
1832 CHECK_EQ(in_attrs->size(), 5U);
1833 CHECK_EQ(out_attrs->size(), 1U);
1834
1835 mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1);
1836
1837 mxnet::TShape& weight_shape = in_attrs->at(0);
1838 mxnet::TShape& g_shape = in_attrs->at(1);
1839 mxnet::TShape& weight32_shape = in_attrs->at(4);
1840 CHECK_EQ(weight_shape.ndim(), g_shape.ndim())
1841 << "total no. of dimensions for weights and g must match";
1842 CHECK_EQ(weight_shape.ndim(), weight32_shape.ndim())
1843 << "total no. of dimensions for weights and g must match";
1844 for (int i=0; i < weight_shape.ndim(); ++i) {
1845 CHECK_EQ(weight_shape[i], g_shape[i])
1846 << "weight and g dimension size mismatch at " << i << "-th index";
1847 CHECK_EQ(weight_shape[i], weight32_shape[i])
1848 << "weight and g dimension size mismatch at " << i << "-th index";
1849 }
1850 mxnet::TShape& r1_shape = in_attrs->at(2);
1851 mxnet::TShape& r2_shape = in_attrs->at(3);
1852 CHECK_EQ(r1_shape[0], 1U) << "r1 shape incorrect";
1853 CHECK_EQ(r2_shape[0], 1U) << "r2 shape incorrect";
1854 for (int i=0; i < expected_out.ndim(); ++i) {
1855 expected_out[i] = weight_shape[i];
1856 }
1857
1858 SHAPE_ASSIGN_CHECK(*out_attrs, 0, expected_out);
1859 return shape_is_known(expected_out);
1860 }
1861
1862 struct MPLambUpdatePhaseTwoKernel {
1863 template<typename DType>
1864 MSHADOW_XINLINE static void Map(index_t i, DType* out_data,
1865 const DType* weight_data, const float* g,
1866 const float* r1, const float* r2, const float* weight32_data,
1867 float lr, const float lower_bound,
1868 const float upper_bound, const OpReqType req) {
1869 using namespace mshadow_op;
1870
1871 float new_r1 = r1[0];
1872 if (lower_bound >= 0) {
1873 new_r1 = maximum::Map(new_r1, lower_bound);
1874 }
1875 if (upper_bound >= 0) {
1876 new_r1 = minimum::Map(new_r1, upper_bound);
1877 }
1878 if (new_r1 == 0.0f || r2[0] == 0.0f) {
1879 lr = lr * 1.0f;
1880 } else {
1881 lr = lr * new_r1 / r2[0];
1882 }
1883
1884 KERNEL_ASSIGN(out_data[i], req, weight32_data[i] - lr * g[i]);
1885 }
1886 };
1887
1888 template<typename xpu>
1889 inline void MPLambUpdatePhaseTwo(const nnvm::NodeAttrs& attrs,
1890 const OpContext &ctx,
1891 const std::vector<TBlob> &inputs,
1892 const std::vector<OpReqType> &req,
1893 const std::vector<TBlob> &outputs) {
1894 using namespace mxnet_op;
1895 const LambUpdatePhaseTwoParam& param = nnvm::get<LambUpdatePhaseTwoParam>(attrs.parsed);
1896 Stream<xpu>* s = ctx.get_stream<xpu>();
1897 MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1898 Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
1899 Tensor<xpu, 2, float> g = inputs[1].FlatTo2D<xpu, float>(s);
1900 Tensor<xpu, 2, float> r1 = inputs[2].FlatTo2D<xpu, float>(s);
1901 Tensor<xpu, 2, float> r2 = inputs[3].FlatTo2D<xpu, float>(s);
1902 Tensor<xpu, 2, float> weight32 = inputs[4].FlatTo2D<xpu, float>(s);
1903 Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
1904
1905 Kernel<MPLambUpdatePhaseTwoKernel, xpu>::Launch(s, weight.shape_.Size(),
1906 out.dptr_, weight.dptr_, g.dptr_, r1.dptr_, r2.dptr_, weight32.dptr_,
1907 param.lr, param.lower_bound,
1908 param.upper_bound, req[0]);
1909 });
1910 }
1911
1912 // This RMSProp code follows the version in
1913 // http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45)
1914 // by Alex Graves, 2013.
1915 struct RMSPropAlexParam : public dmlc::Parameter<RMSPropAlexParam> {
1916 float lr;
1917 float gamma1;
1918 float gamma2;
1919 float epsilon;
1920 float wd;
1921 float rescale_grad;
1922 float clip_gradient;
1923 float clip_weights;
1924 DMLC_DECLARE_PARAMETER(RMSPropAlexParam) {
1925 DMLC_DECLARE_FIELD(lr)
1926 .describe("Learning rate");
1927 DMLC_DECLARE_FIELD(gamma1).set_default(0.95f)
1928 .describe("Decay rate.");
1929 DMLC_DECLARE_FIELD(gamma2).set_default(0.9f)
1930 .describe("Decay rate.");
1931 DMLC_DECLARE_FIELD(epsilon).set_default(1e-8f)
1932 .describe("A small constant for numerical stability.");
1933 DMLC_DECLARE_FIELD(wd).set_default(0.0f)
1934 .describe("Weight decay augments the objective function with a "
1935 "regularization term that penalizes large weights. "
1936 "The penalty scales with the square of the magnitude of each weight.");
1937 DMLC_DECLARE_FIELD(rescale_grad)
1938 .set_default(1.0f)
1939 .describe("Rescale gradient to grad = rescale_grad*grad.");
1940 DMLC_DECLARE_FIELD(clip_gradient)
1941 .set_default(-1.0f)
1942 .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
1943 "If clip_gradient <= 0, gradient clipping is turned off. "
1944 "grad = max(min(grad, clip_gradient), -clip_gradient).");
1945 DMLC_DECLARE_FIELD(clip_weights)
1946 .set_default(-1.0f)
1947 .describe("Clip weights to the range of [-clip_weights, clip_weights] "
1948 "If clip_weights <= 0, weight clipping is turned off. "
1949 "weights = max(min(weights, clip_weights), -clip_weights).");
1950 }
1951 };
1952
1953 struct RMSPropAlexUpdateKernel {
1954 template<typename DType>
1955 MSHADOW_XINLINE static void Map(index_t i, DType* out_data,
1956 DType* state_n_data, DType* state_g_data, DType* delta_data,
1957 const DType* weight_data, const DType* grad_data,
1958 const DType clip_gradient, const DType rescale_grad,
1959 const DType gamma1, const DType gamma2,
1960 const DType lr, const DType wd,
1961 const DType clip_weights, const DType epsilon,
1962 const OpReqType req) {
1963 using namespace mshadow_op;
1964
1965 DType grad_rescaled = rescale_grad * grad_data[i] + wd * weight_data[i];
1966 if (clip_gradient >= 0.0f) {
1967 grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
1968 }
1969
1970 state_n_data[i] = (1.f - gamma1) * grad_rescaled * grad_rescaled +
1971 gamma1 * state_n_data[i];
1972 state_g_data[i] = (1.f - gamma1) * grad_rescaled +
1973 gamma1 * state_g_data[i];
1974 delta_data[i] = gamma2 * delta_data[i] -
1975 (lr * (grad_rescaled) /
1976 (square_root::Map(state_n_data[i] -
1977 state_g_data[i] * state_g_data[i] + epsilon)));
1978
1979 if (clip_weights >= 0.0f) {
1980 const DType clipped_weight = clip::Map(weight_data[i] + delta_data[i], clip_weights);
1981 KERNEL_ASSIGN(out_data[i], req, clipped_weight);
1982 } else {
1983 KERNEL_ASSIGN(out_data[i], req, weight_data[i] + delta_data[i]);
1984 }
1985 }
1986 };
1987
1988 template <typename xpu>
1989 inline void RMSPropAlexUpdate(const nnvm::NodeAttrs &attrs,
1990 const OpContext &ctx,
1991 const std::vector<TBlob> &inputs,
1992 const std::vector<OpReqType> &req,
1993 const std::vector<TBlob> &outputs) {
1994 using namespace mxnet_op;
1995 const RMSPropAlexParam ¶m = nnvm::get<RMSPropAlexParam>(attrs.parsed);
1996 Stream<xpu> *s = ctx.get_stream<xpu>();
1997 MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
1998 DType* weight_data = inputs[0].dptr<DType>();
1999 DType* grad_data = inputs[1].dptr<DType>();
2000 DType* state_n_data = inputs[2].dptr<DType>();
2001 DType* state_g_data = inputs[3].dptr<DType>();
2002 DType* delta_data = inputs[4].dptr<DType>();
2003 DType* out_data = outputs[0].dptr<DType>();
2004
2005 Kernel<RMSPropAlexUpdateKernel, xpu>::Launch(s, inputs[0].shape_.Size(),
2006 out_data, state_n_data, state_g_data, delta_data, weight_data, grad_data,
2007 static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
2008 static_cast<DType>(param.gamma1), static_cast<DType>(param.gamma2),
2009 static_cast<DType>(param.lr), static_cast<DType>(param.wd),
2010 static_cast<DType>(param.clip_weights), static_cast<DType>(param.epsilon), req[0]);
2011 });
2012 }
2013
2014 // This RMSProp code follows the version in
2015 // http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf
2016 // by Tieleman & Hinton, 2012
2017 struct RMSPropParam : public dmlc::Parameter<RMSPropParam> {
2018 float lr;
2019 float gamma1;
2020 float epsilon;
2021 float wd;
2022 float rescale_grad;
2023 float clip_gradient;
2024 float clip_weights;
2025 DMLC_DECLARE_PARAMETER(RMSPropParam) {
2026 DMLC_DECLARE_FIELD(lr)
2027 .describe("Learning rate");
2028 DMLC_DECLARE_FIELD(gamma1).set_default(0.95f)
2029 .describe("The decay rate of momentum estimates.");
2030 DMLC_DECLARE_FIELD(epsilon).set_default(1e-8f)
2031 .describe("A small constant for numerical stability.");
2032 DMLC_DECLARE_FIELD(wd).set_default(0.0f)
2033 .describe("Weight decay augments the objective function with a "
2034 "regularization term that penalizes large weights. "
2035 "The penalty scales with the square of the magnitude of each weight.");
2036 DMLC_DECLARE_FIELD(rescale_grad)
2037 .set_default(1.0f)
2038 .describe("Rescale gradient to grad = rescale_grad*grad.");
2039 DMLC_DECLARE_FIELD(clip_gradient)
2040 .set_default(-1.0f)
2041 .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
2042 "If clip_gradient <= 0, gradient clipping is turned off. "
2043 "grad = max(min(grad, clip_gradient), -clip_gradient).");
2044 DMLC_DECLARE_FIELD(clip_weights)
2045 .set_default(-1.0f)
2046 .describe("Clip weights to the range of [-clip_weights, clip_weights] "
2047 "If clip_weights <= 0, weight clipping is turned off. "
2048 "weights = max(min(weights, clip_weights), -clip_weights).");
2049 }
2050 };
2051
2052 struct RMSPropUpdateKernel {
2053 template<typename DType>
2054 MSHADOW_XINLINE static void Map(index_t i,
2055 DType* out_data, DType* state_n_data,
2056 const DType* weight_data, const DType* grad_data,
2057 const DType clip_gradient, const DType rescale_grad,
2058 const DType gamma1, const DType lr, const DType wd,
2059 const DType clip_weights, const DType epsilon,
2060 const OpReqType req) {
2061 using namespace mshadow_op;
2062
2063 DType grad_rescaled = rescale_grad * grad_data[i] + wd * weight_data[i];
2064 if (clip_gradient >= 0.0f) {
2065 grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
2066 }
2067
2068 state_n_data[i] = (1.f - gamma1) * (grad_rescaled * grad_rescaled) + gamma1 * state_n_data[i];
2069
2070 DType weight = weight_data[i] -
2071 lr * (grad_rescaled / square_root::Map(state_n_data[i] + epsilon));
2072 if (clip_weights >= 0.0f) {
2073 weight = clip::Map(weight, clip_weights);
2074 }
2075 KERNEL_ASSIGN(out_data[i], req, weight);
2076 }
2077 };
2078
2079 template <typename xpu>
2080 inline void RMSPropUpdate(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
2081 const std::vector<TBlob> &inputs,
2082 const std::vector<OpReqType> &req,
2083 const std::vector<TBlob> &outputs) {
2084 using namespace mxnet_op;
2085 const RMSPropParam ¶m = nnvm::get<RMSPropParam>(attrs.parsed);
2086 Stream<xpu> *s = ctx.get_stream<xpu>();
2087 MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
2088 DType* weight_data = inputs[0].dptr<DType>();
2089 DType* grad_data = inputs[1].dptr<DType>();
2090 DType* state_n_data = inputs[2].dptr<DType>();
2091 DType* out_data = outputs[0].dptr<DType>();
2092
2093 Kernel<RMSPropUpdateKernel, xpu>::Launch(s, inputs[0].shape_.Size(),
2094 out_data, state_n_data, weight_data, grad_data,
2095 static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
2096 static_cast<DType>(param.gamma1), static_cast<DType>(param.lr), static_cast<DType>(param.wd),
2097 static_cast<DType>(param.clip_weights), static_cast<DType>(param.epsilon), req[0]);
2098 });
2099 }
2100
2101 struct FtrlParam : public dmlc::Parameter<FtrlParam> {
2102 float lr;
2103 float lamda1;
2104 float beta;
2105 float wd;
2106 float rescale_grad;
2107 float clip_gradient;
2108 DMLC_DECLARE_PARAMETER(FtrlParam) {
2109 DMLC_DECLARE_FIELD(lr)
2110 .describe("Learning rate");
2111 DMLC_DECLARE_FIELD(lamda1)
2112 .set_default(0.01f)
2113 .describe("The L1 regularization coefficient.");
2114 DMLC_DECLARE_FIELD(beta)
2115 .set_default(1.0f)
2116 .describe("Per-Coordinate Learning Rate beta.");
2117 DMLC_DECLARE_FIELD(wd)
2118 .set_default(0.0f)
2119 .describe("Weight decay augments the objective function with a "
2120 "regularization term that penalizes large weights. "
2121 "The penalty scales with the square of the magnitude of each weight.");
2122 DMLC_DECLARE_FIELD(rescale_grad)
2123 .set_default(1.0f)
2124 .describe("Rescale gradient to grad = rescale_grad*grad.");
2125 DMLC_DECLARE_FIELD(clip_gradient)
2126 .set_default(-1.0f)
2127 .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
2128 "If clip_gradient <= 0, gradient clipping is turned off. "
2129 "grad = max(min(grad, clip_gradient), -clip_gradient).");
2130 }
2131 };
2132
2133 struct FtrlUpdateKernel {
2134 template<typename DType>
2135 MSHADOW_XINLINE static void Map(index_t i, DType* out_data,
2136 DType* n_data, DType* z_data, const DType* weight_data, const DType* grad_data,
2137 const DType clip_gradient, const DType rescale_grad,
2138 const DType beta, const DType lamda1,
2139 const DType lr, const DType wd,
2140 const OpReqType req) {
2141 using namespace mshadow_op;
2142
2143 DType grad_rescaled = grad_data[i] * rescale_grad;
2144 if (clip_gradient >= 0.0f) {
2145 grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
2146 }
2147
2148 z_data[i] += grad_rescaled - (square_root::Map(n_data[i] +
2149 square::Map(grad_rescaled)) - square_root::Map(n_data[i])) *
2150 weight_data[i] / lr;
2151 n_data[i] += square::Map(grad_rescaled);
2152
2153 KERNEL_ASSIGN(out_data[i], req,
2154 (sign::Map(z_data[i]) * lamda1 - z_data[i]) /
2155 ((beta + square_root::Map(n_data[i])) / lr + wd) *
2156 gt::Map(abs::Map(z_data[i]), lamda1));
2157 }
2158 };
2159
2160 template<typename xpu>
2161 inline void FtrlUpdate(const nnvm::NodeAttrs& attrs,
2162 const OpContext &ctx,
2163 const std::vector<TBlob> &inputs,
2164 const std::vector<OpReqType> &req,
2165 const std::vector<TBlob> &outputs) {
2166 using namespace mxnet_op;
2167
2168 const FtrlParam& param = nnvm::get<FtrlParam>(attrs.parsed);
2169 Stream<xpu>* s = ctx.get_stream<xpu>();
2170 MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
2171 Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
2172 Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
2173 Tensor<xpu, 2, DType> z = inputs[2].FlatTo2D<xpu, DType>(s);
2174 Tensor<xpu, 2, DType> n = inputs[3].FlatTo2D<xpu, DType>(s);
2175 Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
2176
2177 Kernel<FtrlUpdateKernel, xpu>::Launch(s, weight.shape_.Size(),
2178 out.dptr_, n.dptr_, z.dptr_, weight.dptr_, grad.dptr_,
2179 static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
2180 static_cast<DType>(param.beta), static_cast<DType>(param.lamda1),
2181 static_cast<DType>(param.lr), static_cast<DType>(param.wd), req[0]);
2182 });
2183 }
2184
2185 template<int req>
2186 struct FtrlDnsRspDnsKernel {
2187 template<typename DType, typename IType>
2188 MSHADOW_XINLINE static void Map(index_t i, const nnvm::dim_t row_length, DType* out_data,
2189 DType* z_data, DType* n_data, const DType* weight_data, const IType* grad_idx,
2190 const DType* grad_data, const DType clip_gradient, const DType lamda1, const DType beta,
2191 const DType lr, const DType wd, const DType rescale_grad) {
2192 using nnvm::dim_t;
2193 using namespace mshadow_op;
2194 const dim_t row_offset = grad_idx[i] * row_length;
2195 for (dim_t j = 0; j < row_length; j++) {
2196 // index in data/z/n
2197 const dim_t data_i = row_offset + j;
2198 // index in grad
2199 const dim_t grad_i = i * row_length + j;
2200 const DType grad_rescaled = grad_data[grad_i] * rescale_grad;
2201 if (clip_gradient >= 0.0f) {
2202 z_data[data_i] += clip::Map(grad_rescaled, clip_gradient) -
2203 (square_root::Map(n_data[data_i] +
2204 square::Map(clip::Map(grad_rescaled, clip_gradient))) -
2205 square_root::Map(n_data[data_i])) * weight_data[data_i] / lr;
2206 n_data[data_i] += square::Map(clip::Map(grad_rescaled, clip_gradient));
2207 } else {
2208 z_data[data_i] += grad_rescaled - (square_root::Map(n_data[data_i] +
2209 square::Map(grad_rescaled)) - square_root::Map(n_data[data_i])) *
2210 weight_data[data_i] / lr;
2211 n_data[data_i] += square::Map(grad_rescaled);
2212 }
2213 KERNEL_ASSIGN(out_data[data_i], req,
2214 (sign::Map(z_data[data_i]) * lamda1 - z_data[data_i]) /
2215 ((beta + square_root::Map(n_data[data_i])) / lr + wd) *
2216 gt::Map(abs::Map(z_data[data_i]), lamda1));
2217 }
2218 }
2219 };
2220
2221
2222 template<typename xpu>
2223 inline void FtrlUpdateDnsRspDnsImpl(const FtrlParam& param,
2224 const OpContext& ctx,
2225 const TBlob& weight,
2226 const NDArray& grad,
2227 const TBlob& z,
2228 const TBlob& n,
2229 const OpReqType& req,
2230 TBlob *out) {
2231 using namespace mxnet_op;
2232 using namespace rowsparse;
2233 Stream<xpu>* s = ctx.get_stream<xpu>();
2234 if (!grad.storage_initialized() || req == kNullOp) return;
2235 CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse ftrl_update";
2236 CHECK_GT(weight.shape_.Size(), 0);
2237 CHECK_GT(z.shape_.Size(), 0);
2238 CHECK_GT(n.shape_.Size(), 0);
2239
2240 MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
2241 MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
2242 MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
2243 const DType* weight_data = weight.dptr<DType>();
2244 const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
2245 const DType* grad_val = grad.data().dptr<DType>();
2246 DType* z_data = z.dptr<DType>();
2247 DType* n_data = n.dptr<DType>();
2248 DType* out_data = out->dptr<DType>();
2249 nnvm::dim_t num_rows = grad.aux_shape(kIdx)[0];
2250 const auto row_length = weight.shape_.ProdShape(1, weight.ndim());
2251 Kernel<FtrlDnsRspDnsKernel<req_type>, xpu>::Launch(s, num_rows, row_length,
2252 out_data, z_data, n_data, weight_data, grad_idx, grad_val,
2253 static_cast<DType>(param.clip_gradient), static_cast<DType>(param.lamda1),
2254 static_cast<DType>(param.beta), static_cast<DType>(param.lr),
2255 static_cast<DType>(param.wd), static_cast<DType>(param.rescale_grad));
2256 });
2257 });
2258 });
2259 }
2260
2261 template<typename xpu>
2262 inline void FtrlUpdateRspRspRspImpl(const FtrlParam& param,
2263 const OpContext& ctx,
2264 const NDArray& weight,
2265 const NDArray& grad,
2266 const NDArray& z,
2267 const NDArray& n,
2268 const OpReqType& req,
2269 NDArray *out) {
2270 using namespace mshadow;
2271 using namespace mshadow::expr;
2272 using namespace mxnet_op;
2273 using namespace rowsparse;
2274 CheckAllRowsPresent(weight, "FtrlUpdate", "weights");
2275 Stream<xpu>* s = ctx.get_stream<xpu>();
2276 // fill z and n with zero values in order to reuse the sgd mom dns impl
2277 if (!z.storage_initialized()) {
2278 NDArray z_zeros = z;
2279 FillDnsZerosRspImpl(s, &z_zeros);
2280 }
2281 if (!n.storage_initialized()) {
2282 NDArray n_zeros = n;
2283 FillDnsZerosRspImpl(s, &n_zeros);
2284 }
2285 TBlob out_blob = out->data();
2286 // reuse dns rsp implementation when storage_shape == shape
2287 FtrlUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad, z.data(),
2288 n.data(), req, &out_blob);
2289 }
2290
2291 template<typename xpu>
2292 inline void FtrlUpdateEx(const nnvm::NodeAttrs& attrs,
2293 const OpContext &ctx,
2294 const std::vector<NDArray> &inputs,
2295 const std::vector<OpReqType> &req,
2296 const std::vector<NDArray> &outputs) {
2297 const FtrlParam& param = nnvm::get<FtrlParam>(attrs.parsed);
2298 const auto weight_stype = inputs[0].storage_type();
2299 const auto z_stype = inputs[2].storage_type();
2300 const auto n_stype = inputs[3].storage_type();
2301
2302 const auto out_stype = outputs[0].storage_type();
2303 CHECK_EQ(z_stype, weight_stype) << "Inconsistent storage type detected between "
2304 << " z.stype = " << z_stype << " and weight.stype = " << weight_stype;
2305 CHECK_EQ(n_stype, weight_stype) << "Inconsistent storage type detected between "
2306 << " n.stype = " << n_stype << " and weight.stype = " << weight_stype;
2307 if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) && out_stype == kRowSparseStorage) {
2308 NDArray out = outputs[0];
2309 FtrlUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
2310 inputs[3], req[0], &out);
2311 } else {
2312 LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
2313 }
2314 }
2315
2316
2317 // Implementation for signSGD and Signum
2318
2319 struct SignSGDParam : public dmlc::Parameter<SignSGDParam> {
2320 float lr;
2321 float wd;
2322 float rescale_grad;
2323 float clip_gradient;
2324 DMLC_DECLARE_PARAMETER(SignSGDParam) {
2325 DMLC_DECLARE_FIELD(lr)
2326 .describe("Learning rate");
2327 DMLC_DECLARE_FIELD(wd)
2328 .set_default(0.0f)
2329 .describe("Weight decay augments the objective function with a "
2330 "regularization term that penalizes large weights. "
2331 "The penalty scales with the square of the magnitude of each weight.");
2332 DMLC_DECLARE_FIELD(rescale_grad)
2333 .set_default(1.0f)
2334 .describe("Rescale gradient to grad = rescale_grad*grad.");
2335 DMLC_DECLARE_FIELD(clip_gradient)
2336 .set_default(-1.0f)
2337 .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
2338 "If clip_gradient <= 0, gradient clipping is turned off. "
2339 "grad = max(min(grad, clip_gradient), -clip_gradient).");
2340 }
2341 };
2342
2343
2344 struct SignSGDKernel {
2345 template<typename DType>
2346 MSHADOW_XINLINE static void Map(index_t i, DType* out_data, const DType* weight_data,
2347 const DType* grad_data, const DType param_clip_gradient,
2348 const DType param_lr, const DType param_wd, const DType param_rescale_grad,
2349 const OpReqType req) {
2350
2351 // param_clip_gradient has no effect for SignSGD
2352 KERNEL_ASSIGN(out_data[i], req,
2353 (1.f-param_lr*param_wd)*weight_data[i]
2354 - (param_lr)*((grad_data[i] > 0) - (grad_data[i] < 0)));
2355 }
2356 };
2357
2358 template<typename xpu>
2359 inline void SignSGDUpdate(const nnvm::NodeAttrs& attrs,
2360 const OpContext &ctx,
2361 const std::vector<TBlob> &inputs,
2362 const std::vector<OpReqType> &req,
2363 const std::vector<TBlob> &outputs) {
2364 using namespace mxnet_op;
2365 const SignSGDParam& param = nnvm::get<SignSGDParam>(attrs.parsed);
2366 Stream<xpu>* s = ctx.get_stream<xpu>();
2367 MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
2368 Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
2369 Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
2370 Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
2371 Kernel<SignSGDKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, weight.dptr_,
2372 grad.dptr_, static_cast<DType>(param.clip_gradient),
2373 static_cast<DType>(param.lr), static_cast<DType>(param.wd),
2374 static_cast<DType>(param.rescale_grad), req[0]);
2375 });
2376 }
2377
2378
2379 struct SignumParam : public dmlc::Parameter<SignumParam> {
2380 float lr;
2381 float momentum;
2382 float wd;
2383 float rescale_grad;
2384 float clip_gradient;
2385 float wd_lh; // the amount of algorithmic weight decay by Loshchilov and Frank Hutter
2386 DMLC_DECLARE_PARAMETER(SignumParam) {
2387 DMLC_DECLARE_FIELD(lr)
2388 .describe("Learning rate");
2389 DMLC_DECLARE_FIELD(momentum)
2390 .set_default(0.0f)
2391 .describe("The decay rate of momentum estimates at each epoch.");
2392 DMLC_DECLARE_FIELD(wd)
2393 .set_default(0.0f)
2394 .describe("Weight decay augments the objective function with a "
2395 "regularization term that penalizes large weights. "
2396 "The penalty scales with the square of the magnitude of each weight.");
2397 DMLC_DECLARE_FIELD(rescale_grad)
2398 .set_default(1.0f)
2399 .describe("Rescale gradient to grad = rescale_grad*grad.");
2400 DMLC_DECLARE_FIELD(clip_gradient)
2401 .set_default(-1.0f)
2402 .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
2403 "If clip_gradient <= 0, gradient clipping is turned off. "
2404 "grad = max(min(grad, clip_gradient), -clip_gradient).");
2405 DMLC_DECLARE_FIELD(wd_lh)
2406 .set_default(0.0f)
2407 .describe("The amount of weight decay that does not go into gradient/momentum calculations"
2408 "otherwise do weight decay algorithmically only.");
2409 }
2410 };
2411
2412 struct SignumKernel {
2413 template<typename DType>
2414 MSHADOW_XINLINE static void Map(index_t i, DType* out_data, DType* mom_data,
2415 const DType* weight_data, const DType* grad_data,
2416 const DType param_clip_gradient, const DType param_momentum,
2417 const DType param_lr, const DType param_wd,
2418 const DType param_rescale_grad, const DType param_wd_lh,
2419 const OpReqType req) {
2420 if (param_clip_gradient >= 0.0f) {
2421 mom_data[i] = param_momentum*mom_data[i]
2422 - (1-param_momentum)*param_wd*weight_data[i]
2423 - (1-param_momentum)
2424 *mshadow_op::clip::Map(param_rescale_grad*grad_data[i], param_clip_gradient);
2425 } else {
2426 mom_data[i] = param_momentum*mom_data[i]
2427 - (1-param_momentum)*param_wd*weight_data[i]
2428 - (1-param_momentum)*param_rescale_grad*grad_data[i];
2429 }
2430 KERNEL_ASSIGN(out_data[i], req, (1.f-param_lr*param_wd_lh)*weight_data[i]
2431 + (param_lr)*((mom_data[i] > 0) - (mom_data[i] < 0)));
2432 }
2433 };
2434
2435 template<typename xpu>
2436 inline void SignumUpdate(const nnvm::NodeAttrs& attrs,
2437 const OpContext &ctx,
2438 const std::vector<TBlob> &inputs,
2439 const std::vector<OpReqType> &req,
2440 const std::vector<TBlob> &outputs) {
2441 using namespace mxnet_op;
2442 SignumParam param = nnvm::get<SignumParam>(attrs.parsed);
2443 Stream<xpu>* s = ctx.get_stream<xpu>();
2444 MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
2445 Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
2446 Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
2447 Tensor<xpu, 2, DType> mom = inputs[2].FlatTo2D<xpu, DType>(s);
2448 Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);
2449 Kernel<SignumKernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, mom.dptr_, weight.dptr_,
2450 grad.dptr_, static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
2451 static_cast<DType>(param.lr), static_cast<DType>(param.wd),
2452 static_cast<DType>(param.rescale_grad), static_cast<DType>(param.wd_lh), req[0]);
2453 });
2454 }
2455
2456 struct AdagradParam : public dmlc::Parameter<AdagradParam> {
2457 float lr;
2458 float epsilon;
2459 float rescale_grad;
2460 float clip_gradient;
2461 float wd;
2462 DMLC_DECLARE_PARAMETER(AdagradParam) {
2463 DMLC_DECLARE_FIELD(lr)
2464 .describe("Learning rate");
2465 DMLC_DECLARE_FIELD(epsilon)
2466 .set_default(1.0e-7)
2467 .describe("epsilon");
2468 DMLC_DECLARE_FIELD(wd)
2469 .set_default(0.0f)
2470 .describe("weight decay");
2471 DMLC_DECLARE_FIELD(rescale_grad)
2472 .set_default(1.0f)
2473 .describe("Rescale gradient to grad = rescale_grad*grad.");
2474 DMLC_DECLARE_FIELD(clip_gradient)
2475 .set_default(-1.0f)
2476 .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
2477 "If clip_gradient <= 0, gradient clipping is turned off. "
2478 "grad = max(min(grad, clip_gradient), -clip_gradient).");
2479 }
2480 };
2481
2482 inline bool AdagradStorageType(const nnvm::NodeAttrs& attrs,
2483 const int dev_mask,
2484 DispatchMode* dispatch_mode,
2485 std::vector<int>* in_attrs,
2486 std::vector<int>* out_attrs) {
2487 const AdagradParam& param = nnvm::get<AdagradParam>(attrs.parsed);
2488 CHECK_EQ(in_attrs->size(), 3U);
2489 CHECK_EQ(out_attrs->size(), 1U);
2490 const int weight_stype = in_attrs->at(0);
2491 const int grad_stype = in_attrs->at(1);
2492 const int state_stype = in_attrs->at(2);
2493 bool dispatched = false;
2494 if (!dispatched && grad_stype == kRowSparseStorage &&
2495 (weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage) &&
2496 state_stype == weight_stype && param.wd == 0.0f) {
2497 // weight and state share stype, grad's stype = rsp
2498 dispatched = storage_type_assign(
2499 out_attrs, static_cast<NDArrayStorageType>(weight_stype), dispatch_mode,
2500 DispatchMode::kFComputeEx);
2501 }
2502 return dispatched;
2503 }
2504
2505 template<typename xpu>
2506 struct AdagradDnsRspDnsKernel;
2507
2508 template<>
2509 struct AdagradDnsRspDnsKernel<cpu> {
2510 template<typename DType, typename IType>
2511 MSHADOW_XINLINE static void Map(index_t i, index_t row_length, DType* out_data,
2512 DType* state_data, const DType* weight_data, const IType* grad_idx,
2513 const DType* grad_data, const DType clip_gradient, const DType epsilon,
2514 const DType lr, const DType rescale_grad) {
2515 using nnvm::dim_t;
2516 using namespace mshadow_op;
2517 const dim_t data_i = grad_idx[i] * row_length;
2518 const dim_t grad_i = i * row_length;
2519 for (dim_t j = 0; j < row_length; j++) {
2520 const dim_t data_j = data_i + j;
2521 const dim_t grad_j = grad_i + j;
2522 DType grad_rescaled = grad_data[grad_j] * rescale_grad;
2523 if (clip_gradient >= 0.0f) {
2524 grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
2525 }
2526 const DType grad_squared = grad_rescaled * grad_rescaled;
2527 state_data[data_j] += grad_squared;
2528 const DType div = grad_rescaled / square_root::Map(state_data[data_j] + epsilon);
2529 // No need to use KERNEL_ASSIGN, as we already checked req is kWriteInplace
2530 out_data[data_j] = weight_data[data_j] - div * lr;
2531 }
2532 }
2533 };
2534
2535 template<>
2536 struct AdagradDnsRspDnsKernel<gpu> {
2537 template<typename DType, typename IType>
2538 MSHADOW_XINLINE static void Map(index_t i, index_t row_length, DType* out_data,
2539 DType* state_data, const DType* weight_data, const IType* grad_idx,
2540 const DType* grad_data, const DType clip_gradient, const DType epsilon,
2541 const DType lr, const DType rescale_grad) {
2542 using nnvm::dim_t;
2543 using namespace mshadow_op;
2544 const dim_t row_id = i / row_length;
2545 const dim_t col_id = i % row_length;
2546 const dim_t data_i = grad_idx[row_id] * row_length + col_id;
2547 DType grad_rescaled = grad_data[i] * rescale_grad;
2548 if (clip_gradient >= 0.0f) {
2549 grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
2550 }
2551 const DType grad_squared = grad_rescaled * grad_rescaled;
2552 state_data[data_i] += grad_squared;
2553 const DType div = grad_rescaled / square_root::Map(state_data[data_i] + epsilon);
2554 // No need to use KERNEL_ASSIGN, as we already checked req is kWriteInplace
2555 out_data[data_i] = weight_data[data_i] - div * lr;
2556 }
2557 };
2558
2559 template<typename xpu>
2560 void AdagradUpdateDnsRspDnsImpl(const AdagradParam& param,
2561 const OpContext& ctx,
2562 const TBlob& weight,
2563 const NDArray& grad,
2564 const TBlob& state,
2565 const OpReqType& req,
2566 TBlob *out) {
2567 using namespace mxnet_op;
2568 using namespace rowsparse;
2569 using namespace mshadow;
2570 Stream<xpu>* s = ctx.get_stream<xpu>();
2571 CHECK_EQ(param.wd, 0.0f)
2572 << "sparse adagrad_update does not support wd.";
2573 if (req == kNullOp || !grad.storage_initialized()) return;
2574 CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse adagrad_update";
2575 CHECK_GT(weight.shape_.Size(), 0);
2576 CHECK_GT(state.shape_.Size(), 0);
2577 MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
2578 MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
2579 const DType* weight_data = weight.dptr<DType>();
2580 const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
2581 const DType* grad_val = grad.data().dptr<DType>();
2582 DType* state_data = state.dptr<DType>();
2583 DType* out_data = out->dptr<DType>();
2584 const nnvm::dim_t nnr = grad.storage_shape()[0];
2585 const auto row_length = weight.shape_.ProdShape(1, weight.ndim());
2586 size_t num_threads = nnr;
2587 if (std::is_same<xpu, gpu>::value) {
2588 num_threads = nnr * row_length;
2589 }
2590 Kernel<AdagradDnsRspDnsKernel<xpu>, xpu>::Launch(s, num_threads, row_length,
2591 out_data, state_data, weight_data, grad_idx, grad_val,
2592 static_cast<DType>(param.clip_gradient), static_cast<DType>(param.epsilon),
2593 static_cast<DType>(param.lr), static_cast<DType>(param.rescale_grad));
2594 });
2595 });
2596 }
2597
2598 template<typename xpu>
2599 inline void AdagradUpdateRspRspRspImpl(const AdagradParam& param,
2600 const OpContext& ctx,
2601 const NDArray& weight,
2602 const NDArray& grad,
2603 const NDArray& state,
2604 const OpReqType& req,
2605 NDArray *out) {
2606 using namespace mshadow;
2607 using namespace mxnet_op;
2608 using namespace rowsparse;
2609 CheckAllRowsPresent(weight, "AdagradUpdate", "weights");
2610 Stream<xpu>* s = ctx.get_stream<xpu>();
2611 // fill history with zero values
2612 if (!state.storage_initialized()) {
2613 NDArray state_zeros = state;
2614 FillDnsZerosRspImpl(s, &state_zeros);
2615 }
2616 TBlob out_blob = out->data();
2617 // reuse dns rsp implementation when storage_shape == shape
2618 AdagradUpdateDnsRspDnsImpl<xpu>(param, ctx, weight.data(), grad,
2619 state.data(), req, &out_blob);
2620 }
2621
2622 template<typename xpu>
2623 inline void AdagradUpdateEx(const nnvm::NodeAttrs& attrs,
2624 const OpContext &ctx,
2625 const std::vector<NDArray> &inputs,
2626 const std::vector<OpReqType> &req,
2627 const std::vector<NDArray> &outputs) {
2628 using namespace mxnet_op;
2629 const AdagradParam& param = nnvm::get<AdagradParam>(attrs.parsed);
2630
2631 const auto weight_stype = inputs[0].storage_type();
2632 const auto grad_stype = inputs[1].storage_type();
2633 const auto state_stype = inputs[2].storage_type();
2634 const auto output_stype = outputs[0].storage_type();
2635
2636 if (common::ContainsOnlyStorage(inputs, kRowSparseStorage) &&
2637 common::ContainsOnlyStorage(outputs, kRowSparseStorage)) {
2638 NDArray out = outputs[0];
2639 AdagradUpdateRspRspRspImpl<xpu>(param, ctx, inputs[0], inputs[1], inputs[2],
2640 req[0], &out);
2641 } else if (state_stype == weight_stype && output_stype == weight_stype &&
2642 weight_stype == kDefaultStorage &&
2643 grad_stype == kRowSparseStorage) {
2644 TBlob out_blob = outputs[0].data();
2645 AdagradUpdateDnsRspDnsImpl<xpu>(param, ctx, inputs[0].data(), inputs[1],
2646 inputs[2].data(), req[0],
2647 &out_blob);
2648 } else {
2649 LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
2650 }
2651 }
2652
2653 } // namespace op
2654 } // namespace mxnet
2655
2656 #endif // MXNET_OPERATOR_OPTIMIZER_OP_INL_H_
2657