1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18import functools 19import mxnet.ndarray as nd 20from mxnet.ndarray import zeros_like 21from mxnet.autograd import * 22from mxnet.test_utils import * 23from common import setup_module, with_seed, teardown 24from mxnet.test_utils import environment 25 26 27def grad_and_loss(func, argnum=None): 28 """Return function that computes both gradient of arguments and loss value. 29 30 Parameters 31 ---------- 32 func: a python function 33 The forward (loss) function. 34 argnum: an int or a list of int 35 The index of argument to calculate gradient for. 36 37 Returns 38 ------- 39 grad_and_loss_func: a python function 40 A function that would compute both the gradient of arguments and loss value. 41 """ 42 @functools.wraps(func) 43 def wrapped(*args): 44 """Wrapped function.""" 45 variables = args 46 if argnum is not None: 47 argnum_ = argnum if isinstance(argnum, list) else [argnum] 48 variables = [args[i] for i in argnum_] 49 for x in variables: 50 assert isinstance(x, NDArray), "type of autograd input should NDArray." 51 grads = [zeros_like(x) for x in variables] 52 mark_variables(variables, grads) 53 with record(): 54 outputs = func(*args) 55 backward([outputs] if isinstance(outputs, NDArray) else outputs) 56 return grads, outputs 57 return wrapped 58 59def grad(func, argnum=None): 60 """Return function that computes gradient of arguments. 61 62 Parameters 63 ---------- 64 func: a python function 65 The forward (loss) function. 66 argnum: an int or a list of int 67 The index of argument to calculate gradient for. 68 69 Returns 70 ------- 71 grad_func: a python function 72 A function that would compute the gradient of arguments. 73 74 Examples 75 -------- 76 >>> # autograd supports dynamic graph which is changed 77 >>> # every instance 78 >>> def func(x): 79 >>> r = random.randint(0, 1) 80 >>> if r % 2: 81 >>> return x**2 82 >>> else: 83 >>> return x/3 84 >>> # use `grad(func)` to get the gradient function 85 >>> for x in range(10): 86 >>> grad_func = grad(func) 87 >>> inputs = nd.array([[1, 2, 3], [4, 5, 6]]) 88 >>> grad_vals = grad_func(inputs) 89 """ 90 grad_with_loss_func = grad_and_loss(func, argnum) 91 @functools.wraps(grad_with_loss_func) 92 def wrapped(*args): 93 return grad_with_loss_func(*args)[0] 94 return wrapped 95 96def autograd_assert(*args, **kwargs): 97 func = kwargs["func"] 98 grad_f = kwargs["grad_func"] 99 argnum = kwargs["argnum"] if 'argnum' in kwargs else None 100 101 grad_func = grad_and_loss(func, argnum) 102 grad_vals, output = grad_func(*args) 103 res = func(*args) 104 assert same(output.asnumpy(), res.asnumpy()) 105 grad_res = grad_f(*args) 106 assert len(grad_vals) == len(grad_res) 107 for a, b in zip(grad_vals, grad_res): 108 assert same(a.asnumpy(), b.asnumpy()) 109 110@with_seed() 111def test_unary_func(): 112 def check_unary_func(x): 113 f_exp = lambda x: nd.exp(x) 114 f_exp_grad = lambda x: [nd.exp(x)] 115 autograd_assert(x, func=f_exp, grad_func=f_exp_grad) 116 f_half = lambda x: x/2 117 f_half_grad = lambda x: [nd.ones(x.shape) * 0.5] 118 autograd_assert(x, func=f_half, grad_func=f_half_grad) 119 f_square = lambda x: x**2 120 f_square_grad = lambda x: [2*x] 121 autograd_assert(x, func=f_square, grad_func=f_square_grad) 122 uniform = nd.uniform(shape=(4, 5)) 123 stypes = ['default', 'row_sparse', 'csr'] 124 with environment('MXNET_STORAGE_FALLBACK_LOG_VERBOSE', '0'): 125 for stype in stypes: 126 check_unary_func(uniform.tostype(stype)) 127 128@with_seed() 129def test_binary_func(): 130 def check_binary_func(x, y): 131 f_add = lambda x, y: x+y 132 f_add_grad = lambda x, y: [nd.ones(x.shape), nd.ones(y.shape)] 133 autograd_assert(x, y, func=f_add, grad_func=f_add_grad) 134 f_mul = lambda x, y: x*y 135 f_mul_grad = lambda x, y: [y, x] 136 autograd_assert(x, y, func=f_mul, grad_func=f_mul_grad) 137 f_compose = lambda x, y: x+x*y 138 f_compose_grad = lambda x, y: [nd.ones(x.shape) + y, x] 139 autograd_assert(x, y, func=f_compose, grad_func=f_compose_grad) 140 uniform_x = nd.uniform(shape=(4, 5)) 141 uniform_y = nd.uniform(shape=(4, 5)) 142 stypes = ['default', 'row_sparse', 'csr'] 143 with environment('MXNET_STORAGE_FALLBACK_LOG_VERBOSE', '0'): 144 for stype_x in stypes: 145 for stype_y in stypes: 146 x = uniform_x.tostype(stype_x) 147 y = uniform_y.tostype(stype_y) 148 check_binary_func(x, y) 149 150 151@with_seed() 152def test_operator_with_state(): 153 def f_fc(a, b, weight, bias): 154 x = a*b 155 fc = nd.FullyConnected( 156 x, weight, bias, num_hidden=32) 157 return fc 158 159 a = nd.uniform(shape=(64, 50)) 160 b = nd.uniform(shape=(64, 50)) 161 weight = nd.uniform(shape=(32, 50)) 162 bias = nd.uniform(shape=(32, )) 163 164 grad_func = grad_and_loss(f_fc) 165 grad_vals, outputs = grad_func(a, b, weight, bias) 166 # (TODO) assert 167 168@with_seed() 169def test_argnum(): 170 def f_with_mode(a, b, mode): 171 if mode: 172 return a+b 173 else: 174 return a*b 175 176 a = nd.uniform(shape=(3, 2)) 177 b = nd.uniform(shape=(3, 2)) 178 f_add_grad = lambda x, y, mode: [nd.ones(x.shape), nd.ones(y.shape)] 179 f_mul_grad = lambda x, y, mode: [y, x] 180 autograd_assert(a, b, True, 181 argnum=[0, 1], func=f_with_mode, grad_func=f_add_grad) 182 autograd_assert(a, b, False, 183 argnum=[0, 1], func=f_with_mode, grad_func=f_mul_grad) 184 185 186@with_seed() 187def test_training(): 188 x = nd.ones((10, 10)) 189 with record(): 190 y = nd.Dropout(x, p=0.5) 191 assert not (y.asnumpy() == x.asnumpy()).all() 192 with pause(): 193 y = nd.Dropout(x, p=0.5) 194 assert (y.asnumpy() == x.asnumpy()).all() 195 196 197@with_seed() 198def test_out_grads(): 199 x = nd.ones((3, 5)) 200 dx = nd.zeros_like(x) 201 mark_variables([x], [dx]) 202 da = None 203 db = nd.array([1,2,3,4,5]) 204 dc = nd.array([5,4,3,2,1]) 205 206 with record(): 207 a, b, c = nd.split(x, axis=0, num_outputs=3, squeeze_axis=True) 208 backward([a, b, c], [da, db, dc]) 209 210 assert (dx.asnumpy() == np.array( 211 [[1,1,1,1,1], 212 [1,2,3,4,5], 213 [5,4,3,2,1]])).all() 214 215 216@with_seed() 217def test_detach_updated_grad(): 218 x = nd.ones((2, 2)) 219 dx = nd.zeros_like(x) 220 y = nd.ones_like(x) 221 dy = nd.zeros_like(x) 222 mark_variables([x, y], [dx, dy]) 223 assert x._fresh_grad == False 224 assert y._fresh_grad == False 225 226 with record(): 227 x2 = x + 2 228 y2 = x2 + y 229 y2.backward() 230 assert (dx.asnumpy() == 1).all() 231 assert x._fresh_grad == True 232 assert y._fresh_grad == True 233 234 dx[:] = 0 235 x._fresh_grad = False 236 y._fresh_grad = False 237 assert x._fresh_grad == False 238 assert y._fresh_grad == False 239 with record(): 240 x2 = x + 2 241 x2 = x2.detach() 242 y2 = x2 + y 243 y2.backward() 244 assert (dx.asnumpy() == 0).all() 245 assert y._fresh_grad == True 246 assert x._fresh_grad == False 247 248 249@with_seed() 250def test_retain_grad(): 251 x = mx.nd.ones((2, 2)) 252 dx = mx.nd.zeros((2, 2)) 253 mark_variables([x], [dx], grad_reqs='add') 254 with record(): 255 y = x + 1 256 y.backward(retain_graph=False) 257 assert (dx.asnumpy() == 1).all() 258 259 dx[:] = 0 260 with record(): 261 y = x + 1 262 y.backward(retain_graph=True) 263 y.backward(retain_graph=False) 264 assert (dx.asnumpy() == 2).all() 265 266 # The following sequence should throw an exception. We discard the expected 267 # stderr stack trace output for this operation to keep the test logs clean. 268 with discard_stderr(): 269 try: 270 with record(): 271 y = x + 1 272 y.backward() 273 y.backward() 274 except Exception: 275 return 276 277 raise AssertionError( 278 "differentiating the same graph twice without retain_graph should fail") 279 280 281@with_seed() 282def test_attach_grad(): 283 def check_attach_grad(x): 284 assert x.grad is None 285 x.attach_grad() 286 with record(): 287 y = x * 2 288 assert y.grad is None 289 y.backward(out_grad=mx.nd.ones_like(y).tostype(x.stype)) 290 assert (x.grad.asnumpy() == 2).all() 291 zeros = mx.nd.zeros((10, 10)) 292 stypes = ['default', 'row_sparse', 'csr'] 293 for stype in stypes: 294 x = zeros.tostype(stype) 295 check_attach_grad(x) 296 297 298@with_seed() 299def test_is_train(): 300 x = mx.nd.ones((10, 10)) 301 x.attach_grad() 302 with record(train_mode=True): 303 assert is_recording() 304 assert is_training() 305 y = mx.nd.Dropout(x, p=0.5) 306 assert y.asnumpy().max() == 2 and y.asnumpy().min() == 0 307 y.backward() 308 assert (x.grad.asnumpy() == y.asnumpy()).all() 309 310 with predict_mode(): 311 assert is_recording() 312 assert not is_training() 313 y = mx.nd.Dropout(x, p=0.5) 314 assert (y.asnumpy() == x.asnumpy()).all() 315 y.backward(train_mode=False) 316 assert (x.grad.asnumpy() == x.asnumpy()).all() 317 318 with record(train_mode=False): 319 assert is_recording() 320 assert not is_training() 321 y = mx.nd.Dropout(x, p=0.5) 322 assert (y.asnumpy() == x.asnumpy()).all() 323 y.backward(train_mode=False) 324 assert (x.grad.asnumpy() == x.asnumpy()).all() 325 326 with train_mode(): 327 assert is_recording() 328 assert is_training() 329 y = mx.nd.Dropout(x, p=0.5) 330 assert y.asnumpy().max() == 2 and y.asnumpy().min() == 0 331 y.backward() 332 assert (x.grad.asnumpy() == y.asnumpy()).all() 333 334 assert not is_recording() 335 assert not is_training() 336 y = mx.nd.Dropout(x, p=0.5) 337 assert (y.asnumpy() == x.asnumpy()).all() 338 339 with train_mode(): 340 assert not is_recording() 341 assert is_training() 342 y = mx.nd.Dropout(x, p=0.5) 343 assert y.asnumpy().max() == 2 and y.asnumpy().min() == 0 344 345@with_seed() 346def test_function(): 347 class func(Function): 348 def forward(self, x, y): 349 m = x / y 350 n = x * y 351 self.save_for_backward(x, y) 352 return m, n 353 354 def backward(self, dm, dn): 355 x, y = self.saved_tensors 356 dx = dm/y + dn*y 357 dy = dn*x - dm * x / y / y 358 return dx, dy 359 360 f = func() 361 x = mx.nd.random.uniform(shape=(10,)) 362 x.attach_grad() 363 y = mx.nd.random.uniform(shape=(10,)) 364 y.attach_grad() 365 with record(): 366 m, n = f(x, y) 367 backward([m, n]) 368 369 dx1 = x.grad.asnumpy() 370 dy1 = y.grad.asnumpy() 371 372 with record(): 373 backward([x/y, x*y]) 374 375 # Non-zero atol required, as exposed by seed 630179191 376 atol = 1e-6 377 assert_almost_equal(x.grad.asnumpy(), dx1, atol=atol) 378 assert_almost_equal(y.grad.asnumpy(), dy1, atol=atol) 379 380 381@with_seed() 382def test_function1(): 383 class Foo(mx.autograd.Function): 384 def __init__(self): 385 super(Foo, self).__init__() 386 387 def forward(self, X): 388 return X + 1; 389 390 def backward(self, dY): 391 return dY 392 393 with mx.autograd.record(): 394 X = mx.nd.zeros((3, 4)) 395 #X.attach_grad() # uncommenting this line works 396 for i in range(5): 397 f = Foo() 398 X = f(X) 399 X.wait_to_read() 400 401 402@with_seed() 403def test_get_symbol(): 404 x = mx.nd.ones((1,)) 405 x.attach_grad() 406 with record(): 407 y = x*x + 2*x - 1 408 assert len(get_symbol(y).list_arguments()) == 1 409 410 z = mx.nd.ones((1,)) 411 z.attach_grad() 412 with record(): 413 y = x*x + 2*z - 1 414 assert len(get_symbol(y).list_arguments()) == 2 415 416@with_seed() 417def test_grad_with_stype(): 418 def check_grad_with_stype(array_stype, grad_stype, expected_stype): 419 x = mx.nd.zeros((1, 1), stype=array_stype) 420 x.attach_grad(stype=grad_stype) 421 # check grad attached 422 assert x.grad.stype == expected_stype 423 y = x.detach() 424 # check array detached 425 assert y.stype == array_stype 426 427 stypes = ['default', 'csr', 'row_sparse'] 428 for stype in stypes: 429 # check the default stype of the gradient (same as the array stype) 430 check_grad_with_stype(stype, None, stype) 431 for grad_stype in stypes: 432 # check the stype of the gradient when provided 433 check_grad_with_stype(stype, grad_stype, grad_stype) 434 435@with_seed() 436def test_sparse_dot_grad(): 437 def check_sparse_dot_grad(rhs): 438 lhs = rand_ndarray((2, 8), 'csr') 439 with mx.autograd.record(): 440 y = mx.nd.dot(lhs, rhs) 441 y.backward() 442 grad = rhs.grad 443 grad_np = np.dot(lhs.asnumpy().T, np.ones((lhs.shape[0], rhs.shape[1]))) 444 assert grad.stype == 'row_sparse' 445 assert_almost_equal(grad.asnumpy(), grad_np) 446 447 # check grad with row_sparse weight 448 shape = (8, 3) 449 rsp = mx.nd.ones(shape).tostype('row_sparse') 450 rsp.attach_grad() 451 check_sparse_dot_grad(rsp) 452 453 # check grad with dense weight 454 dns = mx.nd.ones(shape) 455 dns.attach_grad(stype='row_sparse') 456 check_sparse_dot_grad(dns) 457 458@with_seed() 459def test_gradient(): 460 x = mx.nd.ones((1,)) 461 x.attach_grad() 462 463 with mx.autograd.record(): 464 z = mx.nd.elemwise_add(mx.nd.exp(x), x) 465 dx, = mx.autograd.grad(z, [x], create_graph=True) 466 assert abs(dx.asscalar() - 3.71828175) < 1e-7 467 dx.backward() 468 assert abs(x.grad.asscalar() - 2.71828175) < 1e-7 469 470 471if __name__ == "__main__": 472 import nose 473 nose.runmodule() 474