1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18use strict;
19use warnings;
20use AI::MXNet qw(mx);
21use AI::MXNet::AutoGrad qw(autograd);
22use AI::MXNet::TestUtils qw(same almost_equal rand_ndarray);
23use AI::MXNet::Base qw(:DEFAULT pones);
24use Test::More tests => 246;
25$ENV{MXNET_STORAGE_FALLBACK_LOG_VERBOSE} = 0;
26$ENV{MXNET_SUBGRAPH_VERBOSE} = 0;
27
28sub autograd_assert
29{
30    my $kwargs = {};
31    if(ref $_[-1] eq 'HASH') { $kwargs = pop(@_) };
32    my @args = @_;
33    my $func   = $kwargs->{func};
34    my $grad_f = $kwargs->{grad_func};
35    my $argnum = $kwargs->{argnum};
36    my $grad_func = autograd->grad_and_loss($func, $argnum);
37    my ($grad_vals, $output) = $grad_func->(@args);
38    my $res = $func->(@args);
39    ok(same($output->aspdl, $res->aspdl));
40    my $grad_res = $grad_f->(@args);
41    ok(@$grad_vals == @$grad_res);
42    for(zip($grad_vals, $grad_res)) {
43        my ($a, $b) = @$_;
44        ok(same($a->aspdl, $b->aspdl));
45    }
46}
47
48sub test_unary_func
49{
50    my $check_unary_func = sub {
51        my ($x) = @_;
52        my $f_exp       = sub { $_[0]->exp };
53        my $f_exp_grad  = sub { [$_[0]->exp] };
54        autograd_assert($x, { func => $f_exp, grad_func => $f_exp_grad });
55        my $f_half      = sub { $_[0]/2 };
56        my $f_half_grad = sub { [mx->nd->ones($_[0]->shape) * 0.5] };
57        autograd_assert($x, { func => $f_half, grad_func => $f_half_grad });
58        my $f_square    = sub { $_[0]**2 };
59        my $f_square_grad = sub { [2*$_[0]] };
60        autograd_assert($x, { func => $f_square, grad_func => $f_square_grad });
61    };
62    my $uniform = mx->nd->uniform(shape=>[4, 5]);
63    $check_unary_func->($uniform);
64    my $stypes = ['row_sparse', 'csr', 'default'];
65    for my $stype (@$stypes)
66    {
67        $check_unary_func->($uniform->tostype($stype));
68    }
69}
70
71test_unary_func();
72
73sub test_binary_func
74{
75    my $check_binary_func = sub {
76        my ($x, $y) = @_;
77        my $f_add      = sub { $_[0]+$_[1] };
78        my $f_add_grad = sub { [map { mx->nd->ones($_->shape) } @_] };
79        autograd_assert($x, $y, { func => $f_add, grad_func => $f_add_grad });
80        my $f_mul      = sub { $_[0]*$_[1] };
81        my $f_mul_grad = sub { [reverse(@_)] };
82        autograd_assert($x, $y, { func => $f_mul, grad_func => $f_mul_grad });
83        my $f_compose  = sub { $_[0]+$_[0]*$_[1] };
84        my $f_compose_grad = sub { [mx->nd->ones($_[0]->shape) + $y, $x] };
85        autograd_assert($x, $y, { func => $f_compose, grad_func => $f_compose_grad });
86    };
87    my $uniform_x = mx->nd->uniform(shape=>[4, 5]);
88    my $uniform_y = mx->nd->uniform(shape=>[4, 5]);
89    $check_binary_func->($uniform_x, $uniform_y);
90    my $stypes = ['row_sparse', 'csr', 'default'];
91    for my $stype_x (@$stypes)
92    {
93        for my $stype_y (@$stypes)
94        {
95            my $x = $uniform_x->tostype($stype_x);
96            my $y = $uniform_y->tostype($stype_y);
97            $check_binary_func->($x, $y);
98        }
99    }
100}
101
102test_binary_func();
103
104sub test_operator_with_state
105{
106    my $f_fc = sub {
107        my ($a, $b, $weight, $bias) = @_;
108        my $x = $a*$b;
109        my $fc = mx->nd->FullyConnected(
110            $x, $weight, $bias, num_hidden=>32);
111        return $fc;
112    };
113
114    my $a = mx->nd->uniform(shape=>[64, 50]);
115    my $b = mx->nd->uniform(shape=>[64, 50]);
116    my $weight = mx->nd->uniform(shape=>[32, 50]);
117    my $bias = mx->nd->uniform(shape=>[32]);
118
119    my $grad_func = autograd->grad_and_loss($f_fc);
120    my ($grad_vals, $outputs) = $grad_func->($a, $b, $weight, $bias);
121}
122
123test_operator_with_state();
124
125sub test_argnum
126{
127    my $f_with_mode = sub {
128        my ($a, $b, $mode) = @_;
129        if($mode)
130        {
131            return $a+$b;
132        }
133        else
134        {
135            return $a*$b;
136        }
137    };
138    my $a = mx->nd->uniform(shape=>[3, 2]);
139    my $b = mx->nd->uniform(shape=>[3, 2]);
140    my $f_add_grad = sub { [map { mx->nd->ones($_->shape) } @_[0,1]] };
141    my $f_mul_grad = sub { [reverse(@_[0,1])] };
142    autograd_assert($a, $b, 1,
143        { argnum=>[0, 1], func=>$f_with_mode, grad_func=>$f_add_grad });
144    autograd_assert($a, $b, 0,
145        { argnum=>[0, 1], func=>$f_with_mode, grad_func=>$f_mul_grad });
146}
147
148test_argnum();
149
150sub test_training
151{
152    my $x = mx->nd->ones([10, 10]);
153    autograd->record(sub {
154        my $y = mx->nd->Dropout($x, p=>0.5);
155        ok(not ($y->aspdl == $x->aspdl)->all);
156        autograd->pause(sub {
157            my $y = mx->nd->Dropout($x, p=>0.5);
158            ok(($y->aspdl == $x->aspdl)->all);
159        });
160    });
161}
162
163test_training();
164
165sub test_out_grads
166{
167    my $x = mx->nd->ones([3, 5]);
168    my $dx = mx->nd->zeros_like($x);
169    autograd->mark_variables([$x], [$dx]);
170    my $da;
171    my $db = mx->nd->array([1,2,3,4,5]);
172    my $dc = mx->nd->array([5,4,3,2,1]);
173
174    autograd->record(sub {
175        my ($a, $b, $c) = @{ $x };
176        autograd->backward([$a, $b, $c], head_grads => [$da, $db, $dc]);
177    });
178    ok(($dx->aspdl == pdl(
179        [[1,1,1,1,1],
180         [1,2,3,4,5],
181         [5,4,3,2,1]]))->all);
182}
183
184test_out_grads();
185
186sub test_detach_updated_grad
187{
188    my $x = mx->nd->ones([2, 2]);
189    my $dx = mx->nd->zeros_like($x);
190    my $y = mx->nd->ones_like($x);
191    my $dy = mx->nd->zeros_like($x);
192    autograd->mark_variables([$x, $y], [$dx, $dy]);
193    ok($x->_fresh_grad == 0);
194    ok($y->_fresh_grad == 0);
195
196    autograd->record(sub {
197        my $x2 = $x + 2;
198        my $y2  = $x2 + $y;
199        $y2->backward();
200    });
201    ok(($dx->aspdl == 1)->all);
202    ok($x->_fresh_grad == 1);
203    ok($y->_fresh_grad == 1);
204
205    $dx .= 0;
206    $x->_fresh_grad(0);
207    $y->_fresh_grad(0);
208    ok($x->_fresh_grad == 0);
209    ok($y->_fresh_grad == 0);
210
211    autograd->record(sub {
212        my $x2 = $x + 2;
213        $x2 = $x2->detach;
214        my $y2  = $x2 + $y;
215        $y2->backward();
216    });
217    ok(($dx->aspdl == 0)->all);
218    ok($x->_fresh_grad == 0);
219    ok($y->_fresh_grad == 1);
220}
221
222test_detach_updated_grad();
223
224sub test_retain_grad
225{
226    my $x = mx->nd->ones([2, 2]);
227    my $dx = mx->nd->zeros([2, 2]);
228    autograd->mark_variables([$x], [$dx], grad_reqs=>'add');
229    autograd->record(sub {
230        my $y = $x + 1;
231        $y->backward(retain_graph=>0);
232    });
233    ok(($dx->aspdl == 1)->all);
234
235    $dx .= 0;
236    autograd->record(sub {
237        my $y = $x + 1;
238        $y->backward(retain_graph=>1);
239        $y->backward(retain_graph=>0);
240    });
241    ok(($dx->aspdl == 2)->all);
242    no warnings;
243    open(CPERR, ">&STDERR");
244    open(STDERR, ">/dev/null");
245    eval {
246        autograd->record(sub {
247            my $y = $x + 1;
248            $y->backward();
249            $y->backward();
250        });
251    };
252    open(STDERR, ">&CPERR");
253    ok($@);
254}
255
256test_retain_grad();
257
258sub test_attach_grad
259{
260    my $check_attach_grad = sub {
261        my ($x) = @_;
262        ok(not defined $x->grad);
263        $x->attach_grad();
264        autograd->record(sub {
265            my $y = $x * 2;
266            ok(not defined $y->grad);
267            $y->backward;
268        });
269        ok(($x->grad->aspdl == 2)->all);
270    };
271    my $zeros = mx->nd->zeros([10, 10]);
272    $check_attach_grad->($zeros);
273    my @stypes = ('default', 'row_sparse', 'csr');
274    for my $stype (@stypes)
275    {
276        my $x = $zeros->tostype($stype);
277        $check_attach_grad->($x);
278    }
279}
280
281test_attach_grad();
282
283sub test_is_train
284{
285    my $x = mx->nd->ones([10, 10]);
286    $x->attach_grad();
287    autograd->record(sub {
288        ok(autograd->is_recording());
289        ok(autograd->is_training());
290        my $y = mx->nd->Dropout($x, p=>0.5);
291        ok($y->aspdl->max == 2 and $y->aspdl->min == 0);
292        $y->backward();
293        ok(($x->grad->aspdl == $y->aspdl)->all);
294        autograd->predict_mode(sub {
295            ok(autograd->is_recording());
296            ok(not autograd->is_training());
297            my $y = mx->nd->Dropout($x, p=>0.5);
298            ok(($y->aspdl == $x->aspdl)->all);
299            $y->backward(train_mode=>0);
300            ok(($x->grad->aspdl == $x->aspdl)->all);
301        });
302    }, train_mode => 1);
303
304    autograd->record(sub {
305        ok(autograd->is_recording());
306        ok(not autograd->is_training());
307        my $y = mx->nd->Dropout($x, p=>0.5);
308        ok(($y->aspdl == $x->aspdl)->all);
309        $y->backward(train_mode=>0);
310        ok(($x->grad->aspdl == $x->aspdl)->all);
311
312        autograd->train_mode(sub {
313            ok(autograd->is_recording);
314            ok(autograd->is_training);
315            my $y = mx->nd->Dropout($x, p=>0.5);
316            ok($y->aspdl->max == 2 and $y->aspdl->min == 0);
317            $y->backward;
318            ok(($x->grad->aspdl == $y->aspdl)->all);
319        });
320    }, train_mode => 0);
321
322    ok(not autograd->is_recording);
323    ok(not autograd->is_training);
324    my $y = mx->nd->Dropout($x, p=>0.5);
325    ok(($y->aspdl == $x->aspdl)->all);
326
327    autograd->train_mode(sub {
328        ok(not autograd->is_recording);
329        ok(autograd->is_training);
330        my $y = mx->nd->Dropout($x, p=>0.5);
331        ok($y->aspdl->max == 2 and $y->aspdl->min == 0);
332    });
333}
334
335test_is_train();
336
337sub test_get_symbol
338{
339    my $x = mx->nd->ones([1]);
340    $x->attach_grad;
341    my $y;
342    autograd->record(sub {
343        $y = $x*$x + 2*$x - 1;
344    });
345    ok(@{ autograd->get_symbol($y)->list_arguments } == 1);
346
347    my $z = mx->nd->ones([1]);
348    $z->attach_grad;
349    autograd->record(sub {
350        $y = $x*$x + 2*$z - 1;
351    });
352    ok(@{ autograd->get_symbol($y)->list_arguments } == 2);
353}
354
355test_get_symbol();
356
357sub test_gradient
358{
359    my $x = mx->nd->ones([1]);
360    $x->attach_grad;
361    my $z;
362    mx->autograd->record(sub {
363        $z = mx->nd->elemwise_add($x->exp, $x);
364    });
365    my $dx = mx->autograd->grad($z, $x, create_graph=>1);
366    ok(abs($dx->asscalar - 3.71828175) < 1e-7);
367    $dx->backward;
368    ok(abs($x->grad->asscalar - 2.71828175) < 1e-7);
369}
370
371test_gradient();
372
373sub test_grad_with_stype
374{
375    my $check_grad_with_stype = sub { my ($array_stype, $grad_stype, $expected_stype) = @_;
376        my $x = mx->nd->zeros([1, 1], stype=>$array_stype);
377        $x->attach_grad(stype=>$grad_stype);
378        # check grad attached
379        ok($x->grad->stype eq $expected_stype);
380        my $y = $x->detach();
381        # check array detached
382        ok($y->stype eq $array_stype);
383    };
384    my @stypes = ('default', 'csr', 'row_sparse');
385    for my $stype (@stypes)
386    {
387        # check the default stype of the gradient (same as the array stype)
388        $check_grad_with_stype->($stype, undef, $stype);
389        for my $grad_stype (@stypes)
390        {
391            # check the stype of the gradient when provided
392            $check_grad_with_stype->($stype, $grad_stype, $grad_stype);
393        }
394    }
395}
396
397test_grad_with_stype();
398
399sub test_sparse_dot_grad
400{
401    my $check_sparse_dot_grad = sub { my ($rhs) = @_;
402        my $lhs = rand_ndarray([2, 8], 'csr');
403        my $y;
404        mx->autograd->record(sub {
405            $y = mx->nd->dot($lhs, $rhs);
406        });
407        $y->backward();
408        my $grad = $rhs->grad;
409        my $grad_pdl = $lhs->aspdl->transpose x pones($rhs->shape->[1], $lhs->shape->[0]);
410        ok($grad->stype eq 'row_sparse');
411        ok(almost_equal($grad->aspdl, $grad_pdl));
412    };
413
414    # check grad with row_sparse weight
415    my $shape = [8, 3];
416    my $rsp = mx->nd->ones($shape)->tostype('row_sparse');
417    $rsp->attach_grad();
418    $check_sparse_dot_grad->($rsp);
419
420    # check grad with dense weight
421    my $dns = mx->nd->ones($shape);
422    $dns->attach_grad(stype=>'row_sparse');
423    $check_sparse_dot_grad->($dns);
424}
425
426test_sparse_dot_grad();
427