1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18use strict; 19use warnings; 20use AI::MXNet qw(mx); 21use AI::MXNet::TestUtils qw(same); 22use PDL; 23use Test::More tests => 54; 24 25sub test_rnn 26{ 27 my $cell = mx->rnn->RNNCell(100, prefix=>'rnn_'); 28 my ($outputs) = $cell->unroll(3, input_prefix=>'rnn_'); 29 $outputs = mx->sym->Group($outputs); 30 is_deeply([sort keys %{$cell->params->_params}], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']); 31 is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']); 32 my (undef, $outs, undef) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]); 33 is_deeply($outs, [[10, 100], [10, 100], [10, 100]]); 34} 35 36sub test_lstm 37{ 38 my $cell = mx->rnn->LSTMCell(100, prefix=>'rnn_', forget_bias => 1); 39 my($outputs) = $cell->unroll(3, input_prefix=>'rnn_'); 40 $outputs = mx->sym->Group($outputs); 41 is_deeply([sort keys %{$cell->params->_params}], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']); 42 is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']); 43 my (undef, $outs, undef) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]); 44 is_deeply($outs, [[10, 100], [10, 100], [10, 100]]); 45} 46 47sub test_lstm_forget_bias 48{ 49 my $forget_bias = 2; 50 my $stack = mx->rnn->SequentialRNNCell(); 51 $stack->add(mx->rnn->LSTMCell(100, forget_bias=>$forget_bias, prefix=>'l0_')); 52 $stack->add(mx->rnn->LSTMCell(100, forget_bias=>$forget_bias, prefix=>'l1_')); 53 54 my $dshape = [32, 1, 200]; 55 my $data = mx->sym->Variable('data'); 56 57 my ($sym) = $stack->unroll(1, inputs => $data, merge_outputs => 1); 58 my $mod = mx->mod->Module($sym, context => mx->cpu(0)); 59 $mod->bind(data_shapes=>[['data', $dshape]]); 60 61 $mod->init_params(); 62 my ($bias_argument) = grep { /i2h_bias$/ } @{ $sym->list_arguments }; 63 my $f = zeros(100); 64 my $expected_bias = $f->glue(0, $forget_bias * ones(100), zeros(200)); 65 ok( 66 ((($mod->get_params())[0]->{$bias_argument}->aspdl - $expected_bias)->abs < 1e-07)->all 67 ); 68} 69 70sub test_gru 71{ 72 my $cell = mx->rnn->GRUCell(100, prefix=>'rnn_'); 73 my($outputs) = $cell->unroll(3, input_prefix=>'rnn_'); 74 $outputs = mx->sym->Group($outputs); 75 is_deeply([sort keys %{$cell->params->_params}], ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']); 76 is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']); 77 my (undef, $outs, undef) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]); 78 is_deeply($outs, [[10, 100], [10, 100], [10, 100]]); 79} 80 81sub test_residual 82{ 83 my $cell = mx->rnn->ResidualCell(mx->rnn->GRUCell(50, prefix=>'rnn_')); 84 my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..1]; 85 my ($outputs)= $cell->unroll(2, inputs => $inputs); 86 $outputs = mx->sym->Group($outputs); 87 is_deeply( 88 [sort keys %{ $cell->params->_params }], 89 ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight'] 90 ); 91 is_deeply( 92 $outputs->list_outputs, 93 ['rnn_t0_out_plus_residual_output', 'rnn_t1_out_plus_residual_output'] 94 ); 95 96 my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10, 50], rnn_t1_data=>[10, 50]); 97 is_deeply($outs, [[10, 50], [10, 50]]); 98 $outputs = $outputs->eval(args => { 99 rnn_t0_data=>mx->nd->ones([10, 50]), 100 rnn_t1_data=>mx->nd->ones([10, 50]), 101 rnn_i2h_weight=>mx->nd->zeros([150, 50]), 102 rnn_i2h_bias=>mx->nd->zeros([150]), 103 rnn_h2h_weight=>mx->nd->zeros([150, 50]), 104 rnn_h2h_bias=>mx->nd->zeros([150]) 105 }); 106 my $expected_outputs = mx->nd->ones([10, 50])->aspdl; 107 same(@{$outputs}[0]->aspdl, $expected_outputs); 108 same(@{$outputs}[1]->aspdl, $expected_outputs); 109} 110 111sub test_residual_bidirectional 112{ 113 my $cell = mx->rnn->ResidualCell( 114 mx->rnn->BidirectionalCell( 115 mx->rnn->GRUCell(25, prefix=>'rnn_l_'), 116 mx->rnn->GRUCell(25, prefix=>'rnn_r_') 117 ) 118 ); 119 my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..1]; 120 my ($outputs) = $cell->unroll(2, inputs => $inputs, merge_outputs=>0); 121 $outputs = mx->sym->Group($outputs); 122 is_deeply( 123 [sort keys %{ $cell->params->_params }], 124 ['rnn_l_h2h_bias', 'rnn_l_h2h_weight', 'rnn_l_i2h_bias', 'rnn_l_i2h_weight', 125 'rnn_r_h2h_bias', 'rnn_r_h2h_weight', 'rnn_r_i2h_bias', 'rnn_r_i2h_weight'] 126 ); 127 is_deeply( 128 $outputs->list_outputs, 129 ['bi_t0_plus_residual_output', 'bi_t1_plus_residual_output'] 130 ); 131 132 my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10, 50], rnn_t1_data=>[10, 50]); 133 is_deeply($outs, [[10, 50], [10, 50]]); 134 $outputs = $outputs->eval(args => { 135 rnn_t0_data=>mx->nd->ones([10, 50])+5, 136 rnn_t1_data=>mx->nd->ones([10, 50])+5, 137 rnn_l_i2h_weight=>mx->nd->zeros([75, 50]), 138 rnn_l_i2h_bias=>mx->nd->zeros([75]), 139 rnn_l_h2h_weight=>mx->nd->zeros([75, 25]), 140 rnn_l_h2h_bias=>mx->nd->zeros([75]), 141 rnn_r_i2h_weight=>mx->nd->zeros([75, 50]), 142 rnn_r_i2h_bias=>mx->nd->zeros([75]), 143 rnn_r_h2h_weight=>mx->nd->zeros([75, 25]), 144 rnn_r_h2h_bias=>mx->nd->zeros([75]) 145 }); 146 my $expected_outputs = (mx->nd->ones([10, 50])+5)->aspdl; 147 ok(same(@{$outputs}[0]->aspdl, $expected_outputs)); 148 ok(same(@{$outputs}[1]->aspdl, $expected_outputs)); 149} 150 151sub test_stack 152{ 153 my $cell = mx->rnn->SequentialRNNCell(); 154 for my $i (0..4) 155 { 156 if($i == 1) 157 { 158 $cell->add(mx->rnn->ResidualCell(mx->rnn->LSTMCell(100, prefix=>"rnn_stack${i}_"))); 159 } 160 else 161 { 162 $cell->add(mx->rnn->LSTMCell(100, prefix=>"rnn_stack${i}_")); 163 } 164 } 165 my ($outputs) = $cell->unroll(3, input_prefix=>'rnn_'); 166 $outputs = mx->sym->Group($outputs); 167 my %params = %{ $cell->params->_params }; 168 for my $i (0..4) 169 { 170 ok(exists $params{"rnn_stack${i}_h2h_weight"}); 171 ok(exists $params{"rnn_stack${i}_h2h_bias"}); 172 ok(exists $params{"rnn_stack${i}_i2h_weight"}); 173 ok(exists $params{"rnn_stack${i}_i2h_bias"}); 174 } 175 is_deeply($outputs->list_outputs(), ['rnn_stack4_t0_out_output', 'rnn_stack4_t1_out_output', 'rnn_stack4_t2_out_output']); 176 my (undef, $outs, undef) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]); 177 is_deeply($outs, [[10, 100], [10, 100], [10, 100]]); 178} 179 180sub test_bidirectional 181{ 182 my $cell = mx->rnn->BidirectionalCell( 183 mx->rnn->LSTMCell(100, prefix=>'rnn_l0_'), 184 mx->rnn->LSTMCell(100, prefix=>'rnn_r0_'), 185 output_prefix=>'rnn_bi_' 186 ); 187 my ($outputs) = $cell->unroll(3, input_prefix=>'rnn_'); 188 $outputs = mx->sym->Group($outputs); 189 is_deeply($outputs->list_outputs(), ['rnn_bi_t0_output', 'rnn_bi_t1_output', 'rnn_bi_t2_output']); 190 my (undef, $outs, undef) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]); 191 is_deeply($outs, [[10, 200], [10, 200], [10, 200]]); 192} 193 194sub test_unfuse 195{ 196 my $cell = mx->rnn->FusedRNNCell( 197 100, num_layers => 1, mode => 'lstm', 198 prefix => 'test_', bidirectional => 1 199 )->unfuse; 200 my ($outputs) = $cell->unroll(3, input_prefix=>'rnn_'); 201 $outputs = mx->sym->Group($outputs); 202 is_deeply($outputs->list_outputs(), ['test_bi_lstm_0t0_output', 'test_bi_lstm_0t1_output', 'test_bi_lstm_0t2_output']); 203 my (undef, $outs, undef) = $outputs->infer_shape(rnn_t0_data=>[10,50], rnn_t1_data=>[10,50], rnn_t2_data=>[10,50]); 204 is_deeply($outs, [[10, 200], [10, 200], [10, 200]]); 205} 206 207sub test_zoneout 208{ 209 my $cell = mx->rnn->ZoneoutCell( 210 mx->rnn->RNNCell(100, prefix=>'rnn_'), 211 zoneout_outputs => 0.5, 212 zoneout_states => 0.5 213 ); 214 my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2]; 215 my ($outputs) = $cell->unroll(3, inputs => $inputs); 216 $outputs = mx->sym->Group($outputs); 217 my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[10, 50], rnn_t1_data=>[10, 50], rnn_t2_data=>[10, 50]); 218 is_deeply($outs, [[10, 100], [10, 100], [10, 100]]); 219} 220 221sub test_convrnn 222{ 223 my $cell = mx->rnn->ConvRNNCell(input_shape => [1, 3, 16, 10], num_hidden=>10, 224 h2h_kernel=>[3, 3], h2h_dilate=>[1, 1], 225 i2h_kernel=>[3, 3], i2h_stride=>[1, 1], 226 i2h_pad=>[1, 1], i2h_dilate=>[1, 1], 227 prefix=>'rnn_'); 228 my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2]; 229 my ($outputs) = $cell->unroll(3, inputs => $inputs); 230 $outputs = mx->sym->Group($outputs); 231 is_deeply( 232 [sort keys %{ $cell->params->_params }], 233 ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight'] 234 ); 235 is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']); 236 my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[1, 3, 16, 10], rnn_t1_data=>[1, 3, 16, 10], rnn_t2_data=>[1, 3, 16, 10]); 237 is_deeply($outs, [[1, 10, 16, 10], [1, 10, 16, 10], [1, 10, 16, 10]]); 238} 239 240sub test_convlstm 241{ 242 my $cell = mx->rnn->ConvLSTMCell(input_shape => [1, 3, 16, 10], num_hidden=>10, 243 h2h_kernel=>[3, 3], h2h_dilate=>[1, 1], 244 i2h_kernel=>[3, 3], i2h_stride=>[1, 1], 245 i2h_pad=>[1, 1], i2h_dilate=>[1, 1], 246 prefix=>'rnn_', forget_bias => 1); 247 my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2]; 248 my ($outputs) = $cell->unroll(3, inputs => $inputs); 249 $outputs = mx->sym->Group($outputs); 250 is_deeply( 251 [sort keys %{ $cell->params->_params }], 252 ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight'] 253 ); 254 is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']); 255 my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[1, 3, 16, 10], rnn_t1_data=>[1, 3, 16, 10], rnn_t2_data=>[1, 3, 16, 10]); 256 is_deeply($outs, [[1, 10, 16, 10], [1, 10, 16, 10], [1, 10, 16, 10]]); 257} 258 259sub test_convgru 260{ 261 my $cell = mx->rnn->ConvGRUCell(input_shape => [1, 3, 16, 10], num_hidden=>10, 262 h2h_kernel=>[3, 3], h2h_dilate=>[1, 1], 263 i2h_kernel=>[3, 3], i2h_stride=>[1, 1], 264 i2h_pad=>[1, 1], i2h_dilate=>[1, 1], 265 prefix=>'rnn_', forget_bias => 1); 266 my $inputs = [map { mx->sym->Variable("rnn_t${_}_data") } 0..2]; 267 my ($outputs) = $cell->unroll(3, inputs => $inputs); 268 $outputs = mx->sym->Group($outputs); 269 is_deeply( 270 [sort keys %{ $cell->params->_params }], 271 ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight'] 272 ); 273 is_deeply($outputs->list_outputs(), ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']); 274 my (undef, $outs) = $outputs->infer_shape(rnn_t0_data=>[1, 3, 16, 10], rnn_t1_data=>[1, 3, 16, 10], rnn_t2_data=>[1, 3, 16, 10]); 275 is_deeply($outs, [[1, 10, 16, 10], [1, 10, 16, 10], [1, 10, 16, 10]]); 276} 277 278test_rnn(); 279test_lstm(); 280test_lstm_forget_bias(); 281test_gru(); 282test_residual(); 283test_residual_bidirectional(); 284test_stack(); 285test_bidirectional(); 286test_unfuse(); 287test_zoneout(); 288test_convrnn(); 289test_convlstm(); 290test_convgru(); 291