1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import os
19import sys
20import tempfile
21import math
22import numpy as np
23import mxnet as mx
24
25curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
26sys.path.append(os.path.join(curr_path, '../python/unittest/'))
27
28from mxnet.test_utils import rand_ndarray, assert_almost_equal, rand_coord_2d, default_context, check_symbolic_forward, create_2d_tensor, get_identity_mat, get_identity_mat_batch
29from mxnet import gluon, nd
30from common import with_seed, assertRaises
31from mxnet.base import MXNetError
32import unittest
33
34# dimension constants
35MEDIUM_X = 10000
36VLARGE_X = 4300000000
37LARGE_X = 100000000
38SMALL_X = 100
39SMALL_Y = 50
40LARGE_SQ_X = 70000
41LARGE_SIZE = LARGE_X * SMALL_Y
42LARGE_TENSOR_SHAPE = 2**32
43RNN_LARGE_TENSOR = 2**28
44
45
46def test_nn():
47    def check_gluon_embedding():
48        m = gluon.nn.Embedding(SMALL_Y, MEDIUM_X)
49        m.initialize()
50        a = nd.zeros((MEDIUM_X, SMALL_Y))
51        b = m(a)
52        assert b.shape == (MEDIUM_X, SMALL_Y, MEDIUM_X)
53        assert b.asnumpy().size == LARGE_SIZE
54
55    def check_fully_connected():
56        a = nd.ones(shape=(LARGE_X, SMALL_Y))
57        b = nd.ones(shape=(SMALL_Y, SMALL_Y))
58        c = nd.ones(shape=(b.shape[0],))
59
60        # w/o bias
61        res = nd.FullyConnected(a, b, num_hidden=b.shape[0], no_bias=True)
62        assert np.sum(res[-1].asnumpy() == a.shape[1]) == b.shape[0]
63
64        # w/ bias
65        res = nd.FullyConnected(a, b, c, num_hidden=b.shape[0], no_bias=False)
66        assert np.sum(res[-1].asnumpy() == a.shape[1] + 1) == b.shape[0]
67
68    def check_dense():
69        data = mx.nd.ones(shape=(50*1000*1000, 100))
70        linear = gluon.nn.Dense(100)
71        linear.initialize()
72        res = linear(data)
73        assert res.shape == (50000000, 100)
74
75    def check_softmax():
76        input_data = mx.nd.ones((SMALL_Y, LARGE_X))
77        for axis in [0, 1]:
78            true_output = np.full((SMALL_Y, LARGE_X), (1 / input_data.shape[axis]))
79            output = nd.softmax(input_data, axis=axis)
80            assert_almost_equal(output.asnumpy(), true_output, rtol=1e-5, atol=1e-5)
81
82    def check_softmax_cross_entropy():
83        # dtype of input data, mxnet cross entropy set explicitly to float64
84        # numpy implicitly takes care of double precision
85        batch_size = SMALL_Y
86        num_labels = LARGE_X
87        input_data = mx.nd.ones((batch_size, num_labels), dtype="float64")
88        input_label = mx.nd.zeros((batch_size,), dtype="float64")
89        true_softmax = np.full((batch_size, num_labels), (1 / num_labels))
90        # use 1/batch_size when softmax axis=0
91        # here 1/num_labels since softmax_cross_entropy uses default axis
92        # by default axis=1
93        np_one_hot_label = np.zeros((batch_size, num_labels))
94        np_one_hot_label[:, 0] = 1
95        true_softmax_cross_entropy = np.sum(-np.log(true_softmax) *
96                                            np_one_hot_label)
97        mx_softmax_cross_entropy = mx.nd.softmax_cross_entropy(input_data,
98                                                               input_label,
99                                                               dtype="float64")
100        assert_almost_equal(mx_softmax_cross_entropy.asnumpy(),
101                            true_softmax_cross_entropy, rtol=1e-3, atol=1e-5)
102
103    def check_softmax_output():
104        x = mx.sym.Variable('x')
105        label = mx.sym.Variable('label')
106        x_nd = mx.nd.ones((LARGE_X, SMALL_Y))
107        grad_x = mx.nd.zeros((LARGE_X, SMALL_Y))
108        label_nd = mx.nd.ones((LARGE_X))
109        sym = mx.sym.SoftmaxOutput(data=x, label=label, ignore_label=0,
110                                   use_ignore=False)
111
112        ex = sym.bind(ctx=default_context(), args={'x': x_nd, 'label': label_nd},
113                      args_grad=None)
114        ex.forward(is_train=False)
115        softmax_out = ex.outputs[0][0].asnumpy()
116        expected_softmax_out = (1 / SMALL_Y) * mx.nd.ones((SMALL_Y)).asnumpy()
117        assert np.isclose(softmax_out, expected_softmax_out).all()
118
119        ex = sym.bind(ctx=default_context(), args={'x': x_nd, 'label': label_nd},
120                      args_grad={'x': grad_x})
121        ex.forward(is_train=True)
122        softmax_out = ex.outputs[0][0].asnumpy()
123        expected_softmax_out = (1 / SMALL_Y) * mx.nd.ones((SMALL_Y)).asnumpy()
124        assert np.isclose(softmax_out, expected_softmax_out).all()
125
126        ex.backward(is_train=True)
127        grad_out = ex.grad_arrays[0][0].asnumpy()
128        k = int(label_nd[0].asscalar())
129        expected_grad_out = np.zeros((SMALL_Y,))
130        expected_grad_out[k] = -1
131        assert np.isclose(grad_out - softmax_out, expected_grad_out).all()
132
133    def check_softmax_activation():
134        data = nd.random_normal(shape=(2**29, 2, 2, 2))
135        out = nd.random_normal(shape=(2**29, 2, 2, 2))
136
137        res = nd.SoftmaxActivation(data=data, out=out)
138
139        assert res.shape[0] == 536870912
140        assert res.shape[1] == 2
141        assert res.shape[2] == 2
142        assert res.shape[3] == 2
143
144    def np_softmax(x, axis=-1, temperature=1.0):
145        x = x - np.max(x, axis=axis, keepdims=True)
146        x = np.exp(x/temperature)
147        x /= np.sum(x, axis=axis, keepdims=True)
148        return x
149
150    @unittest.skip("log_softmax flaky, tracked at "
151                   "https://github.com/apache/incubator-mxnet/issues/17397")
152    def check_log_softmax():
153        ndim = 2
154        shape = (SMALL_Y, LARGE_X)
155        axis = np.random.randint(0, ndim)
156        data = np.random.uniform(-2, 2, size=shape)
157        sym = mx.sym.log_softmax(axis=axis-ndim)
158        check_symbolic_forward(sym, [data], [np.log(np_softmax(data, axis=axis)+1e-20)])
159
160    # TODO: correctness of prelu (currently flaky)
161    def check_leaky_relu():
162        a = -1*mx.nd.ones((LARGE_X, SMALL_Y))
163
164        def check_leaky():
165            res = mx.nd.LeakyReLU(a, act_type="leaky", slope=0.3)
166            assert_almost_equal(res[-1][-1].asnumpy(), 0.3*a[-1][-1].asnumpy(), atol=1e-3, rtol=1e-3)
167
168        def check_elu():
169            res = mx.nd.LeakyReLU(a, act_type="elu", slope=0.3)
170            assert_almost_equal(res[-1][-1].asnumpy(), 0.3*(np.exp(a[-1][-1].asnumpy())-1), atol=1e-3, rtol=1e-3)
171
172        def check_selu():
173            lam = 1.0507009873554804934193349852946
174            alpha = 1.6732632423543772848170429916717
175            res = mx.nd.LeakyReLU(a, act_type="selu")
176            assert_almost_equal(res[-1][-1].asnumpy(), (lam * alpha * (np.exp(a[-1][-1].asnumpy())-1)), atol=1e-3, rtol=1e-3)
177
178        def check_rrelu():
179            lower = 0.125
180            upper = 0.333999991
181            res = mx.nd.LeakyReLU(a, act_type="rrelu")
182            assert_almost_equal(res[0][-1][-1].asnumpy(), (lower + upper) / 2 * a[-1][-1].asnumpy(), atol=1e-3, rtol=1e-3)
183
184        check_leaky()
185        check_elu()
186        check_selu()
187        check_rrelu()
188
189    def check_pooling():
190        a = mx.nd.ones((MEDIUM_X, 200, SMALL_Y, SMALL_Y))
191
192        def check_avg_pooling():
193            res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='avg')
194            assert_almost_equal(res[-1][-1][-1][-1].asnumpy(), 1.0000001, atol=1e-3, rtol=1e-3)
195            assert res.shape[-1] == SMALL_Y - 5 + 1
196
197        def check_max_pooling():
198            res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='max')
199            assert_almost_equal(res[-1][-1][-1][-1].asnumpy(), 1., atol=1e-3, rtol=1e-3)
200            assert res.shape[-1] == SMALL_Y - 5 + 1
201
202        def check_sum_pooling():
203            res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='sum')
204            assert_almost_equal(res[-1][-1][-1][-1].asnumpy(), 25, atol=1e-3, rtol=1e-3)
205            assert res.shape[-1] == SMALL_Y - 5 + 1
206
207        def check_lp_pooling():
208            res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='lp', p_value=2)
209            assert_almost_equal(res[-1][-1][-1][-1].asnumpy(), 5., atol=1e-3, rtol=1e-3)
210            assert res.shape[-1] == SMALL_Y - 5 + 1
211
212            res = mx.nd.Pooling(a, kernel=(5, 5), pool_type='lp', p_value=1)
213            assert_almost_equal(res[-1][-1][-1][-1].asnumpy(), 25., atol=1e-3, rtol=1e-3)
214            assert res.shape[-1] == SMALL_Y - 5 + 1
215
216        check_avg_pooling()
217        check_max_pooling()
218        check_sum_pooling()
219        check_lp_pooling()
220
221    def check_layer_norm():
222        dtype = np.float32
223        forward_check_eps = 1E-3
224        axis = 1
225        eps = 1E-5
226        in_shape = (LARGE_X, SMALL_Y)
227        ctx = mx.cpu()
228
229        def npy_layer_norm(data, gamma, beta, axis=1, eps=1E-5):
230            if axis < 0:
231                axis += data.ndim
232            broadcast_shape = [1 for _ in range(data.ndim)]
233            broadcast_shape[axis] = data.shape[axis]
234            mean = data.mean(axis=axis, keepdims=True).astype(dtype)
235            var = data.var(axis=axis, keepdims=True).astype(dtype)
236            std = np.sqrt(var + dtype(eps)).astype(dtype)
237            out = np.reshape(gamma, broadcast_shape) * (data - mean) / std + \
238                  np.reshape(beta, broadcast_shape)
239            return out
240        data = np.random.normal(0, 1, in_shape).astype(dtype)
241        gamma = np.random.normal(0, 1, (in_shape[axis],)).astype(dtype)
242        beta = np.random.normal(0, 1, (in_shape[axis],)).astype(dtype)
243        data_s = mx.symbol.Variable('data')
244        gamma_s = mx.symbol.Variable('gamma')
245        beta_s = mx.symbol.Variable('beta')
246        out_s = mx.symbol.LayerNorm(data=data_s, gamma=gamma_s, beta=beta_s,
247                                    axis=axis, eps=eps)
248        exe = out_s.simple_bind(ctx, data=in_shape)
249        exe.arg_dict['data'][:] = data
250        exe.arg_dict['gamma'][:] = gamma
251        exe.arg_dict['beta'][:] = beta
252        out_nd = exe.forward()[0]
253        out = npy_layer_norm(data, gamma, beta, axis, eps)
254        assert_almost_equal(out, out_nd.asnumpy(), forward_check_eps,
255                            forward_check_eps)
256
257    # TODO: correctness of dropout
258    # currently only test for dropout to work
259    # since testing for correctness involves flakiness issue #14288
260    def check_dropout():
261        shape = (LARGE_X, SMALL_Y)
262        x = mx.sym.var('data')
263        y = mx.sym.Dropout(x, p=1, cudnn_off=True)
264        exe = y.simple_bind(ctx=default_context(), data=shape)
265        exe.arg_arrays[0][:] = 1
266        out = exe.forward(is_train=True)
267        nd.waitall()
268        assert out[0].shape == shape
269
270    def check_activation():
271        x = mx.nd.ones((LARGE_X, SMALL_Y))
272        check_x = -2
273        x[-1, -1] = check_x
274        # Hyperbolic tangent (tanh)
275        # y = (exp(x)-exp(-x))/(exp(x)+exp(-x))
276        y = mx.nd.Activation(x, act_type="tanh")
277        tanh_x = ((np.exp(check_x)-np.exp(-check_x))/(np.exp(check_x)+np.exp(-check_x)))
278        assert y[-1][-1] == np.float32(tanh_x)
279        # Recitified Linear Unit (relu)
280        # y = max(x,0)
281        y = mx.nd.Activation(x, act_type="relu")
282        assert y[-1][-1] == 0
283        # Sigmoid
284        # y = x/(1+abs(x))
285        y = mx.nd.Activation(x, act_type="sigmoid")
286        sigmoid_x = (1/(1+math.exp(-check_x)))
287        assert_almost_equal(y[-1][-1].asnumpy(), np.float32(sigmoid_x), atol=1e-3, rtol=1e-3)
288        # Soft Sign
289        # y = 1/(1+exp(-x))
290        y = mx.nd.Activation(x, act_type="softsign")
291        softsign_x = (check_x/(1+abs(check_x)))
292        assert y[-1][-1] == np.float32(softsign_x)
293
294
295    # TODO: correctness of batchnorm
296    # in future, we could test if mean, var of output
297    # matches target output's mean, var
298    def check_batchnorm():
299        def get_np_mean_var(data, running_mean, running_var, eps, use_global_status=True):
300            if not use_global_status:
301                # train mode, calculate the real mean and var
302                mean = np.mean(data, axis=(0, 2, 3))
303                mean_broad = np.expand_dims(mean, axis=0)
304                mean_broad = np.expand_dims(mean_broad, axis=2)
305                mean_broad = np.expand_dims(mean_broad, axis=3)
306                mean_broad = np.broadcast_to(mean_broad, data.shape)
307                var = np.square(data - mean_broad)
308                var = np.mean(var, axis=(0, 2, 3))
309            else:
310                # inference mode, use running_mean and running_var instead
311                mean = np.full((data.shape[1],), running_mean)
312                var = np.full((data.shape[1],), running_var)
313            # calculate the inverse of standard variance
314            invstdvar = 1. / np.sqrt(var + eps)
315            return mean, invstdvar
316        # Here use 4D input to cover mkldnn BN and non-mkldnn BN
317        shape = (1, 2, LARGE_X, SMALL_Y)
318        axis = 1  # default
319        eps = 1e-3
320        nch = shape[axis]
321        data = mx.nd.ones(shape=shape)
322        bn_gamma = mx.nd.random.uniform(shape=(nch,))
323        bn_beta = mx.nd.random.uniform(shape=(nch,))
324        bn_running_mean = mx.nd.zeros(nch)
325        bn_running_var = mx.nd.ones(nch)
326        output = mx.nd.BatchNorm(data, bn_gamma, bn_beta,
327                                 bn_running_mean, bn_running_var, output_mean_var=True)
328        assert output[0].shape == shape
329        mean, invstdvar = output[1], output[2]
330        np_mean, np_invstdvar = get_np_mean_var(data.asnumpy(), bn_running_mean.asnumpy(), bn_running_var.asnumpy(),
331                                                eps, use_global_status=True)
332        assert_almost_equal(mean.asnumpy(), np_mean)
333        assert_almost_equal(invstdvar.asnumpy(), np_invstdvar)
334
335    def check_relu():
336        def frelu(x):
337            return np.maximum(x, 0.0)
338
339        def frelu_grad(x):
340            return 1.0 * (x > 0.0)
341        shape = (SMALL_Y, LARGE_X)
342        x = mx.symbol.Variable("x")
343        y = mx.sym.relu(x)
344        xa = np.random.uniform(low=-1.0, high=1.0, size=shape)
345        eps = 1e-4
346        xa[abs(xa) < eps] = 1.0
347        ya = frelu(xa)
348        ga = frelu_grad(xa)
349        check_symbolic_forward(y, [xa], [ya])
350
351    def check_sigmoid():
352        def fsigmoid(a):
353            return np.divide(1.0, (1.0 + np.exp(-a)))
354        shape = (SMALL_Y, LARGE_X)
355        x = mx.symbol.Variable("x")
356        y = mx.sym.sigmoid(x)
357        xa = np.random.uniform(low=-1.0, high=1.0, size=shape)
358        ya = fsigmoid(xa)
359        check_symbolic_forward(y, [xa], [ya])
360
361    def check_linear_and_logistic_regression():
362        shape = (LARGE_X, SMALL_Y)
363
364        def check_regression(symbol, forward, backward, shape):
365            # init executor
366            data_s = mx.symbol.Variable('data')
367            label_s = mx.symbol.Variable('label')
368            out_s = symbol(data=data_s, label=label_s)
369            grad_req = {'data': 'write', 'label': 'null'}
370            exe = out_s.simple_bind(ctx=default_context(), data=shape, label=shape, grad_req=grad_req)
371            arg_map = dict(zip(out_s.list_arguments(), exe.arg_arrays))
372            grad_map = dict(zip(out_s.list_arguments(), exe.grad_arrays))
373            # init data
374            data = mx.random.uniform(-1, -1, shape)
375            arg_map["data"][:] = data
376            atol = 1e-5
377            density = 0.5
378            stype = 'default'
379            label = arg_map["label"]
380            label[:] = rand_ndarray(shape, stype, density=density)
381            exe.forward(is_train=True)
382            exe.backward()
383            np_out = forward(data.asnumpy())
384            out_grad = backward(np_out, label.asnumpy().reshape(np_out.shape)) / shape[1]
385            assert_almost_equal(exe.outputs[0].asnumpy(), np_out, atol=atol)
386            assert_almost_equal(grad_map["data"].asnumpy(), out_grad, atol=atol)
387
388        check_regression(mx.symbol.LogisticRegressionOutput,
389                         lambda x: 1.0 / (1.0 + np.exp(-x)),
390                         lambda x, y: x - y,
391                         shape)
392        check_regression(mx.symbol.LinearRegressionOutput,
393                         lambda x: x,
394                         lambda x, y: x - y,
395                         shape)
396
397    def check_l2_normalization():
398        x = nd.ones((2, LARGE_X*2))
399        x[0] = 3
400        x[1] = 4
401        # Channel Mode
402        z = x.reshape(1, 2, LARGE_X*2)
403        y = nd.L2Normalization(z, mode='channel')
404        assert y[0][0][0] == 0.6
405        assert y[0][0][-1] == 0.6
406        assert y[0][1][0] == 0.8
407        assert y[0][1][-1] == 0.8
408        # Instance Mode
409        z = x.T
410        y = nd.L2Normalization(z, mode='instance')
411        assert y[0][0] == 0.6
412        assert y[0][1] == 0.8
413        assert y[-1][0] == 0.6
414        assert y[-1][1] == 0.8
415        # Spatial Mode
416        z = z.reshape(1, 200000000, 2)
417        y = nd.L2Normalization(z, mode='spatial')
418        assert y[0][0][0] == 0.6
419        assert y[0][0][1] == 0.8
420        assert y[0][-1][0] == 0.6
421        assert y[0][-1][1] == 0.8
422
423    def check_instance_norm():
424        dtype = np.float32
425        forward_check_eps = 1E-3
426        axis = -1
427        eps = 1E-5
428        in_shape = (LARGE_X, 1, SMALL_Y)
429        ctx = mx.cpu()
430
431        # Implementation of instance normalization using numpy
432        def npy_instance_norm(data, gamma, beta, axis, eps=1E-5):
433            if axis < 0:
434                axis += data.ndim
435            broadcast_shape = [1 for _ in range(data.ndim)]
436            broadcast_shape[axis] = data.shape[axis]
437            mean = data.mean(axis=axis, keepdims=True).astype(dtype)
438            var = data.var(axis=axis, keepdims=True).astype(dtype)
439            std = np.sqrt(var + dtype(eps)).astype(dtype)
440            out = gamma * (data - mean) / std + \
441                  beta
442            return out
443        data = np.random.normal(0, 1, in_shape).astype(dtype)
444        gamma = np.random.normal(0, 1, (1,)).astype(dtype)
445        beta = np.random.normal(0, 1, (1,)).astype(dtype)
446        data_s = mx.symbol.Variable('data')
447        gamma_s = mx.symbol.Variable('gamma')
448        beta_s = mx.symbol.Variable('beta')
449        out_s = mx.symbol.InstanceNorm(data=data_s, gamma=gamma_s, beta=beta_s,
450                                       eps=eps)
451        exe = out_s.simple_bind(ctx, data=in_shape)
452        exe.arg_dict['data'][:] = data
453        exe.arg_dict['gamma'][:] = gamma
454        exe.arg_dict['beta'][:] = beta
455        out_nd = exe.forward()[0]
456        # Calls implementation of instance norm in numpy and compares the output
457        out = npy_instance_norm(data, gamma, beta, axis, eps)
458        assert_almost_equal(out, out_nd.asnumpy(), forward_check_eps,
459                            forward_check_eps)
460
461    def check_col2im():
462        data = nd.random_normal(shape=(1, 2**30, 4))
463        output_size = (2, 2, 1)
464        kernel = (1, 1, 1)
465
466        res = nd.col2im(data=data, output_size=output_size, kernel=kernel)
467
468        assert res.shape[0] == 1
469        assert res.shape[1] == 1073741824
470        assert res.shape[2] == 2
471        assert res.shape[3] == 2
472        assert res.shape[4] == 1
473
474    def check_embedding():
475        data = nd.random_normal(shape=(LARGE_TENSOR_SHAPE, 1))
476        weight = nd.random_normal(shape=(LARGE_TENSOR_SHAPE, 1))
477        input_dim = LARGE_TENSOR_SHAPE
478        output_dim = 1
479
480        out = nd.Embedding(data=data, weight=weight, input_dim=input_dim, output_dim=output_dim)
481
482        assert out.shape[0] == LARGE_TENSOR_SHAPE
483        assert out.shape[1] == 1
484
485    def check_spatial_transformer():
486        data = nd.random_normal(shape=(2, 2**29, 1, 6))
487        loc = nd.random_normal(shape=(2, 6))
488        transform_type = 'affine'
489        sampler_type = 'bilinear'
490        target_shape = (2, 6)
491
492        res = nd.SpatialTransformer(data=data, loc=loc, transform_type=transform_type,
493                                    sampler_type=sampler_type, target_shape=target_shape)
494
495        assert res.shape[0] == 2
496        assert res.shape[1] == 536870912
497        assert res.shape[2] == 2
498        assert res.shape[3] == 6
499
500    def check_ravel():
501        data = nd.random_normal(shape=(2, LARGE_TENSOR_SHAPE))
502        shape = (2, 10)
503
504        out = nd.ravel_multi_index(data=data, shape=shape)
505
506        assert out.shape[0] == LARGE_TENSOR_SHAPE
507
508    def check_rnn():
509        data = nd.random_normal(shape=(RNN_LARGE_TENSOR, 4, 4))
510        parameters_relu_tanh = nd.random_normal(shape=(7,))
511        parameters_lstm = nd.random_normal(shape=(28,))
512        parameters_gru = nd.random_normal(shape=(21,))
513        state = nd.random_normal(shape=(1, 4, 1))
514        state_cell = nd.random_normal(shape=(1, 4, 1))
515        mode_relu = 'rnn_relu'
516        mode_tanh = 'rnn_tanh'
517        mode_lstm = 'lstm'
518        mode_gru = 'gru'
519        state_size = 1
520        num_layers = 1
521
522        out_relu = nd.RNN(data=data, parameters=parameters_relu_tanh, state=state, mode=mode_relu,
523                          state_size=state_size, num_layers=num_layers)
524
525        out_tanh = nd.RNN(data=data, parameters=parameters_relu_tanh, state=state, mode=mode_tanh,
526                          state_size=state_size, num_layers=num_layers)
527
528        out_lstm = nd.RNN(data=data, parameters=parameters_lstm, state=state, mode=mode_lstm,
529                          state_cell=state_cell, state_size=state_size, num_layers=num_layers)
530
531        out_gru = nd.RNN(data=data, parameters=parameters_gru, state=state, mode=mode_gru,
532                         state_size=state_size, num_layers=num_layers)
533
534        for out in [out_relu, out_tanh, out_lstm, out_gru]:
535            assert out.shape[0] == RNN_LARGE_TENSOR
536            assert out.shape[1] == 4
537            assert out.shape[2] == 1
538
539            assert type(out[0, 0, 0].asscalar()).__name__ == 'float32'
540
541    check_gluon_embedding()
542    check_fully_connected()
543    check_dense()
544    check_softmax()
545    check_softmax_cross_entropy()
546    check_softmax_output()
547    check_softmax_activation()
548    check_log_softmax()
549    check_leaky_relu()
550    check_pooling()
551    check_layer_norm()
552    check_dropout()
553    check_activation()
554    check_batchnorm()
555    check_relu()
556    check_sigmoid()
557    check_linear_and_logistic_regression()
558    check_l2_normalization()
559    check_instance_norm()
560    check_col2im()
561    check_embedding()
562    check_spatial_transformer()
563    check_ravel()
564    check_rnn()
565
566
567def test_tensor():
568    def check_ndarray_zeros():
569        a = nd.zeros(shape=(LARGE_X, SMALL_Y))
570        assert a[-1][0] == 0
571        assert a.shape == (LARGE_X, SMALL_Y)
572        assert a.size == LARGE_SIZE
573
574    def check_ndarray_ones():
575        a = nd.ones(shape=(LARGE_X, SMALL_Y))
576        assert a[-1][0] == 1
577        assert nd.sum(a).asnumpy() == LARGE_SIZE
578
579    @with_seed()
580    def check_ndarray_random_uniform():
581        a = nd.random.uniform(shape=(LARGE_X, SMALL_Y))
582        assert a[-1][0] != 0
583
584    @unittest.skip("Randint flaky, tracked at "
585                   "https://github.com/apache/incubator-mxnet/issues/16172")
586    @with_seed()
587    def check_ndarray_random_randint():
588        a = nd.random.randint(100, 10000, shape=(LARGE_X, SMALL_Y))
589        assert a.shape == (LARGE_X, SMALL_Y)
590        # check if randint can generate value greater than 2**32 (large)
591        low_large_value = 2**32
592        high_large_value = 2**34
593        a = nd.random.randint(low_large_value, high_large_value, dtype=np.int64)
594        low = mx.nd.array([low_large_value], dtype='int64')
595        high = mx.nd.array([high_large_value], dtype='int64')
596        assert a >= low and a < high
597        assert a[-1][0].dtype == np.int64
598
599    @with_seed()
600    def check_ndarray_random_exponential():
601        scale_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X))
602        a = nd.random.exponential(scale=scale_array, shape=(SMALL_X, SMALL_Y))
603        assert a[-1][0][0][0] >= 0
604        assert a.shape == (MEDIUM_X, SMALL_X, SMALL_X, SMALL_Y)
605
606    @with_seed()
607    def check_ndarray_random_gamma():
608        alpha_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X))
609        beta_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X))
610        a = nd.random.gamma(alpha=alpha_array, beta=beta_array,
611                            shape=(SMALL_X, SMALL_Y))
612        assert a[-1][0][0][0] >= 0
613        assert a.shape == (MEDIUM_X, SMALL_X, SMALL_X, SMALL_Y)
614
615    @with_seed()
616    def check_ndarray_random_multinomial():
617        # test 1 shape dimension
618        probs = nd.random.uniform(shape=(LARGE_X, SMALL_Y))
619        a = nd.random.multinomial(probs)
620        assert a[-1] >= 0
621        assert a.shape == (LARGE_X,)
622        # test for NDArray multi-dimension shape
623        a = nd.random.multinomial(probs, shape=(2, SMALL_Y))
624        assert a[-1][0][0] >= 0
625        assert a.shape == (LARGE_X, 2, SMALL_Y)
626        # test log_likelihood output shape
627        a = nd.random.multinomial(probs, shape=(2, SMALL_Y), get_prob=True)
628        assert a[0][0][0][0] >= 0
629        assert a[0].shape == (LARGE_X, 2, SMALL_Y) and a[0].shape == a[1].shape
630
631    @with_seed()
632    def check_ndarray_random_generalized_negative_binomial():
633        alpha_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X))
634        mu_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X))
635        a = nd.random.generalized_negative_binomial(mu=mu_array, alpha=alpha_array,
636                                                    shape=(SMALL_X, SMALL_Y))
637        assert a[-1][0][0][0] >= 0
638        assert a.shape == (MEDIUM_X, SMALL_X, SMALL_X, SMALL_Y)
639
640    @with_seed()
641    def check_ndarray_random_negative_binomial():
642        k_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X))
643        p_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X))
644        a = nd.random.negative_binomial(k=k_array, p=p_array,
645                                        shape=(SMALL_X, SMALL_Y))
646        assert a[-1][0][0][0] >= 0
647        assert a.shape == (MEDIUM_X, SMALL_X, SMALL_X, SMALL_Y)
648
649    @with_seed()
650    def check_ndarray_random_normal():
651        scale_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X))
652        loc_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X))
653        a = nd.random.normal(loc=loc_array, scale=scale_array,
654                             shape=(SMALL_X, SMALL_Y))
655        assert a.shape == (MEDIUM_X, SMALL_X, SMALL_X, SMALL_Y)
656
657    @with_seed()
658    def check_ndarray_random_poisson():
659        lambda_array = nd.random.uniform(shape=(MEDIUM_X, SMALL_X))
660        a = nd.random.poisson(lam=lambda_array, shape=(SMALL_X, SMALL_Y))
661        assert a[-1][0][0][0] >= 0
662        assert a.shape == (MEDIUM_X, SMALL_X, SMALL_X, SMALL_Y)
663
664    @with_seed()
665    def check_ndarray_random_randn():
666        a = nd.random.randn(LARGE_X, SMALL_Y)
667        assert a.shape == (LARGE_X, SMALL_Y)
668        # TODO: Once PR #15772 for randn ndarray dtype for loc,scale param merged
669        # Add check for (x,y,m,n) where x,y shape of loc,scale and m,n input shape
670
671    @with_seed()
672    def check_ndarray_random_shuffle():
673        a = nd.ones(shape=(LARGE_X, SMALL_Y))
674        a[-1] = 3  # assign 3 to entire last row
675        a = nd.random.shuffle(a)
676        # slice first column from shuffled array
677        # pass LARGE_X values to numpy instead of LARGE_X*SMALL_Y
678        # could have assigned to last column (so as to pass SMALL_Y)
679        # but shuffle operation is performed along first axis
680        unique_a = np.unique(a[:, 0].asnumpy())
681        assert len(unique_a) == 2  # only 2 unique values
682        assert unique_a[0] == 1  # first unique value is 1
683        assert unique_a[1] == 3  # second unique value is 3
684        assert a.shape == (LARGE_X, SMALL_Y)
685
686    def check_ndarray_empty():
687        a = nd.empty((LARGE_X, SMALL_Y))
688        assert a.shape == (LARGE_X, SMALL_Y)
689
690    def check_zeros_like():
691        a = nd.array(np.ones((SMALL_Y, LARGE_X)))
692        b = nd.zeros_like(a)
693        assert b[-1][-1] == 0
694        assert b.shape == a.shape
695
696    def check_ones_like():
697        a = nd.array(np.zeros((SMALL_Y, LARGE_X)))
698        b = nd.ones_like(a)
699        assert b[-1][-1] == 1
700        assert b.shape == a.shape
701
702    def check_broadcast():
703        a = nd.ones(shape=(LARGE_X, SMALL_Y))
704        b = nd.arange(0, LARGE_X).reshape(LARGE_X, 1)
705        res = nd.broadcast_to(b, shape=(b.shape[0], SMALL_Y))
706        assert np.sum(res[-1].asnumpy() == LARGE_X) == res.shape[1]
707        res = mx.nd.broadcast_like(b, a)
708        assert np.sum(res[-1].asnumpy() == LARGE_X) == a.shape[1]
709
710    def check_clip():
711        a = nd.arange(0, LARGE_X * SMALL_Y).reshape(LARGE_X, SMALL_Y)
712        res = nd.clip(a, a_min=100, a_max=1000)
713        assert np.sum(res[-1].asnumpy() == 1000) == a.shape[1]
714
715    def check_split():
716        a = nd.arange(0, LARGE_X * SMALL_Y).reshape(LARGE_X, SMALL_Y)
717        outs = nd.split(a, num_outputs=SMALL_Y, axis=1)
718        result = sum(1 for i, v in enumerate(outs) if i == v[0].asnumpy())
719        assert result == a.shape[1]
720
721    def check_tile():
722        a = nd.arange(0, LARGE_X).reshape(LARGE_X, 1)
723        b = nd.tile(a, reps=(1, SMALL_Y))
724        assert np.sum(b[-1].asnumpy() == LARGE_X) == b.shape[1]
725
726    def check_take():
727        a = nd.ones(shape=(LARGE_X, SMALL_Y))
728        idx = nd.arange(LARGE_X - 1000, LARGE_X)
729        res = nd.take(a, idx)
730        assert np.sum(res[-1].asnumpy() == 1) == res.shape[1]
731
732    def check_slice():
733        a = nd.ones(shape=(LARGE_X, SMALL_Y))
734        res = nd.slice(a, begin=(LARGE_X-1000, 1), end=(LARGE_X, SMALL_Y))
735        assert np.sum(res[-1].asnumpy() == 1) == res.shape[1]
736
737    def check_slice_assign():
738        a = nd.ones(shape=(LARGE_X, SMALL_Y))
739        a[LARGE_X-1:LARGE_X] = 1000
740        assert np.sum(a[-1].asnumpy() == 1000) == a.shape[1]
741
742    def check_slice_like():
743        a = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X)
744        b = nd.array(np.ones((SMALL_Y//2, LARGE_X//2)))
745        c = nd.slice_like(a, b)
746        d = nd.slice_like(a, b, axes=(0))
747        e = nd.slice_like(a, b, axes=(-1))
748        assert c.shape == b.shape
749        assert d.shape[0] == b.shape[0]
750        assert e.shape[-1] == b.shape[-1]
751        assert c[0][-1] == 0
752        assert d[-1][0] == (SMALL_Y//2-1)
753        assert e[-1][-1] == (SMALL_Y-1)
754
755    def check_slice_axis():
756        a = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X)
757        c = nd.slice_axis(a, axis=0, begin=0, end=SMALL_Y//2)
758        d = nd.slice_axis(a, axis=1, begin=0, end=LARGE_X//2)
759        assert c.shape[0] == a.shape[0]//2
760        assert d.shape[1] == a.shape[1]//2
761        assert c[-1][0] == (SMALL_Y//2-1)
762        assert d[-1][-1] == (SMALL_Y-1)
763
764    def check_expand_dims():
765        a = nd.ones(shape=(LARGE_X, SMALL_Y))
766        res = nd.expand_dims(a, axis=1)
767        res.wait_to_read()
768        assert a[0][0][0] == 1
769        assert res.shape == (a.shape[0], 1, a.shape[1])
770
771    def check_squeeze():
772        a = nd.ones(shape=(LARGE_X, SMALL_Y))
773        data = nd.expand_dims(a, axis=1)
774        res = nd.squeeze(data)
775        assert res.shape == a.shape
776
777    def check_broadcast_div():
778        a = nd.ones(shape=(LARGE_X, SMALL_Y))
779        b = nd.ones(shape=(LARGE_X, 1)) * 2
780        res = a / b
781        assert np.sum(res[-1].asnumpy() == 0.5) == a.shape[1]
782
783    def check_where():
784        a = nd.ones(shape=(LARGE_X, SMALL_Y))
785        b = nd.arange(0, LARGE_X * SMALL_Y).reshape(LARGE_X, SMALL_Y)
786        res = nd.where(b > 100, a, b)
787        assert np.sum(res[-1].asnumpy() == 1) == b.shape[1]
788        csr_cond = nd.sparse.cast_storage(b < 10, 'csr')
789        res = nd.sparse.where(csr_cond, a, b)
790        assert np.sum(res[0].asnumpy() == 1) == 10
791
792    def check_pick():
793        a = mx.nd.ones(shape=(256 * 35, 1024 * 1024))
794        b = mx.nd.ones(shape=(256 * 35, ))
795        res = mx.nd.pick(a, b)
796        assert res.shape == b.shape
797
798    @unittest.skip("Memory doesn't free up after stacked execution with other ops, "
799                   "tracked at https://github.com/apache/incubator-mxnet/issues/17411")
800    def check_depthtospace():
801        def numpy_depth_to_space(x, blocksize):
802            b, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
803            tmp = np.reshape(x, [b, blocksize, blocksize, c // (blocksize**2), h,
804                             w])
805            tmp = np.transpose(tmp, [0, 3, 4, 1, 5, 2])
806            y = np.reshape(tmp, [b, c // (blocksize**2), h * blocksize,
807                           w * blocksize])
808            return y
809
810        shape_inp = (LARGE_X, 8, 4, 2)
811        data = rand_ndarray(shape_inp, 'default')
812        data_np = data.asnumpy()
813        expected = numpy_depth_to_space(data_np, 2)
814        output = mx.nd.depth_to_space(data, 2)
815        assert_almost_equal(output.asnumpy(), expected, atol=1e-3, rtol=1e-3)
816
817    @unittest.skip("Memory doesn't free up after stacked execution with other ops, "
818                   "tracked at https://github.com/apache/incubator-mxnet/issues/17411")
819    def check_spacetodepth():
820        def numpy_space_to_depth(x, blocksize):
821            b, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
822            tmp = np.reshape(x, [b, c, h // blocksize, blocksize, w // blocksize,
823                             blocksize])
824            tmp = np.transpose(tmp, [0, 3, 5, 1, 2, 4])
825            y = np.reshape(tmp, [b, c * (blocksize**2), h // blocksize,
826                           w // blocksize])
827            return y
828
829        shape_inp = (LARGE_X, 2, 8, 4)
830        data = rand_ndarray(shape_inp, 'default')
831        data_np = data.asnumpy()
832        expected = numpy_space_to_depth(data_np, 2)
833        output = mx.nd.space_to_depth(data, 2)
834        assert_almost_equal(output.asnumpy(), expected, atol=1e-3, rtol=1e-3)
835
836    @with_seed()
837    def check_diag():
838        a_np = np.random.random((LARGE_X, SMALL_Y)).astype(np.float32)
839        a = mx.nd.array(a_np)
840
841        # k == 0
842        r = mx.nd.diag(a)
843        assert_almost_equal(r.asnumpy(), np.diag(a_np))
844
845        # k == 1
846        k = 1
847        r = mx.nd.diag(a, k=k)
848        assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k))
849
850        # k == -1
851        k = -1
852        r = mx.nd.diag(a, k=k)
853        assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k))
854
855        # random k
856        k = np.random.randint(-min(LARGE_X, SMALL_Y) + 1, min(LARGE_X, SMALL_Y))
857        r = mx.nd.diag(a, k=k)
858        assert_almost_equal(r.asnumpy(), np.diag(a_np, k=k))
859
860    @with_seed()
861    def check_ravel_multi_index():
862        x1, y1 = rand_coord_2d((LARGE_X - 100), LARGE_X, 10, SMALL_Y)
863        x2, y2 = rand_coord_2d((LARGE_X - 200), LARGE_X, 9, SMALL_Y)
864        x3, y3 = rand_coord_2d((LARGE_X - 300), LARGE_X, 8, SMALL_Y)
865        indices_2d = [[x1, x2, x3], [y1, y2, y3]]
866        idx = mx.nd.ravel_multi_index(mx.nd.array(indices_2d, dtype=np.int64),
867                                      shape=(LARGE_X, SMALL_Y))
868        idx_numpy = np.ravel_multi_index(indices_2d, (LARGE_X, SMALL_Y))
869        assert np.sum(1 for i in range(idx.size) if idx[i] == idx_numpy[i]) == 3
870
871    @with_seed()
872    def check_unravel_index():
873        x1, y1 = rand_coord_2d((LARGE_X - 100), LARGE_X, 10, SMALL_Y)
874        x2, y2 = rand_coord_2d((LARGE_X - 200), LARGE_X, 9, SMALL_Y)
875        x3, y3 = rand_coord_2d((LARGE_X - 300), LARGE_X, 8, SMALL_Y)
876        original_2d_indices = [[x1, x2, x3], [y1, y2, y3]]
877        idx_numpy = np.ravel_multi_index(original_2d_indices, (LARGE_X, SMALL_Y))
878        indices_2d = mx.nd.unravel_index(mx.nd.array(idx_numpy, dtype=np.int64),
879                                         shape=(LARGE_X, SMALL_Y))
880        assert (indices_2d.asnumpy() == np.array(original_2d_indices)).all()
881
882    @unittest.skip("Memory doesn't free up after stacked execution with other ops, " +
883                   "tracked at https://github.com/apache/incubator-mxnet/issues/17411")
884    def check_transpose():
885        check_dtypes = [np.float32, np.int64]
886        for dtype in check_dtypes:
887            b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y, dtype=dtype)
888            t = b.T
889            assert t.shape == (SMALL_Y, LARGE_X)
890            ref_out = np.transpose(b.asnumpy())
891            assert_almost_equal(t.asnumpy(), ref_out, rtol=1e-10)
892
893    @unittest.skip("Memory doesn't free up after stacked execution with other ops, " +
894                   "tracked at https://github.com/apache/incubator-mxnet/issues/17411")
895    def check_swapaxes():
896        b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y)
897        t = nd.swapaxes(b, dim1=0, dim2=1)
898        assert np.sum(t[:, -1].asnumpy() == (LARGE_X - 1)) == b.shape[1]
899        assert t.shape == (SMALL_Y, LARGE_X)
900
901    @unittest.skip("Memory doesn't free up after stacked execution with other ops, " +
902                   "tracked at https://github.com/apache/incubator-mxnet/issues/17411")
903    def check_flip():
904        b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y)
905        t = nd.flip(b, axis=0)
906        assert np.sum(t[-1, :].asnumpy() == 0) == b.shape[1]
907        assert t.shape == (LARGE_X, SMALL_Y)
908
909    def check_sequence_mask():
910        # Sequence Mask input [max_sequence_length, batch_size, other_feature_dims]
911        # test with input batch_size = 2
912        a = nd.arange(0, LARGE_X * SMALL_Y * 2).reshape(LARGE_X, 2, SMALL_Y)
913        # test as identity operator
914        b = nd.SequenceMask(a)
915        assert b[-1][0][1] == a[-1][0][1]
916        assert b.shape == a.shape
917        # test with default mask
918        b = nd.SequenceMask(a, sequence_length=nd.array([1, 1]),
919                            use_sequence_length=True)
920        assert b[0][1][-1] == a[0][1][-1]  # first sequence of each batch kept
921        assert b[-1][-1][-1] != a[-1][-1][-1]  # rest sequences masked
922        assert b[-1][-1][-1] == 0
923
924        # test with mask value
925        b = nd.SequenceMask(a, sequence_length=nd.array([1, 1]),
926                            use_sequence_length=True, value=-1)
927        assert b[-1][-1][-1] == -1
928
929    def check_sequence_reverse():
930        a = nd.arange(0, LARGE_X * SMALL_Y * 2).reshape(LARGE_X, 2, SMALL_Y)
931        # test as reverse operator
932        b = nd.SequenceReverse(a)
933        assert b[-1][0][0] == a[0][0][0]
934        assert b.shape == a.shape
935        # test with sequence length
936        # 2 rows of batch 1 and 3 rows of batch 2 reversed
937        b = nd.SequenceReverse(a, sequence_length=nd.array([2, 3]),
938                               use_sequence_length=True)
939        assert b[1][0][0] == a[0][0][0]  # check if reversed
940        assert b[-1][0][0] == a[-1][0][0]  # check if intact
941        assert b.shape == a.shape
942
943    def check_sequence_last():
944        a = nd.arange(0, LARGE_X * SMALL_Y * 2).reshape(LARGE_X, 2, SMALL_Y)
945        # test if returns last sequence
946        b = nd.SequenceLast(a)
947        assert_almost_equal(b.asnumpy(), a[-1].asnumpy())  # only checks for (2, SMALL_Y) tensor
948        assert b.shape == (2, SMALL_Y)
949        # test with sequence length
950        # parameter sequence_length - NDArray with shape (batch_size)
951        # (2,3) indicates 2nd sequence from batch 1 and 3rd sequence from batch 2
952        b = nd.SequenceLast(a, sequence_length=mx.nd.array([2, 3]),
953                            use_sequence_length=True)
954        # check if it takes 2nd sequence from the first batch
955        assert b[0][-1] == a[1][0][-1]
956
957    def check_index_copy():
958        x = mx.nd.zeros((LARGE_X, SMALL_Y))
959        t = mx.nd.arange(1, SMALL_Y + 1).reshape((1, SMALL_Y))
960        index = mx.nd.array([LARGE_X - 1], dtype="int64")
961
962        x = mx.nd.contrib.index_copy(x, index, t)
963        assert x[-1][-1] == t[0][-1]
964
965    def check_one_hot():
966        # default dtype of ndarray is float32 which cannot index elements over 2^32
967        a = nd.array([1, (VLARGE_X - 1)], dtype=np.int64)
968        b = nd.one_hot(a, VLARGE_X)
969        b[0][1] == 1
970        b[1][-1] == 1
971
972    def check_full():
973        a = nd.full((SMALL_Y, LARGE_X), 3)
974        assert a.shape == (SMALL_Y, LARGE_X)
975        assert a[SMALL_Y//2][LARGE_X//2] == 3
976        assert a[-1][-1] == 3
977
978    def check_shape():
979        b = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X)
980        mx.nd.waitall()
981        assert b.shape == (SMALL_Y, LARGE_X)
982
983    def check_size():
984        b = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X)
985        mx.nd.waitall()
986        assert b.size == LARGE_SIZE
987
988    def check_copy():
989        a = nd.ones((SMALL_Y, LARGE_X))
990        b = a.copy()
991        nd.waitall()
992        assert b.shape == a.shape
993        assert b.size == LARGE_SIZE
994
995    def check_copy_to():
996        a = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X)
997        b = nd.array(np.zeros((SMALL_Y, LARGE_X)))
998        c = a.copyto(b)
999        assert c is b
1000        assert b[-1][-1] == SMALL_Y-1
1001
1002    def check_reshape_like():
1003        a = nd.array(np.zeros((SMALL_Y, LARGE_X)))
1004        b = nd.array(np.zeros((SMALL_Y//2, LARGE_X*2)))
1005        c = nd.reshape_like(a, b)
1006        assert c.shape == (SMALL_Y//2, LARGE_X*2)
1007
1008    def check_flatten():
1009        check_dtypes = [np.float32, np.int64]
1010        for dtype in check_dtypes:
1011            a = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y, dtype=dtype).reshape((LARGE_X//2, 2, SMALL_Y))
1012            b = nd.flatten(a)
1013            # Here we removed the value asserts due to different precision of `int64` and `float32`.
1014            # For `float32`, it will lose some precision when `LARGE_X` is too large, that is `LARGE_X-1`
1015            # and `LARGE_X-2` can not represent the accurate value in the current situation.
1016            assert b.shape == (LARGE_X//2, SMALL_Y*2)
1017            assert_almost_equal(b[-1,-1].asnumpy(), a[-1,-1,-1].asnumpy(), rtol=1e-8)
1018
1019    def check_concat():
1020        a = nd.array(np.ones((SMALL_Y, LARGE_X)))
1021        b = nd.array(np.zeros((SMALL_Y, LARGE_X)))
1022        for axis in [0, 1]:
1023            c = nd.concat(a, b, dim=axis)
1024            c.wait_to_read()
1025            assert c.shape[axis] == b.shape[axis] * 2
1026            assert c.shape[1-axis] == b.shape[1-axis]
1027
1028    def check_stack():
1029        a = nd.array(np.ones((SMALL_Y, LARGE_X)))
1030        b = nd.array(np.zeros((SMALL_Y, LARGE_X)))
1031        c = nd.stack(a, b, axis=1)
1032        assert c.shape == (b.shape[0], 2, LARGE_X)
1033
1034    def check_broadcast_axes():
1035        a = create_2d_tensor(rows=1, columns=LARGE_X)
1036        b = nd.broadcast_axis(a, axis=[0], size=2)
1037        assert b.shape == (a.shape[0]*2, a.shape[1])
1038
1039    def check_astype():
1040        x = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X)
1041        y = x.astype('int32')
1042        assert y.dtype == np.int32
1043        assert y[-1][-1] == SMALL_Y-1
1044
1045    def check_cast():
1046        x = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X)
1047        y = nd.cast(x, np.int32)
1048        assert y.dtype == np.int32
1049        assert y[-1][-1] == SMALL_Y-1
1050
1051    def check_repeat():
1052        x = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X//2)
1053        y = nd.repeat(x, repeats=2, axis = 1)
1054        assert y.shape == (SMALL_Y, LARGE_X)
1055        assert y[0][1] == 0
1056        assert y[-1][-1] == SMALL_Y-1
1057        x = create_2d_tensor(rows=SMALL_Y//2, columns=LARGE_X)
1058        y = nd.repeat(x, repeats=2, axis = 0)
1059        assert y.shape == (SMALL_Y, LARGE_X)
1060        assert y[0][1] == 0
1061        assert y[-1][0] == SMALL_Y//2-1
1062
1063    def check_ndarray_convert():
1064        a = nd.zeros(shape=(LARGE_X, SMALL_Y))
1065        b = a.astype(np.int32)
1066        assert b.dtype == np.int32
1067        b = a.tostype('row_sparse')
1068        assert isinstance(b, mx.nd.sparse.RowSparseNDArray)
1069
1070    def check_load_save():
1071        x = create_2d_tensor(SMALL_Y, LARGE_X)
1072        tmp = tempfile.mkdtemp()
1073        tmpfile = os.path.join(tmp, 'large_tensor')
1074        nd.save(tmpfile, [x])
1075        y = nd.load(tmpfile)
1076        y = y[0]
1077        assert x[0][0] == y[0][0]
1078        assert x[-1][-1]== y[-1][-1]
1079
1080    def check_pad():
1081        x = create_2d_tensor(rows=SMALL_Y-2, columns=LARGE_X//2-2, dtype=np.float32).reshape(1 , 1, SMALL_Y-2, LARGE_X//2-2)
1082        y = nd.pad(x, mode="edge", pad_width=(0, 0, 0, 0, 1, 1, 1, 1))
1083        assert y[0][0][1][0] == 0
1084        assert y[0][0][1][-1] == 0
1085        assert y[0][0][-1][0] == SMALL_Y-3
1086        assert y[0][0][-1][-1] == SMALL_Y-3
1087        assert y.shape == (1, 1, SMALL_Y, LARGE_X//2)
1088
1089    def check_gather():
1090        arr = mx.nd.ones((LARGE_X, SMALL_Y))
1091        idx = mx.nd.random.randint(0, LARGE_X, SMALL_X)
1092        # Calls gather_nd internally
1093        tmp = arr[idx]
1094        assert np.sum(tmp[0].asnumpy() == 1) == SMALL_Y
1095        # Calls gather_nd internally
1096        arr[idx] += 1
1097        assert np.sum(arr[idx[0]].asnumpy() == 2) == SMALL_Y
1098
1099    def check_binary_broadcast():
1100        def check_correctness(mxnet_op, numpy_op, atol=1e-3):
1101            a = mx.nd.ones((LARGE_X, SMALL_Y)).as_np_ndarray()
1102            b = 2*mx.nd.ones((LARGE_X, SMALL_Y)).as_np_ndarray()
1103            res = mxnet_op(a, b)
1104            np_res = numpy_op(1, 2)
1105            assert np.abs(res[-1][-1] - np_res) < atol
1106        check_correctness(mx.np.arctan2, np.arctan2)
1107        check_correctness(mx.np.hypot, np.hypot)
1108
1109    check_ndarray_zeros()
1110    check_ndarray_ones()
1111    check_ndarray_random_uniform()
1112    check_ndarray_random_randint()
1113    check_ndarray_random_exponential()
1114    check_ndarray_random_gamma()
1115    check_ndarray_random_multinomial()
1116    check_ndarray_random_generalized_negative_binomial()
1117    check_ndarray_random_negative_binomial()
1118    check_ndarray_random_normal()
1119    check_ndarray_random_poisson()
1120    check_ndarray_random_randn()
1121    check_ndarray_random_shuffle()
1122    check_ndarray_empty()
1123    check_zeros_like()
1124    check_ones_like()
1125    check_broadcast()
1126    check_clip()
1127    check_split()
1128    check_tile()
1129    check_take()
1130    check_slice()
1131    check_slice_assign()
1132    check_slice_like()
1133    check_slice_axis()
1134    check_expand_dims()
1135    check_squeeze()
1136    check_broadcast_div()
1137    check_where()
1138    check_pick()
1139    check_depthtospace()
1140    check_spacetodepth()
1141    check_diag()
1142    check_ravel_multi_index()
1143    check_unravel_index()
1144    check_transpose()
1145    check_swapaxes()
1146    check_flip()
1147    check_sequence_mask()
1148    check_sequence_reverse()
1149    check_sequence_last()
1150    check_index_copy()
1151    check_one_hot()
1152    check_full()
1153    check_shape()
1154    check_size()
1155    check_copy()
1156    check_copy_to()
1157    check_reshape_like()
1158    check_flatten()
1159    check_concat()
1160    check_stack()
1161    check_broadcast_axes()
1162    check_astype()
1163    check_cast()
1164    check_repeat()
1165    check_ndarray_convert()
1166    check_load_save()
1167    check_pad()
1168    check_gather()
1169    check_binary_broadcast()
1170
1171def test_linalg():
1172    def check_potrf():
1173        def run_potrf(inp):
1174            inp.attach_grad()
1175            with mx.autograd.record():
1176                out = mx.nd.linalg.potrf(inp)
1177            return inp.grad, out
1178
1179        A = get_identity_mat(LARGE_SQ_X)
1180        grad, out = run_potrf(A)
1181        assert(out.shape == (LARGE_SQ_X, LARGE_SQ_X))
1182        assert(out[0, 0] == 1)
1183        out.backward()
1184        assert(grad.shape == (LARGE_SQ_X, LARGE_SQ_X))
1185        assert(grad[0, 0] == 0.5)
1186
1187    def check_potri():
1188        def run_potri(inp):
1189            inp.attach_grad()
1190            with mx.autograd.record():
1191                out = mx.nd.linalg.potri(inp)
1192            return inp.grad, out
1193
1194        A = get_identity_mat(LARGE_SQ_X)
1195        grad, out = run_potri(A)
1196        assert(out.shape == (LARGE_SQ_X, LARGE_SQ_X))
1197        assert(out[0, 0] == 1)
1198        out.backward()
1199        assert(grad.shape == (LARGE_SQ_X, LARGE_SQ_X))
1200        assert(grad[0, 0] == -2)
1201
1202    def check_syrk_batch():
1203        # test both forward and backward
1204        # batch syrk will be applied to the last two dimensions
1205        A = nd.zeros((2, LARGE_SQ_X, LARGE_SQ_X))
1206        for i in range(LARGE_SQ_X):
1207            A[0,i,i] = 1
1208            A[1,i,i] = 0.1
1209        A.attach_grad()
1210        with mx.autograd.record():
1211            out = nd.linalg.syrk(A, alpha=2, transpose=False)
1212        assert out.shape == (2, LARGE_SQ_X, LARGE_SQ_X)
1213        assert out[0,0,0] == 2
1214        assert_almost_equal(out[1,0,0], nd.array([0.02]), rtol=1e-3, atol=1e-5)
1215        out.backward()
1216        assert A.grad.shape == (2, LARGE_SQ_X, LARGE_SQ_X)
1217        assert A.grad[0,0,0] == 4
1218        assert_almost_equal(A.grad[1,0,0], nd.array([0.4]), rtol=1e-3, atol=1e-5)
1219
1220    def check_gemm2():
1221        def run_gemm2(inp1, inp2):
1222            inp1.attach_grad()
1223            inp2.attach_grad()
1224            with mx.autograd.record():
1225                out = mx.nd.linalg.gemm2(inp1, inp2)
1226            return inp1.grad, inp2.grad, out
1227
1228        inp1 = mx.nd.ones(shape=(SMALL_Y, LARGE_X))
1229        perturbation = 0.2
1230        inp1[0][0] = perturbation
1231        inp2 = mx.nd.ones(shape=(LARGE_X, SMALL_Y))
1232        inp1_grad, inp2_grad, out = run_gemm2(inp1, inp2)
1233        assert out.asnumpy()[0][0] == LARGE_X
1234        assert out.shape == (SMALL_Y, SMALL_Y)
1235        out.backward()
1236        assert inp1_grad.shape == (SMALL_Y, LARGE_X)
1237        assert inp2_grad.shape == (LARGE_X, SMALL_Y)
1238        assert_almost_equal(inp1_grad.asnumpy()[0][0], SMALL_Y)
1239        assert_almost_equal(inp2_grad.asnumpy()[0][0], SMALL_Y - (1 - perturbation))
1240
1241    def check_gemm():
1242        def run_gemm(inp1,inp2, inp3):
1243            inp1.attach_grad()
1244            inp2.attach_grad()
1245            inp3.attach_grad()
1246            with mx.autograd.record():
1247                out = mx.nd.linalg.gemm(inp1, inp2, inp3, transpose_b=True)
1248            return inp1.grad, inp2.grad, inp3.grad, out
1249
1250        inp1 = mx.nd.ones(shape=(MEDIUM_X, SMALL_Y, MEDIUM_X))
1251        perturbation = 0.2
1252        inp1[0][0][0] = perturbation
1253        inp2 = mx.nd.ones(shape=(MEDIUM_X, SMALL_Y, MEDIUM_X))
1254        inp3 = mx.nd.ones(shape=(MEDIUM_X, SMALL_Y, SMALL_Y))
1255        inp1_grad, inp2_grad, inp3_grad, out= run_gemm(inp1, inp2, inp3)
1256        assert_almost_equal(out.asnumpy()[0][0][0], MEDIUM_X + perturbation)
1257        assert out.shape == inp3.shape
1258        out.backward()
1259        assert inp1_grad.shape == (MEDIUM_X, SMALL_Y, MEDIUM_X)
1260        assert inp2_grad.shape == (MEDIUM_X, SMALL_Y, MEDIUM_X)
1261        assert inp3_grad.shape == (MEDIUM_X, SMALL_Y, SMALL_Y)
1262        assert_almost_equal(inp1_grad.asnumpy()[0][0][0], SMALL_Y)
1263        assert_almost_equal(inp2_grad.asnumpy()[0][0][0], SMALL_Y - (1 - perturbation))
1264
1265    def check_det():
1266        def run_det(inp):
1267            inp.attach_grad()
1268            with mx.autograd.record():
1269                out = mx.nd.linalg.det(inp)
1270            return inp.grad, out
1271
1272        A = get_identity_mat(LARGE_SQ_X)
1273        grad, out = run_det(A)
1274        assert(out.shape == (1,))
1275        assert(out[0] == 1)
1276        out.backward()
1277        assert(grad.shape == (LARGE_SQ_X, LARGE_SQ_X))
1278        assert(grad[0, 0] == 1)
1279
1280    def check_inverse():
1281        def run_inverse(inp):
1282            inp.attach_grad()
1283            with mx.autograd.record():
1284                out = mx.nd.linalg.inverse(inp)
1285            return inp.grad, out
1286
1287        A = get_identity_mat(LARGE_SQ_X)
1288        grad, out = run_inverse(A)
1289        assert(out.shape == (LARGE_SQ_X, LARGE_SQ_X))
1290        assert(out[0, 0] == 1)
1291        out.backward()
1292        assert(grad.shape == (LARGE_SQ_X, LARGE_SQ_X))
1293        assert(grad[0, 0] == -1)
1294
1295    def check_trmm():
1296        def run_trmm(inp):
1297            inp.attach_grad()
1298            with mx.autograd.record():
1299                out = mx.nd.linalg.trmm(inp, inp)
1300            return inp.grad, out
1301
1302        A = get_identity_mat(LARGE_SQ_X)
1303        grad, out = run_trmm(A)
1304        assert(out.shape == (LARGE_SQ_X, LARGE_SQ_X))
1305        assert(out[0, 0] == 1)
1306        out.backward()
1307        assert(grad.shape == (LARGE_SQ_X, LARGE_SQ_X))
1308        assert(grad[0, 0] == 2)
1309
1310    def check_trsm():
1311        def run_trsm(inp):
1312            inp.attach_grad()
1313            with mx.autograd.record():
1314                out = mx.nd.linalg.trsm(inp, inp)
1315            return inp.grad, out
1316
1317        A = get_identity_mat(LARGE_SQ_X)
1318        grad, out = run_trsm(A)
1319        assert(out.shape == (LARGE_SQ_X, LARGE_SQ_X))
1320        assert(out[0, 0] == 1)
1321        out.backward()
1322        assert(grad.shape == (LARGE_SQ_X, LARGE_SQ_X))
1323        assert(grad[0, 0] == 0)
1324
1325    def check_batch_inverse():
1326        def run_inverse(inp):
1327            inp.attach_grad()
1328            with mx.autograd.record():
1329                out = mx.nd.linalg.inverse(inp)
1330            return inp.grad, out
1331
1332        B = get_identity_mat_batch(LARGE_SQ_X)
1333        grad, out = run_inverse(B)
1334        assert(out.shape == (2, LARGE_SQ_X, LARGE_SQ_X))
1335        assert(out[0, 0, 0] == 1)
1336        assert(out[1, 0, 0] == 1)
1337        out.backward()
1338        assert(grad.shape == (2, LARGE_SQ_X, LARGE_SQ_X))
1339        assert(grad[0, 0, 0] == -1)
1340        assert(grad[1, 0, 0] == -1)
1341
1342    def check_batch_trmm():
1343        def run_trmm(inp):
1344            inp.attach_grad()
1345            with mx.autograd.record():
1346                out = mx.nd.linalg.trmm(inp, inp)
1347            return inp.grad, out
1348
1349        B = get_identity_mat_batch(LARGE_SQ_X)
1350        grad, out = run_trmm(B)
1351        assert(out.shape == (2, LARGE_SQ_X, LARGE_SQ_X))
1352        assert(out[0, 0, 0] == 1)
1353        assert(out[1, 0, 0] == 1)
1354        out.backward()
1355        assert(grad.shape == (2, LARGE_SQ_X, LARGE_SQ_X))
1356        assert(grad[0, 0, 0] == 2)
1357        assert(grad[1, 0, 0] == 2)
1358
1359    def check_batch_trsm():
1360        def run_trsm(inp):
1361            inp.attach_grad()
1362            with mx.autograd.record():
1363                out = mx.nd.linalg.trsm(inp, inp)
1364            return inp.grad, out
1365
1366        B = get_identity_mat_batch(LARGE_SQ_X)
1367        grad, out = run_trsm(B)
1368        assert(out.shape == (2, LARGE_SQ_X, LARGE_SQ_X))
1369        assert(out[0, 0, 0] == 1)
1370        assert(out[1, 0, 0] == 1)
1371        out.backward()
1372        assert(grad.shape == (2, LARGE_SQ_X, LARGE_SQ_X))
1373        assert(grad[0, 0, 0] == 0)
1374        assert(grad[1, 0, 0] == 0)
1375
1376    check_gemm()
1377    check_potrf()
1378    check_potri()
1379    check_syrk_batch()
1380    check_gemm2()
1381    check_det()
1382    check_inverse()
1383    check_trmm()
1384    check_trsm()
1385    check_batch_inverse()
1386    check_batch_trmm()
1387    check_batch_trsm()
1388
1389
1390def test_linalg_errors():
1391    def check_syevd_error():
1392        A = get_identity_mat(LARGE_SQ_X)
1393        for i in range(LARGE_SQ_X):
1394            A[i,i] = 1
1395        assertRaises(MXNetError, mx.nd.linalg.syevd, A)
1396
1397    check_syevd_error()
1398
1399
1400def test_basic():
1401    def check_elementwise():
1402        a = nd.ones(shape=(LARGE_X, SMALL_Y))
1403        b = nd.ones(shape=(LARGE_X, SMALL_Y))
1404        res = a + b
1405        assert np.sum(res[-1].asnumpy() == 2) == a.shape[1]
1406        res = a + 1
1407        assert np.sum(res[-1].asnumpy() == 2) == a.shape[1]
1408        res = nd.sqrt(a + 3)
1409        assert np.sum(res[-1].asnumpy() == 2) == a.shape[1]
1410
1411    def check_reduce():
1412        a = nd.ones(shape=(LARGE_X, SMALL_Y))
1413        assert nd.sum(a).asnumpy() == a.shape[0] * a.shape[1]
1414
1415    def check_dot():
1416        a = nd.ones(shape=(LARGE_X, SMALL_Y))
1417        b = nd.ones(shape=(SMALL_Y, SMALL_Y))
1418        res = nd.dot(a, b)
1419        assert np.sum(res[-1].asnumpy() == SMALL_Y) == b.shape[1]
1420
1421    def check_argmin():
1422        a = nd.arange(0, LARGE_X * SMALL_Y).reshape(LARGE_X, SMALL_Y)
1423        idx = mx.nd.argmin(a, axis=0)
1424        assert idx.shape[0] == SMALL_Y
1425
1426    @unittest.skip("Memory doesn't free up after stacked execution with other ops, " +
1427                   "tracked at https://github.com/apache/incubator-mxnet/issues/17411")
1428    def check_argsort():
1429        b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y)
1430        s = nd.argsort(b, axis=0, is_ascend=False, dtype=np.int64)
1431        mx.nd.waitall()
1432        assert (s[0].asnumpy() == (LARGE_X - 1)).all()
1433
1434    @unittest.skip("Memory doesn't free up after stacked execution with other ops, " +
1435                   "tracked at https://github.com/apache/incubator-mxnet/issues/17411")
1436    def check_sort():
1437        b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y)
1438        s = nd.sort(b, axis=0, is_ascend=False)
1439        assert np.sum(s[-1][SMALL_Y//2:SMALL_Y].asnumpy() == 0).all()
1440        s = nd.sort(b, is_ascend=False)
1441        assert np.sum(s[0].asnumpy() == 0).all()
1442
1443    @unittest.skip("Memory doesn't free up after stacked execution with other ops, " +
1444                   "tracked at https://github.com/apache/incubator-mxnet/issues/17411")
1445    def check_topk():
1446        b = create_2d_tensor(rows=LARGE_X, columns=SMALL_Y)
1447        k = nd.topk(b, k=10, axis=0, dtype=np.int64)
1448        assert np.sum(k.asnumpy() == (LARGE_X - 1)) == SMALL_Y
1449        ind, val = mx.nd.topk(b, k=3, axis=0, dtype=np.int64, ret_typ="both",
1450                              is_ascend=False)
1451        assert np.all(ind == val)
1452        b = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X)
1453        l = nd.topk(b, k=1, axis=-1, dtype=np.int64, ret_typ="value")
1454        assert l.sum() == np.sum(np.arange(0, SMALL_Y))
1455
1456    def check_exponent_logarithm_operators():
1457        a = 2*nd.ones(shape=(LARGE_X, SMALL_Y))
1458        # exponent
1459        result = nd.exp(a)
1460        assert result[0][-1] == 7.389056
1461        assert result.shape == a.shape
1462        # exponent minus 1
1463        result = nd.expm1(a)
1464        assert result[0][-1] == 6.389056
1465        assert result.shape == a.shape
1466        # log2
1467        result = nd.log2(a)
1468        assert result[0][-1] == 1
1469        assert result.shape == a.shape
1470        # log10
1471        result = nd.log10(a)
1472        assert result[0][-1] == 0.30103
1473        assert result.shape == a.shape
1474        # log1p
1475        result = nd.log1p(a)
1476        assert result[0][-1] == 1.0986123
1477        assert result.shape == a.shape
1478        # log
1479        result = nd.log(a)
1480        assert result[0][-1] == 0.6931472
1481        assert result.shape == a.shape
1482
1483    def check_power_operators():
1484        a = 2*nd.ones(shape=(LARGE_X, SMALL_Y))
1485        # sqrt
1486        result = nd.sqrt(a)
1487        assert result[0][-1] == 1.4142135
1488        assert result.shape == a.shape
1489        # rsqrt
1490        result = nd.rsqrt(a)
1491        assert result[0][-1] == 0.70710677
1492        assert result.shape == a.shape
1493        # cbrt
1494        result = nd.cbrt(a)
1495        assert result[0][-1] == 1.2599211
1496        assert result.shape == a.shape
1497        # rcbrt
1498        result = nd.rcbrt(a)
1499        assert result[0][-1] == 0.7937005
1500        assert result.shape == a.shape
1501        # square
1502        result = nd.square(a)
1503        assert result[0][-1] == 4
1504        assert result.shape == a.shape
1505        # reciprocal
1506        result = nd.reciprocal(a)
1507        assert result[0][-1] == 0.5
1508        assert result.shape == a.shape
1509
1510    def check_elemwise_add():
1511        a = nd.ones(shape=(LARGE_X, SMALL_Y))
1512        b = nd.ones(shape=(LARGE_X, SMALL_Y))
1513        res = nd.elemwise_add(a, b)
1514        assert np.sum(res[-1].asnumpy() == 2) == a.shape[1]
1515
1516    def check_add():
1517        a = nd.ones(shape=(LARGE_X, SMALL_Y))
1518        b = nd.ones(shape=(LARGE_X, SMALL_Y))
1519        c = b.__add__(a)
1520        assert c[0][-1] == 2
1521        assert c.shape == a.shape
1522
1523    def check_sub():
1524        a = 3*nd.ones(shape=(LARGE_X, SMALL_Y))
1525        b = nd.ones(shape=(LARGE_X, SMALL_Y))
1526        c = b.__sub__(a)
1527        assert c[0][-1] == -2
1528        assert c.shape == a.shape
1529
1530    def check_rsub():
1531        a = 3*nd.ones(shape=(LARGE_X, SMALL_Y))
1532        b = nd.ones(shape=(LARGE_X, SMALL_Y))
1533        c = b.__rsub__(a)
1534        assert c[0][-1] == 2
1535        assert c.shape == a.shape
1536
1537    def check_neg():
1538        a = nd.ones(shape=(LARGE_X, SMALL_Y))
1539        c = a.__neg__()
1540        assert c[0][-1] == -1
1541        assert c.shape == a.shape
1542
1543    def check_mul():
1544        a = 2*nd.ones(shape=(LARGE_X, SMALL_Y))
1545        b = 3*nd.ones(shape=(LARGE_X, SMALL_Y))
1546        c = b.__mul__(a)
1547        assert c[0][-1] == 6
1548        assert c.shape == a.shape
1549
1550    def check_div():
1551        a = 2*nd.ones(shape=(LARGE_X, SMALL_Y))
1552        b = 3*nd.ones(shape=(LARGE_X, SMALL_Y))
1553        c = b.__div__(a)
1554        mx_divide = nd.divide(b, a)
1555        assert c[0][-1] == 3/2
1556        assert mx_divide[0][-1] == c[0][-1]
1557        assert c.shape == a.shape
1558
1559    def check_rdiv():
1560        a = 2*nd.ones(shape=(LARGE_X, SMALL_Y))
1561        b = 3*nd.ones(shape=(LARGE_X, SMALL_Y))
1562        c = b.__rdiv__(a)
1563        assert c[0][-1] == 2/3
1564        assert c.shape == a.shape
1565
1566    def check_mod():
1567        a = 2*nd.ones(shape=(LARGE_X, SMALL_Y))
1568        b = 3*nd.ones(shape=(LARGE_X, SMALL_Y))
1569        c = b.__mod__(a)
1570        assert c[0][-1] == 1
1571        assert c.shape == a.shape
1572
1573    def check_rmod():
1574        a = 2*nd.ones(shape=(LARGE_X, SMALL_Y))
1575        b = 3*nd.ones(shape=(LARGE_X, SMALL_Y))
1576        c = b.__rmod__(a)
1577        assert c[0][-1] == 2
1578        assert c.shape == a.shape
1579
1580    def check_imod():
1581        a = 2*nd.ones(shape=(LARGE_X, SMALL_Y))
1582        b = 3*nd.ones(shape=(LARGE_X, SMALL_Y))
1583        c = b.__imod__(a)
1584        assert c[0][-1] == 1
1585        assert c.shape == a.shape
1586
1587    def check_pow():
1588        a = 2*nd.ones(shape=(LARGE_X, SMALL_Y))
1589        b = 3*nd.ones(shape=(LARGE_X, SMALL_Y))
1590        c = b.__pow__(a)
1591        assert c[0][-1] == 9
1592        assert c.shape == a.shape
1593
1594    def check_rpow():
1595        a = 2*nd.ones(shape=(LARGE_X, SMALL_Y))
1596        b = 3*nd.ones(shape=(LARGE_X, SMALL_Y))
1597        c = b.__rpow__(a)
1598        assert c[0][-1] == 8
1599        assert c.shape == a.shape
1600
1601    def check_sum():
1602        a = nd.array(np.ones((SMALL_Y, LARGE_X)))
1603        b = nd.sum(a, axis=1)
1604        assert b.shape[0] == SMALL_Y
1605
1606    def check_prod():
1607        a = nd.array(np.ones((SMALL_Y, LARGE_X)))
1608        b = nd.prod(a, axis=1)
1609        assert b.shape[0] == SMALL_Y
1610
1611    def check_mean():
1612        a = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X)
1613        b = nd.mean(a, axis=0)
1614        assert b[0] == (SMALL_Y/2-1)
1615
1616    def check_min():
1617        a = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X)
1618        b = nd.min(a, axis=0)
1619        assert b[0] == 0
1620        assert b[-1] == 0
1621
1622    def check_max():
1623        a = create_2d_tensor(rows=SMALL_Y, columns=LARGE_X)
1624        b = nd.max(a, axis=0)
1625        assert b[0] == (SMALL_Y-1)
1626        assert b[-1] == (SMALL_Y-1)
1627
1628    def check_norm():
1629        a = np.array(np.full((1, LARGE_X), 3))
1630        b = np.array(np.full((1, LARGE_X), 4))
1631        c = nd.array(np.concatenate((a, b), axis=0))
1632        d = nd.norm(c, ord=2, axis=0)
1633        e = nd.norm(c, ord=1, axis=0)
1634        assert d.shape[0] == LARGE_X
1635        assert e.shape[0] == LARGE_X
1636        assert d[-1] == 5
1637        assert e[-1] == 7
1638
1639    def check_argmax():
1640        a = np.ones((SMALL_Y, LARGE_X))
1641        b = np.zeros((SMALL_Y, LARGE_X))
1642        c = nd.array(np.concatenate((a, b), axis=0))
1643        d = nd.argmax(c, axis=0)
1644        assert d.shape[0] == LARGE_X
1645        assert d[-1] == d[0] == 0
1646
1647    def check_iadd():
1648        a = nd.array(np.ones((SMALL_Y, LARGE_X)))
1649        b = nd.array(np.ones((SMALL_Y, LARGE_X)))
1650        c = b + a
1651        assert c.shape == a.shape
1652        assert c[0][-1] == 2
1653
1654    def check_isub():
1655        a = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3)))
1656        b = nd.array(np.ones((SMALL_Y, LARGE_X)))
1657        c = a - b
1658        assert c.shape == a.shape
1659        assert c[0][-1] == 2
1660
1661    def check_imul():
1662        a = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3)))
1663        b = nd.array(np.ones((SMALL_Y, LARGE_X)))
1664        c = b * a
1665        assert c.shape == a.shape
1666        assert c[0][-1] == 3
1667
1668    def check_idiv():
1669        a = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 4)))
1670        b = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 2)))
1671        c = a / b
1672        assert c.shape == a.shape
1673        assert c[0][-1] == 2
1674
1675    def check_eq():
1676        a = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3)))
1677        b = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3)))
1678        c = (a == b)
1679        assert np.sum(c[0].asnumpy() == 1).all()
1680
1681    def check_neq():
1682        a = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 2)))
1683        b = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3)))
1684        c = (a != b)
1685        assert np.sum(c[0].asnumpy() == 1).all()
1686
1687    def check_lt():
1688        a = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 2)))
1689        b = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3)))
1690        d = (a <= b)
1691        assert np.sum(d[0].asnumpy() == 1).all()
1692
1693    def check_lte():
1694        a = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 2)))
1695        b = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3)))
1696        c = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 2)))
1697        d = (a <= b)
1698        e = (a <= c)
1699        assert np.sum(d[0].asnumpy() == 1).all()
1700        assert np.sum(e[0].asnumpy() == 1).all()
1701
1702    def check_gt():
1703        a = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3)))
1704        b = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 2)))
1705        d = (a >= b)
1706        assert np.sum(d[0].asnumpy() == 1).all()
1707
1708    def check_gte():
1709        a = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3)))
1710        b = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 2)))
1711        c = nd.array(np.array(np.full((SMALL_Y, LARGE_X), 3)))
1712        d = (a >= b)
1713        e = (a >= c)
1714        assert np.sum(d[0].asnumpy() == 1).all()
1715        assert np.sum(e[0].asnumpy() == 1).all()
1716
1717    def check_sign():
1718        a = mx.nd.random.normal(-1,1, shape=(LARGE_X, SMALL_Y))
1719        mx_res = mx.nd.sign(a)
1720        assert_almost_equal(mx_res[-1][-1].asnumpy(), np.sign(a[-1][-1].asnumpy()))
1721
1722    def check_logical():
1723        def check_logical_and(a, b):
1724            mx_res = mx.nd.logical_and(a, b)
1725            assert_almost_equal(mx_res[-1][-1].asnumpy(), np.logical_and(a[-1][-1].asnumpy(), b[-1][-1].asnumpy()))
1726
1727        def check_logical_or(a, b):
1728            mx_res = mx.nd.logical_or(a, b)
1729            assert_almost_equal(mx_res[-1][-1].asnumpy(), np.logical_or(a[-1][-1].asnumpy(), b[-1][-1].asnumpy()))
1730
1731        def check_logical_not(a, b):
1732            mx_res = mx.nd.logical_not(a, b)
1733            assert_almost_equal(mx_res[-1][-1].asnumpy(), np.logical_not(a[-1][-1].asnumpy(), b[-1][-1].asnumpy()))
1734
1735        def check_logical_xor(a, b):
1736            mx_res = mx.nd.logical_xor(a, b)
1737            assert_almost_equal(mx_res[-1][-1].asnumpy(), np.logical_xor(a[-1][-1].asnumpy(), b[-1][-1].asnumpy()))
1738
1739        a = mx.nd.ones((LARGE_X, SMALL_Y))
1740        b = mx.nd.zeros((LARGE_X, SMALL_Y))
1741        check_logical_and(a, b)
1742        check_logical_or(a, b)
1743        check_logical_not(a, b)
1744        check_logical_xor(a, b)
1745
1746    def create_input_for_rounding_ops():
1747        # Creates an vector with values (-LARGE_X/2 .... -2, -1, 0, 1, 2, .... , LARGE_X/2-1)
1748        # then divides each element by 2 i.e (-LARGE_X/4 .... -1, -0.5, 0, 0.5, 1, .... , LARGE_X/4-1)
1749        # and finally broadcasts to
1750        inp = nd.arange(-LARGE_X//2, LARGE_X//2, dtype=np.float64).reshape(1, LARGE_X)
1751        inp = inp/2
1752        inp = nd.broadcast_to(inp, (SMALL_Y, LARGE_X))
1753        return inp
1754
1755    def assert_correctness_of_rounding_ops(output, mid, expected_vals):
1756        # checks verifies 5 values at the middle positions of the input vector
1757        # i.e mid-2, mid-1, mid, mid+1, mid+2
1758        output_idx_to_inspect = [mid-2, mid-1, mid, mid+1, mid+2]
1759        for i in range(len(output_idx_to_inspect)):
1760            assert output[1][output_idx_to_inspect[i]] == expected_vals[i]
1761
1762    # TODO(access2rohit): merge similar tests in large vector and array into one file.
1763    def check_rounding_ops():
1764        x = create_input_for_rounding_ops()
1765        def check_ceil():
1766            y = nd.ceil(x)
1767            # expected ouput for middle 5 values after applying ceil()
1768            expected_output = [-1, 0, 0, 1, 1]
1769            assert_correctness_of_rounding_ops(y, LARGE_X//2, expected_output)
1770        def check_fix():
1771            y = nd.fix(x)
1772            # expected ouput for middle 5 values after applying fix()
1773            expected_output = [-1, 0, 0, 0, 1]
1774            assert_correctness_of_rounding_ops(y, LARGE_X//2, expected_output)
1775        def check_floor():
1776            y = nd.floor(x)
1777            # expected ouput for middle 5 values after applying floor()
1778            expected_output = [-1, -1, 0, 0, 1]
1779            assert_correctness_of_rounding_ops(y, LARGE_X//2, expected_output)
1780        def check_rint():
1781            y = nd.rint(x)
1782            # expected ouput for middle 5 values after applying rint()
1783            expected_output = [-1, -1, 0, 0, 1]
1784            assert_correctness_of_rounding_ops(y, LARGE_X//2, expected_output)
1785        def check_round():
1786            y = nd.round(x)
1787            # expected ouput for middle 5 values after applying round()
1788            expected_output = [-1, -1, 0, 1, 1]
1789            assert_correctness_of_rounding_ops(y, LARGE_X//2, expected_output)
1790        def check_trunc():
1791            y = nd.trunc(x)
1792            # expected ouput for middle 5 values after applying trunc()
1793            expected_output = [-1, 0, 0, 0, 1]
1794            assert_correctness_of_rounding_ops(y, LARGE_X//2, expected_output)
1795        check_ceil()
1796        check_fix()
1797        check_floor()
1798        check_rint()
1799        check_round()
1800        check_trunc()
1801
1802    def create_input_for_trigonometric_ops(vals):
1803        # Creates large vector input of size(LARGE_X*10, SMALL_Y/10) from vals using broadcast_to operator
1804        inp = nd.array(vals).reshape(1, 5)
1805        inp = nd.broadcast_to(inp, (LARGE_X*10, SMALL_Y//10))
1806        return inp
1807
1808    def assert_correctness_of_trigonometric_ops(output, expected_vals, atol=1e-3):
1809        # checks verifies 5 values at positions(0, 1, -3, -2, -1) of the input vector
1810        output_idx_to_inspect = [0, 1, -3, -2, -1]
1811        for i in range(len(output_idx_to_inspect)):
1812            assert np.abs(output[1][output_idx_to_inspect[i]].asnumpy()-expected_vals[i]) <= atol
1813
1814    def check_trigonometric_ops():
1815        def check_arcsin():
1816            x = create_input_for_trigonometric_ops([-1, -.707, 0, .707, 1])
1817            y = nd.arcsin(x)
1818            # expected ouput for indices=(0, 1, -3, -2, -1) after applying arcsin()
1819            expected_output = [-np.pi/2, -np.pi/4, 0, np.pi/4, np.pi/2]
1820            assert_correctness_of_trigonometric_ops(y, expected_output)
1821
1822        def check_arccos():
1823            x = create_input_for_trigonometric_ops([-1, -.707, 0, .707, 1])
1824            y = nd.arccos(x)
1825            # expected ouput for indices=(0, 1, -3, -2, -1) after applying arccos()
1826            expected_output = [np.pi, 3*np.pi/4, np.pi/2, np.pi/4, 0]
1827            assert_correctness_of_trigonometric_ops(y, expected_output)
1828
1829        def check_arctan():
1830            x = create_input_for_trigonometric_ops([-np.Inf, -1, 0, 1, np.Inf])
1831            y = nd.arctan(x)
1832            # expected ouput for indices=(0, 1, -3, -2, -1) after applying arctan()
1833            expected_output = [-np.pi/2, -np.pi/4, 0, np.pi/4, np.pi/2]
1834            assert_correctness_of_trigonometric_ops(y, expected_output)
1835
1836        def check_sin():
1837            x = create_input_for_trigonometric_ops([-np.pi/2, -np.pi/4, 0, np.pi/4, np.pi/2])
1838            y = nd.sin(x)
1839            # expected ouput for indices=(0, 1, -3, -2, -1) after applying sin()
1840            expected_output = [-1, -.707, 0, .707, 1]
1841            assert_correctness_of_trigonometric_ops(y, expected_output)
1842
1843        def check_cos():
1844            x = create_input_for_trigonometric_ops([0, np.pi/4, np.pi/2, 3*np.pi/4, np.pi])
1845            y = nd.cos(x)
1846            # expected ouput for indices=(0, 1, -3, -2, -1) after applying cos()
1847            expected_output = [1, .707, 0, -.707, -1]
1848            assert_correctness_of_trigonometric_ops(y, expected_output)
1849
1850        def check_tan():
1851            x = create_input_for_trigonometric_ops([-np.pi/6, -np.pi/4, 0, np.pi/4, np.pi/6])
1852            y = nd.tan(x)
1853            # expected ouput for indices=(0, 1, -3, -2, -1) after applying tan()
1854            expected_output = [-.577, -1, 0, 1, .577]
1855            assert_correctness_of_trigonometric_ops(y, expected_output)
1856
1857        def check_arcsinh():
1858            x = create_input_for_trigonometric_ops([-np.pi/2, -np.pi/4, 0, np.pi/4, np.pi/2])
1859            y = nd.arcsinh(x)
1860            # expected ouput for indices=(0, 1, -3, -2, -1) after applying arcsinh()
1861            expected_output = [np.arcsinh(-np.pi/2), np.arcsinh(-np.pi/4), 0, np.arcsinh(np.pi/4), np.arcsinh(np.pi/2)]
1862            assert_correctness_of_trigonometric_ops(y, expected_output)
1863
1864        def check_arccosh():
1865            x = create_input_for_trigonometric_ops([1, np.pi/2, 3*np.pi/4, np.pi, 5*np.pi/4])
1866            y = nd.arccosh(x)
1867            # expected ouput for indices=(0, 1, -3, -2, -1) after applying arccosh()
1868            expected_output = [0, np.arccosh(np.pi/2), np.arccosh(3*np.pi/4), np.arccosh(np.pi), np.arccosh(5*np.pi/4)]
1869            assert_correctness_of_trigonometric_ops(y, expected_output)
1870
1871        def check_arctanh():
1872            x = create_input_for_trigonometric_ops([-1/4, -1/2, 0, 1/4, 1/2])
1873            y = nd.arctanh(x)
1874            # expected ouput for indices=(0, 1, -3, -2, -1) after applying arctanh()
1875            expected_output = [np.arctanh(-1/4), np.arctanh(-1/2), 0, np.arctanh(1/4), np.arctanh(1/2)]
1876            assert_correctness_of_trigonometric_ops(y, expected_output)
1877
1878        def check_sinh():
1879            x = create_input_for_trigonometric_ops([-np.pi/2, -np.pi/4, 0, np.pi/4, np.pi/2])
1880            y = nd.sinh(x)
1881            # expected ouput for indices=(0, 1, -3, -2, -1) after applying sinh()
1882            expected_output = [np.sinh(-np.pi/2), np.sinh(-np.pi/4), 0, np.sinh(np.pi/4), np.sinh(np.pi/2)]
1883            assert_correctness_of_trigonometric_ops(y, expected_output)
1884
1885        def check_cosh():
1886            x = create_input_for_trigonometric_ops([0, 1, np.pi/2, 3*np.pi/4, np.pi])
1887            y = nd.cosh(x)
1888            # expected ouput for indices=(0, 1, -3, -2, -1) after applying cosh()
1889            expected_output = [1, np.cosh(1), np.cosh(np.pi/2), np.cosh(3*np.pi/4), np.cosh(np.pi)]
1890            assert_correctness_of_trigonometric_ops(y, expected_output)
1891
1892        def check_tanh():
1893            x = create_input_for_trigonometric_ops([-1/4, -1/2, 0, 1/4, 1/2])
1894            y = nd.tanh(x)
1895            # expected ouput for indices=(0, 1, -3, -2, -1) after applying tanh()
1896            expected_output = [np.tanh(-1/4), np.tanh(-1/2), 0, np.tanh(1/4), np.tanh(1/2)]
1897            assert_correctness_of_trigonometric_ops(y, expected_output)
1898
1899        def check_radians():
1900            x = create_input_for_trigonometric_ops([0, 90, 180, 270, 360])
1901            y = nd.radians(x)
1902            # expected ouput for indices=(0, 1, -3, -2, -1) after applying radians()
1903            expected_output = [0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi]
1904            assert_correctness_of_trigonometric_ops(y, expected_output)
1905
1906        def check_degrees():
1907            x = create_input_for_trigonometric_ops([0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi])
1908            y = nd.degrees(x)
1909            # expected ouput for indices=(0, 1, -3, -2, -1) after applying degrees()
1910            expected_output = [0, 90, 180, 270, 360]
1911            assert_correctness_of_trigonometric_ops(y, expected_output)
1912
1913        check_arcsin()
1914        check_arccos()
1915        check_arctan()
1916        check_sin()
1917        check_cos()
1918        check_tan()
1919        check_arcsinh()
1920        check_arccosh()
1921        check_arctanh()
1922        check_sinh()
1923        check_cosh()
1924        check_tanh()
1925        check_radians()
1926        check_degrees()
1927
1928    def check_add_n():
1929        x = [nd.ones(LARGE_X) for j in range(SMALL_Y)]
1930        y = nd.add_n(*x)
1931        assert y[0] == SMALL_Y
1932        assert y[-1] == SMALL_Y
1933
1934    def check_modulo():
1935        x = mx.nd.ones((SMALL_Y, LARGE_X))*6
1936        y = mx.nd.ones(LARGE_X)*4
1937        z = (x%y)
1938        assert z[0][0] == 2
1939        assert z[-1][-1] == 2
1940        x = mx.nd.ones((SMALL_Y, LARGE_X))*5
1941        z = nd.modulo(x,y)
1942        assert z[0][0] == 1
1943        assert z[-1][-1] == 1
1944
1945    def check_maximum():
1946        x = mx.nd.ones((SMALL_Y, LARGE_X))*3
1947        y = mx.nd.ones(LARGE_X)*4
1948        z = nd.maximum(x, y)
1949        assert z[0][0] == 4
1950        assert z[-1][-1] == 4
1951        z = nd.maximum(x, 5)
1952        assert z[0][0] == 5
1953        assert z[-1][-1] == 5
1954
1955    def check_minimum():
1956        x = mx.nd.ones((SMALL_Y, LARGE_X))*3
1957        y = mx.nd.ones(LARGE_X)*2
1958        z = nd.minimum(x, y)
1959        assert z[0][0] == 2
1960        assert z[-1][-1] == 2
1961        z = nd.minimum(x, 5)
1962        assert z[0][0] == 3
1963        assert z[-1][-1] == 3
1964
1965    check_elementwise()
1966    check_reduce()
1967    check_dot()
1968    check_argmin()
1969    check_argsort()
1970    check_sort()
1971    check_topk()
1972    check_exponent_logarithm_operators()
1973    check_power_operators()
1974    check_elemwise_add()
1975    check_add()
1976    check_sub()
1977    check_rsub()
1978    check_neg()
1979    check_mul()
1980    check_div()
1981    check_rdiv()
1982    check_mod()
1983    check_rmod()
1984    check_imod()
1985    check_pow()
1986    check_rpow()
1987    check_sum()
1988    check_prod()
1989    check_mean()
1990    check_min()
1991    check_max()
1992    check_norm()
1993    check_argmax()
1994    check_iadd()
1995    check_isub()
1996    check_imul()
1997    check_idiv()
1998    check_eq()
1999    check_neq()
2000    check_lt()
2001    check_lte()
2002    check_gt()
2003    check_gte()
2004    check_sign()
2005    check_logical()
2006    check_rounding_ops()
2007    check_trigonometric_ops()
2008    check_add_n()
2009    check_modulo()
2010    check_maximum()
2011    check_minimum()
2012
2013
2014def test_sparse_dot():
2015    shape = (2, VLARGE_X)
2016    sp_mat1 = nd.sparse.csr_matrix(([2], [6], [0, 1, 1]), shape=shape)
2017    mat2 = nd.ones((VLARGE_X, 2))
2018    out = nd.dot(sp_mat1, mat2)
2019    assert out.asnumpy()[0][0] == 2
2020    assert out.shape == (2, 2)
2021
2022
2023def test_slice_assign():
2024    # test _slice_assign
2025    A = np.zeros((2**31, 2))
2026    A[-1] = np.ones((1))
2027    assert A[-1, 0] == 1 and A[-1, 1] == 1
2028    # test _slice_assign_scalar
2029    B = np.zeros((2**31, 2))
2030    B[-1] = 2
2031    assert B[-1, 0] == 2 and B[-1, 1] == 2
2032
2033
2034if __name__ == '__main__':
2035    import nose
2036    nose.runmodule()
2037