1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18use strict; 19use warnings; 20use AI::MXNet qw(mx); 21use AI::MXNet::AutoGrad qw(autograd); 22use AI::MXNet::TestUtils qw(same almost_equal rand_ndarray); 23use AI::MXNet::Base qw(:DEFAULT pones); 24use Test::More tests => 246; 25$ENV{MXNET_STORAGE_FALLBACK_LOG_VERBOSE} = 0; 26$ENV{MXNET_SUBGRAPH_VERBOSE} = 0; 27 28sub autograd_assert 29{ 30 my $kwargs = {}; 31 if(ref $_[-1] eq 'HASH') { $kwargs = pop(@_) }; 32 my @args = @_; 33 my $func = $kwargs->{func}; 34 my $grad_f = $kwargs->{grad_func}; 35 my $argnum = $kwargs->{argnum}; 36 my $grad_func = autograd->grad_and_loss($func, $argnum); 37 my ($grad_vals, $output) = $grad_func->(@args); 38 my $res = $func->(@args); 39 ok(same($output->aspdl, $res->aspdl)); 40 my $grad_res = $grad_f->(@args); 41 ok(@$grad_vals == @$grad_res); 42 for(zip($grad_vals, $grad_res)) { 43 my ($a, $b) = @$_; 44 ok(same($a->aspdl, $b->aspdl)); 45 } 46} 47 48sub test_unary_func 49{ 50 my $check_unary_func = sub { 51 my ($x) = @_; 52 my $f_exp = sub { $_[0]->exp }; 53 my $f_exp_grad = sub { [$_[0]->exp] }; 54 autograd_assert($x, { func => $f_exp, grad_func => $f_exp_grad }); 55 my $f_half = sub { $_[0]/2 }; 56 my $f_half_grad = sub { [mx->nd->ones($_[0]->shape) * 0.5] }; 57 autograd_assert($x, { func => $f_half, grad_func => $f_half_grad }); 58 my $f_square = sub { $_[0]**2 }; 59 my $f_square_grad = sub { [2*$_[0]] }; 60 autograd_assert($x, { func => $f_square, grad_func => $f_square_grad }); 61 }; 62 my $uniform = mx->nd->uniform(shape=>[4, 5]); 63 $check_unary_func->($uniform); 64 my $stypes = ['row_sparse', 'csr', 'default']; 65 for my $stype (@$stypes) 66 { 67 $check_unary_func->($uniform->tostype($stype)); 68 } 69} 70 71test_unary_func(); 72 73sub test_binary_func 74{ 75 my $check_binary_func = sub { 76 my ($x, $y) = @_; 77 my $f_add = sub { $_[0]+$_[1] }; 78 my $f_add_grad = sub { [map { mx->nd->ones($_->shape) } @_] }; 79 autograd_assert($x, $y, { func => $f_add, grad_func => $f_add_grad }); 80 my $f_mul = sub { $_[0]*$_[1] }; 81 my $f_mul_grad = sub { [reverse(@_)] }; 82 autograd_assert($x, $y, { func => $f_mul, grad_func => $f_mul_grad }); 83 my $f_compose = sub { $_[0]+$_[0]*$_[1] }; 84 my $f_compose_grad = sub { [mx->nd->ones($_[0]->shape) + $y, $x] }; 85 autograd_assert($x, $y, { func => $f_compose, grad_func => $f_compose_grad }); 86 }; 87 my $uniform_x = mx->nd->uniform(shape=>[4, 5]); 88 my $uniform_y = mx->nd->uniform(shape=>[4, 5]); 89 $check_binary_func->($uniform_x, $uniform_y); 90 my $stypes = ['row_sparse', 'csr', 'default']; 91 for my $stype_x (@$stypes) 92 { 93 for my $stype_y (@$stypes) 94 { 95 my $x = $uniform_x->tostype($stype_x); 96 my $y = $uniform_y->tostype($stype_y); 97 $check_binary_func->($x, $y); 98 } 99 } 100} 101 102test_binary_func(); 103 104sub test_operator_with_state 105{ 106 my $f_fc = sub { 107 my ($a, $b, $weight, $bias) = @_; 108 my $x = $a*$b; 109 my $fc = mx->nd->FullyConnected( 110 $x, $weight, $bias, num_hidden=>32); 111 return $fc; 112 }; 113 114 my $a = mx->nd->uniform(shape=>[64, 50]); 115 my $b = mx->nd->uniform(shape=>[64, 50]); 116 my $weight = mx->nd->uniform(shape=>[32, 50]); 117 my $bias = mx->nd->uniform(shape=>[32]); 118 119 my $grad_func = autograd->grad_and_loss($f_fc); 120 my ($grad_vals, $outputs) = $grad_func->($a, $b, $weight, $bias); 121} 122 123test_operator_with_state(); 124 125sub test_argnum 126{ 127 my $f_with_mode = sub { 128 my ($a, $b, $mode) = @_; 129 if($mode) 130 { 131 return $a+$b; 132 } 133 else 134 { 135 return $a*$b; 136 } 137 }; 138 my $a = mx->nd->uniform(shape=>[3, 2]); 139 my $b = mx->nd->uniform(shape=>[3, 2]); 140 my $f_add_grad = sub { [map { mx->nd->ones($_->shape) } @_[0,1]] }; 141 my $f_mul_grad = sub { [reverse(@_[0,1])] }; 142 autograd_assert($a, $b, 1, 143 { argnum=>[0, 1], func=>$f_with_mode, grad_func=>$f_add_grad }); 144 autograd_assert($a, $b, 0, 145 { argnum=>[0, 1], func=>$f_with_mode, grad_func=>$f_mul_grad }); 146} 147 148test_argnum(); 149 150sub test_training 151{ 152 my $x = mx->nd->ones([10, 10]); 153 autograd->record(sub { 154 my $y = mx->nd->Dropout($x, p=>0.5); 155 ok(not ($y->aspdl == $x->aspdl)->all); 156 autograd->pause(sub { 157 my $y = mx->nd->Dropout($x, p=>0.5); 158 ok(($y->aspdl == $x->aspdl)->all); 159 }); 160 }); 161} 162 163test_training(); 164 165sub test_out_grads 166{ 167 my $x = mx->nd->ones([3, 5]); 168 my $dx = mx->nd->zeros_like($x); 169 autograd->mark_variables([$x], [$dx]); 170 my $da; 171 my $db = mx->nd->array([1,2,3,4,5]); 172 my $dc = mx->nd->array([5,4,3,2,1]); 173 174 autograd->record(sub { 175 my ($a, $b, $c) = @{ $x }; 176 autograd->backward([$a, $b, $c], head_grads => [$da, $db, $dc]); 177 }); 178 ok(($dx->aspdl == pdl( 179 [[1,1,1,1,1], 180 [1,2,3,4,5], 181 [5,4,3,2,1]]))->all); 182} 183 184test_out_grads(); 185 186sub test_detach_updated_grad 187{ 188 my $x = mx->nd->ones([2, 2]); 189 my $dx = mx->nd->zeros_like($x); 190 my $y = mx->nd->ones_like($x); 191 my $dy = mx->nd->zeros_like($x); 192 autograd->mark_variables([$x, $y], [$dx, $dy]); 193 ok($x->_fresh_grad == 0); 194 ok($y->_fresh_grad == 0); 195 196 autograd->record(sub { 197 my $x2 = $x + 2; 198 my $y2 = $x2 + $y; 199 $y2->backward(); 200 }); 201 ok(($dx->aspdl == 1)->all); 202 ok($x->_fresh_grad == 1); 203 ok($y->_fresh_grad == 1); 204 205 $dx .= 0; 206 $x->_fresh_grad(0); 207 $y->_fresh_grad(0); 208 ok($x->_fresh_grad == 0); 209 ok($y->_fresh_grad == 0); 210 211 autograd->record(sub { 212 my $x2 = $x + 2; 213 $x2 = $x2->detach; 214 my $y2 = $x2 + $y; 215 $y2->backward(); 216 }); 217 ok(($dx->aspdl == 0)->all); 218 ok($x->_fresh_grad == 0); 219 ok($y->_fresh_grad == 1); 220} 221 222test_detach_updated_grad(); 223 224sub test_retain_grad 225{ 226 my $x = mx->nd->ones([2, 2]); 227 my $dx = mx->nd->zeros([2, 2]); 228 autograd->mark_variables([$x], [$dx], grad_reqs=>'add'); 229 autograd->record(sub { 230 my $y = $x + 1; 231 $y->backward(retain_graph=>0); 232 }); 233 ok(($dx->aspdl == 1)->all); 234 235 $dx .= 0; 236 autograd->record(sub { 237 my $y = $x + 1; 238 $y->backward(retain_graph=>1); 239 $y->backward(retain_graph=>0); 240 }); 241 ok(($dx->aspdl == 2)->all); 242 no warnings; 243 open(CPERR, ">&STDERR"); 244 open(STDERR, ">/dev/null"); 245 eval { 246 autograd->record(sub { 247 my $y = $x + 1; 248 $y->backward(); 249 $y->backward(); 250 }); 251 }; 252 open(STDERR, ">&CPERR"); 253 ok($@); 254} 255 256test_retain_grad(); 257 258sub test_attach_grad 259{ 260 my $check_attach_grad = sub { 261 my ($x) = @_; 262 ok(not defined $x->grad); 263 $x->attach_grad(); 264 autograd->record(sub { 265 my $y = $x * 2; 266 ok(not defined $y->grad); 267 $y->backward; 268 }); 269 ok(($x->grad->aspdl == 2)->all); 270 }; 271 my $zeros = mx->nd->zeros([10, 10]); 272 $check_attach_grad->($zeros); 273 my @stypes = ('default', 'row_sparse', 'csr'); 274 for my $stype (@stypes) 275 { 276 my $x = $zeros->tostype($stype); 277 $check_attach_grad->($x); 278 } 279} 280 281test_attach_grad(); 282 283sub test_is_train 284{ 285 my $x = mx->nd->ones([10, 10]); 286 $x->attach_grad(); 287 autograd->record(sub { 288 ok(autograd->is_recording()); 289 ok(autograd->is_training()); 290 my $y = mx->nd->Dropout($x, p=>0.5); 291 ok($y->aspdl->max == 2 and $y->aspdl->min == 0); 292 $y->backward(); 293 ok(($x->grad->aspdl == $y->aspdl)->all); 294 autograd->predict_mode(sub { 295 ok(autograd->is_recording()); 296 ok(not autograd->is_training()); 297 my $y = mx->nd->Dropout($x, p=>0.5); 298 ok(($y->aspdl == $x->aspdl)->all); 299 $y->backward(train_mode=>0); 300 ok(($x->grad->aspdl == $x->aspdl)->all); 301 }); 302 }, train_mode => 1); 303 304 autograd->record(sub { 305 ok(autograd->is_recording()); 306 ok(not autograd->is_training()); 307 my $y = mx->nd->Dropout($x, p=>0.5); 308 ok(($y->aspdl == $x->aspdl)->all); 309 $y->backward(train_mode=>0); 310 ok(($x->grad->aspdl == $x->aspdl)->all); 311 312 autograd->train_mode(sub { 313 ok(autograd->is_recording); 314 ok(autograd->is_training); 315 my $y = mx->nd->Dropout($x, p=>0.5); 316 ok($y->aspdl->max == 2 and $y->aspdl->min == 0); 317 $y->backward; 318 ok(($x->grad->aspdl == $y->aspdl)->all); 319 }); 320 }, train_mode => 0); 321 322 ok(not autograd->is_recording); 323 ok(not autograd->is_training); 324 my $y = mx->nd->Dropout($x, p=>0.5); 325 ok(($y->aspdl == $x->aspdl)->all); 326 327 autograd->train_mode(sub { 328 ok(not autograd->is_recording); 329 ok(autograd->is_training); 330 my $y = mx->nd->Dropout($x, p=>0.5); 331 ok($y->aspdl->max == 2 and $y->aspdl->min == 0); 332 }); 333} 334 335test_is_train(); 336 337sub test_get_symbol 338{ 339 my $x = mx->nd->ones([1]); 340 $x->attach_grad; 341 my $y; 342 autograd->record(sub { 343 $y = $x*$x + 2*$x - 1; 344 }); 345 ok(@{ autograd->get_symbol($y)->list_arguments } == 1); 346 347 my $z = mx->nd->ones([1]); 348 $z->attach_grad; 349 autograd->record(sub { 350 $y = $x*$x + 2*$z - 1; 351 }); 352 ok(@{ autograd->get_symbol($y)->list_arguments } == 2); 353} 354 355test_get_symbol(); 356 357sub test_gradient 358{ 359 my $x = mx->nd->ones([1]); 360 $x->attach_grad; 361 my $z; 362 mx->autograd->record(sub { 363 $z = mx->nd->elemwise_add($x->exp, $x); 364 }); 365 my $dx = mx->autograd->grad($z, $x, create_graph=>1); 366 ok(abs($dx->asscalar - 3.71828175) < 1e-7); 367 $dx->backward; 368 ok(abs($x->grad->asscalar - 2.71828175) < 1e-7); 369} 370 371test_gradient(); 372 373sub test_grad_with_stype 374{ 375 my $check_grad_with_stype = sub { my ($array_stype, $grad_stype, $expected_stype) = @_; 376 my $x = mx->nd->zeros([1, 1], stype=>$array_stype); 377 $x->attach_grad(stype=>$grad_stype); 378 # check grad attached 379 ok($x->grad->stype eq $expected_stype); 380 my $y = $x->detach(); 381 # check array detached 382 ok($y->stype eq $array_stype); 383 }; 384 my @stypes = ('default', 'csr', 'row_sparse'); 385 for my $stype (@stypes) 386 { 387 # check the default stype of the gradient (same as the array stype) 388 $check_grad_with_stype->($stype, undef, $stype); 389 for my $grad_stype (@stypes) 390 { 391 # check the stype of the gradient when provided 392 $check_grad_with_stype->($stype, $grad_stype, $grad_stype); 393 } 394 } 395} 396 397test_grad_with_stype(); 398 399sub test_sparse_dot_grad 400{ 401 my $check_sparse_dot_grad = sub { my ($rhs) = @_; 402 my $lhs = rand_ndarray([2, 8], 'csr'); 403 my $y; 404 mx->autograd->record(sub { 405 $y = mx->nd->dot($lhs, $rhs); 406 }); 407 $y->backward(); 408 my $grad = $rhs->grad; 409 my $grad_pdl = $lhs->aspdl->transpose x pones($rhs->shape->[1], $lhs->shape->[0]); 410 ok($grad->stype eq 'row_sparse'); 411 ok(almost_equal($grad->aspdl, $grad_pdl)); 412 }; 413 414 # check grad with row_sparse weight 415 my $shape = [8, 3]; 416 my $rsp = mx->nd->ones($shape)->tostype('row_sparse'); 417 $rsp->attach_grad(); 418 $check_sparse_dot_grad->($rsp); 419 420 # check grad with dense weight 421 my $dns = mx->nd->ones($shape); 422 $dns->attach_grad(stype=>'row_sparse'); 423 $check_sparse_dot_grad->($dns); 424} 425 426test_sparse_dot_grad(); 427