1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18package AI::MXNet::RNN::Params;
19use Mouse;
20use AI::MXNet::Function::Parameters;
21
22=head1 NAME
23
24    AI::MXNet::RNN::Params - A container for holding variables.
25=cut
26
27=head1 DESCRIPTION
28
29    A container for holding variables.
30    Used by RNN cells for parameter sharing between cells.
31
32    Parameters
33    ----------
34    prefix : str
35        All variables name created by this container will
36        be prepended with the prefix
37=cut
38has '_prefix' => (is => 'ro', init_arg => 'prefix', isa => 'Str', default => '');
39has '_params' => (is => 'rw', init_arg => undef);
40around BUILDARGS => sub {
41    my $orig  = shift;
42    my $class = shift;
43    return $class->$orig(prefix => $_[0]) if @_ == 1;
44    return $class->$orig(@_);
45};
46
47sub BUILD
48{
49    my $self = shift;
50    $self->_params({});
51}
52
53
54=head2 get
55
56    Get a variable with the name or create a new one if does not exist.
57
58    Parameters
59    ----------
60    $name : str
61        name of the variable
62    @kwargs:
63        more arguments that are passed to mx->sym->Variable call
64=cut
65
66method get(Str $name, @kwargs)
67{
68    $name = $self->_prefix . $name;
69    if(not exists $self->_params->{$name})
70    {
71        $self->_params->{$name} = AI::MXNet::Symbol->Variable($name, @kwargs);
72    }
73    return $self->_params->{$name};
74}
75
76package AI::MXNet::RNN::Cell::Base;
77=head1 NAME
78
79    AI::MXNet::RNNCell::Base
80=cut
81
82=head1 DESCRIPTION
83
84    Abstract base class for RNN cells
85
86    Parameters
87    ----------
88    prefix : str
89        prefix for name of layers
90        (and name of weight if params is undef)
91    params : AI::MXNet::RNN::Params or undef
92        container for weight sharing between cells.
93        created if undef.
94=cut
95
96use AI::MXNet::Base;
97use Mouse;
98use overload "&{}"  => sub { my $self = shift; sub { $self->call(@_) } };
99has '_prefix'       => (is => 'rw', init_arg => 'prefix', isa => 'Str', default => '');
100has '_params'       => (is => 'rw', init_arg => 'params', isa => 'Maybe[AI::MXNet::RNN::Params]');
101has [qw/_own_params
102        _modified
103        _init_counter
104        _counter
105                 /] => (is => 'rw', init_arg => undef);
106
107around BUILDARGS => sub {
108    my $orig  = shift;
109    my $class = shift;
110    return $class->$orig(prefix => $_[0]) if @_ == 1;
111    return $class->$orig(@_);
112};
113
114sub BUILD
115{
116    my $self = shift;
117    if(not defined $self->_params)
118    {
119        $self->_own_params(1);
120        $self->_params(AI::MXNet::RNN::Params->new($self->_prefix));
121    }
122    else
123    {
124        $self->_own_params(0);
125    }
126    $self->_modified(0);
127    $self->reset;
128}
129
130=head2 reset
131
132    Reset before re-using the cell for another graph
133=cut
134
135method reset()
136{
137    $self->_init_counter(-1);
138    $self->_counter(-1);
139}
140
141=head2 call
142
143    Construct symbol for one step of RNN.
144
145    Parameters
146    ----------
147    $inputs : mx->sym->Variable
148        input symbol, 2D, batch * num_units
149    $states : mx->sym->Variable or ArrayRef[AI::MXNet::Symbol]
150        state from previous step or begin_state().
151
152    Returns
153    -------
154    $output : AI::MXNet::Symbol
155        output symbol
156    $states : ArrayRef[AI::MXNet::Symbol]
157        state to next step of RNN.
158    Can be called via overloaded &{}: &{$cell}($inputs, $states);
159=cut
160
161method call(AI::MXNet::Symbol $inputs, AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol] $states)
162{
163    confess("Not Implemented");
164}
165
166method _gate_names()
167{
168    [''];
169}
170
171=head2 params
172
173    Parameters of this cell
174=cut
175
176method params()
177{
178    $self->_own_params(0);
179    return $self->_params;
180}
181
182=head2 state_shape
183
184    shape(s) of states
185=cut
186
187method state_shape()
188{
189    return [map { $_->{shape} } @{ $self->state_info }];
190}
191
192=head2 state_info
193
194    shape and layout information of states
195=cut
196
197method state_info()
198{
199    confess("Not Implemented");
200}
201
202=head2 begin_state
203
204    Initial state for this cell.
205
206    Parameters
207    ----------
208    :$func : sub ref, default is AI::MXNet::Symbol->can('zeros')
209        Function for creating initial state.
210        Can be AI::MXNet::Symbol->can('zeros'),
211        AI::MXNet::Symbol->can('uniform'), AI::MXNet::Symbol->can('Variable') etc.
212        Use AI::MXNet::Symbol->can('Variable') if you want to directly
213        feed the input as states.
214    @kwargs :
215        more keyword arguments passed to func. For example
216        mean, std, dtype, etc.
217
218    Returns
219    -------
220    $states : ArrayRef[AI::MXNet::Symbol]
221        starting states for first RNN step
222=cut
223
224method begin_state(CodeRef :$func=AI::MXNet::Symbol->can('zeros'), @kwargs)
225{
226    assert(
227        (not $self->_modified),
228        "After applying modifier cells (e.g. DropoutCell) the base "
229        ."cell cannot be called directly. Call the modifier cell instead."
230    );
231    my @states;
232    my $func_needs_named_name = $func ne AI::MXNet::Symbol->can('Variable');
233    for my $info (@{ $self->state_info })
234    {
235        $self->_init_counter($self->_init_counter + 1);
236        my @name = (sprintf("%sbegin_state_%d", $self->_prefix, $self->_init_counter));
237        my %info = %{ $info//{} };
238        if($func_needs_named_name)
239        {
240            unshift(@name, 'name');
241        }
242        else
243        {
244            if(exists $info{__layout__})
245            {
246                $info{kwargs} = { __layout__ => delete $info{__layout__} };
247            }
248        }
249        my %kwargs = (@kwargs, %info);
250        my $state = $func->(
251            'AI::MXNet::Symbol',
252            @name,
253            %kwargs
254        );
255        push @states, $state;
256    }
257    return \@states;
258}
259
260=head2 unpack_weights
261
262    Unpack fused weight matrices into separate
263    weight matrices
264
265    Parameters
266    ----------
267    $args : HashRef[AI::MXNet::NDArray]
268        hash ref containing packed weights.
269        usually from AI::MXNet::Module->get_output()
270
271    Returns
272    -------
273    $args : HashRef[AI::MXNet::NDArray]
274        hash ref with weights associated with
275        this cell, unpacked.
276=cut
277
278method unpack_weights(HashRef[AI::MXNet::NDArray] $args)
279{
280    my %args = %{ $args };
281    my $h = $self->_num_hidden;
282    for my $group_name ('i2h', 'h2h')
283    {
284        my $weight = delete $args{ sprintf('%s%s_weight', $self->_prefix, $group_name) };
285        my $bias   = delete $args{ sprintf('%s%s_bias', $self->_prefix, $group_name) };
286        enumerate(sub {
287            my ($j, $name) = @_;
288            my $wname = sprintf('%s%s%s_weight', $self->_prefix, $group_name, $name);
289            $args->{$wname} = $weight->slice([$j*$h,($j+1)*$h-1])->copy;
290            my $bname = sprintf('%s%s%s_bias', $self->_prefix, $group_name, $name);
291            $args->{$bname} = $bias->slice([$j*$h,($j+1)*$h-1])->copy;
292        }, $self->_gate_names);
293    }
294    return \%args;
295}
296
297=head2 pack_weights
298
299    Pack fused weight matrices into common
300    weight matrices
301
302    Parameters
303    ----------
304    args : HashRef[AI::MXNet::NDArray]
305        hash ref containing unpacked weights.
306
307    Returns
308    -------
309    $args : HashRef[AI::MXNet::NDArray]
310        hash ref with weights associated with
311        this cell, packed.
312=cut
313
314method pack_weights(HashRef[AI::MXNet::NDArray] $args)
315{
316    my %args = %{ $args };
317    my $h = $self->_num_hidden;
318    for my $group_name ('i2h', 'h2h')
319    {
320        my @weight;
321        my @bias;
322        for my $name (@{ $self->_gate_names })
323        {
324            my $wname = sprintf('%s%s%s_weight', $self->_prefix, $group_name, $name);
325            push @weight, delete $args{$wname};
326            my $bname = sprintf('%s%s%s_bias', $self->_prefix, $group_name, $name);
327            push @bias, delete $args{$bname};
328        }
329        $args{ sprintf('%s%s_weight', $self->_prefix, $group_name) } = AI::MXNet::NDArray->concatenate(
330            \@weight
331        );
332        $args{ sprintf('%s%s_bias', $self->_prefix, $group_name) } = AI::MXNet::NDArray->concatenate(
333            \@bias
334        );
335    }
336    return \%args;
337}
338
339=head2 unroll
340
341    Unroll an RNN cell across time steps.
342
343    Parameters
344    ----------
345    :$length : Int
346        number of steps to unroll
347    :$inputs : AI::MXNet::Symbol, array ref of Symbols, or undef
348        if inputs is a single Symbol (usually the output
349        of Embedding symbol), it should have shape
350        of [$batch_size, $length, ...] if layout == 'NTC' (batch, time series)
351        or ($length, $batch_size, ...) if layout == 'TNC' (time series, batch).
352
353        If inputs is a array ref of symbols (usually output of
354        previous unroll), they should all have shape
355        ($batch_size, ...).
356
357        If inputs is undef, a placeholder variables are
358        automatically created.
359    :$begin_state : array ref of Symbol
360        input states. Created by begin_state()
361        or output state of another cell. Created
362        from begin_state() if undef.
363    :$input_prefix : str
364        prefix for automatically created input
365        placehodlers.
366    :$layout : str
367        layout of input symbol. Only used if the input
368        is a single Symbol.
369    :$merge_outputs : Bool
370        If 0, returns outputs as an array ref of Symbols.
371        If 1, concatenates the output across the time steps
372        and returns a single symbol with the shape
373        [$batch_size, $length, ...) if the layout equal to 'NTC',
374        or [$length, $batch_size, ...) if the layout equal tp 'TNC'.
375        If undef, output whatever is faster
376
377    Returns
378    -------
379    $outputs : array ref of Symbol or Symbol
380        output symbols.
381    $states : Symbol or nested list of Symbol
382        has the same structure as begin_state()
383=cut
384
385
386method unroll(
387    Int $length,
388    Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=,
389    Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=,
390    Str                                                  :$input_prefix='',
391    Str                                                  :$layout='NTC',
392    Maybe[Bool]                                          :$merge_outputs=
393)
394{
395    $self->reset;
396    my $axis = index($layout, 'T');
397    if(not defined $inputs)
398    {
399        $inputs = [
400            map { AI::MXNet::Symbol->Variable("${input_prefix}t${_}_data") } (0..$length-1)
401        ];
402    }
403    elsif(blessed($inputs))
404    {
405        assert(
406            (@{ $inputs->list_outputs() } == 1),
407            "unroll doesn't allow grouped symbol as input. Please "
408            ."convert to list first or let unroll handle slicing"
409        );
410        $inputs = AI::MXNet::Symbol->SliceChannel(
411            $inputs,
412            axis         => $axis,
413            num_outputs  => $length,
414            squeeze_axis => 1
415        );
416    }
417    else
418    {
419        assert(@$inputs == $length);
420    }
421    $begin_state //= $self->begin_state;
422    my $states = $begin_state;
423    my $outputs;
424    my @inputs = @{ $inputs };
425    for my $i (0..$length-1)
426    {
427        my $output;
428        ($output, $states) = $self->(
429            $inputs[$i],
430            $states
431        );
432        push @$outputs, $output;
433    }
434    if($merge_outputs)
435    {
436        @$outputs = map { AI::MXNet::Symbol->expand_dims($_, axis => $axis) } @$outputs;
437        $outputs = AI::MXNet::Symbol->Concat(@$outputs, dim => $axis);
438    }
439    return($outputs, $states);
440}
441
442method _get_activation($inputs, $activation, @kwargs)
443{
444    if(not ref $activation)
445    {
446        return AI::MXNet::Symbol->Activation($inputs, act_type => $activation, @kwargs);
447    }
448    else
449    {
450        return $activation->($inputs, @kwargs);
451    }
452}
453
454method _cells_state_shape($cells)
455{
456    return [map { @{ $_->state_shape } } @$cells];
457}
458
459method _cells_state_info($cells)
460{
461    return [map { @{ $_->state_info } } @$cells];
462}
463
464method _cells_begin_state($cells, @kwargs)
465{
466    return [map { @{ $_->begin_state(@kwargs) } } @$cells];
467}
468
469method _cells_unpack_weights($cells, $args)
470{
471    $args = $_->unpack_weights($args) for @$cells;
472    return $args;
473}
474
475method _cells_pack_weights($cells, $args)
476{
477    $args = $_->pack_weights($args) for @$cells;
478    return $args;
479}
480
481package AI::MXNet::RNN::Cell;
482use Mouse;
483extends 'AI::MXNet::RNN::Cell::Base';
484
485=head1 NAME
486
487    AI::MXNet::RNN::Cell
488=cut
489
490=head1 DESCRIPTION
491
492    Simple recurrent neural network cell
493
494    Parameters
495    ----------
496    num_hidden : int
497        number of units in output symbol
498    activation : str or Symbol, default 'tanh'
499        type of activation function
500    prefix : str, default 'rnn_'
501        prefix for name of layers
502        (and name of weight if params is undef)
503    params : AI::MXNet::RNNParams or undef
504        container for weight sharing between cells.
505        created if undef.
506=cut
507
508has '_num_hidden'  => (is => 'ro', init_arg => 'num_hidden', isa => 'Int', required => 1);
509has 'forget_bias'  => (is => 'ro', isa => 'Num');
510has '_activation'  => (
511    is       => 'ro',
512    init_arg => 'activation',
513    isa      => 'Activation',
514    default  => 'tanh'
515);
516has '+_prefix'    => (default => 'rnn_');
517has [qw/_iW _iB
518        _hW _hB/] => (is => 'rw', init_arg => undef);
519
520around BUILDARGS => sub {
521    my $orig  = shift;
522    my $class = shift;
523    return $class->$orig(num_hidden => $_[0]) if @_ == 1;
524    return $class->$orig(@_);
525};
526
527sub BUILD
528{
529    my $self = shift;
530    $self->_iW($self->params->get('i2h_weight'));
531    $self->_iB(
532        $self->params->get(
533            'i2h_bias',
534            (defined($self->forget_bias)
535                ? (init => AI::MXNet::LSTMBias->new(forget_bias => $self->forget_bias))
536                : ()
537            )
538        )
539    );
540    $self->_hW($self->params->get('h2h_weight'));
541    $self->_hB($self->params->get('h2h_bias'));
542}
543
544method state_info()
545{
546    return [{ shape => [0, $self->_num_hidden], __layout__ => 'NC' }];
547}
548
549method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states)
550{
551    $self->_counter($self->_counter + 1);
552    my $name = sprintf('%st%d_', $self->_prefix, $self->_counter);
553    my $i2h = AI::MXNet::Symbol->FullyConnected(
554        data       => $inputs,
555        weight     => $self->_iW,
556        bias       => $self->_iB,
557        num_hidden => $self->_num_hidden,
558        name       => "${name}i2h"
559    );
560    my $h2h = AI::MXNet::Symbol->FullyConnected(
561        data       => @{$states}[0],
562        weight     => $self->_hW,
563        bias       => $self->_hB,
564        num_hidden => $self->_num_hidden,
565        name       => "${name}h2h"
566    );
567    my $output = $self->_get_activation(
568        $i2h + $h2h,
569        $self->_activation,
570        name       => "${name}out"
571    );
572    return ($output, [$output]);
573}
574
575package AI::MXNet::RNN::LSTMCell;
576use Mouse;
577use AI::MXNet::Base;
578extends 'AI::MXNet::RNN::Cell';
579
580=head1 NAME
581
582    AI::MXNet::RNN::LSTMCell
583=cut
584
585=head1 DESCRIPTION
586
587    Long-Short Term Memory (LSTM) network cell.
588
589    Parameters
590    ----------
591    num_hidden : int
592        number of units in output symbol
593    prefix : str, default 'lstm_'
594        prefix for name of layers
595        (and name of weight if params is undef)
596    params : AI::MXNet::RNN::Params or None
597        container for weight sharing between cells.
598        created if undef.
599    forget_bias : bias added to forget gate, default 1.0.
600        Jozefowicz et al. 2015 recommends setting this to 1.0
601=cut
602
603has '+_prefix'     => (default => 'lstm_');
604has '+_activation' => (init_arg => undef);
605has '+forget_bias' => (is => 'ro', isa => 'Num', default => 1);
606
607method state_info()
608{
609    return [{ shape => [0, $self->_num_hidden], __layout__ => 'NC' } , { shape => [0, $self->_num_hidden], __layout__ => 'NC' }];
610}
611
612method _gate_names()
613{
614    [qw/_i _f _c _o/];
615}
616
617method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states)
618{
619    $self->_counter($self->_counter + 1);
620    my $name = sprintf('%st%d_', $self->_prefix, $self->_counter);
621    my @states = @{ $states };
622    my $i2h = AI::MXNet::Symbol->FullyConnected(
623        data       => $inputs,
624        weight     => $self->_iW,
625        bias       => $self->_iB,
626        num_hidden => $self->_num_hidden*4,
627        name       => "${name}i2h"
628    );
629    my $h2h = AI::MXNet::Symbol->FullyConnected(
630        data       => $states[0],
631        weight     => $self->_hW,
632        bias       => $self->_hB,
633        num_hidden => $self->_num_hidden*4,
634        name       => "${name}h2h"
635    );
636    my $gates = $i2h + $h2h;
637    my @slice_gates = @{ AI::MXNet::Symbol->SliceChannel(
638        $gates, num_outputs => 4, name => "${name}slice"
639    ) };
640    my $in_gate = AI::MXNet::Symbol->Activation(
641        $slice_gates[0], act_type => "sigmoid", name => "${name}i"
642    );
643    my $forget_gate = AI::MXNet::Symbol->Activation(
644        $slice_gates[1], act_type => "sigmoid", name => "${name}f"
645    );
646    my $in_transform = AI::MXNet::Symbol->Activation(
647        $slice_gates[2], act_type => "tanh", name => "${name}c"
648    );
649    my $out_gate = AI::MXNet::Symbol->Activation(
650        $slice_gates[3], act_type => "sigmoid", name => "${name}o"
651    );
652    my $next_c = AI::MXNet::Symbol->_plus(
653        $forget_gate * $states[1], $in_gate * $in_transform,
654        name => "${name}state"
655    );
656    my $next_h = AI::MXNet::Symbol->_mul(
657        $out_gate,
658        AI::MXNet::Symbol->Activation(
659            $next_c, act_type => "tanh"
660        ),
661        name => "${name}out"
662    );
663    return ($next_h, [$next_h, $next_c]);
664
665}
666
667package AI::MXNet::RNN::GRUCell;
668use Mouse;
669use AI::MXNet::Base;
670extends 'AI::MXNet::RNN::Cell';
671
672=head1 NAME
673
674    AI::MXNet::RNN::GRUCell
675=cut
676
677=head1 DESCRIPTION
678
679    Gated Rectified Unit (GRU) network cell.
680    Note: this is an implementation of the cuDNN version of GRUs
681    (slight modification compared to Cho et al. 2014).
682
683    Parameters
684    ----------
685    num_hidden : int
686        number of units in output symbol
687    prefix : str, default 'gru_'
688        prefix for name of layers
689        (and name of weight if params is undef)
690    params : AI::MXNet::RNN::Params or undef
691        container for weight sharing between cells.
692        created if undef.
693=cut
694
695has '+_prefix'     => (default => 'gru_');
696
697method _gate_names()
698{
699    [qw/_r _z _o/];
700}
701
702method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states)
703{
704    $self->_counter($self->_counter + 1);
705    my $name = sprintf('%st%d_', $self->_prefix, $self->_counter);
706    my $prev_state_h = @{ $states }[0];
707    my $i2h = AI::MXNet::Symbol->FullyConnected(
708        data       => $inputs,
709        weight     => $self->_iW,
710        bias       => $self->_iB,
711        num_hidden => $self->_num_hidden*3,
712        name       => "${name}i2h"
713    );
714    my $h2h = AI::MXNet::Symbol->FullyConnected(
715        data       => $prev_state_h,
716        weight     => $self->_hW,
717        bias       => $self->_hB,
718        num_hidden => $self->_num_hidden*3,
719        name       => "${name}h2h"
720    );
721    my ($i2h_r, $i2h_z);
722    ($i2h_r, $i2h_z, $i2h) = @{ AI::MXNet::Symbol->SliceChannel(
723        $i2h, num_outputs => 3, name => "${name}_i2h_slice"
724    ) };
725    my ($h2h_r, $h2h_z);
726    ($h2h_r, $h2h_z, $h2h) = @{ AI::MXNet::Symbol->SliceChannel(
727        $h2h, num_outputs => 3, name => "${name}_h2h_slice"
728    ) };
729    my $reset_gate = AI::MXNet::Symbol->Activation(
730        $i2h_r + $h2h_r, act_type => "sigmoid", name => "${name}_r_act"
731    );
732    my $update_gate = AI::MXNet::Symbol->Activation(
733        $i2h_z + $h2h_z, act_type => "sigmoid", name => "${name}_z_act"
734    );
735    my $next_h_tmp = AI::MXNet::Symbol->Activation(
736        $i2h + $reset_gate * $h2h, act_type => "tanh", name => "${name}_h_act"
737    );
738    my $next_h = AI::MXNet::Symbol->_plus(
739        (1 - $update_gate) * $next_h_tmp, $update_gate * $prev_state_h,
740        name => "${name}out"
741    );
742    return ($next_h, [$next_h]);
743}
744
745package AI::MXNet::RNN::FusedCell;
746use Mouse;
747use AI::MXNet::Types;
748use AI::MXNet::Base;
749extends 'AI::MXNet::RNN::Cell::Base';
750
751=head1 NAME
752
753    AI::MXNet::RNN::FusedCell
754=cut
755
756=head1 DESCRIPTION
757
758    Fusing RNN layers across time step into one kernel.
759    Improves speed but is less flexible. Currently only
760    supported if using cuDNN on GPU.
761=cut
762
763has '_num_hidden'      => (is => 'ro', isa => 'Int',  init_arg => 'num_hidden',     required => 1);
764has '_num_layers'      => (is => 'ro', isa => 'Int',  init_arg => 'num_layers',     default => 1);
765has '_dropout'         => (is => 'ro', isa => 'Num',  init_arg => 'dropout',        default => 0);
766has '_get_next_state'  => (is => 'ro', isa => 'Bool', init_arg => 'get_next_state', default => 0);
767has '_bidirectional'   => (is => 'ro', isa => 'Bool', init_arg => 'bidirectional',  default => 0);
768has 'forget_bias'      => (is => 'ro', isa => 'Num',  default => 1);
769has 'initializer'      => (is => 'rw', isa => 'Maybe[Initializer]');
770has '_mode'            => (
771    is => 'ro',
772    isa => enum([qw/rnn_relu rnn_tanh lstm gru/]),
773    init_arg => 'mode',
774    default => 'lstm'
775);
776has [qw/_parameter
777        _directions/] => (is => 'rw', init_arg => undef);
778
779around BUILDARGS => sub {
780    my $orig  = shift;
781    my $class = shift;
782    return $class->$orig(num_hidden => $_[0]) if @_ == 1;
783    return $class->$orig(@_);
784};
785
786sub BUILD
787{
788    my $self = shift;
789    if(not $self->_prefix)
790    {
791        $self->_prefix($self->_mode.'_');
792    }
793    if(not defined $self->initializer)
794    {
795        $self->initializer(
796            AI::MXNet::Xavier->new(
797                factor_type => 'in',
798                magnitude   => 2.34
799            )
800        );
801    }
802    if(not $self->initializer->isa('AI::MXNet::FusedRNN'))
803    {
804        $self->initializer(
805            AI::MXNet::FusedRNN->new(
806                init           => $self->initializer,
807                num_hidden     => $self->_num_hidden,
808                num_layers     => $self->_num_layers,
809                mode           => $self->_mode,
810                bidirectional  => $self->_bidirectional,
811                forget_bias    => $self->forget_bias
812            )
813        );
814    }
815    $self->_parameter($self->params->get('parameters', init => $self->initializer));
816    $self->_directions($self->_bidirectional ? [qw/l r/] : ['l']);
817}
818
819
820method state_info()
821{
822    my $b = @{ $self->_directions };
823    my $n = $self->_mode eq 'lstm' ? 2 : 1;
824    return [map { +{ shape => [$b*$self->_num_layers, 0, $self->_num_hidden], __layout__ => 'LNC' } } 0..$n-1];
825}
826
827method _gate_names()
828{
829    return {
830        rnn_relu => [''],
831        rnn_tanh => [''],
832        lstm     => [qw/_i _f _c _o/],
833        gru      => [qw/_r _z _o/]
834    }->{ $self->_mode };
835}
836
837method _num_gates()
838{
839    return scalar(@{ $self->_gate_names })
840}
841
842method _slice_weights($arr, $li, $lh)
843{
844    my %args;
845    my @gate_names = @{ $self->_gate_names };
846    my @directions = @{ $self->_directions };
847
848    my $b = @directions;
849    my $p = 0;
850    for my $layer (0..$self->_num_layers-1)
851    {
852        for my $direction (@directions)
853        {
854            for my $gate (@gate_names)
855            {
856                my $name = sprintf('%s%s%d_i2h%s_weight', $self->_prefix, $direction, $layer, $gate);
857                my $size;
858                if($layer > 0)
859                {
860                    $size = $b*$lh*$lh;
861                    $args{$name} = $arr->slice([$p,$p+$size-1])->reshape([$lh, $b*$lh]);
862                }
863                else
864                {
865                    $size = $li*$lh;
866                    $args{$name} = $arr->slice([$p,$p+$size-1])->reshape([$lh, $li]);
867                }
868                $p += $size;
869            }
870            for my $gate (@gate_names)
871            {
872                my $name = sprintf('%s%s%d_h2h%s_weight', $self->_prefix, $direction, $layer, $gate);
873                my $size = $lh**2;
874                $args{$name} = $arr->slice([$p,$p+$size-1])->reshape([$lh, $lh]);
875                $p += $size;
876            }
877        }
878    }
879    for my $layer (0..$self->_num_layers-1)
880    {
881        for my $direction (@directions)
882        {
883            for my $gate (@gate_names)
884            {
885                my $name = sprintf('%s%s%d_i2h%s_bias', $self->_prefix, $direction, $layer, $gate);
886                $args{$name} = $arr->slice([$p,$p+$lh-1]);
887                $p += $lh;
888            }
889            for my $gate (@gate_names)
890            {
891                my $name = sprintf('%s%s%d_h2h%s_bias', $self->_prefix, $direction, $layer, $gate);
892                $args{$name} = $arr->slice([$p,$p+$lh-1]);
893                $p += $lh;
894            }
895        }
896    }
897    assert($p == $arr->size, "Invalid parameters size for FusedRNNCell");
898    return %args;
899}
900
901method unpack_weights(HashRef[AI::MXNet::NDArray] $args)
902{
903    my %args = %{ $args };
904    my $arr = delete $args{ $self->_parameter->name };
905    my $b = @{ $self->_directions };
906    my $m = $self->_num_gates;
907    my $h = $self->_num_hidden;
908    my $num_input = int(int(int($arr->size/$b)/$h)/$m) - ($self->_num_layers - 1)*($h+$b*$h+2) - $h - 2;
909    my %nargs = $self->_slice_weights($arr, $num_input, $self->_num_hidden);
910    %args = (%args, map { $_ => $nargs{$_}->copy } keys %nargs);
911    return \%args
912}
913
914method pack_weights(HashRef[AI::MXNet::NDArray] $args)
915{
916    my %args = %{ $args };
917    my $b = @{ $self->_directions };
918    my $m = $self->_num_gates;
919    my @c = @{ $self->_gate_names };
920    my $h = $self->_num_hidden;
921    my $w0 = $args{ sprintf('%sl0_i2h%s_weight', $self->_prefix, $c[0]) };
922    my $num_input = $w0->shape->[1];
923    my $total = ($num_input+$h+2)*$h*$m*$b + ($self->_num_layers-1)*$m*$h*($h+$b*$h+2)*$b;
924    my $arr = AI::MXNet::NDArray->zeros([$total], ctx => $w0->context, dtype => $w0->dtype);
925    my %nargs = $self->_slice_weights($arr, $num_input, $h);
926    while(my ($name, $nd) = each %nargs)
927    {
928        $nd .= delete $args{ $name };
929    }
930    $args{ $self->_parameter->name } = $arr;
931    return \%args;
932}
933
934method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states)
935{
936    confess("AI::MXNet::RNN::FusedCell cannot be stepped. Please use unroll");
937}
938
939method unroll(
940    Int $length,
941    Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=,
942    Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=,
943    Str                                                  :$input_prefix='',
944    Str                                                  :$layout='NTC',
945    Maybe[Bool]                                          :$merge_outputs=
946)
947{
948    $self->reset;
949    my $axis = index($layout, 'T');
950    $inputs //= AI::MXNet::Symbol->Variable("${input_prefix}data");
951    if(blessed($inputs))
952    {
953        assert(
954            (@{ $inputs->list_outputs() } == 1),
955            "unroll doesn't allow grouped symbol as input. Please "
956            ."convert to list first or let unroll handle slicing"
957        );
958        if($axis == 1)
959        {
960            AI::MXNet::Logging->warning(
961                "NTC layout detected. Consider using "
962                ."TNC for RNN::FusedCell for faster speed"
963            );
964            $inputs = AI::MXNet::Symbol->SwapAxis($inputs, dim1 => 0, dim2 => 1);
965        }
966        else
967        {
968            assert($axis == 0, "Unsupported layout $layout");
969        }
970    }
971    else
972    {
973        assert(@$inputs == $length);
974        $inputs = [map { AI::MXNet::Symbol->expand_dims($_, axis => 0) } @{ $inputs }];
975        $inputs = AI::MXNet::Symbol->Concat(@{ $inputs }, dim => 0);
976    }
977    $begin_state //= $self->begin_state;
978    my $states = $begin_state;
979    my @states = @{ $states };
980    my %states;
981    if($self->_mode eq 'lstm')
982    {
983        %states = (state => $states[0], state_cell => $states[1]);
984    }
985    else
986    {
987        %states = (state => $states[0]);
988    }
989    my $rnn = AI::MXNet::Symbol->RNN(
990        data          => $inputs,
991        parameters    => $self->_parameter,
992        state_size    => $self->_num_hidden,
993        num_layers    => $self->_num_layers,
994        bidirectional => $self->_bidirectional,
995        p             => $self->_dropout,
996        state_outputs => $self->_get_next_state,
997        mode          => $self->_mode,
998        name          => $self->_prefix.'rnn',
999        %states
1000    );
1001    my $outputs;
1002    my %attr = (__layout__ => 'LNC');
1003    if(not $self->_get_next_state)
1004    {
1005        ($outputs, $states) = ($rnn, []);
1006    }
1007    elsif($self->_mode eq 'lstm')
1008    {
1009        my @rnn = @{ $rnn };
1010        $rnn[1]->_set_attr(%attr);
1011        $rnn[2]->_set_attr(%attr);
1012        ($outputs, $states) = ($rnn[0], [$rnn[1], $rnn[2]]);
1013    }
1014    else
1015    {
1016        my @rnn = @{ $rnn };
1017        $rnn[1]->_set_attr(%attr);
1018        ($outputs, $states) = ($rnn[0], [$rnn[1]]);
1019    }
1020    if(defined $merge_outputs and not $merge_outputs)
1021    {
1022        AI::MXNet::Logging->warning(
1023            "Call RNN::FusedCell->unroll with merge_outputs=1 "
1024            ."for faster speed"
1025        );
1026        $outputs = [@ {
1027            AI::MXNet::Symbol->SliceChannel(
1028                $outputs,
1029                axis         => 0,
1030                num_outputs  => $length,
1031                squeeze_axis => 1
1032            )
1033        }];
1034    }
1035    elsif($axis == 1)
1036    {
1037        $outputs = AI::MXNet::Symbol->SwapAxis($outputs, dim1 => 0, dim2 => 1);
1038    }
1039    return ($outputs, $states);
1040}
1041
1042=head2 unfuse
1043
1044    Unfuse the fused RNN
1045
1046    Returns
1047    -------
1048    $cell : AI::MXNet::RNN::SequentialCell
1049        unfused cell that can be used for stepping, and can run on CPU.
1050=cut
1051
1052method unfuse()
1053{
1054    my $stack = AI::MXNet::RNN::SequentialCell->new;
1055    my $get_cell = {
1056        rnn_relu => sub {
1057            AI::MXNet::RNN::Cell->new(
1058                num_hidden => $self->_num_hidden,
1059                activation => 'relu',
1060                prefix     => shift
1061            )
1062        },
1063        rnn_tanh => sub {
1064            AI::MXNet::RNN::Cell->new(
1065                num_hidden => $self->_num_hidden,
1066                activation => 'tanh',
1067                prefix     => shift
1068            )
1069        },
1070        lstm     => sub {
1071            AI::MXNet::RNN::LSTMCell->new(
1072                num_hidden => $self->_num_hidden,
1073                prefix     => shift
1074            )
1075        },
1076        gru      => sub {
1077            AI::MXNet::RNN::GRUCell->new(
1078                num_hidden => $self->_num_hidden,
1079                prefix     => shift
1080            )
1081        },
1082    }->{ $self->_mode };
1083    for my $i (0..$self->_num_layers-1)
1084    {
1085        if($self->_bidirectional)
1086        {
1087            $stack->add(
1088                AI::MXNet::RNN::BidirectionalCell->new(
1089                    $get_cell->(sprintf('%sl%d_', $self->_prefix, $i)),
1090                    $get_cell->(sprintf('%sr%d_', $self->_prefix, $i)),
1091                    output_prefix => sprintf('%sbi_%s_%d', $self->_prefix, $self->_mode, $i)
1092                )
1093            );
1094        }
1095        else
1096        {
1097            $stack->add($get_cell->(sprintf('%sl%d_', $self->_prefix, $i)));
1098        }
1099    }
1100    return $stack;
1101}
1102
1103package AI::MXNet::RNN::SequentialCell;
1104use Mouse;
1105use AI::MXNet::Base;
1106extends 'AI::MXNet::RNN::Cell::Base';
1107
1108=head1 NAME
1109
1110    AI:MXNet::RNN::SequentialCell
1111=cut
1112
1113=head1 DESCRIPTION
1114
1115    Sequentially stacking multiple RNN cells
1116
1117    Parameters
1118    ----------
1119    params : AI::MXNet::RNN::Params or undef
1120        container for weight sharing between cells.
1121        created if undef.
1122=cut
1123
1124has [qw/_override_cell_params _cells/] => (is => 'rw', init_arg => undef);
1125
1126sub BUILD
1127{
1128    my ($self, $original_arguments) = @_;
1129    $self->_override_cell_params(defined $original_arguments->{params});
1130    $self->_cells([]);
1131}
1132
1133=head2 add
1134
1135    Append a cell to the stack.
1136
1137    Parameters
1138    ----------
1139    $cell : AI::MXNet::RNN::Cell::Base
1140=cut
1141
1142method add(AI::MXNet::RNN::Cell::Base $cell)
1143{
1144    push @{ $self->_cells }, $cell;
1145    if($self->_override_cell_params)
1146    {
1147        assert(
1148            $cell->_own_params,
1149            "Either specify params for SequentialRNNCell "
1150            ."or child cells, not both."
1151        );
1152        %{ $cell->params->_params } = (%{ $cell->params->_params }, %{ $self->params->_params });
1153    }
1154    %{ $self->params->_params } = (%{ $self->params->_params }, %{ $cell->params->_params });
1155}
1156
1157method state_info()
1158{
1159    return $self->_cells_state_info($self->_cells);
1160}
1161
1162method begin_state(@kwargs)
1163{
1164    assert(
1165        (not $self->_modified),
1166        "After applying modifier cells (e.g. DropoutCell) the base "
1167        ."cell cannot be called directly. Call the modifier cell instead."
1168    );
1169    return $self->_cells_begin_state($self->_cells, @kwargs);
1170}
1171
1172method unpack_weights(HashRef[AI::MXNet::NDArray] $args)
1173{
1174    return $self->_cells_unpack_weights($self->_cells, $args)
1175}
1176
1177method pack_weights(HashRef[AI::MXNet::NDArray] $args)
1178{
1179    return $self->_cells_pack_weights($self->_cells, $args);
1180}
1181
1182method call($inputs, $states)
1183{
1184    $self->_counter($self->_counter + 1);
1185    my @next_states;
1186    my $p = 0;
1187    for my $cell (@{ $self->_cells })
1188    {
1189        assert(not $cell->isa('AI::MXNet::BidirectionalCell'));
1190        my $n = scalar(@{ $cell->state_info });
1191        my $state = [@{ $states }[$p..$p+$n-1]];
1192        $p += $n;
1193        ($inputs, $state) = $cell->($inputs, $state);
1194        push @next_states, $state;
1195    }
1196    return ($inputs, [map { @$_} @next_states]);
1197}
1198
1199method unroll(
1200    Int $length,
1201    Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=,
1202    Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=,
1203    Str                                                  :$input_prefix='',
1204    Str                                                  :$layout='NTC',
1205    Maybe[Bool]                                          :$merge_outputs=
1206)
1207{
1208    my $num_cells = @{ $self->_cells };
1209    $begin_state //= $self->begin_state;
1210    my $p = 0;
1211    my $states;
1212    my @next_states;
1213    enumerate(sub {
1214        my ($i, $cell) = @_;
1215        my $n   = @{ $cell->state_info };
1216        $states = [@{$begin_state}[$p..$p+$n-1]];
1217        $p += $n;
1218        ($inputs, $states) = $cell->unroll(
1219            $length,
1220            inputs          => $inputs,
1221            input_prefix    => $input_prefix,
1222            begin_state     => $states,
1223            layout          => $layout,
1224            merge_outputs   => ($i < $num_cells-1) ? undef : $merge_outputs
1225        );
1226        push @next_states, $states;
1227    }, $self->_cells);
1228    return ($inputs, [map { @{ $_ } } @next_states]);
1229}
1230
1231package AI::MXNet::RNN::BidirectionalCell;
1232use Mouse;
1233use AI::MXNet::Base;
1234extends 'AI::MXNet::RNN::Cell::Base';
1235
1236=head1 NAME
1237
1238    AI::MXNet::RNN::BidirectionalCell
1239=cut
1240
1241=head1 DESCRIPTION
1242
1243    Bidirectional RNN cell
1244
1245    Parameters
1246    ----------
1247    l_cell : AI::MXNet::RNN::Cell::Base
1248        cell for forward unrolling
1249    r_cell : AI::MXNet::RNN::Cell::Base
1250        cell for backward unrolling
1251    output_prefix : str, default 'bi_'
1252        prefix for name of output
1253=cut
1254
1255has 'l_cell'         => (is => 'ro', isa => 'AI::MXNet::RNN::Cell::Base', required => 1);
1256has 'r_cell'         => (is => 'ro', isa => 'AI::MXNet::RNN::Cell::Base', required => 1);
1257has '_output_prefix' => (is => 'ro', init_arg => 'output_prefix', isa => 'Str', default => 'bi_');
1258has [qw/_override_cell_params _cells/] => (is => 'rw', init_arg => undef);
1259
1260around BUILDARGS => sub {
1261    my $orig  = shift;
1262    my $class = shift;
1263    if(@_ >= 2 and blessed $_[0] and blessed $_[1])
1264    {
1265        my $l_cell = shift(@_);
1266        my $r_cell = shift(@_);
1267        return $class->$orig(
1268            l_cell => $l_cell,
1269            r_cell => $r_cell,
1270            @_
1271        );
1272    }
1273    return $class->$orig(@_);
1274};
1275
1276sub BUILD
1277{
1278    my ($self, $original_arguments) = @_;
1279    $self->_override_cell_params(defined $original_arguments->{params});
1280    if($self->_override_cell_params)
1281    {
1282        assert(
1283            ($self->l_cell->_own_params and $self->r_cell->_own_params),
1284            "Either specify params for BidirectionalCell ".
1285            "or child cells, not both."
1286        );
1287        %{ $self->l_cell->params->_params } = (%{ $self->l_cell->params->_params }, %{ $self->params->_params });
1288        %{ $self->r_cell->params->_params } = (%{ $self->r_cell->params->_params }, %{ $self->params->_params });
1289    }
1290    %{ $self->params->_params } = (%{ $self->params->_params }, %{ $self->l_cell->params->_params });
1291    %{ $self->params->_params } = (%{ $self->params->_params }, %{ $self->r_cell->params->_params });
1292    $self->_cells([$self->l_cell, $self->r_cell]);
1293}
1294
1295method unpack_weights(HashRef[AI::MXNet::NDArray] $args)
1296{
1297    return $self->_cells_unpack_weights($self->_cells, $args)
1298}
1299
1300method pack_weights(HashRef[AI::MXNet::NDArray] $args)
1301{
1302    return $self->_cells_pack_weights($self->_cells, $args);
1303}
1304
1305method call($inputs, $states)
1306{
1307    confess("Bidirectional cannot be stepped. Please use unroll");
1308}
1309
1310method state_info()
1311{
1312    return $self->_cells_state_info($self->_cells);
1313}
1314
1315method begin_state(@kwargs)
1316{
1317    assert((not $self->_modified),
1318            "After applying modifier cells (e.g. DropoutCell) the base "
1319            ."cell cannot be called directly. Call the modifier cell instead."
1320    );
1321    return $self->_cells_begin_state($self->_cells, @kwargs);
1322}
1323
1324method unroll(
1325    Int $length,
1326    Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=,
1327    Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=,
1328    Str                                                  :$input_prefix='',
1329    Str                                                  :$layout='NTC',
1330    Maybe[Bool]                                          :$merge_outputs=
1331)
1332{
1333
1334    my $axis = index($layout, 'T');
1335    if(not defined $inputs)
1336    {
1337        $inputs = [
1338            map { AI::MXNet::Symbol->Variable("${input_prefix}t${_}_data") } (0..$length-1)
1339        ];
1340    }
1341    elsif(blessed($inputs))
1342    {
1343        assert(
1344            (@{ $inputs->list_outputs() } == 1),
1345            "unroll doesn't allow grouped symbol as input. Please "
1346            ."convert to list first or let unroll handle slicing"
1347        );
1348        $inputs = [ @{ AI::MXNet::Symbol->SliceChannel(
1349            $inputs,
1350            axis         => $axis,
1351            num_outputs  => $length,
1352            squeeze_axis => 1
1353        ) }];
1354    }
1355    else
1356    {
1357        assert(@$inputs == $length);
1358    }
1359    $begin_state //= $self->begin_state;
1360    my $states = $begin_state;
1361    my ($l_cell, $r_cell) = @{ $self->_cells };
1362    my ($l_outputs, $l_states) = $l_cell->unroll(
1363        $length, inputs => $inputs,
1364        begin_state     => [@{$states}[0..@{$l_cell->state_info}-1]],
1365        layout          => $layout,
1366        merge_outputs   => $merge_outputs
1367    );
1368    my ($r_outputs, $r_states) = $r_cell->unroll(
1369        $length, inputs => [reverse @{$inputs}],
1370        begin_state     => [@{$states}[@{$l_cell->state_info}..@{$states}-1]],
1371        layout          => $layout,
1372        merge_outputs   => $merge_outputs
1373    );
1374    if(not defined $merge_outputs)
1375    {
1376        $merge_outputs = (
1377            blessed $l_outputs and $l_outputs->isa('AI::MXNet::Symbol')
1378                and
1379            blessed $r_outputs and $r_outputs->isa('AI::MXNet::Symbol')
1380        );
1381        if(not $merge_outputs)
1382        {
1383            if(blessed $l_outputs and $l_outputs->isa('AI::MXNet::Symbol'))
1384            {
1385                $l_outputs = [
1386                    @{ AI::MXNet::Symbol->SliceChannel(
1387                        $l_outputs, axis => $axis,
1388                        num_outputs      => $length,
1389                        squeeze_axis     => 1
1390                    ) }
1391                ];
1392            }
1393            if(blessed $r_outputs and $r_outputs->isa('AI::MXNet::Symbol'))
1394            {
1395                $r_outputs = [
1396                    @{ AI::MXNet::Symbol->SliceChannel(
1397                        $r_outputs, axis => $axis,
1398                        num_outputs      => $length,
1399                        squeeze_axis     => 1
1400                    ) }
1401                ];
1402            }
1403        }
1404    }
1405    if($merge_outputs)
1406    {
1407        $l_outputs = [@{ $l_outputs }];
1408        $r_outputs = [@{ AI::MXNet::Symbol->reverse(blessed $r_outputs ? $r_outputs : @{ $r_outputs }, axis=>$axis) }];
1409    }
1410    else
1411    {
1412        $r_outputs = [reverse(@{ $r_outputs })];
1413    }
1414    my $outputs = [];
1415    for(zip([0..@{ $l_outputs }-1], [@{ $l_outputs }], [@{ $r_outputs }])) {
1416        my ($i, $l_o, $r_o) = @$_;
1417        push @$outputs, AI::MXNet::Symbol->Concat(
1418            $l_o, $r_o, dim=>(1+($merge_outputs?1:0)),
1419            name => $merge_outputs
1420                        ? sprintf('%sout', $self->_output_prefix)
1421                        : sprintf('%st%d', $self->_output_prefix, $i)
1422        );
1423    }
1424    if($merge_outputs)
1425    {
1426        $outputs = @{ $outputs }[0];
1427    }
1428    $states = [$l_states, $r_states];
1429    return($outputs, $states);
1430}
1431
1432package AI::MXNet::RNN::ConvCell::Base;
1433use Mouse;
1434use AI::MXNet::Base;
1435extends 'AI::MXNet::RNN::Cell::Base';
1436
1437=head1 NAME
1438
1439    AI::MXNet::RNN::Conv::Base
1440=cut
1441
1442=head1 DESCRIPTION
1443
1444    Abstract base class for Convolutional RNN cells
1445
1446=cut
1447
1448has '_h2h_kernel'  => (is => 'ro', isa => 'Shape', init_arg => 'h2h_kernel');
1449has '_h2h_dilate'  => (is => 'ro', isa => 'Shape', init_arg => 'h2h_dilate');
1450has '_h2h_pad'     => (is => 'rw', isa => 'Shape', init_arg => undef);
1451has '_i2h_kernel'  => (is => 'ro', isa => 'Shape', init_arg => 'i2h_kernel');
1452has '_i2h_stride'  => (is => 'ro', isa => 'Shape', init_arg => 'i2h_stride');
1453has '_i2h_dilate'  => (is => 'ro', isa => 'Shape', init_arg => 'i2h_dilate');
1454has '_i2h_pad'     => (is => 'ro', isa => 'Shape', init_arg => 'i2h_pad');
1455has '_num_hidden'  => (is => 'ro', isa => 'DimSize', init_arg => 'num_hidden');
1456has '_input_shape' => (is => 'ro', isa => 'Shape', init_arg => 'input_shape');
1457has '_conv_layout' => (is => 'ro', isa => 'Str', init_arg => 'conv_layout', default => 'NCHW');
1458has '_activation'  => (is => 'ro', init_arg => 'activation');
1459has '_state_shape' => (is => 'rw', init_arg => undef);
1460has [qw/i2h_weight_initializer h2h_weight_initializer
1461    i2h_bias_initializer h2h_bias_initializer/] => (is => 'rw', isa => 'Maybe[Initializer]');
1462
1463sub BUILD
1464{
1465    my $self = shift;
1466    assert (
1467        ($self->_h2h_kernel->[0] % 2 == 1 and $self->_h2h_kernel->[1] % 2 == 1),
1468        "Only support odd numbers, got h2h_kernel= (@{[ $self->_h2h_kernel ]})"
1469    );
1470    $self->_h2h_pad([
1471        int($self->_h2h_dilate->[0] * ($self->_h2h_kernel->[0] - 1) / 2),
1472        int($self->_h2h_dilate->[1] * ($self->_h2h_kernel->[1] - 1) / 2)
1473    ]);
1474    # Infer state shape
1475    my $data = AI::MXNet::Symbol->Variable('data');
1476    my $state_shape = AI::MXNet::Symbol->Convolution(
1477        data => $data,
1478        num_filter => $self->_num_hidden,
1479        kernel => $self->_i2h_kernel,
1480        stride => $self->_i2h_stride,
1481        pad => $self->_i2h_pad,
1482        dilate => $self->_i2h_dilate,
1483        layout => $self->_conv_layout
1484    );
1485    $state_shape = ($state_shape->infer_shape(data=>$self->_input_shape))[1]->[0];
1486    $state_shape->[0] = 0;
1487    $self->_state_shape($state_shape);
1488}
1489
1490method state_info()
1491{
1492    return [
1493                { shape => $self->_state_shape, __layout__ => $self->_conv_layout },
1494                { shape => $self->_state_shape, __layout__ => $self->_conv_layout }
1495    ];
1496}
1497
1498method call($inputs, $states)
1499{
1500    confess("AI::MXNet::RNN::ConvCell::Base is abstract class for convolutional RNN");
1501}
1502
1503package AI::MXNet::RNN::ConvCell;
1504use Mouse;
1505extends 'AI::MXNet::RNN::ConvCell::Base';
1506
1507=head1 NAME
1508
1509    AI::MXNet::RNN::ConvCell
1510=cut
1511
1512=head1 DESCRIPTION
1513
1514    Convolutional RNN cells
1515
1516    Parameters
1517    ----------
1518    input_shape : array ref of int
1519        Shape of input in single timestep.
1520    num_hidden : int
1521        Number of units in output symbol.
1522    h2h_kernel : array ref of int, default (3, 3)
1523        Kernel of Convolution operator in state-to-state transitions.
1524    h2h_dilate : array ref of int, default (1, 1)
1525        Dilation of Convolution operator in state-to-state transitions.
1526    i2h_kernel : array ref of int, default (3, 3)
1527        Kernel of Convolution operator in input-to-state transitions.
1528    i2h_stride : array ref of int, default (1, 1)
1529        Stride of Convolution operator in input-to-state transitions.
1530    i2h_pad : array ref of int, default (1, 1)
1531        Pad of Convolution operator in input-to-state transitions.
1532    i2h_dilate : array ref of int, default (1, 1)
1533        Dilation of Convolution operator in input-to-state transitions.
1534    activation : str or Symbol,
1535        default functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2)
1536        Type of activation function.
1537    prefix : str, default 'ConvRNN_'
1538        Prefix for name of layers (and name of weight if params is None).
1539    params : RNNParams, default None
1540        Container for weight sharing between cells. Created if None.
1541    conv_layout : str, , default 'NCHW'
1542        Layout of ConvolutionOp
1543=cut
1544
1545has '+_h2h_kernel' => (default => sub { [3, 3] });
1546has '+_h2h_dilate' => (default => sub { [1, 1] });
1547has '+_i2h_kernel' => (default => sub { [3, 3] });
1548has '+_i2h_stride' => (default => sub { [1, 1] });
1549has '+_i2h_dilate' => (default => sub { [1, 1] });
1550has '+_i2h_pad'    => (default => sub { [1, 1] });
1551has '+_prefix'     => (default => 'ConvRNN_');
1552has '+_activation' => (default => sub { sub { AI::MXNet::Symbol->LeakyReLU(@_, act_type => 'leaky', slope => 0.2) } });
1553has '+i2h_bias_initializer' => (default => 'zeros');
1554has '+h2h_bias_initializer' => (default => 'zeros');
1555has 'forget_bias'  => (is => 'ro', isa => 'Num');
1556has [qw/_iW _iB
1557        _hW _hB/] => (is => 'rw', init_arg => undef);
1558
1559
1560sub BUILD
1561{
1562    my $self = shift;
1563    $self->_iW($self->_params->get('i2h_weight', init => $self->i2h_weight_initializer));
1564    $self->_hW($self->_params->get('h2h_weight', init => $self->h2h_weight_initializer));
1565    $self->_iB(
1566        $self->params->get(
1567            'i2h_bias',
1568            (defined($self->forget_bias and not defined $self->i2h_bias_initializer)
1569                ? (init => AI::MXNet::LSTMBias->new(forget_bias => $self->forget_bias))
1570                : (init => $self->i2h_bias_initializer)
1571            )
1572        )
1573    );
1574    $self->_hB($self->_params->get('h2h_bias', init => $self->h2h_bias_initializer));
1575}
1576
1577method _num_gates()
1578{
1579    scalar(@{ $self->_gate_names() });
1580}
1581
1582method _gate_names()
1583{
1584    return ['']
1585}
1586
1587method _conv_forward($inputs, $states, $name)
1588{
1589    my $i2h = AI::MXNet::Symbol->Convolution(
1590        name       => "${name}i2h",
1591        data       => $inputs,
1592        num_filter => $self->_num_hidden*$self->_num_gates(),
1593        kernel     => $self->_i2h_kernel,
1594        stride     => $self->_i2h_stride,
1595        pad        => $self->_i2h_pad,
1596        dilate     => $self->_i2h_dilate,
1597        weight     => $self->_iW,
1598        bias       => $self->_iB
1599    );
1600    my $h2h = AI::MXNet::Symbol->Convolution(
1601        name       => "${name}h2h",
1602        data       => @{ $states }[0],
1603        num_filter => $self->_num_hidden*$self->_num_gates(),
1604        kernel     => $self->_h2h_kernel,
1605        stride     => [1, 1],
1606        pad        => $self->_h2h_pad,
1607        dilate     => $self->_h2h_dilate,
1608        weight     => $self->_hW,
1609        bias       => $self->_hB
1610    );
1611    return ($i2h, $h2h);
1612}
1613
1614method call(AI::MXNet::Symbol $inputs, AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol] $states)
1615{
1616    $self->_counter($self->_counter + 1);
1617    my $name = sprintf('%st%d_', $self->_prefix, $self->_counter);
1618    my ($i2h, $h2h) = $self->_conv_forward($inputs, $states, $name);
1619    my $output = $self->_get_activation($i2h + $h2h, $self->_activation, name => "${name}out");
1620    return ($output, [$output]);
1621}
1622
1623package AI::MXNet::RNN::ConvLSTMCell;
1624use Mouse;
1625extends 'AI::MXNet::RNN::ConvCell';
1626has '+forget_bias' => (default => 1);
1627has '+_prefix'     => (default => 'ConvLSTM_');
1628
1629=head1 NAME
1630
1631    AI::MXNet::RNN::ConvLSTMCell
1632=cut
1633
1634=head1 DESCRIPTION
1635
1636    Convolutional LSTM network cell.
1637
1638    Reference:
1639        Xingjian et al. NIPS2015
1640=cut
1641
1642method _gate_names()
1643{
1644    return ['_i', '_f', '_c', '_o'];
1645}
1646
1647method call(AI::MXNet::Symbol $inputs, AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol] $states)
1648{
1649    $self->_counter($self->_counter + 1);
1650    my $name = sprintf('%st%d_', $self->_prefix, $self->_counter);
1651    my ($i2h, $h2h) = $self->_conv_forward($inputs, $states, $name);
1652    my $gates = $i2h + $h2h;
1653    my @slice_gates = @{ AI::MXNet::Symbol->SliceChannel(
1654        $gates,
1655        num_outputs => 4,
1656        axis => index($self->_conv_layout, 'C'),
1657        name => "${name}slice"
1658    ) };
1659    my $in_gate = AI::MXNet::Symbol->Activation(
1660        $slice_gates[0],
1661        act_type => "sigmoid",
1662        name => "${name}i"
1663    );
1664    my $forget_gate = AI::MXNet::Symbol->Activation(
1665        $slice_gates[1],
1666        act_type => "sigmoid",
1667        name => "${name}f"
1668    );
1669    my $in_transform = $self->_get_activation(
1670        $slice_gates[2],
1671        $self->_activation,
1672        name => "${name}c"
1673    );
1674    my $out_gate = AI::MXNet::Symbol->Activation(
1675        $slice_gates[3],
1676        act_type => "sigmoid",
1677        name => "${name}o"
1678    );
1679    my $next_c = AI::MXNet::Symbol->_plus(
1680        $forget_gate * @{$states}[1],
1681        $in_gate * $in_transform,
1682        name => "${name}state"
1683    );
1684    my $next_h = AI::MXNet::Symbol->_mul(
1685        $out_gate, $self->_get_activation($next_c, $self->_activation),
1686        name => "${name}out"
1687    );
1688    return ($next_h, [$next_h, $next_c]);
1689}
1690
1691package AI::MXNet::RNN::ConvGRUCell;
1692use Mouse;
1693extends 'AI::MXNet::RNN::ConvCell';
1694has '+_prefix'     => (default => 'ConvGRU_');
1695
1696=head1 NAME
1697
1698    AI::MXNet::RNN::ConvGRUCell
1699=cut
1700
1701=head1 DESCRIPTION
1702
1703    Convolutional GRU network cell.
1704=cut
1705
1706method _gate_names()
1707{
1708    return ['_r', '_z', '_o'];
1709}
1710
1711method call(AI::MXNet::Symbol $inputs, AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol] $states)
1712{
1713    $self->_counter($self->_counter + 1);
1714    my $name = sprintf('%st%d_', $self->_prefix, $self->_counter);
1715    my ($i2h, $h2h) = $self->_conv_forward($inputs, $states, $name);
1716    my ($i2h_r, $i2h_z, $h2h_r, $h2h_z);
1717    ($i2h_r, $i2h_z, $i2h) = @{ AI::MXNet::Symbol->SliceChannel($i2h, num_outputs => 3, name => "${name}_i2h_slice") };
1718    ($h2h_r, $h2h_z, $h2h) = @{ AI::MXNet::Symbol->SliceChannel($h2h, num_outputs => 3, name => "${name}_h2h_slice") };
1719    my $reset_gate = AI::MXNet::Symbol->Activation(
1720        $i2h_r + $h2h_r, act_type => "sigmoid",
1721        name => "${name}_r_act"
1722    );
1723    my $update_gate = AI::MXNet::Symbol->Activation(
1724        $i2h_z + $h2h_z, act_type => "sigmoid",
1725        name => "${name}_z_act"
1726    );
1727    my $next_h_tmp = $self->_get_activation($i2h + $reset_gate * $h2h, $self->_activation, name => "${name}_h_act");
1728    my $next_h = AI::MXNet::Symbol->_plus(
1729        (1 - $update_gate) * $next_h_tmp, $update_gate * @{$states}[0],
1730        name => "${name}out"
1731    );
1732    return ($next_h, [$next_h]);
1733}
1734
1735package AI::MXNet::RNN::ModifierCell;
1736use Mouse;
1737use AI::MXNet::Base;
1738extends 'AI::MXNet::RNN::Cell::Base';
1739
1740=head1 NAME
1741
1742    AI::MXNet::RNN::ModifierCell
1743=cut
1744
1745=head1 DESCRIPTION
1746
1747    Base class for modifier cells. A modifier
1748    cell takes a base cell, apply modifications
1749    on it (e.g. Dropout), and returns a new cell.
1750
1751    After applying modifiers the base cell should
1752    no longer be called directly. The modifer cell
1753    should be used instead.
1754=cut
1755
1756has 'base_cell' => (is => 'ro', isa => 'AI::MXNet::RNN::Cell::Base', required => 1);
1757
1758around BUILDARGS => sub {
1759    my $orig  = shift;
1760    my $class = shift;
1761    if(@_%2)
1762    {
1763        my $base_cell = shift;
1764        return $class->$orig(base_cell => $base_cell, @_);
1765    }
1766    return $class->$orig(@_);
1767};
1768
1769sub BUILD
1770{
1771    my $self = shift;
1772    $self->base_cell->_modified(1);
1773}
1774
1775method params()
1776{
1777    $self->_own_params(0);
1778    return $self->base_cell->params;
1779}
1780
1781method state_info()
1782{
1783    return $self->base_cell->state_info;
1784}
1785
1786method begin_state(CodeRef :$init_sym=AI::MXNet::Symbol->can('zeros'), @kwargs)
1787{
1788    assert(
1789        (not $self->_modified),
1790        "After applying modifier cells (e.g. DropoutCell) the base "
1791        ."cell cannot be called directly. Call the modifier cell instead."
1792    );
1793    $self->base_cell->_modified(0);
1794    my $begin_state = $self->base_cell->begin_state(func => $init_sym, @kwargs);
1795    $self->base_cell->_modified(1);
1796    return $begin_state;
1797}
1798
1799method unpack_weights(HashRef[AI::MXNet::NDArray] $args)
1800{
1801    return $self->base_cell->unpack_weights($args)
1802}
1803
1804method pack_weights(HashRef[AI::MXNet::NDArray] $args)
1805{
1806    return $self->base_cell->pack_weights($args)
1807}
1808
1809method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states)
1810{
1811    confess("Not Implemented");
1812}
1813
1814package AI::MXNet::RNN::DropoutCell;
1815use Mouse;
1816extends 'AI::MXNet::RNN::ModifierCell';
1817has [qw/dropout_outputs dropout_states/] => (is => 'ro', isa => 'Num', default => 0);
1818
1819=head1 NAME
1820
1821    AI::MXNet::RNN::DropoutCell
1822=cut
1823
1824=head1 DESCRIPTION
1825
1826    Apply the dropout on base cell
1827=cut
1828
1829method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states)
1830{
1831    my ($output, $states) = $self->base_cell->($inputs, $states);
1832    if($self->dropout_outputs > 0)
1833    {
1834        $output = AI::MXNet::Symbol->Dropout(data => $output, p => $self->dropout_outputs);
1835    }
1836    if($self->dropout_states > 0)
1837    {
1838        $states = [map { AI::MXNet::Symbol->Dropout(data => $_, p => $self->dropout_states) } @{ $states }];
1839    }
1840    return ($output, $states);
1841}
1842
1843package AI::MXNet::RNN::ZoneoutCell;
1844use Mouse;
1845use AI::MXNet::Base;
1846extends 'AI::MXNet::RNN::ModifierCell';
1847has [qw/zoneout_outputs zoneout_states/] => (is => 'ro', isa => 'Num', default => 0);
1848has 'prev_output' => (is => 'rw', init_arg => undef);
1849
1850=head1 NAME
1851
1852    AI::MXNet::RNN::ZoneoutCell
1853=cut
1854
1855=head1 DESCRIPTION
1856
1857    Apply Zoneout on base cell.
1858=cut
1859
1860sub BUILD
1861{
1862    my $self = shift;
1863    assert(
1864        (not $self->base_cell->isa('AI::MXNet::RNN::FusedCell')),
1865        "FusedRNNCell doesn't support zoneout. ".
1866        "Please unfuse first."
1867    );
1868    assert(
1869        (not $self->base_cell->isa('AI::MXNet::RNN::BidirectionalCell')),
1870        "BidirectionalCell doesn't support zoneout since it doesn't support step. ".
1871        "Please add ZoneoutCell to the cells underneath instead."
1872    );
1873    assert(
1874        (not $self->base_cell->isa('AI::MXNet::RNN::SequentialCell') or not $self->_bidirectional),
1875        "Bidirectional SequentialCell doesn't support zoneout. ".
1876        "Please add ZoneoutCell to the cells underneath instead."
1877    );
1878}
1879
1880method reset()
1881{
1882    $self->SUPER::reset;
1883    $self->prev_output(undef);
1884}
1885
1886method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states)
1887{
1888    my ($cell, $p_outputs, $p_states) = ($self->base_cell, $self->zoneout_outputs, $self->zoneout_states);
1889    my ($next_output, $next_states) = $cell->($inputs, $states);
1890    my $mask = sub {
1891        my ($p, $like) = @_;
1892        AI::MXNet::Symbol->Dropout(
1893            AI::MXNet::Symbol->ones_like(
1894                $like
1895            ),
1896            p => $p
1897        );
1898    };
1899    my $prev_output = $self->prev_output // AI::MXNet::Symbol->zeros(shape => [0, 0]);
1900    my $output = $p_outputs != 0
1901        ? AI::MXNet::Symbol->where(
1902            $mask->($p_outputs, $next_output),
1903            $next_output,
1904            $prev_output
1905        )
1906        : $next_output;
1907    my @states;
1908    if($p_states != 0)
1909    {
1910        for(zip($next_states, $states)) {
1911            my ($new_s, $old_s) = @$_;
1912            push @states, AI::MXNet::Symbol->where(
1913                $mask->($p_states, $new_s),
1914                $new_s,
1915                $old_s
1916            );
1917        }
1918    }
1919    $self->prev_output($output);
1920    return ($output, @states ? \@states : $next_states);
1921}
1922
1923package AI::MXNet::RNN::ResidualCell;
1924use Mouse;
1925use AI::MXNet::Base;
1926extends 'AI::MXNet::RNN::ModifierCell';
1927
1928=head1 NAME
1929
1930    AI::MXNet::RNN::ResidualCell
1931=cut
1932
1933=head1 DESCRIPTION
1934
1935    Adds residual connection as described in Wu et al, 2016
1936    (https://arxiv.org/abs/1609.08144).
1937    Output of the cell is output of the base cell plus input.
1938=cut
1939
1940method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states)
1941{
1942    my $output;
1943    ($output, $states) = $self->base_cell->($inputs, $states);
1944    $output = AI::MXNet::Symbol->elemwise_add($output, $inputs, name => $output->name.'_plus_residual');
1945    return ($output, $states)
1946}
1947
1948method unroll(
1949    Int $length,
1950    Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=,
1951    Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=,
1952    Str                                                  :$input_prefix='',
1953    Str                                                  :$layout='NTC',
1954    Maybe[Bool]                                          :$merge_outputs=
1955)
1956{
1957    $self->reset;
1958    $self->base_cell->_modified(0);
1959    my ($outputs, $states) = $self->base_cell->unroll($length, inputs=>$inputs, begin_state=>$begin_state,
1960                                                layout=>$layout, merge_outputs=>$merge_outputs);
1961    $self->base_cell->_modified(1);
1962    $merge_outputs //= (blessed($outputs) and $outputs->isa('AI::MXNet::Symbol'));
1963    ($inputs) = _normalize_sequence($length, $inputs, $layout, $merge_outputs);
1964    if($merge_outputs)
1965    {
1966        $outputs = AI::MXNet::Symbol->elemwise_add($outputs, $inputs, name => $outputs->name . "_plus_residual");
1967    }
1968    else
1969    {
1970        my @temp;
1971        for(zip([@{ $outputs }], [@{ $inputs }])) {
1972            my ($output_sym, $input_sym) = @$_;
1973            push @temp, AI::MXNet::Symbol->elemwise_add($output_sym, $input_sym,
1974                            name=>$output_sym->name."_plus_residual");
1975        }
1976        $outputs = \@temp;
1977    }
1978    return ($outputs, $states);
1979}
1980
1981func _normalize_sequence($length, $inputs, $layout, $merge, $in_layout=)
1982{
1983    assert((defined $inputs),
1984        "unroll(inputs=>undef) has been deprecated. ".
1985        "Please create input variables outside unroll."
1986    );
1987
1988    my $axis = index($layout, 'T');
1989    my $in_axis = defined $in_layout ? index($in_layout, 'T') : $axis;
1990    if(blessed($inputs))
1991    {
1992        if(not $merge)
1993        {
1994            assert(
1995                (@{ $inputs->list_outputs() } == 1),
1996                "unroll doesn't allow grouped symbol as input. Please "
1997                ."convert to list first or let unroll handle splitting"
1998            );
1999            $inputs = [ @{ AI::MXNet::Symbol->split(
2000                $inputs,
2001                axis         => $in_axis,
2002                num_outputs  => $length,
2003                squeeze_axis => 1
2004            ) }];
2005        }
2006    }
2007    else
2008    {
2009        assert(not defined $length or @$inputs == $length);
2010        if($merge)
2011        {
2012            $inputs = [map { AI::MXNet::Symbol->expand_dims($_, axis=>$axis) } @{ $inputs }];
2013            $inputs = AI::MXNet::Symbol->Concat(@{ $inputs }, dim=>$axis);
2014            $in_axis = $axis;
2015        }
2016    }
2017
2018    if(blessed($inputs) and $axis != $in_axis)
2019    {
2020        $inputs = AI::MXNet::Symbol->swapaxes($inputs, dim0=>$axis, dim1=>$in_axis);
2021    }
2022    return ($inputs, $axis);
2023}
2024
20251;
2026