1from __future__ import division
2import math
3import warnings
4
5import numpy
6
7import chainer
8from chainer.backends import cuda
9from chainer.backends import intel64
10from chainer import optimizer
11from chainer import types
12
13
14if types.TYPE_CHECKING:
15    import typing_extensions as tpe
16
17    class AdamHyperparameter(tpe.Protocol):
18        """Protocol class for hyperparameter of Adam.
19
20        This is only for PEP 544 compliant static type checkers.
21        """
22        alpha = None  # type: float
23        beta1 = None  # type: float
24        beta2 = None  # type: float
25        eps = None  # type: float
26        eta = None  # type: float
27        weight_decay_rate = None  # type: float
28        amsgrad = None  # type: bool
29        adabound = None  # type: bool
30        final_lr = None  # type: float
31        gamma = None  # type: float
32
33
34_default_hyperparam = optimizer.Hyperparameter()  # type: AdamHyperparameter # NOQA
35_default_hyperparam.alpha = 0.001
36_default_hyperparam.beta1 = 0.9
37_default_hyperparam.beta2 = 0.999
38_default_hyperparam.eps = 1e-8
39_default_hyperparam.eta = 1.0
40_default_hyperparam.weight_decay_rate = 0
41_default_hyperparam.amsgrad = False
42_default_hyperparam.adabound = False
43_default_hyperparam.final_lr = 0.1
44_default_hyperparam.gamma = 1e-3
45
46
47def _learning_rate(hp, t):
48    if t == 0:
49        raise RuntimeError(
50            'Can\'t determine the learning rate of Adam optimizer '
51            'because the update steps have not been started.')
52    fix1 = 1. - math.pow(hp.beta1, t)
53    fix2 = 1. - math.pow(hp.beta2, t)
54    return hp.alpha * math.sqrt(fix2) / fix1
55
56
57def _get_intermediate_dtype(dtype):
58    # Returns the dtype for intermediate calculation.
59    # For float16 input, float32 is used.
60    # Otherwise the same dtype as the parameter is used.
61    if dtype == numpy.float16:
62        return numpy.float32
63    return dtype
64
65
66def _inplace_axpby(x, a, b, y):
67    # in-place axpby: x = a * x + b * y
68    if isinstance(x, intel64.mdarray):
69        x.inplace_axpby(a, b, y)
70    else:
71        if a == 1:
72            x += b * y
73        else:
74            x[...] = a * x + b * y
75
76
77class AdamRule(optimizer.UpdateRule):
78
79    """Update rule of Adam optimization algorithm.
80
81    See: `Adam: A Method for Stochastic Optimization
82    <https://arxiv.org/abs/1412.6980v8>`_
83
84    Modified for proper weight decay.
85
86    See: `Fixing Weight Decay Regularization in Adam
87    <https://openreview.net/forum?id=rk6qdGgCZ>`_
88
89    With option to use AMSGrad variant of Adam.
90
91    See: `On the Convergence of Adam and Beyond
92    <https://openreview.net/forum?id=ryQu7f-RZ>`_
93
94    With option to use AdaBound variant of Adam.
95
96    See: `Adaptive Gradient Methods with Dynamic Bound of Learning Rate
97    <https://openreview.net/forum?id=Bkg3g2R9FX>`
98
99    See :class:`~chainer.optimizers.Adam` for the default values
100    of the hyperparameters.
101
102    Args:
103        parent_hyperparam (~chainer.optimizer.Hyperparameter): Hyperparameter
104            that provides the default values.
105        alpha (float): Coefficient of learning rate.
106        beta1 (float): Exponential decay rate of the first order moment.
107        beta2 (float): Exponential decay rate of the second order moment.
108        eps (float): Small value for the numerical stability.
109        eta (float): Schedule multiplier, can be used for warm restarts.
110        weight_decay_rate (float): Weight decay rate.
111        amsgrad (bool): Whether to use the AMSGrad variant of Adam.
112        adabound (bool): Whether to use the AdaBound variant of Adam.
113        final_lr (float): Final (SGD) learning rate in AdaBound.
114        gamma (float): Convergence speed of the bound functions in AdaBound.
115
116    """
117    is_elementwise = True
118
119    _kernel = None
120    _amsgrad_kernel = None
121    _adabound_kernel = None
122    _amsbound_kernel = None
123
124    # Only used in `update_core_gpu`.
125    # A dummy ndarray to help ElementwiseKernel deduce generic type T as
126    # `dtype`.
127    # It cannot be deduced only by scalar arguments.
128    _dummy = None
129
130    def __init__(self, parent_hyperparam=None,
131                 alpha=None, beta1=None, beta2=None, eps=None,
132                 eta=None, weight_decay_rate=None, amsgrad=None,
133                 adabound=None, final_lr=None, gamma=None):
134        super(AdamRule, self).__init__(
135            parent_hyperparam or _default_hyperparam)
136        if alpha is not None:
137            self.hyperparam.alpha = alpha
138        if beta1 is not None:
139            self.hyperparam.beta1 = beta1
140        if beta2 is not None:
141            self.hyperparam.beta2 = beta2
142        if eps is not None:
143            self.hyperparam.eps = eps
144        if eta is not None:
145            self.hyperparam.eta = eta
146        if weight_decay_rate is not None:
147            self.hyperparam.weight_decay_rate = weight_decay_rate
148        if amsgrad is not None:
149            self.hyperparam.amsgrad = amsgrad
150        if adabound is not None:
151            self.hyperparam.adabound = adabound
152        if final_lr is not None:
153            self.hyperparam.final_lr = final_lr
154        if gamma is not None:
155            self.hyperparam.gamma = gamma
156        if self.hyperparam.adabound:
157            self.initial_alpha = self.hyperparam.alpha
158
159    def init_state(self, param):
160        with chainer.using_device(param.device):
161            xp = param.device.xp
162            self.state['m'] = xp.zeros_like(param.data)
163            self.state['v'] = xp.zeros_like(param.data)
164            if self.hyperparam.amsgrad:
165                self.state['vhat'] = xp.zeros_like(param.data)
166
167        # For iDeep
168        if isinstance(param.data, intel64.mdarray):
169            self.state['m'] = intel64.ideep.array(
170                self.state['m'], itype=intel64.ideep.wgt_array)
171            self.state['v'] = intel64.ideep.array(
172                self.state['v'], itype=intel64.ideep.wgt_array)
173            if self.hyperparam.amsgrad:
174                self.state['vhat'] = intel64.ideep.array(
175                    self.state['vhat'], itype=intel64.ideep.wgt_array)
176
177    def _check_eps(self, interm_dtype):
178        # Checks that the eps does not underflow.
179        hp = self.hyperparam
180        eps = interm_dtype(hp.eps)
181        if hp.eps != 0 and eps == 0:
182            raise ValueError(
183                'eps of Adam optimizer is too small for {} ({})'.format(
184                    interm_dtype.name, hp.eps))
185        # Note that the converted `eps` (numpy scalar) is discarded here and
186        # the original `hp.eps` is used in calculation, because Python
187        # scalars are faster in cupy elementwise kernels.
188
189    def update_core_cpu(self, param):
190        grad = param.grad
191        if grad is None:
192            return
193        hp = self.hyperparam
194        dtype = _get_intermediate_dtype(param.dtype.type)
195        self._check_eps(dtype)
196        grad = grad.astype(dtype, copy=False)
197
198        m, v = self.state['m'], self.state['v']
199
200        # m += (1 - beta1) * (grad - m)
201        _inplace_axpby(m, 1.0, 1.0 - hp.beta1, grad - m)
202        # v += (1 - beta2) * (grad * grad - v)
203        _inplace_axpby(v, 1.0, 1.0 - hp.beta2, grad*grad - v)
204
205        if hp.amsgrad:
206            vhat = self.state['vhat']
207            # For iDeep
208            if isinstance(vhat, intel64.mdarray):
209                vhat[...] = numpy.maximum(vhat, v)
210            else:
211                numpy.maximum(vhat, v, out=vhat)
212        else:
213            vhat = v
214        vhat = vhat.astype(dtype, copy=False)
215        step = self.alpha_t / (numpy.sqrt(vhat) + hp.eps)
216        if hp.adabound:
217            lower, upper = self.bounds
218            step = numpy.clip(step, lower, upper)
219        # param -=
220        #  eta * (step * m - weight_decay_rate * param)
221        _inplace_axpby(
222            param.data, 1.0 - hp.eta * hp.weight_decay_rate, -hp.eta, step * m)
223
224    def update_core_gpu(self, param):
225        grad = param.grad
226        if grad is None:
227            return
228        hp = self.hyperparam
229        dtype = _get_intermediate_dtype(param.dtype.type)
230        self._check_eps(dtype)
231
232        if self._dummy is None:
233            self._dummy = cuda.cupy.empty((0,), dtype=dtype)
234
235        if hp.adabound:
236            lower, upper = self.bounds
237        if hp.amsgrad and hp.adabound:
238            if AdamRule._amsbound_kernel is None:
239                AdamRule._amsbound_kernel = cuda.elementwise(
240                    'P grad, T alpha_t, T one_minus_beta1, T one_minus_beta2, '
241                    'T lower, T upper, '
242                    'T eps, T eta, T weight_decay_rate, raw T dummy',
243                    'P param, P m, P v, P vhat',
244                    '''T grad_ = static_cast<T>(grad);
245                       T m_ = static_cast<T>(m);
246                       T v_ = static_cast<T>(v);
247                       T vhat_ = static_cast<T>(vhat);
248                       m_ += one_minus_beta1 * (grad_ - m_);
249                       v_ += one_minus_beta2 * (grad_ * grad_ - v_);
250                       vhat_ = max(vhat_, v_);
251                       vhat = static_cast<T>(vhat_);
252                       m = static_cast<P>(m_);
253                       v = static_cast<P>(v_);
254                       param -= eta *
255                           (max(min(alpha_t / (sqrt(vhat_) + eps), upper),
256                                lower) * m_ + weight_decay_rate * param);''',
257                    'amsbound')
258            AdamRule._amsbound_kernel(
259                grad, self.alpha_t, 1 - hp.beta1,
260                1 - hp.beta2, lower, upper, hp.eps,
261                hp.eta, hp.weight_decay_rate, self._dummy,
262                param.data, self.state['m'], self.state['v'],
263                self.state['vhat'])
264        elif hp.adabound:
265            if AdamRule._adabound_kernel is None:
266                AdamRule._adabound_kernel = cuda.elementwise(
267                    'P grad, T alpha_t, T one_minus_beta1, T one_minus_beta2, '
268                    'T lower, T upper, '
269                    'T eps, T eta, T weight_decay_rate, raw T dummy',
270                    'P param, P m, P v',
271                    '''T grad_ = static_cast<T>(grad);
272                       T m_ = static_cast<T>(m);
273                       T v_ = static_cast<T>(v);
274                       m_ += one_minus_beta1 * (grad_ - m_);
275                       v_ += one_minus_beta2 * (grad_ * grad_ - v_);
276                       m = static_cast<P>(m_);
277                       v = static_cast<P>(v_);
278                       param -= eta *
279                           (max(min(alpha_t / (sqrt(v_) + eps), upper),
280                                lower) * m_ + weight_decay_rate * param);''',
281                    'adabound')
282            AdamRule._adabound_kernel(
283                grad, self.alpha_t, 1 - hp.beta1,
284                1 - hp.beta2, lower, upper, hp.eps,
285                hp.eta, hp.weight_decay_rate, self._dummy,
286                param.data, self.state['m'], self.state['v'])
287        elif hp.amsgrad:
288            if AdamRule._amsgrad_kernel is None:
289                AdamRule._amsgrad_kernel = cuda.elementwise(
290                    'P grad, T alpha_t, T one_minus_beta1, T one_minus_beta2, '
291                    'T eps, T eta, T weight_decay_rate, raw T dummy',
292                    'P param, P m, P v, P vhat',
293                    '''T grad_ = static_cast<T>(grad);
294                       T m_ = static_cast<T>(m);
295                       T v_ = static_cast<T>(v);
296                       T vhat_ = static_cast<T>(vhat);
297                       m_ += one_minus_beta1 * (grad_ - m_);
298                       v_ += one_minus_beta2 * (grad_ * grad_ - v_);
299                       vhat_ = max(vhat_, v_);
300                       vhat = static_cast<T>(vhat_);
301                       m = static_cast<P>(m_);
302                       v = static_cast<P>(v_);
303                       param -= eta * (alpha_t * m_ / (sqrt(vhat_) + eps) +
304                                       weight_decay_rate * param);''',
305                    'adam')
306            AdamRule._amsgrad_kernel(
307                grad, self.alpha_t, 1 - hp.beta1,
308                1 - hp.beta2, hp.eps,
309                hp.eta, hp.weight_decay_rate, self._dummy,
310                param.data, self.state['m'], self.state['v'],
311                self.state['vhat'])
312        else:
313            if AdamRule._kernel is None:
314                AdamRule._kernel = cuda.elementwise(
315                    'P grad, T alpha_t, T one_minus_beta1, T one_minus_beta2, '
316                    'T eps, T eta, T weight_decay_rate, raw T dummy',
317                    'P param, P m, P v',
318                    '''T grad_ = static_cast<T>(grad);
319                       T m_ = static_cast<T>(m);
320                       T v_ = static_cast<T>(v);
321                       m_ += one_minus_beta1 * (grad_ - m_);
322                       v_ += one_minus_beta2 * (grad_ * grad_ - v_);
323                       m = static_cast<P>(m_);
324                       v = static_cast<P>(v_);
325                       param -= eta * (alpha_t * m_ / (sqrt(v_) + eps) +
326                                       weight_decay_rate * param);''',
327                    'adam')
328            AdamRule._kernel(
329                grad, self.alpha_t, 1 - hp.beta1,
330                1 - hp.beta2, hp.eps,
331                hp.eta, hp.weight_decay_rate, self._dummy,
332                param.data, self.state['m'], self.state['v'])
333
334    @property
335    def alpha_t(self):
336        return _learning_rate(self.hyperparam, self.t)
337
338    @property
339    def lr(self):
340        warnings.warn(
341            'AdamRule.lr has been renamed to AdamRule.alpha_t. '
342            'Use of AdamRule.lr is deprecated in Chainer v6.',
343            DeprecationWarning)
344        return self.alpha_t
345
346    @property
347    def bounds(self):
348        if self.t == 0:
349            raise RuntimeError(
350                'Can\'t determine the bounds of AdaBound optimizer '
351                'because the update steps have not been started.')
352        hp = self.hyperparam
353        # Workaround to reflect changing `alpha` in `final_lr`.
354        # (by some of `chainer.training.extensions`)
355        final_lr = hp.final_lr * hp.alpha / self.initial_alpha
356        lower = final_lr * (1.0 - 1.0 / (hp.gamma * self.t + 1))
357        upper = final_lr * (1.0 + 1.0 / (hp.gamma * self.t))
358        return lower, upper
359
360
361class Adam(optimizer.GradientMethod):
362
363    """Adam optimizer.
364
365    See: `Adam: A Method for Stochastic Optimization
366    <https://arxiv.org/abs/1412.6980v8>`_
367
368    Modified for proper weight decay (also called
369    :class:`~chainer.optimizers.AdamW`).
370    AdamW introduces the additional parameters ``eta``
371    and ``weight_decay_rate``, which can be used to properly scale the
372    learning rate, and decouple the weight decay rate from ``alpha``,
373    as shown in the below paper.
374
375    Note that with the default values ``eta = 1`` and
376    ``weight_decay_rate = 0``, this implementation is identical to
377    the standard Adam method.
378
379    See: `Fixing Weight Decay Regularization in Adam
380    <https://openreview.net/forum?id=rk6qdGgCZ>`_
381
382    A flag ``amsgrad`` to use the :class:`~chainer.optimizers.AMSGrad`
383    variant of Adam from the paper:
384    `On the Convergence of Adam and Beyond
385    <https://openreview.net/forum?id=ryQu7f-RZ>`_
386
387    A flag ``adabound`` to use the :class:`~chainer.optimizers.AdaBound`
388    variant of Adam from the paper:
389    `Adaptive Gradient Methods with Dynamic Bound of Learning Rate
390    <https://openreview.net/forum?id=Bkg3g2R9FX>`_
391
392    If both ``amsgrad`` and ``adabound`` are ``True``, the optimizer is
393    equivalent to :class:`~chainer.optimizers.AMSBound` proposed in the
394    AdaBound paper.
395
396    Args:
397        alpha (float): Coefficient of learning rate.
398        beta1 (float): Exponential decay rate of the first order moment.
399        beta2 (float): Exponential decay rate of the second order moment.
400        eps (float): Small value for the numerical stability.
401        eta (float): Schedule multiplier, can be used for warm restarts.
402        weight_decay_rate (float): Weight decay rate.
403        amsgrad (bool): Whether to use AMSGrad variant of Adam.
404        adabound (bool): Whether to use the AdaBound variant of Adam.
405        final_lr (float): Final (SGD) learning rate in AdaBound.
406        gamma (float): Convergence speed of the bound functions in AdaBound.
407
408    """
409
410    def __init__(self,
411                 alpha=_default_hyperparam.alpha,
412                 beta1=_default_hyperparam.beta1,
413                 beta2=_default_hyperparam.beta2,
414                 eps=_default_hyperparam.eps,
415                 eta=_default_hyperparam.eta,
416                 weight_decay_rate=_default_hyperparam.weight_decay_rate,
417                 amsgrad=_default_hyperparam.amsgrad,
418                 adabound=_default_hyperparam.adabound,
419                 final_lr=_default_hyperparam.final_lr,
420                 gamma=_default_hyperparam.gamma):
421        super(Adam, self).__init__()
422        self.hyperparam.alpha = alpha
423        self.hyperparam.beta1 = beta1
424        self.hyperparam.beta2 = beta2
425        self.hyperparam.eps = eps
426        self.hyperparam.eta = eta
427        self.hyperparam.weight_decay_rate = weight_decay_rate
428        self.hyperparam.amsgrad = amsgrad
429        self.hyperparam.adabound = adabound
430        self.hyperparam.final_lr = final_lr
431        self.hyperparam.gamma = gamma
432
433    alpha = optimizer.HyperparameterProxy('alpha')
434    beta1 = optimizer.HyperparameterProxy('beta1')
435    beta2 = optimizer.HyperparameterProxy('beta2')
436    eps = optimizer.HyperparameterProxy('eps')
437    eta = optimizer.HyperparameterProxy('eta')
438    weight_decay_rate = optimizer.HyperparameterProxy('weight_decay_rate')
439    amsgrad = optimizer.HyperparameterProxy('amsgrad')
440    adabound = optimizer.HyperparameterProxy('adabound')
441    final_lr = optimizer.HyperparameterProxy('final_lr')
442    gamma = optimizer.HyperparameterProxy('gamma')
443
444    def create_update_rule(self):
445        return AdamRule(self.hyperparam)
446
447    @property
448    def alpha_t(self):
449        return _learning_rate(self.hyperparam, self.t)
450
451    @property
452    def lr(self):
453        warnings.warn(
454            'Adam.lr has been renamed to AdamRule.alpha_t. '
455            'Use of Adam.lr is deprecated in Chainer v6.',
456            DeprecationWarning)
457        return self.alpha_t
458
459
460class AdamW(Adam):
461
462    """AdamW optimizer.
463
464    This class is a special case of :class:`~chainer.optimizers.Adam`.
465
466    See: `Fixing Weight Decay Regularization in Adam
467    <https://openreview.net/forum?id=rk6qdGgCZ>`_
468
469    Args:
470        alpha (float): Coefficient of learning rate.
471        beta1 (float): Exponential decay rate of the first order moment.
472        beta2 (float): Exponential decay rate of the second order moment.
473        eps (float): Small value for the numerical stability.
474        eta (float): Schedule multiplier, can be used for warm restarts.
475            The default value is 1.0.
476        weight_decay_rate (float): Weight decay rate.
477            The default value is 0.
478    """
479
480    def __init__(self,
481                 alpha=_default_hyperparam.alpha,
482                 beta1=_default_hyperparam.beta1,
483                 beta2=_default_hyperparam.beta2,
484                 eps=_default_hyperparam.eps,
485                 eta=_default_hyperparam.eta,
486                 weight_decay_rate=_default_hyperparam.weight_decay_rate):
487        super(AdamW, self).__init__(
488            alpha=alpha, beta1=beta1, beta2=beta2, eps=eps, eta=eta,
489            weight_decay_rate=weight_decay_rate)
490
491
492class AMSGrad(Adam):
493
494    """AMSGrad optimizer.
495
496    This class is a special case of :class:`~chainer.optimizers.Adam`.
497
498    See: `On the Convergence of Adam and Beyond
499    <https://openreview.net/forum?id=ryQu7f-RZ>`_
500
501    Args:
502        alpha (float): Coefficient of learning rate.
503        beta1 (float): Exponential decay rate of the first order moment.
504        beta2 (float): Exponential decay rate of the second order moment.
505        eps (float): Small value for the numerical stability.
506        eta (float): Schedule multiplier, can be used for warm restarts.
507    """
508
509    def __init__(self,
510                 alpha=_default_hyperparam.alpha,
511                 beta1=_default_hyperparam.beta1,
512                 beta2=_default_hyperparam.beta2,
513                 eps=_default_hyperparam.eps,
514                 eta=_default_hyperparam.eta):
515        super(AMSGrad, self).__init__(
516            alpha=alpha, beta1=beta1, beta2=beta2, eps=eps, eta=eta,
517            amsgrad=True)
518
519
520class AdaBound(Adam):
521
522    """AdaBound optimizer.
523
524    This class is a special case of :class:`~chainer.optimizers.Adam`.
525
526    See: `Adaptive Gradient Methods with Dynamic Bound of Learning Rate
527    <https://openreview.net/forum?id=Bkg3g2R9FX>`_
528
529    Args:
530        alpha (float): Coefficient of learning rate.
531        beta1 (float): Exponential decay rate of the first order moment.
532        beta2 (float): Exponential decay rate of the second order moment.
533        final_lr (float): Final (SGD) learning rate in AdaBound.
534        gamma (float): Convergence speed of the bound functions in AdaBound.
535        eps (float): Small value for the numerical stability.
536        eta (float): Schedule multiplier, can be used for warm restarts.
537    """
538
539    def __init__(self,
540                 alpha=_default_hyperparam.alpha,
541                 beta1=_default_hyperparam.beta1,
542                 beta2=_default_hyperparam.beta2,
543                 final_lr=_default_hyperparam.final_lr,
544                 gamma=_default_hyperparam.gamma,
545                 eps=_default_hyperparam.eps,
546                 eta=_default_hyperparam.eta):
547        super(AdaBound, self).__init__(
548            alpha=alpha, beta1=beta1, beta2=beta2, eps=eps, eta=eta,
549            amsgrad=False, adabound=True, final_lr=final_lr, gamma=gamma)
550
551
552class AMSBound(Adam):
553
554    """AMSBound optimizer.
555
556    This class is a special case of :class:`~chainer.optimizers.Adam`.
557
558    See: `Adaptive Gradient Methods with Dynamic Bound of Learning Rate
559    <https://openreview.net/forum?id=Bkg3g2R9FX>`_
560
561    Args:
562        alpha (float): Coefficient of learning rate.
563        beta1 (float): Exponential decay rate of the first order moment.
564        beta2 (float): Exponential decay rate of the second order moment.
565        final_lr (float): Final (SGD) learning rate in AdaBound.
566        gamma (float): Convergence speed of the bound functions in AdaBound.
567        eps (float): Small value for the numerical stability.
568        eta (float): Schedule multiplier, can be used for warm restarts.
569    """
570
571    def __init__(self,
572                 alpha=_default_hyperparam.alpha,
573                 beta1=_default_hyperparam.beta1,
574                 beta2=_default_hyperparam.beta2,
575                 final_lr=_default_hyperparam.final_lr,
576                 gamma=_default_hyperparam.gamma,
577                 eps=_default_hyperparam.eps,
578                 eta=_default_hyperparam.eta):
579        super(AMSBound, self).__init__(
580            alpha=alpha, beta1=beta1, beta2=beta2, eps=eps, eta=eta,
581            amsgrad=True, adabound=True, final_lr=final_lr, gamma=gamma)
582