1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18use strict;
19use warnings;
20use AI::MXNet qw(mx);
21use AI::MXNet::TestUtils qw(same);
22use PDL;
23use Test::More tests => 54;
24
25sub test_rnn
26{
27    my $cell = mx->rnn->RNNCell(100, prefix=>'rnn_');
28    my ($outputs) = $cell->unroll(3, input_prefix=>'rnn_');
29    $outputs = mx->sym->Group($outputs);
30    is_deeply([sort keys %{$cell->params->_params}], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']);
31    is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']);
32    my (undef, $outs, undef) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
33    is_deeply($outs, [[10, 100], [10, 100], [10, 100]]);
34}
35
36sub test_lstm
37{
38    my $cell = mx->rnn->LSTMCell(100, prefix=>'rnn_', forget_bias => 1);
39    my($outputs) = $cell->unroll(3, input_prefix=>'rnn_');
40    $outputs = mx->sym->Group($outputs);
41    is_deeply([sort keys %{$cell->params->_params}], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']);
42    is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']);
43    my (undef, $outs, undef) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
44    is_deeply($outs, [[10, 100], [10, 100], [10, 100]]);
45}
46
47sub test_lstm_forget_bias
48{
49    my $forget_bias = 2;
50    my $stack = mx->rnn->SequentialRNNCell();
51    $stack->add(mx->rnn->LSTMCell(100, forget_bias=>$forget_bias, prefix=>'l0_'));
52    $stack->add(mx->rnn->LSTMCell(100, forget_bias=>$forget_bias, prefix=>'l1_'));
53
54    my $dshape = [32, 1, 200];
55    my $data   = mx->sym->Variable('data');
56
57    my ($sym) = $stack->unroll(1, inputs => $data, merge_outputs => 1);
58    my $mod = mx->mod->Module($sym, context => mx->cpu(0));
59    $mod->bind(data_shapes=>[['data', $dshape]]);
60
61    $mod->init_params();
62    my ($bias_argument) = grep { /i2h_bias$/ } @{ $sym->list_arguments };
63    my $f = zeros(100);
64    my $expected_bias = $f->glue(0, $forget_bias * ones(100), zeros(200));
65    ok(
66        ((($mod->get_params())[0]->{$bias_argument}->aspdl - $expected_bias)->abs < 1e-07)->all
67    );
68}
69
70sub test_gru
71{
72    my $cell = mx->rnn->GRUCell(100, prefix=>'rnn_');
73    my($outputs) = $cell->unroll(3, input_prefix=>'rnn_');
74    $outputs = mx->sym->Group($outputs);
75    is_deeply([sort keys %{$cell->params->_params}], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']);
76    is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']);
77    my (undef, $outs, undef) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
78    is_deeply($outs, [[10, 100], [10, 100], [10, 100]]);
79}
80
81sub test_residual
82{
83    my $cell = mx->rnn->ResidualCell(mx->rnn->GRUCell(50, prefix=>'rnn_'));
84    my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..1];
85    my ($outputs)= $cell->unroll(2, inputs => $inputs);
86    $outputs = mx->sym->Group($outputs);
87    is_deeply(
88        [sort keys %{ $cell->params->_params }],
89        ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']
90    );
91    is_deeply(
92        $outputs->list_outputs,
93        ['rnn_t0_out_plus_residual_output', 'rnn_t1_out_plus_residual_output']
94    );
95
96    my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10, 50], rnn_t1_data=>[10, 50]);
97    is_deeply($outs, [[10, 50], [10, 50]]);
98    $outputs = $outputs->eval(args => {
99        rnn_t0_data=>mx->nd->ones([10, 50]),
100        rnn_t1_data=>mx->nd->ones([10, 50]),
101        rnn_i2h_weight=>mx->nd->zeros([150, 50]),
102        rnn_i2h_bias=>mx->nd->zeros([150]),
103        rnn_h2h_weight=>mx->nd->zeros([150, 50]),
104        rnn_h2h_bias=>mx->nd->zeros([150])
105    });
106    my $expected_outputs = mx->nd->ones([10, 50])->aspdl;
107    same(@{$outputs}[0]->aspdl, $expected_outputs);
108    same(@{$outputs}[1]->aspdl, $expected_outputs);
109}
110
111sub test_residual_bidirectional
112{
113    my $cell = mx->rnn->ResidualCell(
114        mx->rnn->BidirectionalCell(
115            mx->rnn->GRUCell(25, prefix=>'rnn_l_'),
116            mx->rnn->GRUCell(25, prefix=>'rnn_r_')
117        )
118    );
119    my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..1];
120    my ($outputs) = $cell->unroll(2, inputs => $inputs, merge_outputs=>0);
121    $outputs = mx->sym->Group($outputs);
122    is_deeply(
123        [sort keys %{ $cell->params->_params }],
124        ['rnn_l_h2h_bias', 'rnn_l_h2h_weight', 'rnn_l_i2h_bias', 'rnn_l_i2h_weight',
125        'rnn_r_h2h_bias', 'rnn_r_h2h_weight', 'rnn_r_i2h_bias', 'rnn_r_i2h_weight']
126    );
127    is_deeply(
128        $outputs->list_outputs,
129        ['bi_t0_plus_residual_output', 'bi_t1_plus_residual_output']
130    );
131
132    my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10, 50], rnn_t1_data=>[10, 50]);
133    is_deeply($outs, [[10, 50], [10, 50]]);
134    $outputs = $outputs->eval(args => {
135        rnn_t0_data=>mx->nd->ones([10, 50])+5,
136        rnn_t1_data=>mx->nd->ones([10, 50])+5,
137        rnn_l_i2h_weight=>mx->nd->zeros([75, 50]),
138        rnn_l_i2h_bias=>mx->nd->zeros([75]),
139        rnn_l_h2h_weight=>mx->nd->zeros([75, 25]),
140        rnn_l_h2h_bias=>mx->nd->zeros([75]),
141        rnn_r_i2h_weight=>mx->nd->zeros([75, 50]),
142        rnn_r_i2h_bias=>mx->nd->zeros([75]),
143        rnn_r_h2h_weight=>mx->nd->zeros([75, 25]),
144        rnn_r_h2h_bias=>mx->nd->zeros([75])
145    });
146    my $expected_outputs = (mx->nd->ones([10, 50])+5)->aspdl;
147    ok(same(@{$outputs}[0]->aspdl, $expected_outputs));
148    ok(same(@{$outputs}[1]->aspdl, $expected_outputs));
149}
150
151sub test_stack
152{
153    my $cell = mx->rnn->SequentialRNNCell();
154    for my $i (0..4)
155    {
156        if($i == 1)
157        {
158            $cell->add(mx->rnn->ResidualCell(mx->rnn->LSTMCell(100, prefix=>"rnn_stack${i}_")));
159        }
160        else
161        {
162            $cell->add(mx->rnn->LSTMCell(100, prefix=>"rnn_stack${i}_"));
163        }
164    }
165    my ($outputs) = $cell->unroll(3, input_prefix=>'rnn_');
166    $outputs = mx->sym->Group($outputs);
167    my %params = %{ $cell->params->_params };
168    for my $i (0..4)
169    {
170        ok(exists $params{"rnn_stack${i}_h2h_weight"});
171        ok(exists $params{"rnn_stack${i}_h2h_bias"});
172        ok(exists $params{"rnn_stack${i}_i2h_weight"});
173        ok(exists $params{"rnn_stack${i}_i2h_bias"});
174    }
175    is_deeply($outputs->list_outputs(), ['rnn_stack4_t0_out_output', 'rnn_stack4_t1_out_output', 'rnn_stack4_t2_out_output']);
176    my (undef, $outs, undef) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
177    is_deeply($outs, [[10, 100], [10, 100], [10, 100]]);
178}
179
180sub test_bidirectional
181{
182    my $cell = mx->rnn->BidirectionalCell(
183        mx->rnn->LSTMCell(100, prefix=>'rnn_l0_'),
184        mx->rnn->LSTMCell(100, prefix=>'rnn_r0_'),
185        output_prefix=>'rnn_bi_'
186    );
187    my ($outputs) = $cell->unroll(3, input_prefix=>'rnn_');
188    $outputs = mx->sym->Group($outputs);
189    is_deeply($outputs->list_outputs(), ['rnn_bi_t0_output', 'rnn_bi_t1_output', 'rnn_bi_t2_output']);
190    my (undef, $outs, undef) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
191    is_deeply($outs, [[10, 200], [10, 200], [10, 200]]);
192}
193
194sub test_unfuse
195{
196    my $cell = mx->rnn->FusedRNNCell(
197        100, num_layers => 1, mode => 'lstm',
198        prefix => 'test_', bidirectional => 1
199    )->unfuse;
200    my ($outputs) = $cell->unroll(3, input_prefix=>'rnn_');
201    $outputs = mx->sym->Group($outputs);
202    is_deeply($outputs->list_outputs(), ['test_bi_lstm_0t0_output', 'test_bi_lstm_0t1_output', 'test_bi_lstm_0t2_output']);
203    my (undef, $outs, undef) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]);
204    is_deeply($outs, [[10, 200], [10, 200], [10, 200]]);
205}
206
207sub test_zoneout
208{
209    my $cell = mx->rnn->ZoneoutCell(
210        mx->rnn->RNNCell(100, prefix=>'rnn_'),
211        zoneout_outputs => 0.5,
212        zoneout_states  => 0.5
213    );
214    my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2];
215    my ($outputs) = $cell->unroll(3, inputs => $inputs);
216    $outputs = mx->sym->Group($outputs);
217    my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10, 50], rnn_t1_data=>[10, 50], rnn_t2_data=>[10, 50]);
218    is_deeply($outs, [[10, 100], [10, 100], [10, 100]]);
219}
220
221sub test_convrnn
222{
223    my $cell = mx->rnn->ConvRNNCell(input_shape => [1, 3, 16, 10], num_hidden=>10,
224                              h2h_kernel=>[3, 3], h2h_dilate=>[1, 1],
225                              i2h_kernel=>[3, 3], i2h_stride=>[1, 1],
226                              i2h_pad=>[1, 1], i2h_dilate=>[1, 1],
227                              prefix=>'rnn_');
228    my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2];
229    my ($outputs) = $cell->unroll(3, inputs => $inputs);
230    $outputs = mx->sym->Group($outputs);
231    is_deeply(
232        [sort keys %{ $cell->params->_params }],
233        ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']
234    );
235    is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']);
236    my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[1, 3, 16, 10], rnn_t1_data=>[1, 3, 16, 10], rnn_t2_data=>[1, 3, 16, 10]);
237    is_deeply($outs, [[1, 10, 16, 10], [1, 10, 16, 10], [1, 10, 16, 10]]);
238}
239
240sub test_convlstm
241{
242    my $cell = mx->rnn->ConvLSTMCell(input_shape => [1, 3, 16, 10], num_hidden=>10,
243                              h2h_kernel=>[3, 3], h2h_dilate=>[1, 1],
244                              i2h_kernel=>[3, 3], i2h_stride=>[1, 1],
245                              i2h_pad=>[1, 1], i2h_dilate=>[1, 1],
246                              prefix=>'rnn_', forget_bias => 1);
247    my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2];
248    my ($outputs) = $cell->unroll(3, inputs => $inputs);
249    $outputs = mx->sym->Group($outputs);
250    is_deeply(
251        [sort keys %{ $cell->params->_params }],
252        ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']
253    );
254    is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']);
255    my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[1, 3, 16, 10], rnn_t1_data=>[1, 3, 16, 10], rnn_t2_data=>[1, 3, 16, 10]);
256    is_deeply($outs, [[1, 10, 16, 10], [1, 10, 16, 10], [1, 10, 16, 10]]);
257}
258
259sub test_convgru
260{
261    my $cell = mx->rnn->ConvGRUCell(input_shape => [1, 3, 16, 10], num_hidden=>10,
262                              h2h_kernel=>[3, 3], h2h_dilate=>[1, 1],
263                              i2h_kernel=>[3, 3], i2h_stride=>[1, 1],
264                              i2h_pad=>[1, 1], i2h_dilate=>[1, 1],
265                              prefix=>'rnn_', forget_bias => 1);
266    my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2];
267    my ($outputs) = $cell->unroll(3, inputs => $inputs);
268    $outputs = mx->sym->Group($outputs);
269    is_deeply(
270        [sort keys %{ $cell->params->_params }],
271        ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']
272    );
273    is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']);
274    my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[1, 3, 16, 10], rnn_t1_data=>[1, 3, 16, 10], rnn_t2_data=>[1, 3, 16, 10]);
275    is_deeply($outs, [[1, 10, 16, 10], [1, 10, 16, 10], [1, 10, 16, 10]]);
276}
277
278test_rnn();
279test_lstm();
280test_lstm_forget_bias();
281test_gru();
282test_residual();
283test_residual_bidirectional();
284test_stack();
285test_bidirectional();
286test_unfuse();
287test_zoneout();
288test_convrnn();
289test_convlstm();
290test_convgru();
291