1# coding: utf-8 2# Licensed to the Apache Software Foundation (ASF) under one 3# or more contributor license agreements. See the NOTICE file 4# distributed with this work for additional information 5# regarding copyright ownership. The ASF licenses this file 6# to you under the Apache License, Version 2.0 (the 7# "License"); you may not use this file except in compliance 8# with the License. You may obtain a copy of the License at 9# 10# http://www.apache.org/licenses/LICENSE-2.0 11# 12# Unless required by applicable law or agreed to in writing, 13# software distributed under the License is distributed on an 14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15# KIND, either express or implied. See the License for the 16# specific language governing permissions and limitations 17# under the License. 18 19# pylint: disable=too-many-lines 20"""Weight updating functions.""" 21import logging 22import math 23import pickle 24import warnings 25import os 26import numpy 27from ..base import py_str 28from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply, 29 multi_sum_sq, multi_lars, norm as NDnorm) 30from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update, 31 mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update, 32 signsgd_update, signum_update, nag_mom_update, mp_nag_mom_update, 33 multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update, 34 multi_mp_sgd_mom_update, preloaded_multi_sgd_update, 35 preloaded_multi_sgd_mom_update, preloaded_multi_mp_sgd_update, 36 preloaded_multi_mp_sgd_mom_update, lamb_update_phase1, lamb_update_phase2, 37 mp_lamb_update_phase1, mp_lamb_update_phase2) 38from ..ndarray.contrib import (multi_lamb_update, multi_mp_lamb_update) 39from ..ndarray import sparse 40from ..random import normal 41from ..util import is_np_array 42 43__all__ = [ 44 'AdaDelta', 'AdaGrad', 'Adam', 'Adamax', 'DCASGD', 'FTML', 'Ftrl', 'LARS', 'LBSGD', 45 'NAG', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD', 'Signum', 'LAMB', 46 'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register' 47] 48 49def _flatten_list(nested_list): 50 return [item for sublist in nested_list for item in sublist] 51 52class Optimizer(object): 53 """The base class inherited by all optimizers. 54 55 Parameters 56 ---------- 57 rescale_grad : float, optional, default 1.0 58 Multiply the gradient with `rescale_grad` before updating. Often 59 choose to be ``1.0/batch_size``. 60 61 param_idx2name : dict from int to string, optional, default None 62 A dictionary that maps int index to string name. 63 64 clip_gradient : float, optional, default None 65 Clip the gradient by projecting onto the box ``[-clip_gradient, clip_gradient]``. 66 67 learning_rate : float, optional, default None 68 The initial learning rate. If None, the optimization will use the 69 learning rate from ``lr_scheduler``. If not None, it will overwrite 70 the learning rate in ``lr_scheduler``. If None and ``lr_scheduler`` 71 is also None, then it will be set to 0.01 by default. 72 73 lr_scheduler : LRScheduler, optional, default None 74 The learning rate scheduler. 75 76 wd : float, optional, default 0.0 77 The weight decay (or L2 regularization) coefficient. Modifies objective 78 by adding a penalty for having large weights. 79 80 sym: Symbol, optional, default None 81 The Symbol this optimizer is applying to. 82 83 begin_num_update : int, optional, default 0 84 The initial number of updates. 85 86 multi_precision : bool, optional, default False 87 Flag to control the internal precision of the optimizer. 88 False: results in using the same precision as the weights (default), 89 True: makes internal 32-bit copy of the weights and applies gradients 90 in 32-bit precision even if actual weights used in the model have lower precision. 91 Turning this on can improve convergence and accuracy when training with float16. 92 93 param_dict : dict of int -> gluon.Parameter, default None 94 Dictionary of parameter index to gluon.Parameter, used to lookup parameter attributes 95 such as lr_mult, wd_mult, etc. param_dict shall not be deep copied. 96 97 Properties 98 ---------- 99 learning_rate : float 100 The current learning rate of the optimizer. Given an Optimizer object 101 optimizer, its learning rate can be accessed as optimizer.learning_rate. 102 """ 103 def __init__(self, rescale_grad=1., param_idx2name=None, wd=0., 104 clip_gradient=None, learning_rate=None, 105 lr_scheduler=None, sym=None, begin_num_update=0, 106 multi_precision=False, param_dict=None): 107 self.rescale_grad = rescale_grad 108 self.lr_scheduler = lr_scheduler 109 if self.lr_scheduler is None and learning_rate is None: 110 learning_rate = 0.01 111 self.lr = learning_rate 112 if self.lr_scheduler is not None and learning_rate is not None: 113 if self.lr_scheduler.base_lr != learning_rate: 114 print(UserWarning("learning rate from ``lr_scheduler`` has been " 115 "overwritten by ``learning_rate`` in optimizer.")) 116 self.lr_scheduler.base_lr = learning_rate 117 118 self.wd = wd 119 self.lr_mult = {} 120 self.wd_mult = {} 121 self.begin_num_update = begin_num_update 122 self.num_update = begin_num_update 123 self._all_index_update_counts = {0 : {}} 124 self._index_update_count = self._all_index_update_counts[0] 125 self.clip_gradient = clip_gradient 126 self.multi_precision = multi_precision 127 self.aggregate_num = 0 128 129 if param_idx2name is None: 130 param_idx2name = {} 131 assert isinstance(param_idx2name, dict), \ 132 'param_idx2name should be a dict of param indexes to names.' 133 self.idx2name = param_idx2name.copy() 134 self.sym_info = (sym.attr_dict(), sym.list_arguments()) if sym is not None else () 135 self.param_dict = param_dict if param_dict else {} 136 self.allow_np_array = is_np_array() 137 138 self.set_lr_mult({}) 139 self.set_wd_mult({}) 140 141 opt_registry = {} 142 143 @staticmethod 144 def register(klass): 145 """Registers a new optimizer. 146 147 Once an optimizer is registered, we can create an instance of this 148 optimizer with `create_optimizer` later. 149 150 Examples 151 -------- 152 153 >>> @mx.optimizer.Optimizer.register 154 ... class MyOptimizer(mx.optimizer.Optimizer): 155 ... pass 156 >>> optim = mx.optimizer.Optimizer.create_optimizer('MyOptimizer') 157 >>> print(type(optim)) 158 <class '__main__.MyOptimizer'> 159 """ 160 assert(isinstance(klass, type)) 161 name = klass.__name__.lower() 162 if name in Optimizer.opt_registry: 163 warnings.warn('WARNING: New optimizer %s.%s is overriding ' 164 'existing optimizer %s.%s' % 165 (klass.__module__, klass.__name__, 166 Optimizer.opt_registry[name].__module__, 167 Optimizer.opt_registry[name].__name__)) 168 Optimizer.opt_registry[name] = klass 169 return klass 170 171 @staticmethod 172 def create_optimizer(name, **kwargs): 173 """Instantiates an optimizer with a given name and kwargs. 174 175 .. note:: We can use the alias `create` for ``Optimizer.create_optimizer``. 176 177 Parameters 178 ---------- 179 name: str 180 Name of the optimizer. Should be the name 181 of a subclass of Optimizer. Case insensitive. 182 183 kwargs: dict 184 Parameters for the optimizer. 185 186 Returns 187 ------- 188 Optimizer 189 An instantiated optimizer. 190 191 Examples 192 -------- 193 >>> sgd = mx.optimizer.Optimizer.create_optimizer('sgd') 194 >>> type(sgd) 195 <class 'mxnet.optimizer.SGD'> 196 >>> adam = mx.optimizer.create('adam', learning_rate=.1) 197 >>> type(adam) 198 <class 'mxnet.optimizer.Adam'> 199 """ 200 if name.lower() in Optimizer.opt_registry: 201 return Optimizer.opt_registry[name.lower()](**kwargs) 202 else: 203 raise ValueError('Cannot find optimizer %s' % name) 204 205 @property 206 def learning_rate(self): 207 if self.lr_scheduler is not None: 208 return self.lr_scheduler(self.num_update) 209 else: 210 return self.lr 211 212 def create_state(self, index, weight): 213 """Creates auxiliary state for a given weight. 214 215 Some optimizers require additional states, e.g. as momentum, in addition 216 to gradients in order to update weights. This function creates state 217 for a given weight which will be used in `update`. This function is 218 called only once for each weight. 219 220 Parameters 221 ---------- 222 index : int 223 An unique index to identify the weight. 224 weight : NDArray 225 The weight. 226 227 Returns 228 ------- 229 state : any obj 230 The state associated with the weight. 231 """ 232 233 def create_state_multi_precision(self, index, weight): 234 """Creates auxiliary state for a given weight, including FP32 high 235 precision copy if original weight is FP16. 236 237 This method is provided to perform automatic mixed precision training 238 for optimizers that do not support it themselves. 239 240 Parameters 241 ---------- 242 index : int 243 An unique index to identify the weight. 244 weight : NDArray 245 The weight. 246 247 Returns 248 ------- 249 state : any obj 250 The state associated with the weight. 251 """ 252 weight_master_copy = None 253 if self.multi_precision and weight.dtype == numpy.float16: 254 weight_master_copy = weight.astype(numpy.float32) 255 return (weight_master_copy,) + (self.create_state(index, weight_master_copy),) 256 if weight.dtype == numpy.float16 and not self.multi_precision: 257 warnings.warn("Accumulating with float16 in optimizer can lead to " 258 "poor accuracy or slow convergence. " 259 "Consider using multi_precision=True option of the " 260 "optimizer") 261 return self.create_state(index, weight) 262 263 def update(self, index, weight, grad, state): 264 """Updates the given parameter using the corresponding gradient and state. 265 266 Parameters 267 ---------- 268 index : int 269 The unique index of the parameter into the individual learning 270 rates and weight decays. Learning rates and weight decay 271 may be set via `set_lr_mult()` and `set_wd_mult()`, respectively. 272 weight : NDArray 273 The parameter to be updated. 274 grad : NDArray 275 The gradient of the objective with respect to this parameter. 276 state : any obj 277 The state returned by `create_state()`. 278 """ 279 raise NotImplementedError() 280 281 def update_multi_precision(self, index, weight, grad, state): 282 """Updates the given parameter using the corresponding gradient and state. 283 Mixed precision version. 284 285 Parameters 286 ---------- 287 index : int 288 The unique index of the parameter into the individual learning 289 rates and weight decays. Learning rates and weight decay 290 may be set via `set_lr_mult()` and `set_wd_mult()`, respectively. 291 weight : NDArray 292 The parameter to be updated. 293 grad : NDArray 294 The gradient of the objective with respect to this parameter. 295 state : any obj 296 The state returned by `create_state()`. 297 """ 298 if self.multi_precision and weight.dtype == numpy.float16: 299 # Wrapper for mixed precision 300 weight_master_copy = state[0] 301 original_state = state[1] 302 grad32 = grad.astype(numpy.float32) 303 self.update(index, weight_master_copy, grad32, original_state) 304 cast(weight_master_copy, dtype=weight.dtype, out=weight) 305 else: 306 self.update(index, weight, grad, state) 307 308 def set_learning_rate(self, lr): 309 """Sets a new learning rate of the optimizer. 310 311 Parameters 312 ---------- 313 lr : float 314 The new learning rate of the optimizer. 315 """ 316 if self.lr_scheduler is not None: # pylint: disable=no-else-raise 317 raise UserWarning("LRScheduler of the optimizer has already been " 318 "defined. Note that set_learning_rate can mutate " 319 "the value of the learning rate of the optimizer " 320 "only when the LRScheduler of the optimizer is " 321 "undefined.") 322 else: 323 self.lr = lr 324 325 def set_lr_scale(self, args_lrscale): # pylint: disable=unused-argument 326 """[DEPRECATED] Sets lr scale. Use set_lr_mult instead.""" 327 raise DeprecationWarning 328 329 def set_lr_mult(self, args_lr_mult): 330 """Sets an individual learning rate multiplier for each parameter. 331 332 If you specify a learning rate multiplier for a parameter, then 333 the learning rate for the parameter will be set as the product of 334 the global learning rate `self.lr` and its multiplier. 335 336 .. note:: The default learning rate multiplier of a `Variable` 337 can be set with `lr_mult` argument in the constructor. 338 339 Parameters 340 ---------- 341 args_lr_mult : dict of str/int to float 342 For each of its key-value entries, the learning rate multipler for the 343 parameter specified in the key will be set as the given value. 344 345 You can specify the parameter with either its name or its index. 346 If you use the name, you should pass `sym` in the constructor, 347 and the name you specified in the key of `args_lr_mult` should match 348 the name of the parameter in `sym`. If you use the index, it should 349 correspond to the index of the parameter used in the `update` method. 350 351 Specifying a parameter by its index is only supported for backward 352 compatibility, and we recommend to use the name instead. 353 """ 354 self.lr_mult = {} 355 if self.sym_info: 356 attr, arg_names = self.sym_info 357 for name in arg_names: 358 if name in attr and '__lr_mult__' in attr[name]: 359 self.lr_mult[name] = float(attr[name]['__lr_mult__']) 360 self.lr_mult.update(args_lr_mult) 361 362 def set_wd_mult(self, args_wd_mult): 363 """Sets an individual weight decay multiplier for each parameter. 364 365 By default, if `param_idx2name` was provided in the 366 constructor, the weight decay multipler is set as 0 for all 367 parameters whose name don't end with ``_weight`` or 368 ``_gamma``. 369 370 .. note:: The default weight decay multiplier for a `Variable` 371 can be set with its `wd_mult` argument in the constructor. 372 373 Parameters 374 ---------- 375 args_wd_mult : dict of string/int to float 376 For each of its key-value entries, the weight decay multipler for the 377 parameter specified in the key will be set as the given value. 378 379 You can specify the parameter with either its name or its index. 380 If you use the name, you should pass `sym` in the constructor, 381 and the name you specified in the key of `args_lr_mult` should match 382 the name of the parameter in `sym`. If you use the index, it should 383 correspond to the index of the parameter used in the `update` method. 384 385 Specifying a parameter by its index is only supported for backward 386 compatibility, and we recommend to use the name instead. 387 """ 388 self.wd_mult = {} 389 for n in self.idx2name.values(): 390 if not (n.endswith('_weight') or n.endswith('_gamma')): 391 self.wd_mult[n] = 0.0 392 if self.sym_info: 393 attr, arg_names = self.sym_info 394 for name in arg_names: 395 if name in attr and '__wd_mult__' in attr[name]: 396 self.wd_mult[name] = float(attr[name]['__wd_mult__']) 397 self.wd_mult.update(args_wd_mult) 398 399 def _set_current_context(self, device_id): 400 """Sets the number of the currently handled device. 401 402 Parameters 403 ---------- 404 device_id : int 405 The number of current device. 406 """ 407 if device_id not in self._all_index_update_counts: 408 self._all_index_update_counts[device_id] = {} 409 self._index_update_count = self._all_index_update_counts[device_id] 410 411 def _update_count(self, index): 412 """Updates num_update. 413 414 Parameters 415 ---------- 416 index : int or list of int 417 The index to be updated. 418 """ 419 if not isinstance(index, (list, tuple)): 420 index = [index] 421 for idx in index: 422 if idx not in self._index_update_count: 423 self._index_update_count[idx] = self.begin_num_update 424 self._index_update_count[idx] += 1 425 self.num_update = max(self._index_update_count[idx], self.num_update) 426 427 def _get_lrs(self, indices): 428 """Gets the learning rates given the indices of the weights. 429 430 Parameters 431 ---------- 432 indices : list of int 433 Indices corresponding to weights. 434 435 Returns 436 ------- 437 lrs : list of float 438 Learning rates for those indices. 439 """ 440 if self.lr_scheduler is not None: 441 lr = self.lr_scheduler(self.num_update) 442 else: 443 lr = self.lr 444 445 lrs = [lr for _ in indices] 446 for i, index in enumerate(indices): 447 if index in self.param_dict: 448 lrs[i] *= self.param_dict[index].lr_mult 449 elif index in self.lr_mult: 450 lrs[i] *= self.lr_mult[index] 451 elif index in self.idx2name: 452 lrs[i] *= self.lr_mult.get(self.idx2name[index], 1.0) 453 return lrs 454 455 def _get_lr(self, index): 456 """Gets the learning rate given the index of the weight. 457 458 Parameters 459 ---------- 460 index : int 461 The index corresponding to the weight. 462 463 Returns 464 ------- 465 lr : float 466 Learning rate for this index. 467 """ 468 return self._get_lrs([index])[0] 469 470 def _get_wds(self, indices): 471 """Gets weight decays for indices. 472 Returns 0 for non-weights if the name of weights are provided for `__init__`. 473 474 Parameters 475 ---------- 476 indices : list of int 477 Indices of weights. 478 479 Returns 480 ------- 481 wds : list of float 482 Weight decays for those indices. 483 """ 484 wds = [self.wd for _ in indices] 485 for i, index in enumerate(indices): 486 if index in self.param_dict: 487 wds[i] *= self.param_dict[index].wd_mult 488 elif index in self.wd_mult: 489 wds[i] *= self.wd_mult[index] 490 elif index in self.idx2name: 491 wds[i] *= self.wd_mult.get(self.idx2name[index], 1.0) 492 return wds 493 494 def _get_wd(self, index): 495 """Gets weight decay for index. 496 Returns 0 for non-weights if the name of weights are provided for `__init__`. 497 498 Parameters 499 ---------- 500 index : int 501 The index of weight. 502 503 Returns 504 ------- 505 wd : float 506 Weight decay for this index. 507 """ 508 return self._get_wds([index])[0] 509 510 def __getstate__(self): 511 ret = self.__dict__.copy() 512 # do not include param_dict in the state 513 del ret['param_dict'] 514 return ret 515 516 def __setstate__(self, state): 517 self.__dict__ = state 518 # param_dict needs to be explicitly set by the trainer 519 self.param_dict = {} 520 521# convenience wrapper for Optimizer.Register 522register = Optimizer.register # pylint: disable=invalid-name 523 524# pylint: disable=line-too-long 525@register 526class SGD(Optimizer): 527 """The SGD optimizer with momentum and weight decay. 528 529 If the storage types of grad is ``row_sparse`` and ``lazy_update`` is True, \ 530 **lazy updates** are applied by:: 531 532 for row in grad.indices: 533 rescaled_grad[row] = lr * (rescale_grad * clip(grad[row], clip_gradient) + wd * weight[row]) 534 state[row] = momentum[row] * state[row] + rescaled_grad[row] 535 weight[row] = weight[row] - state[row] 536 537 The sparse update only updates the momentum for the weights whose row_sparse 538 gradient indices appear in the current batch, rather than updating it for all 539 indices. Compared with the original update, it can provide large 540 improvements in model training throughput for some applications. However, it 541 provides slightly different semantics than the original update, and 542 may lead to different empirical results. 543 544 In the case when ``update_on_kvstore`` is set to False (either globally via 545 MXNET_UPDATE_ON_KVSTORE=0 environment variable or as a parameter in 546 :class:`~mxnet.gluon.Trainer`) SGD optimizer can perform aggregated update 547 of parameters, which may lead to improved performance. The aggregation size 548 is controlled by MXNET_OPTIMIZER_AGGREGATION_SIZE environment variable and 549 defaults to 4. 550 551 Otherwise, **standard updates** are applied by:: 552 553 rescaled_grad = lr * (rescale_grad * clip(grad, clip_gradient) + wd * weight) 554 state = momentum * state + rescaled_grad 555 weight = weight - state 556 557 For details of the update algorithm see 558 :class:`~mxnet.ndarray.sgd_update` and :class:`~mxnet.ndarray.sgd_mom_update`. 559 560 This optimizer accepts the following parameters in addition to those accepted 561 by :class:`.Optimizer`. 562 563 Parameters 564 ---------- 565 momentum : float, optional 566 The momentum value. 567 lazy_update : bool, optional 568 Default is True. If True, lazy updates are applied \ 569 if the storage types of weight and grad are both ``row_sparse``. 570 multi_precision: bool, optional 571 Flag to control the internal precision of the optimizer. 572 False: results in using the same precision as the weights (default), 573 True: makes internal 32-bit copy of the weights and applies gradients 574 in 32-bit precision even if actual weights used in the model have lower precision. 575 Turning this on can improve convergence and accuracy when training with float16. 576 """ 577 def __init__(self, momentum=0.0, lazy_update=True, **kwargs): 578 super(SGD, self).__init__(**kwargs) 579 self.momentum = momentum 580 self.lazy_update = lazy_update 581 self.aggregate_num = int(os.getenv('MXNET_OPTIMIZER_AGGREGATION_SIZE', "4")) 582 583 def create_state_multi_precision(self, index, weight): 584 weight_master_copy = None 585 if self.multi_precision and weight.dtype == numpy.float16: 586 weight_master_copy = weight.astype(numpy.float32) 587 return (self.create_state(index, weight_master_copy), weight_master_copy) 588 if weight.dtype == numpy.float16 and not self.multi_precision: 589 warnings.warn("Accumulating with float16 in optimizer can lead to " 590 "poor accuracy or slow convergence. " 591 "Consider using multi_precision=True option of the " 592 "SGD optimizer") 593 return self.create_state(index, weight) 594 595 def create_state(self, index, weight): 596 momentum = None 597 if self.momentum != 0.0: 598 stype = weight.stype if self.lazy_update else 'default' 599 momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=stype) 600 return momentum 601 602 def _update_impl(self, indices, weights, grads, states, multi_precision=False): 603 aggregate = True 604 if not isinstance(indices, (tuple, list)): 605 indices = [indices] 606 weights = [weights] 607 grads = [grads] 608 states = [states] 609 for weight, grad in zip(weights, grads): 610 assert(isinstance(weight, NDArray)) 611 assert(isinstance(grad, NDArray)) 612 aggregate = (aggregate and 613 weight.stype == 'default' and 614 grad.stype == 'default') 615 self._update_count(indices) 616 lrs = self._get_lrs(indices) 617 wds = self._get_wds(indices) 618 619 kwargs = {'rescale_grad': self.rescale_grad} 620 if self.momentum > 0: 621 kwargs['momentum'] = self.momentum 622 if self.clip_gradient: 623 kwargs['clip_gradient'] = self.clip_gradient 624 625 if aggregate: 626 if not multi_precision: 627 if self.momentum > 0: 628 multi_sgd_mom_update(*_flatten_list(zip(weights, grads, states)), out=weights, 629 num_weights=len(weights), lrs=lrs, wds=wds, **kwargs) 630 else: 631 multi_sgd_update(*_flatten_list(zip(weights, grads)), out=weights, 632 num_weights=len(weights), lrs=lrs, wds=wds, **kwargs) 633 else: 634 if self.momentum > 0: 635 multi_mp_sgd_mom_update(*_flatten_list(zip(weights, grads, *zip(*states))), 636 out=weights, num_weights=len(weights), 637 lrs=lrs, wds=wds, **kwargs) 638 else: 639 multi_mp_sgd_update(*_flatten_list(zip(weights, grads, 640 list(zip(*states))[1])), 641 out=weights, num_weights=len(weights), 642 lrs=lrs, wds=wds, **kwargs) 643 else: 644 for weight, grad, state, lr, wd in zip(weights, grads, states, lrs, wds): 645 if not multi_precision: 646 if state is not None: 647 sgd_mom_update(weight, grad, state, out=weight, 648 lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs) 649 else: 650 sgd_update(weight, grad, out=weight, lazy_update=self.lazy_update, 651 lr=lr, wd=wd, **kwargs) 652 else: 653 if state[0] is not None: 654 mp_sgd_mom_update(weight, grad, state[0], state[1], out=weight, 655 lr=lr, wd=wd, **kwargs) 656 else: 657 mp_sgd_update(weight, grad, state[1], out=weight, 658 lr=lr, wd=wd, **kwargs) 659 660 def update(self, index, weight, grad, state): 661 self._update_impl(index, weight, grad, state, multi_precision=False) 662 663 def update_multi_precision(self, index, weight, grad, state): 664 if not isinstance(index, (tuple, list)): 665 use_multi_precision = self.multi_precision and weight.dtype == numpy.float16 666 else: 667 use_multi_precision = self.multi_precision and weight[0].dtype == numpy.float16 668 self._update_impl(index, weight, grad, state, 669 multi_precision=use_multi_precision) 670 671@register 672class Signum(Optimizer): 673 r"""The Signum optimizer that takes the sign of gradient or momentum. 674 675 The optimizer updates the weight by:: 676 677 rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight 678 state = momentum * state + (1-momentum)*rescaled_grad 679 weight = (1 - lr * wd_lh) * weight - lr * sign(state) 680 681 References 682 ---------- 683 Jeremy Bernstein, Yu-Xiang Wang, Kamyar Azizzadenesheli & Anima Anandkumar. (2018). 684 signSGD: Compressed Optimisation for Non-Convex Problems. In ICML'18. 685 686 See: https://arxiv.org/abs/1802.04434 687 688 For details of the update algorithm see 689 :class:`~mxnet.ndarray.signsgd_update` and :class:`~mxnet.ndarray.signum_update`. 690 691 This optimizer accepts the following parameters in addition to those accepted 692 by :class:`.Optimizer`. 693 694 Parameters 695 ---------- 696 momentum : float, optional 697 The momentum value. 698 wd_lh : float, optional 699 The amount of decoupled weight decay regularization, see details in the original paper at:\ 700 https://arxiv.org/abs/1711.05101 701 """ 702 def __init__(self, learning_rate=0.01, momentum=0.9, wd_lh=0.0, **kwargs): 703 super(Signum, self).__init__(learning_rate=learning_rate, **kwargs) 704 self.momentum = momentum 705 self.wd_lh = wd_lh 706 707 def create_state(self, index, weight): 708 momentum = None 709 if self.momentum != 0.0: 710 momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype) 711 return momentum 712 713 def _update_impl(self, index, weight, grad, state): 714 assert(isinstance(weight, NDArray)) 715 assert(isinstance(grad, NDArray)) 716 self._update_count(index) 717 lr = self._get_lr(index) 718 wd = self._get_wd(index) 719 720 kwargs = {'rescale_grad': self.rescale_grad} 721 if self.momentum > 0: 722 kwargs['momentum'] = self.momentum 723 if self.clip_gradient: 724 kwargs['clip_gradient'] = self.clip_gradient 725 if self.wd_lh: 726 kwargs['wd_lh'] = self.wd_lh 727 728 if state is not None: 729 signum_update(weight, grad, state, out=weight, 730 lr=lr, wd=wd, **kwargs) 731 else: 732 signsgd_update(weight, grad, out=weight, 733 lr=lr, wd=wd, **kwargs) 734 735 def update(self, index, weight, grad, state): 736 self._update_impl(index, weight, grad, state) 737 738@register 739class FTML(Optimizer): 740 """The FTML optimizer. 741 742 This class implements the optimizer described in 743 *FTML - Follow the Moving Leader in Deep Learning*, 744 available at http://proceedings.mlr.press/v70/zheng17a/zheng17a.pdf. 745 746 Denote time step by t. The optimizer updates the weight by:: 747 748 rescaled_grad = clip(grad * rescale_grad + wd * weight, clip_gradient) 749 v = beta2 * v + (1 - beta2) * square(rescaled_grad) 750 d_t = (1 - power(beta1, t)) / lr * square_root(v / (1 - power(beta2, t))) + epsilon) 751 z = beta1 * z + (1 - beta1) * rescaled_grad - (d_t - beta1 * d_(t-1)) * weight 752 weight = - z / d_t 753 754 For details of the update algorithm, see :class:`~mxnet.ndarray.ftml_update`. 755 756 This optimizer accepts the following parameters in addition to those accepted 757 by :class:`.Optimizer`. 758 759 Parameters 760 ---------- 761 beta1 : float, optional 762 0 < beta1 < 1. Generally close to 0.5. 763 beta2 : float, optional 764 0 < beta2 < 1. Generally close to 1. 765 epsilon : float, optional 766 Small value to avoid division by 0. 767 """ 768 def __init__(self, beta1=0.6, beta2=0.999, epsilon=1e-8, **kwargs): 769 super(FTML, self).__init__(**kwargs) 770 self.beta1 = beta1 771 self.beta2 = beta2 772 self.epsilon = epsilon 773 774 def create_state(self, index, weight): 775 return (zeros(weight.shape, weight.context, dtype=weight.dtype), # d_0 776 zeros(weight.shape, weight.context, dtype=weight.dtype), # v_0 777 zeros(weight.shape, weight.context, dtype=weight.dtype)) # z_0 778 779 def update(self, index, weight, grad, state): 780 assert(isinstance(weight, NDArray)) 781 assert(isinstance(grad, NDArray)) 782 self._update_count(index) 783 lr = self._get_lr(index) 784 wd = self._get_wd(index) 785 t = self._index_update_count[index] 786 787 kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon, 788 'rescale_grad': self.rescale_grad, 't': t} 789 if self.clip_gradient: 790 kwargs['clip_grad'] = self.clip_gradient 791 792 prev_d, prev_v, prev_z = state 793 ftml_update(weight, grad, prev_d, prev_v, prev_z, out=weight, 794 lr=lr, wd=wd, **kwargs) 795 796@register 797class LARS(Optimizer): 798 """the LARS optimizer from 'Large Batch Training of Convolution Networks' \ 799 (https://arxiv.org/abs/1708.03888) 800 801 Behave mostly like SGD with momentum and weight decay but is scaling \ 802 adaptively the learning for each layer (except bias and batch norm parameters): 803 w_norm = L2norm(weights) 804 g_norm = L2norm(gradients) 805 if w_norm > 0 and g_norm > 0: 806 lr_layer = lr * lr_mult * eta * w_norm / (g_norm + weight_decay * w_norm + eps) 807 else: 808 lr_layer = lr * lr_mult 809 810 Parameters 811 ---------- 812 momentum : float, optional 813 The momentum value. 814 lazy_update : bool, optional 815 Default is True. If True, lazy updates are applied \ 816 if the storage types of weight and grad are both ``row_sparse``. 817 lars_eta : float, optional 818 LARS coefficient used to scale the learning rate. Default set to 0.001. 819 lars_epsilon : float, optional 820 Optional epsilon in case of very small gradients. Default set to 0. 821 momentum_correction : bool, optional 822 If True scale momentum w.r.t global learning rate change (with an lr_scheduler) \ 823 as indicated in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour` \ 824 (https://arxiv.org/pdf/1706.02677.pdf) 825 Default set to True. 826 """ 827 def __init__(self, momentum=0.0, lazy_update=True, eta=0.001, eps=0, 828 momentum_correction=True, **kwargs): 829 super(LARS, self).__init__(**kwargs) 830 self.momentum = momentum 831 self.momentum_correction = momentum_correction 832 self.lazy_update = lazy_update 833 self.aggregate_num = int(os.getenv('MXNET_OPTIMIZER_AGGREGATION_SIZE', "4")) 834 self.eta = eta 835 self.eps = eps 836 self.skip = 0 837 self.last_lr = None 838 self.cur_lr = None 839 840 841 def _get_lrs(self, indices): 842 """Gets the learning rates given the indices of the weights. 843 844 Parameters 845 ---------- 846 indices : list of int 847 Indices corresponding to weights. 848 849 Returns 850 ------- 851 lrs : list of float 852 Learning rates for those indices. 853 """ 854 if self.cur_lr is not None: 855 self.last_lr = self.cur_lr 856 857 if self.lr_scheduler is not None: 858 lr = self.lr_scheduler(self.num_update) 859 else: 860 lr = self.lr 861 862 if self.cur_lr is None: 863 self.last_lr = lr 864 self.cur_lr = lr 865 866 lrs = [lr for _ in indices] 867 for i, index in enumerate(indices): 868 if index in self.param_dict: 869 lrs[i] *= self.param_dict[index].lr_mult 870 elif index in self.lr_mult: 871 lrs[i] *= self.lr_mult[index] 872 elif index in self.idx2name: 873 lrs[i] *= self.lr_mult.get(self.idx2name[index], 1.0) 874 return lrs 875 876 def set_wd_mult(self, args_wd_mult): 877 self.wd_mult = {} 878 for n in self.idx2name.values(): 879 is_weight = n.endswith('_weight') 880 881 if not is_weight: 882 self.wd_mult[n] = 0.0 883 884 if self.sym_info: 885 attr, arg_names = self.sym_info 886 for name in arg_names: 887 if name in attr and '__wd_mult__' in attr[name]: 888 self.wd_mult[name] = float(attr[name]['__wd_mult__']) 889 self.wd_mult.update(args_wd_mult) 890 891 def create_state_multi_precision(self, index, weight): 892 weight_master_copy = None 893 if self.multi_precision and weight.dtype == numpy.float16: 894 weight_master_copy = weight.astype(numpy.float32) 895 return (self.create_state(index, weight_master_copy), weight_master_copy) 896 if weight.dtype == numpy.float16 and not self.multi_precision: 897 warnings.warn("Accumulating with float16 in optimizer can lead to " 898 "poor accuracy or slow convergence. " 899 "Consider using multi_precision=True option of the " 900 "SGD optimizer") 901 return self.create_state(index, weight) 902 903 def create_state(self, index, weight): 904 momentum = None 905 if self.momentum != 0.0: 906 stype = weight.stype if self.lazy_update else 'default' 907 momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=stype) 908 return momentum 909 910 def _l2norm(self, v, rescale=False): 911 """L2 Norm implementation""" 912 v = v.astype('float32') 913 if rescale: 914 v *= self.rescale_grad 915 norm = NDnorm(v).asnumpy()[0] 916 return norm 917 918 def _get_lars(self, i, weight, g, lr, wd): 919 """Returns a scaling factor for the learning rate for this layer""" 920 name = self.idx2name[i] if i in self.idx2name else str(i) 921 if name.endswith('gamma') or name.endswith('beta') or name.endswith('bias'): 922 return lr 923 924 w_norm = self._l2norm(weight) 925 g_norm = self._l2norm(g, rescale=True) 926 927 if w_norm > 0.0 and g_norm > 0.0: 928 lars = self.eta * w_norm/(g_norm + wd * w_norm + self.eps) 929 else: 930 lars = 1.0 931 return lars * lr 932 933 def _update_impl(self, indices, weights, grads, states, multi_precision=False): 934 aggregate = True 935 if not isinstance(indices, (tuple, list)): 936 indices = [indices] 937 weights = [weights] 938 grads = [grads] 939 states = [states] 940 for weight, grad in zip(weights, grads): 941 assert(isinstance(weight, NDArray)) 942 assert(isinstance(grad, NDArray)) 943 aggregate = (aggregate and 944 weight.stype == 'default' and 945 grad.stype == 'default') 946 self._update_count(indices) 947 lrs = self._get_lrs(indices) 948 wds = self._get_wds(indices) 949 950 kwargs = {'rescale_grad': self.rescale_grad} 951 if self.momentum > 0: 952 kwargs['momentum'] = (self.momentum * (self.cur_lr / self.last_lr)) \ 953 if (self.momentum_correction and self.last_lr != 0) else \ 954 self.momentum 955 956 if self.clip_gradient: 957 kwargs['clip_gradient'] = self.clip_gradient 958 959 if aggregate: 960 nb_params = len(indices) 961 names = [self.idx2name[i] if i in self.idx2name else str(i) for i in indices] 962 lars_idx = [i for i in range(nb_params) if 963 not(names[i].endswith('gamma') or names[i].endswith('beta') or 964 names[i].endswith('bias'))] 965 nb_lars = len(lars_idx) 966 no_lars_idx = [i for i in range(nb_params) if 967 (names[i].endswith('gamma') or names[i].endswith('beta') or 968 names[i].endswith('bias'))] 969 cur_ctx = weights[0].context 970 full_idx = lars_idx + no_lars_idx 971 new_lrs = array([lrs[i] for i in full_idx], ctx=cur_ctx, dtype='float32') 972 new_wds = array([wds[i] for i in full_idx], ctx=cur_ctx, dtype='float32') 973 new_weights = [weights[i] for i in full_idx] 974 new_grads = [grads[i] for i in full_idx] 975 new_states = [states[i] for i in full_idx] 976 if nb_lars > 0: 977 w_sum_sq = multi_sum_sq(*new_weights[:nb_lars], num_arrays=nb_lars) 978 g_sum_sq = multi_sum_sq(*new_grads[:nb_lars], num_arrays=nb_lars) 979 multi_lars(new_lrs[:nb_lars], w_sum_sq, g_sum_sq, new_wds[:nb_lars], 980 eta=self.eta, eps=self.eps, rescale_grad=self.rescale_grad, 981 out=new_lrs[:nb_lars]) 982 # Same than usual using preloaded sgd functions 983 sidx = 0 984 while sidx < len(indices): 985 eidx = sidx + len(new_weights[sidx:sidx+self.aggregate_num]) 986 if not multi_precision: 987 if self.momentum > 0: 988 preloaded_multi_sgd_mom_update( 989 *(_flatten_list(zip(new_weights[sidx:eidx], 990 new_grads[sidx:eidx], 991 new_states[sidx:eidx])) + 992 [new_lrs[sidx:eidx], new_wds[sidx:eidx]]), 993 out=new_weights[sidx:eidx], 994 num_weights=len(new_weights[sidx:eidx]), 995 **kwargs) 996 else: 997 preloaded_multi_sgd_update( 998 *(_flatten_list(zip(new_weights[sidx:eidx], 999 new_grads[sidx:eidx])) + 1000 [new_lrs[sidx:eidx], new_wds[sidx:eidx]]), 1001 out=new_weights[sidx:eidx], 1002 num_weights=len(new_weights[sidx:eidx]), 1003 **kwargs) 1004 else: 1005 if self.momentum > 0: 1006 preloaded_multi_mp_sgd_mom_update( 1007 *(_flatten_list(zip(new_weights[sidx:eidx], 1008 new_grads[sidx:eidx], 1009 *zip(*new_states[sidx:eidx]))) + 1010 [new_lrs[sidx:eidx], new_wds[sidx:eidx]]), 1011 out=new_weights[sidx:eidx], 1012 num_weights=len(new_weights[sidx:eidx]), 1013 **kwargs) 1014 else: 1015 preloaded_multi_mp_sgd_update( 1016 *(_flatten_list(zip(new_weights[sidx:eidx], 1017 new_grads[sidx:eidx], 1018 list(zip(*new_states[sidx:eidx]))[1])) + 1019 [new_lrs[sidx:eidx], new_wds[sidx:eidx]]), 1020 out=new_weights[sidx:eidx], 1021 num_weights=len(new_weights[sidx:eidx]), 1022 **kwargs) 1023 sidx += self.aggregate_num 1024 else: 1025 lrs = [self._get_lars(i, w, g, lr, wd) for (i, w, g, lr, wd) in 1026 zip(indices, weights, grads, lrs, wds)] 1027 1028 for weight, grad, state, lr, wd in zip(weights, grads, states, lrs, wds): 1029 if not multi_precision: 1030 if state is not None: 1031 sgd_mom_update(weight, grad, state, out=weight, 1032 lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs) 1033 else: 1034 sgd_update(weight, grad, out=weight, lazy_update=self.lazy_update, 1035 lr=lr, wd=wd, **kwargs) 1036 else: 1037 if state[0] is not None: 1038 mp_sgd_mom_update(weight, grad, state[0], state[1], out=weight, 1039 lr=lr, wd=wd, **kwargs) 1040 else: 1041 mp_sgd_update(weight, grad, state[1], out=weight, 1042 lr=lr, wd=wd, **kwargs) 1043 1044 def update(self, index, weight, grad, state): 1045 self._update_impl(index, weight, grad, state, multi_precision=False) 1046 1047 def update_multi_precision(self, index, weight, grad, state): 1048 if not isinstance(index, (tuple, list)): 1049 use_multi_precision = self.multi_precision and weight.dtype == numpy.float16 1050 else: 1051 use_multi_precision = self.multi_precision and weight[0].dtype == numpy.float16 1052 self._update_impl(index, weight, grad, state, 1053 multi_precision=use_multi_precision) 1054 1055# 1056@register 1057class LBSGD(Optimizer): 1058 """The Large Batch SGD optimizer with momentum and weight decay. 1059 1060 The optimizer updates the weight by:: 1061 1062 state = momentum * state + lr * rescale_grad * clip(grad, clip_gradient) + wd * weight 1063 weight = weight - state 1064 1065 For details of the update algorithm see :class:`~mxnet.ndarray.sgd_update` 1066 and :class:`~mxnet.ndarray.sgd_mom_update`. 1067 In addition to the SGD updates the LBSGD optimizer uses the LARS, Layer-wise 1068 Adaptive Rate Scaling, algorithm to have a separate learning rate for each 1069 layer of the network, which leads to better stability over large batch sizes. 1070 1071 This optimizer accepts the following parameters in addition to those accepted 1072 by :class:`.Optimizer`. 1073 1074 Parameters 1075 ---------- 1076 momentum : float, optional 1077 The momentum value. 1078 multi_precision: bool, optional 1079 Flag to control the internal precision of the optimizer. 1080 False: results in using the same precision as the weights (default), 1081 True: makes internal 32-bit copy of the weights and applies gradients 1082 in 32-bit precision even if actual weights used in the model have lower precision. 1083 Turning this on can improve convergence and accuracy when training with float16. 1084 1085 warmup_strategy: string ('linear', 'power2', 'sqrt'. , 'lars' default : 'linear') 1086 warmup_epochs: unsigned, default: 5 1087 batch_scale: unsigned, default: 1 (same as batch size * numworkers) 1088 updates_per_epoch: updates_per_epoch (default: 32, Default might not reflect true number batches per epoch. Used for warmup.) 1089 begin_epoch: unsigned, default 0, starting epoch. 1090 """ 1091 def __init__(self, momentum=0.0, multi_precision=False, warmup_strategy='linear', 1092 warmup_epochs=5, batch_scale=1, updates_per_epoch=32, begin_epoch=0, num_epochs=60, 1093 **kwargs): 1094 super(LBSGD, self).__init__(**kwargs) 1095 logging.info('Running Large-Batch SGD Algorithm') 1096 logging.info('(Batch_scale=%f, warmup_epochs=%d, warmup_strategy=%s, updates_per_epoch=%d)', 1097 batch_scale, warmup_epochs, warmup_strategy, updates_per_epoch) 1098 self.momentum = momentum 1099 self.multi_precision = multi_precision 1100 # new user parameters for large batch 1101 self.warmup_strategy = warmup_strategy 1102 self.warmup_epochs = warmup_epochs 1103 self.batch_scale = batch_scale 1104 self.updates_per_epoch = updates_per_epoch 1105 self.init_updates = begin_epoch * updates_per_epoch 1106 self.num_epochs = num_epochs 1107 # addl internal usage parameters and storage 1108 self.lbmult = 1 1109 self.cumgrads = {} 1110 # for adaptive lr 1111 self.adaptive = False 1112 self.admult = 1 # adaptation constant 1113 1114 def create_state(self, index, weight): 1115 momentum = None 1116 weight_master_copy = None 1117 if self.multi_precision and weight.dtype == numpy.float16: 1118 weight_master_copy = array(weight, ctx=weight.context, dtype=numpy.float32) 1119 if self.momentum != 0.0: 1120 momentum = zeros(weight.shape, weight.context, dtype=numpy.float32, 1121 stype=weight.stype) 1122 return (momentum, weight_master_copy) 1123 if weight.dtype == numpy.float16 and not self.multi_precision: 1124 warnings.warn("Accumulating with float16 in optimizer can lead to " 1125 "poor accuracy or slow convergence. " 1126 "Consider using multi_precision=True option of the " 1127 "SGD optimizer") 1128 if self.momentum != 0.0: 1129 momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype) 1130 return momentum 1131 1132 def _get_lbmult(self, nup): 1133 """Returns lr scaling factor for large batch according to warmup schedule 1134 (to be implemented) 1135 """ 1136 nwup = self.warmup_epochs * self.updates_per_epoch 1137 strategy = self.warmup_strategy 1138 maxmult = float(self.batch_scale) 1139 if nup >= nwup: 1140 mult = maxmult 1141 elif nwup <= 1: 1142 mult = 1.0 1143 else: 1144 if (strategy == 'linear'): 1145 mult = 1.0 + (maxmult - 1) * nup / nwup 1146 elif (strategy == 'power2'): 1147 mult = 1.0 + (maxmult-1) * (nup*nup)/(nwup*nwup) 1148 elif (strategy == 'sqrt'): 1149 mult = 1.0 + (maxmult - 1) * math.sqrt(float(nup) / nwup) 1150 else: 1151 mult = 1.0 1152 return mult 1153 1154 def _get_lars(self, weight, g, wd): 1155 """Returns a scaling factor for the learning rate for this layer 1156 default is 1 1157 """ 1158 weight2 = self._l2norm(weight) 1159 grad2 = self._l2norm(g) 1160 lars = math.sqrt(weight2 / (grad2 + wd * weight2 + 1e-18)) 1161 if lars < 0.01: 1162 lars = 0.01 1163 elif lars > 100: 1164 lars = 100 1165 return lars 1166 1167 def _l2norm(self, v): 1168 "inner product implementation" 1169 norm = multiply(v, v).asnumpy().sum() 1170 return norm 1171 1172 def _reset_cum_gradient(self, index): 1173 "called every macro-batch to reset cumulated gradients to 0 for a given index" 1174 self.cumgrads[index]['cum_grad'] = 0 1175 1176 def _get_cum_gradient(self, index): 1177 "get the cumulated gradient for index" 1178 if index in self.cumgrads: 1179 return self.cumgrads[index] 1180 else: 1181 return {} 1182 1183 def _put_cum_gradient(self, index, cgrad): 1184 "store cumulated gradient for index" 1185 self.cumgrads[index] = cgrad 1186 1187 def _cumulate_gradient(self, grad, index): 1188 "Cumulate gradients for large-batch emulation. Cumulated by index (layer)" 1189 cgrad = self._get_cum_gradient(index) 1190 if cgrad: 1191 num_cums = cgrad['num_cums'] 1192 if num_cums > 0: 1193 cum_grad = cgrad['cum_grad'] + grad 1194 num_cums += 1 1195 else: 1196 cum_grad = grad 1197 num_cums = self.init_updates + 1 1198 else: 1199 cum_grad = grad 1200 num_cums = self.init_updates + 1 1201 cgrad = {'cum_grad': cum_grad, 'num_cums': num_cums} 1202 self._put_cum_gradient(index, cgrad) 1203 return cgrad 1204 1205 def update(self, index, weight, grad, state): 1206 assert (isinstance(weight, NDArray)) 1207 assert (isinstance(grad, NDArray)) 1208 1209 lr = self._get_lr(index) 1210 wd = self._get_wd(index) 1211 self._update_count(index) 1212 1213 # new stuff for large batch 1214 cgrad = self._cumulate_gradient(grad, index) 1215 if (cgrad['num_cums'] % self.batch_scale) == 0: 1216 grad = cgrad['cum_grad'] / self.batch_scale 1217 if self.warmup_strategy == 'lars': 1218 lbmult = self._get_lars(weight, grad, wd) 1219 else: 1220 lbmult = self._get_lbmult(cgrad['num_cums']) 1221 lr = lr * lbmult 1222 # do the regular sgd update flow 1223 kwargs = {'rescale_grad': self.rescale_grad} 1224 if self.momentum > 0: 1225 kwargs['momentum'] = self.momentum 1226 if self.clip_gradient: 1227 kwargs['clip_gradient'] = self.clip_gradient 1228 use_multi_precision = isinstance(state, (list, tuple)) 1229 1230 if not use_multi_precision: 1231 if state is not None: 1232 sgd_mom_update(weight, grad, state, out=weight, lr=lr, wd=wd, **kwargs) 1233 else: 1234 sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs) 1235 else: 1236 if state[0] is not None: 1237 mp_sgd_mom_update(weight, grad, state[0], state[1], out=weight, lr=lr, wd=wd, 1238 **kwargs) 1239 else: 1240 mp_sgd_update(weight, grad, state[1], out=weight, lr=lr, wd=wd, **kwargs) 1241 # reset update count and cumulated gradient per large batch 1242 self._reset_cum_gradient(index) 1243 else: 1244 lr = 0.0 1245 kwargs = {} 1246 sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs) 1247 1248 1249@register 1250class LAMB(Optimizer): 1251 """LAMB Optimizer. 1252 """ 1253 def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6, 1254 lower_bound=None, upper_bound=None, bias_correction=True, **kwargs): 1255 super(LAMB, self).__init__(learning_rate=learning_rate, **kwargs) 1256 self.beta1 = beta1 1257 self.beta2 = beta2 1258 self.epsilon = epsilon 1259 self.lower_bound = lower_bound 1260 self.upper_bound = upper_bound 1261 self.bias_correction = bias_correction 1262 self.aggregate_num = max(1, min(45, int(os.getenv('MXNET_OPTIMIZER_AGGREGATION_SIZE', "45")))) 1263 1264 def create_state(self, index, weight): 1265 stype = weight.stype 1266 dtype = weight.dtype 1267 return (zeros(weight.shape, weight.context, dtype=dtype, stype=stype), 1268 zeros(weight.shape, weight.context, dtype=dtype, stype=stype)) 1269 1270 def _update_impl(self, index, weight, grad, state, multi_precision=False): 1271 kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon, 1272 'bias_correction': self.bias_correction, 1273 'rescale_grad': self.rescale_grad} 1274 1275 if self.aggregate_num <= 1 or not isinstance(index, (tuple, list)): 1276 if isinstance(index, (tuple, list)): 1277 assert(len(index) == self.aggregate_num) 1278 index, weight, grad, state = index[0], weight[0], grad[0], state[0] 1279 assert(isinstance(weight, NDArray)) 1280 assert(isinstance(grad, NDArray)) 1281 self._update_count(index) 1282 lr = self._get_lr(index) 1283 wd = self._get_wd(index) 1284 t = self._index_update_count[index] 1285 weight_ptr = weight 1286 grad_ptr = grad 1287 if multi_precision: 1288 mean, var = state[1] 1289 weight32 = state[0] 1290 else: 1291 mean, var = state 1292 kwargs['t'] = t 1293 if self.clip_gradient: 1294 kwargs['clip_gradient'] = self.clip_gradient 1295 1296 if multi_precision: 1297 g = mp_lamb_update_phase1(weight_ptr, grad_ptr, mean, var, weight32, wd=wd, **kwargs) 1298 kwargs = {} 1299 if self.lower_bound: 1300 kwargs['lower_bound'] = self.lower_bound 1301 if self.upper_bound: 1302 kwargs['upper_bound'] = self.upper_bound 1303 r_1 = weight32.norm() 1304 r_2 = g.norm() 1305 mp_lamb_update_phase2(weight_ptr, g, r_1, r_2, weight32, lr=lr, out=weight_ptr, **kwargs) 1306 else: 1307 g = lamb_update_phase1(weight_ptr, grad_ptr, mean, var, wd=wd, **kwargs) 1308 kwargs = {} 1309 if self.lower_bound: 1310 kwargs['lower_bound'] = self.lower_bound 1311 if self.upper_bound: 1312 kwargs['upper_bound'] = self.upper_bound 1313 r_1 = weight_ptr.norm() 1314 r_2 = g.norm() 1315 lamb_update_phase2(weight_ptr, g, r_1, r_2, lr=lr, out=weight_ptr, **kwargs) 1316 else: 1317 if self.clip_gradient: 1318 kwargs['clip_gradient'] = self.clip_gradient 1319 if self.lower_bound: 1320 kwargs['lower_bound'] = self.lower_bound 1321 if self.upper_bound: 1322 kwargs['upper_bound'] = self.upper_bound 1323 1324 step_count, lrs, wds = [], [], [] 1325 for i, w_i, g_i in zip(index, weight, grad): 1326 assert(isinstance(w_i, NDArray)) 1327 assert(isinstance(g_i, NDArray)) 1328 self._update_count(i) 1329 step_count.append(self._index_update_count[i]) 1330 lrs.append(self._get_lr(i)) 1331 wds.append(self._get_wd(i)) 1332 1333 updated_tensors = 0 1334 while updated_tensors < len(weight): 1335 sidx = updated_tensors 1336 eidx = min(updated_tensors + self.aggregate_num, len(weight)) 1337 if not multi_precision: 1338 mean, var = list(zip(*state[sidx:eidx])) 1339 multi_lamb_update(weight[sidx:eidx], 1340 grad[sidx:eidx], 1341 mean, var, 1342 out=weight[sidx:eidx], 1343 step_count=step_count[sidx:eidx], 1344 lrs=lrs[sidx:eidx], 1345 wds=wds[sidx:eidx], 1346 **kwargs) 1347 else: 1348 mean_var = list(zip(*state[sidx:eidx]))[1] 1349 temp = list(zip(*mean_var)) 1350 mean = temp[0] 1351 var = temp[1] 1352 multi_mp_lamb_update(weight[sidx:eidx], 1353 grad[sidx:eidx], 1354 mean, var, 1355 list(zip(*state[sidx:eidx]))[0], 1356 out=weight[sidx:eidx], 1357 step_count=step_count[sidx:eidx], 1358 lrs=lrs[sidx:eidx], 1359 wds=wds[sidx:eidx], 1360 **kwargs) 1361 updated_tensors += self.aggregate_num 1362 1363 def update(self, index, weight, grad, state): 1364 self._update_impl(index, weight, grad, state, multi_precision=False) 1365 1366 def update_multi_precision(self, index, weight, grad, state): 1367 if not isinstance(index, (tuple, list)): 1368 use_multi_precision = self.multi_precision and weight.dtype == numpy.float16 1369 else: 1370 use_multi_precision = self.multi_precision and weight[0].dtype == numpy.float16 1371 self._update_impl(index, weight, grad, state, 1372 multi_precision=use_multi_precision) 1373 1374# pylint: enable=line-too-long 1375@register 1376class DCASGD(Optimizer): 1377 """The DCASGD optimizer. 1378 1379 This class implements the optimizer described in *Asynchronous Stochastic Gradient Descent 1380 with Delay Compensation for Distributed Deep Learning*, 1381 available at https://arxiv.org/abs/1609.08326. 1382 1383 This optimizer accepts the following parameters in addition to those accepted 1384 by :class:`.Optimizer`. 1385 1386 Parameters 1387 ---------- 1388 momentum : float, optional 1389 The momentum value. 1390 1391 lamda : float, optional 1392 Scale DC value. 1393 """ 1394 def __init__(self, momentum=0.0, lamda=0.04, **kwargs): 1395 super(DCASGD, self).__init__(**kwargs) 1396 self.momentum = momentum 1397 self.weight_previous = {} 1398 self.lamda = lamda 1399 1400 def create_state(self, index, weight): 1401 if self.momentum == 0.0: 1402 return (None, 1403 weight.copy()) # previous weight 1404 else: 1405 return (zeros(weight.shape, weight.context, dtype=weight.dtype), # momentum 1406 weight.copy()) # previous weight 1407 1408 def update(self, index, weight, grad, state): 1409 assert(isinstance(weight, NDArray)) 1410 assert(isinstance(grad, NDArray)) 1411 self._update_count(index) 1412 lr = self._get_lr(index) 1413 wd = self._get_wd(index) 1414 1415 grad = grad * self.rescale_grad 1416 if self.clip_gradient is not None: 1417 grad = clip(grad, -self.clip_gradient, self.clip_gradient) 1418 1419 mom, previous_weight = state 1420 if mom: 1421 mom[:] *= self.momentum 1422 mom[:] += -lr * (grad + wd * weight + self.lamda \ 1423 * grad * grad * (weight - previous_weight)) 1424 else: 1425 assert(self.momentum == 0.0) 1426 mom = -lr * (grad + wd * weight + self.lamda \ 1427 * grad * grad * (weight - previous_weight)) 1428 previous_weight[:] = weight 1429 weight[:] += mom 1430 1431@register 1432class NAG(Optimizer): 1433 """Nesterov accelerated gradient. 1434 1435 This optimizer updates each weight by:: 1436 1437 state = momentum * state + grad + wd * weight 1438 weight = weight - (lr * (grad + momentum * state)) 1439 1440 Parameters 1441 ---------- 1442 momentum : float, optional 1443 The momentum value. 1444 multi_precision: bool, optional 1445 Flag to control the internal precision of the optimizer. 1446 False: results in using the same precision as the weights (default), 1447 True: makes internal 32-bit copy of the weights and applies gradients 1448 in 32-bit precision even if actual weights used in the model have lower precision. 1449 Turning this on can improve convergence and accuracy when training with float16. 1450 """ 1451 def __init__(self, momentum=0.0, **kwargs): 1452 super(NAG, self).__init__(**kwargs) 1453 self.momentum = momentum 1454 1455 def create_state_multi_precision(self, index, weight): 1456 weight_master_copy = None 1457 if self.multi_precision and weight.dtype == numpy.float16: 1458 weight_master_copy = weight.astype(numpy.float32) 1459 return (self.create_state(index, weight_master_copy), weight_master_copy) 1460 if weight.dtype == numpy.float16 and not self.multi_precision: 1461 warnings.warn("Accumulating with float16 in optimizer can lead to " 1462 "poor accuracy or slow convergence. " 1463 "Consider using multi_precision=True option of the " 1464 "NAG optimizer") 1465 return self.create_state(index, weight) 1466 1467 def create_state(self, index, weight): 1468 momentum = None 1469 if self.momentum != 0.0: 1470 momentum = zeros(weight.shape, weight.context, dtype=weight.dtype) 1471 return momentum 1472 1473 def _update_impl(self, index, weight, grad, state, multi_precision=False): 1474 assert(isinstance(weight, NDArray)) 1475 assert(isinstance(grad, NDArray)) 1476 self._update_count(index) 1477 lr = self._get_lr(index) 1478 wd = self._get_wd(index) 1479 1480 kwargs = {'rescale_grad': self.rescale_grad} 1481 if self.momentum > 0: 1482 kwargs['momentum'] = self.momentum 1483 if self.clip_gradient: 1484 kwargs['clip_gradient'] = self.clip_gradient 1485 1486 if not multi_precision: 1487 if state is not None: 1488 nag_mom_update(weight, grad, state, out=weight, lr=lr, wd=wd, **kwargs) 1489 else: 1490 sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs) 1491 else: 1492 if state[0] is not None: 1493 mp_nag_mom_update(weight, grad, state[0], state[1], out=weight, 1494 lr=lr, wd=wd, **kwargs) 1495 else: 1496 mp_sgd_update(weight, grad, state[1], out=weight, 1497 lr=lr, wd=wd, **kwargs) 1498 1499 def update(self, index, weight, grad, state): 1500 self._update_impl(index, weight, grad, state, multi_precision=False) 1501 1502 def update_multi_precision(self, index, weight, grad, state): 1503 use_multi_precision = self.multi_precision and weight.dtype == numpy.float16 \ 1504 and isinstance(state, (tuple, list)) 1505 self._update_impl(index, weight, grad, state, 1506 multi_precision=use_multi_precision) 1507 1508 1509@register 1510class SGLD(Optimizer): 1511 """Stochastic Gradient Riemannian Langevin Dynamics. 1512 1513 This class implements the optimizer described in the paper *Stochastic Gradient 1514 Riemannian Langevin Dynamics on the Probability Simplex*, available at 1515 https://papers.nips.cc/paper/4883-stochastic-gradient-riemannian-langevin-dynamics-on-the-probability-simplex.pdf. 1516 1517 """ 1518 def __init__(self, **kwargs): 1519 super(SGLD, self).__init__(**kwargs) 1520 1521 def create_state(self, index, weight): 1522 return None 1523 1524 def update(self, index, weight, grad, state): 1525 assert(isinstance(weight, NDArray)) 1526 assert(isinstance(grad, NDArray)) 1527 self._update_count(index) 1528 lr = self._get_lr(index) 1529 wd = self._get_wd(index) 1530 1531 grad = grad * self.rescale_grad 1532 if self.clip_gradient is not None: 1533 grad = clip(grad, -self.clip_gradient, self.clip_gradient) 1534 weight[:] += - lr/2 * (grad + wd * weight) 1535 weight[:] += normal(0, math.sqrt(lr), shape=weight.shape, 1536 dtype=weight.dtype, ctx=weight.context) 1537 1538 1539 1540@register # pylint: disable=invalid-name 1541class ccSGD(SGD): 1542 """[DEPRECATED] Same as `SGD`. Left here for backward compatibility.""" 1543 def __init__(self, *args, **kwargs): 1544 super(ccSGD, self).__init__(*args, **kwargs) 1545 1546@register 1547class Adam(Optimizer): 1548 """The Adam optimizer. 1549 1550 This class implements the optimizer described in *Adam: A Method for 1551 Stochastic Optimization*, available at http://arxiv.org/abs/1412.6980. 1552 1553 If the storage types of grad is ``row_sparse``, and ``lazy_update`` is True, \ 1554 **lazy updates** at step t are applied by:: 1555 1556 for row in grad.indices: 1557 rescaled_grad[row] = clip(grad[row] * rescale_grad + wd * weight[row], clip_gradient) 1558 m[row] = beta1 * m[row] + (1 - beta1) * rescaled_grad[row] 1559 v[row] = beta2 * v[row] + (1 - beta2) * (rescaled_grad[row]**2) 1560 lr = learning_rate * sqrt(1 - beta1**t) / (1 - beta2**t) 1561 w[row] = w[row] - lr * m[row] / (sqrt(v[row]) + epsilon) 1562 1563 The lazy update only updates the mean and var for the weights whose row_sparse 1564 gradient indices appear in the current batch, rather than updating it for all indices. 1565 Compared with the original update, it can provide large improvements in model training 1566 throughput for some applications. However, it provides slightly different semantics than 1567 the original update, and may lead to different empirical results. 1568 1569 Otherwise, **standard updates** at step t are applied by:: 1570 1571 rescaled_grad = clip(grad * rescale_grad + wd * weight, clip_gradient) 1572 m = beta1 * m + (1 - beta1) * rescaled_grad 1573 v = beta2 * v + (1 - beta2) * (rescaled_grad**2) 1574 lr = learning_rate * sqrt(1 - beta1**t) / (1 - beta2**t) 1575 w = w - lr * m / (sqrt(v) + epsilon) 1576 1577 This optimizer accepts the following parameters in addition to those accepted 1578 by :class:`.Optimizer`. 1579 1580 For details of the update algorithm, see :class:`~mxnet.ndarray.adam_update`. 1581 1582 Parameters 1583 ---------- 1584 beta1 : float, optional 1585 Exponential decay rate for the first moment estimates. 1586 beta2 : float, optional 1587 Exponential decay rate for the second moment estimates. 1588 epsilon : float, optional 1589 Small value to avoid division by 0. 1590 lazy_update : bool, optional 1591 Default is True. If True, lazy updates are applied \ 1592 if the storage types of weight and grad are both ``row_sparse``. 1593 """ 1594 def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, 1595 lazy_update=True, **kwargs): 1596 super(Adam, self).__init__(learning_rate=learning_rate, **kwargs) 1597 self.beta1 = beta1 1598 self.beta2 = beta2 1599 self.epsilon = epsilon 1600 self.lazy_update = lazy_update 1601 1602 def create_state(self, index, weight): 1603 stype = weight.stype if self.lazy_update else 'default' 1604 return (zeros(weight.shape, weight.context, dtype=weight.dtype, 1605 stype=stype), # mean 1606 zeros(weight.shape, weight.context, dtype=weight.dtype, 1607 stype=stype)) # variance 1608 1609 def update(self, index, weight, grad, state): 1610 assert(isinstance(weight, NDArray)) 1611 assert(isinstance(grad, NDArray)) 1612 self._update_count(index) 1613 lr = self._get_lr(index) 1614 wd = self._get_wd(index) 1615 1616 t = self._index_update_count[index] 1617 coef1 = 1. - self.beta1**t 1618 coef2 = 1. - self.beta2**t 1619 lr *= math.sqrt(coef2)/coef1 1620 1621 kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon, 1622 'rescale_grad': self.rescale_grad} 1623 if self.clip_gradient: 1624 kwargs['clip_gradient'] = self.clip_gradient 1625 1626 mean, var = state 1627 adam_update(weight, grad, mean, var, out=weight, 1628 lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs) 1629 1630@register 1631class AdaGrad(Optimizer): 1632 """AdaGrad optimizer. 1633 1634 This class implements the AdaGrad optimizer described in *Adaptive Subgradient 1635 Methods for Online Learning and Stochastic Optimization*, and available at 1636 http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf. 1637 1638 This optimizer updates each weight by:: 1639 1640 grad = clip(grad * rescale_grad, clip_gradient) 1641 history += square(grad) 1642 div = grad / sqrt(history + float_stable_eps) 1643 weight += (div + weight * wd) * -lr 1644 1645 This optimizer accepts the following parameters in addition to those accepted 1646 by :class:`.Optimizer`. 1647 1648 See Also 1649 ---------- 1650 :meth:`mxnet.ndarray.sparse.adagrad_update`. 1651 1652 Parameters 1653 ---------- 1654 eps: float, optional 1655 Initial value of the history accumulator. Avoids division by 0. 1656 1657 """ 1658 def __init__(self, eps=1e-7, **kwargs): 1659 super(AdaGrad, self).__init__(**kwargs) 1660 self.float_stable_eps = eps 1661 1662 def create_state(self, index, weight): 1663 return zeros(weight.shape, weight.context, stype=weight.stype) # history 1664 1665 def update(self, index, weight, grad, state): 1666 assert(isinstance(weight, NDArray)) 1667 assert(isinstance(grad, NDArray)) 1668 self._update_count(index) 1669 lr = self._get_lr(index) 1670 wd = self._get_wd(index) 1671 1672 is_sparse = grad.stype == 'row_sparse' 1673 history = state 1674 1675 if is_sparse: 1676 kwargs = {'epsilon': self.float_stable_eps, 1677 'rescale_grad': self.rescale_grad} 1678 if self.clip_gradient: 1679 kwargs['clip_gradient'] = self.clip_gradient 1680 sparse.adagrad_update(weight, grad, history, out=weight, lr=lr, wd=wd, **kwargs) 1681 else: 1682 grad = grad * self.rescale_grad 1683 if self.clip_gradient is not None: 1684 grad = clip(grad, -self.clip_gradient, self.clip_gradient) 1685 history[:] += square(grad) 1686 div = grad / sqrt(history + self.float_stable_eps) 1687 weight[:] += (div + weight * wd) * -lr 1688 1689@register 1690class RMSProp(Optimizer): 1691 """The RMSProp optimizer. 1692 1693 Two versions of RMSProp are implemented: 1694 1695 If ``centered=False``, we follow 1696 http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf by 1697 Tieleman & Hinton, 2012. 1698 For details of the update algorithm see :class:`~mxnet.ndarray.rmsprop_update`. 1699 1700 If ``centered=True``, we follow http://arxiv.org/pdf/1308.0850v5.pdf (38)-(45) 1701 by Alex Graves, 2013. 1702 For details of the update algorithm see :class:`~mxnet.ndarray.rmspropalex_update`. 1703 1704 This optimizer accepts the following parameters in addition to those accepted 1705 by :class:`.Optimizer`. 1706 1707 Parameters 1708 ---------- 1709 gamma1: float, optional 1710 A decay factor of moving average over past squared gradient. 1711 gamma2: float, optional 1712 A "momentum" factor. Only used if `centered`=``True``. 1713 epsilon : float, optional 1714 Small value to avoid division by 0. 1715 centered : bool, optional 1716 Flag to control which version of RMSProp to use.:: 1717 1718 True: will use Graves's version of `RMSProp`, 1719 False: will use Tieleman & Hinton's version of `RMSProp`. 1720 1721 clip_weights : float, optional 1722 Clips weights into range ``[-clip_weights, clip_weights]``. 1723 """ 1724 def __init__(self, learning_rate=0.001, gamma1=0.9, gamma2=0.9, 1725 epsilon=1e-8, centered=False, clip_weights=None, **kwargs): 1726 super(RMSProp, self).__init__(learning_rate=learning_rate, **kwargs) 1727 self.gamma1 = gamma1 1728 self.gamma2 = gamma2 1729 self.centered = centered 1730 self.epsilon = epsilon 1731 self.clip_weights = clip_weights 1732 1733 def create_state(self, index, weight): 1734 if self.centered: 1735 return ( 1736 zeros(weight.shape, weight.context, stype=weight.stype), # n 1737 zeros(weight.shape, weight.context, stype=weight.stype), # g 1738 zeros(weight.shape, weight.context, stype=weight.stype)) # delta 1739 else: 1740 return (zeros(weight.shape, weight.context, stype=weight.stype),) # n 1741 1742 def update(self, index, weight, grad, state): 1743 assert(isinstance(weight, NDArray)) 1744 assert(isinstance(grad, NDArray)) 1745 self._update_count(index) 1746 lr = self._get_lr(index) 1747 wd = self._get_wd(index) 1748 1749 kwargs = {'gamma1': self.gamma1, 'epsilon': self.epsilon, 1750 'rescale_grad': self.rescale_grad} 1751 if self.centered: 1752 kwargs['gamma2'] = self.gamma2 1753 if self.clip_gradient: 1754 kwargs['clip_gradient'] = self.clip_gradient 1755 if self.clip_weights: 1756 kwargs['clip_weights'] = self.clip_weights 1757 1758 if not self.centered: 1759 (n, ) = state 1760 rmsprop_update( 1761 weight, grad, n, out=weight, lr=lr, wd=wd, **kwargs) 1762 else: 1763 n, g, delta = state 1764 rmspropalex_update(weight, grad, n, g, delta, out=weight, 1765 lr=lr, wd=wd, **kwargs) 1766 1767@register 1768class AdaDelta(Optimizer): 1769 """The AdaDelta optimizer. 1770 1771 This class implements AdaDelta, an optimizer described in *ADADELTA: An adaptive 1772 learning rate method*, available at https://arxiv.org/abs/1212.5701. 1773 1774 This optimizer updates each weight by:: 1775 1776 grad = clip(grad * rescale_grad + wd * weight, clip_gradient) 1777 acc_grad = rho * acc_grad + (1. - rho) * grad * grad 1778 delta = sqrt(acc_delta + epsilon) / sqrt(acc_grad + epsilon) * grad 1779 acc_delta = rho * acc_delta + (1. - rho) * delta * delta 1780 weight -= (delta + wd * weight) 1781 1782 This optimizer accepts the following parameters in addition to those accepted 1783 by :class:`.Optimizer`. 1784 1785 Parameters 1786 ---------- 1787 rho: float 1788 Decay rate for both squared gradients and delta. 1789 epsilon : float 1790 Small value to avoid division by 0. 1791 """ 1792 def __init__(self, rho=0.90, epsilon=1e-5, **kwargs): 1793 super(AdaDelta, self).__init__(**kwargs) 1794 self.rho = rho 1795 self.epsilon = epsilon 1796 1797 def create_state(self, index, weight): 1798 return (zeros(weight.shape, weight.context), # accumulated g 1799 zeros(weight.shape, weight.context)) # accumulated delta 1800 1801 def update(self, index, weight, grad, state): 1802 assert(isinstance(weight, NDArray)) 1803 assert(isinstance(grad, NDArray)) 1804 wd = self._get_wd(index) 1805 self._update_count(index) 1806 1807 # preprocess grad 1808 grad *= self.rescale_grad 1809 if self.clip_gradient is not None: 1810 grad = clip(grad, - self.clip_gradient, self.clip_gradient) 1811 1812 # accumulated g and delta initlization 1813 acc_g, acc_delta = state 1814 1815 # update g, delta 1816 acc_g[:] *= self.rho 1817 acc_g[:] += (1. - self.rho) * grad * grad 1818 current_delta = sqrt(acc_delta + self.epsilon) / sqrt(acc_g + self.epsilon) * grad 1819 acc_delta[:] *= self.rho 1820 acc_delta[:] += (1. - self.rho) * current_delta * current_delta 1821 1822 # update weight 1823 weight[:] -= current_delta + wd * weight 1824 1825#pylint: disable=invalid-name 1826#pylint: disable=line-too-long 1827@register 1828class Ftrl(Optimizer): 1829 """The Ftrl optimizer. 1830 1831 Referenced from *Ad Click Prediction: a View from the Trenches*, available at 1832 http://dl.acm.org/citation.cfm?id=2488200. 1833 1834 eta : 1835 .. math:: 1836 \\eta_{t,i} = \\frac{learningrate}{\\beta+\\sqrt{\\sum_{s=1}^tg_{s,i}^2}} 1837 1838 The optimizer updates the weight by:: 1839 1840 rescaled_grad = clip(grad * rescale_grad, clip_gradient) 1841 z += rescaled_grad - (sqrt(n + rescaled_grad**2) - sqrt(n)) * weight / learning_rate 1842 n += rescaled_grad**2 1843 w = (sign(z) * lamda1 - z) / ((beta + sqrt(n)) / learning_rate + wd) * (abs(z) > lamda1) 1844 1845 If the storage types of weight, state and grad are all ``row_sparse``, \ 1846 **sparse updates** are applied by:: 1847 1848 for row in grad.indices: 1849 rescaled_grad[row] = clip(grad[row] * rescale_grad, clip_gradient) 1850 z[row] += rescaled_grad[row] - (sqrt(n[row] + rescaled_grad[row]**2) - sqrt(n[row])) * weight[row] / learning_rate 1851 n[row] += rescaled_grad[row]**2 1852 w[row] = (sign(z[row]) * lamda1 - z[row]) / ((beta + sqrt(n[row])) / learning_rate + wd) * (abs(z[row]) > lamda1) 1853 1854 The sparse update only updates the z and n for the weights whose row_sparse 1855 gradient indices appear in the current batch, rather than updating it for all 1856 indices. Compared with the original update, it can provide large 1857 improvements in model training throughput for some applications. However, it 1858 provides slightly different semantics than the original update, and 1859 may lead to different empirical results. 1860 1861 For details of the update algorithm, see :class:`~mxnet.ndarray.ftrl_update`. 1862 1863 This optimizer accepts the following parameters in addition to those accepted 1864 by :class:`.Optimizer`. 1865 1866 Parameters 1867 ---------- 1868 lamda1 : float, optional 1869 L1 regularization coefficient. 1870 learning_rate : float, optional 1871 The initial learning rate. 1872 beta : float, optional 1873 Per-coordinate learning rate correlation parameter. 1874 """ 1875 1876 def __init__(self, lamda1=0.01, learning_rate=0.1, beta=1, **kwargs): 1877 super(Ftrl, self).__init__(**kwargs) 1878 self.lamda1 = lamda1 1879 self.beta = beta 1880 self.lr = learning_rate 1881 1882 def create_state(self, index, weight): 1883 return (zeros(weight.shape, weight.context, stype=weight.stype), # z 1884 zeros(weight.shape, weight.context, stype=weight.stype)) # n 1885 1886 def update(self, index, weight, grad, state): 1887 assert(isinstance(weight, NDArray)) 1888 assert(isinstance(grad, NDArray)) 1889 self._update_count(index) 1890 wd = self._get_wd(index) 1891 lr = self._get_lr(index) 1892 1893 kwargs = {'lamda1': self.lamda1, 'beta': self.beta, 'rescale_grad': self.rescale_grad} 1894 if self.clip_gradient: 1895 kwargs['clip_gradient'] = self.clip_gradient 1896 1897 # accumulated g and delta initialization 1898 z, n = state 1899 ftrl_update(weight, grad, z, n, out=weight, 1900 lr=lr, wd=wd, **kwargs) 1901 1902# pylint: enable=line-too-long 1903@register 1904class Adamax(Optimizer): 1905 """The AdaMax optimizer. 1906 1907 It is a variant of Adam based on the infinity norm 1908 available at http://arxiv.org/abs/1412.6980 Section 7. 1909 1910 The optimizer updates the weight by:: 1911 1912 grad = clip(grad * rescale_grad + wd * weight, clip_gradient) 1913 m = beta1 * m_t + (1 - beta1) * grad 1914 u = maximum(beta2 * u, abs(grad)) 1915 weight -= lr / (1 - beta1**t) * m / u 1916 1917 This optimizer accepts the following parameters in addition to those accepted 1918 by :class:`.Optimizer`. 1919 1920 Parameters 1921 ---------- 1922 beta1 : float, optional 1923 Exponential decay rate for the first moment estimates. 1924 beta2 : float, optional 1925 Exponential decay rate for the second moment estimates. 1926 """ 1927 def __init__(self, learning_rate=0.002, beta1=0.9, beta2=0.999, **kwargs): 1928 super(Adamax, self).__init__(learning_rate=learning_rate, **kwargs) 1929 self.beta1 = beta1 1930 self.beta2 = beta2 1931 1932 def create_state(self, index, weight): 1933 return (zeros(weight.shape, weight.context, dtype=weight.dtype), # mean 1934 zeros(weight.shape, weight.context, dtype=weight.dtype)) # variance 1935 1936 def update(self, index, weight, grad, state): 1937 assert(isinstance(weight, NDArray)) 1938 assert(isinstance(grad, NDArray)) 1939 self._update_count(index) 1940 lr = self._get_lr(index) 1941 wd = self._get_wd(index) 1942 1943 t = self._index_update_count[index] 1944 lr /= (1. - self.beta1**t) 1945 1946 # preprocess grad 1947 grad = grad * self.rescale_grad + wd * weight 1948 if self.clip_gradient is not None: 1949 grad = clip(grad, -self.clip_gradient, self.clip_gradient) 1950 1951 # update m_t and u_t 1952 m_t, u_t = state 1953 m_t[:] *= self.beta1 1954 m_t[:] += (1. - self.beta1) * grad 1955 u_t[:] = maximum(self.beta2 * u_t, NDabs(grad)) 1956 1957 # update weight 1958 weight[:] -= lr * m_t / u_t 1959 1960@register 1961class Nadam(Optimizer): 1962 """The Nesterov Adam optimizer. 1963 1964 Much like Adam is essentially RMSprop with momentum, 1965 Nadam is Adam RMSprop with Nesterov momentum available 1966 at http://cs229.stanford.edu/proj2015/054_report.pdf. 1967 1968 This optimizer accepts the following parameters in addition to those accepted 1969 by :class:`.Optimizer`. 1970 1971 Parameters 1972 ---------- 1973 beta1 : float, optional 1974 Exponential decay rate for the first moment estimates. 1975 beta2 : float, optional 1976 Exponential decay rate for the second moment estimates. 1977 epsilon : float, optional 1978 Small value to avoid division by 0. 1979 schedule_decay : float, optional 1980 Exponential decay rate for the momentum schedule 1981 """ 1982 def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, 1983 schedule_decay=0.004, **kwargs): 1984 super(Nadam, self).__init__(learning_rate=learning_rate, **kwargs) 1985 self.beta1 = beta1 1986 self.beta2 = beta2 1987 self.epsilon = epsilon 1988 self.schedule_decay = schedule_decay 1989 self.m_schedule = 1. 1990 1991 def create_state(self, index, weight): 1992 return (zeros(weight.shape, weight.context, dtype=weight.dtype), # mean 1993 zeros(weight.shape, weight.context, dtype=weight.dtype)) # variance 1994 1995 def update(self, index, weight, grad, state): 1996 assert(isinstance(weight, NDArray)) 1997 assert(isinstance(grad, NDArray)) 1998 self._update_count(index) 1999 lr = self._get_lr(index) 2000 wd = self._get_wd(index) 2001 2002 t = self._index_update_count[index] 2003 2004 # preprocess grad 2005 grad = grad * self.rescale_grad + wd * weight 2006 if self.clip_gradient is not None: 2007 grad = clip(grad, -self.clip_gradient, self.clip_gradient) 2008 2009 # warming momentum schedule 2010 momentum_t = self.beta1 * (1. - 0.5 * (pow(0.96, t * self.schedule_decay))) 2011 momentum_t_1 = self.beta1 * (1. - 0.5 * (pow(0.96, (t + 1) * self.schedule_decay))) 2012 self.m_schedule = self.m_schedule * momentum_t 2013 m_schedule_next = self.m_schedule * momentum_t_1 2014 2015 # update m_t and v_t 2016 m_t, v_t = state 2017 m_t[:] *= self.beta1 2018 m_t[:] += (1. - self.beta1) * grad 2019 v_t[:] *= self.beta2 2020 v_t[:] += (1. - self.beta2) * grad * grad 2021 2022 grad_prime = grad / (1. - self.m_schedule) 2023 m_t_prime = m_t / (1. - m_schedule_next) 2024 v_t_prime = v_t / (1. - pow(self.beta2, t)) 2025 m_t_bar = (1. - momentum_t) * grad_prime + momentum_t_1 * m_t_prime 2026 2027 # update weight 2028 weight[:] -= lr * m_t_bar / (sqrt(v_t_prime) + self.epsilon) 2029 2030@register 2031class Test(Optimizer): 2032 """The Test optimizer""" 2033 def __init__(self, **kwargs): 2034 super(Test, self).__init__(**kwargs) 2035 2036 def create_state(self, index, weight): 2037 """Creates a state to duplicate weight.""" 2038 return zeros(weight.shape, weight.context) 2039 2040 def update(self, index, weight, grad, state): 2041 """Performs w += rescale_grad * grad.""" 2042 weight[:] += grad * self.rescale_grad 2043 state[:] = weight 2044 2045# backward compatibility wrapper for Optimizer.CreateOptimizer 2046create = Optimizer.create_optimizer # pylint: disable=invalid-name 2047 2048 2049def _as_classic(a, allow_np): 2050 # TODO(junwu): This is a temp solution for allowing converting 2051 # np.ndarray to mx.nd.NDArray to be fed into the optimizer since 2052 # users may have custom optimizers implemented using mx.nd.NDArray ops. 2053 from ..numpy import ndarray as np_ndarray 2054 if isinstance(a, (tuple, list)): 2055 if any(isinstance(x, np_ndarray) for x in a): 2056 if allow_np: 2057 return [x.as_nd_ndarray() for x in a] 2058 else: 2059 raise ValueError('Converting np.ndarray to mx.nd.NDArray is not allowed') 2060 else: 2061 if isinstance(a, np_ndarray): 2062 if allow_np: 2063 return a.as_nd_ndarray() 2064 else: 2065 raise ValueError('Converting np.ndarray to mx.nd.NDArray is not allowed') 2066 return a 2067 2068 2069 2070class Updater(object): 2071 """Updater for kvstore.""" 2072 def __init__(self, optimizer): 2073 self.optimizer = optimizer 2074 self.states = {} 2075 self.states_synced = {} 2076 self.aggregate_updates = optimizer.aggregate_num > 0 2077 2078 def __call__(self, index, grad, weight): 2079 """Updates weight given gradient and index.""" 2080 allow_np = self.optimizer.allow_np_array if hasattr(self.optimizer, "allow_np_array") else is_np_array() 2081 if not isinstance(index, (list, tuple)): 2082 indices = [index] 2083 grads = [_as_classic(grad, allow_np)] 2084 weights = [_as_classic(weight, allow_np)] 2085 else: 2086 indices = index 2087 grads = _as_classic(grad, allow_np) 2088 weights = _as_classic(weight, allow_np) 2089 if weights: 2090 self.optimizer._set_current_context(weights[0].context.device_id) 2091 for i, idx in enumerate(indices): 2092 # convert ctypes.char_p.value back to python str if needed 2093 if isinstance(idx, bytes): 2094 indices[i] = py_str(idx) 2095 idx = indices[i] 2096 if idx not in self.states: 2097 self.states[idx] = self.optimizer.create_state_multi_precision(idx, weights[i]) 2098 self.states_synced[idx] = True 2099 elif not self.states_synced[idx]: 2100 self.states[idx] = \ 2101 self.sync_state_context(self.states[idx], weights[i].context) 2102 self.states_synced[idx] = True 2103 if self.aggregate_updates: 2104 # segregate values based on type 2105 type_map = {} 2106 for i, w, g in zip(indices, weights, grads): 2107 if w.dtype in type_map: 2108 type_map[w.dtype].append((i, w, g)) 2109 else: 2110 type_map[w.dtype] = [(i, w, g)] 2111 for idx in type_map: 2112 current_index = 0 2113 indices, weights, grads = zip(*type_map[idx]) 2114 while current_index < len(indices): 2115 states = [] 2116 step = min(self.optimizer.aggregate_num, len(indices) - current_index) 2117 for j in range(step): 2118 states.append(self.states[indices[current_index + j]]) 2119 self.optimizer.update_multi_precision( 2120 indices[current_index:current_index + self.optimizer.aggregate_num], 2121 weights[current_index:current_index + self.optimizer.aggregate_num], 2122 grads[current_index:current_index + self.optimizer.aggregate_num], 2123 states) 2124 current_index += self.optimizer.aggregate_num 2125 else: 2126 for i, w, g in zip(indices, weights, grads): 2127 self.optimizer.update_multi_precision(i, w, g, self.states[i]) 2128 2129 def sync_state_context(self, state, context): 2130 """sync state context.""" 2131 if isinstance(state, NDArray): 2132 return state.as_in_context(context) 2133 elif isinstance(state, (tuple, list)): 2134 synced_state = (self.sync_state_context(i, context) for i in state) 2135 if isinstance(state, tuple): 2136 return tuple(synced_state) 2137 else: 2138 return list(synced_state) 2139 else: 2140 return state 2141 2142 def set_states(self, states): 2143 """Sets updater states.""" 2144 states = pickle.loads(states) 2145 if isinstance(states, tuple) and len(states) == 2: 2146 self.states, self.optimizer = states 2147 else: 2148 self.states = states 2149 self.states_synced = dict.fromkeys(self.states.keys(), False) 2150 2151 def get_states(self, dump_optimizer=False): 2152 """Gets updater states. 2153 2154 Parameters 2155 ---------- 2156 dump_optimizer : bool, default False 2157 Whether to also save the optimizer itself. This would also save optimizer 2158 information such as learning rate and weight decay schedules. 2159 """ 2160 return pickle.dumps((self.states, self.optimizer) if dump_optimizer else self.states) 2161 2162def get_updater(optimizer): 2163 """Returns a closure of the updater needed for kvstore. 2164 2165 Parameters 2166 ---------- 2167 optimizer: Optimizer 2168 The optimizer. 2169 2170 Returns 2171 ------- 2172 updater: function 2173 The closure of the updater. 2174 """ 2175 return Updater(optimizer) 2176