1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18package AI::MXNet::Optimizer;
19use strict;
20use warnings;
21use AI::MXNet::NS;
22use AI::MXNet::Base;
23use AI::MXNet::NDArray;
24use AI::MXNet::Random;
25use List::Util qw(max);
26
27=head1 NAME
28
29    AI::MXNet::Optimizer - Common Optimization algorithms with regularizations.
30
31=head1  DESCRIPTION
32
33    Common Optimization algorithms with regularizations.
34=cut
35
36use Mouse;
37use AI::MXNet::Function::Parameters;
38my %opt_registry;
39method get_opt_registry()
40{
41    return \%opt_registry;
42}
43
44method register()
45{
46    my $name = $self;
47    ($name) = $name =~ /::(\w+)$/;
48    {  no strict 'refs'; *{__PACKAGE__."::$name"} = sub { shift; $self->new(@_)  }; }
49    $name = lc $name;
50    if(exists $opt_registry{ $name })
51    {
52        my $existing = $opt_registry{ $name };
53        warn(
54            "WARNING: New optimizer $self.$name"
55            ."is overriding existing optimizer $existing.$name"
56        );
57    }
58    $opt_registry{ $name } = $self;
59}
60
61=head2 create_optimizer
62
63        Create an optimizer with specified name.
64
65        Parameters
66        ----------
67        $name: Str
68            Name of required optimizer. Should be the name
69            of a subclass of Optimizer. Case insensitive.
70
71        :$rescale_grad : Num
72            Rescaling factor on gradient. Normally should be 1/batch_size.
73
74        %kwargs: Hash
75            Parameters for optimizer
76
77        Returns
78        -------
79        opt : Optimizer
80            The result optimizer.
81=cut
82
83method create_optimizer(Str $name, %kwargs)
84{
85    if(exists $opt_registry{ lc $name })
86    {
87        my $rescale_grad = delete($kwargs{rescale_grad})//1;
88        return $opt_registry{ lc $name }->new(
89            rescale_grad => $rescale_grad,
90            %kwargs
91        );
92    }
93    confess("Cannot find optimizer $name");
94}
95
96*create = \&create_optimizer;
97
98has 'rescale_grad'        => (is => "rw", isa => "Num", default=>1);
99has 'lr'                  => (is => "rw", isa => "Num");
100has 'learning_rate'       => (is => "rw", isa => "Num", default => 0.01);
101has 'lr_scheduler'        => (is => "rw", isa => "Maybe[AI::MXNet::LRScheduler]");
102has 'wd'                  => (is => "rw", isa => "Num", default => 0);
103has 'lr_mult'             => (is => "rw", isa => "HashRef", default => sub { +{} });
104has 'wd_mult'             => (is => "rw", isa => "HashRef", , default => sub { +{} });
105has 'num_update'          => (is => "rw", isa => "Int");
106has 'begin_num_update'    => (is => "rw", isa => "Int", default => 0);
107has '_index_update_count' => (is => "rw", isa => "HashRef", default => sub { +{} });
108has 'clip_gradient'       => (is => "rw", isa => "Maybe[Num]");
109has 'param_idx2name'      => (is => "rw", isa => "HashRef[Str]", default => sub { +{} });
110has 'idx2name'            => (is => "rw", isa => "HashRef[Str]");
111has 'sym'                 => (is => "rw", isa => "Maybe[AI::MXNet::Symbol]");
112has 'param_dict'          => (is => "rw", isa => "HashRef", default => sub { +{} });
113
114sub BUILD
115{
116    my $self = shift;
117    if($self->lr_scheduler)
118    {
119        $self->lr_scheduler->base_lr($self->learning_rate);
120    }
121    $self->lr($self->learning_rate);
122    $self->num_update($self->begin_num_update);
123    $self->idx2name({ %{ $self->param_idx2name } });
124    $self->set_lr_mult({});
125    $self->set_wd_mult({});
126}
127# Create additional optimizer state such as momentum.
128# override in implementations.
129method create_state($index, $weight){}
130
131# Update the parameters. override in implementations
132method update($index, $weight, $grad, $state){}
133
134# set lr scale is deprecated. Use set_lr_mult instead.
135method set_lr_scale($args_lrscale)
136{
137    Carp::cluck("set lr scale is deprecated. Use set_lr_mult instead.");
138}
139
140=head2 set_lr_mult
141
142        Set individual learning rate multipler for parameters
143
144        Parameters
145        ----------
146        args_lr_mult : dict of string/int to float
147            set the lr multipler for name/index to float.
148            setting multipler by index is supported for backward compatibility,
149            but we recommend using name and symbol.
150=cut
151
152method set_lr_mult(HashRef[Num] $args_lr_mult)
153{
154    $self->lr_mult({});
155    if($self->sym)
156    {
157        my $attr = $self->sym->attr_dict();
158        for my $name (@{ $self->sym->list_arguments() })
159        {
160            if(exists $attr->{ $name } and exists $attr->{ $name }{ __lr_mult__ })
161            {
162                $self->lr_mult->{ $name } = $attr->{ $name }{ __lr_mult__ };
163            }
164        }
165    }
166    $self->lr_mult({ %{ $self->lr_mult }, %{ $args_lr_mult } });
167}
168
169=head2 set_wd_mult
170
171        Set individual weight decay multipler for parameters.
172        By default wd multipler is 0 for all params whose name doesn't
173        end with _weight, if param_idx2name is provided.
174
175        Parameters
176        ----------
177        args_wd_mult : dict of string/int to float
178            set the wd multipler for name/index to float.
179            setting multipler by index is supported for backward compatibility,
180            but we recommend using name and symbol.
181=cut
182
183method set_wd_mult(HashRef[Num] $args_wd_mult)
184{
185    $self->wd_mult({});
186    for my $n (values %{ $self->idx2name })
187    {
188        if(not $n =~ /(?:_weight|_gamma)$/)
189        {
190            $self->wd_mult->{ $n } = 0;
191        }
192    }
193    if($self->sym)
194    {
195        my $attr = $self->sym->attr_dict();
196        for my $name (@{ $self->sym->list_arguments() })
197        {
198            if(exists $attr->{ $name } and exists $attr->{ $name }{ __wd_mult__ })
199            {
200                $self->wd_mult->{ $name } = $attr->{ $name }{ __wd_mult__ };
201            }
202        }
203    }
204    $self->wd_mult({ %{ $self->wd_mult }, %{ $args_wd_mult } });
205}
206
207method _update_count(Index $index)
208{
209    if(not exists $self->_index_update_count->{ $index })
210    {
211        $self->_index_update_count->{ $index } = $self->begin_num_update;
212    }
213    $self->_index_update_count->{ $index } += 1;
214    $self->num_update(max($self->_index_update_count->{ $index }, $self->num_update));
215}
216
217method _get_lr(Index $index)
218{
219    my $lr;
220    if($self->lr_scheduler)
221    {
222        $lr = $self->lr_scheduler->($self->num_update);
223    }
224    else
225    {
226        $lr = $self->lr;
227    }
228
229    if(exists $self->param_dict->{ $index })
230    {
231        $lr *= $self->param_dict->{ $index }->lr_mult;
232    }
233    elsif(exists $self->lr_mult->{ $index })
234    {
235        $lr *= $self->lr_mult->{ $index };
236    }
237    elsif(exists $self->idx2name->{ $index })
238    {
239        $lr *= $self->lr_mult->{ $self->idx2name->{ $index } }//1;
240    }
241    return $lr;
242}
243
244method _get_wd(Index $index)
245{
246    my $wd = $self->wd;
247    if(exists $self->param_dict->{ $index })
248    {
249        $wd *= $self->param_dict->{ $index }->wd_mult;
250    }
251    elsif(exists $self->wd_mult->{ $index })
252    {
253        $wd *= $self->wd_mult->{ $index };
254    }
255    elsif(exists $self->idx2name->{ $index })
256    {
257        $wd *= $self->wd_mult->{ $self->idx2name->{ $index } }//1;
258    }
259    return $wd;
260}
261
262=head1 NAME
263
264    AI::MXNet::SGD - A very simple SGD optimizer with momentum and weight regularization.
265=cut
266
267=head1 DESCRIPTION
268
269    A very simple SGD optimizer with momentum and weight regularization.
270
271    If the storage types of weight and grad are both 'row_sparse', and 'lazy_update' is True,
272    **lazy updates** are applied by
273
274        for row in grad.indices:
275            rescaled_grad[row] = lr * rescale_grad * clip(grad[row], clip_gradient) + wd * weight[row]
276            state[row] = momentum[row] * state[row] + rescaled_grad[row]
277            weight[row] = weight[row] - state[row]
278
279    The sparse update only updates the momentum for the weights whose row_sparse
280    gradient indices appear in the current batch, rather than updating it for all
281    indices. Compared with the original update, it can provide large
282    improvements in model training throughput for some applications. However, it
283    provides slightly different semantics than the original update, and
284    may lead to different empirical results.
285
286    Otherwise, **standard updates** are applied by::
287
288        rescaled_grad = lr * rescale_grad * clip(grad, clip_gradient) + wd * weight
289        state = momentum * state + rescaled_grad
290        weight = weight - state
291
292    Parameters
293    ----------
294    learning_rate : Num, optional
295        learning_rate of SGD
296
297    momentum : Num, optional
298       momentum value
299
300    wd : Num, optional
301        L2 regularization coefficient add to all the weights
302
303    rescale_grad : Num, optional
304        rescaling factor of gradient. Normally should be 1/batch_size.
305
306    clip_gradient : Num, optional
307        clip gradient in range [-clip_gradient, clip_gradient]
308
309    param_idx2name : hash ref of Str/Int to Num, optional
310        special treat weight decay in parameter ends with bias, gamma, and beta
311
312    multi_precision: Bool, optional
313        Flag to control the internal precision of the optimizer.
314        False results in using the same precision as the weights (default),
315        True makes internal 32-bit copy of the weights and applies gradients
316        in 32-bit precision even if actual weights used in the model have lower precision.
317        Turning this on can improve convergence and accuracy when training with float16.
318
319    lazy_update: Bool, optional, default true
320=cut
321
322package AI::MXNet::SGD;
323use Mouse;
324extends 'AI::MXNet::Optimizer';
325
326has 'kwargs'          => (is => "rw", isa => "HashRef[Num]");
327has 'momentum'        => (is => "rw", isa => "Num", default => 0);
328has 'multi_precision' => (is => "ro", isa => "Bool", default => 0);
329has 'lazy_update'     => (is => "ro", isa => "Bool", default => 1);
330
331sub BUILD
332{
333    my $self = shift;
334    $self->kwargs({});
335    if($self->momentum)
336    {
337        $self->kwargs->{momentum} = $self->momentum;
338    }
339    if($self->clip_gradient)
340    {
341        $self->kwargs->{clip_gradient} = $self->clip_gradient;
342    }
343}
344
345method create_state(Index $index, AI::MXNet::NDArray $weight)
346{
347    my $momentum;
348    my $weight_master_copy;
349    my $stype = $self->lazy_update ? $weight->stype : 'default';
350    if($self->multi_precision and $weight->dtype eq 'float16')
351    {
352        my $weight_master_copy = AI::MXNet::NDArray->array($weight, ctx => $weight->context, dtype => 'float32');
353        if($self->momentum != 0)
354        {
355            $momentum = AI::MXNet::NDArray->zeros($weight->shape, stype => $stype, ctx => $weight->context, dtype => 'float32');
356        }
357        return [$momentum, $weight_master_copy];
358    }
359    if($weight->dtype eq 'float16' and not $self->multi_precision)
360    {
361        AI::MXNet::Logging->warning(
362            "Accumulating with float16 in optimizer can lead to ".
363            "poor accuracy or slow convergence. ".
364            "Consider using multi_precision=True option of the ".
365            "SGD optimizer"
366        );
367    }
368    if($self->momentum != 0)
369    {
370        $momentum = AI::MXNet::NDArray->zeros($weight->shape, stype => $stype, ctx => $weight->context, dtype => $weight->dtype);
371    }
372    return $momentum;
373}
374
375method update(
376    Index                     $index,
377    AI::MXNet::NDArray        $weight,
378    AI::MXNet::NDArray        $grad,
379    Maybe[AI::MXNet::NDArray|ArrayRef[Maybe[AI::MXNet::NDArray]]] $state
380)
381{
382    $self->_update_count($index);
383    my $lr = $self->_get_lr($index);
384    my $wd = $self->_get_wd($index);
385    my $kwargs = {
386        out => $weight,
387        lr  => $lr,
388        wd  => $wd,
389        rescale_grad => $self->rescale_grad,
390        %{ $self->kwargs }
391    };
392    my $use_multi_precision = ref($state) eq 'ARRAY';
393    if(not $use_multi_precision)
394    {
395        if(defined $state)
396        {
397            AI::MXNet::NDArray->sgd_mom_update(
398                $weight, $grad, $state, $kwargs
399            );
400        }
401        else
402        {
403            AI::MXNet::NDArray->sgd_update(
404                $weight, $grad, $kwargs
405            );
406        }
407    }
408    else
409    {
410        if(defined $state->[0])
411        {
412            AI::MXNet::NDArray->mp_sgd_mom_update(
413                $weight, $grad, $state->[0], $state->[1], $kwargs
414            );
415        }
416        else
417        {
418            AI::MXNet::NDArray->mp_sgd_update(
419                $weight, $grad, $state->[1], $kwargs
420            );
421        }
422    }
423}
424
425__PACKAGE__->register;
426
427=head1 NAME
428
429    AI::MXNet::Signum - The Signum optimizer that takes the sign of gradient or momentum.
430=cut
431
432=head1 DESCRIPTION
433
434    The optimizer updates the weight by:
435
436        rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight
437        state = momentum * state + (1-momentum)*rescaled_grad
438        weight = (1 - lr * wd_lh) * weight - lr * sign(state)
439
440    See the original paper at: https://jeremybernste.in/projects/amazon/signum.pdf
441
442    This optimizer accepts the following parameters in addition to those accepted
443    by AI::MXNet::Optimizer
444
445    Parameters
446    ----------
447    momentum : Num, optional
448       The momentum value.
449    wd_lh : Num, optional
450       The amount of decoupled weight decay regularization, see details in the original paper at:
451       https://arxiv.org/abs/1711.05101
452=cut
453
454package AI::MXNet::Signum;
455use Mouse;
456extends 'AI::MXNet::Optimizer';
457
458has 'momentum' => (is => "rw", isa => "Num", default => 0.9);
459has 'wd_lh'    => (is => "rw", isa => "Num", default => 0);
460
461method create_state(Index $index, AI::MXNet::NDArray $weight)
462{
463
464    my $momentum;
465    if($self->momentum != 0)
466    {
467        $momentum = AI::MXNet::NDArray->zeros(
468            $weight->shape,
469            ctx => $weight->context,
470            dtype=>$weight->dtype,
471            stype=>$weight->stype
472        );
473    }
474    return $momentum;
475}
476
477method update(
478    Index                     $index,
479    AI::MXNet::NDArray        $weight,
480    AI::MXNet::NDArray        $grad,
481    Maybe[AI::MXNet::NDArray|ArrayRef[Maybe[AI::MXNet::NDArray]]] $state
482)
483{
484    $self->_update_count($index);
485    my $lr = $self->_get_lr($index);
486    my $wd = $self->_get_wd($index);
487    my %kwargs = (
488        out => $weight,
489        lr  => $lr,
490        wd  => $wd,
491        rescale_grad => $self->rescale_grad,
492    );
493    if($self->momentum > 0)
494    {
495        $kwargs{momentum} = $self->momentum;
496    }
497    if($self->clip_gradient)
498    {
499        $kwargs{clip_gradient} = $self->clip_gradient;
500    }
501    if($self->wd_lh)
502    {
503        $kwargs{wd_lh} = $self->wd_lh;
504    }
505    if(defined $state)
506    {
507        AI::MXNet::NDArray->signum_update(
508            $weight, $grad, $state, %kwargs
509        );
510    }
511    else
512    {
513        AI::MXNet::NDArray->signsgd_update(
514            $weight, $grad, %kwargs
515        );
516    }
517}
518
519__PACKAGE__->register;
520
521=head1 NAME
522
523    AI::MXNet::FTML - The FTML optimizer.
524=cut
525
526=head1 DESCRIPTION
527
528    This class implements the optimizer described in
529    *FTML - Follow the Moving Leader in Deep Learning*,
530    available at http://proceedings.mlr.press/v70/zheng17a/zheng17a.pdf.
531
532    This optimizer accepts the following parameters in addition to those accepted
533    by AI::MXNet::Optimizer
534
535    Parameters
536    ----------
537    beta1 : Num, optional
538        0 < beta1 < 1. Generally close to 0.5.
539    beta2 : Num, optional
540        0 < beta2 < 1. Generally close to 1.
541    epsilon : Num, optional
542        Small value to avoid division by 0.
543=cut
544
545package AI::MXNet::FTML;
546use Mouse;
547extends 'AI::MXNet::Optimizer';
548
549has 'beta1'   => (is => "rw", isa => "Num", default => 0.6);
550has 'beta2'   => (is => "rw", isa => "Num", default => 0.999);
551has 'epsilon' => (is => "rw", isa => "Num", default => 1e-8);
552
553method create_state(Index $index, AI::MXNet::NDArray $weight)
554{
555    return [
556        AI::MXNet::NDArray->zeros($weight->shape, ctx => $weight->context, dtype=>$weight->dtype), # d_0
557        AI::MXNet::NDArray->zeros($weight->shape, ctx => $weight->context, dtype=>$weight->dtype), # v_0
558        AI::MXNet::NDArray->zeros($weight->shape, ctx => $weight->context, dtype=>$weight->dtype), # z_0
559    ];
560}
561
562method update(
563    Index                     $index,
564    AI::MXNet::NDArray        $weight,
565    AI::MXNet::NDArray        $grad,
566    Maybe[AI::MXNet::NDArray|ArrayRef[Maybe[AI::MXNet::NDArray]]] $state
567)
568{
569    my $lr = $self->_get_lr($index);
570    my $wd = $self->_get_wd($index);
571    my $t = $self->_update_count($index);
572    my %kwargs = (
573        out => $weight,
574        lr  => $lr,
575        wd  => $wd,
576        t   => $t,
577        beta1 => $self->beta1,
578        beta2 => $self->beta2,
579        epsilon => $self->epsilon,
580        rescale_grad => $self->rescale_grad
581    );
582    if($self->clip_gradient)
583    {
584        $kwargs{clip_grad} = $self->clip_gradient;
585    }
586    AI::MXNet::NDArray->ftml_update($weight, $grad, @{ $state }, \%kwargs);
587}
588
589__PACKAGE__->register;
590
591=head1 NAME
592
593    AI::MXNet::LBSGD - The Large Batch SGD optimizer with momentum and weight decay.
594=cut
595
596=head1 DESCRIPTION
597
598    The optimizer updates the weight by::
599
600        state = momentum * state + lr * rescale_grad * clip(grad, clip_gradient) + wd * weight
601        weight = weight - state
602
603    Parameters
604    ----------
605    momentum : Num, optional
606       The momentum value.
607    multi_precision: Bool, optional
608       Flag to control the internal precision of the optimizer.
609       0 results in using the same precision as the weights (default),
610       1 makes internal 32-bit copy of the weights and applies gradients
611                in 32-bit precision even if actual weights used in the model have lower precision.`<
612                Turning this on can improve convergence and accuracy when training with float16.
613    warmup_strategy: string ('linear', 'power2', 'sqrt'. , 'lars'   default : 'linear')
614    warmup_epochs: unsigned, default: 5
615    batch_scale:   unsigned, default: 1 (same as batch size*numworkers)
616    updates_per_epoch: updates_per_epoch (default: 32, Default might not reflect true number batches per epoch. Used for warmup.)
617    begin_epoch: unsigned, default 0, starting epoch.
618=cut
619
620package AI::MXNet::LBSGD;
621use Mouse;
622extends 'AI::MXNet::Optimizer';
623
624has 'momentum'          => (is => 'rw', isa => 'Num', default => 0);
625has 'multi_precision'   => (is => 'rw', isa => 'Bool', default => 0);
626has 'warmup_startegy'   => (is => 'rw', isa => 'Str', default => 'linear');
627has 'warmup_epochs'     => (is => 'rw', isa => 'Int', default => 5);
628has 'batch_scale'       => (is => 'rw', isa => 'Num', default => 1);
629has 'updates_per_epoch' => (is => 'rw', isa => 'Int', default => 32);
630has 'begin_epoch'       => (is => 'rw', isa => 'Int', default => 0);
631has 'num_epochs'        => (is => 'rw', isa => 'Int', default => 60);
632has 'beta2'             => (is => 'rw', isa => 'Num', default => 0.999);
633has 'epsilon'           => (is => 'rw', isa => 'Num', default => 1e-8);
634has 'init_updates'      => (is => 'rw', init_arg => undef);
635has [qw/lbmult
636        cumgrads
637        adaptive
638        init_updates
639        admult/]        => (is => 'rw', init_arg => undef);
640
641sub BUILD
642{
643    my $self = shift;
644    AI::MXNet::Logging->info('Running Large-Batch SGD Algorithm');
645    AI::MXNet::Logging->info(
646        '(Batch_scale=%f, warmup_epochs=%d, warmup_strategy=%s, updates_per_epoch=%d)',
647        map { $self->$_ } qw/batch_scale warmup_epochs warmup_strategy updates_per_epoch/
648    );
649    $self->init_updates($self->begin_epoch * $self->updates_per_epoch);
650    $self->lbmult(1);
651    $self->cumgrads({});
652    $self->adaptive(0);
653    $self->admult(1);
654}
655
656method create_state(Index $index, AI::MXNet::NDArray $weight)
657{
658    return [
659        AI::MXNet::NDArray->zeros($weight->shape, ctx => $weight->context, dtype=>$weight->dtype), # d_0
660        AI::MXNet::NDArray->zeros($weight->shape, ctx => $weight->context, dtype=>$weight->dtype), # v_0
661        AI::MXNet::NDArray->zeros($weight->shape, ctx => $weight->context, dtype=>$weight->dtype), # z_0
662    ];
663    my $momentum;
664    my $weight_master_copy;
665    if($self->multi_precision and $weight->dtype eq 'float16')
666    {
667        $weight_master_copy = AI::MXNet::NDArray->array($weight, ctx=>$weight->context, dtype=>'float32');
668        if($self->momentum != 0)
669        {
670            $momentum = AI::MXNet::NDArray->zeros(
671                $weight->shape, ctx => $weight->context, dtype => 'float32',
672                stype => $weight->stype
673            );
674        }
675        return [$momentum, $weight_master_copy];
676    }
677    if($weight->dtype eq 'float16' and not $self->multi_precision)
678    {
679        AI::MXNet::Logging->warning(
680            "Accumulating with float16 in optimizer can lead to "
681            ."poor accuracy or slow convergence. "
682            ."Consider using multi_precision=True option of the "
683            ."LBSGD optimizer"
684        );
685    }
686    if($self->momentum != 0)
687    {
688        $momentum = AI::MXNet::NDArray->zeros(
689            $weight->shape, ctx => $weight->context, dtype => $weight->dtype,
690            stype => $weight->stype
691        );
692    }
693    return $momentum;
694}
695
696method _get_lbmult($nup)
697{
698    my $nwup = $self->warmup_epochs * $self->updates_per_epoch;
699    my $strategy = $self->warmup_strategy;
700    my $maxmult = $self->batch_scale;
701    my $mult;
702    if($nup >= $nwup)
703    {
704        $mult = $maxmult;
705    }
706    elsif($nwup <= 1)
707    {
708        $mult = 1;
709    }
710    else
711    {
712        if ($strategy eq 'linear')
713        {
714            $mult = 1 + ($maxmult - 1) * $nup / $nwup;
715        }
716        elsif($strategy eq 'power2')
717        {
718            $mult = 1 + ($maxmult-1) * ($nup*$nup)/($nwup*$nwup);
719        }
720        elsif($strategy eq 'sqrt')
721        {
722            $mult = 1 + ($maxmult - 1) * sqrt($nup / $nwup);
723        }
724        else
725        {
726            $mult = 1;
727        }
728    }
729    return $mult;
730}
731
732
733method _get_lars($weight, $g, $wd)
734{
735    my $weight2 = $self->_l2norm($weight);
736    my $grad2 = $self->_l2norm($g);
737    my $lars = sqrt($weight2 / ($grad2 + $wd * $weight2 + 1e-18));
738    if($lars < 0.01)
739    {
740        $lars = 0.01;
741    }
742    elsif($lars > 100)
743    {
744        $lars = 100;
745    }
746    return $lars;
747}
748
749method _l2norm($v)
750{
751    my $norm = AI::MXNet::NDArray->multiply($v, $v)->aspdl->sum;
752    return $norm;
753}
754
755method  _reset_cum_gradient($index)
756{
757    $self->cumgrads->{$index}{cum_grad} = 0;
758}
759
760method _get_cum_gradient($index)
761{
762    if(exists $self->cumgrads->{$index})
763    {
764        return $self->cumgrads->{$index};
765    }
766    else
767    {
768        return {}
769    }
770}
771
772method _put_cum_gradient($index, $cgrad)
773{
774    $self->cumgrads->{$index} = $cgrad;
775}
776
777method _cumulate_gradient($grad, $index)
778{
779    my $cgrad = $self->_get_cum_gradient($index);
780    my ($num_cums, $cum_grad);
781    if(%{ $cgrad })
782    {
783        my $num_cums = $cgrad->{num_cums};
784        if($num_cums > 0)
785        {
786            $cum_grad = $cgrad->{cum_grad} + $grad;
787            $num_cums += 1;
788        }
789        else
790        {
791            $cum_grad = $grad;
792            $num_cums = $self->init_updates + 1;
793        }
794    }
795    else
796    {
797        $cum_grad = $grad;
798        $num_cums = $self->init_updates + 1;
799    }
800    $cgrad = {cum_grad => $cum_grad, num_cums => $num_cums};
801    $self->_put_cum_gradient($index, $cgrad);
802    return $cgrad;
803}
804
805
806
807method update(
808    Index                     $index,
809    AI::MXNet::NDArray        $weight,
810    AI::MXNet::NDArray        $grad,
811    Maybe[AI::MXNet::NDArray|ArrayRef[Maybe[AI::MXNet::NDArray]]] $state
812)
813{
814    my $lr = $self->_get_lr($index);
815    my $wd = $self->_get_wd($index);
816    my $t = $self->_update_count($index);
817    my $cgrad = $self->_cumulate_gradient($grad, $index);
818    if(($cgrad->{num_cums} % $self->batch_scale) == 0)
819    {
820        my $lbmult;
821        $grad = $cgrad->{cum_grad} / $self->batch_scale;
822        if($self->warmup_strategy eq 'lars')
823        {
824            $lbmult = $self->_get_lars($weight, $grad, $wd);
825        }
826        else
827        {
828            $lbmult = $self->_get_lbmult($cgrad->{num_cums});
829        }
830        $lr = $lr * $lbmult;
831        my %kwargs = (
832            out => $weight,
833            lr  => $lr,
834            wd  => $wd,
835            rescale_grad => $self->rescale_grad
836        );
837        if($self->clip_gradient)
838        {
839            $kwargs{clip_gradient} = $self->clip_gradient;
840        }
841        if($self->momentum > 0)
842        {
843            $kwargs{momentum} = $self->momentum;
844        }
845        my $use_multi_precision = ref($state) eq 'ARRAY';
846        if(not $use_multi_precision)
847        {
848            if(defined $state)
849            {
850                AI::MXNet::NDArray->sgd_mom_update($weight, $grad, $state, %kwargs);
851            }
852            else
853            {
854                AI::MXNet::NDArray->sgd_update($weight, $grad, %kwargs);
855            }
856        }
857        else
858        {
859            if(defined $state->[0])
860            {
861                AI::MXNet::NDArray->mp_sgd_mom_update($weight, $grad, @{ $state }, %kwargs);
862            }
863            else
864            {
865                AI::MXNet::NDArray->mp_sgd_update($weight, $grad, $state->[1], %kwargs);
866            }
867        }
868        $self->_reset_cum_gradient($index);
869    }
870    else
871    {
872        AI::MXNet::NDArray->sgd_update($weight, $grad, out => $weight, lr => 0, wd => $wd);
873    }
874}
875
876__PACKAGE__->register;
877
878package AI::MXNet::DCASGD;
879use Mouse;
880use AI::MXNet::Base;
881extends 'AI::MXNet::Optimizer';
882
883=head1 NAME
884
885    AI::MXNet::DCASGD - DCASGD optimizer with momentum and weight regularization.
886=cut
887
888=head1 DESCRIPTION
889
890    DCASGD optimizer with momentum and weight regularization.
891
892    Implements paper "Asynchronous Stochastic Gradient Descent with
893                    Delay Compensation for Distributed Deep Learning"
894
895    Parameters
896    ----------
897    learning_rate : Num, optional
898        learning_rate of SGD
899
900    momentum : Num, optional
901       momentum value
902
903    lamda : NUm, optional
904       scale DC value
905
906    wd : Num, optional
907        L2 regularization coefficient add to all the weights
908
909    rescale_grad : Num, optional
910        rescaling factor of gradient. Normally should be 1/batch_size.
911
912    clip_gradient : Num, optional
913        clip gradient in range [-clip_gradient, clip_gradient]
914
915    param_idx2name : hash ref of Str/Int to Num, optional
916        special threating of weight decay for parameters that end with bias, gamma, and beta
917=cut
918has 'momentum'        => (is => 'ro', isa => 'Num', default => 0);
919has 'lamda'           => (is => 'ro', isa => 'Num', default => 0.04);
920has 'weight_previous' => (is => 'rw', init_arg => undef);
921
922sub BUILD
923{
924    my $self = shift;
925    $self->weight_previous({});
926}
927
928method create_state(Index $index, AI::MXNet::NDArray $weight)
929{
930        return [
931            $self->momentum ? AI::MXNet::NDArray->zeros(
932                $weight->shape, ctx => $weight->context, dtype => $weight->dtype
933            ) : undef,
934            $weight->copy
935        ];
936}
937
938method update(
939    Index                     $index,
940    AI::MXNet::NDArray        $weight,
941    AI::MXNet::NDArray        $grad,
942    Maybe[AI::MXNet::NDArray|ArrayRef[Maybe[AI::MXNet::NDArray]]] $state
943)
944{
945    my $lr = $self->_get_lr($index);
946    my $wd = $self->_get_wd($index);
947    $self->_update_count($index);
948    $grad *= $self->rescale_grad;
949    if($self->clip_gradient)
950    {
951        $grad = AI::MXNet::NDArray->clip(
952            $grad,
953            -$self->clip_gradient,
954            $self->clip_gradient
955        );
956    }
957    my ($mom, $weight_previous) = @{ $state };
958    if(defined $mom)
959    {
960        $mom *= $self->momentum;
961        $mom += -$lr * (
962                $grad + $wd * $weight
963                    +
964                $self->lamda * $grad * $grad * ($weight - $weight_previous)
965        );
966    }
967    else
968    {
969        assert($self->momentum == 0);
970        $mom = -$lr * (
971                $grad + $wd * $weight
972                    +
973                $self->lamda * $grad * $grad * ($weight - $weight_previous)
974        );
975    }
976    $weight_previous .= $weight;
977    $weight += $mom;
978}
979
980__PACKAGE__->register;
981
982=head1 NAME
983
984    AI::MXNet::NAG - SGD with Nesterov weight handling.
985=cut
986
987=head1 DESCRIPTION
988
989    It is implemented according to
990    https://github.com/torch/optim/blob/master/sgd.lua
991=cut
992
993package AI::MXNet::NAG;
994use Mouse;
995extends 'AI::MXNet::SGD';
996
997method create_state(Index $index, AI::MXNet::NDArray $weight)
998{
999    my $momentum;
1000    my $weight_master_copy;
1001    my $do_multi_precision = ($self->multi_precision and $weight->dtype eq 'float16');
1002    if($do_multi_precision)
1003    {
1004        if($self->momentum != 0)
1005        {
1006            $momentum = AI::MXNet::NDArray->zeros($weight->shape, ctx => $weight->context, dtype=>'float32');
1007        }
1008        $weight_master_copy = AI::MXNet::NDArray->array($weight, ctx=>$weight->context, dtype=>'float32');
1009        return [$weight_master_copy, $momentum];
1010    }
1011    else
1012    {
1013        if($self->momentum != 0)
1014        {
1015            $momentum = AI::MXNet::NDArray->zeros($weight->shape, ctx => $weight->context, dtype=>$weight->dtype);
1016        }
1017        return $momentum;
1018    }
1019}
1020
1021method update($index, $weight, $grad, $state)
1022{
1023    my $lr = $self->_get_lr($index);
1024    my $wd = $self->_get_wd($index);
1025    $self->_update_count($index);
1026    my $use_multi_precision = (defined $state and not Scalar::Util::blessed($state) and ref($state eq 'ARRAY'));
1027    if(not $use_multi_precision)
1028    {
1029        $grad *= $self->rescale_grad;
1030        if(defined $self->clip_gradient)
1031        {
1032            $grad = AI::MXNet::NDArray->clip($grad, -$self->clip_gradient, $self->clip_gradient);
1033        }
1034        if($self->momentum == 0)
1035        {
1036            $weight += -$lr * ($grad + $wd * $weight);
1037        }
1038        else
1039        {
1040            my $mom = $state;
1041            $mom *= $self->momentum;
1042            $grad += $wd * $weight;
1043            $mom += $grad;
1044            $grad += $self->momentum * $mom;
1045            $weight += -$lr * $grad;
1046        }
1047    }
1048    else
1049    {
1050        my $grad32 = AI::MXNet::NDArray->array($grad, ctx=>$grad->context, dtype=>'float32');
1051        $grad32 *= $self->rescale_grad;
1052        if(defined $self->clip_gradient)
1053        {
1054            $grad32 = AI::MXNet::NDArray->clip($grad32, -$self->clip_gradient, $self->clip_gradient);
1055        }
1056        my $mom = $state->[1];
1057        my $weight32 = $state->[0];
1058        if($self->momentum == 0)
1059        {
1060            $weight32 += -$lr * ($grad32 + $wd * $weight32);
1061        }
1062        else
1063        {
1064            $mom *= $self->momentum;
1065            $grad32 += $wd * $weight32;
1066            $mom += $grad32;
1067            $grad32 += $self->momentum * $mom;
1068            $weight32 += -$lr * $grad32;
1069        }
1070        my $tmp = $weight32->astype($weight->dtype);
1071        $tmp->copyto($weight);
1072    }
1073}
1074
1075__PACKAGE__->register;
1076
1077=head1 NAME
1078
1079    AI::MXNet::SGLD - Stochastic Gradient Riemannian Langevin Dynamics.
1080=cut
1081
1082=head1 DESCRIPTION
1083
1084    Stochastic Gradient Riemannian Langevin Dynamics.
1085
1086    This class implements the optimizer described in the paper *Stochastic Gradient
1087    Riemannian Langevin Dynamics on the Probability Simplex*, available at
1088    https://papers.nips.cc/paper/4883-stochastic-gradient-riemannian-langevin-dynamics-on-the-probability-simplex.pdf.
1089
1090    Parameters
1091    ----------
1092    learning_rate : Num, optional
1093        learning_rate of SGD
1094
1095    wd : Num, optional
1096        L2 regularization coefficient add to all the weights
1097
1098    rescale_grad : Num, optional
1099        rescaling factor of gradient. Normally should be 1/batch_size.
1100
1101    clip_gradient : Num, optional
1102        clip gradient in range [-clip_gradient, clip_gradient]
1103=cut
1104
1105package AI::MXNet::SGLD;
1106use Mouse;
1107
1108extends 'AI::MXNet::Optimizer';
1109
1110method create_state(Index $index, AI::MXNet::NDArray $weight)
1111{
1112    return undef;
1113}
1114
1115method update(
1116    Index $index,
1117    AI::MXNet::NDArray $weight,
1118    AI::MXNet::NDArray $grad,
1119    AI::MXNet::NDArray|Undef $state
1120)
1121{
1122    my $lr = $self->_get_lr($index);
1123    my $wd = $self->_get_wd($index);
1124    $self->_update_count($index);
1125    $grad *= $self->rescale_grad;
1126    if($self->clip_gradient)
1127    {
1128        $grad = AI::MXNet::NDArray->clip(
1129            $grad,
1130            -$self->clip_gradient,
1131             $self->clip_gradient
1132        );
1133    }
1134    $weight +=  - $lr/2 * ($grad + $wd * $weight)
1135                    +
1136                AI::MXNet::Random->normal(
1137                        0, sqrt($lr),
1138                        shape => $weight->shape,
1139                        ctx   => $weight->context,
1140                        dtype => $weight->dtype
1141                );
1142}
1143
1144__PACKAGE__->register;
1145
1146=head1 NAME
1147
1148    AI::MXNet::Adam - Adam optimizer as described in [King2014]_.
1149=cut
1150
1151=head1 DESCRIPTION
1152
1153    Adam optimizer as described in [King2014]_.
1154
1155    .. [King2014] Diederik Kingma, Jimmy Ba,
1156       *Adam: A Method for Stochastic Optimization*,
1157       http://arxiv.org/abs/1412.6980
1158
1159    Parameters
1160    ----------
1161    learning_rate : Num, optional
1162        Step size.
1163        Default value is set to 0.001.
1164    beta1 : Num, optional
1165        Exponential decay rate for the first moment estimates.
1166        Default value is set to 0.9.
1167    beta2 : Num, optional
1168        Exponential decay rate for the second moment estimates.
1169        Default value is set to 0.999.
1170    epsilon : Num, optional
1171        Default value is set to 1e-8.
1172
1173    wd : NUm, optional
1174        L2 regularization coefficient add to all the weights
1175    rescale_grad : Num, optional
1176        rescaling factor of gradient. Normally should be 1/batch_size.
1177
1178    clip_gradient : Num, optional
1179        clip gradient in range [-clip_gradient, clip_gradient]
1180=cut
1181package AI::MXNet::Adam;
1182use Mouse;
1183
1184extends 'AI::MXNet::Optimizer';
1185
1186has 'kwargs'   => (is => "rw", isa => "HashRef[Num]");
1187has '+learning_rate' => (default => 0.001);
1188has 'beta1'    => (is => "rw", isa => "Num", default => 0.9);
1189has 'beta2'    => (is => "rw", isa => "Num", default => 0.999);
1190has 'epsilon'  => (is => "rw", isa => "Num", default => 1e-8);
1191has 'lazy_update' => (is => 'rw', isa => 'Bool', default => 1);
1192
1193sub BUILD
1194{
1195    my $self = shift;
1196    $self->kwargs({
1197        beta1   => $self->beta1,
1198        beta2   => $self->beta2,
1199        epsilon => $self->epsilon
1200    });
1201    if($self->clip_gradient)
1202    {
1203        $self->kwargs->{clip_gradient} = $self->clip_gradient;
1204    }
1205}
1206
1207method create_state(Index $index, AI::MXNet::NDArray $weight)
1208{
1209    my $stype = $self->lazy_update ? $weight->stype : 'default';
1210    return [AI::MXNet::NDArray->zeros(
1211                $weight->shape,
1212                ctx => $weight->context,
1213                dtype => $weight->dtype,
1214                stype => $stype
1215            ),  # mean
1216            AI::MXNet::NDArray->zeros(
1217                $weight->shape,
1218                ctx => $weight->context,
1219                dtype => $weight->dtype,
1220                stype => $stype
1221            )  # variance
1222    ];
1223}
1224
1225method update(
1226    Index $index,
1227    AI::MXNet::NDArray $weight,
1228    AI::MXNet::NDArray $grad,
1229    ArrayRef[AI::MXNet::NDArray] $state
1230)
1231{
1232    my $lr = $self->_get_lr($index);
1233    my $wd = $self->_get_wd($index);
1234    $self->_update_count($index);
1235    my $t = $self->_index_update_count->{$index};
1236    my $coef1 = 1 - $self->beta1**$t;
1237    my $coef2 = 1 - $self->beta2**$t;
1238    $lr *= sqrt($coef2)/$coef1;
1239    my ($mean, $var) = @{ $state };
1240    AI::MXNet::NDArray->adam_update(
1241        $weight, $grad, $mean, $var,
1242        {
1243            out => $weight,
1244            lr  => $lr,
1245            wd  => $wd,
1246            rescale_grad => $self->rescale_grad,
1247            %{ $self->kwargs }
1248        }
1249    );
1250}
1251
1252__PACKAGE__->register;
1253
1254=head1 NAME
1255
1256    AI::MXNet::AdaGrad - AdaGrad optimizer of Duchi et al., 2011
1257=cut
1258
1259=head1 DESCRIPTION
1260
1261    AdaGrad optimizer of Duchi et al., 2011,
1262
1263    This code follows the version in http://arxiv.org/pdf/1212.5701v1.pdf  Eq(5)
1264    by Matthew D. Zeiler, 2012. AdaGrad will help the network to converge faster
1265    in some cases.
1266
1267    Parameters
1268    ----------
1269    learning_rate : Num, optional
1270        Step size.
1271        Default value is set to 0.05.
1272
1273    wd : Num, optional
1274        L2 regularization coefficient add to all the weights
1275
1276    rescale_grad : Num, optional
1277        rescaling factor of gradient. Normally should be 1/batch_size.
1278
1279    eps: Num, optional
1280        A small float number to make the updating processing stable
1281        Default value is set to 1e-7.
1282
1283    clip_gradient : Num, optional
1284        clip gradient in range [-clip_gradient, clip_gradient]
1285=cut
1286package AI::MXNet::AdaGrad;
1287use Mouse;
1288
1289extends 'AI::MXNet::Optimizer';
1290
1291has 'eps'    => (is => "rw", isa => "Num", default => 1e-7);
1292
1293method create_state(Index $index, AI::MXNet::NDArray $weight)
1294{
1295    return AI::MXNet::NDArray->zeros(
1296                $weight->shape,
1297                ctx => $weight->context,
1298                stype => $weight->stype
1299    );  # history
1300}
1301
1302method update(
1303    Index $index,
1304    AI::MXNet::NDArray $weight,
1305    AI::MXNet::NDArray $grad,
1306    AI::MXNet::NDArray $state
1307)
1308{
1309    my $lr = $self->_get_lr($index);
1310    my $wd = $self->_get_wd($index);
1311    $self->_update_count($index);
1312    my $is_sparse = $grad->stype eq 'row_sparse' ? 1 : 0;
1313    my $history = $state;
1314    if($is_sparse)
1315    {
1316        my %kwargs = (
1317            epsilon => $self->eps,
1318            rescale_grad => $self->rescale_grad
1319        );
1320        if($self->clip_gradient)
1321        {
1322            $kwargs{clip_gradient} = $self->clip_gradient;
1323        }
1324        AI::MXNet::NDArray::Sparse->adagrad_update($weight, $grad, $history, { out=>$weight, lr=>$lr, wd=>$wd, %kwargs });
1325    }
1326    else
1327    {
1328        $grad *= $self->rescale_grad;
1329        if(defined $self->clip_gradient)
1330        {
1331            $grad = AI::MXNet::NDArray->clip($grad, -$self->clip_gradient, $self->clip_gradient);
1332        }
1333        $history += $grad->square;
1334        my $div = $grad / ($history + $self->eps)->sqrt;
1335        $weight += ($div + $weight * $wd) * -$lr;
1336    }
1337}
1338
1339__PACKAGE__->register;
1340
1341=head1 NAME
1342
1343    AI::MXNet::RMSProp - RMSProp optimizer of Tieleman & Hinton, 2012.
1344=cut
1345
1346=head1 DESCRIPTION
1347
1348    RMSProp optimizer of Tieleman & Hinton, 2012,
1349
1350    For centered=False, the code follows the version in
1351    http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf by
1352    Tieleman & Hinton, 2012
1353
1354    For centered=True, the code follows the version in
1355    http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45) by Alex Graves, 2013.
1356
1357    Parameters
1358    ----------
1359    learning_rate : Num, optional
1360        Step size.
1361        Default value is set to 0.001.
1362    gamma1: Num, optional
1363        decay factor of moving average for gradient^2.
1364        Default value is set to 0.9.
1365    gamma2: Num, optional
1366        "momentum" factor.
1367        Default value if set to 0.9.
1368        Only used if centered=True
1369    epsilon : Num, optional
1370        Default value is set to 1e-8.
1371    centered : Bool, optional
1372        Use Graves or Tielemans & Hintons version of RMSProp
1373    wd : Num, optional
1374        L2 regularization coefficient add to all the weights
1375    rescale_grad : Num, optional
1376        rescaling factor of gradient.
1377    clip_gradient : Num, optional
1378        clip gradient in range [-clip_gradient, clip_gradient]
1379    clip_weights : Num, optional
1380        clip weights in range [-clip_weights, clip_weights]
1381=cut
1382
1383package AI::MXNet::RMSProp;
1384use Mouse;
1385
1386extends 'AI::MXNet::Optimizer';
1387
1388has '+learning_rate' => (default => 0.001);
1389has 'gamma1'         => (is => "ro", isa => "Num",  default => 0.9);
1390has 'gamma2'         => (is => "ro", isa => "Num",  default => 0.9);
1391has 'epsilon'        => (is => "ro", isa => "Num",  default => 1e-8);
1392has 'centered'       => (is => "ro", isa => "Bool", default => 0);
1393has 'clip_weights'   => (is => "ro", isa => "Num");
1394has 'kwargs'         => (is => "rw", init_arg => undef);
1395
1396sub BUILD
1397{
1398    my $self = shift;
1399    $self->kwargs({
1400        gamma1       => $self->gamma1,
1401        epsilon      => $self->epsilon
1402    });
1403    if($self->centered)
1404    {
1405        $self->kwargs->{gamma2} = $self->gamma2;
1406    }
1407    if($self->clip_gradient)
1408    {
1409        $self->kwargs->{clip_gradient} = $self->clip_gradient;
1410    }
1411    if($self->clip_weights)
1412    {
1413        $self->kwargs->{clip_weights} = $self->clip_weights;
1414    }
1415}
1416
1417# For centered=False: n
1418# For centered=True: n, g, delta
1419method create_state(Index $index, AI::MXNet::NDArray $weight)
1420{
1421    return [
1422            $self->centered
1423            ? (
1424                AI::MXNet::NDArray->zeros(
1425                    $weight->shape,
1426                    ctx => $weight->context,
1427                    stype => $weight->stype
1428                ),  # n
1429                AI::MXNet::NDArray->zeros(
1430                    $weight->shape,
1431                    ctx => $weight->context,
1432                    stype => $weight->stype
1433                ),  # g
1434                AI::MXNet::NDArray->zeros(
1435                    $weight->shape,
1436                    ctx => $weight->context,
1437                    stype => $weight->stype
1438                )
1439            )   # delta
1440            : (
1441                AI::MXNet::NDArray->zeros(
1442                    $weight->shape,
1443                    ctx => $weight->context,
1444                    stype => $weight->stype
1445                ),  # n
1446            )
1447    ];
1448}
1449
1450method update(
1451    Index $index,
1452    AI::MXNet::NDArray $weight,
1453    AI::MXNet::NDArray $grad,
1454    ArrayRef[AI::MXNet::NDArray] $state
1455)
1456{
1457    my $lr = $self->_get_lr($index);
1458    my $wd = $self->_get_wd($index);
1459    $self->_update_count($index);
1460    my ($n, $g, $delta) = @{ $state };
1461    if($self->centered)
1462    {
1463        AI::MXNet::NDArray->rmspropalex_update(
1464            $weight, $grad, $n, $g, $delta,
1465            {
1466                out => $weight,
1467                lr  => $lr,
1468                wd  => $wd,
1469                rescale_grad => $self->rescale_grad,
1470                %{ $self->kwargs }
1471            }
1472        );
1473    }
1474    else
1475    {
1476        AI::MXNet::NDArray->rmsprop_update(
1477            $weight, $grad, $n,
1478            {
1479                out => $weight,
1480                lr  => $lr,
1481                wd  => $wd,
1482                rescale_grad => $self->rescale_grad,
1483                %{ $self->kwargs }
1484            }
1485        );
1486    }
1487}
1488
1489__PACKAGE__->register;
1490
1491=head1 NAME
1492
1493    AI::MXNet::AdaDelta - AdaDelta optimizer.
1494=cut
1495
1496=head1 DESCRIPTION
1497
1498    AdaDelta optimizer as described in
1499    Zeiler, M. D. (2012).
1500    *ADADELTA: An adaptive learning rate method.*
1501
1502    http://arxiv.org/abs/1212.5701
1503
1504    Parameters
1505    ----------
1506    rho: Num
1507        Decay rate for both squared gradients and delta x
1508    epsilon : Num
1509        The constant as described in the thesis
1510    wd : Num
1511        L2 regularization coefficient add to all the weights
1512    rescale_grad : Num, optional
1513        rescaling factor of gradient. Normally should be 1/batch_size.
1514    clip_gradient : Num, optional
1515        clip gradient in range [-clip_gradient, clip_gradient]
1516=cut
1517package AI::MXNet::AdaDelta;
1518use Mouse;
1519
1520extends 'AI::MXNet::Optimizer';
1521
1522has 'rho'    => (is => "rw", isa => "Num", default => 0.9);
1523has 'epsilon'    => (is => "rw", isa => "Num", default => 1e-5);
1524
1525method create_state(Index $index, AI::MXNet::NDArray $weight)
1526{
1527    return [
1528            AI::MXNet::NDArray->zeros(
1529                $weight->shape,
1530                ctx => $weight->context
1531            ),  # accumulated g
1532            AI::MXNet::NDArray->zeros(
1533                $weight->shape,
1534                ctx => $weight->context
1535            )   # accumulated delta
1536    ];
1537}
1538
1539method update(
1540    Index $index,
1541    AI::MXNet::NDArray $weight,
1542    AI::MXNet::NDArray $grad,
1543    ArrayRef[AI::MXNet::NDArray] $state
1544)
1545{
1546    my $wd = $self->_get_wd($index);
1547    $self->_update_count($index);
1548    $grad *= $self->rescale_grad;
1549    if($self->clip_gradient)
1550    {
1551        $grad = AI::MXNet::NDArray->clip(
1552            $grad,
1553            -$self->clip_gradient,
1554             $self->clip_gradient
1555        );
1556    }
1557    my ($acc_g, $acc_delta) = @{ $state };
1558    $acc_g .= $self->rho * $acc_g + (1 - $self->rho) * $grad * $grad;
1559    my $current_delta = ($acc_delta + $self->epsilon)->sqrt
1560                            /
1561                        ($acc_g + $self->epsilon)->sqrt
1562                            *
1563                        $grad;
1564    $acc_delta .= $self->rho * $acc_delta + (1 - $self->rho) * $current_delta * $current_delta;
1565    $weight -= $current_delta + $wd * $weight;
1566}
1567
1568__PACKAGE__->register;
1569
1570# For test use
1571package AI::MXNet::Test;
1572use Mouse;
1573
1574extends 'AI::MXNet::Optimizer';
1575
1576# Create a state to duplicate weight
1577method create_state(Index $index, AI::MXNet::NDArray $weight)
1578{
1579    return AI::MXNet::NDArray->zeros(
1580                $weight->shape,
1581                ctx => $weight->context
1582    );
1583}
1584
1585# performs w += rescale_grad * grad
1586method update(
1587    Index $index,
1588    AI::MXNet::NDArray $weight,
1589    AI::MXNet::NDArray $grad,
1590    AI::MXNet::NDArray $state
1591)
1592{
1593    $weight += $grad * $self->rescale_grad;
1594    $state .= $weight;
1595}
1596
1597__PACKAGE__->register;
1598
1599package AI::MXNet::Ftrl;
1600
1601
1602=head1 NAME
1603
1604    AI::MXNet::Ftrl
1605=cut
1606
1607=head1 DESCRIPTION
1608
1609    Referenced from *Ad Click Prediction: a View from the Trenches*, available at
1610    http://dl.acm.org/citation.cfm?id=2488200.
1611
1612    The optimizer updates the weight by:
1613
1614        rescaled_grad = clip(grad * rescale_grad, clip_gradient)
1615        z += rescaled_grad - (sqrt(n + rescaled_grad**2) - sqrt(n)) * weight / learning_rate
1616        n += rescaled_grad**2
1617        w = (sign(z) * lamda1 - z) / ((beta + sqrt(n)) / learning_rate + wd) * (abs(z) > lamda1)
1618
1619    If the storage types of weight, state and grad are all row_sparse,
1620    **sparse updates** are applied by::
1621
1622        for row in grad.indices:
1623            rescaled_grad[row] = clip(grad[row] * rescale_grad, clip_gradient)
1624            z[row] += rescaled_grad[row] - (sqrt(n[row] + rescaled_grad[row]**2) - sqrt(n[row])) * weight[row] / learning_rate
1625            n[row] += rescaled_grad[row]**2
1626            w[row] = (sign(z[row]) * lamda1 - z[row]) / ((beta + sqrt(n[row])) / learning_rate + wd) * (abs(z[row]) > lamda1)
1627
1628    The sparse update only updates the z and n for the weights whose row_sparse
1629    gradient indices appear in the current batch, rather than updating it for all
1630    indices. Compared with the original update, it can provide large
1631    improvements in model training throughput for some applications. However, it
1632    provides slightly different semantics than the original update, and
1633    may lead to different empirical results.
1634
1635    This optimizer accepts the following parameters in addition to those accepted
1636    by AI::MXNet::Optimizer
1637
1638    Parameters
1639    ----------
1640    lamda1 : Num, optional
1641        L1 regularization coefficient.
1642    learning_rate : Num, optional
1643        The initial learning rate.
1644    beta : Num, optional
1645        Per-coordinate learning rate correlation parameter.
1646=cut
1647
1648use Mouse;
1649extends 'AI::MXNet::Optimizer';
1650has '+learning_rate' => (default => 0.1);
1651has 'beta'           => (is => "ro", isa => "Num",  default => 1);
1652has 'lamda1'         => (is => "ro", isa => "Num",  default => 0.01);
1653
1654method create_state(Index $index, AI::MXNet::NDArray $weight)
1655{
1656    return [
1657            AI::MXNet::NDArray->zeros(
1658                $weight->shape,
1659                ctx => $weight->context,
1660                stype => $weight->stype
1661            ),  # z
1662            AI::MXNet::NDArray->zeros(
1663                $weight->shape,
1664                ctx => $weight->context,
1665                stype => $weight->stype
1666            )   # n
1667    ];
1668}
1669
1670method update(
1671    Index $index,
1672    AI::MXNet::NDArray $weight,
1673    AI::MXNet::NDArray $grad,
1674    ArrayRef[AI::MXNet::NDArray] $state
1675)
1676{
1677    $self->_update_count($index);
1678    my $wd = $self->_get_wd($index);
1679    my $lr = $self->_get_lr($index);
1680    my %kwargs = (lamda1 => $self->lamda1, beta => $self->beta, rescale_grad => $self->rescale_grad);
1681    if($self->clip_gradient)
1682    {
1683        $kwargs{clip_gradient} = $self->clip_gradient;
1684    }
1685    # accumulated g and delta initialization
1686    my ($z, $n) = @{ $state };
1687    AI::MXNet::NDArray->ftrl_update(
1688        $weight, $grad, $z, $n,
1689        { lr => $lr, wd => $wd, %kwargs, out => $weight }
1690    );
1691}
1692
1693__PACKAGE__->register;
1694
1695package AI::MXNet::Adamax;
1696
1697=head1 NAME
1698
1699    AI::MXNet::Adamax
1700=cut
1701
1702=head1 DESCRIPTION
1703
1704    It is a variant of Adam based on the infinity norm
1705    available at http://arxiv.org/abs/1412.6980 Section 7.
1706
1707    This optimizer accepts the following parameters in addition to those accepted
1708    AI::MXNet::Optimizer.
1709
1710    Parameters
1711    ----------
1712    beta1 : Num, optional
1713        Exponential decay rate for the first moment estimates.
1714    beta2 : Num, optional
1715        Exponential decay rate for the second moment estimates.
1716=cut
1717
1718use Mouse;
1719extends 'AI::MXNet::Optimizer';
1720has '+learning_rate' => (default => 0.002);
1721has 'beta1'          => (is => "ro", isa => "Num",  default => 0.9);
1722has 'beta2'          => (is => "ro", isa => "Num",  default => 0.999);
1723
1724method create_state(Index $index, AI::MXNet::NDArray $weight)
1725{
1726    return [
1727            AI::MXNet::NDArray->zeros(
1728                $weight->shape,
1729                ctx => $weight->context,
1730                dtype => $weight->dtype
1731            ),  # mean
1732            AI::MXNet::NDArray->zeros(
1733                $weight->shape,
1734                ctx => $weight->context,
1735                dtype => $weight->dtype
1736            )   # variance
1737    ];
1738}
1739
1740method update(
1741    Index $index,
1742    AI::MXNet::NDArray $weight,
1743    AI::MXNet::NDArray $grad,
1744    ArrayRef[AI::MXNet::NDArray] $state
1745)
1746{
1747    my $wd = $self->_get_wd($index);
1748    my $lr = $self->_get_lr($index);
1749    $self->_update_count($index);
1750    my $t = $self->_index_update_count->{$index};
1751    $lr /= (1 - $self->beta1**$t);
1752
1753    $grad = $grad * $self->rescale_grad + $wd * $weight;
1754    if($self->clip_gradient)
1755    {
1756        $grad = AI::MXNet::NDArray->clip(
1757            $grad,
1758            -$self->clip_gradient,
1759             $self->clip_gradient
1760        );
1761    }
1762
1763    # update m_t and u_t
1764    my($m_t, $u_t) = @{ $state };
1765    $m_t .= $self->beta1 * $m_t + (1 - $self->beta1) * $grad;
1766    $u_t .= AI::MXNet::NDArray->maximum($self->beta2 * $u_t, $grad->abs);
1767
1768    # update weight
1769    $weight -= $lr * $m_t / $u_t;
1770}
1771
1772__PACKAGE__->register;
1773
1774package AI::MXNet::Nadam;
1775
1776=head1 NAME
1777
1778    AI::MXNet::Nadam
1779=cut
1780
1781=head1 DESCRIPTION
1782
1783    The Nesterov Adam optimizer.
1784
1785    Much like Adam is essentially RMSprop with momentum,
1786    Nadam is Adam RMSprop with Nesterov momentum available
1787    at http://cs229.stanford.edu/proj2015/054_report.pdf.
1788
1789    This optimizer accepts the following parameters in addition to those accepted
1790    by AI::MXNet::Optimizer.
1791
1792    Parameters
1793    ----------
1794    beta1 : Num, optional
1795        Exponential decay rate for the first moment estimates.
1796    beta2 : Num, optional
1797        Exponential decay rate for the second moment estimates.
1798    epsilon : Num, optional
1799        Small value to avoid division by 0.
1800    schedule_decay : Num, optional
1801        Exponential decay rate for the momentum schedule
1802=cut
1803
1804use Mouse;
1805extends 'AI::MXNet::Optimizer';
1806has '+learning_rate' => (default => 0.001);
1807has 'beta1'          => (is => "ro", isa => "Num",  default => 0.9);
1808has 'beta2'          => (is => "ro", isa => "Num",  default => 0.999);
1809has 'epsilon'        => (is => "ro", isa => "Num",  default => 1e-8);
1810has 'schedule_decay' => (is => "ro", isa => "Num",  default => 0.004);
1811has 'm_schedule'     => (is => "rw", default => 1, init_arg => undef);
1812
1813method create_state(Index $index, AI::MXNet::NDArray $weight)
1814{
1815    return [
1816            AI::MXNet::NDArray->zeros(
1817                $weight->shape,
1818                ctx => $weight->context,
1819                dtype => $weight->dtype
1820            ),  # mean
1821            AI::MXNet::NDArray->zeros(
1822                $weight->shape,
1823                ctx => $weight->context,
1824                dtype => $weight->dtype
1825            )   # variance
1826    ];
1827}
1828
1829method update(
1830    Index $index,
1831    AI::MXNet::NDArray $weight,
1832    AI::MXNet::NDArray $grad,
1833    ArrayRef[AI::MXNet::NDArray] $state
1834)
1835{
1836    my $wd = $self->_get_wd($index);
1837    my $lr = $self->_get_lr($index);
1838    $self->_update_count($index);
1839    my $t = $self->_index_update_count->{$index};
1840    $grad = $grad * $self->rescale_grad + $wd * $weight;
1841    if($self->clip_gradient)
1842    {
1843        $grad = AI::MXNet::NDArray->clip(
1844            $grad,
1845            -$self->clip_gradient,
1846             $self->clip_gradient
1847        );
1848    }
1849    # warming momentum schedule
1850    my $momentum_t    = $self->beta1 * (1 - 0.5 * (0.96**($t * $self->schedule_decay)));
1851    my $momentum_t_1  = $self->beta1 * (1 - 0.5 * (0.96**(($t + 1) * $self->schedule_decay)));
1852    $self->m_schedule = $self->m_schedule * $momentum_t;
1853    my $m_schedule_next  = $self->m_schedule * $momentum_t_1;
1854
1855    # update m_t and v_t
1856    my ($m_t, $v_t) = @{ $state };
1857    $m_t .= $self->beta1 * $m_t + (1 - $self->beta1) * $grad;
1858    $v_t .= $self->beta2 * $v_t + (1 - $self->beta2) * $grad * $grad;
1859
1860    my $grad_prime = $grad / (1 - $self->m_schedule);
1861    my $m_t_prime  = $m_t  / (1 - $m_schedule_next);
1862    my $v_t_prime  = $v_t  / (1 - $self->beta2**$t);
1863    my $m_t_bar    = (1 - $momentum_t) * $grad_prime + $momentum_t_1 * $m_t_prime;
1864
1865    # update weight
1866    $weight -= $lr * $m_t_bar / (sqrt($v_t_prime) + $self->epsilon);
1867}
1868
1869__PACKAGE__->register;
1870
1871=head1 NAME
1872
1873    AI::MXNet::Updater - Updater for kvstore
1874=cut
1875
1876package AI::MXNet::Updater;
1877use Mouse;
1878use Storable qw(thaw freeze);
1879use overload "&{}" => sub { my $self = shift; sub { $self->call(@_) } },
1880             fallback => 1;
1881
1882has "optimizer"     => (is => "rw", isa => "AI::MXNet::Optimizer");
1883has "states"        => (is => "rw", isa => "HashRef", default => sub { +{} });
1884has "states_synced" => (is => "rw", isa => "HashRef", default => sub { +{} });
1885
1886method call(Index $index, AI::MXNet::NDArray $grad, AI::MXNet::NDArray $weight)
1887{
1888    if(not exists $self->states->{ $index })
1889    {
1890        $self->states->{ $index } = $self->optimizer->create_state($index, $weight);
1891        $self->states_synced->{ $index } = 1;
1892    }
1893    elsif(not $self->states_synced->{ $index })
1894    {
1895        $self->states->{ $index } = $self->sync_state_context($self->states->{ $index }, $weight->context);
1896        $self->states_synced->{ $index } = 1;
1897    }
1898    $self->optimizer->update($index, $weight, $grad, $self->states->{ $index });
1899}
1900*slice = *call;
1901
1902method sync_state_context(Maybe[AI::MXNet::NDArray|ArrayRef[AI::MXNet::NDArray]] $state, AI::MXNet::Context $context)
1903{
1904    if(blessed $state)
1905    {
1906        return $state->as_in_context($context);
1907    }
1908    elsif(ref $state)
1909    {
1910        return [map { $self->sync_state_context($_, $context) } @{ $state }];
1911    }
1912    return $state;
1913}
1914
1915=head2 set_states
1916
1917    Sets updater states.
1918=cut
1919
1920method set_states($states)
1921{
1922    my $thawed_states = thaw($states);
1923    my ($optimizer);
1924    if(ref $thawed_states eq 'ARRAY')
1925    {
1926        ($thawed_states, $optimizer) = @{ $thawed_states };
1927        $self->optimizer($optimizer);
1928    }
1929    $self->states($thawed_states);
1930    %{ $self->states_synced } = map { $_ => 0 } keys %{ $thawed_states };
1931}
1932
1933=head2 get_states
1934
1935        Gets updater states.
1936
1937        Parameters
1938        ----------
1939        dump_optimizer : bool, default False
1940            Whether to also save the optimizer itself. This would also save optimizer
1941            information such as learning rate and weight decay schedules.
1942=cut
1943
1944method get_states(Bool $dump_optimizer=0)
1945{
1946    if($dump_optimizer)
1947    {
1948        my $param_dict = $self->optimizer->param_dict;
1949        $self->optimizer->param_dict({});
1950        my $freezed = freeze([$self->states, $self->optimizer]);
1951        $self->optimizer->param_dict($param_dict);
1952        return $freezed;
1953    }
1954    return freeze($self->states);
1955}
1956
1957package AI::MXNet::Optimizer;
1958
1959method get_updater(AI::MXNet::Optimizer $optimizer)
1960{
1961    return AI::MXNet::Updater->new(optimizer => $optimizer);
1962}
1963
19641;
1965