1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from mxnet.test_utils import *
19from mxnet.base import MXNetError
20from common import setup_module, with_seed, teardown, assertRaises
21import random
22import warnings
23
24def is_scalar(var):
25    return False if hasattr(var, "__len__") else True
26
27def get_result_type(call, dflt_stype):
28    """Try to infer result storage type for a sparse matrix and a given unary operation"""
29    if call is not None and dflt_stype != 'default':
30        zero = np.zeros(([1]))
31        result = do_normalize(call(zero))
32        if not almost_equal(result, zero, equal_nan=True):
33            expected_result_type = 'default'
34        else:
35            if dflt_stype is not None:
36                expected_result_type = dflt_stype;
37            else:
38                expected_result_type = 'default'
39    else:
40        expected_result_type = 'default'
41
42    return expected_result_type
43
44
45def get_result_type_with_scalar(call, dflt_stype):
46    """Try to infer result storage type when operating a sparse matrices and a scalar"""
47    if call is not None and dflt_stype != 'default':
48        zero = np.zeros(([1]))
49        result = call(zero, 5)
50
51        if not almost_equal(result, zero, equal_nan=True):
52            expected_result_type = 'default'
53        else:
54            if dflt_stype is not None:
55                expected_result_type = dflt_stype;
56            else:
57                expected_result_type = 'default'
58    else:
59        expected_result_type = 'default'
60
61    return expected_result_type
62
63
64def get_result_type_2(call, dflt_stype):
65    """Try to infer result storage type when operating on two sparse matrices"""
66    if call is not None and dflt_stype != 'default':
67        zero = np.zeros(([1]))
68        need_default = False
69        for outer in [zero, np.ones(zero.shape)]:
70            for inner in [zero, np.ones(zero.shape)]:
71                result = do_normalize(call(outer, inner))
72                if not almost_equal(result, zero, equal_nan=True):
73                    need_default = True
74                    break
75            if need_default is True:
76                break
77
78        if not need_default and dflt_stype is not None:
79            expected_result_type = dflt_stype
80        else:
81            expected_result_type = 'default'
82    else:
83        expected_result_type = 'default'
84
85    return expected_result_type
86
87
88def get_result_type_3(call, dflt_stype):
89    """Try to infer result storage type when operating on three sparse matrices"""
90    if call is not None and dflt_stype != 'default':
91        zero = np.zeros(([1]))
92        need_default = False
93        for moon in [zero]:
94            for outer in [zero]:
95                for inner in [zero]:
96                    res_1, res_2 = call(moon, outer, inner)
97                    result = do_normalize(res_1)
98                    if not almost_equal(result, zero, equal_nan=True):
99                        need_default = True
100                        break
101                    result = do_normalize(res_2)
102                    if not almost_equal(result, zero, equal_nan=True):
103                        need_default = True
104                        break
105                if need_default is True:
106                    break
107            if need_default is True:
108                break
109
110        if not need_default and dflt_stype is not None:
111            expected_result_type = dflt_stype
112        else:
113            expected_result_type = 'default'
114    else:
115        expected_result_type = 'default'
116
117    return expected_result_type
118
119
120def get_fw_bw_result_types(forward_numpy_call,  fwd_res_dflt,
121                           backward_numpy_call, bwd_res_dflt):
122
123    return (get_result_type(forward_numpy_call,  fwd_res_dflt),
124            get_result_type(backward_numpy_call, bwd_res_dflt))
125
126
127def get_fw_bw_result_types_2(forward_numpy_call,  fwd_res_dflt,
128                             backward_numpy_call, bwd_res_dflt):
129    return (get_result_type(forward_numpy_call,  fwd_res_dflt),
130            get_result_type_2(backward_numpy_call, bwd_res_dflt))
131
132def get_fw_bw_result_types_with_scalar(forward_numpy_call,  fwd_res_dflt,
133                                       backward_numpy_call, bwd_res_dflt):
134    return (get_result_type_with_scalar(forward_numpy_call,  fwd_res_dflt),
135            get_result_type_with_scalar(backward_numpy_call, bwd_res_dflt))
136
137def gen_rsp_random_indices(shape, density=.5, force_indices=None):
138    assert density >= 0 and density <= 1
139    indices = set()
140    if force_indices is not None:
141        for val in force_indices:
142            indices.add(int(val))
143    if not np.isclose(density, .0, rtol=1.e-3, atol=1.e-3, equal_nan=True) and len(shape) > 0:
144        row_count = shape[0]
145        for i in range(row_count):
146            r = random.uniform(0, 1)
147            if r <= density and len(indices) < shape[0]:
148                indices.add(i)
149    assert len(indices) <= shape[0]
150    return list(indices)
151
152
153def all_zero(var):
154    return 0
155
156@with_seed()
157def test_elemwise_binary_ops():
158    def test_elemwise_binary_op(name, lhs_stype, rhs_stype, shape,
159                                forward_mxnet_call, forward_numpy_call, backward_numpy_call,
160                                lhs_grad_stype,
161                                rhs_grad_stype,
162                                expected_result_storage_type=None,
163                                modifier_func=None,
164                                lhs_density=.5,
165                                rhs_density=.5,
166                                force_lr_overlap=False,
167                                force_grad_overlap=False,
168                                ograd_density=0.0,
169                                skip_gradient_check=False,
170                                shuffle_csr_indices=True,
171                                verbose=False):
172        if lhs_grad_stype is None:
173            lhs_grad_stype = lhs_stype
174        if rhs_grad_stype is None:
175            rhs_grad_stype = rhs_stype
176
177        lhs_grad_stype = get_result_type_3(backward_numpy_call, lhs_grad_stype)
178        rhs_grad_stype = get_result_type_3(backward_numpy_call, rhs_grad_stype)
179
180        if verbose is True:
181            print("testing: {}  lhs={}, rhs={}, lhs_grad_stype={}, rhs_grad_stype={}"
182                  .format(name, lhs_stype, rhs_stype, lhs_grad_stype, rhs_grad_stype))
183
184        # Output type should be same as lvalue type, unless otherwise specified
185        if expected_result_storage_type is None:
186            if lhs_stype == 'default' or rhs_stype == 'default':
187                expected_result_storage_type = 'default'
188            else:
189                expected_result_storage_type = lhs_stype
190
191        lhs = mx.symbol.Variable('lhs', stype=lhs_stype)
192        rhs = mx.symbol.Variable('rhs', stype=rhs_stype)
193
194        grad_stypes = dict()
195        grad_stypes['lhs'] = lhs_grad_stype
196        grad_stypes['rhs'] = rhs_grad_stype
197
198        if lhs_stype == 'default':
199            lhs_nd = rand_ndarray(shape, 'default')
200            if abs(lhs_density) < 1e-4:
201                func = all_zero
202            else:
203                func = modifier_func
204            lhs_nd = mx.nd.array(assign_each(lhs_nd.asnumpy(), func))
205        else:
206            lhs_nd = create_sparse_array_zd(
207                shape, lhs_stype, density=lhs_density,
208                modifier_func=modifier_func,
209                shuffle_csr_indices=shuffle_csr_indices,
210                rsp_indices=gen_rsp_random_indices(
211                    shape,
212                    density=lhs_density,
213                    force_indices=[(shape[0]/2)] if force_lr_overlap is True else None
214                ))
215
216        if rhs_stype == 'default':
217            rhs_nd = rand_ndarray(shape, 'default')
218            if abs(rhs_density) < 1e-4:
219                func = all_zero
220            else:
221                func = modifier_func
222            rhs_nd = mx.nd.array(assign_each(rhs_nd.asnumpy(), func))
223        else:
224            rhs_nd = create_sparse_array_zd(
225                shape, rhs_stype, density=rhs_density,
226                modifier_func=modifier_func,
227                shuffle_csr_indices=shuffle_csr_indices,
228                rsp_indices=gen_rsp_random_indices(
229                    shape,
230                    density=rhs_density,
231                    force_indices=[(shape[0]/2)] if force_lr_overlap is True else None
232                ))
233
234        lhs_np = lhs_nd.asnumpy()
235        rhs_np = rhs_nd.asnumpy()
236
237        if verbose is True:
238            print("lhs input: {}".format(lhs_np))
239            print("rhs input: {}".format(rhs_np))
240
241        out_np = forward_numpy_call(lhs_np, rhs_np)
242
243        if verbose is True:
244            print("out_np: {}".format(out_np))
245
246        test = forward_mxnet_call(lhs, rhs)
247
248        location = {'lhs': lhs_nd, 'rhs': rhs_nd}
249
250        outputs = check_symbolic_forward(test, location, [out_np], equal_nan=True)
251        assert len(outputs) == 1
252        assert outputs[0].stype == expected_result_storage_type
253
254        if verbose is True:
255            print ("mx forward output: ", outputs[0].asnumpy())
256            print ("lhs_nd: ", lhs_nd.stype)
257            print ("rhs_nd: ", rhs_nd.stype)
258            print ("forward output: ", outputs[0].stype)
259
260        if outputs[0].stype != 'default':
261            out_grad = create_sparse_array_zd(
262                shape, outputs[0].stype, density=ograd_density,
263                data_init=1,
264                modifier_func=lambda x: 2,
265                shuffle_csr_indices=shuffle_csr_indices,
266                rsp_indices=gen_rsp_random_indices(
267                    shape,
268                    density=ograd_density,
269                    force_indices=[(shape[0]/2)] if force_grad_overlap is True else None
270                ))
271        else:
272            if abs(ograd_density) < 1e-4:
273                out_grad = mx.nd.array(np.zeros(shape))
274            else:
275                out_grad = mx.nd.array(np.ones(shape))
276
277
278        out_grad_np = out_grad.asnumpy()
279
280        if verbose is True:
281            print("out_grad_np", out_grad_np)
282
283        ingrad_lhs_np, ingrad_rhs_np = backward_numpy_call(out_grad_np, lhs_np, rhs_np)
284
285        if verbose is True:
286            print("out_grad", out_grad.asnumpy())
287            print("ingrad_lhs_np", ingrad_lhs_np)
288            print("ingrad_rhs_np", ingrad_rhs_np)
289
290        igrads_result = check_symbolic_backward(test, location, [out_grad],
291                                                [ingrad_lhs_np, ingrad_rhs_np],
292                                                grad_stypes=grad_stypes,
293                                                equal_nan=True)
294
295        if verbose is True:
296            print("ingrad_lhs", igrads_result['lhs'].asnumpy())
297            print("ingrad_rhs", igrads_result['rhs'].asnumpy())
298
299        assert len(igrads_result) == 2
300
301        if lhs_grad_stype is not None:
302            assert igrads_result['lhs'].stype == lhs_grad_stype
303        if rhs_grad_stype is not None:
304            assert igrads_result['rhs'].stype == rhs_grad_stype
305
306        if skip_gradient_check is not True:
307            check_numeric_gradient(test, location,
308                                   grad_stype_dict=grad_stypes)
309
310    def check_all(l, r, check_function):
311        assert l.shape == r.shape
312        return check_function(l, r)
313
314    def gt(l, r):
315        return check_all(l, r, lambda a, b: a > b)
316
317    def ge(l, r):
318        return check_all(l, r, lambda a, b: a >= b)
319
320    def lt(l, r):
321        return check_all(l, r, lambda a, b: a < b)
322
323    def le(l, r):
324        return check_all(l, r, lambda a, b: a <= b)
325
326    def elemwise_mul_stype(lstype, rstype):
327        if lstype == rstype:
328            return lstype
329        elif lstype == 'default' and rstype == 'row_sparse':
330            return 'row_sparse'
331        elif lstype == 'row_sparse' and rstype == 'default':
332            return 'row_sparse'
333        elif lstype == 'default' and rstype == 'csr':
334            return 'csr'
335        elif lstype == 'csr' and rstype == 'default':
336            return 'csr'
337        else:
338            return 'default'
339
340    def elemwise_mul_lhs_grad_stype(lstype, rstype):
341        return elemwise_mul_stype(elemwise_mul_stype(lstype, rstype), rstype)
342
343    def elemwise_mul_rhs_grad_stype(lstype, rstype):
344        return elemwise_mul_stype(elemwise_mul_stype(lstype, rstype), lstype)
345
346    def check_elemwise_binary_ops(lhs_stype, rhs_stype, shape,
347                                  lhs_grad_stype=None, rhs_grad_stype=None,
348                                  lhs_density=.5, rhs_density=.5,
349                                  force_lr_overlap=False,
350                                  force_grad_overlap=False,
351                                  ograd_density=0.0):
352        test_elemwise_binary_op("elemwise_add", lhs_stype, rhs_stype, shape,
353                                lambda l, r: mx.sym.sparse.elemwise_add(l, r),
354                                lambda l, r: l + r,
355                                lambda outg, l, r: (outg, outg),
356                                lhs_grad_stype, rhs_grad_stype,
357                                ograd_density=ograd_density,
358                                force_lr_overlap=force_lr_overlap,
359                                force_grad_overlap=force_grad_overlap,
360                                lhs_density=lhs_density, rhs_density=rhs_density,
361                                verbose=False)
362
363        if ((lhs_stype is 'default' and rhs_stype is 'row_sparse') or
364            (lhs_stype is 'default' and rhs_stype is 'csr') or
365            (lhs_stype is 'row_sparse' and rhs_stype is 'row_sparse') and (rhs_density == 0.0)):
366            test_elemwise_binary_op("elemwise_add", lhs_stype, rhs_stype, shape,
367                                    lambda l, r: mx.sym.sparse.elemwise_add(l, r, out=l),
368                                    lambda l, r: l + r,
369                                    lambda outg, l, r: (outg, outg),
370                                    lhs_grad_stype, rhs_grad_stype,
371                                    ograd_density=ograd_density,
372                                    force_lr_overlap=force_lr_overlap,
373                                    force_grad_overlap=force_grad_overlap,
374                                    lhs_density=lhs_density, rhs_density=rhs_density,
375                                    verbose=False)
376            test_elemwise_binary_op("elemwise_sub", lhs_stype, rhs_stype, shape,
377                                    lambda l, r: mx.sym.sparse.elemwise_sub(l, r, out=l),
378                                    lambda l, r: l - r,
379                                    lambda outg, l, r: (outg, -outg),
380                                    lhs_grad_stype, rhs_grad_stype,
381                                    ograd_density=ograd_density,
382                                    force_lr_overlap=force_lr_overlap,
383                                    force_grad_overlap=force_grad_overlap,
384                                    lhs_density=lhs_density, rhs_density=rhs_density,
385                                    verbose=False)
386
387        if ((lhs_stype is 'row_sparse' and rhs_stype is 'row_sparse') and (lhs_density == 0.0)):
388            test_elemwise_binary_op("elemwise_add", lhs_stype, rhs_stype, shape,
389                                    lambda l, r: mx.sym.sparse.elemwise_add(l, r, out=r),
390                                    lambda l, r: l + r,
391                                    lambda outg, l, r: (outg, outg),
392                                    lhs_grad_stype, rhs_grad_stype,
393                                    ograd_density=ograd_density,
394                                    force_lr_overlap=force_lr_overlap,
395                                    force_grad_overlap=force_grad_overlap,
396                                    lhs_density=lhs_density, rhs_density=rhs_density,
397                                    verbose=False)
398            test_elemwise_binary_op("elemwise_sub", lhs_stype, rhs_stype, shape,
399                                    lambda l, r: mx.sym.sparse.elemwise_sub(l, r, out=l),
400                                    lambda l, r: l - r,
401                                    lambda outg, l, r: (outg, -outg),
402                                    lhs_grad_stype, rhs_grad_stype,
403                                    ograd_density=ograd_density,
404                                    force_lr_overlap=force_lr_overlap,
405                                    force_grad_overlap=force_grad_overlap,
406                                    lhs_density=lhs_density, rhs_density=rhs_density,
407                                    verbose=False)
408
409        test_elemwise_binary_op("elemwise_sub", lhs_stype, rhs_stype, shape,
410                                lambda l, r: mx.sym.sparse.elemwise_sub(l, r),
411                                lambda l, r: l - r,
412                                lambda outg, l, r: (outg, -outg),
413                                lhs_grad_stype, rhs_grad_stype,
414                                ograd_density=ograd_density,
415                                force_lr_overlap=force_lr_overlap,
416                                force_grad_overlap=force_grad_overlap,
417                                lhs_density=lhs_density,
418                                rhs_density=rhs_density,
419                                verbose=False)
420
421        test_elemwise_binary_op("elemwise_mul", lhs_stype, rhs_stype, shape,
422                                lambda l, r: mx.sym.sparse.elemwise_mul(l, r),
423                                lambda l, r: l * r,
424                                lambda outg, l, r: (outg * r, outg * l),
425                                elemwise_mul_lhs_grad_stype(lhs_stype, rhs_stype),
426                                elemwise_mul_rhs_grad_stype(lhs_stype, rhs_stype),
427                                expected_result_storage_type=elemwise_mul_stype(lhs_stype, rhs_stype),
428                                ograd_density=ograd_density,
429                                force_lr_overlap=force_lr_overlap,
430                                force_grad_overlap=force_grad_overlap,
431                                lhs_density=lhs_density, rhs_density=rhs_density,
432                                verbose=False)
433
434        test_elemwise_binary_op("elemwise_div", lhs_stype, rhs_stype, shape,
435                                lambda l, r: mx.sym.sparse.elemwise_div(l, r),
436                                lambda l, r: l / r,
437                                lambda outg, l, r: (outg * (1/r), outg * (-l/(r*r))),
438                                lhs_grad_stype, rhs_grad_stype,
439                                modifier_func=lambda a: a if abs(a) > 0.25 else abs(a) + 1,
440                                force_lr_overlap=force_lr_overlap,
441                                force_grad_overlap=force_grad_overlap,
442                                lhs_density=lhs_density, rhs_density=rhs_density,
443                                ograd_density=ograd_density,
444                                expected_result_storage_type='default',
445                                skip_gradient_check=True,
446                                verbose=False)
447
448        test_elemwise_binary_op("maximum", lhs_stype, rhs_stype, shape,
449                                lambda l, r: mx.sym._internal._maximum(l, r),
450                                lambda l, r: np.maximum(l, r),
451                                lambda outg, l, r: (outg * ge(l, r), outg * lt(l, r)),
452                                lhs_grad_stype, rhs_grad_stype,
453                                modifier_func=lambda a: a if abs(a) > 0.25 else abs(a) + 1,
454                                force_lr_overlap=force_lr_overlap,
455                                force_grad_overlap=force_grad_overlap,
456                                lhs_density=lhs_density, rhs_density=rhs_density,
457                                skip_gradient_check=True,
458                                ograd_density=ograd_density,
459                                verbose=False)
460
461        test_elemwise_binary_op("minimum", lhs_stype, rhs_stype, shape,
462                                lambda l, r: mx.sym._internal._minimum(l, r),
463                                lambda l, r: np.minimum(l, r),
464                                lambda outg, l, r: (outg * le(l, r), outg * gt(l, r)),
465                                lhs_grad_stype, rhs_grad_stype,
466                                modifier_func=lambda a: a if abs(a) > 0.25 else abs(a) + 1,
467                                force_lr_overlap=force_lr_overlap,
468                                force_grad_overlap=force_grad_overlap,
469                                lhs_density=lhs_density, rhs_density=rhs_density,
470                                ograd_density=ograd_density,
471                                skip_gradient_check=True,
472                                verbose=False)
473
474        test_elemwise_binary_op("hypot", lhs_stype, rhs_stype, shape,
475                                lambda l, r: mx.sym._internal._hypot(l, r),
476                                lambda l, r: np.hypot(l, r),
477                                lambda outg, l, r: (
478                                    outg * assign_each2(
479                                        l, r, lambda a, b: a/np.sqrt(a * a + b * b)),
480                                    outg * assign_each2(
481                                        l, r, lambda a, b: b/np.sqrt(a * a + b * b))
482                                ),
483                                lhs_grad_stype, rhs_grad_stype,
484                                force_lr_overlap=force_lr_overlap,
485                                force_grad_overlap=force_grad_overlap,
486                                lhs_density=lhs_density, rhs_density=rhs_density,
487                                ograd_density=ograd_density,
488                                skip_gradient_check=True,
489                                verbose=False)
490
491    # Run basic tests
492    with warnings.catch_warnings():
493        warnings.simplefilter("ignore")
494
495        for ii in range(1):
496            # Run defaults
497            check_elemwise_binary_ops('default', 'default', rand_shape_2d(5, 5))
498
499            # Try different densities
500            shape = rand_shape_2d(5, 5)
501            for lhs_density in [0.0, random.uniform(0, 1), 1.0]:
502                for rhs_density in [0.0, random.uniform(0, 1), 1.0]:
503                    for ograd_density in [0.0, random.uniform(0, 1), 1.0]:
504
505                        print("lhs_density={}, rhs_density={}, ograd_density={}, shape: {}".format(
506                            lhs_density, rhs_density, ograd_density, shape))
507
508                        # Try row_sparse overlaps
509                        for force_lr_overlap in [False, True]:
510                            for force_grad_overlap in [False, True]:
511
512                                print("  force_lr_overlap={}, force_grad_overlap={}, shape={}".
513                                      format(force_lr_overlap, force_grad_overlap, shape))
514
515                                # Left and right always overlap when one is default storage
516                                # (assuming the row_sparse one has some entries in it)
517                                if force_lr_overlap is False:
518                                    check_elemwise_binary_ops('default', 'row_sparse', shape,
519                                                              lhs_density=lhs_density,
520                                                              rhs_density=rhs_density,
521                                                              force_lr_overlap=force_lr_overlap,
522                                                              force_grad_overlap=force_grad_overlap,
523                                                              ograd_density=ograd_density)
524                                    check_elemwise_binary_ops('row_sparse', 'default', shape,
525                                                              lhs_density=lhs_density,
526                                                              rhs_density=rhs_density,
527                                                              force_lr_overlap=force_lr_overlap,
528                                                              force_grad_overlap=force_grad_overlap,
529                                                              ograd_density=ograd_density)
530
531                                # Back to left-right overlap possiblities
532                                check_elemwise_binary_ops('row_sparse', 'row_sparse', shape,
533                                                          lhs_grad_stype='row_sparse',
534                                                          rhs_grad_stype='row_sparse',
535                                                          lhs_density=lhs_density,
536                                                          rhs_density=rhs_density,
537                                                          force_lr_overlap=force_lr_overlap,
538                                                          force_grad_overlap=force_grad_overlap,
539                                                          ograd_density=ograd_density)
540
541                        # No overlap flags for CSR
542                        check_elemwise_binary_ops('csr', 'csr', shape,
543                                                  lhs_grad_stype='csr',
544                                                  rhs_grad_stype='csr',
545                                                  lhs_density=lhs_density,
546                                                  rhs_density=rhs_density,
547                                                  ograd_density=ograd_density)
548                        check_elemwise_binary_ops('csr', 'csr', shape,
549                                                  lhs_grad_stype='default',
550                                                  rhs_grad_stype='default',
551                                                  lhs_density=lhs_density,
552                                                  rhs_density=rhs_density,
553                                                  ograd_density=ograd_density)
554                        check_elemwise_binary_ops('default', 'csr', shape,
555                                                  lhs_grad_stype='csr',
556                                                  rhs_grad_stype='csr',
557                                                  lhs_density=lhs_density,
558                                                  rhs_density=rhs_density,
559                                                  ograd_density=ograd_density)
560                        check_elemwise_binary_ops('csr', 'default', shape,
561                                                  lhs_grad_stype='csr',
562                                                  rhs_grad_stype='csr',
563                                                  lhs_density=lhs_density,
564                                                  rhs_density=rhs_density,
565                                                  ograd_density=ograd_density)
566
567
568@with_seed()
569def test_elemwise_csr_same_zeros():
570    # Zeroes
571    a = mx.nd.sparse.zeros('csr', (1,1))
572    b = mx.nd.elemwise_add(a,a)
573    res = a.asnumpy() + a.asnumpy()
574    assert_almost_equal(b.asnumpy(), res)
575
576
577def as_dense(arr):
578    if arr.stype != 'default':
579        return mx.nd.cast_storage(arr, stype='default')
580    else:
581        return arr;
582
583# Make sure that 0's look like 0's when we do a comparison
584def do_normalize(arr):
585    ret = arr.copy()
586    idx = np.isclose(arr, -0, rtol=1.e-3, atol=1.e-3, equal_nan=True)
587    ret[idx] = 0
588    return ret
589
590def check_sparse_mathematical_core(name, stype,
591                                   forward_mxnet_call, forward_numpy_call, backward_numpy_call=None,
592                                   rhs_arg=None, data_init=9., grad_init=2., output_grad_stype=None,
593                                   input_grad_stype=None, force_overlap=False, density=.5,
594                                   ograd_density=.5, verbose=False, shuffle_csr_indices=True):
595    if verbose is True:
596        print("TESTING: " + name)
597
598    data = mx.symbol.Variable('data', stype=stype)
599
600    temp_input_grad_stype = input_grad_stype
601
602    if temp_input_grad_stype is None:
603        temp_input_grad_stype = stype
604
605    if rhs_arg is not None:
606        if is_scalar(rhs_arg):
607            expected_result_type, expected_grad_result_type = \
608                get_fw_bw_result_types_with_scalar(forward_numpy_call, stype,
609                                                   backward_numpy_call, temp_input_grad_stype)
610        else:
611            expected_result_type, expected_grad_result_type = \
612                get_fw_bw_result_types_2(forward_numpy_call, stype,
613                                         backward_numpy_call, temp_input_grad_stype)
614    else:
615        expected_result_type, expected_grad_result_type = \
616            get_fw_bw_result_types(forward_numpy_call, stype,
617                                   backward_numpy_call, temp_input_grad_stype)
618
619    if input_grad_stype is not None and input_grad_stype != expected_grad_result_type:
620        print("{}: explicit override of deduced input grad type '{}' with '{}'".format(
621            name, expected_grad_result_type, input_grad_stype))
622        expected_grad_result_type = input_grad_stype
623
624    shape = rand_shape_2d()
625
626    if verbose is True:
627        print("Shape: ", shape, "density: ", density, "force_overlap", force_overlap)
628
629    if stype == 'default':
630        data_tmp = np.zeros(shape)
631        if abs(density) >= 1e-4:
632            data_tmp[:] = data_init
633        arr_data = mx.nd.array(data_tmp)
634    else:
635        arr_data = create_sparse_array_zd(
636            shape, stype, density=density,
637            data_init=data_init,
638            shuffle_csr_indices=shuffle_csr_indices,
639            rsp_indices=gen_rsp_random_indices(
640                shape,
641                density=density,
642                force_indices=[(shape[0]/2)] if force_overlap is True else None
643            )
644        )
645        data_tmp = arr_data.asnumpy()
646        if verbose is True:
647            print("arr_data indices", arr_data.indices.asnumpy())
648
649    if verbose is True:
650        print("input", data_tmp)
651
652    if backward_numpy_call is None:
653        arr_grad = None
654    elif expected_grad_result_type == 'default':
655        if abs(density) < 1e-4:
656            arr_grad = mx.nd.zeros(shape)
657        else:
658            arr_grad = mx.nd.ones(shape)
659    else:
660        arr_grad = create_sparse_array_zd(
661            shape,
662            expected_grad_result_type,
663            density=density,
664            data_init=1,
665            shuffle_csr_indices=shuffle_csr_indices,
666            rsp_indices=gen_rsp_random_indices(
667                shape,
668                density=density,
669                force_indices=[(shape[0]/2)] if force_overlap is True else None
670            )
671        )
672
673    if rhs_arg is not None:
674        test = forward_mxnet_call(data, rhs_arg)
675    else:
676        test = forward_mxnet_call(data)
677
678    args = list()
679    args.append(arr_data)
680
681    if arr_grad is not None:
682        exe_test = test.bind(default_context(), args=args, args_grad=[arr_grad])
683    else:
684        exe_test = test.bind(default_context(), args=args)
685
686    exe_test.forward(is_train=True)
687    assert exe_test.outputs[0].stype == expected_result_type
688    out = exe_test.outputs[0].asnumpy()
689
690    if rhs_arg is not None:
691        npout = forward_numpy_call(data_tmp, rhs_arg)
692    else:
693        npout = forward_numpy_call(data_tmp)
694
695    if verbose is True:
696        print("out", out)
697        print("npout", npout)
698
699    assert_almost_equal(out, npout, equal_nan=True)
700
701    if backward_numpy_call is not None:
702        if output_grad_stype == 'default' or output_grad_stype is None:
703            out_grad = mx.nd.empty(shape)
704            out_grad[:] = grad_init
705        else:
706            out_grad = create_sparse_array_zd(
707                shape, output_grad_stype,
708                density=density,
709                data_init=grad_init,
710                shuffle_csr_indices=shuffle_csr_indices,
711                rsp_indices=gen_rsp_random_indices(
712                    shape,
713                    density=ograd_density,
714                    force_indices=[(shape[0]/2)] if force_overlap is True else None))
715
716        npout_grad = out_grad.asnumpy()
717
718        if verbose is True:
719            print("npout_grad", npout_grad)
720
721        if rhs_arg is not None:
722            temp = backward_numpy_call(data_tmp, rhs_arg)
723        else:
724            temp = backward_numpy_call(data_tmp)
725        input_grad = npout_grad * temp
726
727        if verbose is True:
728            print(arr_grad.asnumpy())
729        exe_test.backward(out_grad)
730        if verbose is True:
731            print(arr_grad.asnumpy())
732
733        assert arr_grad.stype == expected_grad_result_type
734
735        if verbose is True:
736            print(name)
737            print("arr_grad", arr_grad.asnumpy())
738            print("input_grad", input_grad)
739
740        assert_almost_equal(arr_grad, input_grad, equal_nan=True)
741
742
743@with_seed()
744def test_sparse_mathematical_core():
745    def util_sign(a):
746        if np.isclose(a, -0, rtol=1.e-3, atol=1.e-3, equal_nan=True):
747            return 0
748        elif np.isclose(a, 0, rtol=1.e-3, atol=1.e-3, equal_nan=True):
749            return 0
750        elif a < 0.0:
751            return -1
752        else:  # a > 0.0:
753            return 1
754
755    # Check scalar binary operators
756    def check_binary_op_with_scalar(stype,
757                                    output_grad_stype=None,
758                                    input_grad_stype=None,
759                                    density=.5, ograd_density=.5,
760                                    force_overlap=False,):
761        # mul_scalar
762        check_sparse_mathematical_core("mul_scalar", stype,
763                                       lambda x, y: x * y,
764                                       lambda x, y: x * y,
765                                       lambda input, rhs: rhs,
766                                       rhs_arg=5.0,
767                                       data_init=2, grad_init=3,
768                                       output_grad_stype=output_grad_stype,
769                                       input_grad_stype=input_grad_stype,
770                                       density=density, ograd_density=ograd_density,
771                                       force_overlap=force_overlap,
772                                       verbose=False)
773
774        # plus_scalar
775        check_sparse_mathematical_core("plus_scalar", stype,
776                                       lambda x, y: x + y,
777                                       lambda x, y: x + y,
778                                       lambda input, rhs: 1,
779                                       rhs_arg=5.0,
780                                       data_init=2, grad_init=3,
781                                       output_grad_stype=output_grad_stype,
782                                       input_grad_stype=input_grad_stype,
783                                       density=density, ograd_density=ograd_density,
784                                       force_overlap=force_overlap,
785                                       verbose=False)
786
787        # minus_scalar
788        check_sparse_mathematical_core("minus_scalar", stype,
789                                       lambda x, y: x - y,
790                                       lambda x, y: x - y,
791                                       lambda input, rhs: 1,
792                                       rhs_arg=5.0,
793                                       data_init=2, grad_init=3,
794                                       output_grad_stype=output_grad_stype,
795                                       input_grad_stype=input_grad_stype,
796                                       density=density, ograd_density=ograd_density,
797                                       force_overlap=force_overlap,
798                                       verbose=False)
799
800    # Check many basic unary operators
801    def check_mathematical_core(stype, output_grad_stype=None,
802                                input_grad_stype=None, force_overlap=False,
803                                density=.5, ograd_density=.5):
804
805        # negative
806        check_sparse_mathematical_core("negative", stype,
807                                       lambda x: mx.sym.sparse.negative(x),
808                                       lambda x: np.negative(x),
809                                       force_overlap=force_overlap,
810                                       density=density,
811                                       input_grad_stype=input_grad_stype,
812                                       ograd_density=ograd_density)
813
814        # square
815        check_sparse_mathematical_core("square", stype,
816                                       lambda x: mx.sym.sparse.square(x),
817                                       lambda x: np.square(x),
818                                       lambda x: 2 * x,
819                                       output_grad_stype=output_grad_stype,
820                                       input_grad_stype=input_grad_stype,
821                                       force_overlap=force_overlap,
822                                       density=density, ograd_density=ograd_density,
823                                       verbose=False)
824
825        # sqrt
826        check_sparse_mathematical_core("sqrt", stype,
827                                       lambda x: mx.sym.sparse.sqrt(x),
828                                       lambda x: np.sqrt(x),
829                                       lambda x: 1.0/(2.0 * np.sqrt(x)),
830                                       output_grad_stype=output_grad_stype,
831                                       input_grad_stype=input_grad_stype,
832                                       force_overlap=force_overlap,
833                                       density=density, ograd_density=ograd_density,
834                                       verbose=False)
835
836        # cbrt
837        check_sparse_mathematical_core("cbrt", stype,
838                                       lambda x: mx.sym.sparse.cbrt(x),
839                                       lambda x: np.cbrt(x),
840                                       lambda x: 1.0/(3.0 * np.cbrt(x) * np.cbrt(x)),
841                                       output_grad_stype=output_grad_stype,
842                                       input_grad_stype=input_grad_stype,
843                                       force_overlap=force_overlap,
844                                       density=density, ograd_density=ograd_density,
845                                       verbose=False)
846
847        # rint
848        check_sparse_mathematical_core("rint", stype,
849                                       lambda x: mx.sym.sparse.rint(x),
850                                       lambda x: np.rint(x),
851                                       force_overlap=force_overlap, density=density,
852                                       input_grad_stype=input_grad_stype,
853                                       ograd_density=ograd_density)
854
855        # fix
856        check_sparse_mathematical_core("fix", stype,
857                                       lambda x: mx.sym.sparse.fix(x),
858                                       lambda x: np.fix(x),
859                                       force_overlap=force_overlap, density=density,
860                                       input_grad_stype=input_grad_stype,
861                                       ograd_density=ograd_density)
862
863        # floor
864        check_sparse_mathematical_core("floor", stype, lambda x: mx.sym.sparse.floor(x),
865                                       lambda x: np.floor(x),
866                                       force_overlap=force_overlap,
867                                       input_grad_stype=input_grad_stype,
868                                       density=density, ograd_density=ograd_density)
869
870        # ceil
871        check_sparse_mathematical_core("ceil", stype,
872                                       lambda x: mx.sym.sparse.ceil(x),
873                                       lambda x: np.ceil(x),
874                                       force_overlap=force_overlap,
875                                       input_grad_stype=input_grad_stype,
876                                       density=density, ograd_density=ograd_density)
877
878        # round
879        check_sparse_mathematical_core("round", stype,
880                                       lambda x: mx.sym.sparse.round(x),
881                                       lambda x: np.round(x),
882                                       force_overlap=force_overlap,
883                                       input_grad_stype=input_grad_stype,
884                                       density=density, ograd_density=ograd_density)
885
886        # trunc
887        check_sparse_mathematical_core("trunc", stype,
888                                       lambda x: mx.sym.sparse.trunc(x),
889                                       lambda x: np.trunc(x),
890                                       force_overlap=force_overlap,
891                                       input_grad_stype=input_grad_stype,
892                                       density=density, ograd_density=ograd_density)
893
894        # sign
895        check_sparse_mathematical_core("sign", stype,
896                                       lambda x: mx.sym.sparse.sign(x),
897                                       lambda x: np.sign(x),
898                                       lambda x: np.zeros(x.shape),
899                                       output_grad_stype=output_grad_stype,
900                                       force_overlap=force_overlap,
901                                       density=density, ograd_density=ograd_density)
902
903        # log1p
904        check_sparse_mathematical_core("log1p", stype,
905                                       lambda x: mx.sym.sparse.log1p(x),
906                                       lambda x: np.log1p(x),
907                                       lambda x: 1. / (1.0 + x),
908                                       data_init=0.5, grad_init=0.5,
909                                       output_grad_stype=output_grad_stype,
910                                       input_grad_stype=input_grad_stype,
911                                       force_overlap=force_overlap, density=density,
912                                       ograd_density=ograd_density)
913
914        # expm1
915        check_sparse_mathematical_core("expm1", stype,
916                                        lambda x: mx.sym.sparse.expm1(x),
917                                        lambda x: np.expm1(x),
918                                        lambda x: np.exp(x),
919                                        data_init=0.5, grad_init=0.5,
920                                        output_grad_stype=output_grad_stype,
921                                        input_grad_stype=input_grad_stype,
922                                        force_overlap=force_overlap, density=density,
923                                        ograd_density=ograd_density)
924
925        # sin
926        check_sparse_mathematical_core("sin", stype,
927                                       lambda x: mx.sym.sparse.sin(x),
928                                       lambda x: np.sin(x),
929                                       lambda x: np.cos(x),
930                                       output_grad_stype=output_grad_stype,
931                                       input_grad_stype=input_grad_stype,
932                                       force_overlap=force_overlap,
933                                       density=density, ograd_density=ograd_density)
934
935        # tan
936        check_sparse_mathematical_core("tan", stype,
937                                       lambda x: mx.sym.sparse.tan(x),
938                                       lambda x: np.tan(x),
939                                       lambda x: np.tan(x) ** 2 + 1,
940                                       output_grad_stype=output_grad_stype,
941                                       input_grad_stype=input_grad_stype,
942                                       density=density,
943                                       ograd_density=ograd_density)
944
945        # arcsin
946        check_sparse_mathematical_core("arcsin", stype,
947                                       lambda x: mx.sym.sparse.arcsin(x),
948                                       lambda x: np.arcsin(x),
949                                       lambda x: 1. / (1. - x ** 2) ** (1. / 2.),
950                                       data_init=0.5, grad_init=0.5,
951                                       output_grad_stype=output_grad_stype,
952                                       input_grad_stype=input_grad_stype,
953                                       force_overlap=force_overlap,
954                                       density=density, ograd_density=ograd_density)
955
956        # arctan
957        check_sparse_mathematical_core("arctan", stype,
958                                       lambda x: mx.sym.sparse.arctan(x),
959                                       lambda x: np.arctan(x),
960                                       lambda x: 1. / (x ** 2. + 1.),
961                                       data_init=0.5, grad_init=0.5,
962                                       output_grad_stype=output_grad_stype,
963                                       input_grad_stype=input_grad_stype,
964                                       force_overlap=force_overlap,
965                                       density=density, ograd_density=ograd_density)
966
967        # degrees
968        check_sparse_mathematical_core("degrees", stype,
969                                       lambda x: mx.sym.sparse.degrees(x),
970                                       lambda x: np.degrees(x),
971                                       lambda x: assign_each(x, lambda a: 180./np.pi),
972                                       data_init=0.5, grad_init=0.5,
973                                       output_grad_stype=output_grad_stype,
974                                       input_grad_stype=input_grad_stype,
975                                       force_overlap=force_overlap,
976                                       density=density, ograd_density=ograd_density)
977
978        # radians
979        check_sparse_mathematical_core("radians", stype,
980                                       lambda x: mx.sym.sparse.radians(x),
981                                       lambda x: np.radians(x),
982                                       lambda x: assign_each(x, lambda a: np.pi / 180.),
983                                       data_init=0.6, grad_init=1,
984                                       output_grad_stype=output_grad_stype,
985                                       input_grad_stype=input_grad_stype,
986                                       force_overlap=force_overlap,
987                                       density=density, ograd_density=ograd_density)
988
989        # sinh
990        check_sparse_mathematical_core("sinh", stype,
991                                       lambda x: mx.sym.sparse.sinh(x),
992                                       lambda x: np.sinh(x),
993                                       lambda x: np.cosh(x),
994                                       output_grad_stype=output_grad_stype,
995                                       input_grad_stype=input_grad_stype,
996                                       force_overlap=force_overlap,
997                                       density=density, ograd_density=ograd_density)
998
999        # tanh
1000        check_sparse_mathematical_core("tanh", stype,
1001                                       lambda x: mx.sym.sparse.tanh(x),
1002                                       lambda x: np.tanh(x),
1003                                       lambda x: 1. - np.tanh(x) ** 2,
1004                                       data_init=0.5, grad_init=1,
1005                                       output_grad_stype=output_grad_stype,
1006                                       input_grad_stype=input_grad_stype,
1007                                       force_overlap=force_overlap, density=density,
1008                                       ograd_density=ograd_density)
1009
1010        # arcsinh
1011        check_sparse_mathematical_core("arcsinh", stype,
1012                                       lambda x: mx.sym.sparse.arcsinh(x),
1013                                       lambda x: np.arcsinh(x),
1014                                       lambda x: 1./(x**2 + 1.)**(1./2.),
1015                                       output_grad_stype=output_grad_stype,
1016                                       input_grad_stype=input_grad_stype,
1017                                       force_overlap=force_overlap, density=density,
1018                                       ograd_density=ograd_density)
1019
1020        # arctanh
1021        check_sparse_mathematical_core("arctanh", stype,
1022                                       lambda x: mx.sym.sparse.arctanh(x),
1023                                       lambda x: np.arctanh(x),
1024                                       lambda x: -1./(x**2 - 1.),
1025                                       data_init=0.5,
1026                                       output_grad_stype=output_grad_stype,
1027                                       input_grad_stype=input_grad_stype,
1028                                       force_overlap=force_overlap, density=density,
1029                                       ograd_density=ograd_density)
1030
1031        # abs
1032        check_sparse_mathematical_core("abs", stype,
1033                                       lambda x: mx.sym.sparse.abs(x),
1034                                       lambda x: np.abs(x),
1035                                       lambda x: assign_each(x, function=util_sign),
1036                                       output_grad_stype=output_grad_stype,
1037                                       input_grad_stype=input_grad_stype,
1038                                       force_overlap=force_overlap,
1039                                       density=density, ograd_density=ograd_density)
1040
1041        if stype != "csr":
1042            # rsqrt
1043            check_sparse_mathematical_core("rsqrt", stype,
1044                                           lambda x: mx.sym.sparse.rsqrt(x),
1045                                           lambda x: 1 / np.sqrt(x),
1046                                           lambda x: -(1.0 / (2.0 * x * np.sqrt(x))),
1047                                           output_grad_stype=output_grad_stype,
1048                                           input_grad_stype=input_grad_stype,
1049                                           force_overlap=force_overlap,
1050                                           density=density, ograd_density=ograd_density)
1051
1052            # cos
1053            check_sparse_mathematical_core("cos", stype,
1054                                           lambda x: mx.sym.sparse.cos(x),
1055                                           lambda x: np.cos(x),
1056                                           lambda x: -np.sin(x),
1057                                           output_grad_stype=output_grad_stype,
1058                                           input_grad_stype=input_grad_stype,
1059                                           force_overlap=force_overlap,
1060                                           density=density, ograd_density=ograd_density)
1061
1062            # arccos
1063            check_sparse_mathematical_core("arccos", stype,
1064                                           lambda x: mx.sym.sparse.arccos(x),
1065                                           lambda x: np.arccos(x),
1066                                           lambda x: -1. / (1. - x ** 2.) ** (1. / 2.),
1067                                           data_init=0.5, grad_init=0.5,
1068                                           output_grad_stype=output_grad_stype,
1069                                           input_grad_stype=input_grad_stype,
1070                                           force_overlap=force_overlap, density=density,
1071                                           ograd_density=ograd_density)
1072
1073            # cosh
1074            check_sparse_mathematical_core("cosh", stype,
1075                                           lambda x: mx.sym.sparse.cosh(x),
1076                                           lambda x: np.cosh(x),
1077                                           lambda x: np.sinh(x),
1078                                           data_init=5, grad_init=5,
1079                                           output_grad_stype=output_grad_stype,
1080                                           input_grad_stype=input_grad_stype,
1081                                           force_overlap=force_overlap,
1082                                           density=density, ograd_density=ograd_density)
1083
1084            # arccosh
1085            check_sparse_mathematical_core("arccosh", stype,
1086                                           lambda x: mx.sym.sparse.arccosh(x),
1087                                           lambda x: np.arccosh(x),
1088                                           lambda x: 1./(x**2 - 1.)**(1./2.),
1089                                           output_grad_stype=output_grad_stype,
1090                                           input_grad_stype=input_grad_stype,
1091                                           force_overlap=force_overlap, density=density,
1092                                           ograd_density=ograd_density)
1093
1094            # log10
1095            check_sparse_mathematical_core("log10", stype,
1096                                           lambda x: mx.sym.sparse.log10(x),
1097                                           lambda x: np.log10(x),
1098                                           lambda x: 1. / (x * np.log(10.)),
1099                                           output_grad_stype=output_grad_stype,
1100                                           input_grad_stype=input_grad_stype,
1101                                           force_overlap=force_overlap, density=density,
1102                                           ograd_density=ograd_density)
1103
1104            # log2
1105            check_sparse_mathematical_core("log2", stype,
1106                                           lambda x: mx.sym.sparse.log2(x),
1107                                           lambda x: np.log2(x),
1108                                           lambda x: 1. / (x * np.log(2.)),
1109                                           output_grad_stype=output_grad_stype,
1110                                           input_grad_stype=input_grad_stype,
1111                                           force_overlap=force_overlap, density=density,
1112                                           ograd_density=ograd_density)
1113
1114
1115            try:
1116                from scipy import special as scipy_special
1117                # On scipy v1.0, psi([0, -1, -2, -3, ...]) = [ inf, inf, inf, inf, ...]
1118                # On scipy v1.1, psi([0, -1, -2, -3, ...]) = [-inf, nan, nan, nan, ...]
1119                # Map the behavior of v1.1 psi() to that of v1.0 for ints <= 0 for consistency
1120                scipy_psi = np.vectorize(lambda x: np.inf if float(x).is_integer() and x <= 0 else
1121                                         scipy_special.psi(x))
1122                # gamma
1123                check_sparse_mathematical_core("gamma", stype,
1124                                               lambda x: mx.sym.sparse.gamma(x),
1125                                               lambda x: scipy_special.gamma(x),
1126                                               lambda x: scipy_special.gamma(x) * scipy_psi(x),
1127                                               output_grad_stype=output_grad_stype,
1128                                               input_grad_stype=input_grad_stype,
1129                                               force_overlap=force_overlap,
1130                                               density=density, ograd_density=ograd_density)
1131                # gammaln
1132                check_sparse_mathematical_core("gammaln", stype,
1133                                               lambda x: mx.sym.sparse.gammaln(x),
1134                                               lambda x: scipy_special.gammaln(x),
1135                                               lambda x: scipy_psi(x),
1136                                               output_grad_stype=output_grad_stype,
1137                                               input_grad_stype=input_grad_stype,
1138                                               force_overlap=force_overlap,
1139                                               density=density, ograd_density=ograd_density)
1140
1141            except ImportError:
1142                print("Could not import scipy. Skipping unit tests for special functions")
1143
1144    for i in range(1):
1145        print("pass", i)
1146        for density in [0.0, random.uniform(0, 1), 1.0]:
1147            for ograd_density in [0.0, random.uniform(0, 1), 1.0]:
1148                for force_overlap in [False, True]:
1149                    print("{}, {}, {}".format(density, ograd_density, force_overlap))
1150                    with warnings.catch_warnings():
1151                        warnings.simplefilter("ignore")
1152
1153                        # Check unary ops (unary fwd, binary bwd)
1154                        check_mathematical_core('default', force_overlap=force_overlap,
1155                                                density=density, ograd_density=ograd_density)
1156                        check_mathematical_core('row_sparse', force_overlap=force_overlap,
1157                                                density=density, ograd_density=ograd_density)
1158                        check_mathematical_core('row_sparse', output_grad_stype='default',
1159                                                force_overlap=force_overlap,
1160                                                density=density, ograd_density=ograd_density)
1161                        check_mathematical_core('row_sparse', output_grad_stype='row_sparse',
1162                                                force_overlap=force_overlap,
1163                                                density=density, ograd_density=ograd_density)
1164                        check_mathematical_core('csr', output_grad_stype='default',
1165                                                force_overlap=force_overlap,
1166                                                density=density, ograd_density=ograd_density)
1167                        check_mathematical_core('csr', output_grad_stype='csr',
1168                                                force_overlap=force_overlap,
1169                                                density=density, ograd_density=ograd_density)
1170
1171                        # Check binary with scalar ops
1172                        check_binary_op_with_scalar('default',
1173                                                    density=density,
1174                                                    ograd_density=ograd_density,
1175                                                    force_overlap=force_overlap)
1176                        check_binary_op_with_scalar('row_sparse',
1177                                                    density=density,
1178                                                    ograd_density=ograd_density,
1179                                                    force_overlap=force_overlap)
1180                        check_binary_op_with_scalar('row_sparse', output_grad_stype='default',
1181                                                    density=density,
1182                                                    ograd_density=ograd_density,
1183                                                    force_overlap=force_overlap)
1184                        check_binary_op_with_scalar('row_sparse',
1185                                                    output_grad_stype='row_sparse',
1186                                                    density=density, ograd_density=ograd_density,
1187                                                    force_overlap=force_overlap)
1188                        check_binary_op_with_scalar('csr',
1189                                                    output_grad_stype='csr',
1190                                                    input_grad_stype='default',
1191                                                    density=density,
1192                                                    ograd_density=ograd_density,
1193                                                    force_overlap=force_overlap)
1194                        check_binary_op_with_scalar('csr',
1195                                                    output_grad_stype='csr',
1196                                                    input_grad_stype='csr',
1197                                                    density=density,
1198                                                    ograd_density=ograd_density,
1199                                                    force_overlap=force_overlap)
1200                        check_binary_op_with_scalar('csr',
1201                                                    output_grad_stype='default',
1202                                                    density=density,
1203                                                    ograd_density=ograd_density,
1204                                                    force_overlap=force_overlap)
1205
1206
1207
1208@with_seed()
1209def test_elemwise_add_ex():
1210    def check_elemwise_add_ex(lhs_stype, rhs_stype, shape, lhs_grad_stype=None, rhs_grad_stype=None):
1211        lhs = mx.symbol.Variable('lhs', stype=lhs_stype)
1212        rhs = mx.symbol.Variable('rhs', stype=rhs_stype)
1213        lhs_nd = rand_ndarray(shape, lhs_stype)
1214        rhs_nd = rand_ndarray(shape, rhs_stype)
1215        lhs_np = lhs_nd.asnumpy()
1216        rhs_np = rhs_nd.asnumpy()
1217
1218        out_np = lhs_np + rhs_np
1219        test = mx.symbol.sparse.elemwise_add(lhs, rhs)
1220        location = {'lhs': lhs_nd, 'rhs': rhs_nd}
1221        check_symbolic_forward(test, location, [out_np])
1222        check_numeric_gradient(test, location)
1223        grad_stypes = {}
1224        if lhs_grad_stype is not None and lhs_grad_stype != 'default':
1225            grad_stypes['lhs'] = lhs_grad_stype
1226        if rhs_grad_stype is not None and rhs_grad_stype != 'default':
1227            grad_stypes['rhs'] = rhs_grad_stype
1228        check_symbolic_backward(test, location, [out_np], [out_np, out_np],
1229                                grad_stypes=grad_stypes)
1230
1231    shapes = [rand_shape_2d(), rand_shape_3d()]
1232    for shape in shapes:
1233        check_elemwise_add_ex('default', 'default', shape)
1234        check_elemwise_add_ex('row_sparse', 'row_sparse', shape,
1235                              lhs_grad_stype='row_sparse', rhs_grad_stype='row_sparse')
1236
1237
1238@with_seed()
1239def test_cast_storage_ex():
1240    def check_cast_storage(shape, density, from_stype, to_stype, check_numeric_grad=True):
1241        x = mx.symbol.Variable('x', stype=from_stype)
1242        x_nd = rand_ndarray(shape, from_stype, density=density)
1243        x_np = x_nd.asnumpy()
1244        out_np = x_np
1245        test = mx.symbol.cast_storage(x, stype=to_stype)
1246        location = {'x': x_nd}
1247        check_symbolic_forward(test, location, [out_np])
1248        # consider disable the numeric grad check for gpu block kernel since the input is large
1249        if check_numeric_grad:
1250            check_numeric_gradient(test, location)
1251        grad_stypes = {'x': to_stype}
1252        check_symbolic_backward(test, location, [out_np], [out_np], grad_stypes=grad_stypes)
1253
1254    density = [1.00, 0.50, 0.01]
1255    for d in density:
1256        shape_2d = rand_shape_2d()
1257        shape_3d = rand_shape_3d()
1258        check_cast_storage(shape_2d, d, 'csr', 'default')
1259        check_cast_storage(shape_2d, d, 'default', 'csr')
1260        check_cast_storage(shape_2d, d, 'csr', 'csr')
1261        check_cast_storage(shape_2d, d, 'row_sparse', 'default')
1262        check_cast_storage(shape_2d, d, 'default', 'row_sparse')
1263        check_cast_storage(shape_2d, d, 'row_sparse', 'row_sparse')
1264        check_cast_storage(shape_3d, d, 'row_sparse', 'default')
1265        check_cast_storage(shape_3d, d, 'default', 'row_sparse')
1266        check_cast_storage(shape_3d, d, 'row_sparse', 'row_sparse')
1267        for i in range(4, 6):
1268            shape = rand_shape_nd(i, 5)
1269            check_cast_storage(shape, d, 'default', 'row_sparse')
1270            check_cast_storage(shape, d, 'row_sparse', 'default')
1271        # Test specific gpu kernels
1272        if default_context().device_type is 'gpu':
1273            dim0 = rnd.randint(1, 10)
1274            # test gpu thread kernel
1275            check_cast_storage((dim0, rnd.randint(  1,   32)), d, 'default', 'csr')
1276            # test gpu warp   kernel
1277            check_cast_storage((dim0, rnd.randint( 32,  512)), d, 'default', 'csr')
1278            # test gpu block  kernel
1279            check_cast_storage((dim0, rnd.randint(512, 1024)), d, 'default', 'csr',
1280                               check_numeric_grad=False)
1281            # check race condition in block kernel
1282            check_cast_storage((200, 128 * 2 + 1), d, 'default', 'csr',
1283                               check_numeric_grad=False)
1284            # test gpu thread kernel
1285            check_cast_storage((dim0, rnd.randint(  1,   32)), d, 'default', 'row_sparse')
1286            # test gpu warp   kernel
1287            check_cast_storage((dim0, rnd.randint( 32,  512)), d, 'default', 'row_sparse')
1288            # test gpu block  kernel
1289            check_cast_storage((dim0, rnd.randint(512, 1024)), d, 'default', 'row_sparse',
1290                               check_numeric_grad=False)
1291
1292
1293@with_seed()
1294def test_sparse_dot():
1295    def test_infer_forward_stype(lhs_shape, rhs_shape, lhs_density, rhs_density, trans_a, trans_b):
1296        all_stypes = ["default", "csr", "row_sparse"]
1297        lhs_nd = rand_ndarray(lhs_shape, 'default', density=lhs_density)
1298        rhs_nd = rand_ndarray(rhs_shape, 'default', density=rhs_density)
1299        out_nd = mx.nd.dot(lhs_nd, rhs_nd, transpose_a=trans_a, transpose_b=trans_b)
1300        out_np = out_nd.asnumpy()
1301        for lhs_stype in all_stypes:
1302            for rhs_stype in all_stypes:
1303                for forward_stype in all_stypes:
1304                    lhs = lhs_nd.tostype(lhs_stype)
1305                    rhs = rhs_nd.tostype(rhs_stype)
1306                    out = mx.nd.dot(lhs, rhs, forward_stype=forward_stype,
1307                                    transpose_a=trans_a, transpose_b=trans_b)
1308                    assert_almost_equal(out.tostype('default').asnumpy(), out_np, rtol=1e-3, atol=1e-4)
1309                    lhs_var = mx.symbol.Variable('lhs', stype=lhs_stype)
1310                    rhs_var = mx.symbol.Variable('rhs', stype=rhs_stype)
1311                    out = mx.symbol.sparse.dot(lhs_var, rhs_var,
1312                                               forward_stype=forward_stype,
1313                                               transpose_a=trans_a, transpose_b=trans_b)
1314                    location = {'lhs': lhs, 'rhs': rhs}
1315                    check_symbolic_forward(out, location, [out_np], rtol=1e-3, atol=1e-4)
1316    def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, lhs_density, rhs_density):
1317        lhs_nd = rand_ndarray(lhs_shape, 'csr', density=lhs_density, shuffle_csr_indices=False)
1318        lhs_dns = lhs_nd.tostype('default')
1319        rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=rhs_density)
1320        rhs_dns = rhs_nd if rhs_stype == 'default' else rhs_nd.tostype('default')
1321
1322        out = mx.nd.dot(lhs_nd, rhs_nd, transpose_a=trans_lhs)
1323        out_dns = mx.nd.dot(lhs_dns, rhs_dns, transpose_a=trans_lhs)
1324        out_np = out_dns.asnumpy()
1325        assert_almost_equal(out.asnumpy(), out_np, rtol=1e-3, atol=1e-5)
1326
1327        # test symbolic forward
1328        lhs = mx.symbol.Variable('lhs', stype='csr')
1329        rhs = mx.symbol.Variable('rhs', stype=rhs_stype)
1330        out = mx.symbol.sparse.dot(lhs, rhs, transpose_a=trans_lhs)
1331        location = {'lhs': lhs_nd, 'rhs': rhs_nd}
1332        check_symbolic_forward(out, location, [out_np], rtol=1e-3, atol=1e-4)
1333
1334        # test symbolic backward
1335        backward_trans = not trans_lhs
1336        rhs_backward_grad = mx.nd.dot(lhs_dns, out_dns, transpose_a=backward_trans).asnumpy()
1337        expected = {'rhs': rhs_backward_grad}
1338        check_symbolic_backward(out, location, [out_np], expected,
1339                                grad_req={'lhs': 'null', 'rhs': 'write'},
1340                                rtol=1e-3, atol=1e-4)
1341
1342    def test_dot_dns_csr(lhs_shape, rhs_shape, lhs_density, rhs_density, trans_lhs=False, trans_rhs=False):
1343        lhs_nd = rand_ndarray(lhs_shape, stype='default', density=lhs_density)
1344        rhs_nd = rand_ndarray(rhs_shape, stype='csr', density=rhs_density)
1345        rhs_dns = rhs_nd.tostype('default')
1346
1347        if default_context() == mx.cpu():
1348            forward_stype = 'csr'
1349        else:
1350            forward_stype = 'default'
1351        out = mx.nd.sparse.dot(lhs_nd, rhs_nd, transpose_a=trans_lhs, transpose_b=trans_rhs, forward_stype=forward_stype)
1352        out_dns = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs, transpose_b=trans_rhs, forward_stype=forward_stype)
1353        out_np = out_dns.asnumpy()
1354        assert_almost_equal(out.asnumpy(), out_np, rtol=1e-3, atol=1e-5)
1355
1356        # test symbolic forward
1357        lhs = mx.symbol.Variable('lhs', stype='default')
1358        rhs = mx.symbol.Variable('rhs', stype='csr')
1359        out = mx.symbol.sparse.dot(lhs, rhs, transpose_a=trans_lhs, transpose_b=trans_rhs, forward_stype=forward_stype)
1360        location = {'lhs': lhs_nd, 'rhs': rhs_nd}
1361        check_symbolic_forward(out, location, [out_np], rtol=1e-3, atol=1e-4)
1362
1363        if default_context() == mx.cpu():
1364            # test symbolic backward
1365            backward_trans = not trans_lhs
1366            rhs_backward_grad = mx.nd.dot(lhs_nd, out_dns, transpose_a=backward_trans).asnumpy()
1367            if trans_rhs is True:
1368                rhs_backward_grad = rhs_backward_grad.T
1369            expected = {'rhs': rhs_backward_grad}
1370            check_symbolic_backward(out, location, [out_np], expected,
1371                                    grad_req={'lhs': 'null', 'rhs': 'write'},
1372                                    rtol=1e-3, atol=1e-4)
1373        else:
1374            transpose_b = not trans_rhs
1375            lhs_backward_grad = mx.nd.dot(out_dns, rhs_dns, transpose_b=transpose_b)
1376            expected = {'lhs': lhs_backward_grad.asnumpy()}
1377            check_symbolic_backward(out, location, [out_np], expected,
1378                                    grad_req={'lhs': 'write', 'rhs': 'null'},
1379                                    rtol=1e-3, atol=1e-4)
1380
1381    def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols):
1382        """Test for nnr_out = 0. Before the fix, the test would fail."""
1383        lhs = mx.nd.zeros(lhs_shape)
1384        irow = np.random.randint(0, lhs_shape[0])
1385        icol = np.random.randint(0, lhs_shape[1])
1386        lhs[irow, icol] = 1.0
1387        if trans_lhs:
1388            rhs = rand_ndarray(shape=(lhs_shape[0], rhs_num_cols), stype='default')
1389            rhs[irow, :] = 0
1390        else:
1391            rhs = rand_ndarray(shape=(lhs_shape[1], rhs_num_cols), stype='default')
1392            rhs[icol, :] = 0
1393        dns_out = mx.nd.dot(lhs, rhs, transpose_a=trans_lhs)
1394        assert mx.nd.sum(mx.nd.abs(dns_out)).asscalar() == 0
1395        sps_out = mx.nd.sparse.dot(lhs.tostype('csr'), rhs.tostype('row_sparse'), transpose_a=trans_lhs)
1396        assert same(dns_out.asnumpy(), sps_out.asnumpy())
1397
1398    density = [1.00, 0.5, 0.01]
1399    for lhs_d in density:
1400        lhs_shape = rand_shape_2d(50, 200)
1401        rhs_d = 1
1402        test_dot_csr(lhs_shape, (lhs_shape[1], 1), 'default', False, lhs_d, rhs_d)  # test gpu SpMV
1403        test_dot_csr(lhs_shape, (lhs_shape[0], 1), 'default', True,  lhs_d, rhs_d)  # (vector kernel)
1404        test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(5, 10)), 'default', False, lhs_d, rhs_d)  # test gpu SpMM
1405        test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(5, 10)), 'default', True, lhs_d, rhs_d)  # (scalar kernel)
1406        test_dot_dns_csr(lhs_shape, (lhs_shape[1], rnd.randint(50, 200)), lhs_d, lhs_d)
1407        test_dot_dns_csr(lhs_shape, (rnd.randint(50, 200), lhs_shape[1]), lhs_d, lhs_d, trans_rhs=True)
1408        for rhs_d in density:
1409            test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False, lhs_d, rhs_d)
1410            test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True, lhs_d, rhs_d)
1411            test_infer_forward_stype(lhs_shape, (lhs_shape[1], rnd.randint(10, 20)),
1412                                     lhs_d, rhs_d, False, False)
1413            test_infer_forward_stype(lhs_shape, (rnd.randint(10, 20), lhs_shape[1]),
1414                                     lhs_d, rhs_d, False, True)
1415            test_infer_forward_stype(lhs_shape, (lhs_shape[0], rnd.randint(10, 20)),
1416                                     lhs_d, rhs_d, True, False)
1417            test_infer_forward_stype(lhs_shape, (rnd.randint(10, 20), lhs_shape[0]),
1418                                     lhs_d, rhs_d, True, True)
1419
1420    test_sparse_dot_zero_output(rand_shape_2d(50, 200), False, 40)
1421    test_sparse_dot_zero_output(rand_shape_2d(50, 200), True, 40)
1422
1423@with_seed()
1424def test_sparse_dot_determinism():
1425    def check_dot_determinism(lhs_stype, rhs_stype, lhs_density, rhs_density, transpose_a, transpose_b, forward_stype):
1426        lhs_row = rnd.randint(50, 100)
1427        lhs_col = rnd.randint(50, 100)
1428        if transpose_a:
1429            if transpose_b:
1430                rhs_shape = (rnd.randint(50, 100), lhs_row)
1431            else:
1432                rhs_shape = (lhs_row, rnd.randint(50, 100))
1433        else:
1434            if transpose_b:
1435                rhs_shape = (rnd.randint(50, 100), lhs_col)
1436            else:
1437                rhs_shape = (lhs_col, rnd.randint(50, 100))
1438        lhs_shape = (lhs_row, lhs_col)
1439        lhs = rand_ndarray(lhs_shape, lhs_stype, density=lhs_density)
1440        rhs = rand_ndarray(rhs_shape, rhs_stype, density=rhs_density)
1441        res1 = mx.nd.sparse.dot(lhs, rhs, transpose_a=transpose_a, transpose_b=transpose_b, forward_stype=forward_stype)
1442        res2 = mx.nd.sparse.dot(lhs, rhs, transpose_a=transpose_a, transpose_b=transpose_b, forward_stype=forward_stype)
1443        assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.0, atol=0.0)
1444
1445    check_dot_determinism('csr', 'default', 0.1, 1.0, True, False, 'row_sparse')
1446    forward_stype = 'csr' if default_context() == mx.cpu() else 'default'
1447    check_dot_determinism('default', 'csr', 1.0, 0.1, False, False, forward_stype)
1448    check_dot_determinism('default', 'csr', 1.0, 0.1, False, True, forward_stype)
1449    check_dot_determinism('csr', 'default', 0.1, 1.0, True, False, 'default')
1450
1451
1452@with_seed()
1453def test_sparse_slice():
1454    def check_csr_slice(shape, slice_input):
1455        storage_type = 'csr'
1456        B, _ = rand_sparse_ndarray(shape, storage_type)
1457        np = B.asnumpy()
1458        begin = rnd.randint(0, B.shape[0] - 1)
1459        end = rnd.randint(begin + 1, B.shape[0])
1460        nd_slice = mx.nd.crop(B, begin=begin, end=end)
1461        assert same(nd_slice.asnumpy(), np[begin:end]), (nd_slice.asnumpy(), np[begin:end])
1462
1463    shape = (rnd.randint(7, 15), rnd.randint(1, 10))
1464    check_csr_slice(shape, True)
1465    check_csr_slice(shape, False)
1466
1467
1468@with_seed()
1469def test_sparse_retain():
1470    def check_sparse_retain(shape, density, index_type=np.int64):
1471        num_rows = shape[0]
1472        rsp, _ = rand_sparse_ndarray(shape=shape, stype='row_sparse', density=density)
1473        length = np.random.randint(1, num_rows + 1)
1474        idx = random_sample(list(range(0, num_rows)), length)
1475        idx.sort()
1476        dns = rsp.asnumpy()
1477        tensor_retained_expected = np.zeros(shape)
1478        for i in idx:
1479            tensor_retained_expected[i][:] = dns[i]
1480        indices = mx.nd.array(idx, dtype=index_type)
1481        rsp_retained = mx.nd.sparse.retain(rsp, indices=indices)
1482        assert same(tensor_retained_expected, rsp_retained.asnumpy())
1483
1484        # check numeric gradient
1485        data = mx.symbol.Variable('data')
1486        idx = mx.symbol.Variable('indices')
1487        sym = mx.sym.sparse.retain(data=data, indices=idx)
1488        check_numeric_gradient(sym, [rsp, indices], grad_nodes=['data'],
1489                               grad_stype_dict={'data': 'row_sparse'})
1490
1491    shape = rand_shape_2d()
1492    shape_3d = rand_shape_3d()
1493    densities = [0.01, 0.5, 1.0]
1494    index_types = [np.float32, np.int32, np.int64]
1495    for density in densities:
1496        for itype in index_types:
1497            check_sparse_retain(shape, density, itype)
1498            check_sparse_retain(shape_3d, density, itype)
1499
1500
1501@with_seed()
1502def test_sparse_unary_with_numerics():
1503    def check_sparse_simple(name, stype, mxnet_func, forward_numpy_call,
1504                            backward_numpy_call, output_grad_stype=None,
1505                            backward_is_use_output=False):
1506        if output_grad_stype is None:
1507            output_grad_stype = stype
1508
1509        expected_result_type, expected_grad_result_type = \
1510            get_fw_bw_result_types_2(forward_numpy_call, stype, backward_numpy_call, output_grad_stype)
1511        if backward_is_use_output is True:
1512            expected_grad_result_type = expected_result_type
1513
1514        shape = (3, 4)
1515        data = mx.symbol.Variable("data")
1516
1517        grad_stypes = {'data' : expected_grad_result_type}
1518
1519        y = mxnet_func(data)
1520        if stype == 'default':
1521            xa = np.random.uniform(low=-1.0, high=1.0, size=shape)
1522            xa_np = xa
1523        else:
1524            xa = create_sparse_array(shape, stype, data_init=None, rsp_indices=[1],
1525                                     modifier_func=lambda a: a - 0.5,
1526                                     shuffle_csr_indices=True)
1527            xa_np = xa.asnumpy()
1528
1529        if output_grad_stype != 'default':
1530            out_grad = create_sparse_array(shape, output_grad_stype, data_init=None,
1531                                           rsp_indices=[1, 2],
1532                                           modifier_func=lambda a: a - 0.5,
1533                                           shuffle_csr_indices=True)
1534            out_grad_np = out_grad.asnumpy()
1535        else:
1536            out_grad_np = np.ones(xa.shape)
1537            out_grad = mx.nd.array(out_grad_np)
1538
1539        output_np = forward_numpy_call(xa_np)
1540        input_grad_np = backward_numpy_call(output_np, out_grad_np)
1541
1542        outputs = check_symbolic_forward(y, [xa], [output_np])
1543        output = outputs[0]
1544
1545        assert output.stype == expected_result_type
1546
1547        input_grad_dict = check_symbolic_backward(y, location=[xa], out_grads=[out_grad],
1548                                                  expected=[input_grad_np],
1549                                                  grad_stypes=grad_stypes)
1550        inp_grad = input_grad_dict["data"]
1551
1552        assert inp_grad.stype == expected_grad_result_type
1553
1554    def check_sparse_function(name, mxnet_func, forward_numpy_call, backward_numpy_call,
1555                              backward_is_use_output=False):
1556        check_sparse_simple(name, 'default', mxnet_func, forward_numpy_call, backward_numpy_call)
1557        for output_grad_stype in [None, "row_sparse", "default"]:
1558            check_sparse_simple(name, 'row_sparse', mxnet_func, forward_numpy_call, backward_numpy_call,
1559                                output_grad_stype=output_grad_stype,
1560                                backward_is_use_output=backward_is_use_output)
1561
1562        for output_grad_stype in [None, "csr", "default"]:
1563            check_sparse_simple(name, 'csr', mxnet_func, forward_numpy_call, backward_numpy_call,
1564                                output_grad_stype=output_grad_stype,
1565                                backward_is_use_output=backward_is_use_output)
1566
1567    check_sparse_function('relu',
1568                          lambda x: mx.sym.relu(x),
1569                          lambda x: np.maximum(x, 0.0),
1570                          lambda output, outg: outg * assign_each(output, lambda x: x > 0.0), backward_is_use_output=True)
1571
1572    check_sparse_function('sigmoid',
1573                          lambda x: mx.sym.sigmoid(x),
1574                          lambda x: np.divide(1.0, (1.0 + np.exp(-x))),
1575                          lambda output, outg: outg * assign_each(output, lambda x: x * (1.0 - x)),
1576                          backward_is_use_output=True)
1577
1578
1579@with_seed()
1580def test_sparse_nd_zeros():
1581    def check_sparse_nd_zeros(stype, shape):
1582        zero = mx.nd.zeros(shape)
1583        sparse_zero = mx.nd.zeros(shape=shape, stype=stype)
1584        assert_almost_equal(sparse_zero.asnumpy(), zero.asnumpy())
1585
1586    shape = rand_shape_2d()
1587    check_sparse_nd_zeros('row_sparse', shape)
1588    check_sparse_nd_zeros('csr', shape)
1589    check_sparse_nd_zeros('default', shape)
1590
1591
1592@with_seed()
1593def test_sparse_nd_zeros_like():
1594    def check_sparse_nd_zeros_like(stype, shape):
1595        zero = mx.nd.zeros(shape, stype=stype)
1596        zero_like = mx.nd.sparse.zeros_like(zero)
1597        assert_almost_equal(zero.asnumpy(), zero_like.asnumpy())
1598
1599    shape = rand_shape_2d()
1600    check_sparse_nd_zeros_like('row_sparse', shape)
1601    check_sparse_nd_zeros_like('csr', shape)
1602
1603
1604@with_seed()
1605def test_sparse_axis_operations():
1606    def test_variations(func_name):
1607        dim0 = 30
1608        dim1 = 100
1609        axes = [0, 1]
1610        densities = [0, 0.5, 1]
1611        for density in densities:
1612            shape = rand_shape_2d(dim0, dim1)
1613            csr_array = rand_ndarray(shape=shape, stype='csr', density=density)
1614            dns = csr_array.tostype('default')
1615            for axis in axes:
1616                ret = func_name(csr_array, axis=axis)
1617                assert ret.stype == 'default'
1618                ret_expected = func_name(dns, axis=axis)
1619                assert_almost_equal(ret.asnumpy(), ret_expected.asnumpy())
1620
1621    def test_fallback(func_name, axis=0, keepdims=True, exclude=True):
1622        dim0 = 30
1623        dim1 = 100
1624        shape = rand_shape_2d(dim0, dim1)
1625        csr_array = rand_ndarray(shape=shape, stype='csr', density=0.01)
1626        ret= func_name(csr_array, axis=axis, keepdims=keepdims,
1627                       exclude=exclude)
1628
1629    test_variations(mx.nd.sum)
1630    test_fallback(mx.nd.sum, axis=0, keepdims=True, exclude=True)
1631    test_variations(mx.nd.mean)
1632    test_fallback(mx.nd.mean, axis=0, keepdims=True, exclude=True)
1633
1634
1635@with_seed()
1636def test_sparse_square_sum():
1637    dim0 = 30
1638    dim1 = 30
1639    axes = [0, 1]
1640    keepdims = [False, True]
1641    densities = [0, 0.01, 0.2, 0.5, 1.0]
1642    for density in densities:
1643        shape = rand_shape_2d(dim0, dim1)
1644        rsp = rand_ndarray(shape, 'row_sparse', density)
1645        dns = rsp.tostype('default')
1646        for axis in axes:
1647            for keepdim in keepdims:
1648                ret = mx.nd._internal._square_sum(rsp, axis=axis, keepdims=keepdim)
1649                if axis == 1 and keepdim:
1650                    assert ret.stype == 'row_sparse'
1651                else:
1652                    assert ret.stype == 'default'
1653                ret_expected = mx.nd.sum(dns*dns, axis=axis, keepdims=keepdim)
1654                # check forward result
1655                assert_almost_equal(ret.asnumpy(), ret_expected.asnumpy())
1656
1657                rsp_data = mx.sym.Variable('data', stype='row_sparse')
1658                test = mx.symbol._internal._square_sum(rsp_data, axis=axis, keepdims=keepdim)
1659
1660                # check symbolic backward since ograd can be an rsp
1661                # and cannot be checked through check_numeric_gradient
1662                # because it will add a loss layer as the output layer
1663                # which makes ograd of the square_sum dense
1664                if axis == 1 and keepdim:
1665                    dns_data = mx.sym.Variable('data')
1666                    baseline = mx.sym.sum(mx.sym.square(dns_data), axis=axis, keepdims=keepdim)
1667                    igrad_expected = mx.nd.empty(dns.shape)
1668                    baseline_exec = baseline.bind(default_context(), args=[dns],
1669                                                  args_grad=[igrad_expected])
1670                    baseline_exec.forward(is_train=True)
1671                    baseline_exec.backward([ret_expected])
1672                    # check backward when ograd is row sparse
1673                    check_symbolic_backward(test, [rsp], [ret_expected.tostype('row_sparse')],
1674                                            [igrad_expected.asnumpy()], grad_stypes={'data': 'row_sparse'})
1675
1676                    # check backward when ograd is dense
1677                    # the stype of output of the square_sum is deteremined in symbol binding stage.
1678                    # The ograd stype of the last layer is the same as the output stype of the last layer.
1679                    # Need to add one more layer after square_sum to trigger the kernel for ograd
1680                    # with default stype in square_sum op.
1681                    baseline1 = baseline + 1
1682                    baseline_exec1 = baseline1.bind(default_context(), args=[dns],
1683                                                    args_grad=[igrad_expected])
1684                    baseline_exec1.forward(is_train=True)
1685                    baseline_exec1.backward([ret_expected])
1686                    test1 = test + 1
1687                    check_symbolic_backward(test1, [rsp], [ret_expected], [igrad_expected.asnumpy()],
1688                                            grad_stypes={'data': 'row_sparse'})
1689
1690                # check numeric gradient
1691                check_numeric_gradient(test, [rsp], grad_stype_dict={'data': 'row_sparse'},
1692                                       atol=1e-2, rtol=0.1)
1693
1694
1695@with_seed()
1696def test_sparse_storage_fallback():
1697    """ test operators which don't implement FComputeEx or FStatefulComputeEx """
1698    def check_broadcast_add(shape, lhs_stype, rhs_stype):
1699        lhs = mx.symbol.Variable('lhs', stype=lhs_stype)
1700        rhs = mx.symbol.Variable('rhs', stype=rhs_stype)
1701        lhs_nd = rand_ndarray(shape, lhs_stype)
1702        rhs_nd = rand_ndarray(shape, rhs_stype)
1703        lhs_dns = mx.nd.cast_storage(lhs_nd, stype='default')
1704        rhs_dns = mx.nd.cast_storage(rhs_nd, stype='default')
1705
1706        out_dns = (lhs_dns + rhs_dns).asnumpy()
1707        test = mx.symbol.broadcast_add(lhs, rhs)
1708        location = {'lhs': lhs_nd, 'rhs': rhs_nd}
1709        check_symbolic_forward(test, location, [out_dns])
1710        check_numeric_gradient(test, location)
1711        check_symbolic_backward(test, location, [out_dns], [out_dns, out_dns])
1712
1713    def np_softmax(x, axis=-1):
1714        # fix for old numpy on Travis not supporting keepdims
1715        x = x - np.max(x, axis=axis, keepdims=True)
1716        x = np.exp(x)
1717        x /= np.sum(x, axis=axis, keepdims=True)
1718        return x
1719
1720    def check_softmax_with_shape(lhs_stype, rhs_stype, shape, preserve_shape=False):
1721        # bind with label
1722        ctx = default_context()
1723        X = mx.symbol.Variable('X', stype=lhs_stype)
1724        L = mx.symbol.Variable('L', stype=rhs_stype)
1725        Y = mx.symbol.SoftmaxOutput(data=X, label=L, preserve_shape=preserve_shape)
1726        x = rand_ndarray(shape, lhs_stype)
1727        l = rand_ndarray(shape, rhs_stype)
1728        l[:] = np_softmax(l.asnumpy())
1729        grad = mx.nd.empty(shape, ctx=ctx)
1730        exec1 = Y.bind(ctx, args = [x, l], args_grad = {'X': grad})
1731        exec1.forward(is_train=True)
1732        out = exec1.outputs[0].asnumpy()
1733        assert_almost_equal(out, np_softmax(x.asnumpy()), rtol=1e-4)
1734        exec1.backward()
1735        assert_almost_equal(grad.asnumpy(), np_softmax(x.asnumpy()) - l.asnumpy(),
1736                            rtol=1e-3, atol=1e-4)
1737
1738    def check_concat(shape, lhs_stype, rhs_stype):
1739        x = mx.symbol.Variable('x', stype=lhs_stype)
1740        w = mx.symbol.Variable('w', stype=rhs_stype)
1741        test = mx.sym.Concat(x, w)
1742        x_nd = rand_ndarray(shape, lhs_stype)
1743        w_nd = rand_ndarray(shape, rhs_stype)
1744        location = {'x': x_nd, 'w': w_nd}
1745        check_numeric_gradient(test, location)
1746
1747    def check_operator_with_temp_resource(shape, stype):
1748        x = mx.symbol.Variable('x', stype=stype)
1749        test = mx.sym.sum(x)
1750        x_nd = rand_ndarray(shape, stype)
1751        location = {'x': x_nd}
1752        check_numeric_gradient(test, location)
1753
1754    shape = rand_shape_2d()
1755    stypes = ['default', 'csr', 'row_sparse']
1756    for lhs in stypes:
1757        check_operator_with_temp_resource(shape, lhs)
1758        for rhs in stypes:
1759            check_broadcast_add(shape, lhs, rhs)
1760            check_concat(shape, lhs, rhs)
1761            check_softmax_with_shape(lhs, rhs, shape, preserve_shape=False)
1762            check_softmax_with_shape(rhs, rhs, shape, preserve_shape=True)
1763
1764
1765@with_seed()
1766def test_sparse_elementwise_sum():
1767    def check_sparse_elementwise_sum_with_shape(stypes, shape, n):
1768        # forward
1769        inputs = [mx.symbol.Variable('arg%d' % i) for i in range(n)]
1770        out = mx.symbol.sparse.add_n(*inputs, name='esum')
1771        arr = []
1772        arr_grad = [mx.nd.empty(shape, stype=stype) for stype in stypes]
1773        densities = [0, 0.01, 0.5, 1.0]
1774        for stype in stypes:
1775            arr.append(rand_ndarray(shape, stype, densities[np.random.randint(0, len(densities))]))
1776
1777        exec1 = out.bind(default_context(),
1778                         args=arr,
1779                         args_grad=arr_grad)
1780        exec1.forward(is_train=True)
1781        out1 = exec1.outputs[0].asnumpy()
1782        out = sum(a.asnumpy() for a in arr)
1783        assert_almost_equal(out, out1, atol=1e-5)
1784
1785        out_grad = mx.nd.empty(shape)
1786        out_grad[:] = np.random.uniform(-10, 10, shape)
1787        # backward
1788        exec1.backward([out_grad])
1789        for a in arr_grad:
1790            assert_almost_equal(a.asnumpy(), out_grad.asnumpy(), atol=1e-5)
1791
1792    all_stypes = ['default', 'csr', 'row_sparse']
1793    for dim in range(2, 4):
1794        shape = tuple(np.random.randint(5, 10, size=dim))
1795        rsp_test_cnt = np.random.randint(1, 9)
1796        check_sparse_elementwise_sum_with_shape(['row_sparse' for i in range(rsp_test_cnt)], shape, rsp_test_cnt)
1797        if dim is 2:
1798            check_sparse_elementwise_sum_with_shape(['default', 'csr', 'default'], shape, 3)
1799            test_len = np.random.randint(5, 10)
1800            # at least one default type
1801            stypes = ['default']
1802            for i in range(test_len):
1803                pick_side = np.random.randint(2)
1804                pick_type = np.random.randint(3)
1805                stypes = ([all_stypes[pick_type]] if pick_side is 0 else []) + stypes + ([all_stypes[pick_type]] if pick_side is 1 else [])
1806            check_sparse_elementwise_sum_with_shape(stypes, shape, test_len+1)
1807
1808
1809@with_seed()
1810def test_contrib_sparse_embedding():
1811    ''' test sparse embedding operator '''
1812    def check_sparse_embedding(in_dim, out_dim, batch, densities, deterministic, weight_stype):
1813        # init executor
1814        data = mx.sym.Variable("data")
1815        weight = mx.sym.Variable("embed_weight", stype=weight_stype)
1816        embed = mx.sym.contrib.SparseEmbedding(data=data, weight=weight, input_dim=in_dim,
1817                                               output_dim=out_dim, deterministic=deterministic,
1818                                               name="embed")
1819        grad_req = {'data': 'null', 'embed_weight': 'write'}
1820        exe_test = embed.simple_bind(default_context(), grad_req=grad_req, data=(batch,))
1821        arg_map = dict(zip(embed.list_arguments(), exe_test.arg_arrays))
1822        grad_map = dict(zip(embed.list_arguments(), exe_test.grad_arrays))
1823        # init data
1824        np_data = np.random.randint(low=0, high=in_dim, size=batch)
1825        np_onehot = np.zeros((batch, in_dim)).astype(np.float32)
1826        np_onehot[np.arange(batch), np_data] = 1.0
1827        arg_map["data"][:] = np_data
1828        # init grad
1829        np_grad = np.random.uniform(-1, 1, exe_test.outputs[0].shape)
1830        grad = mx.nd.zeros(np_grad.shape)
1831        grad[:] = np_grad
1832        # weight
1833        weight = arg_map["embed_weight"]
1834        for density in densities:
1835            # update weight based on density
1836            weight[:] = rand_ndarray(weight.shape, weight_stype, density=density)
1837            # check forward
1838            exe_test.forward(is_train=True)
1839            assert_almost_equal(exe_test.outputs[0].asnumpy(), np.dot(np_onehot, weight.asnumpy()), atol=1e-4)
1840            # check backward
1841            exe_test.backward([grad])
1842            assert_almost_equal(grad_map["embed_weight"].asnumpy(), np.dot(np_onehot.T, grad.asnumpy()), atol=1e-4)
1843            # run twice to check if the result is deterministic when passing "deterministic=True" to SparseEmbedding
1844            if deterministic:
1845                grad_ref = grad_map["embed_weight"].asnumpy()
1846                exe_test.backward([grad])
1847                assert_almost_equal(grad_map["embed_weight"].asnumpy(), grad_ref, atol=0, rtol=0)
1848
1849    densities = [0, 0.5, 1]
1850    in_dim = 50
1851    out_dim = 3
1852    batch = 8
1853    stypes = ['default', 'row_sparse']
1854    deterministics = [True, False]
1855    for stype in stypes:
1856        for deterministic in deterministics:
1857            check_sparse_embedding(in_dim, out_dim, batch, densities, deterministic, stype)
1858            check_sparse_embedding(in_dim, out_dim, batch, densities, deterministic, stype)
1859
1860@with_seed()
1861def test_sparse_embedding():
1862    ''' test sparse embedding operator '''
1863    def check_sparse_embedding(in_dim, out_dim, batch, densities, sparse_grad, weight_stype):
1864        target_stype = 'row_sparse' if sparse_grad else 'default'
1865        # init executor
1866        data = mx.sym.Variable("data")
1867        weight = mx.sym.Variable("embed_weight", stype=weight_stype)
1868        embed = mx.sym.sparse.Embedding(data=data, weight=weight, input_dim=in_dim,
1869                                        sparse_grad=sparse_grad, output_dim=out_dim, name='embed')
1870        grad_req = {'data': 'null', 'embed_weight': 'write'}
1871        exe_test = embed.simple_bind(default_context(), grad_req=grad_req, data=(batch,))
1872        arg_map = dict(zip(embed.list_arguments(), exe_test.arg_arrays))
1873        grad_map = dict(zip(embed.list_arguments(), exe_test.grad_arrays))
1874        # init data
1875        np_data = np.random.randint(low=0, high=in_dim, size=batch)
1876        np_onehot = np.zeros((batch, in_dim)).astype(np.float32)
1877        np_onehot[np.arange(batch), np_data] = 1.0
1878        arg_map["data"][:] = np_data
1879        # init grad
1880        np_grad = np.random.uniform(-1, 1, exe_test.outputs[0].shape)
1881        grad = mx.nd.zeros(np_grad.shape)
1882        grad[:] = np_grad
1883        # weight
1884        weight = arg_map["embed_weight"]
1885        for density in densities:
1886            # update weight based on density
1887            weight[:] = rand_ndarray(weight.shape, weight_stype, density=density)
1888            # check forward
1889            exe_test.forward(is_train=True)
1890            assert_almost_equal(exe_test.outputs[0].asnumpy(), np.dot(np_onehot, weight.asnumpy()), atol=1e-4)
1891            # check backward
1892            exe_test.backward([grad])
1893            assert_almost_equal(grad_map["embed_weight"].asnumpy(), np.dot(np_onehot.T, grad.asnumpy()), atol=1e-4)
1894            # check grad stype
1895            assert(grad_map["embed_weight"].stype == target_stype)
1896
1897    densities = [0, 0.5, 1]
1898    in_dim = 50
1899    out_dim = 3
1900    batch = 8
1901    weight_stypes = ['default', 'row_sparse']
1902    sparse_grads = [True, False]
1903    for weight_stype in weight_stypes:
1904        for sparse_grad in sparse_grads:
1905            check_sparse_embedding(in_dim, out_dim, batch, densities, sparse_grad, weight_stype)
1906            check_sparse_embedding(in_dim, out_dim, batch, densities, sparse_grad, weight_stype)
1907
1908@with_seed()
1909def test_sparse_broadcast_add_sub():
1910    def check_broadcast_add(mx_lhs, mx_rhs, np_lhs, np_rhs, dtype):
1911        assert_almost_equal(mx.nd.sparse.add(mx_lhs, mx_rhs).asnumpy(), np.add(np_lhs, np_rhs), atol=1e-4)
1912    def check_broadcast_sub(mx_lhs, mx_rhs, np_lhs, np_rhs, dtype):
1913        assert_almost_equal(mx.nd.sparse.subtract(mx_lhs, mx_rhs).asnumpy(), np.subtract(np_lhs, np_rhs), atol=1e-4)
1914    stype = 'csr'
1915    shape = rand_shape_2d()
1916    num_rows = shape[0]
1917    num_cols = shape[1]
1918    for density in [0.1 * i for i in range(10)]:
1919        mx_lhs = rand_ndarray(shape, stype, density)
1920        np_lhs = mx_lhs.asnumpy()
1921        mx_rhs_row_2D = rand_ndarray((1, num_cols), 'default')
1922        mx_rhs_row_1D = mx_rhs_row_2D.reshape((num_cols))
1923        mx_rhs_col = rand_ndarray((num_rows, 1), 'default')
1924        mx_rhs_scalar_2D = rand_ndarray((1, 1), 'default')
1925        mx_rhs_scalar_1D = mx_rhs_scalar_2D.reshape((1, ))
1926        for mx_rhs in [mx_rhs_row_2D, mx_rhs_row_1D, mx_rhs_col, mx_rhs_scalar_2D, mx_rhs_scalar_1D]:
1927            np_rhs = mx_rhs.asnumpy()
1928            check_broadcast_add(mx_lhs, mx_rhs, np_lhs, np_rhs, np.float32)
1929            check_broadcast_sub(mx_lhs, mx_rhs, np_lhs, np_rhs, np.float32)
1930            check_broadcast_add(mx_rhs, mx_lhs, np_rhs, np_lhs, np.float32)
1931            check_broadcast_sub(mx_rhs, mx_lhs, np_rhs, np_lhs, np.float32)
1932
1933@with_seed()
1934def test_sparse_broadcast_mul_div():
1935    def check_broadcast_mul(mx_lhs, mx_rhs, np_lhs, np_rhs, dtype):
1936        assert_almost_equal(mx.nd.sparse.multiply(mx_lhs, mx_rhs).asnumpy(), np.multiply(np_lhs, np_rhs), atol=1e-4)
1937    def check_broadcast_div(mx_lhs, mx_rhs, np_lhs, np_rhs, dtype):
1938        assert_almost_equal(mx.nd.sparse.divide(mx_lhs, mx_rhs).asnumpy(), np.divide(np_lhs, np_rhs), atol=1e-4)
1939    stype = 'csr'
1940    shape = rand_shape_2d()
1941    num_rows = shape[0]
1942    num_cols = shape[1]
1943    for density in [0.1 * i for i in range(10)]:
1944        mx_lhs = rand_ndarray(shape, stype, density)
1945        np_lhs = mx_lhs.asnumpy()
1946        mx_rhs_row_2D = rand_ndarray((1, num_cols), 'default')
1947        mx_rhs_row_1D = mx_rhs_row_2D.reshape((num_cols))
1948        mx_rhs_col = rand_ndarray((num_rows, 1), 'default')
1949        mx_rhs_scalar_2D = rand_ndarray((1, 1), 'default')
1950        mx_rhs_scalar_1D = mx_rhs_scalar_2D.reshape((1, ))
1951        for mx_rhs in [mx_rhs_row_2D, mx_rhs_row_1D, mx_rhs_col, mx_rhs_scalar_2D, mx_rhs_scalar_1D]:
1952            np_rhs = mx_rhs.asnumpy()
1953            check_broadcast_mul(mx_lhs, mx_rhs, np_lhs, np_rhs, np.float32)
1954            check_broadcast_div(mx_lhs, mx_rhs, np_lhs, np_rhs, np.float32)
1955
1956@with_seed()
1957def test_scatter_ops():
1958    def csr_get_seen_points(name, csr_array, verbose=False):
1959        """Get a unique list of points int he CSR array as well as a
1960        corresponding parallel list of points and values"""
1961        seen_points = set()
1962        seen_point_list = list()
1963        values = list()
1964        row_count = csr_array.shape[0]
1965        row_pointers = csr_array.indptr.asnumpy()
1966        col_indexes  = csr_array.indices.asnumpy()
1967        data = csr_array.data.asnumpy()
1968        for row in range(row_count):
1969            start_pos = row_pointers[row]
1970            end_pos = row_pointers[row + 1]
1971            for col_index in range(start_pos, end_pos):
1972                col = col_indexes[col_index]
1973                val = data[col_index]
1974                if verbose is True:
1975                    print("{}: (row, col = ({}, {}) = {}".format(name, row, col, val))
1976                seen_points.add((row, col))
1977                seen_point_list.append((row, col))
1978                values.append(val)
1979        return seen_points, values, seen_point_list
1980
1981    def check_scatter_ops(name, shape, lhs_stype, rhs_stype, forward_mxnet_call, forward_numpy_call,
1982                          density=0.25, rhs_is_scalar=False, verbose=False):
1983        lhs = mx.symbol.Variable('lhs', stype=lhs_stype)
1984        if rhs_is_scalar is False:
1985            rhs = mx.symbol.Variable('rhs', stype=rhs_stype)
1986
1987        if verbose is True:
1988            print(name)
1989
1990        if lhs_stype != 'default':
1991            lhs_nd = create_sparse_array_zd(
1992                shape, lhs_stype, density=density,
1993                rsp_indices=gen_rsp_random_indices(
1994                    shape,
1995                    density=density,
1996                    force_indices=[(shape[0]/2)]  # force at least one overlap
1997                ))
1998        else:
1999            lhs_nd = rand_ndarray(shape, 'default')
2000
2001        if rhs_is_scalar is False:
2002            if rhs_stype != 'default':
2003                rhs_nd = create_sparse_array_zd(
2004                    shape, rhs_stype, density=density,
2005                    rsp_indices=gen_rsp_random_indices(
2006                        shape,
2007                        density=density,
2008                        force_indices=[(shape[0]/2)]  # force at least one overlap
2009                    ))
2010            else:
2011                rhs_nd = rand_ndarray(shape, 'default')
2012        else:
2013            rhs_nd = 9
2014            rhs = rhs_nd
2015
2016        lhs_np = lhs_nd.asnumpy()
2017        rhs_np = rhs_nd if rhs_is_scalar is True else rhs_nd.asnumpy()
2018
2019        if verbose is True:
2020            print("lhs = {}".format(lhs_np))
2021            print("rhs = {}".format(rhs_np))
2022
2023        out_np = forward_numpy_call(lhs_np, rhs_np)
2024
2025        if verbose is True:
2026            print("Numpy: out_np = {}".format(out_np))
2027
2028        location = {'lhs': lhs_nd, 'rhs': rhs_nd}
2029
2030        out = forward_mxnet_call(lhs, rhs)
2031        exe_test = out.bind(default_context(), args=location)
2032        exe_test.forward(is_train=False)
2033        out_nd = exe_test.outputs[0]
2034
2035        if verbose is True:
2036            print("Sym: out_nd = {}".format(out_nd.asnumpy()))
2037
2038        # For row_sparse, check that rows only exist for rows that are
2039        # either int lhs or rhs, and if they exist, they should equal
2040        # the numpy values
2041        if lhs_stype == 'default':
2042            almost_equal(out_nd.asnumpy(), out_np, equal_nan=True)
2043        elif lhs_stype == 'row_sparse':
2044            seen_rows = set()
2045            indices = lhs_nd.indices.asnumpy()
2046            for i in range(len(indices)):
2047                seen_rows.add(indices[i])
2048            assert len(out_nd.indices.asnumpy()) == len(seen_rows)
2049            out_nd_np = out_nd.asnumpy()
2050            for row in seen_rows:
2051                row_nd = out_nd_np[row]
2052                row_np = out_np[row]
2053                almost_equal(row_nd, row_np, equal_nan=True)
2054        elif lhs_stype == 'csr' and rhs_is_scalar is False:
2055            almost_equal(out_nd.asnumpy(), out_np, equal_nan=True)
2056        else:
2057            assert rhs_is_scalar
2058            lhs_seen_points, _, _ = csr_get_seen_points("lhs", lhs_nd, verbose)
2059            if rhs_is_scalar is False:
2060                rhs_seen_points, _, _ = csr_get_seen_points("rhs", rhs_nd, verbose)
2061            else:
2062                rhs_seen_points = set()
2063            input_seen_points = lhs_seen_points.union(rhs_seen_points)
2064            out_seen_pounts, out_values, seen_point_list = csr_get_seen_points("out_nd", out_nd, verbose)
2065            # Some may have been zero
2066            assert len(out_seen_pounts) <= len(input_seen_points)
2067            out_nd_np = out_nd.asnumpy()
2068            val_index = 0
2069            for row_col in seen_point_list:
2070                row = row_col[0]
2071                col = row_col[1]
2072                val = out_values[val_index]
2073                val_np = out_nd_np[row, col]
2074                almost_equal(val, val_np, equal_nan=True)
2075                val_index += 1
2076
2077    shape = (10, 5)
2078
2079    for lhs_stype in ['row_sparse', 'default', 'csr']:
2080        for rhs_stype in ['row_sparse', 'default', 'csr']:
2081            print("op: {}, lhs_stype: {}, rhs_stype: {}".format('_scatter_elemwise_div',
2082                                                                lhs_stype, rhs_stype))
2083            check_scatter_ops('_scatter_elemwise_div', shape, lhs_stype, rhs_stype,
2084                              lambda l, r: mx.sym._internal._scatter_elemwise_div(l, r),
2085                              lambda l, r: l / r,
2086                              verbose=False)
2087
2088    for lhs_stype in ['row_sparse', 'default', 'csr']:
2089        print("op: {}, lhs_stype: {}".format('_scatter_plus', lhs_stype))
2090        check_scatter_ops('_scatter_plus', shape, lhs_stype, 'scalar',
2091                          lambda l, r: mx.sym._internal._scatter_plus_scalar(l, r),
2092                          lambda l, r: l + r,
2093                          rhs_is_scalar=True, verbose=False)
2094
2095        print("op: {}, lhs_stype: {}".format('_scatter_minus', lhs_stype))
2096        check_scatter_ops('_scatter_minus', shape, lhs_stype, 'scalar',
2097                          lambda l, r: mx.sym._internal._scatter_minus_scalar(l, r),
2098                          lambda l, r: l + r,
2099                          rhs_is_scalar=True, verbose=False, density=0.5)
2100
2101
2102@with_seed()
2103def test_batchnorm_fallback():
2104    # same test as test_operator.test_batchnorm_training, but tests fallback logic of batchnorm
2105    stype = 'row_sparse'
2106    for shape in [(2, 3), (2, 3, 2, 2)]:
2107        data_tmp = np.random.normal(-0.1, 0.1, size=shape)
2108        s = shape[1],
2109        gamma = np.ones(s)
2110        beta = np.ones(s)
2111        gamma[1] = 3
2112        beta[0] = 3
2113
2114        rolling_mean = np.random.uniform(size=s)
2115        rolling_std = np.random.uniform(size=s)
2116
2117        data = mx.symbol.Variable('data', stype=stype)
2118        in_location = [mx.nd.array(data_tmp).tostype(stype), mx.nd.array(gamma).tostype(stype),
2119                        mx.nd.array(beta).tostype(stype)]
2120        mean_std = [mx.nd.array(rolling_mean).tostype(stype), mx.nd.array(rolling_std).tostype(stype)]
2121
2122        test = mx.symbol.BatchNorm(data, fix_gamma=True)
2123        assertRaises(MXNetError, check_numeric_gradient, test, in_location, mean_std, numeric_eps=1e-3, rtol=0.16, atol=1e-2)
2124
2125        test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True)
2126        assertRaises(MXNetError, check_numeric_gradient, test, in_location, mean_std, numeric_eps=1e-3, rtol=0.16, atol=1e-2)
2127
2128        test = mx.symbol.BatchNorm(data, fix_gamma=False)
2129        check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-3, rtol=0.16, atol=1e-2)
2130
2131        test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True)
2132        check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-3, rtol=0.16, atol=1e-2)
2133
2134        # Test varying channel axis
2135        dim = len(shape)
2136        for chaxis in range(-dim, dim):
2137            chaxis_true = chaxis
2138            if chaxis < 0:
2139                chaxis_true = dim + chaxis
2140
2141            shapex = shape
2142
2143            channel_count = shapex[chaxis_true]
2144            data_tmp = np.random.normal(-0.1, 0.1, size=shapex)
2145
2146            gamma = np.ones(channel_count)
2147            beta = np.ones(channel_count)
2148            if channel_count > 1:
2149                gamma[1] = 3
2150            beta[0] = 3
2151
2152            in_location = [mx.nd.array(data_tmp).tostype(stype), mx.nd.array(gamma).tostype(stype),
2153                            mx.nd.array(beta).tostype(stype)]
2154
2155            xrolling_mean = np.random.uniform(size=channel_count)
2156            xrolling_std = np.random.uniform(size=channel_count)
2157            xmean_std = [mx.nd.array(xrolling_mean).tostype(stype),
2158                            mx.nd.array(xrolling_std).tostype(stype)]
2159
2160            test = mx.symbol.BatchNorm(data, fix_gamma=True, axis=chaxis)
2161            assertRaises(MXNetError, check_numeric_gradient, test, in_location, xmean_std, numeric_eps=1e-3, rtol=0.2, atol=0.01)
2162
2163            test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True, axis=chaxis)
2164            assertRaises(MXNetError, check_numeric_gradient, test, in_location, xmean_std, numeric_eps=1e-3, rtol=0.2, atol=0.01)
2165
2166            test = mx.symbol.BatchNorm(data, fix_gamma=False, axis=chaxis)
2167            check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-3, rtol=0.2, atol=0.01)
2168
2169            test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True, axis=chaxis)
2170            check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-3, rtol=0.2, atol=0.01)
2171
2172
2173@with_seed()
2174def test_mkldnn_sparse():
2175    # This test is trying to create a race condition describedd in
2176    # https://github.com/apache/incubator-mxnet/issues/10189
2177    arr = mx.nd.random.uniform(shape=(10, 10, 32, 32))
2178    weight1 = mx.nd.random.uniform(shape=(10, 10, 3, 3))
2179    arr = mx.nd.Convolution(data=arr, weight=weight1, no_bias=True, kernel=(3, 3), num_filter=10)
2180
2181    rs_arr = mx.nd.sparse.row_sparse_array((mx.nd.zeros_like(arr), np.arange(arr.shape[0])))
2182    weight2 = mx.nd.random.uniform(shape=(10, np.prod(arr.shape[1:4])))
2183    fc_res = mx.nd.FullyConnected(data=arr, weight=weight2, no_bias=True, num_hidden=10)
2184    sum_res = mx.nd.elemwise_sub(arr, rs_arr)
2185    res1 = np.dot(mx.nd.flatten(sum_res).asnumpy(), weight2.asnumpy().T)
2186    print(res1 - fc_res.asnumpy())
2187    almost_equal(res1, fc_res.asnumpy())
2188
2189@with_seed()
2190def test_sparse_nd_where():
2191    def get_forward_expected_output(condition, x, y):
2192        original_shape = x.shape
2193        out = np.zeros(original_shape)
2194        if condition.shape == x.shape:
2195            for index, c in np.ndenumerate(condition):
2196                if c != 0:
2197                    out[index] = x[index]
2198                else:
2199                    out[index] = y[index]
2200        else:
2201            raise RuntimeError("Invalid condition shape for where op")
2202
2203        out = out.reshape(original_shape)
2204        return out
2205
2206    def get_forward_inputs_same_shape(shape):
2207        condition_np = np.random.randint(0, 2, np.prod(shape)).reshape(shape)
2208        x_np = np.random.randint(1, 6, np.prod(shape)).reshape(shape)
2209        y_np = np.random.randint(7, 11, np.prod(shape)).reshape(shape)
2210        return condition_np, x_np, y_np
2211
2212    def get_backward_input(shape):
2213        return np.random.randint(20, 30, np.prod(shape)).reshape(shape)
2214
2215    def get_backward_expected_outputs(grad_in, condition):
2216        shape = grad_in.shape
2217        grad_cond = np.zeros(condition.shape)
2218        grad_x = np.empty(shape)
2219        grad_y = np.empty(shape)
2220
2221        for index, c in np.ndenumerate(condition):
2222            if 0 != c:
2223                grad_x[index] = grad_in[index]
2224                grad_y[index] = 0
2225            else:
2226                grad_x[index] = 0
2227                grad_y[index] = grad_in[index]
2228
2229        return grad_cond, grad_x, grad_y
2230
2231    def test_where_helper(shape):
2232        condition_np, x_np, y_np = get_forward_inputs_same_shape(shape)
2233
2234        out_expected = get_forward_expected_output(condition_np, x_np, y_np)
2235
2236        grad_in_np = get_backward_input(shape)
2237        grad_expected_cond, grad_expected_x, grad_expected_y \
2238            = get_backward_expected_outputs(grad_in_np, condition_np)
2239
2240        condition = mx.sym.Variable('condition', stype='csr')
2241        x = mx.sym.Variable('x')
2242        y = mx.sym.Variable('y')
2243        grad_in_mx = mx.nd.array(grad_in_np, dtype=np.int32)
2244        where_sym = mx.sym.where(condition, x, y)
2245
2246        # test req='write'
2247        where_exe_write = where_sym.simple_bind(ctx=default_context(),
2248                                                condition=condition_np.shape,
2249                                                x=x_np.shape, y=y_np.shape,
2250                                                grad_req='write')
2251        # test forward req='write'
2252        cond_nd = mx.nd.array(condition_np).tostype('csr')
2253        outputs = where_exe_write.forward(is_train=True, \
2254                                          condition=cond_nd, x=x_np, y=y_np)
2255        assert same(outputs[0].asnumpy(), out_expected)
2256        # test backward req='write'
2257        where_exe_write.backward(grad_in_mx)
2258        assert same(where_exe_write.grad_dict['x'].asnumpy(), grad_expected_x)
2259        assert same(where_exe_write.grad_dict['y'].asnumpy(), grad_expected_y)
2260        assert same(where_exe_write.grad_dict['condition'].asnumpy(), grad_expected_cond)
2261
2262        # test req='add'
2263        x_grad_init = np.random.randint(30, 40, np.prod(shape)).reshape(shape)
2264        y_grad_init = np.random.randint(40, 50, np.prod(shape)).reshape(shape)
2265        where_exe_add = where_sym.simple_bind(ctx=default_context(),
2266                                              condition=cond_nd.shape,
2267                                              x=x_np.shape, y=y_np.shape,
2268                                              grad_req='add')
2269        where_exe_add.grad_dict['x'][:] = x_grad_init
2270        where_exe_add.grad_dict['y'][:] = y_grad_init
2271        # test forward req='add'
2272        outputs = where_exe_add.forward(is_train=True, condition=cond_nd, x=x_np, y=y_np)
2273        assert same(outputs[0].asnumpy(), out_expected)
2274
2275    def test_where_numeric_gradient(shape):
2276        condition = mx.sym.Variable('condition', stype='csr')
2277        x = mx.sym.Variable('x')
2278        y = mx.sym.Variable('y')
2279        where_sym = mx.sym.where(condition, x, y)
2280        condition_np, x_np, y_np = get_forward_inputs_same_shape(shape)
2281        check_numeric_gradient(where_sym, [condition_np, x_np, y_np], grad_nodes=['x', 'y'])
2282
2283    test_where_helper((5, 9))
2284    test_where_numeric_gradient((5, 9))
2285
2286@with_seed()
2287def test_sparse_quadratic_function():
2288    def f(x, a, b, c):
2289        return a * x**2 + b * x + c
2290
2291    def check_sparse_quadratic_function(a, b, c, expected_stype):
2292      # check forward and compare the result with dense op
2293      ndim = 2
2294      shape = rand_shape_nd(ndim, 5)
2295      data = rand_ndarray(shape=shape, stype='csr')
2296      data_np = data.asnumpy()
2297      expected = f(data_np, a, b, c)
2298      output = mx.nd.contrib.quadratic(data, a=a, b=b, c=c)
2299      assert(output.stype == expected_stype)
2300      assert_almost_equal(output.asnumpy(), expected)
2301
2302    a = np.random.random_sample()
2303    b = np.random.random_sample()
2304    check_sparse_quadratic_function(a, b, 0.0, 'csr')
2305    check_sparse_quadratic_function(a, b, 1.0, 'default')
2306
2307def test_reshape_backward_fallback():
2308    """
2309     out
2310     |  \
2311    w_x  x
2312     /
2313    w
2314    in which x is a sparse tensor.
2315    Due to sparse gradient optimization in sym.dot, grad(w_x) is sparse.
2316    Though sym.reshape itself does not have sparse version,
2317    if we somehow make grad(w) sparse as well, e.g.,
2318        - by setting args_grad in symbol.bind
2319        - or, we can have out_y = sym.dot(sparse_y, w), then grad(w) will be inferred as sparse
2320    reshape backward (from w_x to w) needs to understand how to handle sparse inputs.
2321    """
2322    ctx = default_context()
2323    w_shape = (12, 4)
2324    w_x_shape = (1, 48)
2325    x_nd = rand_ndarray((4, 1), 'csr')
2326
2327    w_nd = rand_ndarray(w_shape)
2328
2329    w_x_nd = w_nd.reshape(w_x_shape)
2330    out_x_nd = mx.nd.dot(x_nd, w_x_nd)
2331
2332    w_x_backward_grad = mx.nd.dot(x_nd, out_x_nd, transpose_a=True).asnumpy()
2333    expected_grad_nd = w_x_backward_grad.reshape(w_shape)
2334
2335    x = mx.sym.Variable('x', stype='csr')
2336    w = mx.sym.Variable('w')
2337
2338    w_x = mx.sym.reshape(w, w_x_shape, name="w_x")
2339    out = mx.sym.sparse.dot(x, w_x, name='out_x')
2340
2341    grad_w_nd = rand_ndarray(w_shape, 'row_sparse')
2342    executor = out.bind(ctx=ctx, args={"x": x_nd, "w": w_nd},
2343                        args_grad={"w": grad_w_nd})
2344    executor.forward(is_train=True)
2345    executor.backward(out_x_nd)
2346
2347    assert_almost_equal(grad_w_nd.asnumpy(), expected_grad_nd)
2348
2349if __name__ == '__main__':
2350    import nose
2351    nose.runmodule()
2352