1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18use strict;
19use warnings;
20use Test::More tests => 2285;
21use AI::MXNet qw(mx);
22use AI::MXNet::TestUtils qw(reldiff pdl_maximum pdl_minimum);
23use PDL;
24
25sub check_bind_with_uniform
26{
27    my ($uf, $gf, $dim, $sf, $lshape, $rshape) = @_;
28    my $shape = (random($dim)*int(1000**(1.0/$dim))+1)->floor->unpdl;
29    my $lhs = mx->symbol->Variable('lhs');
30    my $rhs = mx->symbol->Variable('rhs');
31    my $ret;
32    if(defined $sf)
33    {
34        $ret = &{$sf}($lhs, $rhs);
35    }
36    else
37    {
38        $ret = &{$uf}($lhs, $rhs);
39    }
40
41    is_deeply($ret->list_arguments(), ['lhs', 'rhs']);
42    $lshape //= $shape;
43    $rshape //= $shape;
44
45    my $lhs_arr = mx->nd->array(random(reverse (@$lshape)));
46    my $rhs_arr = mx->nd->array(random(reverse (@$rshape)));
47    my $lhs_grad = mx->nd->empty($lshape);
48    my $rhs_grad = mx->nd->empty($rshape);
49    my $executor = $ret->bind(
50        ctx       => mx->Context('cpu'),
51        args      => [$lhs_arr, $rhs_arr],
52        args_grad => [$lhs_grad, $rhs_grad]
53    );
54
55    my $exec3 = $ret->bind(
56        ctx  => mx->Context('cpu'),
57        args => [$lhs_arr, $rhs_arr]
58    );
59
60    my $exec4 = $ret->bind(
61        ctx  => mx->Context('cpu'),
62        args => {'rhs' => $rhs_arr, 'lhs' => $lhs_arr},
63        args_grad=>{'lhs' => $lhs_grad, 'rhs' => $rhs_grad}
64    );
65
66    $executor->forward(1);
67    $exec3->forward(1);
68    $exec4->forward(1);
69    my $out2 = $executor->outputs->[0]->aspdl;
70    my $out1 = &{$uf}($lhs_arr->aspdl, $rhs_arr->aspdl);
71    my $out3 = $exec3->outputs->[0]->aspdl;
72    my $out4 = $exec4->outputs->[0]->aspdl;
73    ok(reldiff($out1, $out2) < 1e-6);
74    ok(reldiff($out1, $out3) < 1e-6);
75    ok(reldiff($out1, $out4) < 1e-6);
76    # test gradient
77
78    my $out_grad = mx->nd->ones([reverse @{$out2->shape->unpdl}]);
79    my ($lhs_grad2, $rhs_grad2) = &{$gf}(
80        $out_grad->aspdl,
81        $lhs_arr->aspdl,
82        $rhs_arr->aspdl
83    );
84    $executor->backward([$out_grad]);
85
86    ok(reldiff($lhs_grad->aspdl, $lhs_grad2) < 1e-6);
87    ok(reldiff($rhs_grad->aspdl, $rhs_grad2) < 1e-6);
88}
89
90sub test_bind
91{
92    my ($disable_bulk_exec) = @_;
93    my ($prev_fwd_var, $prev_bwd_var);
94    if($disable_bulk_exec)
95    {
96        $prev_fwd_var = $ENV{MXNET_EXEC_BULK_FWD_THRESHOLD_TRAIN}//1;
97        $prev_bwd_var = $ENV{MXNET_EXEC_BULK_BWD_TRAIN}//1;
98        $ENV{MXNET_EXEC_BULK_FWD_THRESHOLD_TRAIN} = 0;
99        $ENV{MXNET_EXEC_BULK_BWD_TRAIN} = 0;
100    }
101    srand(0);
102    my $nrepeat = 9;
103    my $maxdim = 3;
104    for my $repeat (0..$nrepeat)
105    {
106        for my $dim (1..$maxdim)
107        {
108            check_bind_with_uniform(sub { my ($x, $y) = @_; $x + $y },
109                                    sub { my ($g) = @_; ($g, $g) },
110                                    $dim);
111            check_bind_with_uniform(sub { my ($x, $y) = @_; $x - $y },
112                                    sub { my ($g) = @_; ($g, -$g) },
113                                    $dim);
114            check_bind_with_uniform(sub { my ($x, $y) = @_; $x * $y },
115                                    sub { my ($g, $x, $y) = @_; ($g*$y, $g*$x) },
116                                    $dim);
117            check_bind_with_uniform(sub { my ($x, $y) = @_; $x / $y },
118                                    sub { my ($g, $x, $y) = @_; ($g / $y, -$x * $g/ ($y**2)) },
119                                    $dim);
120            check_bind_with_uniform(sub { my ($x, $y) = @_; pdl_maximum($x, $y) },
121                                    sub { my ($g, $x, $y) = @_; ($g * ($x>$y), $g * ($y>$x)) },
122                                    $dim,
123                                    sub { $_[0]->maximum($_[1]) });
124            check_bind_with_uniform(sub { my ($x, $y) = @_; pdl_minimum($x, $y) },
125                                    sub { my ($g, $x, $y) = @_; ($g * ($x<$y), $g * ($y<$x)) },
126                                    $dim,
127                                    sub { $_[0]->minimum($_[1]) });
128        }
129    }
130    if($disable_bulk_exec)
131    {
132        $ENV{MXNET_EXEC_BULK_FWD_THRESHOLD_TRAIN} = $prev_fwd_var;
133        $ENV{MXNET_EXEC_BULK_BWD_TRAIN}           = $prev_bwd_var;
134    }
135}
136
137
138sub test_dot
139{
140    srand(0);
141    my $nrepeat = 9;
142    my $maxdim = 4;
143    for my $repeat (0..$nrepeat)
144    {
145        my $shape = (random(3)*500+1)->floor->unpdl;
146        check_bind_with_uniform(sub { my ($x, $y) = @_; $x x $y },
147                                sub { my ($g, $x, $y) = @_; ($g x $y->transpose, $x->transpose x $g) },
148                                2,
149                                sub { mx->symbol->dot(@_) },
150                                [@{$shape}[0, 1]],
151                                [@{$shape}[1, 2]],
152        );
153    }
154    for my $repeat (0..$nrepeat)
155    {
156        my $shape = (random(1)*500+1)->floor->unpdl;
157        check_bind_with_uniform(sub { my ($x, $y) = @_; $x x $y->transpose },
158                                sub { my ($g, $x, $y) = @_; ($g * $y, $g * $x) },
159                                2,
160                                sub { mx->symbol->dot(@_) },
161                                [@{$shape}[0]],
162                                [@{$shape}[0]],
163        );
164    }
165}
166
167sub test_reshape
168{
169    my $x = mx->sym->Variable('x');
170    my $y = mx->sym->FullyConnected($x, num_hidden=>4);
171    my $exe = $y->simple_bind(ctx => mx->cpu(), shapes => { x=>[5,4] }, grad_req=>'null');
172    $exe->arg_arrays->[0] .= 1;
173    $exe->arg_arrays->[1] .= mx->nd->ones([4,4]);
174    $exe->arg_arrays->[2] .= 0;
175    my $new_exe = $exe->reshape({ x=>[3,4] });
176    $new_exe->forward(0);
177    # test sub exec forward
178    ok(($new_exe->outputs->[0]->aspdl == 4)->all);
179    # test shared memory
180    ok(($exe->outputs->[0]->aspdl->slice('X', [0,2]) == 4)->all);
181    # test base exec forward
182    $exe->forward(0);
183    ok(($new_exe->outputs->[0]->aspdl == 4)->all);
184    $new_exe = $exe->reshape({ x=>[6,4] }, allow_up_sizing=>1);
185    # data ndarray is not shared between exe and new_exe
186    $new_exe->arg_arrays->[0] .= 0;
187    ok(($exe->arg_arrays->[0]->aspdl == 1)->all);
188    # weight ndarray is shared between exe and new_exe
189    ok(($new_exe->arg_arrays->[1]->aspdl == 1)->all);
190}
191
192test_bind(0);
193test_bind(1);
194test_dot();
195test_reshape();
196