1# coding: utf-8
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements.  See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership.  The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License.  You may obtain a copy of the License at
9#
10#   http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied.  See the License for the
16# specific language governing permissions and limitations
17# under the License.
18
19# pylint: disable=too-many-lines
20"""Weight updating functions."""
21import logging
22import math
23import pickle
24import warnings
25import os
26import numpy
27from ..base import py_str
28from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply,
29                       multi_sum_sq, multi_lars, norm as NDnorm)
30from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
31                       mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update,
32                       signsgd_update, signum_update, nag_mom_update, mp_nag_mom_update,
33                       multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update,
34                       multi_mp_sgd_mom_update, preloaded_multi_sgd_update,
35                       preloaded_multi_sgd_mom_update, preloaded_multi_mp_sgd_update,
36                       preloaded_multi_mp_sgd_mom_update, lamb_update_phase1, lamb_update_phase2,
37                       mp_lamb_update_phase1, mp_lamb_update_phase2)
38from ..ndarray.contrib import (multi_lamb_update, multi_mp_lamb_update)
39from ..ndarray import sparse
40from ..random import normal
41from ..util import is_np_array
42
43__all__ = [
44    'AdaDelta', 'AdaGrad', 'Adam', 'Adamax', 'DCASGD', 'FTML', 'Ftrl', 'LARS', 'LBSGD',
45    'NAG', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD', 'Signum', 'LAMB',
46    'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register'
47]
48
49def _flatten_list(nested_list):
50    return [item for sublist in nested_list for item in sublist]
51
52class Optimizer(object):
53    """The base class inherited by all optimizers.
54
55    Parameters
56    ----------
57    rescale_grad : float, optional, default 1.0
58        Multiply the gradient with `rescale_grad` before updating. Often
59        choose to be ``1.0/batch_size``.
60
61    param_idx2name : dict from int to string, optional, default None
62        A dictionary that maps int index to string name.
63
64    clip_gradient : float, optional, default None
65        Clip the gradient by projecting onto the box ``[-clip_gradient, clip_gradient]``.
66
67    learning_rate : float, optional, default None
68        The initial learning rate. If None, the optimization will use the
69        learning rate from ``lr_scheduler``. If not None, it will overwrite
70        the learning rate in ``lr_scheduler``. If None and ``lr_scheduler``
71        is also None, then it will be set to 0.01 by default.
72
73    lr_scheduler : LRScheduler, optional, default None
74        The learning rate scheduler.
75
76    wd : float, optional, default 0.0
77        The weight decay (or L2 regularization) coefficient. Modifies objective
78        by adding a penalty for having large weights.
79
80    sym: Symbol, optional, default None
81        The Symbol this optimizer is applying to.
82
83    begin_num_update : int, optional, default 0
84        The initial number of updates.
85
86    multi_precision : bool, optional, default False
87       Flag to control the internal precision of the optimizer.
88       False: results in using the same precision as the weights (default),
89       True: makes internal 32-bit copy of the weights and applies gradients
90       in 32-bit precision even if actual weights used in the model have lower precision.
91       Turning this on can improve convergence and accuracy when training with float16.
92
93    param_dict : dict of int -> gluon.Parameter, default None
94        Dictionary of parameter index to gluon.Parameter, used to lookup parameter attributes
95        such as lr_mult, wd_mult, etc. param_dict shall not be deep copied.
96
97    Properties
98    ----------
99    learning_rate : float
100        The current learning rate of the optimizer. Given an Optimizer object
101        optimizer, its learning rate can be accessed as optimizer.learning_rate.
102    """
103    def __init__(self, rescale_grad=1., param_idx2name=None, wd=0.,
104                 clip_gradient=None, learning_rate=None,
105                 lr_scheduler=None, sym=None, begin_num_update=0,
106                 multi_precision=False, param_dict=None):
107        self.rescale_grad = rescale_grad
108        self.lr_scheduler = lr_scheduler
109        if self.lr_scheduler is None and learning_rate is None:
110            learning_rate = 0.01
111        self.lr = learning_rate
112        if self.lr_scheduler is not None and learning_rate is not None:
113            if self.lr_scheduler.base_lr != learning_rate:
114                print(UserWarning("learning rate from ``lr_scheduler`` has been "
115                                  "overwritten by ``learning_rate`` in optimizer."))
116                self.lr_scheduler.base_lr = learning_rate
117
118        self.wd = wd
119        self.lr_mult = {}
120        self.wd_mult = {}
121        self.begin_num_update = begin_num_update
122        self.num_update = begin_num_update
123        self._all_index_update_counts = {0 : {}}
124        self._index_update_count = self._all_index_update_counts[0]
125        self.clip_gradient = clip_gradient
126        self.multi_precision = multi_precision
127        self.aggregate_num = 0
128
129        if param_idx2name is None:
130            param_idx2name = {}
131        assert isinstance(param_idx2name, dict), \
132            'param_idx2name should be a dict of param indexes to names.'
133        self.idx2name = param_idx2name.copy()
134        self.sym_info = (sym.attr_dict(), sym.list_arguments()) if sym is not None else ()
135        self.param_dict = param_dict if param_dict else {}
136        self.allow_np_array = is_np_array()
137
138        self.set_lr_mult({})
139        self.set_wd_mult({})
140
141    opt_registry = {}
142
143    @staticmethod
144    def register(klass):
145        """Registers a new optimizer.
146
147        Once an optimizer is registered, we can create an instance of this
148        optimizer with `create_optimizer` later.
149
150        Examples
151        --------
152
153        >>> @mx.optimizer.Optimizer.register
154        ... class MyOptimizer(mx.optimizer.Optimizer):
155        ...     pass
156        >>> optim = mx.optimizer.Optimizer.create_optimizer('MyOptimizer')
157        >>> print(type(optim))
158        <class '__main__.MyOptimizer'>
159        """
160        assert(isinstance(klass, type))
161        name = klass.__name__.lower()
162        if name in Optimizer.opt_registry:
163            warnings.warn('WARNING: New optimizer %s.%s is overriding '
164                          'existing optimizer %s.%s' %
165                          (klass.__module__, klass.__name__,
166                           Optimizer.opt_registry[name].__module__,
167                           Optimizer.opt_registry[name].__name__))
168        Optimizer.opt_registry[name] = klass
169        return klass
170
171    @staticmethod
172    def create_optimizer(name, **kwargs):
173        """Instantiates an optimizer with a given name and kwargs.
174
175        .. note:: We can use the alias `create` for ``Optimizer.create_optimizer``.
176
177        Parameters
178        ----------
179        name: str
180            Name of the optimizer. Should be the name
181            of a subclass of Optimizer. Case insensitive.
182
183        kwargs: dict
184            Parameters for the optimizer.
185
186        Returns
187        -------
188        Optimizer
189            An instantiated optimizer.
190
191        Examples
192        --------
193        >>> sgd = mx.optimizer.Optimizer.create_optimizer('sgd')
194        >>> type(sgd)
195        <class 'mxnet.optimizer.SGD'>
196        >>> adam = mx.optimizer.create('adam', learning_rate=.1)
197        >>> type(adam)
198        <class 'mxnet.optimizer.Adam'>
199        """
200        if name.lower() in Optimizer.opt_registry:
201            return Optimizer.opt_registry[name.lower()](**kwargs)
202        else:
203            raise ValueError('Cannot find optimizer %s' % name)
204
205    @property
206    def learning_rate(self):
207        if self.lr_scheduler is not None:
208            return self.lr_scheduler(self.num_update)
209        else:
210            return self.lr
211
212    def create_state(self, index, weight):
213        """Creates auxiliary state for a given weight.
214
215        Some optimizers require additional states, e.g. as momentum, in addition
216        to gradients in order to update weights. This function creates state
217        for a given weight which will be used in `update`. This function is
218        called only once for each weight.
219
220        Parameters
221        ----------
222        index : int
223            An unique index to identify the weight.
224        weight : NDArray
225            The weight.
226
227        Returns
228        -------
229        state : any obj
230            The state associated with the weight.
231        """
232
233    def create_state_multi_precision(self, index, weight):
234        """Creates auxiliary state for a given weight, including FP32 high
235        precision copy if original weight is FP16.
236
237        This method is provided to perform automatic mixed precision training
238        for optimizers that do not support it themselves.
239
240        Parameters
241        ----------
242        index : int
243            An unique index to identify the weight.
244        weight : NDArray
245            The weight.
246
247        Returns
248        -------
249        state : any obj
250            The state associated with the weight.
251        """
252        weight_master_copy = None
253        if self.multi_precision and weight.dtype == numpy.float16:
254            weight_master_copy = weight.astype(numpy.float32)
255            return (weight_master_copy,) + (self.create_state(index, weight_master_copy),)
256        if weight.dtype == numpy.float16 and not self.multi_precision:
257            warnings.warn("Accumulating with float16 in optimizer can lead to "
258                          "poor accuracy or slow convergence. "
259                          "Consider using multi_precision=True option of the "
260                          "optimizer")
261        return self.create_state(index, weight)
262
263    def update(self, index, weight, grad, state):
264        """Updates the given parameter using the corresponding gradient and state.
265
266        Parameters
267        ----------
268        index : int
269            The unique index of the parameter into the individual learning
270            rates and weight decays. Learning rates and weight decay
271            may be set via `set_lr_mult()` and `set_wd_mult()`, respectively.
272        weight : NDArray
273            The parameter to be updated.
274        grad : NDArray
275            The gradient of the objective with respect to this parameter.
276        state : any obj
277            The state returned by `create_state()`.
278        """
279        raise NotImplementedError()
280
281    def update_multi_precision(self, index, weight, grad, state):
282        """Updates the given parameter using the corresponding gradient and state.
283        Mixed precision version.
284
285        Parameters
286        ----------
287        index : int
288            The unique index of the parameter into the individual learning
289            rates and weight decays. Learning rates and weight decay
290            may be set via `set_lr_mult()` and `set_wd_mult()`, respectively.
291        weight : NDArray
292            The parameter to be updated.
293        grad : NDArray
294            The gradient of the objective with respect to this parameter.
295        state : any obj
296            The state returned by `create_state()`.
297        """
298        if self.multi_precision and weight.dtype == numpy.float16:
299            # Wrapper for mixed precision
300            weight_master_copy = state[0]
301            original_state = state[1]
302            grad32 = grad.astype(numpy.float32)
303            self.update(index, weight_master_copy, grad32, original_state)
304            cast(weight_master_copy, dtype=weight.dtype, out=weight)
305        else:
306            self.update(index, weight, grad, state)
307
308    def set_learning_rate(self, lr):
309        """Sets a new learning rate of the optimizer.
310
311        Parameters
312        ----------
313        lr : float
314            The new learning rate of the optimizer.
315        """
316        if self.lr_scheduler is not None: # pylint: disable=no-else-raise
317            raise UserWarning("LRScheduler of the optimizer has already been "
318                              "defined. Note that set_learning_rate can mutate "
319                              "the value of the learning rate of the optimizer "
320                              "only when the LRScheduler of the optimizer is "
321                              "undefined.")
322        else:
323            self.lr = lr
324
325    def set_lr_scale(self, args_lrscale): # pylint: disable=unused-argument
326        """[DEPRECATED] Sets lr scale. Use set_lr_mult instead."""
327        raise DeprecationWarning
328
329    def set_lr_mult(self, args_lr_mult):
330        """Sets an individual learning rate multiplier for each parameter.
331
332        If you specify a learning rate multiplier for a parameter, then
333        the learning rate for the parameter will be set as the product of
334        the global learning rate `self.lr` and its multiplier.
335
336        .. note:: The default learning rate multiplier of a `Variable`
337            can be set with `lr_mult` argument in the constructor.
338
339        Parameters
340        ----------
341        args_lr_mult : dict of str/int to float
342            For each of its key-value entries, the learning rate multipler for the
343            parameter specified in the key will be set as the given value.
344
345            You can specify the parameter with either its name or its index.
346            If you use the name, you should pass `sym` in the constructor,
347            and the name you specified in the key of `args_lr_mult` should match
348            the name of the parameter in `sym`. If you use the index, it should
349            correspond to the index of the parameter used in the `update` method.
350
351            Specifying a parameter by its index is only supported for backward
352            compatibility, and we recommend to use the name instead.
353        """
354        self.lr_mult = {}
355        if self.sym_info:
356            attr, arg_names = self.sym_info
357            for name in arg_names:
358                if name in attr and '__lr_mult__' in attr[name]:
359                    self.lr_mult[name] = float(attr[name]['__lr_mult__'])
360        self.lr_mult.update(args_lr_mult)
361
362    def set_wd_mult(self, args_wd_mult):
363        """Sets an individual weight decay multiplier for each parameter.
364
365        By default, if `param_idx2name` was provided in the
366        constructor, the weight decay multipler is set as 0 for all
367        parameters whose name don't end with ``_weight`` or
368        ``_gamma``.
369
370        .. note:: The default weight decay multiplier for a `Variable`
371            can be set with its `wd_mult` argument in the constructor.
372
373        Parameters
374        ----------
375        args_wd_mult : dict of string/int to float
376            For each of its key-value entries, the weight decay multipler for the
377            parameter specified in the key will be set as the given value.
378
379            You can specify the parameter with either its name or its index.
380            If you use the name, you should pass `sym` in the constructor,
381            and the name you specified in the key of `args_lr_mult` should match
382            the name of the parameter in `sym`. If you use the index, it should
383            correspond to the index of the parameter used in the `update` method.
384
385            Specifying a parameter by its index is only supported for backward
386            compatibility, and we recommend to use the name instead.
387        """
388        self.wd_mult = {}
389        for n in self.idx2name.values():
390            if not (n.endswith('_weight') or n.endswith('_gamma')):
391                self.wd_mult[n] = 0.0
392        if self.sym_info:
393            attr, arg_names = self.sym_info
394            for name in arg_names:
395                if name in attr and '__wd_mult__' in attr[name]:
396                    self.wd_mult[name] = float(attr[name]['__wd_mult__'])
397        self.wd_mult.update(args_wd_mult)
398
399    def _set_current_context(self, device_id):
400        """Sets the number of the currently handled device.
401
402        Parameters
403        ----------
404        device_id : int
405            The number of current device.
406        """
407        if device_id not in self._all_index_update_counts:
408            self._all_index_update_counts[device_id] = {}
409        self._index_update_count = self._all_index_update_counts[device_id]
410
411    def _update_count(self, index):
412        """Updates num_update.
413
414        Parameters
415        ----------
416        index : int or list of int
417            The index to be updated.
418        """
419        if not isinstance(index, (list, tuple)):
420            index = [index]
421        for idx in index:
422            if idx not in self._index_update_count:
423                self._index_update_count[idx] = self.begin_num_update
424            self._index_update_count[idx] += 1
425            self.num_update = max(self._index_update_count[idx], self.num_update)
426
427    def _get_lrs(self, indices):
428        """Gets the learning rates given the indices of the weights.
429
430        Parameters
431        ----------
432        indices : list of int
433            Indices corresponding to weights.
434
435        Returns
436        -------
437        lrs : list of float
438            Learning rates for those indices.
439        """
440        if self.lr_scheduler is not None:
441            lr = self.lr_scheduler(self.num_update)
442        else:
443            lr = self.lr
444
445        lrs = [lr for _ in indices]
446        for i, index in enumerate(indices):
447            if index in self.param_dict:
448                lrs[i] *= self.param_dict[index].lr_mult
449            elif index in self.lr_mult:
450                lrs[i] *= self.lr_mult[index]
451            elif index in self.idx2name:
452                lrs[i] *= self.lr_mult.get(self.idx2name[index], 1.0)
453        return lrs
454
455    def _get_lr(self, index):
456        """Gets the learning rate given the index of the weight.
457
458        Parameters
459        ----------
460        index : int
461            The index corresponding to the weight.
462
463        Returns
464        -------
465        lr : float
466            Learning rate for this index.
467        """
468        return self._get_lrs([index])[0]
469
470    def _get_wds(self, indices):
471        """Gets weight decays for indices.
472        Returns 0 for non-weights if the name of weights are provided for `__init__`.
473
474        Parameters
475        ----------
476        indices : list of int
477            Indices of weights.
478
479        Returns
480        -------
481        wds : list of float
482            Weight decays for those indices.
483        """
484        wds = [self.wd for _ in indices]
485        for i, index in enumerate(indices):
486            if index in self.param_dict:
487                wds[i] *= self.param_dict[index].wd_mult
488            elif index in self.wd_mult:
489                wds[i] *= self.wd_mult[index]
490            elif index in self.idx2name:
491                wds[i] *= self.wd_mult.get(self.idx2name[index], 1.0)
492        return wds
493
494    def _get_wd(self, index):
495        """Gets weight decay for index.
496        Returns 0 for non-weights if the name of weights are provided for `__init__`.
497
498        Parameters
499        ----------
500        index : int
501            The index of weight.
502
503        Returns
504        -------
505        wd : float
506            Weight decay for this index.
507        """
508        return self._get_wds([index])[0]
509
510    def __getstate__(self):
511        ret = self.__dict__.copy()
512        # do not include param_dict in the state
513        del ret['param_dict']
514        return ret
515
516    def __setstate__(self, state):
517        self.__dict__ = state
518        # param_dict needs to be explicitly set by the trainer
519        self.param_dict = {}
520
521# convenience wrapper for Optimizer.Register
522register = Optimizer.register   # pylint: disable=invalid-name
523
524# pylint: disable=line-too-long
525@register
526class SGD(Optimizer):
527    """The SGD optimizer with momentum and weight decay.
528
529    If the storage types of grad is ``row_sparse`` and ``lazy_update`` is True, \
530    **lazy updates** are applied by::
531
532        for row in grad.indices:
533            rescaled_grad[row] = lr * (rescale_grad * clip(grad[row], clip_gradient) + wd * weight[row])
534            state[row] = momentum[row] * state[row] + rescaled_grad[row]
535            weight[row] = weight[row] - state[row]
536
537    The sparse update only updates the momentum for the weights whose row_sparse
538    gradient indices appear in the current batch, rather than updating it for all
539    indices. Compared with the original update, it can provide large
540    improvements in model training throughput for some applications. However, it
541    provides slightly different semantics than the original update, and
542    may lead to different empirical results.
543
544    In the case when ``update_on_kvstore`` is set to False (either globally via
545    MXNET_UPDATE_ON_KVSTORE=0 environment variable or as a parameter in
546    :class:`~mxnet.gluon.Trainer`) SGD optimizer can perform aggregated update
547    of parameters, which may lead to improved performance. The aggregation size
548    is controlled by MXNET_OPTIMIZER_AGGREGATION_SIZE environment variable and
549    defaults to 4.
550
551    Otherwise, **standard updates** are applied by::
552
553        rescaled_grad = lr * (rescale_grad * clip(grad, clip_gradient) + wd * weight)
554        state = momentum * state + rescaled_grad
555        weight = weight - state
556
557    For details of the update algorithm see
558    :class:`~mxnet.ndarray.sgd_update` and :class:`~mxnet.ndarray.sgd_mom_update`.
559
560    This optimizer accepts the following parameters in addition to those accepted
561    by :class:`.Optimizer`.
562
563    Parameters
564    ----------
565    momentum : float, optional
566        The momentum value.
567    lazy_update : bool, optional
568        Default is True. If True, lazy updates are applied \
569        if the storage types of weight and grad are both ``row_sparse``.
570    multi_precision: bool, optional
571        Flag to control the internal precision of the optimizer.
572        False: results in using the same precision as the weights (default),
573        True: makes internal 32-bit copy of the weights and applies gradients
574        in 32-bit precision even if actual weights used in the model have lower precision.
575        Turning this on can improve convergence and accuracy when training with float16.
576    """
577    def __init__(self, momentum=0.0, lazy_update=True, **kwargs):
578        super(SGD, self).__init__(**kwargs)
579        self.momentum = momentum
580        self.lazy_update = lazy_update
581        self.aggregate_num = int(os.getenv('MXNET_OPTIMIZER_AGGREGATION_SIZE', "4"))
582
583    def create_state_multi_precision(self, index, weight):
584        weight_master_copy = None
585        if self.multi_precision and weight.dtype == numpy.float16:
586            weight_master_copy = weight.astype(numpy.float32)
587            return (self.create_state(index, weight_master_copy), weight_master_copy)
588        if weight.dtype == numpy.float16 and not self.multi_precision:
589            warnings.warn("Accumulating with float16 in optimizer can lead to "
590                          "poor accuracy or slow convergence. "
591                          "Consider using multi_precision=True option of the "
592                          "SGD optimizer")
593        return self.create_state(index, weight)
594
595    def create_state(self, index, weight):
596        momentum = None
597        if self.momentum != 0.0:
598            stype = weight.stype if self.lazy_update else 'default'
599            momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=stype)
600        return momentum
601
602    def _update_impl(self, indices, weights, grads, states, multi_precision=False):
603        aggregate = True
604        if not isinstance(indices, (tuple, list)):
605            indices = [indices]
606            weights = [weights]
607            grads = [grads]
608            states = [states]
609        for weight, grad in zip(weights, grads):
610            assert(isinstance(weight, NDArray))
611            assert(isinstance(grad, NDArray))
612            aggregate = (aggregate and
613                         weight.stype == 'default' and
614                         grad.stype == 'default')
615        self._update_count(indices)
616        lrs = self._get_lrs(indices)
617        wds = self._get_wds(indices)
618
619        kwargs = {'rescale_grad': self.rescale_grad}
620        if self.momentum > 0:
621            kwargs['momentum'] = self.momentum
622        if self.clip_gradient:
623            kwargs['clip_gradient'] = self.clip_gradient
624
625        if aggregate:
626            if not multi_precision:
627                if self.momentum > 0:
628                    multi_sgd_mom_update(*_flatten_list(zip(weights, grads, states)), out=weights,
629                                         num_weights=len(weights), lrs=lrs, wds=wds, **kwargs)
630                else:
631                    multi_sgd_update(*_flatten_list(zip(weights, grads)), out=weights,
632                                     num_weights=len(weights), lrs=lrs, wds=wds, **kwargs)
633            else:
634                if self.momentum > 0:
635                    multi_mp_sgd_mom_update(*_flatten_list(zip(weights, grads, *zip(*states))),
636                                            out=weights, num_weights=len(weights),
637                                            lrs=lrs, wds=wds, **kwargs)
638                else:
639                    multi_mp_sgd_update(*_flatten_list(zip(weights, grads,
640                                                           list(zip(*states))[1])),
641                                        out=weights, num_weights=len(weights),
642                                        lrs=lrs, wds=wds, **kwargs)
643        else:
644            for weight, grad, state, lr, wd in zip(weights, grads, states, lrs, wds):
645                if not multi_precision:
646                    if state is not None:
647                        sgd_mom_update(weight, grad, state, out=weight,
648                                       lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs)
649                    else:
650                        sgd_update(weight, grad, out=weight, lazy_update=self.lazy_update,
651                                   lr=lr, wd=wd, **kwargs)
652                else:
653                    if state[0] is not None:
654                        mp_sgd_mom_update(weight, grad, state[0], state[1], out=weight,
655                                          lr=lr, wd=wd, **kwargs)
656                    else:
657                        mp_sgd_update(weight, grad, state[1], out=weight,
658                                      lr=lr, wd=wd, **kwargs)
659
660    def update(self, index, weight, grad, state):
661        self._update_impl(index, weight, grad, state, multi_precision=False)
662
663    def update_multi_precision(self, index, weight, grad, state):
664        if not isinstance(index, (tuple, list)):
665            use_multi_precision = self.multi_precision and weight.dtype == numpy.float16
666        else:
667            use_multi_precision = self.multi_precision and weight[0].dtype == numpy.float16
668        self._update_impl(index, weight, grad, state,
669                          multi_precision=use_multi_precision)
670
671@register
672class Signum(Optimizer):
673    r"""The Signum optimizer that takes the sign of gradient or momentum.
674
675    The optimizer updates the weight by::
676
677        rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight
678        state = momentum * state + (1-momentum)*rescaled_grad
679        weight = (1 - lr * wd_lh) * weight - lr * sign(state)
680
681    References
682    ----------
683    Jeremy Bernstein, Yu-Xiang Wang, Kamyar Azizzadenesheli & Anima Anandkumar. (2018).
684    signSGD: Compressed Optimisation for Non-Convex Problems. In ICML'18.
685
686    See: https://arxiv.org/abs/1802.04434
687
688    For details of the update algorithm see
689    :class:`~mxnet.ndarray.signsgd_update` and :class:`~mxnet.ndarray.signum_update`.
690
691    This optimizer accepts the following parameters in addition to those accepted
692    by :class:`.Optimizer`.
693
694    Parameters
695    ----------
696    momentum : float, optional
697       The momentum value.
698    wd_lh : float, optional
699       The amount of decoupled weight decay regularization, see details in the original paper at:\
700       https://arxiv.org/abs/1711.05101
701    """
702    def __init__(self, learning_rate=0.01, momentum=0.9, wd_lh=0.0, **kwargs):
703        super(Signum, self).__init__(learning_rate=learning_rate, **kwargs)
704        self.momentum = momentum
705        self.wd_lh = wd_lh
706
707    def create_state(self, index, weight):
708        momentum = None
709        if self.momentum != 0.0:
710            momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype)
711        return momentum
712
713    def _update_impl(self, index, weight, grad, state):
714        assert(isinstance(weight, NDArray))
715        assert(isinstance(grad, NDArray))
716        self._update_count(index)
717        lr = self._get_lr(index)
718        wd = self._get_wd(index)
719
720        kwargs = {'rescale_grad': self.rescale_grad}
721        if self.momentum > 0:
722            kwargs['momentum'] = self.momentum
723        if self.clip_gradient:
724            kwargs['clip_gradient'] = self.clip_gradient
725        if self.wd_lh:
726            kwargs['wd_lh'] = self.wd_lh
727
728        if state is not None:
729            signum_update(weight, grad, state, out=weight,
730                          lr=lr, wd=wd, **kwargs)
731        else:
732            signsgd_update(weight, grad, out=weight,
733                           lr=lr, wd=wd, **kwargs)
734
735    def update(self, index, weight, grad, state):
736        self._update_impl(index, weight, grad, state)
737
738@register
739class FTML(Optimizer):
740    """The FTML optimizer.
741
742    This class implements the optimizer described in
743    *FTML - Follow the Moving Leader in Deep Learning*,
744    available at http://proceedings.mlr.press/v70/zheng17a/zheng17a.pdf.
745
746    Denote time step by t. The optimizer updates the weight by::
747
748        rescaled_grad = clip(grad * rescale_grad + wd * weight, clip_gradient)
749        v = beta2 * v + (1 - beta2) * square(rescaled_grad)
750        d_t = (1 - power(beta1, t)) / lr * square_root(v / (1 - power(beta2, t))) + epsilon)
751        z = beta1 * z + (1 - beta1) * rescaled_grad - (d_t - beta1 * d_(t-1)) * weight
752        weight = - z / d_t
753
754    For details of the update algorithm, see :class:`~mxnet.ndarray.ftml_update`.
755
756    This optimizer accepts the following parameters in addition to those accepted
757    by :class:`.Optimizer`.
758
759    Parameters
760    ----------
761    beta1 : float, optional
762        0 < beta1 < 1. Generally close to 0.5.
763    beta2 : float, optional
764        0 < beta2 < 1. Generally close to 1.
765    epsilon : float, optional
766        Small value to avoid division by 0.
767    """
768    def __init__(self, beta1=0.6, beta2=0.999, epsilon=1e-8, **kwargs):
769        super(FTML, self).__init__(**kwargs)
770        self.beta1 = beta1
771        self.beta2 = beta2
772        self.epsilon = epsilon
773
774    def create_state(self, index, weight):
775        return (zeros(weight.shape, weight.context, dtype=weight.dtype), # d_0
776                zeros(weight.shape, weight.context, dtype=weight.dtype), # v_0
777                zeros(weight.shape, weight.context, dtype=weight.dtype)) # z_0
778
779    def update(self, index, weight, grad, state):
780        assert(isinstance(weight, NDArray))
781        assert(isinstance(grad, NDArray))
782        self._update_count(index)
783        lr = self._get_lr(index)
784        wd = self._get_wd(index)
785        t = self._index_update_count[index]
786
787        kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
788                  'rescale_grad': self.rescale_grad, 't': t}
789        if self.clip_gradient:
790            kwargs['clip_grad'] = self.clip_gradient
791
792        prev_d, prev_v, prev_z = state
793        ftml_update(weight, grad, prev_d, prev_v, prev_z, out=weight,
794                    lr=lr, wd=wd, **kwargs)
795
796@register
797class LARS(Optimizer):
798    """the LARS optimizer from 'Large Batch Training of Convolution Networks' \
799    (https://arxiv.org/abs/1708.03888)
800
801    Behave mostly like SGD with momentum and weight decay but is scaling \
802    adaptively the learning for each layer (except bias and batch norm parameters):
803    w_norm = L2norm(weights)
804    g_norm = L2norm(gradients)
805    if w_norm > 0 and g_norm > 0:
806        lr_layer = lr * lr_mult * eta * w_norm / (g_norm + weight_decay * w_norm + eps)
807    else:
808        lr_layer = lr * lr_mult
809
810    Parameters
811    ----------
812    momentum : float, optional
813        The momentum value.
814    lazy_update : bool, optional
815        Default is True. If True, lazy updates are applied \
816        if the storage types of weight and grad are both ``row_sparse``.
817    lars_eta : float, optional
818        LARS coefficient used to scale the learning rate. Default set to 0.001.
819    lars_epsilon : float, optional
820        Optional epsilon in case of very small gradients. Default set to 0.
821    momentum_correction : bool, optional
822        If True scale momentum w.r.t global learning rate change (with an lr_scheduler) \
823        as indicated in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour` \
824        (https://arxiv.org/pdf/1706.02677.pdf)
825        Default set to True.
826    """
827    def __init__(self, momentum=0.0, lazy_update=True, eta=0.001, eps=0,
828                 momentum_correction=True, **kwargs):
829        super(LARS, self).__init__(**kwargs)
830        self.momentum = momentum
831        self.momentum_correction = momentum_correction
832        self.lazy_update = lazy_update
833        self.aggregate_num = int(os.getenv('MXNET_OPTIMIZER_AGGREGATION_SIZE', "4"))
834        self.eta = eta
835        self.eps = eps
836        self.skip = 0
837        self.last_lr = None
838        self.cur_lr = None
839
840
841    def _get_lrs(self, indices):
842        """Gets the learning rates given the indices of the weights.
843
844        Parameters
845        ----------
846        indices : list of int
847            Indices corresponding to weights.
848
849        Returns
850        -------
851        lrs : list of float
852            Learning rates for those indices.
853        """
854        if self.cur_lr is not None:
855            self.last_lr = self.cur_lr
856
857        if self.lr_scheduler is not None:
858            lr = self.lr_scheduler(self.num_update)
859        else:
860            lr = self.lr
861
862        if self.cur_lr is None:
863            self.last_lr = lr
864        self.cur_lr = lr
865
866        lrs = [lr for _ in indices]
867        for i, index in enumerate(indices):
868            if index in self.param_dict:
869                lrs[i] *= self.param_dict[index].lr_mult
870            elif index in self.lr_mult:
871                lrs[i] *= self.lr_mult[index]
872            elif index in self.idx2name:
873                lrs[i] *= self.lr_mult.get(self.idx2name[index], 1.0)
874        return lrs
875
876    def set_wd_mult(self, args_wd_mult):
877        self.wd_mult = {}
878        for n in self.idx2name.values():
879            is_weight = n.endswith('_weight')
880
881            if not is_weight:
882                self.wd_mult[n] = 0.0
883
884        if self.sym_info:
885            attr, arg_names = self.sym_info
886            for name in arg_names:
887                if name in attr and '__wd_mult__' in attr[name]:
888                    self.wd_mult[name] = float(attr[name]['__wd_mult__'])
889        self.wd_mult.update(args_wd_mult)
890
891    def create_state_multi_precision(self, index, weight):
892        weight_master_copy = None
893        if self.multi_precision and weight.dtype == numpy.float16:
894            weight_master_copy = weight.astype(numpy.float32)
895            return (self.create_state(index, weight_master_copy), weight_master_copy)
896        if weight.dtype == numpy.float16 and not self.multi_precision:
897            warnings.warn("Accumulating with float16 in optimizer can lead to "
898                          "poor accuracy or slow convergence. "
899                          "Consider using multi_precision=True option of the "
900                          "SGD optimizer")
901        return self.create_state(index, weight)
902
903    def create_state(self, index, weight):
904        momentum = None
905        if self.momentum != 0.0:
906            stype = weight.stype if self.lazy_update else 'default'
907            momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=stype)
908        return momentum
909
910    def _l2norm(self, v, rescale=False):
911        """L2 Norm implementation"""
912        v = v.astype('float32')
913        if rescale:
914            v *= self.rescale_grad
915        norm = NDnorm(v).asnumpy()[0]
916        return norm
917
918    def _get_lars(self, i, weight, g, lr, wd):
919        """Returns a scaling factor for the learning rate for this layer"""
920        name = self.idx2name[i] if i in self.idx2name else str(i)
921        if name.endswith('gamma') or name.endswith('beta') or name.endswith('bias'):
922            return lr
923
924        w_norm = self._l2norm(weight)
925        g_norm = self._l2norm(g, rescale=True)
926
927        if w_norm > 0.0 and g_norm > 0.0:
928            lars = self.eta * w_norm/(g_norm + wd * w_norm + self.eps)
929        else:
930            lars = 1.0
931        return lars * lr
932
933    def _update_impl(self, indices, weights, grads, states, multi_precision=False):
934        aggregate = True
935        if not isinstance(indices, (tuple, list)):
936            indices = [indices]
937            weights = [weights]
938            grads = [grads]
939            states = [states]
940        for weight, grad in zip(weights, grads):
941            assert(isinstance(weight, NDArray))
942            assert(isinstance(grad, NDArray))
943            aggregate = (aggregate and
944                         weight.stype == 'default' and
945                         grad.stype == 'default')
946        self._update_count(indices)
947        lrs = self._get_lrs(indices)
948        wds = self._get_wds(indices)
949
950        kwargs = {'rescale_grad': self.rescale_grad}
951        if self.momentum > 0:
952            kwargs['momentum'] = (self.momentum * (self.cur_lr / self.last_lr)) \
953                                 if (self.momentum_correction and self.last_lr != 0) else \
954                                 self.momentum
955
956        if self.clip_gradient:
957            kwargs['clip_gradient'] = self.clip_gradient
958
959        if aggregate:
960            nb_params = len(indices)
961            names = [self.idx2name[i] if i in self.idx2name else str(i) for i in indices]
962            lars_idx = [i for i in range(nb_params) if
963                        not(names[i].endswith('gamma') or names[i].endswith('beta') or
964                            names[i].endswith('bias'))]
965            nb_lars = len(lars_idx)
966            no_lars_idx = [i for i in range(nb_params) if
967                           (names[i].endswith('gamma') or names[i].endswith('beta') or
968                            names[i].endswith('bias'))]
969            cur_ctx = weights[0].context
970            full_idx = lars_idx + no_lars_idx
971            new_lrs = array([lrs[i] for i in full_idx], ctx=cur_ctx, dtype='float32')
972            new_wds = array([wds[i] for i in full_idx], ctx=cur_ctx, dtype='float32')
973            new_weights = [weights[i] for i in full_idx]
974            new_grads = [grads[i] for i in full_idx]
975            new_states = [states[i] for i in full_idx]
976            if nb_lars > 0:
977                w_sum_sq = multi_sum_sq(*new_weights[:nb_lars], num_arrays=nb_lars)
978                g_sum_sq = multi_sum_sq(*new_grads[:nb_lars], num_arrays=nb_lars)
979                multi_lars(new_lrs[:nb_lars], w_sum_sq, g_sum_sq, new_wds[:nb_lars],
980                           eta=self.eta, eps=self.eps, rescale_grad=self.rescale_grad,
981                           out=new_lrs[:nb_lars])
982            # Same than usual using preloaded sgd functions
983            sidx = 0
984            while sidx < len(indices):
985                eidx = sidx + len(new_weights[sidx:sidx+self.aggregate_num])
986                if not multi_precision:
987                    if self.momentum > 0:
988                        preloaded_multi_sgd_mom_update(
989                            *(_flatten_list(zip(new_weights[sidx:eidx],
990                                                new_grads[sidx:eidx],
991                                                new_states[sidx:eidx])) +
992                              [new_lrs[sidx:eidx], new_wds[sidx:eidx]]),
993                            out=new_weights[sidx:eidx],
994                            num_weights=len(new_weights[sidx:eidx]),
995                            **kwargs)
996                    else:
997                        preloaded_multi_sgd_update(
998                            *(_flatten_list(zip(new_weights[sidx:eidx],
999                                                new_grads[sidx:eidx])) +
1000                              [new_lrs[sidx:eidx], new_wds[sidx:eidx]]),
1001                            out=new_weights[sidx:eidx],
1002                            num_weights=len(new_weights[sidx:eidx]),
1003                            **kwargs)
1004                else:
1005                    if self.momentum > 0:
1006                        preloaded_multi_mp_sgd_mom_update(
1007                            *(_flatten_list(zip(new_weights[sidx:eidx],
1008                                                new_grads[sidx:eidx],
1009                                                *zip(*new_states[sidx:eidx]))) +
1010                              [new_lrs[sidx:eidx], new_wds[sidx:eidx]]),
1011                            out=new_weights[sidx:eidx],
1012                            num_weights=len(new_weights[sidx:eidx]),
1013                            **kwargs)
1014                    else:
1015                        preloaded_multi_mp_sgd_update(
1016                            *(_flatten_list(zip(new_weights[sidx:eidx],
1017                                                new_grads[sidx:eidx],
1018                                                list(zip(*new_states[sidx:eidx]))[1])) +
1019                              [new_lrs[sidx:eidx], new_wds[sidx:eidx]]),
1020                            out=new_weights[sidx:eidx],
1021                            num_weights=len(new_weights[sidx:eidx]),
1022                            **kwargs)
1023                sidx += self.aggregate_num
1024        else:
1025            lrs = [self._get_lars(i, w, g, lr, wd) for (i, w, g, lr, wd) in
1026                   zip(indices, weights, grads, lrs, wds)]
1027
1028            for weight, grad, state, lr, wd in zip(weights, grads, states, lrs, wds):
1029                if not multi_precision:
1030                    if state is not None:
1031                        sgd_mom_update(weight, grad, state, out=weight,
1032                                       lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs)
1033                    else:
1034                        sgd_update(weight, grad, out=weight, lazy_update=self.lazy_update,
1035                                   lr=lr, wd=wd, **kwargs)
1036                else:
1037                    if state[0] is not None:
1038                        mp_sgd_mom_update(weight, grad, state[0], state[1], out=weight,
1039                                          lr=lr, wd=wd, **kwargs)
1040                    else:
1041                        mp_sgd_update(weight, grad, state[1], out=weight,
1042                                      lr=lr, wd=wd, **kwargs)
1043
1044    def update(self, index, weight, grad, state):
1045        self._update_impl(index, weight, grad, state, multi_precision=False)
1046
1047    def update_multi_precision(self, index, weight, grad, state):
1048        if not isinstance(index, (tuple, list)):
1049            use_multi_precision = self.multi_precision and weight.dtype == numpy.float16
1050        else:
1051            use_multi_precision = self.multi_precision and weight[0].dtype == numpy.float16
1052        self._update_impl(index, weight, grad, state,
1053                          multi_precision=use_multi_precision)
1054
1055#
1056@register
1057class LBSGD(Optimizer):
1058    """The Large Batch SGD optimizer with momentum and weight decay.
1059
1060    The optimizer updates the weight by::
1061
1062        state = momentum * state + lr * rescale_grad * clip(grad, clip_gradient) + wd * weight
1063        weight = weight - state
1064
1065    For details of the update algorithm see :class:`~mxnet.ndarray.sgd_update`
1066    and :class:`~mxnet.ndarray.sgd_mom_update`.
1067    In addition to the SGD updates the LBSGD optimizer uses the LARS, Layer-wise
1068    Adaptive Rate Scaling, algorithm to have a separate learning rate for each
1069    layer of the network, which leads to better stability over large batch sizes.
1070
1071    This optimizer accepts the following parameters in addition to those accepted
1072    by :class:`.Optimizer`.
1073
1074    Parameters
1075    ----------
1076    momentum : float, optional
1077        The momentum value.
1078    multi_precision: bool, optional
1079        Flag to control the internal precision of the optimizer.
1080        False: results in using the same precision as the weights (default),
1081        True: makes internal 32-bit copy of the weights and applies gradients
1082        in 32-bit precision even if actual weights used in the model have lower precision.
1083        Turning this on can improve convergence and accuracy when training with float16.
1084
1085    warmup_strategy: string ('linear', 'power2', 'sqrt'. , 'lars'   default : 'linear')
1086    warmup_epochs: unsigned, default: 5
1087    batch_scale:   unsigned, default: 1 (same as batch size * numworkers)
1088    updates_per_epoch: updates_per_epoch (default: 32, Default might not reflect true number batches per epoch. Used for warmup.)
1089    begin_epoch: unsigned, default 0, starting epoch.
1090    """
1091    def __init__(self, momentum=0.0, multi_precision=False, warmup_strategy='linear',
1092                 warmup_epochs=5, batch_scale=1, updates_per_epoch=32, begin_epoch=0, num_epochs=60,
1093                 **kwargs):
1094        super(LBSGD, self).__init__(**kwargs)
1095        logging.info('Running Large-Batch SGD Algorithm')
1096        logging.info('(Batch_scale=%f, warmup_epochs=%d, warmup_strategy=%s, updates_per_epoch=%d)',
1097                     batch_scale, warmup_epochs, warmup_strategy, updates_per_epoch)
1098        self.momentum = momentum
1099        self.multi_precision = multi_precision
1100        # new user parameters for large batch
1101        self.warmup_strategy = warmup_strategy
1102        self.warmup_epochs = warmup_epochs
1103        self.batch_scale = batch_scale
1104        self.updates_per_epoch = updates_per_epoch
1105        self.init_updates = begin_epoch * updates_per_epoch
1106        self.num_epochs = num_epochs
1107        # addl internal usage parameters and storage
1108        self.lbmult = 1
1109        self.cumgrads = {}
1110        # for adaptive lr
1111        self.adaptive = False
1112        self.admult = 1  # adaptation constant
1113
1114    def create_state(self, index, weight):
1115        momentum = None
1116        weight_master_copy = None
1117        if self.multi_precision and weight.dtype == numpy.float16:
1118            weight_master_copy = array(weight, ctx=weight.context, dtype=numpy.float32)
1119            if self.momentum != 0.0:
1120                momentum = zeros(weight.shape, weight.context, dtype=numpy.float32,
1121                                 stype=weight.stype)
1122            return (momentum, weight_master_copy)
1123        if weight.dtype == numpy.float16 and not self.multi_precision:
1124            warnings.warn("Accumulating with float16 in optimizer can lead to "
1125                          "poor accuracy or slow convergence. "
1126                          "Consider using multi_precision=True option of the "
1127                          "SGD optimizer")
1128        if self.momentum != 0.0:
1129            momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype)
1130        return momentum
1131
1132    def _get_lbmult(self, nup):
1133        """Returns lr scaling factor for large batch according to warmup schedule
1134        (to be implemented)
1135        """
1136        nwup = self.warmup_epochs * self.updates_per_epoch
1137        strategy = self.warmup_strategy
1138        maxmult = float(self.batch_scale)
1139        if nup >= nwup:
1140            mult = maxmult
1141        elif nwup <= 1:
1142            mult = 1.0
1143        else:
1144            if (strategy == 'linear'):
1145                mult = 1.0 + (maxmult - 1) * nup / nwup
1146            elif (strategy == 'power2'):
1147                mult = 1.0 + (maxmult-1) * (nup*nup)/(nwup*nwup)
1148            elif (strategy == 'sqrt'):
1149                mult = 1.0 + (maxmult - 1) * math.sqrt(float(nup) / nwup)
1150            else:
1151                mult = 1.0
1152        return mult
1153
1154    def _get_lars(self, weight, g, wd):
1155        """Returns a scaling factor for the learning rate for this layer
1156        default is 1
1157        """
1158        weight2 = self._l2norm(weight)
1159        grad2 = self._l2norm(g)
1160        lars = math.sqrt(weight2 / (grad2 + wd * weight2 + 1e-18))
1161        if lars < 0.01:
1162            lars = 0.01
1163        elif lars > 100:
1164            lars = 100
1165        return lars
1166
1167    def _l2norm(self, v):
1168        "inner product implementation"
1169        norm = multiply(v, v).asnumpy().sum()
1170        return norm
1171
1172    def _reset_cum_gradient(self, index):
1173        "called every macro-batch to reset cumulated gradients to 0 for a given index"
1174        self.cumgrads[index]['cum_grad'] = 0
1175
1176    def _get_cum_gradient(self, index):
1177        "get the cumulated gradient for index"
1178        if index in self.cumgrads:
1179            return self.cumgrads[index]
1180        else:
1181            return {}
1182
1183    def _put_cum_gradient(self, index, cgrad):
1184        "store cumulated gradient for index"
1185        self.cumgrads[index] = cgrad
1186
1187    def _cumulate_gradient(self, grad, index):
1188        "Cumulate gradients for large-batch emulation. Cumulated by index (layer)"
1189        cgrad = self._get_cum_gradient(index)
1190        if cgrad:
1191            num_cums = cgrad['num_cums']
1192            if num_cums > 0:
1193                cum_grad = cgrad['cum_grad'] + grad
1194                num_cums += 1
1195            else:
1196                cum_grad = grad
1197                num_cums = self.init_updates + 1
1198        else:
1199            cum_grad = grad
1200            num_cums = self.init_updates + 1
1201        cgrad = {'cum_grad': cum_grad, 'num_cums': num_cums}
1202        self._put_cum_gradient(index, cgrad)
1203        return cgrad
1204
1205    def update(self, index, weight, grad, state):
1206        assert (isinstance(weight, NDArray))
1207        assert (isinstance(grad, NDArray))
1208
1209        lr = self._get_lr(index)
1210        wd = self._get_wd(index)
1211        self._update_count(index)
1212
1213        # new stuff for large batch
1214        cgrad = self._cumulate_gradient(grad, index)
1215        if (cgrad['num_cums'] % self.batch_scale) == 0:
1216            grad = cgrad['cum_grad'] / self.batch_scale
1217            if self.warmup_strategy == 'lars':
1218                lbmult = self._get_lars(weight, grad, wd)
1219            else:
1220                lbmult = self._get_lbmult(cgrad['num_cums'])
1221            lr = lr * lbmult
1222            # do the regular sgd update flow
1223            kwargs = {'rescale_grad': self.rescale_grad}
1224            if self.momentum > 0:
1225                kwargs['momentum'] = self.momentum
1226            if self.clip_gradient:
1227                kwargs['clip_gradient'] = self.clip_gradient
1228            use_multi_precision = isinstance(state, (list, tuple))
1229
1230            if not use_multi_precision:
1231                if state is not None:
1232                    sgd_mom_update(weight, grad, state, out=weight, lr=lr, wd=wd, **kwargs)
1233                else:
1234                    sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs)
1235            else:
1236                if state[0] is not None:
1237                    mp_sgd_mom_update(weight, grad, state[0], state[1], out=weight, lr=lr, wd=wd,
1238                                      **kwargs)
1239                else:
1240                    mp_sgd_update(weight, grad, state[1], out=weight, lr=lr, wd=wd, **kwargs)
1241            # reset update count and cumulated gradient per large batch
1242            self._reset_cum_gradient(index)
1243        else:
1244            lr = 0.0
1245            kwargs = {}
1246            sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs)
1247
1248
1249@register
1250class LAMB(Optimizer):
1251    """LAMB Optimizer.
1252    """
1253    def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
1254                 lower_bound=None, upper_bound=None, bias_correction=True, **kwargs):
1255        super(LAMB, self).__init__(learning_rate=learning_rate, **kwargs)
1256        self.beta1 = beta1
1257        self.beta2 = beta2
1258        self.epsilon = epsilon
1259        self.lower_bound = lower_bound
1260        self.upper_bound = upper_bound
1261        self.bias_correction = bias_correction
1262        self.aggregate_num = max(1, min(45, int(os.getenv('MXNET_OPTIMIZER_AGGREGATION_SIZE', "45"))))
1263
1264    def create_state(self, index, weight):
1265        stype = weight.stype
1266        dtype = weight.dtype
1267        return (zeros(weight.shape, weight.context, dtype=dtype, stype=stype),
1268                zeros(weight.shape, weight.context, dtype=dtype, stype=stype))
1269
1270    def _update_impl(self, index, weight, grad, state, multi_precision=False):
1271        kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
1272                  'bias_correction': self.bias_correction,
1273                  'rescale_grad': self.rescale_grad}
1274
1275        if self.aggregate_num <= 1 or not isinstance(index, (tuple, list)):
1276            if isinstance(index, (tuple, list)):
1277                assert(len(index) == self.aggregate_num)
1278                index, weight, grad, state = index[0], weight[0], grad[0], state[0]
1279            assert(isinstance(weight, NDArray))
1280            assert(isinstance(grad, NDArray))
1281            self._update_count(index)
1282            lr = self._get_lr(index)
1283            wd = self._get_wd(index)
1284            t = self._index_update_count[index]
1285            weight_ptr = weight
1286            grad_ptr = grad
1287            if multi_precision:
1288                mean, var = state[1]
1289                weight32 = state[0]
1290            else:
1291                mean, var = state
1292            kwargs['t'] = t
1293            if self.clip_gradient:
1294                kwargs['clip_gradient'] = self.clip_gradient
1295
1296            if multi_precision:
1297                g = mp_lamb_update_phase1(weight_ptr, grad_ptr, mean, var, weight32, wd=wd, **kwargs)
1298                kwargs = {}
1299                if self.lower_bound:
1300                    kwargs['lower_bound'] = self.lower_bound
1301                if self.upper_bound:
1302                    kwargs['upper_bound'] = self.upper_bound
1303                r_1 = weight32.norm()
1304                r_2 = g.norm()
1305                mp_lamb_update_phase2(weight_ptr, g, r_1, r_2, weight32, lr=lr, out=weight_ptr, **kwargs)
1306            else:
1307                g = lamb_update_phase1(weight_ptr, grad_ptr, mean, var, wd=wd, **kwargs)
1308                kwargs = {}
1309                if self.lower_bound:
1310                    kwargs['lower_bound'] = self.lower_bound
1311                if self.upper_bound:
1312                    kwargs['upper_bound'] = self.upper_bound
1313                r_1 = weight_ptr.norm()
1314                r_2 = g.norm()
1315                lamb_update_phase2(weight_ptr, g, r_1, r_2, lr=lr, out=weight_ptr, **kwargs)
1316        else:
1317            if self.clip_gradient:
1318                kwargs['clip_gradient'] = self.clip_gradient
1319            if self.lower_bound:
1320                kwargs['lower_bound'] = self.lower_bound
1321            if self.upper_bound:
1322                kwargs['upper_bound'] = self.upper_bound
1323
1324            step_count, lrs, wds = [], [], []
1325            for i, w_i, g_i in zip(index, weight, grad):
1326                assert(isinstance(w_i, NDArray))
1327                assert(isinstance(g_i, NDArray))
1328                self._update_count(i)
1329                step_count.append(self._index_update_count[i])
1330                lrs.append(self._get_lr(i))
1331                wds.append(self._get_wd(i))
1332
1333            updated_tensors = 0
1334            while updated_tensors < len(weight):
1335                sidx = updated_tensors
1336                eidx = min(updated_tensors + self.aggregate_num, len(weight))
1337                if not multi_precision:
1338                    mean, var = list(zip(*state[sidx:eidx]))
1339                    multi_lamb_update(weight[sidx:eidx],
1340                                      grad[sidx:eidx],
1341                                      mean, var,
1342                                      out=weight[sidx:eidx],
1343                                      step_count=step_count[sidx:eidx],
1344                                      lrs=lrs[sidx:eidx],
1345                                      wds=wds[sidx:eidx],
1346                                      **kwargs)
1347                else:
1348                    mean_var = list(zip(*state[sidx:eidx]))[1]
1349                    temp = list(zip(*mean_var))
1350                    mean = temp[0]
1351                    var = temp[1]
1352                    multi_mp_lamb_update(weight[sidx:eidx],
1353                                         grad[sidx:eidx],
1354                                         mean, var,
1355                                         list(zip(*state[sidx:eidx]))[0],
1356                                         out=weight[sidx:eidx],
1357                                         step_count=step_count[sidx:eidx],
1358                                         lrs=lrs[sidx:eidx],
1359                                         wds=wds[sidx:eidx],
1360                                         **kwargs)
1361                updated_tensors += self.aggregate_num
1362
1363    def update(self, index, weight, grad, state):
1364        self._update_impl(index, weight, grad, state, multi_precision=False)
1365
1366    def update_multi_precision(self, index, weight, grad, state):
1367        if not isinstance(index, (tuple, list)):
1368            use_multi_precision = self.multi_precision and weight.dtype == numpy.float16
1369        else:
1370            use_multi_precision = self.multi_precision and weight[0].dtype == numpy.float16
1371        self._update_impl(index, weight, grad, state,
1372                          multi_precision=use_multi_precision)
1373
1374# pylint: enable=line-too-long
1375@register
1376class DCASGD(Optimizer):
1377    """The DCASGD optimizer.
1378
1379    This class implements the optimizer described in *Asynchronous Stochastic Gradient Descent
1380    with Delay Compensation for Distributed Deep Learning*,
1381    available at https://arxiv.org/abs/1609.08326.
1382
1383    This optimizer accepts the following parameters in addition to those accepted
1384    by :class:`.Optimizer`.
1385
1386    Parameters
1387    ----------
1388    momentum : float, optional
1389       The momentum value.
1390
1391    lamda : float, optional
1392       Scale DC value.
1393    """
1394    def __init__(self, momentum=0.0, lamda=0.04, **kwargs):
1395        super(DCASGD, self).__init__(**kwargs)
1396        self.momentum = momentum
1397        self.weight_previous = {}
1398        self.lamda = lamda
1399
1400    def create_state(self, index, weight):
1401        if self.momentum == 0.0:
1402            return (None,
1403                    weight.copy())  # previous weight
1404        else:
1405            return (zeros(weight.shape, weight.context, dtype=weight.dtype), # momentum
1406                    weight.copy())  # previous weight
1407
1408    def update(self, index, weight, grad, state):
1409        assert(isinstance(weight, NDArray))
1410        assert(isinstance(grad, NDArray))
1411        self._update_count(index)
1412        lr = self._get_lr(index)
1413        wd = self._get_wd(index)
1414
1415        grad = grad * self.rescale_grad
1416        if self.clip_gradient is not None:
1417            grad = clip(grad, -self.clip_gradient, self.clip_gradient)
1418
1419        mom, previous_weight = state
1420        if mom:
1421            mom[:] *= self.momentum
1422            mom[:] += -lr * (grad + wd * weight + self.lamda \
1423                             * grad * grad * (weight - previous_weight))
1424        else:
1425            assert(self.momentum == 0.0)
1426            mom = -lr * (grad + wd * weight + self.lamda \
1427                         * grad * grad * (weight - previous_weight))
1428        previous_weight[:] = weight
1429        weight[:] += mom
1430
1431@register
1432class NAG(Optimizer):
1433    """Nesterov accelerated gradient.
1434
1435    This optimizer updates each weight by::
1436
1437        state = momentum * state + grad + wd * weight
1438        weight = weight - (lr * (grad + momentum * state))
1439
1440    Parameters
1441    ----------
1442    momentum : float, optional
1443       The momentum value.
1444    multi_precision: bool, optional
1445        Flag to control the internal precision of the optimizer.
1446        False: results in using the same precision as the weights (default),
1447        True: makes internal 32-bit copy of the weights and applies gradients
1448        in 32-bit precision even if actual weights used in the model have lower precision.
1449        Turning this on can improve convergence and accuracy when training with float16.
1450    """
1451    def __init__(self, momentum=0.0, **kwargs):
1452        super(NAG, self).__init__(**kwargs)
1453        self.momentum = momentum
1454
1455    def create_state_multi_precision(self, index, weight):
1456        weight_master_copy = None
1457        if self.multi_precision and weight.dtype == numpy.float16:
1458            weight_master_copy = weight.astype(numpy.float32)
1459            return (self.create_state(index, weight_master_copy), weight_master_copy)
1460        if weight.dtype == numpy.float16 and not self.multi_precision:
1461            warnings.warn("Accumulating with float16 in optimizer can lead to "
1462                          "poor accuracy or slow convergence. "
1463                          "Consider using multi_precision=True option of the "
1464                          "NAG optimizer")
1465        return self.create_state(index, weight)
1466
1467    def create_state(self, index, weight):
1468        momentum = None
1469        if self.momentum != 0.0:
1470            momentum = zeros(weight.shape, weight.context, dtype=weight.dtype)
1471        return momentum
1472
1473    def _update_impl(self, index, weight, grad, state, multi_precision=False):
1474        assert(isinstance(weight, NDArray))
1475        assert(isinstance(grad, NDArray))
1476        self._update_count(index)
1477        lr = self._get_lr(index)
1478        wd = self._get_wd(index)
1479
1480        kwargs = {'rescale_grad': self.rescale_grad}
1481        if self.momentum > 0:
1482            kwargs['momentum'] = self.momentum
1483        if self.clip_gradient:
1484            kwargs['clip_gradient'] = self.clip_gradient
1485
1486        if not multi_precision:
1487            if state is not None:
1488                nag_mom_update(weight, grad, state, out=weight, lr=lr, wd=wd, **kwargs)
1489            else:
1490                sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs)
1491        else:
1492            if state[0] is not None:
1493                mp_nag_mom_update(weight, grad, state[0], state[1], out=weight,
1494                                  lr=lr, wd=wd, **kwargs)
1495            else:
1496                mp_sgd_update(weight, grad, state[1], out=weight,
1497                              lr=lr, wd=wd, **kwargs)
1498
1499    def update(self, index, weight, grad, state):
1500        self._update_impl(index, weight, grad, state, multi_precision=False)
1501
1502    def update_multi_precision(self, index, weight, grad, state):
1503        use_multi_precision = self.multi_precision and weight.dtype == numpy.float16 \
1504                                and isinstance(state, (tuple, list))
1505        self._update_impl(index, weight, grad, state,
1506                          multi_precision=use_multi_precision)
1507
1508
1509@register
1510class SGLD(Optimizer):
1511    """Stochastic Gradient Riemannian Langevin Dynamics.
1512
1513    This class implements the optimizer described in the paper *Stochastic Gradient
1514    Riemannian Langevin Dynamics on the Probability Simplex*, available at
1515    https://papers.nips.cc/paper/4883-stochastic-gradient-riemannian-langevin-dynamics-on-the-probability-simplex.pdf.
1516
1517    """
1518    def __init__(self, **kwargs):
1519        super(SGLD, self).__init__(**kwargs)
1520
1521    def create_state(self, index, weight):
1522        return None
1523
1524    def update(self, index, weight, grad, state):
1525        assert(isinstance(weight, NDArray))
1526        assert(isinstance(grad, NDArray))
1527        self._update_count(index)
1528        lr = self._get_lr(index)
1529        wd = self._get_wd(index)
1530
1531        grad = grad * self.rescale_grad
1532        if self.clip_gradient is not None:
1533            grad = clip(grad, -self.clip_gradient, self.clip_gradient)
1534        weight[:] += - lr/2 * (grad + wd * weight)
1535        weight[:] += normal(0, math.sqrt(lr), shape=weight.shape,
1536                            dtype=weight.dtype, ctx=weight.context)
1537
1538
1539
1540@register  # pylint: disable=invalid-name
1541class ccSGD(SGD):
1542    """[DEPRECATED] Same as `SGD`. Left here for backward compatibility."""
1543    def __init__(self, *args, **kwargs):
1544        super(ccSGD, self).__init__(*args, **kwargs)
1545
1546@register
1547class Adam(Optimizer):
1548    """The Adam optimizer.
1549
1550    This class implements the optimizer described in *Adam: A Method for
1551    Stochastic Optimization*, available at http://arxiv.org/abs/1412.6980.
1552
1553    If the storage types of grad is ``row_sparse``, and ``lazy_update`` is True, \
1554    **lazy updates** at step t are applied by::
1555
1556        for row in grad.indices:
1557            rescaled_grad[row] = clip(grad[row] * rescale_grad + wd * weight[row], clip_gradient)
1558            m[row] = beta1 * m[row] + (1 - beta1) * rescaled_grad[row]
1559            v[row] = beta2 * v[row] + (1 - beta2) * (rescaled_grad[row]**2)
1560            lr = learning_rate * sqrt(1 - beta1**t) / (1 - beta2**t)
1561            w[row] = w[row] - lr * m[row] / (sqrt(v[row]) + epsilon)
1562
1563    The lazy update only updates the mean and var for the weights whose row_sparse
1564    gradient indices appear in the current batch, rather than updating it for all indices.
1565    Compared with the original update, it can provide large improvements in model training
1566    throughput for some applications. However, it provides slightly different semantics than
1567    the original update, and may lead to different empirical results.
1568
1569    Otherwise, **standard updates** at step t are applied by::
1570
1571        rescaled_grad = clip(grad * rescale_grad + wd * weight, clip_gradient)
1572        m = beta1 * m + (1 - beta1) * rescaled_grad
1573        v = beta2 * v + (1 - beta2) * (rescaled_grad**2)
1574        lr = learning_rate * sqrt(1 - beta1**t) / (1 - beta2**t)
1575        w = w - lr * m / (sqrt(v) + epsilon)
1576
1577    This optimizer accepts the following parameters in addition to those accepted
1578    by :class:`.Optimizer`.
1579
1580    For details of the update algorithm, see :class:`~mxnet.ndarray.adam_update`.
1581
1582    Parameters
1583    ----------
1584    beta1 : float, optional
1585        Exponential decay rate for the first moment estimates.
1586    beta2 : float, optional
1587        Exponential decay rate for the second moment estimates.
1588    epsilon : float, optional
1589        Small value to avoid division by 0.
1590    lazy_update : bool, optional
1591       Default is True. If True, lazy updates are applied \
1592       if the storage types of weight and grad are both ``row_sparse``.
1593    """
1594    def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
1595                 lazy_update=True, **kwargs):
1596        super(Adam, self).__init__(learning_rate=learning_rate, **kwargs)
1597        self.beta1 = beta1
1598        self.beta2 = beta2
1599        self.epsilon = epsilon
1600        self.lazy_update = lazy_update
1601
1602    def create_state(self, index, weight):
1603        stype = weight.stype if self.lazy_update else 'default'
1604        return (zeros(weight.shape, weight.context, dtype=weight.dtype,
1605                      stype=stype),  # mean
1606                zeros(weight.shape, weight.context, dtype=weight.dtype,
1607                      stype=stype))  # variance
1608
1609    def update(self, index, weight, grad, state):
1610        assert(isinstance(weight, NDArray))
1611        assert(isinstance(grad, NDArray))
1612        self._update_count(index)
1613        lr = self._get_lr(index)
1614        wd = self._get_wd(index)
1615
1616        t = self._index_update_count[index]
1617        coef1 = 1. - self.beta1**t
1618        coef2 = 1. - self.beta2**t
1619        lr *= math.sqrt(coef2)/coef1
1620
1621        kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
1622                  'rescale_grad': self.rescale_grad}
1623        if self.clip_gradient:
1624            kwargs['clip_gradient'] = self.clip_gradient
1625
1626        mean, var = state
1627        adam_update(weight, grad, mean, var, out=weight,
1628                    lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs)
1629
1630@register
1631class AdaGrad(Optimizer):
1632    """AdaGrad optimizer.
1633
1634    This class implements the AdaGrad optimizer described in *Adaptive Subgradient
1635    Methods for Online Learning and Stochastic Optimization*, and available at
1636    http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf.
1637
1638    This optimizer updates each weight by::
1639
1640        grad = clip(grad * rescale_grad, clip_gradient)
1641        history += square(grad)
1642        div = grad / sqrt(history + float_stable_eps)
1643        weight += (div + weight * wd) * -lr
1644
1645    This optimizer accepts the following parameters in addition to those accepted
1646    by :class:`.Optimizer`.
1647
1648    See Also
1649    ----------
1650    :meth:`mxnet.ndarray.sparse.adagrad_update`.
1651
1652    Parameters
1653    ----------
1654    eps: float, optional
1655        Initial value of the history accumulator. Avoids division by 0.
1656
1657    """
1658    def __init__(self, eps=1e-7, **kwargs):
1659        super(AdaGrad, self).__init__(**kwargs)
1660        self.float_stable_eps = eps
1661
1662    def create_state(self, index, weight):
1663        return zeros(weight.shape, weight.context, stype=weight.stype)  # history
1664
1665    def update(self, index, weight, grad, state):
1666        assert(isinstance(weight, NDArray))
1667        assert(isinstance(grad, NDArray))
1668        self._update_count(index)
1669        lr = self._get_lr(index)
1670        wd = self._get_wd(index)
1671
1672        is_sparse = grad.stype == 'row_sparse'
1673        history = state
1674
1675        if is_sparse:
1676            kwargs = {'epsilon': self.float_stable_eps,
1677                      'rescale_grad': self.rescale_grad}
1678            if self.clip_gradient:
1679                kwargs['clip_gradient'] = self.clip_gradient
1680            sparse.adagrad_update(weight, grad, history, out=weight, lr=lr, wd=wd, **kwargs)
1681        else:
1682            grad = grad * self.rescale_grad
1683            if self.clip_gradient is not None:
1684                grad = clip(grad, -self.clip_gradient, self.clip_gradient)
1685            history[:] += square(grad)
1686            div = grad / sqrt(history + self.float_stable_eps)
1687            weight[:] += (div + weight * wd) * -lr
1688
1689@register
1690class RMSProp(Optimizer):
1691    """The RMSProp optimizer.
1692
1693    Two versions of RMSProp are implemented:
1694
1695    If ``centered=False``, we follow
1696    http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf by
1697    Tieleman & Hinton, 2012.
1698    For details of the update algorithm see :class:`~mxnet.ndarray.rmsprop_update`.
1699
1700    If ``centered=True``, we follow http://arxiv.org/pdf/1308.0850v5.pdf (38)-(45)
1701    by Alex Graves, 2013.
1702    For details of the update algorithm see :class:`~mxnet.ndarray.rmspropalex_update`.
1703
1704    This optimizer accepts the following parameters in addition to those accepted
1705    by :class:`.Optimizer`.
1706
1707    Parameters
1708    ----------
1709    gamma1: float, optional
1710        A decay factor of moving average over past squared gradient.
1711    gamma2: float, optional
1712        A "momentum" factor. Only used if `centered`=``True``.
1713    epsilon : float, optional
1714        Small value to avoid division by 0.
1715    centered : bool, optional
1716        Flag to control which version of RMSProp to use.::
1717
1718            True: will use Graves's version of `RMSProp`,
1719            False: will use Tieleman & Hinton's version of `RMSProp`.
1720
1721    clip_weights : float, optional
1722        Clips weights into range ``[-clip_weights, clip_weights]``.
1723    """
1724    def __init__(self, learning_rate=0.001, gamma1=0.9, gamma2=0.9,
1725                 epsilon=1e-8, centered=False, clip_weights=None, **kwargs):
1726        super(RMSProp, self).__init__(learning_rate=learning_rate, **kwargs)
1727        self.gamma1 = gamma1
1728        self.gamma2 = gamma2
1729        self.centered = centered
1730        self.epsilon = epsilon
1731        self.clip_weights = clip_weights
1732
1733    def create_state(self, index, weight):
1734        if self.centered:
1735            return (
1736                zeros(weight.shape, weight.context, stype=weight.stype),  # n
1737                zeros(weight.shape, weight.context, stype=weight.stype),  # g
1738                zeros(weight.shape, weight.context, stype=weight.stype))  # delta
1739        else:
1740            return (zeros(weight.shape, weight.context, stype=weight.stype),)  # n
1741
1742    def update(self, index, weight, grad, state):
1743        assert(isinstance(weight, NDArray))
1744        assert(isinstance(grad, NDArray))
1745        self._update_count(index)
1746        lr = self._get_lr(index)
1747        wd = self._get_wd(index)
1748
1749        kwargs = {'gamma1': self.gamma1, 'epsilon': self.epsilon,
1750                  'rescale_grad': self.rescale_grad}
1751        if self.centered:
1752            kwargs['gamma2'] = self.gamma2
1753        if self.clip_gradient:
1754            kwargs['clip_gradient'] = self.clip_gradient
1755        if self.clip_weights:
1756            kwargs['clip_weights'] = self.clip_weights
1757
1758        if not self.centered:
1759            (n, ) = state
1760            rmsprop_update(
1761                weight, grad, n, out=weight, lr=lr, wd=wd, **kwargs)
1762        else:
1763            n, g, delta = state
1764            rmspropalex_update(weight, grad, n, g, delta, out=weight,
1765                               lr=lr, wd=wd, **kwargs)
1766
1767@register
1768class AdaDelta(Optimizer):
1769    """The AdaDelta optimizer.
1770
1771    This class implements AdaDelta, an optimizer described in  *ADADELTA: An adaptive
1772    learning rate method*, available at https://arxiv.org/abs/1212.5701.
1773
1774    This optimizer updates each weight by::
1775
1776        grad = clip(grad * rescale_grad + wd * weight, clip_gradient)
1777        acc_grad = rho * acc_grad + (1. - rho) * grad * grad
1778        delta = sqrt(acc_delta + epsilon) / sqrt(acc_grad + epsilon) * grad
1779        acc_delta = rho * acc_delta + (1. - rho) * delta * delta
1780        weight -= (delta + wd * weight)
1781
1782    This optimizer accepts the following parameters in addition to those accepted
1783    by :class:`.Optimizer`.
1784
1785    Parameters
1786    ----------
1787    rho: float
1788        Decay rate for both squared gradients and delta.
1789    epsilon : float
1790        Small value to avoid division by 0.
1791    """
1792    def __init__(self, rho=0.90, epsilon=1e-5, **kwargs):
1793        super(AdaDelta, self).__init__(**kwargs)
1794        self.rho = rho
1795        self.epsilon = epsilon
1796
1797    def create_state(self, index, weight):
1798        return (zeros(weight.shape, weight.context),  # accumulated g
1799                zeros(weight.shape, weight.context))  # accumulated delta
1800
1801    def update(self, index, weight, grad, state):
1802        assert(isinstance(weight, NDArray))
1803        assert(isinstance(grad, NDArray))
1804        wd = self._get_wd(index)
1805        self._update_count(index)
1806
1807        # preprocess grad
1808        grad *= self.rescale_grad
1809        if self.clip_gradient is not None:
1810            grad = clip(grad, - self.clip_gradient, self.clip_gradient)
1811
1812        # accumulated g and delta initlization
1813        acc_g, acc_delta = state
1814
1815        # update g, delta
1816        acc_g[:] *= self.rho
1817        acc_g[:] += (1. - self.rho) * grad * grad
1818        current_delta = sqrt(acc_delta + self.epsilon) / sqrt(acc_g + self.epsilon) * grad
1819        acc_delta[:] *= self.rho
1820        acc_delta[:] += (1. - self.rho) * current_delta * current_delta
1821
1822        # update weight
1823        weight[:] -= current_delta + wd * weight
1824
1825#pylint: disable=invalid-name
1826#pylint: disable=line-too-long
1827@register
1828class Ftrl(Optimizer):
1829    """The Ftrl optimizer.
1830
1831    Referenced from *Ad Click Prediction: a View from the Trenches*, available at
1832    http://dl.acm.org/citation.cfm?id=2488200.
1833
1834    eta :
1835        .. math::
1836           \\eta_{t,i} = \\frac{learningrate}{\\beta+\\sqrt{\\sum_{s=1}^tg_{s,i}^2}}
1837
1838    The optimizer updates the weight by::
1839
1840        rescaled_grad = clip(grad * rescale_grad, clip_gradient)
1841        z += rescaled_grad - (sqrt(n + rescaled_grad**2) - sqrt(n)) * weight / learning_rate
1842        n += rescaled_grad**2
1843        w = (sign(z) * lamda1 - z) / ((beta + sqrt(n)) / learning_rate + wd) * (abs(z) > lamda1)
1844
1845    If the storage types of weight, state and grad are all ``row_sparse``, \
1846    **sparse updates** are applied by::
1847
1848        for row in grad.indices:
1849            rescaled_grad[row] = clip(grad[row] * rescale_grad, clip_gradient)
1850            z[row] += rescaled_grad[row] - (sqrt(n[row] + rescaled_grad[row]**2) - sqrt(n[row])) * weight[row] / learning_rate
1851            n[row] += rescaled_grad[row]**2
1852            w[row] = (sign(z[row]) * lamda1 - z[row]) / ((beta + sqrt(n[row])) / learning_rate + wd) * (abs(z[row]) > lamda1)
1853
1854    The sparse update only updates the z and n for the weights whose row_sparse
1855    gradient indices appear in the current batch, rather than updating it for all
1856    indices. Compared with the original update, it can provide large
1857    improvements in model training throughput for some applications. However, it
1858    provides slightly different semantics than the original update, and
1859    may lead to different empirical results.
1860
1861    For details of the update algorithm, see :class:`~mxnet.ndarray.ftrl_update`.
1862
1863    This optimizer accepts the following parameters in addition to those accepted
1864    by :class:`.Optimizer`.
1865
1866    Parameters
1867    ----------
1868    lamda1 : float, optional
1869        L1 regularization coefficient.
1870    learning_rate : float, optional
1871        The initial learning rate.
1872    beta : float, optional
1873        Per-coordinate learning rate correlation parameter.
1874    """
1875
1876    def __init__(self, lamda1=0.01, learning_rate=0.1, beta=1, **kwargs):
1877        super(Ftrl, self).__init__(**kwargs)
1878        self.lamda1 = lamda1
1879        self.beta = beta
1880        self.lr = learning_rate
1881
1882    def create_state(self, index, weight):
1883        return (zeros(weight.shape, weight.context, stype=weight.stype),  # z
1884                zeros(weight.shape, weight.context, stype=weight.stype))  # n
1885
1886    def update(self, index, weight, grad, state):
1887        assert(isinstance(weight, NDArray))
1888        assert(isinstance(grad, NDArray))
1889        self._update_count(index)
1890        wd = self._get_wd(index)
1891        lr = self._get_lr(index)
1892
1893        kwargs = {'lamda1': self.lamda1, 'beta': self.beta, 'rescale_grad': self.rescale_grad}
1894        if self.clip_gradient:
1895            kwargs['clip_gradient'] = self.clip_gradient
1896
1897        # accumulated g and delta initialization
1898        z, n = state
1899        ftrl_update(weight, grad, z, n, out=weight,
1900                    lr=lr, wd=wd, **kwargs)
1901
1902# pylint: enable=line-too-long
1903@register
1904class Adamax(Optimizer):
1905    """The AdaMax optimizer.
1906
1907    It is a variant of Adam based on the infinity norm
1908    available at http://arxiv.org/abs/1412.6980 Section 7.
1909
1910    The optimizer updates the weight by::
1911
1912        grad = clip(grad * rescale_grad + wd * weight, clip_gradient)
1913        m = beta1 * m_t + (1 - beta1) * grad
1914        u = maximum(beta2 * u, abs(grad))
1915        weight -= lr / (1 - beta1**t) * m / u
1916
1917    This optimizer accepts the following parameters in addition to those accepted
1918    by :class:`.Optimizer`.
1919
1920    Parameters
1921    ----------
1922    beta1 : float, optional
1923        Exponential decay rate for the first moment estimates.
1924    beta2 : float, optional
1925        Exponential decay rate for the second moment estimates.
1926    """
1927    def __init__(self, learning_rate=0.002, beta1=0.9, beta2=0.999, **kwargs):
1928        super(Adamax, self).__init__(learning_rate=learning_rate, **kwargs)
1929        self.beta1 = beta1
1930        self.beta2 = beta2
1931
1932    def create_state(self, index, weight):
1933        return (zeros(weight.shape, weight.context, dtype=weight.dtype),  # mean
1934                zeros(weight.shape, weight.context, dtype=weight.dtype))  # variance
1935
1936    def update(self, index, weight, grad, state):
1937        assert(isinstance(weight, NDArray))
1938        assert(isinstance(grad, NDArray))
1939        self._update_count(index)
1940        lr = self._get_lr(index)
1941        wd = self._get_wd(index)
1942
1943        t = self._index_update_count[index]
1944        lr /= (1. - self.beta1**t)
1945
1946        # preprocess grad
1947        grad = grad * self.rescale_grad + wd * weight
1948        if self.clip_gradient is not None:
1949            grad = clip(grad, -self.clip_gradient, self.clip_gradient)
1950
1951        # update m_t and u_t
1952        m_t, u_t = state
1953        m_t[:] *= self.beta1
1954        m_t[:] += (1. - self.beta1) * grad
1955        u_t[:] = maximum(self.beta2 * u_t, NDabs(grad))
1956
1957        # update weight
1958        weight[:] -= lr * m_t / u_t
1959
1960@register
1961class Nadam(Optimizer):
1962    """The Nesterov Adam optimizer.
1963
1964    Much like Adam is essentially RMSprop with momentum,
1965    Nadam is Adam RMSprop with Nesterov momentum available
1966    at http://cs229.stanford.edu/proj2015/054_report.pdf.
1967
1968    This optimizer accepts the following parameters in addition to those accepted
1969    by :class:`.Optimizer`.
1970
1971    Parameters
1972    ----------
1973    beta1 : float, optional
1974        Exponential decay rate for the first moment estimates.
1975    beta2 : float, optional
1976        Exponential decay rate for the second moment estimates.
1977    epsilon : float, optional
1978        Small value to avoid division by 0.
1979    schedule_decay : float, optional
1980        Exponential decay rate for the momentum schedule
1981    """
1982    def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
1983                 schedule_decay=0.004, **kwargs):
1984        super(Nadam, self).__init__(learning_rate=learning_rate, **kwargs)
1985        self.beta1 = beta1
1986        self.beta2 = beta2
1987        self.epsilon = epsilon
1988        self.schedule_decay = schedule_decay
1989        self.m_schedule = 1.
1990
1991    def create_state(self, index, weight):
1992        return (zeros(weight.shape, weight.context, dtype=weight.dtype),  # mean
1993                zeros(weight.shape, weight.context, dtype=weight.dtype))  # variance
1994
1995    def update(self, index, weight, grad, state):
1996        assert(isinstance(weight, NDArray))
1997        assert(isinstance(grad, NDArray))
1998        self._update_count(index)
1999        lr = self._get_lr(index)
2000        wd = self._get_wd(index)
2001
2002        t = self._index_update_count[index]
2003
2004        # preprocess grad
2005        grad = grad * self.rescale_grad + wd * weight
2006        if self.clip_gradient is not None:
2007            grad = clip(grad, -self.clip_gradient, self.clip_gradient)
2008
2009        # warming momentum schedule
2010        momentum_t = self.beta1 * (1. - 0.5 * (pow(0.96, t * self.schedule_decay)))
2011        momentum_t_1 = self.beta1 * (1. - 0.5 * (pow(0.96, (t + 1) * self.schedule_decay)))
2012        self.m_schedule = self.m_schedule * momentum_t
2013        m_schedule_next = self.m_schedule * momentum_t_1
2014
2015        # update m_t and v_t
2016        m_t, v_t = state
2017        m_t[:] *= self.beta1
2018        m_t[:] += (1. - self.beta1) * grad
2019        v_t[:] *= self.beta2
2020        v_t[:] += (1. - self.beta2) * grad * grad
2021
2022        grad_prime = grad / (1. - self.m_schedule)
2023        m_t_prime = m_t / (1. - m_schedule_next)
2024        v_t_prime = v_t / (1. - pow(self.beta2, t))
2025        m_t_bar = (1. - momentum_t) * grad_prime + momentum_t_1 * m_t_prime
2026
2027        # update weight
2028        weight[:] -= lr * m_t_bar / (sqrt(v_t_prime) + self.epsilon)
2029
2030@register
2031class Test(Optimizer):
2032    """The Test optimizer"""
2033    def __init__(self, **kwargs):
2034        super(Test, self).__init__(**kwargs)
2035
2036    def create_state(self, index, weight):
2037        """Creates a state to duplicate weight."""
2038        return zeros(weight.shape, weight.context)
2039
2040    def update(self, index, weight, grad, state):
2041        """Performs w += rescale_grad * grad."""
2042        weight[:] += grad * self.rescale_grad
2043        state[:] = weight
2044
2045# backward compatibility wrapper for Optimizer.CreateOptimizer
2046create = Optimizer.create_optimizer  # pylint: disable=invalid-name
2047
2048
2049def _as_classic(a, allow_np):
2050    # TODO(junwu): This is a temp solution for allowing converting
2051    # np.ndarray to mx.nd.NDArray to be fed into the optimizer since
2052    # users may have custom optimizers implemented using mx.nd.NDArray ops.
2053    from ..numpy import ndarray as np_ndarray
2054    if isinstance(a, (tuple, list)):
2055        if any(isinstance(x, np_ndarray) for x in a):
2056            if allow_np:
2057                return [x.as_nd_ndarray() for x in a]
2058            else:
2059                raise ValueError('Converting np.ndarray to mx.nd.NDArray is not allowed')
2060    else:
2061        if isinstance(a, np_ndarray):
2062            if allow_np:
2063                return a.as_nd_ndarray()
2064            else:
2065                raise ValueError('Converting np.ndarray to mx.nd.NDArray is not allowed')
2066    return a
2067
2068
2069
2070class Updater(object):
2071    """Updater for kvstore."""
2072    def __init__(self, optimizer):
2073        self.optimizer = optimizer
2074        self.states = {}
2075        self.states_synced = {}
2076        self.aggregate_updates = optimizer.aggregate_num > 0
2077
2078    def __call__(self, index, grad, weight):
2079        """Updates weight given gradient and index."""
2080        allow_np = self.optimizer.allow_np_array if hasattr(self.optimizer, "allow_np_array") else is_np_array()
2081        if not isinstance(index, (list, tuple)):
2082            indices = [index]
2083            grads = [_as_classic(grad, allow_np)]
2084            weights = [_as_classic(weight, allow_np)]
2085        else:
2086            indices = index
2087            grads = _as_classic(grad, allow_np)
2088            weights = _as_classic(weight, allow_np)
2089        if weights:
2090            self.optimizer._set_current_context(weights[0].context.device_id)
2091        for i, idx in enumerate(indices):
2092            # convert ctypes.char_p.value back to python str if needed
2093            if isinstance(idx, bytes):
2094                indices[i] = py_str(idx)
2095                idx = indices[i]
2096            if idx not in self.states:
2097                self.states[idx] = self.optimizer.create_state_multi_precision(idx, weights[i])
2098                self.states_synced[idx] = True
2099            elif not self.states_synced[idx]:
2100                self.states[idx] = \
2101                    self.sync_state_context(self.states[idx], weights[i].context)
2102                self.states_synced[idx] = True
2103        if self.aggregate_updates:
2104            # segregate values based on type
2105            type_map = {}
2106            for i, w, g in zip(indices, weights, grads):
2107                if w.dtype in type_map:
2108                    type_map[w.dtype].append((i, w, g))
2109                else:
2110                    type_map[w.dtype] = [(i, w, g)]
2111            for idx in type_map:
2112                current_index = 0
2113                indices, weights, grads = zip(*type_map[idx])
2114                while current_index < len(indices):
2115                    states = []
2116                    step = min(self.optimizer.aggregate_num, len(indices) - current_index)
2117                    for j in range(step):
2118                        states.append(self.states[indices[current_index + j]])
2119                    self.optimizer.update_multi_precision(
2120                        indices[current_index:current_index + self.optimizer.aggregate_num],
2121                        weights[current_index:current_index + self.optimizer.aggregate_num],
2122                        grads[current_index:current_index + self.optimizer.aggregate_num],
2123                        states)
2124                    current_index += self.optimizer.aggregate_num
2125        else:
2126            for i, w, g in zip(indices, weights, grads):
2127                self.optimizer.update_multi_precision(i, w, g, self.states[i])
2128
2129    def sync_state_context(self, state, context):
2130        """sync state context."""
2131        if isinstance(state, NDArray):
2132            return state.as_in_context(context)
2133        elif isinstance(state, (tuple, list)):
2134            synced_state = (self.sync_state_context(i, context) for i in state)
2135            if isinstance(state, tuple):
2136                return tuple(synced_state)
2137            else:
2138                return list(synced_state)
2139        else:
2140            return state
2141
2142    def set_states(self, states):
2143        """Sets updater states."""
2144        states = pickle.loads(states)
2145        if isinstance(states, tuple) and len(states) == 2:
2146            self.states, self.optimizer = states
2147        else:
2148            self.states = states
2149        self.states_synced = dict.fromkeys(self.states.keys(), False)
2150
2151    def get_states(self, dump_optimizer=False):
2152        """Gets updater states.
2153
2154        Parameters
2155        ----------
2156        dump_optimizer : bool, default False
2157            Whether to also save the optimizer itself. This would also save optimizer
2158            information such as learning rate and weight decay schedules.
2159        """
2160        return pickle.dumps((self.states, self.optimizer) if dump_optimizer else self.states)
2161
2162def get_updater(optimizer):
2163    """Returns a closure of the updater needed for kvstore.
2164
2165    Parameters
2166    ----------
2167    optimizer: Optimizer
2168         The optimizer.
2169
2170    Returns
2171    -------
2172    updater: function
2173         The closure of the updater.
2174    """
2175    return Updater(optimizer)
2176