1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18use strict; 19use warnings; 20use Test::More tests => 2285; 21use AI::MXNet qw(mx); 22use AI::MXNet::TestUtils qw(reldiff pdl_maximum pdl_minimum); 23use PDL; 24 25sub check_bind_with_uniform 26{ 27 my ($uf, $gf, $dim, $sf, $lshape, $rshape) = @_; 28 my $shape = (random($dim)*int(1000**(1.0/$dim))+1)->floor->unpdl; 29 my $lhs = mx->symbol->Variable('lhs'); 30 my $rhs = mx->symbol->Variable('rhs'); 31 my $ret; 32 if(defined $sf) 33 { 34 $ret = &{$sf}($lhs, $rhs); 35 } 36 else 37 { 38 $ret = &{$uf}($lhs, $rhs); 39 } 40 41 is_deeply($ret->list_arguments(), ['lhs', 'rhs']); 42 $lshape //= $shape; 43 $rshape //= $shape; 44 45 my $lhs_arr = mx->nd->array(random(reverse (@$lshape))); 46 my $rhs_arr = mx->nd->array(random(reverse (@$rshape))); 47 my $lhs_grad = mx->nd->empty($lshape); 48 my $rhs_grad = mx->nd->empty($rshape); 49 my $executor = $ret->bind( 50 ctx => mx->Context('cpu'), 51 args => [$lhs_arr, $rhs_arr], 52 args_grad => [$lhs_grad, $rhs_grad] 53 ); 54 55 my $exec3 = $ret->bind( 56 ctx => mx->Context('cpu'), 57 args => [$lhs_arr, $rhs_arr] 58 ); 59 60 my $exec4 = $ret->bind( 61 ctx => mx->Context('cpu'), 62 args => {'rhs' => $rhs_arr, 'lhs' => $lhs_arr}, 63 args_grad=>{'lhs' => $lhs_grad, 'rhs' => $rhs_grad} 64 ); 65 66 $executor->forward(1); 67 $exec3->forward(1); 68 $exec4->forward(1); 69 my $out2 = $executor->outputs->[0]->aspdl; 70 my $out1 = &{$uf}($lhs_arr->aspdl, $rhs_arr->aspdl); 71 my $out3 = $exec3->outputs->[0]->aspdl; 72 my $out4 = $exec4->outputs->[0]->aspdl; 73 ok(reldiff($out1, $out2) < 1e-6); 74 ok(reldiff($out1, $out3) < 1e-6); 75 ok(reldiff($out1, $out4) < 1e-6); 76 # test gradient 77 78 my $out_grad = mx->nd->ones([reverse @{$out2->shape->unpdl}]); 79 my ($lhs_grad2, $rhs_grad2) = &{$gf}( 80 $out_grad->aspdl, 81 $lhs_arr->aspdl, 82 $rhs_arr->aspdl 83 ); 84 $executor->backward([$out_grad]); 85 86 ok(reldiff($lhs_grad->aspdl, $lhs_grad2) < 1e-6); 87 ok(reldiff($rhs_grad->aspdl, $rhs_grad2) < 1e-6); 88} 89 90sub test_bind 91{ 92 my ($disable_bulk_exec) = @_; 93 my ($prev_fwd_var, $prev_bwd_var); 94 if($disable_bulk_exec) 95 { 96 $prev_fwd_var = $ENV{MXNET_EXEC_BULK_FWD_THRESHOLD_TRAIN}//1; 97 $prev_bwd_var = $ENV{MXNET_EXEC_BULK_BWD_TRAIN}//1; 98 $ENV{MXNET_EXEC_BULK_FWD_THRESHOLD_TRAIN} = 0; 99 $ENV{MXNET_EXEC_BULK_BWD_TRAIN} = 0; 100 } 101 srand(0); 102 my $nrepeat = 9; 103 my $maxdim = 3; 104 for my $repeat (0..$nrepeat) 105 { 106 for my $dim (1..$maxdim) 107 { 108 check_bind_with_uniform(sub { my ($x, $y) = @_; $x + $y }, 109 sub { my ($g) = @_; ($g, $g) }, 110 $dim); 111 check_bind_with_uniform(sub { my ($x, $y) = @_; $x - $y }, 112 sub { my ($g) = @_; ($g, -$g) }, 113 $dim); 114 check_bind_with_uniform(sub { my ($x, $y) = @_; $x * $y }, 115 sub { my ($g, $x, $y) = @_; ($g*$y, $g*$x) }, 116 $dim); 117 check_bind_with_uniform(sub { my ($x, $y) = @_; $x / $y }, 118 sub { my ($g, $x, $y) = @_; ($g / $y, -$x * $g/ ($y**2)) }, 119 $dim); 120 check_bind_with_uniform(sub { my ($x, $y) = @_; pdl_maximum($x, $y) }, 121 sub { my ($g, $x, $y) = @_; ($g * ($x>$y), $g * ($y>$x)) }, 122 $dim, 123 sub { $_[0]->maximum($_[1]) }); 124 check_bind_with_uniform(sub { my ($x, $y) = @_; pdl_minimum($x, $y) }, 125 sub { my ($g, $x, $y) = @_; ($g * ($x<$y), $g * ($y<$x)) }, 126 $dim, 127 sub { $_[0]->minimum($_[1]) }); 128 } 129 } 130 if($disable_bulk_exec) 131 { 132 $ENV{MXNET_EXEC_BULK_FWD_THRESHOLD_TRAIN} = $prev_fwd_var; 133 $ENV{MXNET_EXEC_BULK_BWD_TRAIN} = $prev_bwd_var; 134 } 135} 136 137 138sub test_dot 139{ 140 srand(0); 141 my $nrepeat = 9; 142 my $maxdim = 4; 143 for my $repeat (0..$nrepeat) 144 { 145 my $shape = (random(3)*500+1)->floor->unpdl; 146 check_bind_with_uniform(sub { my ($x, $y) = @_; $x x $y }, 147 sub { my ($g, $x, $y) = @_; ($g x $y->transpose, $x->transpose x $g) }, 148 2, 149 sub { mx->symbol->dot(@_) }, 150 [@{$shape}[0, 1]], 151 [@{$shape}[1, 2]], 152 ); 153 } 154 for my $repeat (0..$nrepeat) 155 { 156 my $shape = (random(1)*500+1)->floor->unpdl; 157 check_bind_with_uniform(sub { my ($x, $y) = @_; $x x $y->transpose }, 158 sub { my ($g, $x, $y) = @_; ($g * $y, $g * $x) }, 159 2, 160 sub { mx->symbol->dot(@_) }, 161 [@{$shape}[0]], 162 [@{$shape}[0]], 163 ); 164 } 165} 166 167sub test_reshape 168{ 169 my $x = mx->sym->Variable('x'); 170 my $y = mx->sym->FullyConnected($x, num_hidden=>4); 171 my $exe = $y->simple_bind(ctx => mx->cpu(), shapes => { x=>[5,4] }, grad_req=>'null'); 172 $exe->arg_arrays->[0] .= 1; 173 $exe->arg_arrays->[1] .= mx->nd->ones([4,4]); 174 $exe->arg_arrays->[2] .= 0; 175 my $new_exe = $exe->reshape({ x=>[3,4] }); 176 $new_exe->forward(0); 177 # test sub exec forward 178 ok(($new_exe->outputs->[0]->aspdl == 4)->all); 179 # test shared memory 180 ok(($exe->outputs->[0]->aspdl->slice('X', [0,2]) == 4)->all); 181 # test base exec forward 182 $exe->forward(0); 183 ok(($new_exe->outputs->[0]->aspdl == 4)->all); 184 $new_exe = $exe->reshape({ x=>[6,4] }, allow_up_sizing=>1); 185 # data ndarray is not shared between exe and new_exe 186 $new_exe->arg_arrays->[0] .= 0; 187 ok(($exe->arg_arrays->[0]->aspdl == 1)->all); 188 # weight ndarray is shared between exe and new_exe 189 ok(($new_exe->arg_arrays->[1]->aspdl == 1)->all); 190} 191 192test_bind(0); 193test_bind(1); 194test_dot(); 195test_reshape(); 196