1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import functools
19import mxnet.ndarray as nd
20from mxnet.ndarray import zeros_like
21from mxnet.autograd import *
22from mxnet.test_utils import *
23from common import setup_module, with_seed, teardown
24from mxnet.test_utils import environment
25
26
27def grad_and_loss(func, argnum=None):
28    """Return function that computes both gradient of arguments and loss value.
29
30    Parameters
31    ----------
32    func: a python function
33        The forward (loss) function.
34    argnum: an int or a list of int
35        The index of argument to calculate gradient for.
36
37    Returns
38    -------
39    grad_and_loss_func: a python function
40        A function that would compute both the gradient of arguments and loss value.
41    """
42    @functools.wraps(func)
43    def wrapped(*args):
44        """Wrapped function."""
45        variables = args
46        if argnum is not None:
47            argnum_ = argnum if isinstance(argnum, list) else [argnum]
48            variables = [args[i] for i in argnum_]
49        for x in variables:
50            assert isinstance(x, NDArray), "type of autograd input should NDArray."
51        grads = [zeros_like(x) for x in variables]
52        mark_variables(variables, grads)
53        with record():
54            outputs = func(*args)
55        backward([outputs] if isinstance(outputs, NDArray) else outputs)
56        return grads, outputs
57    return wrapped
58
59def grad(func, argnum=None):
60    """Return function that computes gradient of arguments.
61
62    Parameters
63    ----------
64    func: a python function
65        The forward (loss) function.
66    argnum: an int or a list of int
67        The index of argument to calculate gradient for.
68
69    Returns
70    -------
71    grad_func: a python function
72        A function that would compute the gradient of arguments.
73
74    Examples
75    --------
76    >>> # autograd supports dynamic graph which is changed
77    >>> # every instance
78    >>> def func(x):
79    >>>     r = random.randint(0, 1)
80    >>>     if r % 2:
81    >>>         return x**2
82    >>>     else:
83    >>>         return x/3
84    >>> # use `grad(func)` to get the gradient function
85    >>> for x in range(10):
86    >>>     grad_func = grad(func)
87    >>>     inputs = nd.array([[1, 2, 3], [4, 5, 6]])
88    >>>     grad_vals = grad_func(inputs)
89    """
90    grad_with_loss_func = grad_and_loss(func, argnum)
91    @functools.wraps(grad_with_loss_func)
92    def wrapped(*args):
93        return grad_with_loss_func(*args)[0]
94    return wrapped
95
96def autograd_assert(*args, **kwargs):
97    func   = kwargs["func"]
98    grad_f = kwargs["grad_func"]
99    argnum = kwargs["argnum"] if 'argnum' in kwargs else None
100
101    grad_func = grad_and_loss(func, argnum)
102    grad_vals, output = grad_func(*args)
103    res = func(*args)
104    assert same(output.asnumpy(), res.asnumpy())
105    grad_res = grad_f(*args)
106    assert len(grad_vals) == len(grad_res)
107    for a, b in zip(grad_vals, grad_res):
108        assert same(a.asnumpy(), b.asnumpy())
109
110@with_seed()
111def test_unary_func():
112    def check_unary_func(x):
113        f_exp         = lambda x: nd.exp(x)
114        f_exp_grad    = lambda x: [nd.exp(x)]
115        autograd_assert(x, func=f_exp, grad_func=f_exp_grad)
116        f_half        = lambda x: x/2
117        f_half_grad   = lambda x: [nd.ones(x.shape) * 0.5]
118        autograd_assert(x, func=f_half, grad_func=f_half_grad)
119        f_square      = lambda x: x**2
120        f_square_grad = lambda x: [2*x]
121        autograd_assert(x, func=f_square, grad_func=f_square_grad)
122    uniform = nd.uniform(shape=(4, 5))
123    stypes = ['default', 'row_sparse', 'csr']
124    with environment('MXNET_STORAGE_FALLBACK_LOG_VERBOSE', '0'):
125        for stype in stypes:
126            check_unary_func(uniform.tostype(stype))
127
128@with_seed()
129def test_binary_func():
130    def check_binary_func(x, y):
131        f_add      = lambda x, y: x+y
132        f_add_grad = lambda x, y: [nd.ones(x.shape), nd.ones(y.shape)]
133        autograd_assert(x, y, func=f_add, grad_func=f_add_grad)
134        f_mul      = lambda x, y: x*y
135        f_mul_grad = lambda x, y: [y, x]
136        autograd_assert(x, y, func=f_mul, grad_func=f_mul_grad)
137        f_compose  = lambda x, y: x+x*y
138        f_compose_grad = lambda x, y: [nd.ones(x.shape) + y, x]
139        autograd_assert(x, y, func=f_compose, grad_func=f_compose_grad)
140    uniform_x = nd.uniform(shape=(4, 5))
141    uniform_y = nd.uniform(shape=(4, 5))
142    stypes = ['default', 'row_sparse', 'csr']
143    with environment('MXNET_STORAGE_FALLBACK_LOG_VERBOSE', '0'):
144        for stype_x in stypes:
145            for stype_y in stypes:
146                x = uniform_x.tostype(stype_x)
147                y = uniform_y.tostype(stype_y)
148                check_binary_func(x, y)
149
150
151@with_seed()
152def test_operator_with_state():
153    def f_fc(a, b, weight, bias):
154        x = a*b
155        fc = nd.FullyConnected(
156            x, weight, bias, num_hidden=32)
157        return fc
158
159    a = nd.uniform(shape=(64, 50))
160    b = nd.uniform(shape=(64, 50))
161    weight = nd.uniform(shape=(32, 50))
162    bias = nd.uniform(shape=(32, ))
163
164    grad_func = grad_and_loss(f_fc)
165    grad_vals, outputs = grad_func(a, b, weight, bias)
166    # (TODO) assert
167
168@with_seed()
169def test_argnum():
170    def f_with_mode(a, b, mode):
171        if mode:
172            return a+b
173        else:
174            return a*b
175
176    a = nd.uniform(shape=(3, 2))
177    b = nd.uniform(shape=(3, 2))
178    f_add_grad = lambda x, y, mode: [nd.ones(x.shape), nd.ones(y.shape)]
179    f_mul_grad = lambda x, y, mode: [y, x]
180    autograd_assert(a, b, True,
181        argnum=[0, 1], func=f_with_mode, grad_func=f_add_grad)
182    autograd_assert(a, b, False,
183        argnum=[0, 1], func=f_with_mode, grad_func=f_mul_grad)
184
185
186@with_seed()
187def test_training():
188    x = nd.ones((10, 10))
189    with record():
190        y = nd.Dropout(x, p=0.5)
191        assert not (y.asnumpy() == x.asnumpy()).all()
192        with pause():
193            y = nd.Dropout(x, p=0.5)
194            assert (y.asnumpy() == x.asnumpy()).all()
195
196
197@with_seed()
198def test_out_grads():
199    x = nd.ones((3, 5))
200    dx = nd.zeros_like(x)
201    mark_variables([x], [dx])
202    da = None
203    db = nd.array([1,2,3,4,5])
204    dc = nd.array([5,4,3,2,1])
205
206    with record():
207        a, b, c = nd.split(x, axis=0, num_outputs=3, squeeze_axis=True)
208        backward([a, b, c], [da, db, dc])
209
210    assert (dx.asnumpy() == np.array(
211        [[1,1,1,1,1],
212         [1,2,3,4,5],
213         [5,4,3,2,1]])).all()
214
215
216@with_seed()
217def test_detach_updated_grad():
218    x = nd.ones((2, 2))
219    dx = nd.zeros_like(x)
220    y = nd.ones_like(x)
221    dy = nd.zeros_like(x)
222    mark_variables([x, y], [dx, dy])
223    assert x._fresh_grad == False
224    assert y._fresh_grad == False
225
226    with record():
227        x2 = x + 2
228        y2  = x2 + y
229        y2.backward()
230    assert (dx.asnumpy() == 1).all()
231    assert x._fresh_grad == True
232    assert y._fresh_grad == True
233
234    dx[:] = 0
235    x._fresh_grad = False
236    y._fresh_grad = False
237    assert x._fresh_grad == False
238    assert y._fresh_grad == False
239    with record():
240        x2 = x + 2
241        x2 = x2.detach()
242        y2  = x2 + y
243        y2.backward()
244    assert (dx.asnumpy() == 0).all()
245    assert y._fresh_grad == True
246    assert x._fresh_grad == False
247
248
249@with_seed()
250def test_retain_grad():
251    x = mx.nd.ones((2, 2))
252    dx = mx.nd.zeros((2, 2))
253    mark_variables([x], [dx], grad_reqs='add')
254    with record():
255        y = x + 1
256        y.backward(retain_graph=False)
257    assert (dx.asnumpy() == 1).all()
258
259    dx[:] = 0
260    with record():
261        y = x + 1
262        y.backward(retain_graph=True)
263        y.backward(retain_graph=False)
264    assert (dx.asnumpy() == 2).all()
265
266    # The following sequence should throw an exception. We discard the expected
267    # stderr stack trace output for this operation to keep the test logs clean.
268    with discard_stderr():
269        try:
270            with record():
271                y = x + 1
272                y.backward()
273                y.backward()
274        except Exception:
275            return
276
277    raise AssertionError(
278        "differentiating the same graph twice without retain_graph should fail")
279
280
281@with_seed()
282def test_attach_grad():
283    def check_attach_grad(x):
284        assert x.grad is None
285        x.attach_grad()
286        with record():
287            y = x * 2
288            assert y.grad is None
289            y.backward(out_grad=mx.nd.ones_like(y).tostype(x.stype))
290        assert (x.grad.asnumpy() == 2).all()
291    zeros = mx.nd.zeros((10, 10))
292    stypes = ['default', 'row_sparse', 'csr']
293    for stype in stypes:
294        x = zeros.tostype(stype)
295        check_attach_grad(x)
296
297
298@with_seed()
299def test_is_train():
300    x = mx.nd.ones((10, 10))
301    x.attach_grad()
302    with record(train_mode=True):
303        assert is_recording()
304        assert is_training()
305        y = mx.nd.Dropout(x, p=0.5)
306        assert y.asnumpy().max() == 2 and y.asnumpy().min() == 0
307        y.backward()
308        assert (x.grad.asnumpy() == y.asnumpy()).all()
309
310        with predict_mode():
311            assert is_recording()
312            assert not is_training()
313            y = mx.nd.Dropout(x, p=0.5)
314            assert (y.asnumpy() == x.asnumpy()).all()
315            y.backward(train_mode=False)
316            assert (x.grad.asnumpy() == x.asnumpy()).all()
317
318    with record(train_mode=False):
319        assert is_recording()
320        assert not is_training()
321        y = mx.nd.Dropout(x, p=0.5)
322        assert (y.asnumpy() == x.asnumpy()).all()
323        y.backward(train_mode=False)
324        assert (x.grad.asnumpy() == x.asnumpy()).all()
325
326        with train_mode():
327            assert is_recording()
328            assert is_training()
329            y = mx.nd.Dropout(x, p=0.5)
330            assert y.asnumpy().max() == 2 and y.asnumpy().min() == 0
331            y.backward()
332            assert (x.grad.asnumpy() == y.asnumpy()).all()
333
334    assert not is_recording()
335    assert not is_training()
336    y = mx.nd.Dropout(x, p=0.5)
337    assert (y.asnumpy() == x.asnumpy()).all()
338
339    with train_mode():
340        assert not is_recording()
341        assert is_training()
342        y = mx.nd.Dropout(x, p=0.5)
343        assert y.asnumpy().max() == 2 and y.asnumpy().min() == 0
344
345@with_seed()
346def test_function():
347    class func(Function):
348        def forward(self, x, y):
349            m = x / y
350            n = x * y
351            self.save_for_backward(x, y)
352            return m, n
353
354        def backward(self, dm, dn):
355            x, y = self.saved_tensors
356            dx = dm/y + dn*y
357            dy = dn*x - dm * x / y / y
358            return dx, dy
359
360    f = func()
361    x = mx.nd.random.uniform(shape=(10,))
362    x.attach_grad()
363    y = mx.nd.random.uniform(shape=(10,))
364    y.attach_grad()
365    with record():
366        m, n = f(x, y)
367        backward([m, n])
368
369    dx1 = x.grad.asnumpy()
370    dy1 = y.grad.asnumpy()
371
372    with record():
373        backward([x/y, x*y])
374
375    # Non-zero atol required, as exposed by seed 630179191
376    atol = 1e-6
377    assert_almost_equal(x.grad.asnumpy(), dx1, atol=atol)
378    assert_almost_equal(y.grad.asnumpy(), dy1, atol=atol)
379
380
381@with_seed()
382def test_function1():
383    class Foo(mx.autograd.Function):
384        def __init__(self):
385            super(Foo, self).__init__()
386
387        def forward(self, X):
388            return X + 1;
389
390        def backward(self, dY):
391            return dY
392
393    with mx.autograd.record():
394        X = mx.nd.zeros((3, 4))
395        #X.attach_grad()  # uncommenting this line works
396        for i in range(5):
397            f = Foo()
398            X = f(X)
399        X.wait_to_read()
400
401
402@with_seed()
403def test_get_symbol():
404    x = mx.nd.ones((1,))
405    x.attach_grad()
406    with record():
407        y = x*x + 2*x - 1
408    assert len(get_symbol(y).list_arguments()) == 1
409
410    z = mx.nd.ones((1,))
411    z.attach_grad()
412    with record():
413        y = x*x + 2*z - 1
414    assert len(get_symbol(y).list_arguments()) == 2
415
416@with_seed()
417def test_grad_with_stype():
418    def check_grad_with_stype(array_stype, grad_stype, expected_stype):
419        x = mx.nd.zeros((1, 1), stype=array_stype)
420        x.attach_grad(stype=grad_stype)
421        # check grad attached
422        assert x.grad.stype == expected_stype
423        y = x.detach()
424        # check array detached
425        assert y.stype == array_stype
426
427    stypes = ['default', 'csr', 'row_sparse']
428    for stype in stypes:
429        # check the default stype of the gradient (same as the array stype)
430        check_grad_with_stype(stype, None, stype)
431        for grad_stype in stypes:
432            # check the stype of the gradient when provided
433            check_grad_with_stype(stype, grad_stype, grad_stype)
434
435@with_seed()
436def test_sparse_dot_grad():
437    def check_sparse_dot_grad(rhs):
438        lhs = rand_ndarray((2, 8), 'csr')
439        with mx.autograd.record():
440            y = mx.nd.dot(lhs, rhs)
441        y.backward()
442        grad = rhs.grad
443        grad_np = np.dot(lhs.asnumpy().T, np.ones((lhs.shape[0], rhs.shape[1])))
444        assert grad.stype == 'row_sparse'
445        assert_almost_equal(grad.asnumpy(), grad_np)
446
447    # check grad with row_sparse weight
448    shape = (8, 3)
449    rsp = mx.nd.ones(shape).tostype('row_sparse')
450    rsp.attach_grad()
451    check_sparse_dot_grad(rsp)
452
453    # check grad with dense weight
454    dns = mx.nd.ones(shape)
455    dns.attach_grad(stype='row_sparse')
456    check_sparse_dot_grad(dns)
457
458@with_seed()
459def test_gradient():
460    x = mx.nd.ones((1,))
461    x.attach_grad()
462
463    with mx.autograd.record():
464        z = mx.nd.elemwise_add(mx.nd.exp(x), x)
465    dx, = mx.autograd.grad(z, [x], create_graph=True)
466    assert abs(dx.asscalar() - 3.71828175) < 1e-7
467    dx.backward()
468    assert abs(x.grad.asscalar() - 2.71828175) < 1e-7
469
470
471if __name__ == "__main__":
472    import nose
473    nose.runmodule()
474