1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18use strict;
19use warnings;
20package AI::MXNet::Gluon::RNN::RecurrentCell;
21use Mouse::Role;
22use AI::MXNet::Base;
23use AI::MXNet::Function::Parameters;
24
25method _cells_state_info($cells, $batch_size)
26{
27    return [map { @{ $_->state_info($batch_size) } } $cells->values];
28}
29
30method _cells_begin_state($cells, %kwargs)
31{
32    return [map { @{ $_->begin_state(%kwargs) } } $cells->values];
33}
34
35method _get_begin_state(GluonClass $F, $begin_state, GluonInput $inputs, $batch_size)
36{
37    if(not defined $begin_state)
38    {
39        if($F =~ /AI::MXNet::NDArray/)
40        {
41            my $ctx = blessed $inputs ? $inputs->context : $inputs->[0]->context;
42            {
43                local($AI::MXNet::current_ctx) = $ctx;
44                my $func = sub {
45                    my %kwargs = @_;
46                    my $shape = delete $kwargs{shape};
47                    return AI::MXNet::NDArray->zeros($shape, %kwargs);
48                };
49                $begin_state = $self->begin_state(batch_size => $batch_size, func => $func);
50            }
51        }
52        else
53        {
54            $begin_state = $self->begin_state(batch_size => $batch_size, func => sub { return $F->zeros(@_) });
55        }
56    }
57    return $begin_state;
58}
59
60
61method _format_sequence($length, $inputs, $layout, $merge, $in_layout=)
62{
63    assert(
64        (defined $inputs),
65        "unroll(inputs=None) has been deprecated. ".
66        "Please create input variables outside unroll."
67    );
68
69    my $axis = index($layout, 'T');
70    my $batch_axis = index($layout, 'N');
71    my $batch_size = 0;
72    my $in_axis = defined $in_layout ? index($in_layout, 'T') : $axis;
73    my $F;
74    if(blessed $inputs and $inputs->isa('AI::MXNet::Symbol'))
75    {
76        $F = 'AI::MXNet::Symbol';
77        if(not $merge)
78        {
79            assert(
80                (@{ $inputs->list_outputs() } == 1),
81                "unroll doesn't allow grouped symbol as input. Please convert ".
82                "to list with list(inputs) first or let unroll handle splitting"
83            );
84            $inputs = [
85                AI::MXNet::Symbol->split(
86                    $inputs, axis => $in_axis, num_outputs => $length, squeeze_axis => 1
87                )
88            ];
89        }
90    }
91    elsif(blessed $inputs and $inputs->isa('AI::MXNet::NDArray'))
92    {
93        $F = 'AI::MXNet::NDArray';
94        $batch_size = $inputs->shape->[$batch_axis];
95        if(not $merge)
96        {
97            assert(not defined $length or $length == $inputs->shape->[$in_axis]);
98            $inputs = as_array(
99                AI::MXNet::NDArray->split(
100                    $inputs, axis=>$in_axis,
101                    num_outputs => $inputs->shape->[$in_axis],
102                    squeeze_axis => 1
103                )
104            );
105        }
106    }
107    else
108    {
109        assert(not defined $length or @{ $inputs } == $length);
110        if($inputs->[0]->isa('AI::MXNet::Symbol'))
111        {
112            $F = 'AI::MXNet::Symbol';
113        }
114        else
115        {
116            $F = 'AI::MXNet::NDArray';
117            $batch_size = $inputs->[0]->shape->[$batch_axis];
118        }
119        if($merge)
120        {
121            $inputs  = [map { $F->expand_dims($_, axis => $axis) } @{ $inputs }];
122            $inputs  = $F->stack(@{ $inputs }, axis => $axis);
123            $in_axis = $axis;
124        }
125    }
126    if(blessed $inputs and $axis != $in_axis)
127    {
128        $inputs = $F->swapaxes($inputs, dim1=>$axis, dim2=>$in_axis);
129    }
130    return ($inputs, $axis, $F, $batch_size);
131}
132
133method _mask_sequence_variable_length($F, $data, $length, $valid_length, $time_axis, $merge)
134{
135    assert(defined $valid_length);
136    if(not blessed $data)
137    {
138        $data = $F->stack(@$data, axis=>$time_axis);
139    }
140    my $outputs = $F->SequenceMask($data, { sequence_length=>$valid_length, use_sequence_length=>1,
141                             axis=>$time_axis});
142    if(not $merge)
143    {
144        $outputs = $F->split($outputs, { num_outputs=>$length, axis=>$time_axis,
145                                   squeeze_axis=>1});
146        if(not ref $outputs eq 'ARRAY')
147        {
148            $outputs = [$outputs];
149        }
150    }
151    return $outputs;
152}
153
154method _reverse_sequences($sequences, $unroll_step, $valid_length=)
155{
156    my $F;
157    if($sequences->[0]->isa('AI::MXNet::Symbol'))
158    {
159        $F = 'AI::MXNet::Symbol';
160    }
161    else
162    {
163        $F = 'AI::MXNet::NDArray';
164    }
165
166    my $reversed_sequences;
167    if(not defined $valid_length)
168    {
169        $reversed_sequences = [reverse(@$sequences)];
170    }
171    else
172    {
173        $reversed_sequences = $F->SequenceReverse($F->stack(@$sequences, axis=>0),
174                                               {sequence_length=>$valid_length,
175                                               use_sequence_length=>1});
176        $reversed_sequences = $F->split($reversed_sequences, {axis=>0, num_outputs=>$unroll_step, squeeze_axis=>1});
177    }
178    return $reversed_sequences;
179}
180
181=head1 NAME
182
183    AI::MXNet::Gluon::RNN::RecurrentCell
184=cut
185
186=head1 DESCRIPTION
187
188    Abstract role for RNN cells
189
190    Parameters
191    ----------
192    prefix : str, optional
193        Prefix for names of `Block`s
194        (this prefix is also used for names of weights if `params` is `None`
195        i.e. if `params` are being created and not reused)
196    params : Parameter or None, optional
197        Container for weight sharing between cells.
198        A new Parameter container is created if `params` is `None`.
199=cut
200
201=head2 reset
202
203    Reset before re-using the cell for another graph.
204=cut
205
206method reset()
207{
208    $self->init_counter(-1);
209    $self->counter(-1);
210    $_->reset for $self->_children->values;
211}
212
213=head2 state_info
214
215    Shape and layout information of states
216=cut
217method state_info(Int $batch_size=0)
218{
219    confess('Not Implemented');
220}
221
222=head2 begin_state
223
224        Initial state for this cell.
225
226        Parameters
227        ----------
228        $func : CodeRef, default sub { AI::MXNet::Symbol->zeros(@_) }
229            Function for creating initial state.
230
231            For Symbol API, func can be `symbol.zeros`, `symbol.uniform`,
232            `symbol.var etc`. Use `symbol.var` if you want to directly
233            feed input as states.
234
235            For NDArray API, func can be `ndarray.zeros`, `ndarray.ones`, etc.
236        $batch_size: int, default 0
237            Only required for NDArray API. Size of the batch ('N' in layout)
238            dimension of input.
239
240        %kwargs :
241            Additional keyword arguments passed to func. For example
242            `mean`, `std`, `dtype`, etc.
243
244        Returns
245        -------
246        states : nested array ref of Symbol
247            Starting states for the first RNN step.
248=cut
249
250method begin_state(Int :$batch_size=0, CodeRef :$func=, %kwargs)
251{
252    $func //= sub {
253        my %kwargs = @_;
254        my $shape = delete $kwargs{shape};
255        return AI::MXNet::NDArray->zeros($shape, %kwargs);
256    };
257    assert(
258        (not $self->modified),
259        "After applying modifier cells (e.g. ZoneoutCell) the base ".
260        "cell cannot be called directly. Call the modifier cell instead."
261    );
262    my @states;
263    for my $info (@{ $self->state_info($batch_size) })
264    {
265        $self->init_counter($self->init_counter + 1);
266        if(defined $info)
267        {
268            %$info = (%$info, %kwargs);
269        }
270        else
271        {
272            $info = \%kwargs;
273        }
274        my $state = $func->(
275            name => "${\ $self->_prefix }begin_state_${\ $self->init_counter }",
276            %$info
277        );
278        push @states, $state;
279    }
280    return \@states;
281}
282
283=head2 unroll
284
285        Unrolls an RNN cell across time steps.
286
287        Parameters
288        ----------
289        $length : int
290            Number of steps to unroll.
291        $inputs : Symbol, list of Symbol, or None
292            If `inputs` is a single Symbol (usually the output
293            of Embedding symbol), it should have shape
294            (batch_size, length, ...) if `layout` is 'NTC',
295            or (length, batch_size, ...) if `layout` is 'TNC'.
296
297            If `inputs` is a list of symbols (usually output of
298            previous unroll), they should all have shape
299            (batch_size, ...).
300        :$begin_state : nested list of Symbol, optional
301            Input states created by `begin_state()`
302            or output state of another cell.
303            Created from `begin_state()` if `None`.
304        :$layout : str, optional
305            `layout` of input symbol. Only used if inputs
306            is a single Symbol.
307        :$merge_outputs : bool, optional
308            If `False`, returns outputs as a list of Symbols.
309            If `True`, concatenates output across time steps
310            and returns a single symbol with shape
311            (batch_size, length, ...) if layout is 'NTC',
312            or (length, batch_size, ...) if layout is 'TNC'.
313            If `None`, output whatever is faster.
314
315        Returns
316        -------
317        outputs : list of Symbol or Symbol
318            Symbol (if `merge_outputs` is True) or list of Symbols
319            (if `merge_outputs` is False) corresponding to the output from
320            the RNN from this unrolling.
321
322        states : list of Symbol
323            The new state of this RNN after this unrolling.
324            The type of this symbol is same as the output of `begin_state()`.
325=cut
326
327method unroll(
328    Int $length,
329    Maybe[GluonInput] $inputs,
330    Maybe[GluonInput] :$begin_state=,
331    Str :$layout='NTC',
332    Maybe[Bool] :$merge_outputs=,
333    Maybe[Bool] :$valid_length=
334)
335{
336    $self->reset();
337    my ($F, $batch_size, $axis);
338    ($inputs, $axis, $F, $batch_size) = $self->_format_sequence($length, $inputs, $layout, 0);
339    $begin_state //= $self->_get_begin_state($F, $begin_state, $inputs, $batch_size);
340
341    my $states = $begin_state;
342    my $outputs = [];
343    my $all_states = [];
344    for my $i (0..$length-1)
345    {
346        my $output;
347        ($output, $states) = $self->($inputs->[$i], $states);
348        push @$outputs, $output;
349        if(defined $valid_length)
350        {
351            push @$all_states, $states;
352        }
353    }
354    if(defined $valid_length)
355    {
356        $states = [];
357        for(zip(@$all_states))
358        {
359            push @$states, $F->SequenceLast($F->stack(@$_, axis=>0),
360                                     sequence_length=>$valid_length,
361                                     use_sequence_length=>1,
362                                     axis=>0);
363        }
364        $outputs = $self->_mask_sequence_variable_length($F, $outputs, $length, $valid_length, $axis, 1);
365    }
366    ($outputs) = $self->_format_sequence($length, $outputs, $layout, $merge_outputs);
367    return ($outputs, $states);
368}
369
370method _get_activation(GluonClass $F, GluonInput $inputs, Activation $activation, %kwargs)
371{
372    if(not blessed $activation)
373    {
374        my %act = map { $_ => 1 } qw(tanh relu sigmoid softsign);
375        if(exists $act{$activation})
376        {
377            return $F->$activation($inputs, %kwargs)
378        }
379        return $F->Activation($inputs, act_type=>$activation, %kwargs);
380    }
381    elsif(ref($activation) =~ /LeakyReLU/)
382    {
383        return $F->LeakyReLU($inputs, act_type=>'leaky', slope => $activation->alpha, %kwargs);
384    }
385    else
386    {
387        return $activation->($inputs, %kwargs);
388    }
389}
390
391=head2 forward
392
393        Unrolls the recurrent cell for one time step.
394
395        Parameters
396        ----------
397        inputs : sym.Variable
398            Input symbol, 2D, of shape (batch_size * num_units).
399        states : list of sym.Variable
400            RNN state from previous step or the output of begin_state().
401
402        Returns
403        -------
404        output : Symbol
405            Symbol corresponding to the output from the RNN when unrolling
406            for a single time step.
407        states : list of Symbol
408            The new state of this RNN after this unrolling.
409            The type of this symbol is same as the output of `begin_state()`.
410            This can be used as an input state to the next time step
411            of this RNN.
412
413        See Also
414        --------
415        begin_state: This function can provide the states for the first time step.
416        unroll: This function unrolls an RNN for a given number of (>=1) time steps.
417=cut
418
419package AI::MXNet::Gluon::RNN::HybridRecurrentCell;
420use AI::MXNet::Gluon::Mouse;
421extends 'AI::MXNet::Gluon::HybridBlock';
422with 'AI::MXNet::Gluon::RNN::RecurrentCell';
423has 'modified'      => (is => 'rw', isa => 'Bool', default => 0);
424has [qw/counter
425     init_counter/] => (is => 'rw', isa => 'Int', default => -1);
426
427sub BUILD
428{
429    my $self = shift;
430    $self->reset;
431}
432
433use overload '""' => sub {
434    my $self = shift;
435    my $s = '%s(%s';
436    if($self->can('activation'))
437    {
438        $s .= ", ${\ $self->activation }";
439    }
440    $s .= ')';
441    my $mapping = $self->input_size ? $self->input_size . " -> " . $self->hidden_size : $self->hidden_size;
442    return sprintf($s, $self->_class_name, $mapping);
443};
444
445method forward(GluonInput $inputs, Maybe[GluonInput|ArrayRef[GluonInput]] $states)
446{
447    $self->counter($self->counter + 1);
448    $self->SUPER::forward($inputs, $states);
449}
450
451package AI::MXNet::Gluon::RNN::RNNCell;
452use AI::MXNet::Gluon::Mouse;
453extends 'AI::MXNet::Gluon::RNN::HybridRecurrentCell';
454
455=head1 NAME
456
457    AI::MXNet::Gluon::RNN::RNNCell
458=cut
459
460=head1 DESCRIPTION
461
462    Simple recurrent neural network cell.
463
464    Parameters
465    ----------
466    hidden_size : int
467        Number of units in output symbol
468    activation : str or Symbol, default 'tanh'
469        Type of activation function.
470    i2h_weight_initializer : str or Initializer
471        Initializer for the input weights matrix, used for the linear
472        transformation of the inputs.
473    h2h_weight_initializer : str or Initializer
474        Initializer for the recurrent weights matrix, used for the linear
475        transformation of the recurrent state.
476    i2h_bias_initializer : str or Initializer
477        Initializer for the bias vector.
478    h2h_bias_initializer : str or Initializer
479        Initializer for the bias vector.
480    prefix : str, default 'rnn_'
481        Prefix for name of `Block`s
482        (and name of weight if params is `None`).
483    params : Parameter or None
484        Container for weight sharing between cells.
485        Created if `None`.
486=cut
487
488has 'hidden_size' => (is => 'rw', isa => 'Int', required => 1);
489has 'activation'  => (is => 'rw', isa => 'Activation', default => 'tanh');
490has [qw/
491    i2h_weight_initializer
492    h2h_weight_initializer
493    /]            => (is => 'rw', isa => 'Maybe[Initializer]');
494has [qw/
495    i2h_bias_initializer
496    h2h_bias_initializer
497    /]            => (is => 'rw', isa => 'Maybe[Initializer]', default => 'zeros');
498has 'input_size'  => (is => 'rw', isa => 'Int', default => 0);
499has [qw/
500        i2h_weight
501        h2h_weight
502        i2h_bias
503        h2h_bias
504    /]            => (is => 'rw', init_arg => undef);
505
506method python_constructor_arguments()
507{
508    [qw/
509        hidden_size activation
510        i2h_weight_initializer h2h_weight_initializer
511        i2h_bias_initializer h2h_bias_initializer
512        input_size
513    /];
514}
515
516sub BUILD
517{
518    my $self = shift;
519    $self->i2h_weight($self->params->get(
520        'i2h_weight', shape=>[$self->hidden_size, $self->input_size],
521        init => $self->i2h_weight_initializer,
522        allow_deferred_init => 1
523    ));
524    $self->h2h_weight($self->params->get(
525        'h2h_weight', shape=>[$self->hidden_size, $self->hidden_size],
526        init => $self->h2h_weight_initializer,
527        allow_deferred_init => 1
528    ));
529    $self->i2h_bias($self->params->get(
530        'i2h_bias', shape=>[$self->hidden_size],
531        init => $self->i2h_bias_initializer,
532        allow_deferred_init => 1
533    ));
534    $self->h2h_bias($self->params->get(
535        'h2h_bias', shape=>[$self->hidden_size],
536        init => $self->h2h_bias_initializer,
537        allow_deferred_init => 1
538    ));
539}
540
541method state_info(Int $batch_size=0)
542{
543    return [{ shape => [$batch_size, $self->hidden_size], __layout__ => 'NC' }];
544}
545
546method _alias() { 'rnn' }
547
548method hybrid_forward(
549    GluonClass $F, GluonInput $inputs, GluonInput $states,
550    GluonInput :$i2h_weight, GluonInput :$h2h_weight, GluonInput :$i2h_bias, GluonInput :$h2h_bias
551)
552{
553    my $prefix = "t${\ $self->counter}_";
554    my $i2h = $F->FullyConnected(
555        data => $inputs, weight => $i2h_weight, bias => $i2h_bias,
556        num_hidden => $self->hidden_size,
557        name => "${prefix}i2h"
558    );
559    my $h2h = $F->FullyConnected(
560        data => $states->[0], weight => $h2h_weight, bias => $h2h_bias,
561        num_hidden => $self->hidden_size,
562        name => "${prefix}h2h"
563    );
564    my $i2h_plus_h2h = $F->elemwise_add($i2h, $h2h, name => "${prefix}plus0");
565    my $output = $self->_get_activation($F, $i2h_plus_h2h, $self->activation, name => "${prefix}out");
566    return ($output, [$output]);
567}
568
569__PACKAGE__->register('AI::MXNet::Gluon::RNN');
570
571package AI::MXNet::Gluon::RNN::LSTMCell;
572use AI::MXNet::Gluon::Mouse;
573extends 'AI::MXNet::Gluon::RNN::HybridRecurrentCell';
574
575=head1 NAME
576
577    AI::MXNet::Gluon::RNN::LSTMCell
578=cut
579
580=head1 DESCRIPTION
581
582    Long-Short Term Memory (LSTM) network cell.
583
584    Parameters
585    ----------
586    hidden_size : int
587        Number of units in output symbol.
588    i2h_weight_initializer : str or Initializer
589        Initializer for the input weights matrix, used for the linear
590        transformation of the inputs.
591    h2h_weight_initializer : str or Initializer
592        Initializer for the recurrent weights matrix, used for the linear
593        transformation of the recurrent state.
594    i2h_bias_initializer : str or Initializer, default 'lstmbias'
595        Initializer for the bias vector. By default, bias for the forget
596        gate is initialized to 1 while all other biases are initialized
597        to zero.
598    h2h_bias_initializer : str or Initializer
599        Initializer for the bias vector.
600    prefix : str, default 'lstm_'
601        Prefix for name of `Block`s
602        (and name of weight if params is `None`).
603    params : Parameter or None
604        Container for weight sharing between cells.
605        Created if `None`.
606=cut
607
608has 'hidden_size' => (is => 'rw', isa => 'Int', required => 1);
609has [qw/
610    i2h_weight_initializer
611    h2h_weight_initializer
612    /]            => (is => 'rw', isa => 'Maybe[Initializer]');
613has [qw/
614    i2h_bias_initializer
615    h2h_bias_initializer
616    /]            => (is => 'rw', isa => 'Maybe[Initializer]', default => 'zeros');
617has 'input_size'  => (is => 'rw', isa => 'Int', default => 0);
618has [qw/
619        i2h_weight
620        h2h_weight
621        i2h_bias
622        h2h_bias
623    /]            => (is => 'rw', init_arg => undef);
624
625method python_constructor_arguments()
626{
627    [qw/
628        hidden_size
629        i2h_weight_initializer h2h_weight_initializer
630        i2h_bias_initializer h2h_bias_initializer
631        input_size
632    /];
633}
634
635
636sub BUILD
637{
638    my $self = shift;
639    $self->i2h_weight($self->params->get(
640        'i2h_weight', shape=>[4*$self->hidden_size, $self->input_size],
641        init => $self->i2h_weight_initializer,
642        allow_deferred_init => 1
643    ));
644    $self->h2h_weight($self->params->get(
645        'h2h_weight', shape=>[4*$self->hidden_size, $self->hidden_size],
646        init => $self->h2h_weight_initializer,
647        allow_deferred_init => 1
648    ));
649    $self->i2h_bias($self->params->get(
650        'i2h_bias', shape=>[4*$self->hidden_size],
651        init => $self->i2h_bias_initializer,
652        allow_deferred_init => 1
653    ));
654    $self->h2h_bias($self->params->get(
655        'h2h_bias', shape=>[4*$self->hidden_size],
656        init => $self->h2h_bias_initializer,
657        allow_deferred_init => 1
658    ));
659}
660
661method state_info(Int $batch_size=0)
662{
663    return [
664        { shape => [$batch_size, $self->hidden_size], __layout__ => 'NC' },
665        { shape => [$batch_size, $self->hidden_size], __layout__ => 'NC' }
666    ];
667}
668
669method _alias() { 'lstm' }
670
671method hybrid_forward(
672    GluonClass $F, GluonInput $inputs, GluonInput $states,
673    GluonInput :$i2h_weight, GluonInput :$h2h_weight, GluonInput :$i2h_bias, GluonInput :$h2h_bias
674)
675{
676    my $prefix = "t${\ $self->counter}_";
677    my $i2h = $F->FullyConnected(
678        $inputs, $i2h_weight, $i2h_bias,
679        num_hidden => $self->hidden_size*4,
680        name => "${prefix}i2h"
681    );
682    my $h2h = $F->FullyConnected(
683        $states->[0], $h2h_weight, $h2h_bias,
684        num_hidden => $self->hidden_size*4,
685        name => "${prefix}h2h"
686    );
687    my $gates = $F->elemwise_add($i2h, $h2h, name => "${prefix}plus0");
688    my @slice_gates = @{ $F->SliceChannel($gates, num_outputs => 4, name => "${prefix}slice") };
689    my $in_gate = $F->Activation($slice_gates[0], act_type=>"sigmoid", name => "${prefix}i");
690    my $forget_gate = $F->Activation($slice_gates[1], act_type=>"sigmoid", name => "${prefix}f");
691    my $in_transform = $F->Activation($slice_gates[2], act_type=>"tanh", name => "${prefix}c");
692    my $out_gate = $F->Activation($slice_gates[3], act_type=>"sigmoid", name => "${prefix}o");
693    my $next_c = $F->_plus(
694        $F->elemwise_mul($forget_gate, $states->[1], name => "${prefix}mul0"),
695        $F->elemwise_mul($in_gate, $in_transform, name => "${prefix}mul1"),
696        name => "${prefix}state"
697    );
698    my $next_h = $F->_mul($out_gate, $F->Activation($next_c, act_type=>"tanh", name => "${prefix}activation0"), name => "${prefix}out");
699    return ($next_h, [$next_h, $next_c]);
700}
701
702__PACKAGE__->register('AI::MXNet::Gluon::RNN');
703
704package AI::MXNet::Gluon::RNN::GRUCell;
705use AI::MXNet::Gluon::Mouse;
706extends 'AI::MXNet::Gluon::RNN::HybridRecurrentCell';
707
708=head1 NAME
709
710    AI::MXNet::Gluon::RNN::GRUCell
711=cut
712
713=head1 DESCRIPTION
714
715    Gated Rectified Unit (GRU) network cell.
716    Note: this is an implementation of the cuDNN version of GRUs
717    (slight modification compared to Cho et al. 2014).
718
719    Parameters
720    ----------
721    hidden_size : int
722        Number of units in output symbol.
723    i2h_weight_initializer : str or Initializer
724        Initializer for the input weights matrix, used for the linear
725        transformation of the inputs.
726    h2h_weight_initializer : str or Initializer
727        Initializer for the recurrent weights matrix, used for the linear
728        transformation of the recurrent state.
729    i2h_bias_initializer : str or Initializer
730        Initializer for the bias vector.
731    h2h_bias_initializer : str or Initializer
732        Initializer for the bias vector.
733    prefix : str, default 'gru_'
734        prefix for name of `Block`s
735        (and name of weight if params is `None`).
736    params : Parameter or None
737        Container for weight sharing between cells.
738        Created if `None`.
739=cut
740
741has 'hidden_size' => (is => 'rw', isa => 'Int', required => 1);
742has [qw/
743    i2h_weight_initializer
744    h2h_weight_initializer
745    /]            => (is => 'rw', isa => 'Maybe[Initializer]');
746has [qw/
747    i2h_bias_initializer
748    h2h_bias_initializer
749    /]            => (is => 'rw', isa => 'Maybe[Initializer]', default => 'zeros');
750has 'input_size'  => (is => 'rw', isa => 'Int', default => 0);
751has [qw/
752        i2h_weight
753        h2h_weight
754        i2h_bias
755        h2h_bias
756    /]            => (is => 'rw', init_arg => undef);
757
758method python_constructor_arguments()
759{
760    [qw/
761        hidden_size
762        i2h_weight_initializer h2h_weight_initializer
763        i2h_bias_initializer h2h_bias_initializer
764        input_size
765    /];
766}
767
768sub BUILD
769{
770    my $self = shift;
771    $self->i2h_weight($self->params->get(
772        'i2h_weight', shape=>[3*$self->hidden_size, $self->input_size],
773        init => $self->i2h_weight_initializer,
774        allow_deferred_init => 1
775    ));
776    $self->h2h_weight($self->params->get(
777        'h2h_weight', shape=>[3*$self->hidden_size, $self->hidden_size],
778        init => $self->h2h_weight_initializer,
779        allow_deferred_init => 1
780    ));
781    $self->i2h_bias($self->params->get(
782        'i2h_bias', shape=>[3*$self->hidden_size],
783        init => $self->i2h_bias_initializer,
784        allow_deferred_init => 1
785    ));
786    $self->h2h_bias($self->params->get(
787        'h2h_bias', shape=>[3*$self->hidden_size],
788        init => $self->h2h_bias_initializer,
789        allow_deferred_init => 1
790    ));
791}
792
793method state_info(Int $batch_size=0)
794{
795    return [{ shape => [$batch_size, $self->hidden_size], __layout__ => 'NC' }];
796}
797
798method _alias() { 'gru' }
799
800method hybrid_forward(
801    GluonClass $F, GluonInput $inputs, GluonInput $states,
802    GluonInput :$i2h_weight, GluonInput :$h2h_weight, GluonInput :$i2h_bias, GluonInput :$h2h_bias
803)
804{
805    my $prefix = "t${\ $self->counter}_";
806    my $prev_state_h = $states->[0];
807    my $i2h = $F->FullyConnected(
808        $inputs, $i2h_weight, $i2h_bias,
809        num_hidden => $self->hidden_size*3,
810        name => "${prefix}i2h"
811    );
812    my $h2h = $F->FullyConnected(
813        $states->[0], $h2h_weight, $h2h_bias,
814        num_hidden => $self->hidden_size*3,
815        name => "${prefix}h2h"
816    );
817    my ($i2h_r, $i2h_z, $h2h_r, $h2h_z);
818    ($i2h_r, $i2h_z, $i2h) = @{ $F->SliceChannel($i2h, num_outputs => 3, name => "${prefix}i2h_slice") };
819    ($h2h_r, $h2h_z, $h2h) = @{ $F->SliceChannel($h2h, num_outputs => 3, name => "${prefix}h2h_slice") };
820    my $reset_gate  = $F->Activation($F->elemwise_add($i2h_r, $h2h_r, name => "${prefix}plus0"), act_type=>"sigmoid", name => "${prefix}r_act");
821    my $update_gate = $F->Activation($F->elemwise_add($i2h_z, $h2h_z, name => "${prefix}plus1"), act_type=>"sigmoid", name => "${prefix}z_act");
822    my $next_h_tmp = $F->Activation(
823        $F->elemwise_add(
824            $i2h,
825            $F->elemwise_mul(
826                $reset_gate, $h2h, name => "${prefix}mul0"
827            ),
828            name => "${prefix}plus2"
829        ),
830        act_type => "tanh",
831        name => "${prefix}h_act"
832    );
833    my $ones = $F->ones_like($update_gate, name => "${prefix}ones_like0");
834    my $next_h = $F->_plus(
835        $F->elemwise_mul(
836            $F->elemwise_sub($ones, $update_gate, name => "${prefix}minus0"),
837            $next_h_tmp,
838            name => "${prefix}mul1"
839        ),
840        $F->elemwise_mul($update_gate, $prev_state_h, name => "${prefix}mul2"),
841        name => "${prefix}out"
842    );
843    return ($next_h, [$next_h]);
844}
845
846__PACKAGE__->register('AI::MXNet::Gluon::RNN');
847
848package AI::MXNet::Gluon::RNN::SequentialRNNCell;
849use AI::MXNet::Gluon::Mouse;
850use AI::MXNet::Base;
851no warnings 'redefine';
852extends 'AI::MXNet::Gluon::Block';
853with 'AI::MXNet::Gluon::RNN::RecurrentCell';
854has 'modified'      => (is => 'rw', isa => 'Bool', default => 0);
855has [qw/counter
856     init_counter/] => (is => 'rw', isa => 'Int', default => -1);
857
858sub BUILD
859{
860    my $self = shift;
861    $self->reset;
862}
863
864=head1 NAME
865
866    AI::MXNet::Gluon::RNN::SequentialRNNCell
867=cut
868
869=head1 DESCRIPTION
870
871    Sequentially stacking multiple RNN cells.
872=cut
873
874=head2 add
875
876    Appends a cell into the stack.
877
878    Parameters
879    ----------
880        cell : rnn cell
881=cut
882
883method add(AI::MXNet::Gluon::Block $cell)
884{
885    $self->register_child($cell);
886}
887
888method state_info(Int $batch_size=0)
889{
890    return $self->_cells_state_info($self->_children, $batch_size);
891}
892
893method begin_state(%kwargs)
894{
895    assert(
896        (not $self->modified),
897        "After applying modifier cells (e.g. ZoneoutCell) the base ".
898        "cell cannot be called directly. Call the modifier cell instead."
899    );
900    return $self->_cells_begin_state($self->_children, %kwargs);
901}
902
903method unroll(Int $length, GluonInput $inputs, Maybe[GluonInput] :$begin_state=, Str :$layout='NTC', Maybe[Bool] :$merge_outputs=)
904{
905    $self->reset();
906    my ($F, $batch_size);
907    ($inputs, undef, $F, $batch_size) = $self->_format_sequence($length, $inputs, $layout, undef);
908    my $num_cells = $self->_children->keys;
909    $begin_state = $self->_get_begin_state($F, $begin_state, $inputs, $batch_size);
910    my $p = 0;
911    my @next_states;
912    my $states;
913    enumerate(sub {
914        my ($i, $cell) = @_;
915        my $n = @{ $cell->state_info() };
916        $states = [@{ $begin_state }[$p..$p+$n-1]];
917        $p += $n;
918        ($inputs, $states) = $cell->unroll(
919            $length, $inputs, begin_state => $states, layout => $layout,
920            merge_outputs => ($i < ($num_cells - 1)) ? undef : $merge_outputs
921        );
922        push @next_states, @{ $states };
923    }, [$self->_children->values]);
924    return ($inputs, \@next_states);
925}
926
927method call($inputs, $states)
928{
929    $self->counter($self->counter + 1);
930    my @next_states;
931    my $p = 0;
932    for my $cell ($self->_children->values)
933    {
934        assert(not $cell->isa('AI::MXNet::Gluon::RNN::BidirectionalCell'));
935        my $n = @{ $cell->state_info() };
936        my $state = [@{ $states }[$p,$p+$n-1]];
937        $p += $n;
938        ($inputs, $state) = $cell->($inputs, $state);
939        push @next_states, @{ $state };
940    }
941    return ($inputs, \@next_states);
942}
943
944use overload '@{}' => sub { [shift->_children->values] };
945use overload '""'  => sub {
946    my $self = shift;
947    my $s = "%s(\n%s\n)";
948    my @children;
949    enumerate(sub {
950        my ($i, $m) = @_;
951        push @children, "($i): ". AI::MXNet::Base::_indent("$m", 2);
952    }, [$self->_children->values]);
953    return sprintf($s, $self->_class_name, join("\n", @children));
954};
955
956method hybrid_forward(@args)
957{
958    confess('Not Implemented');
959}
960
961__PACKAGE__->register('AI::MXNet::Gluon::RNN');
962
963package AI::MXNet::Gluon::RNN::DropoutCell;
964use AI::MXNet::Gluon::Mouse;
965extends 'AI::MXNet::Gluon::RNN::HybridRecurrentCell';
966
967=head1 NAME
968
969    AI::MXNet::Gluon::RNN::DropoutCell
970=cut
971
972=head1 DESCRIPTION
973
974    Applies dropout on input.
975
976    Parameters
977    ----------
978    rate : float
979        Percentage of elements to drop out, which
980        is 1 - percentage to retain.
981=cut
982
983has 'rate' => (is => 'ro', isa => 'Num', required => 1);
984method python_constructor_arguments() { ['rate'] }
985
986method state_info(Int $batch_size=0) { [] }
987
988method _alias() { 'dropout' }
989
990method hybrid_forward(GluonClass $F, GluonInput $inputs, GluonInput $states)
991{
992    if($self->rate > 0)
993    {
994        $inputs = $F->Dropout($inputs, p => $self->rate, name => "t${\ $self->counter }_fwd");
995    }
996    return ($inputs, $states);
997}
998
999method unroll(Int $length, GluonInput $inputs, Maybe[GluonInput] :$begin_state=, Str :$layout='NTC', Maybe[Bool] :$merge_outputs=)
1000{
1001    $self->reset;
1002    my $F;
1003    ($inputs, undef, $F) = $self->_format_sequence($length, $inputs, $layout, $merge_outputs);
1004    if(blessed $inputs)
1005    {
1006        return $self->hybrid_forward($F, $inputs, $begin_state//[]);
1007    }
1008    else
1009    {
1010        return $self->SUPER::unroll(
1011            $length, $inputs, begin_state => $begin_state, layout => $layout,
1012            merge_outputs => $merge_outputs
1013        );
1014    }
1015}
1016
1017use overload '""' => sub {
1018    my $self = shift;
1019    return $self->_class_name.'(rate ='.$self->rate.')';
1020};
1021
1022__PACKAGE__->register('AI::MXNet::Gluon::RNN');
1023
1024package AI::MXNet::Gluon::RNN::ModifierCell;
1025use AI::MXNet::Gluon::Mouse;
1026use AI::MXNet::Base;
1027extends 'AI::MXNet::Gluon::RNN::HybridRecurrentCell';
1028has 'base_cell' => (is => 'rw', isa => 'AI::MXNet::Gluon::RNN::HybridRecurrentCell', required => 1);
1029
1030=head1 NAME
1031
1032    AI::MXNet::Gluon::RNN::ModifierCell
1033=cut
1034
1035=head1 DESCRIPTION
1036
1037    Base class for modifier cells. A modifier
1038    cell takes a base cell, apply modifications
1039    on it (e.g. Zoneout), and returns a new cell.
1040
1041    After applying modifiers the base cell should
1042    no longer be called directly. The modifier cell
1043    should be used instead.
1044=cut
1045
1046
1047sub BUILD
1048{
1049    my $self = shift;
1050    assert(
1051        (not $self->base_cell->modified),
1052        "Cell ${\ $self->base_cell->name } is already modified. One cell cannot be modified twice"
1053    );
1054    $self->base_cell->modified(1);
1055}
1056
1057method params()
1058{
1059    return $self->base_cell->params;
1060}
1061
1062method state_info(Int $batch_size=0)
1063{
1064    return $self->base_cell->state_info($batch_size);
1065
1066}
1067
1068method begin_state(CodeRef :$func=sub{ AI::MXNet::Symbol->zeros(@_) }, %kwargs)
1069{
1070    assert(
1071        (not $self->modified),
1072        "After applying modifier cells (e.g. DropoutCell) the base ".
1073        "cell cannot be called directly. Call the modifier cell instead."
1074    );
1075    $self->base_cell->modified(0);
1076    my $begin = $self->base_cell->begin_state(func => $func, %kwargs);
1077    $self->base_cell->modified(1);
1078    return $begin;
1079}
1080
1081method hybrid_forward(GluonClass $F, GluonInput $inputs, GluonInput $states)
1082{
1083    confess('Not Implemented');
1084}
1085
1086use overload '""' => sub {
1087    my $self = shift;
1088    return $self->_class_name.'('.$self->base_cell.')';
1089};
1090
1091package AI::MXNet::Gluon::RNN::ZoneoutCell;
1092use AI::MXNet::Gluon::Mouse;
1093use AI::MXNet::Base;
1094extends 'AI::MXNet::Gluon::RNN::ModifierCell';
1095
1096=head1 NAME
1097
1098    AI::MXNet::Gluon::RNN::ZoneoutCell
1099=cut
1100
1101=head1 DESCRIPTION
1102
1103    Applies Zoneout on base cell.
1104=cut
1105has [qw/zoneout_outputs
1106        zoneout_states/] => (is => 'ro', isa => 'Num', default => 0);
1107has 'prev_output' => (is => 'rw', init_arg => undef);
1108method python_constructor_arguments() { ['base_cell', 'zoneout_outputs', 'zoneout_states'] }
1109
1110sub BUILD
1111{
1112    my $self = shift;
1113    assert(
1114        (not $self->base_cell->isa('AI::MXNet::Gluon::RNN::BidirectionalCell')),
1115        "BidirectionalCell doesn't support zoneout since it doesn't support step. ".
1116        "Please add ZoneoutCell to the cells underneath instead."
1117    );
1118    assert(
1119        (not $self->base_cell->isa('AI::MXNet::Gluon::RNN::SequentialRNNCel') or not $self->base_cell->bidirectional),
1120        "Bidirectional SequentialRNNCell doesn't support zoneout. ".
1121        "Please add ZoneoutCell to the cells underneath instead."
1122    );
1123}
1124
1125use overload '""' => sub {
1126    my $self = shift;
1127    return $self->_class_name.'(p_out='.$self->zoneout_outputs.', p_state='.$self->zoneout_states.
1128           ', '.$self->base_cell.')';
1129};
1130
1131method _alias() { 'zoneout' }
1132
1133method reset()
1134{
1135    $self->SUPER::reset();
1136    $self->prev_output(undef);
1137}
1138
1139method hybrid_forward(GluonClass $F, GluonInput $inputs, GluonInput $states)
1140{
1141    my ($cell, $p_outputs, $p_states) = ($self->base_cell, $self->zoneout_outputs, $self->zoneout_states);
1142    my ($next_output, $next_states) = $cell->($inputs, $states);
1143    my $mask = sub { my ($p, $like) = @_; $F->Dropout($F->ones_like($like), p=>$p) };
1144
1145    my $prev_output = $self->prev_output//$F->zeros_like($next_output);
1146    my $output = $p_outputs != 0 ? $F->where($mask->($p_outputs, $next_output), $next_output, $prev_output) : $next_output;
1147    if($p_states != 0)
1148    {
1149        my @tmp;
1150        for(zip($next_states, $states)) {
1151            my ($new_s, $old_s) = @$_;
1152            push @tmp, $F->where($mask->($p_states, $new_s), $new_s, $old_s);
1153        }
1154        $states = \@tmp;
1155    }
1156    else
1157    {
1158        $states = $next_states;
1159    }
1160    $self->prev_output($output);
1161    return ($output, $states);
1162}
1163
1164__PACKAGE__->register('AI::MXNet::Gluon::RNN');
1165
1166package AI::MXNet::Gluon::RNN::ResidualCell;
1167use AI::MXNet::Gluon::Mouse;
1168use AI::MXNet::Base;
1169extends 'AI::MXNet::Gluon::RNN::ModifierCell';
1170method python_constructor_arguments() { ['base_cell'] }
1171
1172=head1 NAME
1173
1174    AI::MXNet::Gluon::RNN::ResidualCell
1175=cut
1176
1177=head1 DESCRIPTION
1178
1179    Adds residual connection as described in Wu et al, 2016
1180    (https://arxiv.org/abs/1609.08144).
1181    Output of the cell is output of the base cell plus input.
1182=cut
1183
1184method hybrid_forward(GluonClas $F, GluonInput $inputs, GluonInput $states)
1185{
1186    my $output;
1187    ($output, $states) = $self->base_cell->($inputs, $states);
1188    $output = $F->elemwise_add($output, $inputs, name => "t${\ $self->counter }_fwd");
1189    return ($output, $states);
1190}
1191
1192method unroll(Int $length, GluonInput $inputs, Maybe[GluonInput] :$begin_state=, Str :$layout='NTC', Maybe[Bool] :$merge_outputs=)
1193{
1194    $self->reset();
1195
1196    $self->base_cell->modified(0);
1197    my ($outputs, $states) = $self->base_cell->unroll(
1198        $length, $inputs, begin_state => $begin_state, layout => $layout, merge_outputs => $merge_outputs
1199    );
1200    $self->base_cell->modified(1);
1201
1202    $merge_outputs //= blessed $outputs ? 1 : 0;
1203    my $F;
1204    ($inputs, undef, $F) = $self->_format_sequence($length, $inputs, $layout, $merge_outputs);
1205    if($merge_outputs)
1206    {
1207        $outputs = $F->elemwise_add($outputs, $inputs);
1208    }
1209    else
1210    {
1211        my @tmp;
1212        for(zip($outputs, $inputs)) {
1213            my ($i, $j) = @$_;
1214            push @tmp, $F->elemwise_add($i, $j);
1215        }
1216        $outputs = \@tmp;
1217    }
1218    return ($outputs, $states);
1219}
1220
1221__PACKAGE__->register('AI::MXNet::Gluon::RNN');
1222
1223package AI::MXNet::Gluon::RNN::BidirectionalCell;
1224use AI::MXNet::Gluon::Mouse;
1225use AI::MXNet::Base;
1226extends 'AI::MXNet::Gluon::RNN::HybridRecurrentCell';
1227has [qw/l_cell r_cell/] => (is => 'ro', isa => 'AI::MXNet::Gluon::RNN::HybridRecurrentCell', required => 1);
1228has 'output_prefix'     => (is => 'ro', isa => 'Str', default => 'bi_');
1229method python_constructor_arguments() { ['l_cell', 'r_cell', 'output_prefix'] }
1230
1231=head1 NAME
1232
1233    AI::MXNet::Gluon::RNN::BidirectionalCell
1234=cut
1235
1236=head1 DESCRIPTION
1237
1238    Bidirectional RNN cell.
1239
1240    Parameters
1241    ----------
1242    l_cell : RecurrentCell
1243        Cell for forward unrolling
1244    r_cell : RecurrentCell
1245        Cell for backward unrolling
1246=cut
1247
1248method call($inputs, $states)
1249{
1250    confess("Bidirectional cell cannot be stepped. Please use unroll");
1251}
1252
1253use overload '""' => sub {
1254    my $self = shift;
1255    "${\ $self->_class_name }(forward=${\ $self->l_cell }, backward=${\ $self->r_cell })";
1256};
1257
1258method state_info(Int $batch_size=0)
1259{
1260    return $self->_cells_state_info($self->_children, $batch_size);
1261}
1262
1263method begin_state(%kwargs)
1264{
1265    assert(
1266        (not $self->modified),
1267        "After applying modifier cells (e.g. DropoutCell) the base ".
1268        "cell cannot be called directly. Call the modifier cell instead."
1269    );
1270    return $self->_cells_begin_state($self->_children, %kwargs);
1271}
1272
1273method unroll(Int $length, GluonInput $inputs, Maybe[GluonInput] :$begin_state=, Str :$layout='NTC', Maybe[Bool] :$merge_outputs=)
1274{
1275    $self->reset();
1276    my ($axis, $F, $batch_size);
1277    ($inputs, $axis, $F, $batch_size) = $self->_format_sequence($length, $inputs, $layout, 0);
1278    $begin_state //= $self->_get_begin_state($F, $begin_state, $inputs, $batch_size);
1279
1280    my $states = $begin_state;
1281    my ($l_cell, $r_cell) = $self->_children->values;
1282    $l_cell->state_info($batch_size);
1283    my ($l_outputs, $l_states) = $l_cell->unroll(
1284            $length, $inputs,
1285            begin_state => [@{ $states }[0..@{ $l_cell->state_info($batch_size) }-1]],
1286            layout => $layout,
1287            merge_outputs => $merge_outputs
1288    );
1289    my ($r_outputs, $r_states) = $r_cell->unroll(
1290        $length, [reverse @{$inputs}],
1291        begin_state     => [@{$states}[@{ $l_cell->state_info }..@{$states}-1]],
1292        layout          => $layout,
1293        merge_outputs   => $merge_outputs
1294    );
1295    if(not defined $merge_outputs)
1296    {
1297        $merge_outputs = blessed $l_outputs and blessed $r_outputs;
1298        ($l_outputs) = $self->_format_sequence(undef, $l_outputs, $layout, $merge_outputs);
1299        ($r_outputs) = $self->_format_sequence(undef, $r_outputs, $layout, $merge_outputs);
1300    }
1301    my $outputs;
1302    if($merge_outputs)
1303    {
1304        $r_outputs = $F->reverse($r_outputs, axis=>$axis);
1305        $outputs = $F->concat($l_outputs, $r_outputs, dim=>2, name=>$self->output_prefix.'out');
1306    }
1307    else
1308    {
1309        $outputs = [];
1310        enumerate(sub {
1311            my ($i, $l_o, $r_o) = @_;
1312                push @$outputs, $F->concat(
1313                    $l_o, $r_o, dim=>1,
1314                    name => sprintf('%st%d', $self->output_prefix, $i)
1315                );
1316            }, [@{ $l_outputs }], [reverse(@{ $r_outputs })]
1317        );
1318    }
1319    $states = [@{ $l_states }, @{ $r_states }];
1320    return ($outputs, $states);
1321}
1322
1323__PACKAGE__->register('AI::MXNet::Gluon::RNN');
1324
13251;
1326