1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file optimizer_op.cc
22 * \brief Optimizer operators
23 * \author Junyuan Xie
24 */
25 #include "./optimizer_op-inl.h"
26 #include "./elemwise_op_common.h"
27
28 namespace mxnet {
29 namespace op {
30
31 DMLC_REGISTER_PARAMETER(SGDParam);
32 DMLC_REGISTER_PARAMETER(SGDMomParam);
33 DMLC_REGISTER_PARAMETER(MultiSGDParam);
34 DMLC_REGISTER_PARAMETER(MultiSGDMomParam);
35 DMLC_REGISTER_PARAMETER(FTMLParam);
36 DMLC_REGISTER_PARAMETER(AdamParam);
37 DMLC_REGISTER_PARAMETER(NAGParam);
38 DMLC_REGISTER_PARAMETER(NAGMomParam);
39 DMLC_REGISTER_PARAMETER(RMSPropParam);
40 DMLC_REGISTER_PARAMETER(RMSPropAlexParam);
41 DMLC_REGISTER_PARAMETER(FtrlParam);
42 DMLC_REGISTER_PARAMETER(SignSGDParam);
43 DMLC_REGISTER_PARAMETER(SignumParam);
44 DMLC_REGISTER_PARAMETER(AdagradParam);
45 DMLC_REGISTER_PARAMETER(LambUpdatePhaseOneParam);
46 DMLC_REGISTER_PARAMETER(LambUpdatePhaseTwoParam);
47
48 NNVM_REGISTER_OP(signsgd_update)
49 .describe(R"code(Update function for SignSGD optimizer.
50
51 .. math::
52
53 g_t = \nabla J(W_{t-1})\\
54 W_t = W_{t-1} - \eta_t \text{sign}(g_t)
55
56 It updates the weights using::
57
58 weight = weight - learning_rate * sign(gradient)
59
60 .. note::
61 - sparse ndarray not supported for this optimizer yet.
62 )code" ADD_FILELINE)
63 .set_num_inputs(2)
64 .set_num_outputs(1)
65 .set_attr_parser(ParamParser<SignSGDParam>)
66 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
67 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
68 .set_attr<FCompute>("FCompute<cpu>", SignSGDUpdate<cpu>)
69 .add_argument("weight", "NDArray-or-Symbol", "Weight")
70 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
71 .add_arguments(SignSGDParam::__FIELDS__());
72
73
74 NNVM_REGISTER_OP(signum_update)
75 .describe(R"code(SIGN momentUM (Signum) optimizer.
76
77 .. math::
78
79 g_t = \nabla J(W_{t-1})\\
80 m_t = \beta m_{t-1} + (1 - \beta) g_t\\
81 W_t = W_{t-1} - \eta_t \text{sign}(m_t)
82
83 It updates the weights using::
84 state = momentum * state + (1-momentum) * gradient
85 weight = weight - learning_rate * sign(state)
86
87 Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch.
88
89 .. note::
90 - sparse ndarray not supported for this optimizer yet.
91 )code" ADD_FILELINE)
92 .set_num_inputs(3)
93 .set_num_outputs(1)
94 .set_attr_parser(ParamParser<SignumParam>)
95 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
96 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
97 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde190102(const nnvm::NodeAttrs& attrs) 98 [](const nnvm::NodeAttrs& attrs) {
99 return std::vector<uint32_t>{2};
100 })
101 .set_attr<FCompute>("FCompute<cpu>", SignumUpdate<cpu>)
102 .add_argument("weight", "NDArray-or-Symbol", "Weight")
103 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
104 .add_argument("mom", "NDArray-or-Symbol", "Momentum")
105 .add_arguments(SignumParam::__FIELDS__());
106
107 template<int req>
108 struct SGDMomStdDnsRspDnsKernel<req, cpu> {
109 template<typename DType, typename IType, typename RType>
Mapmxnet::op::SGDMomStdDnsRspDnsKernel110 MSHADOW_XINLINE static void Map(int i, index_t row_length, DType* out_data,
111 DType* mom_data, const DType* weight_data, const IType* grad_idx,
112 const DType* grad_data, const RType* prefix_sum, const DType clip_gradient,
113 const DType momentum, const DType lr, const DType wd, const DType rescale_grad) {
114 const DType rate = lr * wd;
115 const bool non_zero = (i == 0) ? prefix_sum[0] > 0
116 : prefix_sum[i] > prefix_sum[i-1];
117
118 const index_t row_i = i * row_length;
119 const RType grad_i = (prefix_sum[i]-1) * row_length;
120 for (index_t j = 0; j < row_length; j++) {
121 const index_t data_i = row_i + j;
122 const DType grad = non_zero ? grad_data[grad_i + j]
123 : static_cast<DType>(0);
124 if (clip_gradient >= 0.0f) {
125 mom_data[data_i] = momentum * mom_data[data_i]
126 - rate * weight_data[data_i]
127 - lr *
128 mshadow_op::clip::Map(rescale_grad * grad,
129 clip_gradient);
130 } else {
131 mom_data[data_i] = momentum * mom_data[data_i]
132 - rate * weight_data[data_i]
133 - lr * rescale_grad * grad;
134 }
135 KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] + mom_data[data_i]);
136 }
137 }
138 };
139
140 /*
141 * \brief standard momentum update for dense weight on cpu.
142 * state is expected to be dense, while grad is expected to be row_sparse.
143 */
144 template<>
SGDMomStdUpdateDnsRspDnsImpl(const SGDMomParam & param,const OpContext & ctx,const TBlob & weight,const NDArray & grad,const TBlob & mom,const OpReqType & req,TBlob * out)145 void SGDMomStdUpdateDnsRspDnsImpl<cpu>(const SGDMomParam& param,
146 const OpContext& ctx,
147 const TBlob& weight,
148 const NDArray& grad,
149 const TBlob& mom,
150 const OpReqType& req,
151 TBlob *out) {
152 using namespace mxnet_op;
153 using namespace rowsparse;
154 using namespace mshadow;
155 Stream<cpu>* s = ctx.get_stream<cpu>();
156 if (req == kNullOp) return;
157 CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_mom_update";
158 CHECK_GT(weight.shape_.Size(), 0);
159 CHECK_GT(mom.shape_.Size(), 0);
160 MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
161 MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
162 MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
163 DType* weight_data = weight.dptr<DType>();
164 const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
165 const DType* grad_val = grad.data().dptr<DType>();
166 DType* mom_data = mom.dptr<DType>();
167 DType* out_data = out->dptr<DType>();
168 const nnvm::dim_t num_rows = weight.shape_[0];
169 const auto row_length = weight.shape_.ProdShape(1, weight.ndim());
170 Tensor<cpu, 1, char> workspace = ctx.requested[0]
171 .get_space_typed<cpu, 1, char>(Shape1(num_rows * sizeof(nnvm::dim_t)), s);
172
173 nnvm::dim_t* prefix_sum = reinterpret_cast<nnvm::dim_t*>(workspace.dptr_);
174 // mark row flags
175 Kernel<set_zero, cpu>::Launch(s, num_rows, prefix_sum);
176 if (grad.storage_initialized()) {
177 Kernel<MarkRowFlgKernel, cpu>::Launch(s, grad.aux_shape(kIdx)[0],
178 prefix_sum, grad_idx);
179 // calculate inclusive prefix sum
180 for (nnvm::dim_t i = 1; i < num_rows; i++) {
181 prefix_sum[i] += prefix_sum[i - 1];
182 }
183 }
184 Kernel<SGDMomStdDnsRspDnsKernel<req_type, cpu>, cpu>::Launch(s, num_rows, row_length,
185 out_data, mom_data, weight_data, grad_idx, grad_val, prefix_sum,
186 static_cast<DType>(param.clip_gradient), static_cast<DType>(param.momentum),
187 static_cast<DType>(param.lr), static_cast<DType>(param.wd),
188 static_cast<DType>(param.rescale_grad));
189 });
190 });
191 });
192 }
193
194 template<int req>
195 struct AdamStdDnsRspDnsKernel<req, cpu> {
196 template<typename DType, typename IType, typename RType>
Mapmxnet::op::AdamStdDnsRspDnsKernel197 MSHADOW_XINLINE static void Map(int i, const nnvm::dim_t row_length, DType* out_data,
198 DType* mean_data, DType* var_data, const DType* weight_data, const IType* grad_idx,
199 const DType* grad_data, const RType* prefix_sum, const DType clip_gradient,
200 const DType beta1, const DType beta2, const DType lr, const DType wd,
201 const DType epsilon, const DType rescale_grad) {
202 using namespace mshadow_op;
203 const bool non_zero = (i == 0) ? prefix_sum[0] > 0
204 : prefix_sum[i] > prefix_sum[i-1];
205
206 const index_t row_i = i * row_length;
207 const RType grad_i = (prefix_sum[i]-1) * row_length;
208 for (index_t j = 0; j < row_length; j++) {
209 const index_t data_i = row_i + j;
210 const DType grad_rescaled = non_zero ? static_cast<DType>(
211 grad_data[grad_i + j] * rescale_grad +
212 weight_data[data_i] * wd)
213 : static_cast<DType>(weight_data[data_i] * wd);
214 if (clip_gradient >= 0.0f) {
215 mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) *
216 clip::Map(grad_rescaled, clip_gradient);
217 var_data[data_i] = beta2 * var_data[data_i] + (1.f - beta2) * square::Map(
218 clip::Map(grad_rescaled, clip_gradient));
219 } else {
220 mean_data[data_i] = beta1 * mean_data[data_i] + (1.f - beta1) * grad_rescaled;
221 var_data[data_i] = beta2 * var_data[data_i] +
222 (1.f - beta2) * square::Map(grad_rescaled);
223 }
224 KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] - lr * mean_data[data_i] /
225 (square_root::Map(var_data[data_i]) + epsilon));
226 }
227 }
228 };
229
230
231 template<>
AdamStdUpdateDnsRspDnsImpl(const AdamParam & param,const OpContext & ctx,const TBlob & weight,const NDArray & grad,const TBlob & mean,const TBlob & var,const OpReqType & req,TBlob * out)232 void AdamStdUpdateDnsRspDnsImpl<cpu>(const AdamParam& param,
233 const OpContext& ctx,
234 const TBlob& weight,
235 const NDArray& grad,
236 const TBlob& mean,
237 const TBlob& var,
238 const OpReqType& req,
239 TBlob *out) {
240 using namespace mxnet_op;
241 using namespace rowsparse;
242 using namespace mshadow;
243 Stream<cpu>* s = ctx.get_stream<cpu>();
244 if (req == kNullOp) return;
245 CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse adam_update";
246 CHECK_GT(weight.shape_.Size(), 0);
247 CHECK_GT(mean.shape_.Size(), 0);
248 CHECK_GT(var.shape_.Size(), 0);
249
250 MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, {
251 MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, {
252 MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
253 const DType* weight_data = weight.dptr<DType>();
254 const IType* grad_idx = grad.aux_data(kIdx).dptr<IType>();
255 const DType* grad_val = grad.data().dptr<DType>();
256 DType* mean_data = mean.dptr<DType>();
257 DType* var_data = var.dptr<DType>();
258 DType* out_data = out->dptr<DType>();
259 nnvm::dim_t num_rows = weight.shape_[0];
260 nnvm::dim_t row_length = weight.shape_.ProdShape(1, weight.ndim());
261 Tensor<cpu, 1, char> workspace = ctx.requested[0]
262 .get_space_typed<cpu, 1, char>(Shape1(num_rows * sizeof(nnvm::dim_t)), s);
263
264 nnvm::dim_t* prefix_sum = reinterpret_cast<nnvm::dim_t*>(workspace.dptr_);
265 // mark row flags
266 Kernel<set_zero, cpu>::Launch(s, num_rows, prefix_sum);
267 if (grad.storage_initialized()) {
268 Kernel<MarkRowFlgKernel, cpu>::Launch(s, grad.aux_shape(kIdx)[0],
269 prefix_sum, grad_idx);
270 // calculate inclusive prefix sum
271 for (nnvm::dim_t i = 1; i < num_rows; i++) {
272 prefix_sum[i] += prefix_sum[i - 1];
273 }
274 }
275
276 Kernel<AdamStdDnsRspDnsKernel<req_type, cpu>, cpu>::Launch(s, num_rows, row_length,
277 out_data, mean_data, var_data, weight_data, grad_idx, grad_val, prefix_sum,
278 static_cast<DType>(param.clip_gradient), static_cast<DType>(param.beta1),
279 static_cast<DType>(param.beta2), static_cast<DType>(param.lr),
280 static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
281 static_cast<DType>(param.rescale_grad));
282 });
283 });
284 });
285 }
286
287 /*!
288 * \brief Storge type inference function for SGD.
289 */
SGDStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)290 inline bool SGDStorageType(const nnvm::NodeAttrs& attrs,
291 const int dev_mask,
292 DispatchMode* dispatch_mode,
293 std::vector<int>* in_attrs,
294 std::vector<int>* out_attrs) {
295 using namespace common;
296 const SGDParam& param = nnvm::get<SGDParam>(attrs.parsed);
297 CHECK_EQ(in_attrs->size(), 2U);
298 CHECK_EQ(out_attrs->size(), 1U);
299 const int weight_stype = in_attrs->at(0);
300 const int grad_stype = in_attrs->at(1);
301 bool dispatched = false;
302 if (!dispatched && ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
303 // dns, ... -> dns
304 dispatched = storage_type_assign(out_attrs, kDefaultStorage,
305 dispatch_mode, DispatchMode::kFCompute);
306 }
307 if (!dispatched && grad_stype == kRowSparseStorage &&
308 (weight_stype == kRowSparseStorage || weight_stype == kDefaultStorage)) {
309 // grad's stype = rsp
310 dispatched = storage_type_assign(out_attrs, static_cast<NDArrayStorageType>(weight_stype),
311 dispatch_mode, DispatchMode::kFComputeEx);
312 // warn users if lazy_update is turned on
313 if (dispatched && param.wd != 0 && param.lazy_update) LogLazyUpdate();
314 }
315 if (!dispatched) {
316 dispatched = dispatch_fallback(out_attrs, dispatch_mode);
317 }
318 return dispatched;
319 }
320
321 NNVM_REGISTER_OP(multi_sgd_update)
322 .describe(R"code(Update function for Stochastic Gradient Descent (SDG) optimizer.
323
324 It updates the weights using::
325
326 weight = weight - learning_rate * (gradient + wd * weight)
327
328 )code" ADD_FILELINE)
__anonfeecde190202(const nnvm::NodeAttrs& attrs) 329 .set_num_inputs([](const nnvm::NodeAttrs& attrs) {
330 const MultiSGDParam& param = dmlc::get<MultiSGDParam>(attrs.parsed);
331 return static_cast<uint32_t>(param.num_weights * 2);
332 })
__anonfeecde190302(const nnvm::NodeAttrs& attrs) 333 .set_num_outputs([](const nnvm::NodeAttrs& attrs) {
334 const MultiSGDParam& param = dmlc::get<MultiSGDParam>(attrs.parsed);
335 return static_cast<uint32_t>(param.num_weights);
336 })
337 .set_attr_parser(ParamParser<MultiSGDParam>)
338 .set_attr<mxnet::FInferShape>("FInferShape", MultiSGDShape<MultiSGDParam, 2>)
339 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, -1>)
340 .set_attr<nnvm::FListInputNames>("FListInputNames",
__anonfeecde190402(const NodeAttrs& attrs) 341 [](const NodeAttrs& attrs) {
342 uint32_t num_args = dmlc::get<MultiSGDParam>(attrs.parsed).num_weights;
343 std::vector<std::string> ret;
344 for (uint32_t i = 0; i < num_args; ++i) {
345 ret.push_back(std::string("weight_") + std::to_string(i));
346 ret.push_back(std::string("grad_") + std::to_string(i));
347 }
348 return ret;
349 })
350 .set_attr<FCompute>("FCompute<cpu>", MultiSGDUpdate<cpu, type_identity, 2>)
351 .add_argument("data", "NDArray-or-Symbol[]", "Weights")
352 .add_arguments(MultiSGDParam::__FIELDS__());
353
354 NNVM_REGISTER_OP(multi_sgd_mom_update)
355 .describe(R"code(Momentum update function for Stochastic Gradient Descent (SGD) optimizer.
356
357 Momentum update has better convergence rates on neural networks. Mathematically it looks
358 like below:
359
360 .. math::
361
362 v_1 = \alpha * \nabla J(W_0)\\
363 v_t = \gamma v_{t-1} - \alpha * \nabla J(W_{t-1})\\
364 W_t = W_{t-1} + v_t
365
366 It updates the weights using::
367
368 v = momentum * v - learning_rate * gradient
369 weight += v
370
371 Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch.
372
373 )code" ADD_FILELINE)
__anonfeecde190502(const nnvm::NodeAttrs& attrs) 374 .set_num_inputs([](const nnvm::NodeAttrs& attrs) {
375 const MultiSGDMomParam& param = dmlc::get<MultiSGDMomParam>(attrs.parsed);
376 return static_cast<uint32_t>(param.num_weights * 3);
377 })
__anonfeecde190602(const nnvm::NodeAttrs& attrs) 378 .set_num_outputs([](const nnvm::NodeAttrs& attrs) {
379 const MultiSGDMomParam& param = dmlc::get<MultiSGDMomParam>(attrs.parsed);
380 return static_cast<uint32_t>(param.num_weights);
381 })
382 .set_attr_parser(ParamParser<MultiSGDMomParam>)
383 .set_attr<mxnet::FInferShape>("FInferShape", MultiSGDShape<MultiSGDMomParam, 3>)
384 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, -1>)
385 .set_attr<nnvm::FListInputNames>("FListInputNames",
__anonfeecde190702(const NodeAttrs& attrs) 386 [](const NodeAttrs& attrs) {
387 uint32_t num_args = dmlc::get<MultiSGDParam>(attrs.parsed).num_weights;
388 std::vector<std::string> ret;
389 for (uint32_t i = 0; i < num_args; ++i) {
390 ret.push_back(std::string("weight_") + std::to_string(i));
391 ret.push_back(std::string("grad_") + std::to_string(i));
392 ret.push_back(std::string("mom_") + std::to_string(i));
393 }
394 return ret;
395 })
396 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde190802(const nnvm::NodeAttrs& attrs) 397 [](const nnvm::NodeAttrs& attrs) {
398 std::vector<uint32_t> ret;
399 const MultiSGDMomParam& param = dmlc::get<MultiSGDMomParam>(attrs.parsed);
400 for (int i = 0; i < param.num_weights; ++i) {
401 ret.push_back(i * 3 + 2);
402 }
403 return ret;
404 })
405 .set_attr<FCompute>("FCompute<cpu>", MultiSGDMomUpdate<cpu, type_identity, 3>)
406 .add_argument("data", "NDArray-or-Symbol[]", "Weights, gradients and momentum")
407 .add_arguments(MultiSGDMomParam::__FIELDS__());
408
409 NNVM_REGISTER_OP(multi_mp_sgd_update)
410 .describe(R"code(Update function for multi-precision Stochastic Gradient Descent (SDG) optimizer.
411
412 It updates the weights using::
413
414 weight = weight - learning_rate * (gradient + wd * weight)
415
416 )code" ADD_FILELINE)
__anonfeecde190902(const nnvm::NodeAttrs& attrs) 417 .set_num_inputs([](const nnvm::NodeAttrs& attrs) {
418 const MultiSGDParam& param = dmlc::get<MultiSGDParam>(attrs.parsed);
419 return static_cast<uint32_t>(param.num_weights * 3);
420 })
__anonfeecde190a02(const nnvm::NodeAttrs& attrs) 421 .set_num_outputs([](const nnvm::NodeAttrs& attrs) {
422 const MultiSGDParam& param = dmlc::get<MultiSGDParam>(attrs.parsed);
423 return static_cast<uint32_t>(param.num_weights);
424 })
425 .set_attr_parser(ParamParser<MultiSGDParam>)
426 .set_attr<mxnet::FInferShape>("FInferShape", MultiSGDShape<MultiSGDParam, 3>)
427 .set_attr<nnvm::FInferType>("FInferType", MP_MultiSGD_InferType<MultiSGDParam, 3, 1>)
428 .set_attr<nnvm::FListInputNames>("FListInputNames",
__anonfeecde190b02(const NodeAttrs& attrs) 429 [](const NodeAttrs& attrs) {
430 uint32_t num_args = dmlc::get<MultiSGDParam>(attrs.parsed).num_weights;
431 std::vector<std::string> ret;
432 for (uint32_t i = 0; i < num_args; ++i) {
433 ret.push_back(std::string("weight_") + std::to_string(i));
434 ret.push_back(std::string("grad_") + std::to_string(i));
435 ret.push_back(std::string("weight32_") + std::to_string(i));
436 }
437 return ret;
438 })
439 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde190c02(const nnvm::NodeAttrs& attrs) 440 [](const nnvm::NodeAttrs& attrs) {
441 std::vector<uint32_t> ret;
442 const MultiSGDParam& param = dmlc::get<MultiSGDParam>(attrs.parsed);
443 for (int i = 0; i < param.num_weights; ++i) {
444 ret.push_back(i * 3 + 2);
445 }
446 return ret;
447 })
448 .set_attr<FCompute>("FCompute<cpu>", MultiSGDUpdate<cpu, single_precision, 3>)
449 .add_argument("data", "NDArray-or-Symbol[]", "Weights")
450 .add_arguments(MultiSGDParam::__FIELDS__());
451
452 NNVM_REGISTER_OP(multi_mp_sgd_mom_update)
453 .describe(R"code(Momentum update function for multi-precision Stochastic Gradient Descent (SGD) optimizer.
454
455 Momentum update has better convergence rates on neural networks. Mathematically it looks
456 like below:
457
458 .. math::
459
460 v_1 = \alpha * \nabla J(W_0)\\
461 v_t = \gamma v_{t-1} - \alpha * \nabla J(W_{t-1})\\
462 W_t = W_{t-1} + v_t
463
464 It updates the weights using::
465
466 v = momentum * v - learning_rate * gradient
467 weight += v
468
469 Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch.
470
471 )code" ADD_FILELINE)
__anonfeecde190d02(const nnvm::NodeAttrs& attrs) 472 .set_num_inputs([](const nnvm::NodeAttrs& attrs) {
473 const MultiSGDMomParam& param = dmlc::get<MultiSGDMomParam>(attrs.parsed);
474 return static_cast<uint32_t>(param.num_weights * 4);
475 })
__anonfeecde190e02(const nnvm::NodeAttrs& attrs) 476 .set_num_outputs([](const nnvm::NodeAttrs& attrs) {
477 const MultiSGDMomParam& param = dmlc::get<MultiSGDMomParam>(attrs.parsed);
478 return static_cast<uint32_t>(param.num_weights);
479 })
480 .set_attr_parser(ParamParser<MultiSGDMomParam>)
481 .set_attr<mxnet::FInferShape>("FInferShape", MultiSGDShape<MultiSGDMomParam, 4>)
482 .set_attr<nnvm::FInferType>("FInferType", MP_MultiSGD_InferType<MultiSGDMomParam, 4, 2>)
483 .set_attr<nnvm::FListInputNames>("FListInputNames",
__anonfeecde190f02(const NodeAttrs& attrs) 484 [](const NodeAttrs& attrs) {
485 uint32_t num_args = dmlc::get<MultiSGDMomParam>(attrs.parsed).num_weights;
486 std::vector<std::string> ret;
487 for (uint32_t i = 0; i < num_args; ++i) {
488 ret.push_back(std::string("weight_") + std::to_string(i));
489 ret.push_back(std::string("grad_") + std::to_string(i));
490 ret.push_back(std::string("mom_") + std::to_string(i));
491 ret.push_back(std::string("weight32_") + std::to_string(i));
492 }
493 return ret;
494 })
495 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191002(const nnvm::NodeAttrs& attrs) 496 [](const nnvm::NodeAttrs& attrs) {
497 std::vector<uint32_t> ret;
498 const MultiSGDMomParam& param = dmlc::get<MultiSGDMomParam>(attrs.parsed);
499 for (int i = 0; i < param.num_weights; ++i) {
500 ret.push_back(i * 4 + 2);
501 ret.push_back(i * 4 + 3);
502 }
503 return ret;
504 })
505 .set_attr<FCompute>("FCompute<cpu>", MultiSGDMomUpdate<cpu, single_precision, 4>)
506 .add_argument("data", "NDArray-or-Symbol[]", "Weights")
507 .add_arguments(MultiSGDMomParam::__FIELDS__());
508
509 NNVM_REGISTER_OP(sgd_update)
510 MXNET_ADD_SPARSE_OP_ALIAS(sgd_update)
511 .describe(R"code(Update function for Stochastic Gradient Descent (SGD) optimizer.
512
513 It updates the weights using::
514
515 weight = weight - learning_rate * (gradient + wd * weight)
516
517 However, if gradient is of ``row_sparse`` storage type and ``lazy_update`` is True,
518 only the row slices whose indices appear in grad.indices are updated::
519
520 for row in gradient.indices:
521 weight[row] = weight[row] - learning_rate * (gradient[row] + wd * weight[row])
522
523 )code" ADD_FILELINE)
524 .set_num_inputs(2)
525 .set_num_outputs(1)
526 .set_attr_parser(ParamParser<SGDParam>)
527 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<2, 1>)
528 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
529 .set_attr<FInferStorageType>("FInferStorageType", SGDStorageType)
530 .set_attr<FCompute>("FCompute<cpu>", SGDUpdate<cpu>)
531 .set_attr<FComputeEx>("FComputeEx<cpu>", SGDUpdateEx<cpu>)
532 .add_argument("weight", "NDArray-or-Symbol", "Weight")
533 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
534 .add_arguments(SGDParam::__FIELDS__());
535
536 NNVM_REGISTER_OP(sgd_mom_update)
537 MXNET_ADD_SPARSE_OP_ALIAS(sgd_mom_update)
538 .describe(R"code(Momentum update function for Stochastic Gradient Descent (SGD) optimizer.
539
540 Momentum update has better convergence rates on neural networks. Mathematically it looks
541 like below:
542
543 .. math::
544
545 v_1 = \alpha * \nabla J(W_0)\\
546 v_t = \gamma v_{t-1} - \alpha * \nabla J(W_{t-1})\\
547 W_t = W_{t-1} + v_t
548
549 It updates the weights using::
550
551 v = momentum * v - learning_rate * gradient
552 weight += v
553
554 Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch.
555
556 However, if grad's storage type is ``row_sparse``, ``lazy_update`` is True and weight's storage
557 type is the same as momentum's storage type,
558 only the row slices whose indices appear in grad.indices are updated (for both weight and momentum)::
559
560 for row in gradient.indices:
561 v[row] = momentum[row] * v[row] - learning_rate * gradient[row]
562 weight[row] += v[row]
563
564 )code" ADD_FILELINE)
565 .set_num_inputs(3)
566 .set_num_outputs(1)
567 .set_attr_parser(ParamParser<SGDMomParam>)
568 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
569 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
570 .set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<1, SGDMomParam>)
571 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191102(const nnvm::NodeAttrs& attrs) 572 [](const nnvm::NodeAttrs& attrs) {
573 return std::vector<uint32_t>{2};
574 })
575 .set_attr<FResourceRequestEx>("FResourceRequestEx",
__anonfeecde191202(const NodeAttrs& attrs, const int dev_mask, const DispatchMode dispatch_mode) 576 [](const NodeAttrs& attrs, const int dev_mask, const DispatchMode dispatch_mode) {
577 std::vector<ResourceRequest> request;
578 if (dispatch_mode == DispatchMode::kFComputeEx) {
579 request.emplace_back(ResourceRequest::kTempSpace);
580 }
581 return request;
582 })
583 .set_attr<FCompute>("FCompute<cpu>", SGDMomUpdate<cpu>)
584 .set_attr<FComputeEx>("FComputeEx<cpu>", SGDMomUpdateEx<cpu>)
585 .add_argument("weight", "NDArray-or-Symbol", "Weight")
586 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
587 .add_argument("mom", "NDArray-or-Symbol", "Momentum")
588 .add_arguments(SGDMomParam::__FIELDS__());
589
590 NNVM_REGISTER_OP(mp_sgd_update)
591 .describe("Updater function for multi-precision sgd optimizer")
592 .set_num_inputs(3)
593 .set_num_outputs(1)
594 .set_attr_parser(ParamParser<SGDParam>)
595 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
596 .set_attr<nnvm::FInferType>("FInferType", MP_InferType<2, 1, 3>)
597 .set_attr<FCompute>("FCompute<cpu>", MP_SGDUpdate<cpu>)
598 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191302(const nnvm::NodeAttrs& attrs) 599 [](const nnvm::NodeAttrs& attrs) {
600 return std::vector<uint32_t>{2};
601 })
602 .add_argument("weight", "NDArray-or-Symbol", "Weight")
603 .add_argument("grad", "NDArray-or-Symbol", "gradient")
604 .add_argument("weight32", "NDArray-or-Symbol", "Weight32")
605 .add_arguments(SGDParam::__FIELDS__());
606
607 NNVM_REGISTER_OP(mp_sgd_mom_update)
608 .describe("Updater function for multi-precision sgd optimizer")
609 .set_num_inputs(4)
610 .set_num_outputs(1)
611 .set_attr_parser(ParamParser<SGDMomParam>)
612 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
613 .set_attr<nnvm::FInferType>("FInferType", MP_InferType<2, 1, 4>)
614 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191402(const nnvm::NodeAttrs& attrs) 615 [](const nnvm::NodeAttrs& attrs) {
616 return std::vector<uint32_t>{2, 3};
617 })
618 .set_attr<FCompute>("FCompute<cpu>", MP_SGDMomUpdate<cpu>)
619 .add_argument("weight", "NDArray-or-Symbol", "Weight")
620 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
621 .add_argument("mom", "NDArray-or-Symbol", "Momentum")
622 .add_argument("weight32", "NDArray-or-Symbol", "Weight32")
623 .add_arguments(SGDMomParam::__FIELDS__());
624
625 NNVM_REGISTER_OP(ftml_update)
626 .describe(R"code(The FTML optimizer described in
627 *FTML - Follow the Moving Leader in Deep Learning*,
628 available at http://proceedings.mlr.press/v70/zheng17a/zheng17a.pdf.
629
630 .. math::
631
632 g_t = \nabla J(W_{t-1})\\
633 v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\
634 d_t = \frac{ 1 - \beta_1^t }{ \eta_t } (\sqrt{ \frac{ v_t }{ 1 - \beta_2^t } } + \epsilon)
635 \sigma_t = d_t - \beta_1 d_{t-1}
636 z_t = \beta_1 z_{ t-1 } + (1 - \beta_1^t) g_t - \sigma_t W_{t-1}
637 W_t = - \frac{ z_t }{ d_t }
638
639 )code" ADD_FILELINE)
640 .set_num_inputs(5)
641 .set_num_outputs(1)
642 .set_attr_parser(ParamParser<FTMLParam>)
643 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<5, 1>)
644 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<5, 1>)
645 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191502(const nnvm::NodeAttrs& attrs) 646 [](const nnvm::NodeAttrs& attrs) {
647 return std::vector<uint32_t>{2, 3, 4};
648 })
649 .set_attr<FCompute>("FCompute<cpu>", FTMLUpdate<cpu>)
650 .add_argument("weight", "NDArray-or-Symbol", "Weight")
651 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
652 .add_argument("d", "NDArray-or-Symbol", "Internal state ``d_t``")
653 .add_argument("v", "NDArray-or-Symbol", "Internal state ``v_t``")
654 .add_argument("z", "NDArray-or-Symbol", "Internal state ``z_t``")
655 .add_arguments(FTMLParam::__FIELDS__());
656
657 NNVM_REGISTER_OP(adam_update)
658 MXNET_ADD_SPARSE_OP_ALIAS(adam_update)
659 .describe(R"code(Update function for Adam optimizer. Adam is seen as a generalization
660 of AdaGrad.
661
662 Adam update consists of the following steps, where g represents gradient and m, v
663 are 1st and 2nd order moment estimates (mean and variance).
664
665 .. math::
666
667 g_t = \nabla J(W_{t-1})\\
668 m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\
669 v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\
670 W_t = W_{t-1} - \alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon }
671
672 It updates the weights using::
673
674 m = beta1*m + (1-beta1)*grad
675 v = beta2*v + (1-beta2)*(grad**2)
676 w += - learning_rate * m / (sqrt(v) + epsilon)
677
678 However, if grad's storage type is ``row_sparse``, ``lazy_update`` is True and the storage
679 type of weight is the same as those of m and v,
680 only the row slices whose indices appear in grad.indices are updated (for w, m and v)::
681
682 for row in grad.indices:
683 m[row] = beta1*m[row] + (1-beta1)*grad[row]
684 v[row] = beta2*v[row] + (1-beta2)*(grad[row]**2)
685 w[row] += - learning_rate * m[row] / (sqrt(v[row]) + epsilon)
686
687 )code" ADD_FILELINE)
688 .set_num_inputs(4)
689 .set_num_outputs(1)
690 .set_attr_parser(ParamParser<AdamParam>)
691 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
692 .set_attr<FResourceRequest>("FResourceRequest",
__anonfeecde191602(const NodeAttrs& attrs) 693 [](const NodeAttrs& attrs) {
694 return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
695 })
696 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
697 .set_attr<FInferStorageType>("FInferStorageType", StdOptStorageType<2, AdamParam>)
698 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191702(const nnvm::NodeAttrs& attrs) 699 [](const nnvm::NodeAttrs& attrs) {
700 return std::vector<uint32_t>{2, 3};
701 })
702 .set_attr<FCompute>("FCompute<cpu>", AdamUpdate<cpu>)
703 .set_attr<FComputeEx>("FComputeEx<cpu>", AdamUpdateEx<cpu>)
704 .add_argument("weight", "NDArray-or-Symbol", "Weight")
705 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
706 .add_argument("mean", "NDArray-or-Symbol", "Moving mean")
707 .add_argument("var", "NDArray-or-Symbol", "Moving variance")
708 .add_arguments(AdamParam::__FIELDS__());
709
710
711 NNVM_REGISTER_OP(nag_mom_update)
712 .describe(R"code(Update function for Nesterov Accelerated Gradient( NAG) optimizer.
713 It updates the weights using the following formula,
714
715 .. math::
716 v_t = \gamma v_{t-1} + \eta * \nabla J(W_{t-1} - \gamma v_{t-1})\\
717 W_t = W_{t-1} - v_t
718
719 Where
720 :math:`\eta` is the learning rate of the optimizer
721 :math:`\gamma` is the decay rate of the momentum estimate
722 :math:`\v_t` is the update vector at time step `t`
723 :math:`\W_t` is the weight vector at time step `t`
724
725 )code" ADD_FILELINE)
726 .set_num_inputs(3)
727 .set_num_outputs(1)
728 .set_attr_parser(ParamParser<NAGMomParam>)
729 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
730 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
731 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191802(const nnvm::NodeAttrs& attrs) 732 [](const nnvm::NodeAttrs& attrs) {
733 return std::vector<uint32_t>{2};
734 })
735 .set_attr<FCompute>("FCompute<cpu>", NAGMomUpdate<cpu>)
736 .add_argument("weight", "NDArray-or-Symbol", "Weight")
737 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
738 .add_argument("mom", "NDArray-or-Symbol", "Momentum")
739 .add_arguments(NAGMomParam::__FIELDS__());
740
741
742 NNVM_REGISTER_OP(mp_nag_mom_update)
743 .describe(R"code(Update function for multi-precision Nesterov Accelerated Gradient( NAG) optimizer.
744 )code" ADD_FILELINE)
745 .set_num_inputs(4)
746 .set_num_outputs(1)
747 .set_attr_parser(ParamParser<NAGMomParam>)
748 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
749 .set_attr<nnvm::FInferType>("FInferType", MP_InferType<2, 1, 4>)
750 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191902(const nnvm::NodeAttrs& attrs) 751 [](const nnvm::NodeAttrs& attrs) {
752 return std::vector<uint32_t>{2, 3};
753 })
754 .set_attr<FCompute>("FCompute<cpu>", MP_NAGMomUpdate<cpu>)
755 .add_argument("weight", "NDArray-or-Symbol", "Weight")
756 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
757 .add_argument("mom", "NDArray-or-Symbol", "Momentum")
758 .add_argument("weight32", "NDArray-or-Symbol", "Weight32")
759 .add_arguments(NAGMomParam::__FIELDS__());
760
761
762 NNVM_REGISTER_OP(rmsprop_update)
763 .describe(R"code(Update function for `RMSProp` optimizer.
764
765 `RMSprop` is a variant of stochastic gradient descent where the gradients are
766 divided by a cache which grows with the sum of squares of recent gradients?
767
768 `RMSProp` is similar to `AdaGrad`, a popular variant of `SGD` which adaptively
769 tunes the learning rate of each parameter. `AdaGrad` lowers the learning rate for
770 each parameter monotonically over the course of training.
771 While this is analytically motivated for convex optimizations, it may not be ideal
772 for non-convex problems. `RMSProp` deals with this heuristically by allowing the
773 learning rates to rebound as the denominator decays over time.
774
775 Define the Root Mean Square (RMS) error criterion of the gradient as
776 :math:`RMS[g]_t = \sqrt{E[g^2]_t + \epsilon}`, where :math:`g` represents
777 gradient and :math:`E[g^2]_t` is the decaying average over past squared gradient.
778
779 The :math:`E[g^2]_t` is given by:
780
781 .. math::
782 E[g^2]_t = \gamma * E[g^2]_{t-1} + (1-\gamma) * g_t^2
783
784 The update step is
785
786 .. math::
787 \theta_{t+1} = \theta_t - \frac{\eta}{RMS[g]_t} g_t
788
789 The RMSProp code follows the version in
790 http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf
791 Tieleman & Hinton, 2012.
792
793 Hinton suggests the momentum term :math:`\gamma` to be 0.9 and the learning rate
794 :math:`\eta` to be 0.001.
795
796 )code" ADD_FILELINE)
797 .set_num_inputs(3)
798 .set_num_outputs(1)
799 .set_attr_parser(ParamParser<RMSPropParam>)
800 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
801 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
802 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191a02(const nnvm::NodeAttrs &attrs) 803 [](const nnvm::NodeAttrs &attrs) {
804 return std::vector<uint32_t>{2};
805 })
806 .set_attr<FCompute>("FCompute<cpu>", RMSPropUpdate<cpu>)
807 .add_argument("weight", "NDArray-or-Symbol", "Weight")
808 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
809 .add_argument("n", "NDArray-or-Symbol", "n")
810 .add_arguments(RMSPropParam::__FIELDS__());
811
812 NNVM_REGISTER_OP(rmspropalex_update)
813 .describe(R"code(Update function for RMSPropAlex optimizer.
814
815 `RMSPropAlex` is non-centered version of `RMSProp`.
816
817 Define :math:`E[g^2]_t` is the decaying average over past squared gradient and
818 :math:`E[g]_t` is the decaying average over past gradient.
819
820 .. math::
821 E[g^2]_t = \gamma_1 * E[g^2]_{t-1} + (1 - \gamma_1) * g_t^2\\
822 E[g]_t = \gamma_1 * E[g]_{t-1} + (1 - \gamma_1) * g_t\\
823 \Delta_t = \gamma_2 * \Delta_{t-1} - \frac{\eta}{\sqrt{E[g^2]_t - E[g]_t^2 + \epsilon}} g_t\\
824
825 The update step is
826
827 .. math::
828 \theta_{t+1} = \theta_t + \Delta_t
829
830 The RMSPropAlex code follows the version in
831 http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45) by Alex Graves, 2013.
832
833 Graves suggests the momentum term :math:`\gamma_1` to be 0.95, :math:`\gamma_2`
834 to be 0.9 and the learning rate :math:`\eta` to be 0.0001.
835 )code" ADD_FILELINE)
836 .set_num_inputs(5)
837 .set_num_outputs(1)
838 .set_attr_parser(ParamParser<RMSPropAlexParam>)
839 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<5, 1>)
840 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<5, 1>)
841 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191b02(const nnvm::NodeAttrs& attrs) 842 [](const nnvm::NodeAttrs& attrs) {
843 return std::vector<uint32_t>{2, 3, 4};
844 })
845 .set_attr<FCompute>("FCompute<cpu>", RMSPropAlexUpdate<cpu>)
846 .add_argument("weight", "NDArray-or-Symbol", "Weight")
847 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
848 .add_argument("n", "NDArray-or-Symbol", "n")
849 .add_argument("g", "NDArray-or-Symbol", "g")
850 .add_argument("delta", "NDArray-or-Symbol", "delta")
851 .add_arguments(RMSPropAlexParam::__FIELDS__());
852
853 NNVM_REGISTER_OP(ftrl_update)
854 MXNET_ADD_SPARSE_OP_ALIAS(ftrl_update)
855 .describe(R"code(Update function for Ftrl optimizer.
856 Referenced from *Ad Click Prediction: a View from the Trenches*, available at
857 http://dl.acm.org/citation.cfm?id=2488200.
858
859 It updates the weights using::
860
861 rescaled_grad = clip(grad * rescale_grad, clip_gradient)
862 z += rescaled_grad - (sqrt(n + rescaled_grad**2) - sqrt(n)) * weight / learning_rate
863 n += rescaled_grad**2
864 w = (sign(z) * lamda1 - z) / ((beta + sqrt(n)) / learning_rate + wd) * (abs(z) > lamda1)
865
866 If w, z and n are all of ``row_sparse`` storage type,
867 only the row slices whose indices appear in grad.indices are updated (for w, z and n)::
868
869 for row in grad.indices:
870 rescaled_grad[row] = clip(grad[row] * rescale_grad, clip_gradient)
871 z[row] += rescaled_grad[row] - (sqrt(n[row] + rescaled_grad[row]**2) - sqrt(n[row])) * weight[row] / learning_rate
872 n[row] += rescaled_grad[row]**2
873 w[row] = (sign(z[row]) * lamda1 - z[row]) / ((beta + sqrt(n[row])) / learning_rate + wd) * (abs(z[row]) > lamda1)
874
875 )code" ADD_FILELINE)
876 .set_num_inputs(4)
877 .set_num_outputs(1)
878 .set_attr_parser(ParamParser<FtrlParam>)
879 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
880 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
881 .set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<4, 1, false, true, false>)
882 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191c02(const nnvm::NodeAttrs& attrs) 883 [](const nnvm::NodeAttrs& attrs) {
884 return std::vector<uint32_t>{2, 3};
885 })
886 .set_attr<FCompute>("FCompute<cpu>", FtrlUpdate<cpu>)
887 .set_attr<FComputeEx>("FComputeEx<cpu>", FtrlUpdateEx<cpu>)
888 .add_argument("weight", "NDArray-or-Symbol", "Weight")
889 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
890 .add_argument("z", "NDArray-or-Symbol", "z")
891 .add_argument("n", "NDArray-or-Symbol", "Square of grad")
892 .add_arguments(FtrlParam::__FIELDS__());
893
894 NNVM_REGISTER_OP(_sparse_adagrad_update)
895 .describe(R"code(Update function for AdaGrad optimizer.
896
897 Referenced from *Adaptive Subgradient Methods for Online Learning and Stochastic Optimization*,
898 and available at http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf.
899
900 Updates are applied by::
901
902 rescaled_grad = clip(grad * rescale_grad, clip_gradient)
903 history = history + square(rescaled_grad)
904 w = w - learning_rate * rescaled_grad / sqrt(history + epsilon)
905
906 Note that non-zero values for the weight decay option are not supported.
907
908 )code" ADD_FILELINE)
909 .set_num_inputs(3)
910 .set_num_outputs(1)
911 .set_attr_parser(ParamParser<AdagradParam>)
912 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
913 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
914 .set_attr<FInferStorageType>("FInferStorageType", AdagradStorageType)
915 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191d02(const nnvm::NodeAttrs& attrs) 916 [](const nnvm::NodeAttrs& attrs) {
917 return std::vector<uint32_t>{2};
918 })
919 .set_attr<FComputeEx>("FComputeEx<cpu>", AdagradUpdateEx<cpu>)
920 .add_argument("weight", "NDArray-or-Symbol", "Weight")
921 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
922 .add_argument("history", "NDArray-or-Symbol", "History")
923 .add_arguments(AdagradParam::__FIELDS__());
924
925 NNVM_REGISTER_OP(lamb_update_phase1)
926 .describe(R"code(Phase I of lamb update it performs the following operations and returns g:.
927
928 Link to paper: https://arxiv.org/pdf/1904.00962.pdf
929
930 .. math::
931 \begin{gather*}
932 grad = grad * rescale_grad
933 if (grad < -clip_gradient)
934 then
935 grad = -clip_gradient
936 if (grad > clip_gradient)
937 then
938 grad = clip_gradient
939
940 mean = beta1 * mean + (1 - beta1) * grad;
941 variance = beta2 * variance + (1. - beta2) * grad ^ 2;
942
943 if (bias_correction)
944 then
945 mean_hat = mean / (1. - beta1^t);
946 var_hat = var / (1 - beta2^t);
947 g = mean_hat / (var_hat^(1/2) + epsilon) + wd * weight;
948 else
949 g = mean / (var_data^(1/2) + epsilon) + wd * weight;
950 \end{gather*}
951
952 )code" ADD_FILELINE)
953 .set_num_inputs(4)
954 .set_num_outputs(1)
955 .set_attr_parser(ParamParser<LambUpdatePhaseOneParam>)
956 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
957 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
958 .set_attr<FCompute>("FCompute<cpu>", LambUpdatePhaseOne<cpu>)
959 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191e02(const nnvm::NodeAttrs& attrs) 960 [](const nnvm::NodeAttrs& attrs) {
961 return std::vector<uint32_t>{2, 3};
962 })
963 .add_argument("weight", "NDArray-or-Symbol", "Weight")
964 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
965 .add_argument("mean", "NDArray-or-Symbol", "Moving mean")
966 .add_argument("var", "NDArray-or-Symbol", "Moving variance")
967 .add_arguments(LambUpdatePhaseOneParam::__FIELDS__());
968
969 NNVM_REGISTER_OP(lamb_update_phase2)
970 .describe(R"code(Phase II of lamb update it performs the following operations and updates grad.
971
972 Link to paper: https://arxiv.org/pdf/1904.00962.pdf
973
974 .. math::
975 \begin{gather*}
976 if (lower_bound >= 0)
977 then
978 r1 = max(r1, lower_bound)
979 if (upper_bound >= 0)
980 then
981 r1 = max(r1, upper_bound)
982
983 if (r1 == 0 or r2 == 0)
984 then
985 lr = lr
986 else
987 lr = lr * (r1/r2)
988 weight = weight - lr * g
989 \end{gather*}
990
991 )code" ADD_FILELINE)
992 .set_num_inputs(4)
993 .set_num_outputs(1)
994 .set_attr_parser(ParamParser<LambUpdatePhaseTwoParam>)
995 .set_attr<mxnet::FInferShape>("FInferShape", LambUpdatePhaseTwoShape)
996 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
997 .set_attr<FCompute>("FCompute<cpu>", LambUpdatePhaseTwo<cpu>)
998 .add_argument("weight", "NDArray-or-Symbol", "Weight")
999 .add_argument("g", "NDArray-or-Symbol", "Output of lamb_update_phase 1")
1000 .add_argument("r1", "NDArray-or-Symbol", "r1")
1001 .add_argument("r2", "NDArray-or-Symbol", "r2")
1002 .add_arguments(LambUpdatePhaseTwoParam::__FIELDS__());
1003
1004 NNVM_REGISTER_OP(mp_lamb_update_phase1)
1005 .describe(R"code(Mixed Precision version of Phase I of lamb update
1006 it performs the following operations and returns g:.
1007
1008 Link to paper: https://arxiv.org/pdf/1904.00962.pdf
1009
1010 .. math::
1011 \begin{gather*}
1012 grad32 = grad(float16) * rescale_grad
1013 if (grad < -clip_gradient)
1014 then
1015 grad = -clip_gradient
1016 if (grad > clip_gradient)
1017 then
1018 grad = clip_gradient
1019
1020 mean = beta1 * mean + (1 - beta1) * grad;
1021 variance = beta2 * variance + (1. - beta2) * grad ^ 2;
1022
1023 if (bias_correction)
1024 then
1025 mean_hat = mean / (1. - beta1^t);
1026 var_hat = var / (1 - beta2^t);
1027 g = mean_hat / (var_hat^(1/2) + epsilon) + wd * weight32;
1028 else
1029 g = mean / (var_data^(1/2) + epsilon) + wd * weight32;
1030 \end{gather*}
1031
1032 )code" ADD_FILELINE)
1033 .set_num_inputs(5)
1034 .set_num_outputs(1)
1035 .set_attr_parser(ParamParser<LambUpdatePhaseOneParam>)
1036 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<5, 1>)
1037 .set_attr<nnvm::FInferType>("FInferType", MPLambPhaseOneType<2, 1, 5>)
1038 .set_attr<FCompute>("FCompute<cpu>", MPLambUpdatePhaseOne<cpu>)
1039 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde191f02(const nnvm::NodeAttrs& attrs) 1040 [](const nnvm::NodeAttrs& attrs) {
1041 return std::vector<uint32_t>{2, 3};
1042 })
1043 .add_argument("weight", "NDArray-or-Symbol", "Weight")
1044 .add_argument("grad", "NDArray-or-Symbol", "Gradient")
1045 .add_argument("mean", "NDArray-or-Symbol", "Moving mean")
1046 .add_argument("var", "NDArray-or-Symbol", "Moving variance")
1047 .add_argument("weight32", "NDArray-or-Symbol", "Weight32")
1048 .add_arguments(LambUpdatePhaseOneParam::__FIELDS__());
1049
1050 NNVM_REGISTER_OP(mp_lamb_update_phase2)
1051 .describe(R"code(Mixed Precision version Phase II of lamb update
1052 it performs the following operations and updates grad.
1053
1054 Link to paper: https://arxiv.org/pdf/1904.00962.pdf
1055
1056 .. math::
1057 \begin{gather*}
1058 if (lower_bound >= 0)
1059 then
1060 r1 = max(r1, lower_bound)
1061 if (upper_bound >= 0)
1062 then
1063 r1 = max(r1, upper_bound)
1064
1065 if (r1 == 0 or r2 == 0)
1066 then
1067 lr = lr
1068 else
1069 lr = lr * (r1/r2)
1070 weight32 = weight32 - lr * g
1071 weight(float16) = weight32
1072 \end{gather*}
1073
1074 )code" ADD_FILELINE)
1075 .set_num_inputs(5)
1076 .set_num_outputs(1)
1077 .set_attr_parser(ParamParser<LambUpdatePhaseTwoParam>)
1078 .set_attr<mxnet::FInferShape>("FInferShape", MPLambUpdatePhaseTwoShape)
1079 .set_attr<nnvm::FInferType>("FInferType", MP_InferType<1, 1, 5>)
1080 .set_attr<FCompute>("FCompute<cpu>", MPLambUpdatePhaseTwo<cpu>)
1081 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anonfeecde192002(const nnvm::NodeAttrs& attrs) 1082 [](const nnvm::NodeAttrs& attrs) {
1083 return std::vector<uint32_t>{4};
1084 })
1085 .add_argument("weight", "NDArray-or-Symbol", "Weight")
1086 .add_argument("g", "NDArray-or-Symbol", "Output of mp_lamb_update_phase 1")
1087 .add_argument("r1", "NDArray-or-Symbol", "r1")
1088 .add_argument("r2", "NDArray-or-Symbol", "r2")
1089 .add_argument("weight32", "NDArray-or-Symbol", "Weight32")
1090 .add_arguments(LambUpdatePhaseTwoParam::__FIELDS__());
1091
1092 } // namespace op
1093 } // namespace mxnet
1094