1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18"""Tools for testing."""
19# pylint: disable=too-many-lines
20import time
21import gzip
22import struct
23import traceback
24import numbers
25import sys
26import os
27import platform
28import errno
29import logging
30import bz2
31import zipfile
32import json
33from contextlib import contextmanager
34from collections import OrderedDict
35import numpy as np
36import numpy.testing as npt
37import numpy.random as rnd
38try:
39    import scipy.stats as ss
40except ImportError:
41    ss = None
42try:
43    import requests
44except ImportError:
45    # in rare cases requests may be not installed
46    pass
47import mxnet as mx
48from .context import Context, current_context
49from .ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
50from .ndarray import array
51from .symbol import Symbol
52from .symbol.numpy import _Symbol as np_symbol
53from .util import use_np, getenv, setenv  # pylint: disable=unused-import
54from .runtime import Features
55from .numpy_extension import get_cuda_compute_capability
56
57
58def default_context():
59    """Get default context for regression test."""
60    # _TODO: get context from environment variable to support
61    # testing with GPUs
62    return current_context()
63
64
65def set_default_context(ctx):
66    """Set default context."""
67    Context._default_ctx.value = ctx
68
69
70def default_dtype():
71    """Get default data type for regression test."""
72    # _TODO: get default dtype from environment variable
73    return np.float32
74
75def default_rtols():
76    """Get default relative tolerances for data comparisons involving each data type."""
77    return {np.dtype(np.float16): 1e-2,
78            np.dtype(np.float32): 1e-4,
79            np.dtype(np.float64): 1e-5,
80            np.dtype(np.bool): 0,
81            np.dtype(np.int8): 0,
82            np.dtype(np.uint8): 0,
83            np.dtype(np.int32): 0,
84            np.dtype(np.uint32): 0,
85            np.dtype(np.int64): 0,
86            np.dtype(np.uint64): 0}
87
88def default_atols():
89    """Get default absolute tolerances for data comparisons involving each data type."""
90    return {np.dtype(np.float16): 1e-1,
91            np.dtype(np.float32): 1e-3,
92            np.dtype(np.float64): 1e-20,
93            np.dtype(np.bool): 0,
94            np.dtype(np.int8): 0,
95            np.dtype(np.uint8): 0,
96            np.dtype(np.int32): 0,
97            np.dtype(np.uint32): 0,
98            np.dtype(np.int64): 0,
99            np.dtype(np.uint64): 0}
100
101def default_numeric_eps():
102    """Get default epsilon for finite difference gradient calculations with data type."""
103    # prefer a power-of-two eps, since no bits are dropped when serving as an input delta
104    return {np.dtype(np.float16): 1.0 / 2**6,
105            np.dtype(np.float32): 1.0 / 2**9,
106            np.dtype(np.float64): 1.0 / 2**14}
107
108
109def effective_dtype(dat):
110    """ Return the most appropriate dtype for determining the tolerance used in dat comparisons
111    Parameters
112    ----------
113    dat : np.ndarray or mx.nd.array or mx.np.ndarray
114    """
115    # On arch 80 gpus, a float32-io gemm or conv op will trim the mantissa of data
116    # inputs to be of comparable precision to a float16, so float16 becomes the
117    # 'effective dtype' for tolerance tests involving such op outputs.
118
119    # Is TF32 enabled in the ctx (the default on arch 80 GPUs)
120    def is_TF32_enabled(ctx):
121        try:
122            return (ctx.device_type == 'gpu' and
123                    get_cuda_compute_capability(ctx) >= 80 and
124                    os.environ.get('NVIDIA_TF32_OVERRIDE') != '0')
125        except:  # pylint: disable=bare-except
126            return False
127
128    ctx = dat.ctx if hasattr(dat, 'ctx') else None
129    dtype = np.dtype(dat.dtype)
130    if dtype == np.dtype(np.float32) and is_TF32_enabled(ctx):
131        return np.dtype(np.float16)
132    else:
133        return dtype
134
135
136def get_tolerance(dat, tol, default_tol):
137    """ Return the tolerance to be used for dat comparisons based on the given tol, datatype and context.
138    Parameters
139    ----------
140    dat : np.ndarray or mx.nd.array or mx.np.ndarray
141    tol : float, or a dict of dtype->float
142    default_tol : default dict of dtype->float for all types
143    """
144
145    if isinstance(tol, numbers.Number):
146        return tol
147
148    # If the caller has supplied a tol dict, use that if it has an entry for dtype,
149    # else use the supplied default tol dict.
150    dtype = effective_dtype(dat)
151    tol = {} if tol is None else tol
152    return tol.get(dtype, default_tol[dtype])
153
154
155def get_tols(x, y, rtol, atol):
156    """For comparing two datasets 'x' and 'y', what relative and absolute tolerances should be used."""
157    # Tolerance analysis needs 'dtype' of 'x' and 'y', so convert numbers to numpy scalars as needed
158    if isinstance(x, numbers.Number):
159        x = np.array(x)
160    if isinstance(y, numbers.Number):
161        y = np.array(y)
162
163    # If tols are not specified, use the largest default tol for 'x' and 'y' based on their ctx and dtype.
164    rtol = max(get_tolerance(x, rtol, default_rtols()),
165               get_tolerance(y, rtol, default_rtols()))
166    atol = max(get_tolerance(x, atol, default_atols()),
167               get_tolerance(y, atol, default_atols()))
168
169    return rtol, atol
170
171
172def get_atol(atol=None, dtype=np.dtype(np.float64)):
173    """Get default numerical threshold for regression test."""
174    return default_atols()[dtype] if atol is None else atol
175
176def get_rtol(rtol=None, dtype=np.dtype(np.float64)):
177    """Get default numerical threshold for regression test."""
178    return default_rtols()[dtype] if rtol is None else rtol
179
180def get_etol(etol=None):
181    """Get default numerical threshold for regression test."""
182    # _TODO: get from env variable, different threshold might
183    # be needed for different device and dtype
184    return 0 if etol is None else etol
185
186def random_arrays(*shapes):
187    """Generate some random numpy arrays."""
188    arrays = [np.array(np.random.randn(), dtype=default_dtype())
189              if len(s) == 0 else np.random.randn(*s).astype(default_dtype())
190              for s in shapes]
191    if len(arrays) == 1:
192        return arrays[0]
193    return arrays
194
195
196def random_uniform_arrays(*shapes, **kwargs):
197    """Generate some random numpy arrays."""
198    low = kwargs.pop('low', 0.0)
199    high = kwargs.pop('high', 1.0)
200    dtype = kwargs.pop('dtype', default_dtype())
201    if len(kwargs) > 0:
202        raise TypeError('Got unexpected argument/s : ' + str(kwargs.keys()))
203    arrays = [np.random.uniform(low, high, size=s).astype(dtype)
204              for s in shapes]
205    return arrays
206
207
208def random_sample(population, k):
209    """Return a k length list of the elements chosen from the population sequence."""
210    assert 0 <= k <= len(population)
211    population_copy = population[:]
212    np.random.shuffle(population_copy)
213    return population_copy[0:k]
214
215
216def _sorted_items(d):
217    """Return (key, value) pairs of dict 'd' in a deterministic order (sorted by key)."""
218    return sorted(d.items(), key=lambda t: t[0])
219
220
221def _sorted_dict(d):
222    """Return ordered dictionary containing items ordered by their keys."""
223    return OrderedDict(_sorted_items(d))
224
225
226def _validate_csr_generation_inputs(num_rows, num_cols, density,
227                                    distribution="uniform"):
228    """Validates inputs for csr generation helper functions
229    """
230    total_nnz = int(num_rows * num_cols * density)
231    if density < 0 or density > 1:
232        raise ValueError("density has to be between 0 and 1")
233
234    if num_rows <= 0 or num_cols <= 0:
235        raise ValueError("num_rows or num_cols should be greater than 0")
236
237    if distribution == "powerlaw":
238        if total_nnz < 2 * num_rows:
239            raise ValueError("not supported for this density: %s"
240                             " for this shape (%s, %s)"
241                             " Please keep :"
242                             " num_rows * num_cols * density >= 2 * num_rows"
243                             % (density, num_rows, num_cols))
244
245
246def shuffle_csr_column_indices(csr):
247    """Shuffle CSR column indices per row
248    This allows validation of unordered column indices, which is not a requirement
249    for a valid CSR matrix
250    """
251    row_count = len(csr.indptr) - 1
252    for i in range(row_count):
253        start_index = csr.indptr[i]
254        end_index = csr.indptr[i + 1]
255        sublist = np.array(csr.indices[start_index : end_index])
256        np.random.shuffle(sublist)
257        csr.indices[start_index : end_index] = sublist
258
259
260def _get_uniform_dataset_csr(num_rows, num_cols, density=0.1, dtype=None,
261                             data_init=None, shuffle_csr_indices=False):
262    """Returns CSRNDArray with uniform distribution
263    This generates a csr matrix with totalnnz unique randomly chosen numbers
264    from num_rows*num_cols and arranges them in the 2d array in the
265    following way:
266    row_index = (random_number_generated / num_rows)
267    col_index = random_number_generated - row_index * num_cols
268    """
269    _validate_csr_generation_inputs(num_rows, num_cols, density,
270                                    distribution="uniform")
271    try:
272        from scipy import sparse as spsp
273        csr = spsp.rand(num_rows, num_cols, density, dtype=dtype, format="csr")
274        if data_init is not None:
275            csr.data.fill(data_init)
276        if shuffle_csr_indices is True:
277            shuffle_csr_column_indices(csr)
278        result = mx.nd.sparse.csr_matrix((csr.data, csr.indices, csr.indptr),
279                                         shape=(num_rows, num_cols), dtype=dtype)
280    except ImportError:
281        assert(data_init is None), \
282               "data_init option is not supported when scipy is absent"
283        assert(not shuffle_csr_indices), \
284               "shuffle_csr_indices option is not supported when scipy is absent"
285        # scipy not available. try to generate one from a dense array
286        dns = mx.nd.random.uniform(shape=(num_rows, num_cols), dtype=dtype)
287        masked_dns = dns * (dns < density)
288        result = masked_dns.tostype('csr')
289    return result
290
291def _get_powerlaw_dataset_csr(num_rows, num_cols, density=0.1, dtype=None):
292    """Returns CSRNDArray with powerlaw distribution
293    with exponentially increasing number of non zeros in each row.
294    Not supported for cases where total_nnz < 2*num_rows. This is because
295    the algorithm first tries to ensure that there are rows with no zeros by
296    putting non zeros at beginning of each row.
297    """
298
299    _validate_csr_generation_inputs(num_rows, num_cols, density,
300                                    distribution="powerlaw")
301
302    total_nnz = int(num_rows * num_cols * density)
303
304    unused_nnz = total_nnz
305    output_arr = np.zeros((num_rows, num_cols), dtype=dtype)
306    # Start with ones on each row so that no row is empty
307    for row in range(num_rows):
308        output_arr[row][0] = 1 + rnd.uniform(0.001, 2)
309        unused_nnz = unused_nnz - 1
310        if unused_nnz <= 0:
311            return mx.nd.array(output_arr).tostype("csr")
312
313    # Populate rest of matrix with 2^i items in ith row.
314    # if we have used all total nnz return the sparse matrix
315    # else if we reached max column size then fill up full columns until we use all nnz
316    col_max = 2
317    for row in range(num_rows):
318        col_limit = min(num_cols, col_max)
319        # In case col_limit reached assign same value to all elements, which is much faster
320        if col_limit == num_cols and unused_nnz > col_limit:
321            output_arr[row] = 1 + rnd.uniform(0.001, 2)
322            unused_nnz = unused_nnz - col_limit + 1
323            if unused_nnz <= 0:
324                return mx.nd.array(output_arr).tostype("csr")
325            else:
326                continue
327        for col_index in range(1, col_limit):
328            output_arr[row][col_index] = 1 + rnd.uniform(0.001, 2)
329            unused_nnz = unused_nnz - 1
330            if unused_nnz <= 0:
331                return mx.nd.array(output_arr).tostype("csr")
332        col_max = col_max * 2
333
334    if unused_nnz > 0:
335        raise ValueError("not supported for this density: %s"
336                         " for this shape (%s,%s)" % (density, num_rows, num_cols))
337
338    return mx.nd.array(output_arr).tostype("csr")
339
340
341def assign_each(the_input, function):
342    """Return ndarray composed of passing each array value through some function"""
343    if function is None:
344        output = np.array(the_input)
345    else:
346        it_input = np.nditer(the_input, flags=['f_index'])
347
348        output = np.zeros(the_input.shape)
349        it_out = np.nditer(output, flags=['f_index'], op_flags=['writeonly'])
350
351        while not it_input.finished:
352            val_input = it_input[0]
353            it_out[0] = function(val_input)
354            it_input.iternext()
355            it_out.iternext()
356
357    return output
358
359def assign_each2(input1, input2, function):
360    """Return ndarray composed of passing two array values through some function"""
361    if function is None:
362        output = np.array(input1)
363    else:
364        assert input1.shape == input2.shape
365        it_input1 = np.nditer(input1, flags=['f_index'])
366        it_input2 = np.nditer(input2, flags=['f_index'])
367
368        output = np.zeros(input1.shape)
369        it_out = np.nditer(output, flags=['f_index'], op_flags=['writeonly'])
370
371        while not it_input1.finished:
372            val_input1 = it_input1[0]
373            val_input2 = it_input2[0]
374            it_out[0] = function(val_input1, val_input2)
375            it_input1.iternext()
376            it_input2.iternext()
377            it_out.iternext()
378
379    return output
380
381# For testing Large Tensors having total size > 2^32 elements
382def create_2d_tensor(rows, columns, dtype=np.int64):
383    a = mx.nd.arange(0, rows, dtype=dtype).reshape(rows, 1)
384    b = mx.nd.broadcast_to(a, shape=(a.shape[0], columns))
385    return b
386
387# For testing Large Vectors having total size > 2^32 elements
388def create_vector(size, dtype=np.int64):
389    a = mx.nd.arange(0, size, dtype=dtype)
390    return a
391
392# For testing Large Square Matrix with total size > 2^32 elements
393def get_identity_mat(size):
394    A = mx.nd.zeros((size, size))
395    for i in range(size):
396        A[i, i] = 1
397    return A
398
399# For testing Batch of Large Square Matrix with total size > 2^32 elements
400def get_identity_mat_batch(size):
401    A = get_identity_mat(size)
402    A_np = A.asnumpy()
403    return mx.nd.array([A_np, A_np])
404
405def rand_sparse_ndarray(shape, stype, density=None, dtype=None, distribution=None,
406                        data_init=None, rsp_indices=None, modifier_func=None,
407                        shuffle_csr_indices=False, ctx=None):
408    """Generate a random sparse ndarray. Returns the ndarray, value(np) and indices(np)
409
410    Parameters
411    ----------
412    shape: list or tuple
413    stype: str
414        valid values: "csr" or "row_sparse"
415    density: float, optional
416        should be between 0 and 1
417    distribution: str, optional
418        valid values: "uniform" or "powerlaw"
419    dtype: numpy.dtype, optional
420        default value is None
421
422    Returns
423    -------
424    Result of type CSRNDArray or RowSparseNDArray
425
426    Examples
427    --------
428    Below is an example of the powerlaw distribution with csr as the stype.
429    It calculates the nnz using the shape and density.
430    It fills up the ndarray with exponentially increasing number of elements.
431    If there are enough unused_nnzs, n+1th row will have twice more nnzs compared to nth row.
432    else, remaining unused_nnzs will be used in n+1th row
433    If number of cols is too small and we have already reached column size it will fill up
434    all following columns in all followings rows until we reach the required density.
435
436    >>> csr_arr, _ = rand_sparse_ndarray(shape=(5, 16), stype="csr",
437                                         density=0.50, distribution="powerlaw")
438    >>> indptr = csr_arr.indptr.asnumpy()
439    >>> indices = csr_arr.indices.asnumpy()
440    >>> data = csr_arr.data.asnumpy()
441    >>> row2nnz = len(data[indptr[1]:indptr[2]])
442    >>> row3nnz = len(data[indptr[2]:indptr[3]])
443    >>> assert(row3nnz == 2*row2nnz)
444    >>> row4nnz = len(data[indptr[3]:indptr[4]])
445    >>> assert(row4nnz == 2*row3nnz)
446
447    """
448    ctx = ctx if ctx else default_context()
449    density = rnd.rand() if density is None else density
450    dtype = default_dtype() if dtype is None else dtype
451    distribution = "uniform" if distribution is None else distribution
452    if stype == 'row_sparse':
453        assert (distribution == "uniform"), \
454               "Distribution %s not supported for row_sparse" % (distribution)
455        # sample index
456        if rsp_indices is not None:
457            indices = rsp_indices
458            assert(len(indices) <= shape[0])
459        else:
460            idx_sample = rnd.rand(shape[0])
461            indices = np.argwhere(idx_sample < density).flatten()
462        if indices.shape[0] == 0:
463            result = mx.nd.zeros(shape, stype='row_sparse', dtype=dtype, ctx=ctx)
464            return result, (np.array([], dtype=dtype), np.array([]))
465        # generate random values
466        val = rnd.rand(indices.shape[0], *shape[1:]).astype(dtype)
467
468        # Allow caller to override or adjust random values
469        if data_init is not None:
470            val.fill(data_init)
471        if modifier_func is not None:
472            val = assign_each(val, modifier_func)
473
474        arr = mx.nd.sparse.row_sparse_array((val, indices), shape=shape, dtype=dtype, ctx=ctx)
475        return arr, (val, indices)
476    elif stype == 'csr':
477        assert len(shape) == 2
478        if distribution == "uniform":
479            csr = _get_uniform_dataset_csr(shape[0], shape[1], density,
480                                           data_init=data_init,
481                                           shuffle_csr_indices=shuffle_csr_indices, dtype=dtype).as_in_context(ctx)
482            return csr, (csr.indptr, csr.indices, csr.data)
483        elif distribution == "powerlaw":
484            csr = _get_powerlaw_dataset_csr(shape[0], shape[1], density=density, dtype=dtype).as_in_context(ctx)
485            return csr, (csr.indptr, csr.indices, csr.data)
486        else:
487            assert(False), "Distribution not supported: %s" % (distribution)
488            return False
489    else:
490        assert(False), "unknown storage type"
491        return False
492
493def rand_ndarray(shape, stype='default', density=None, dtype=None, modifier_func=None,
494                 shuffle_csr_indices=False, distribution=None, ctx=None):
495    """Generate a random sparse ndarray. Returns the generated ndarray."""
496    ctx = ctx if ctx else default_context()
497    if stype == 'default':
498        arr = mx.nd.array(random_arrays(shape), dtype=dtype, ctx=ctx)
499    else:
500        arr, _ = rand_sparse_ndarray(shape, stype, density=density,
501                                     modifier_func=modifier_func, dtype=dtype,
502                                     shuffle_csr_indices=shuffle_csr_indices,
503                                     distribution=distribution, ctx=ctx)
504    return arr
505
506
507def create_sparse_array(shape, stype, data_init=None, rsp_indices=None,
508                        dtype=None, modifier_func=None, density=.5,
509                        shuffle_csr_indices=False):
510    """Create a sparse array, For Rsp, assure indices are in a canonical format"""
511    if stype == 'row_sparse':
512        if rsp_indices is not None:
513            arr_indices = np.asarray(rsp_indices)
514            arr_indices.sort()
515        else:
516            arr_indices = None
517        arr_data, (_, _) = rand_sparse_ndarray(shape, stype,
518                                               density=density,
519                                               data_init=data_init,
520                                               rsp_indices=arr_indices,
521                                               dtype=dtype,
522                                               modifier_func=modifier_func)
523    elif stype == 'csr':
524        arr_data, (_, _, _) = rand_sparse_ndarray(shape,
525                                                  stype,
526                                                  density=density,
527                                                  data_init=data_init,
528                                                  dtype=dtype,
529                                                  modifier_func=modifier_func,
530                                                  shuffle_csr_indices=shuffle_csr_indices)
531    else:
532        msg = "Unknown storage type: " + stype
533        raise AssertionError(msg)
534
535    return arr_data
536
537
538def create_sparse_array_zd(shape, stype, density, data_init=None,
539                           rsp_indices=None, dtype=None, modifier_func=None,
540                           shuffle_csr_indices=False):
541    """Create sparse array, using only rsp_indices to determine density"""
542    if stype == 'row_sparse':
543        density = 0.0
544        if rsp_indices is not None:
545            assert len(rsp_indices) <= shape[0]
546    return create_sparse_array(shape, stype,
547                               data_init=data_init,
548                               rsp_indices=rsp_indices,
549                               dtype=dtype,
550                               modifier_func=modifier_func,
551                               density=density,
552                               shuffle_csr_indices=shuffle_csr_indices)
553
554
555def rand_shape_2d(dim0=10, dim1=10, allow_zero_size=False):
556    low = 0 if allow_zero_size else 1
557    return rnd.randint(low, dim0 + 1), rnd.randint(low, dim1 + 1)
558
559
560def rand_shape_3d(dim0=10, dim1=10, dim2=10, allow_zero_size=False):
561    low = 0 if allow_zero_size else 1
562    return rnd.randint(low, dim0 + 1), rnd.randint(low, dim1 + 1), rnd.randint(low, dim2 + 1)
563
564
565def rand_shape_nd(num_dim, dim=10, allow_zero_size=False):
566    low = 0 if allow_zero_size else 1
567    return tuple(rnd.randint(low, dim+1, size=num_dim))
568
569
570def rand_coord_2d(x_low, x_high, y_low, y_high):
571    x = np.random.randint(x_low, x_high, dtype=np.int64)
572    y = np.random.randint(y_low, y_high, dtype=np.int64)
573    return x, y
574
575
576def np_reduce(dat, axis, keepdims, numpy_reduce_func):
577    """Compatible reduce for old version of NumPy.
578
579    Parameters
580    ----------
581    dat : np.ndarray
582        Same as NumPy.
583
584    axis : None or int or list-like
585        Same as NumPy.
586
587    keepdims : bool
588        Same as NumPy.
589
590    numpy_reduce_func : function
591        A NumPy reducing function like ``np.sum`` or ``np.max``.
592    """
593    if isinstance(axis, int):
594        axis = [axis]
595    else:
596        axis = list(axis) if axis is not None else range(len(dat.shape))
597    ret = dat
598    for i in reversed(sorted(axis)):
599        ret = numpy_reduce_func(ret, axis=i)
600    if keepdims:
601        keepdims_shape = list(dat.shape)
602        for i in axis:
603            keepdims_shape[i] = 1
604        ret = ret.reshape(tuple(keepdims_shape))
605    return ret
606
607
608def _find_max_violation(a, b, rtol, atol):
609    """Finds and returns the location of maximum violation."""
610    # 'smart' absdiff that considers inf's as equals (to match np.allclose)
611    absdiff = np.where(np.equal(a, b), 0, np.abs(a-b))
612    tol = atol + rtol*np.abs(b)
613    violation = absdiff/(tol+1e-20)
614    loc = np.argmax(violation)
615    idx = np.unravel_index(loc, violation.shape)
616    return idx, np.max(violation)
617
618
619def same(a, b):
620    """Test if two NumPy arrays are the same.
621
622    Parameters
623    ----------
624    a : np.ndarray
625    b : np.ndarray
626    """
627    return np.array_equal(a, b)
628
629
630def checkShapes(a, b):
631    if a.shape != b.shape:
632        msg = npt.build_err_msg([a, b],
633                                err_msg="a.shape = {} and b.shape = {} are not equal"
634                                .format(str(a.shape), str(b.shape)))
635        raise AssertionError(msg)
636
637
638def almost_equal(a, b, rtol=None, atol=None, equal_nan=False, use_broadcast=True):
639    """Test if two numpy arrays are almost equal."""
640    # pylint: disable=unexpected-keyword-arg
641    if not use_broadcast:
642        checkShapes(a, b)
643
644    return np.allclose(a, b, rtol=get_rtol(rtol), atol=get_atol(atol), equal_nan=equal_nan)
645    # pylint: enable=unexpected-keyword-arg
646
647def locationError(a, b, index, names, maxError=False):
648    """Create element mismatch comment
649
650    Parameters
651    ----------
652    a, b : compared np.ndarray's
653    index : tuple of coordinate arrays
654        Location of violation
655    names : tuple of names
656        The names of compared arrays.
657    maxError: boolean, optional
658        Flag indicating that maximum error is reporting.
659    """
660    maximum = "maximum " if maxError else ""
661    return "Location of %serror: %s, %s=%.8f, %s=%.8f" \
662            % (maximum, str(index), names[0], a[index], names[1], b[index])
663
664def assert_almost_equal(a, b, rtol=None, atol=None, names=('a', 'b'), equal_nan=False,
665                        use_broadcast=True, mismatches=(10, 10)):
666    """Test that two numpy arrays are almost equal. Raise exception message if not.
667
668    Parameters
669    ----------
670    a : np.ndarray or mx.nd.array
671    b : np.ndarray or mx.nd.array
672    rtol : None or float or dict of dtype -> float
673        The relative threshold. Default threshold will be used if set to ``None``.
674    atol : None or float or dict of dtype -> float
675        The absolute threshold. Default threshold will be used if set to ``None``.
676    names : tuple of names, optional
677        The names used in error message when an exception occurs
678    equal_nan : boolean, optional
679        The flag determining how to treat NAN values in comparison
680    mismatches : tuple of mismatches
681        Maximum number of mismatches to be printed (mismatches[0]) and determine (mismatches[1])
682    """
683    if not use_broadcast:
684        checkShapes(a, b)
685
686    rtol, atol = get_tols(a, b, rtol, atol)
687
688    if isinstance(a, mx.numpy.ndarray):
689        a = a.asnumpy()
690    if isinstance(b, mx.numpy.ndarray):
691        b = b.asnumpy()
692    use_np_allclose = isinstance(a, np.ndarray) and isinstance(b, np.ndarray)
693    if not use_np_allclose:
694        if not (hasattr(a, 'context') and hasattr(b, 'context') and a.context == b.context and a.dtype == b.dtype):
695            use_np_allclose = True
696            if isinstance(a, mx.nd.NDArray):
697                a = a.asnumpy()
698            if isinstance(b, mx.nd.NDArray):
699                b = b.asnumpy()
700
701    if use_np_allclose:
702        if hasattr(a, 'dtype') and a.dtype == np.bool_ and hasattr(b, 'dtype') and b.dtype == np.bool_:
703            np.testing.assert_equal(a, b)
704            return
705        if almost_equal(a, b, rtol, atol, equal_nan=equal_nan):
706            return
707    else:
708        output = mx.nd.contrib.allclose(a, b, rtol, atol, equal_nan)
709        if output.asnumpy() == 1:
710            return
711
712        a = a.asnumpy()
713        b = b.asnumpy()
714
715    index, rel = _find_max_violation(a, b, rtol, atol)
716    if index != ():
717        # a, b are the numpy arrays
718        indexErr = index
719        relErr = rel
720
721        print('\n*** Maximum errors for vector of size {}:  rtol={}, atol={}\n'.format(a.size, rtol, atol))
722        aTmp = a.copy()
723        bTmp = b.copy()
724        i = 1
725        while i <= a.size:
726            if i <= mismatches[0]:
727                print("%3d: Error %f  %s" %(i, rel, locationError(a, b, index, names)))
728
729            aTmp[index] = bTmp[index] = 0
730            if almost_equal(aTmp, bTmp, rtol, atol, equal_nan=equal_nan):
731                break
732
733            i += 1
734            if i <= mismatches[1] or mismatches[1] <= 0:
735                index, rel = _find_max_violation(aTmp, bTmp, rtol, atol)
736            else:
737                break
738
739        mismatchDegree = "at least " if mismatches[1] > 0 and i > mismatches[1] else ""
740        errMsg = "Error %f exceeds tolerance rtol=%e, atol=%e (mismatch %s%f%%).\n%s" % \
741                 (relErr, rtol, atol, mismatchDegree, 100*i/a.size, \
742                  locationError(a, b, indexErr, names, maxError=True))
743    else:
744        errMsg = "Error %f exceeds tolerance rtol=%e, atol=%e.\n" % (rel, rtol, atol)
745
746    np.set_printoptions(threshold=4, suppress=True)
747    msg = npt.build_err_msg([a, b], err_msg=errMsg)
748
749    raise AssertionError(msg)
750
751
752def assert_allclose(a, b, rtol=1e-07, atol=0, equal_nan=True):
753    assert_almost_equal(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
754
755
756def assert_almost_equal_with_err(a, b, rtol=None, atol=None, etol=None,
757                                 names=('a', 'b'), equal_nan=False, mismatches=(10, 10)):
758    """Test that two numpy arrays are almost equal within given error rate. Raise exception message if not.
759
760    Parameters
761    ----------
762    a : np.ndarray
763    b : np.ndarray
764    rtol : None or float or dict of dtype -> float
765        The relative threshold. Default threshold will be used if set to ``None``.
766    atol : None or float or dict of dtype -> float
767        The absolute threshold. Default threshold will be used if set to ``None``.
768    threshold : None or float
769        The checking threshold. Default threshold will be used if set to ``None``.
770    etol : None or float
771        The error rate threshold. If etol is float, return true if error_rate < etol even if
772        any error is found.
773    names : tuple of names, optional
774        The names used in error message when an exception occurs
775    equal_nan : boolean, optional
776        The flag determining how to treat NAN values in comparison
777    mismatches : tuple of mismatches
778        Maximum number of mismatches to be printed (mismatches[0]) and determine (mismatches[1])
779    """
780    etol = get_etol(etol)
781    if etol > 0:
782        rtol, atol = get_tols(a, b, rtol, atol)
783        if isinstance(a, mx.nd.NDArray):
784            a = a.asnumpy()
785        if isinstance(b, mx.nd.NDArray):
786            b = b.asnumpy()
787        equals = np.isclose(a, b, rtol=rtol, atol=atol)
788        err = 1 - np.count_nonzero(equals) / equals.size
789        if err > etol:
790            index, rel = _find_max_violation(a, b, rtol, atol)
791            indexErr = index
792            relErr = rel
793
794            print('\n*** Maximum errors for vector of size {}:  rtol={}, atol={}\n'.format(a.size, rtol, atol))
795            aTmp = a.copy()
796            bTmp = b.copy()
797            i = 1
798            while i <= a.size:
799                if i <= mismatches[0]:
800                    print("%3d: Error %f  %s" %(i, rel, locationError(a, b, index, names)))
801
802                aTmp[index] = bTmp[index] = 0
803                if almost_equal(aTmp, bTmp, rtol, atol, equal_nan=equal_nan):
804                    break
805
806                i += 1
807                if i <= mismatches[1] or mismatches[1] <= 0:
808                    index, rel = _find_max_violation(aTmp, bTmp, rtol, atol)
809                else:
810                    break
811
812            mismatchDegree = "at least " if mismatches[1] > 0 and i > mismatches[1] else ""
813            errMsg = "Error %f exceeds tolerance rtol=%e, atol=%e (mismatch %s%f%%).\n%s" % \
814                    (relErr, rtol, atol, mismatchDegree, 100*i/a.size, \
815                    locationError(a, b, indexErr, names, maxError=True))
816            np.set_printoptions(threshold=4, suppress=True)
817            msg = npt.build_err_msg([a, b], err_msg=errMsg)
818            raise AssertionError(msg)
819    else:
820        assert_almost_equal(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
821
822
823def assert_almost_equal_ignore_nan(a, b, rtol=None, atol=None, names=('a', 'b')):
824    """Test that two NumPy arrays are almost equal (ignoring NaN in either array).
825    Combines a relative and absolute measure of approximate eqality.
826    If either the relative or absolute check passes, the arrays are considered equal.
827    Including an absolute check resolves issues with the relative check where all
828    array values are close to zero.
829
830    Parameters
831    ----------
832    a : np.ndarray
833    b : np.ndarray
834    rtol : None or float
835        The relative threshold. Default threshold will be used if set to ``None``.
836    atol : None or float
837        The absolute threshold. Default threshold will be used if set to ``None``.
838    """
839    a = np.copy(a)
840    b = np.copy(b)
841    nan_mask = np.logical_or(np.isnan(a), np.isnan(b))
842    a[nan_mask] = 0
843    b[nan_mask] = 0
844
845    assert_almost_equal(a, b, rtol, atol, names)
846
847def assert_exception(f, exception_type, *args, **kwargs):
848    """Test that function f will throw an exception of type given by `exception_type`"""
849    try:
850        f(*args, **kwargs)
851        assert(False)
852    except exception_type:
853        return
854
855def retry(n):
856    """Retry n times before failing for stochastic test cases."""
857    assert n > 0
858    def decorate(f):
859        """Decorate a test case."""
860        def wrapper(*args, **kwargs):
861            """Wrapper for tests function."""
862            for i in range(n):
863                try:
864                    f(*args, **kwargs)
865                    return
866                except AssertionError as e:
867                    if i == n-1:
868                        raise e
869                    mx.nd.waitall()
870        return wrapper
871    return decorate
872
873
874def simple_forward(sym, ctx=None, is_train=False, **inputs):
875    """A simple forward function for a symbol.
876
877    Primarily used in doctest to test the functionality of a symbol.
878    Takes NumPy arrays as inputs and outputs are also converted to NumPy arrays.
879
880    Parameters
881    ----------
882    ctx : Context
883        If ``None``, will take the default context.
884    inputs : keyword arguments
885        Mapping each input name to a NumPy array.
886
887    Returns
888    -------
889    The result as a numpy array. Multiple results will
890    be returned as a list of NumPy arrays.
891    """
892    ctx = ctx or default_context()
893    inputs = {k: array(v) for k, v in inputs.items()}
894    exe = sym.bind(ctx, args=inputs)
895    exe.forward(is_train=is_train)
896    outputs = [x.asnumpy() for x in exe.outputs]
897    if len(outputs) == 1:
898        outputs = outputs[0]
899    return outputs
900
901
902def _parse_location(sym, location, ctx, dtype=default_dtype()):
903    """Parses the given location to a ordered dictionary.
904
905    Arguments of the provided op `sym` are used as dictionary keys
906    and elements of `location` are used as values.
907
908    Parameters
909    ----------
910    sym : Symbol
911        Symbol containing op
912    location : list or tuple or dict
913        Argument values location
914
915        - if type is list or tuple of `np.ndarray`
916            inner elements are arrays correspoding to
917            ``sym.list_arguments()``.
918        - if type is dict of str -> `np.ndarray`
919            maps the name of arguments to the corresponding `np.ndarray`.
920        *In either case, value of all the arguments must be provided.*
921    ctx : Context
922        Device context.
923    dtype: "asnumpy" or np.float16 or np.float32 or np.float64
924        If dtype is "asnumpy" then the mx.nd.array created will have the same
925        type as th numpy array from which it is copied.
926        Otherwise, dtype is the explicit datatype for all mx.nd.array objects
927        created in this function.
928
929    Returns
930    -------
931    dict
932        Dictionary with `sym` arguments as keys and `location` elements as
933        values.
934
935    Examples
936    -------
937    >>> a = mx.symbol.Variable('a')
938    >>> b = mx.symbol.Variable('b')
939    >>> l1 = np.ndarray([2,3])
940    >>> l2 = np.ndarray([3,4])
941    >>> _parse_location(a * b, [l1, l2], None)
942    {'a': <NDArray 2x3 @cpu(0)>, 'b': <NDArray 3x4 @cpu(0)>}
943    >>> _parse_location(a * b, {'a': l1, 'b': l2}, None)
944    {'a': <NDArray 2x3 @cpu(0)>, 'b': <NDArray 3x4 @cpu(0)>}
945    >>> _parse_location(a * b, {'a': l1}, None)
946    ValueError: Symbol arguments and keys of the given location do not match.
947    """
948    assert isinstance(location, (dict, list, tuple))
949    assert dtype == "asnumpy" or dtype in (np.float16, np.float32, np.float64)
950    if isinstance(location, dict):
951        if set(location.keys()) != set(sym.list_arguments()):
952            raise ValueError("Symbol arguments and keys of the given location do not match."
953                             "symbol args:%s, location.keys():%s"
954                             % (str(set(sym.list_arguments())), str(set(location.keys()))))
955    else:
956        location = {k: v for k, v in zip(sym.list_arguments(), location)}
957    location = {k: mx.nd.array(v, ctx=ctx, dtype=v.dtype if dtype == "asnumpy" else dtype) \
958               if isinstance(v, np.ndarray) else v for k, v in location.items()}
959    return _sorted_dict(location)
960
961
962def _parse_aux_states(sym, aux_states, ctx, dtype=default_dtype()):
963    """Parses the given auxiliary states to a dictionary.
964
965    Auxiliary states of the provided op `sym` are used as dictionary
966    keys and elements of `aux_states` are used as values.
967
968    Parameters
969    ----------
970    sym : Symbol
971        Symbol containing op
972    aux_states : None or list or dict
973        Aux states
974
975        - if type is list or tuple of `np.ndarray`
976            inner elements are arrays correspoding to
977            ``sym.list_auxiliary_states()``.
978        - if type is dict of str -> `np.ndarray`
979            maps the name of arguments to the corresponding `np.ndarray`.
980        *In either case, all aux states of `sym` must be provided.*
981    ctx : Context
982        Device context.
983    dtype: "asnumpy" or np.float16 or np.float32 or np.float64
984        If dtype is "asnumpy" then the mx.nd.array created will have the same
985        type as th numpy array from which it is copied.
986        Otherwise, dtype is the explicit datatype for all mx.nd.array objects
987        created in this function.
988
989    Returns
990    -------
991    dict
992        Dictionary with `sym` aux states as keys and `aux_states` elements
993        as values.
994
995    Examples
996    -------
997    >>> data = mx.symbol.Variable('data')
998    >>> weight = mx.sym.Variable(name='fc1_weight')
999    >>> fc1 = mx.symbol.FullyConnected(data = data, weight=weight, name='fc1', num_hidden=128)
1000    >>> fc2 = mx.symbol.BatchNorm(fc1, name='batchnorm0')
1001    >>> mean_states = np.ones(3)
1002    >>> var_states = np.ones(3)
1003    >>> _parse_aux_states(fc2, [mean_states, var_states], None)
1004    {'batchnorm0_moving_var': <NDArray 3 @cpu(0)>, 'batchnorm0_moving_mean': <NDArray 3 @cpu(0)>}
1005    >>> _parse_aux_states(fc2, {'batchnorm0_moving_var': mean_states,
1006    ...                         'batchnorm0_moving_mean': var_states}, None)
1007    {'batchnorm0_moving_var': <NDArray 3 @cpu(0)>, 'batchnorm0_moving_mean': <NDArray 3 @cpu(0)>}
1008    >>> _parse_aux_states(fc2, {'batchnorm0_moving_var': mean_states}, None)
1009    ValueError: Symbol aux_states names and given aux_states do not match.
1010    """
1011    assert dtype == "asnumpy" or dtype in (np.float16, np.float32, np.float64)
1012    if aux_states is not None:
1013        if isinstance(aux_states, dict):
1014            if set(aux_states.keys()) != set(sym.list_auxiliary_states()):
1015                raise ValueError("Symbol aux_states names and given aux_states do not match."
1016                                 "symbol aux_names:%s, aux_states.keys:%s"
1017                                 % (str(set(sym.list_auxiliary_states())),
1018                                    str(set(aux_states.keys()))))
1019        elif isinstance(aux_states, (list, tuple)):
1020            aux_names = sym.list_auxiliary_states()
1021            aux_states = {k:v for k, v in zip(aux_names, aux_states)}
1022        aux_states = {k: mx.nd.array(v, ctx=ctx, dtype=v.dtype if dtype == "asnumpy" else dtype) \
1023                      for k, v in aux_states.items()}
1024    return aux_states
1025
1026
1027def numeric_grad(executor, location, aux_states=None, eps=1e-4,
1028                 use_forward_train=True, dtype=default_dtype()):
1029    """Calculates a numeric gradient via finite difference method.
1030
1031    Class based on Theano's `theano.gradient.numeric_grad` [1]
1032
1033    Parameters
1034    ----------
1035    executor : Executor
1036        Executor that computes the forward pass.
1037    location : list of numpy.ndarray or dict of str to numpy.ndarray
1038        Argument values used as location to compute gradient
1039        Maps the name of arguments to the corresponding numpy.ndarray.
1040        Value of all the arguments must be provided.
1041    aux_states : None or list of numpy.ndarray or dict of str to numpy.ndarray, optional
1042        Auxiliary states values used as location to compute gradient
1043        Maps the name of aux_states to the corresponding numpy.ndarray.
1044        Value of all the auxiliary arguments must be provided.
1045    eps : float, optional
1046        Epsilon for the finite-difference method.
1047    use_forward_train : bool, optional
1048        Whether to use `is_train=True` in testing.
1049    dtype: np.float16 or np.float32 or np.float64
1050        Datatype for mx.nd.array.
1051
1052    References
1053    ---------
1054    ..[1] https://github.com/Theano/Theano/blob/master/theano/gradient.py
1055    """
1056    def as_stype(var, stype, dtype):
1057        return mx.nd.cast_storage(mx.nd.array(var, dtype=dtype), stype=stype)
1058
1059    assert dtype in (np.float16, np.float32, np.float64)
1060    approx_grads = {k: np.zeros(v.shape, dtype=dtype)
1061                    for k, v in location.items()}
1062    for k, v in location.items():
1063        stype = executor.arg_dict[k].stype
1064        if stype == 'default':
1065            executor.arg_dict[k][:] = as_stype(v, stype, dtype=dtype)
1066    for k in location:
1067        location[k] = np.asarray(location[k], order='C')
1068    for k, v in location.items():
1069        if v.dtype.kind != 'f':
1070            continue
1071        stype = executor.arg_dict[k].stype
1072        old_value = v.copy()
1073        for i in range(int(np.prod(v.shape))):
1074            # inplace update
1075            v.ravel()[i] += eps/2.0
1076            executor.arg_dict[k][:] = as_stype(v, stype, dtype=dtype)
1077            if aux_states is not None:
1078                for key, val in aux_states.items():
1079                    executor.aux_dict[key][:] = val
1080            executor.forward(is_train=use_forward_train)
1081            f_peps = executor.outputs[0].asnumpy()
1082
1083            v.ravel()[i] -= eps
1084            executor.arg_dict[k][:] = as_stype(v, stype, dtype=dtype)
1085            if aux_states is not None:
1086                for key, val in aux_states.items():
1087                    adstype = executor.aux_dict[key].stype
1088                    executor.aux_dict[key][:] = as_stype(val, adstype, dtype=dtype)
1089            executor.forward(is_train=use_forward_train)
1090            f_neps = executor.outputs[0].asnumpy()
1091
1092            approx_grad = (f_peps - f_neps).sum() / eps
1093            approx_grads[k].ravel()[i] = approx_grad
1094            v.ravel()[i] = old_value.ravel()[i]
1095        # copy back the original value
1096        executor.arg_dict[k][:] = as_stype(old_value, stype, dtype=dtype)
1097
1098    return approx_grads
1099
1100
1101def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=None, rtol=None,
1102                           atol=None, grad_nodes=None, use_forward_train=True, ctx=None,
1103                           grad_stype_dict=None, dtype=default_dtype()):
1104    """Verify an operation by checking backward pass via finite difference method.
1105
1106    Based on Theano's `theano.gradient.verify_grad` [1]
1107
1108    Parameters
1109    ----------
1110    sym : Symbol
1111        Symbol containing op to test
1112    location : list or tuple or dict
1113        Argument values used as location to compute gradient
1114
1115        - if type is list of numpy.ndarray, \
1116            inner elements should have the same order as mxnet.sym.list_arguments().
1117
1118        - if type is dict of str -> numpy.ndarray, \
1119            maps the name of arguments to the corresponding numpy.ndarray.
1120
1121        *In either case, value of all the arguments must be provided.*
1122    aux_states : list or tuple or dict, optional
1123        The auxiliary states required when generating the executor for the symbol.
1124    numeric_eps : float, optional
1125        Delta for the finite difference method that approximates the gradient.
1126    rtol : None or float
1127        The relative threshold. Default threshold will be used if set to ``None``.
1128    atol : None or float
1129        The absolute threshold. Default threshold will be used if set to ``None``.
1130    grad_nodes : None or list or tuple or dict, optional
1131        Names of the nodes to check gradient on
1132    use_forward_train : bool
1133        Whether to use is_train=True when computing the finite-difference.
1134    ctx : Context, optional
1135        Check the gradient computation on the specified device.
1136    grad_stype_dict : dict of str->str, optional
1137        Storage type dictionary for gradient ndarrays.
1138    dtype: np.float16 or np.float32 or np.float64
1139        Datatype for mx.nd.array.
1140
1141    References
1142    ---------
1143    [1] https://github.com/Theano/Theano/blob/master/theano/gradient.py
1144    """
1145    assert dtype in (np.float16, np.float32, np.float64)
1146    if ctx is None:
1147        ctx = default_context()
1148
1149    def random_projection(shape):
1150        """Get a random weight matrix with not too small elements
1151
1152        Parameters
1153        ----------
1154        shape : list or tuple
1155        """
1156        # random_projection should not have elements too small,
1157        # otherwise too much precision is lost in numerical gradient
1158        plain = np.random.rand(*shape) + 0.1
1159        return plain
1160
1161    location = _parse_location(sym=sym, location=location, ctx=ctx, dtype=dtype)
1162    location_npy = {k:v.asnumpy() for k, v in location.items()}
1163    aux_states = _parse_aux_states(sym=sym, aux_states=aux_states, ctx=ctx,
1164                                   dtype=dtype)
1165    if aux_states is not None:
1166        aux_states_npy = {k: v.asnumpy() for k, v in aux_states.items()}
1167    else:
1168        aux_states_npy = None
1169    if grad_nodes is None:
1170        grad_nodes = sym.list_arguments()
1171        grad_req = {k: 'write' for k in grad_nodes}
1172    elif isinstance(grad_nodes, (list, tuple)):
1173        grad_nodes = list(grad_nodes)
1174        grad_req = {k: 'write' for k in grad_nodes}
1175    elif isinstance(grad_nodes, dict):
1176        grad_req = grad_nodes.copy()
1177        grad_nodes = grad_nodes.keys()
1178    else:
1179        raise ValueError
1180
1181    input_shape = {k: v.shape for k, v in location.items()}
1182    _, out_shape, _ = sym.infer_shape(**input_shape)
1183    proj = mx.sym.Variable("__random_proj")
1184    is_np_sym = bool(isinstance(sym, np_symbol))
1185    if is_np_sym:  # convert to np symbol for using element-wise multiplication
1186        proj = proj.as_np_ndarray()
1187    out = sym * proj
1188    if is_np_sym:  # convert to classic symbol so that make_loss can be used
1189        out = out.as_nd_ndarray()
1190    out = mx.sym.make_loss(out)
1191
1192    location = dict(list(location.items()) +
1193                    [("__random_proj", mx.nd.array(random_projection(out_shape[0]),
1194                                                   ctx=ctx, dtype=dtype))])
1195    args_grad_npy = dict([(k, np.random.normal(0, 0.01, size=location[k].shape))
1196                          for k in grad_nodes]
1197                         + [("__random_proj", np.random.normal(0, 0.01, size=out_shape[0]))])
1198
1199    args_grad = {k: mx.nd.array(v, ctx=ctx, dtype=dtype) for k, v in args_grad_npy.items()}
1200    if grad_stype_dict is not None:
1201        assert isinstance(grad_stype_dict, dict), "grad_stype_dict must be a dict"
1202        for k, v in grad_stype_dict.items():
1203            if k in args_grad and v in _STORAGE_TYPE_STR_TO_ID and v != 'default':
1204                # create an uninitialized sparse ndarray for executor
1205                # if the symbolic grad is expected to be zero, it should not be initialized at all
1206                args_grad[k] = mx.nd.zeros(args_grad[k].shape, args_grad[k].context,
1207                                           args_grad[k].dtype, v)
1208
1209    executor = out.bind(ctx, grad_req=grad_req,
1210                        args=location, args_grad=args_grad, aux_states=aux_states)
1211
1212    inps = executor.arg_arrays
1213    if len(inps) != len(location):
1214        raise ValueError("Executor arg_arrays and and location len do not match."
1215                         "Got %d inputs and %d locations"%(len(inps), len(location)))
1216    assert len(executor.outputs) == 1
1217
1218    executor.forward(is_train=True)
1219
1220    eps = get_tolerance(executor.outputs[0], numeric_eps, default_numeric_eps())
1221    # cannot use finite differences with small eps without high precision
1222    if dtype in (np.float32, np.float16):
1223        assert eps >= 1e-5
1224
1225    executor.backward()
1226    symbolic_grads = executor.grad_dict
1227
1228    numeric_gradients = numeric_grad(
1229        executor, location_npy, aux_states_npy,
1230        eps=eps, use_forward_train=use_forward_train, dtype=dtype)
1231
1232    for name in grad_nodes:
1233        fd_grad = numeric_gradients[name]
1234        orig_grad = args_grad_npy[name]
1235        sym_grad = symbolic_grads[name]
1236        if grad_req[name] == 'write':
1237            assert_almost_equal(fd_grad, sym_grad, rtol, atol,
1238                                ("NUMERICAL_%s"%name, "BACKWARD_%s"%name))
1239        elif grad_req[name] == 'add':
1240            if isinstance(sym_grad, mx.nd.NDArray):
1241                sym_grad = sym_grad.asnumpy()
1242            assert_almost_equal(fd_grad, sym_grad - orig_grad, rtol, atol,
1243                                ("NUMERICAL_%s"%name, "BACKWARD_%s"%name))
1244        elif grad_req[name] == 'null':
1245            assert_almost_equal(orig_grad, sym_grad, rtol, atol,
1246                                ("NUMERICAL_%s"%name, "BACKWARD_%s"%name))
1247        else:
1248            raise ValueError("Invalid grad_req %s for argument %s"%(grad_req[name], name))
1249
1250
1251def check_symbolic_forward(sym, location, expected, rtol=None, atol=None,
1252                           aux_states=None, ctx=None, equal_nan=False,
1253                           dtype=default_dtype()):
1254    """Compares a symbol's forward results with the expected ones.
1255    Prints error messages if the forward results are not the same as the expected ones.
1256
1257    Parameters
1258    ---------
1259    sym : Symbol
1260        output symbol
1261    location : list of np.ndarray or dict of str to np.ndarray
1262        The evaluation point
1263
1264        - if type is list of np.ndarray
1265            Contains all the numpy arrays corresponding to `sym.list_arguments()`.
1266        - if type is dict of str to np.ndarray
1267            Contains the mapping between argument names and their values.
1268    expected : list of np.ndarray or dict of str to np.ndarray
1269        The expected output value
1270
1271        - if type is list of np.ndarray
1272            Contains arrays corresponding to exe.outputs.
1273        - if type is dict of str to np.ndarray
1274            Contains mapping between sym.list_output() and exe.outputs.
1275    rtol : None or float
1276        The relative threshold. Default threshold will be used if set to ``None``.
1277    atol : None or float
1278        The absolute threshold. Default threshold will be used if set to ``None``.
1279    aux_states : list of np.ndarray of dict, optional
1280        - if type is list of np.ndarray
1281            Contains all the NumPy arrays corresponding to sym.list_auxiliary_states
1282        - if type is dict of str to np.ndarray
1283            Contains the mapping between names of auxiliary states and their values.
1284    ctx : Context, optional
1285        running context
1286    dtype: "asnumpy" or np.float16 or np.float32 or np.float64
1287        If dtype is "asnumpy" then the mx.nd.array created will have the same
1288        type as th numpy array from which it is copied.
1289        Otherwise, dtype is the explicit datatype for all mx.nd.array objects
1290        created in this function.
1291
1292    equal_nan: Boolean
1293        if True, `nan` is a valid value for checking equivalency (ie `nan` == `nan`)
1294
1295    Example
1296    -------
1297    >>> shape = (2, 2)
1298    >>> lhs = mx.symbol.Variable('lhs')
1299    >>> rhs = mx.symbol.Variable('rhs')
1300    >>> sym_dot = mx.symbol.dot(lhs, rhs)
1301    >>> mat1 = np.array([[1, 2], [3, 4]])
1302    >>> mat2 = np.array([[5, 6], [7, 8]])
1303    >>> ret_expected = np.array([[19, 22], [43, 50]])
1304    >>> check_symbolic_forward(sym_dot, [mat1, mat2], [ret_expected])
1305    """
1306    assert dtype == "asnumpy" or dtype in (np.float16, np.float32, np.float64)
1307    if ctx is None:
1308        ctx = default_context()
1309
1310    location = _parse_location(sym=sym, location=location, ctx=ctx, dtype=dtype)
1311    aux_states = _parse_aux_states(sym=sym, aux_states=aux_states, ctx=ctx,
1312                                   dtype=dtype)
1313    if isinstance(expected, dict):
1314        expected = [expected[k] for k in sym.list_outputs()]
1315    args_grad_data = {k:mx.nd.empty(v.shape, ctx=ctx, dtype=v.dtype if dtype == "asnumpy" else dtype) \
1316                      for k, v in location.items()}
1317
1318    executor = sym.bind(ctx=ctx, args=location, args_grad=args_grad_data, aux_states=aux_states)
1319    for g in executor.grad_arrays:
1320        if g.ndim == 0:
1321            g[()] = 0
1322        else:
1323            g[:] = 0
1324
1325    executor.forward(is_train=False)
1326
1327    outputs = executor.outputs
1328    for output_name, expect, output in zip(sym.list_outputs(), expected, outputs):
1329        assert_almost_equal(expect, output, rtol, atol,
1330                            ("EXPECTED_%s"%output_name, "FORWARD_%s"%output_name),
1331                            equal_nan=equal_nan)
1332    return executor.outputs
1333
1334def check_symbolic_backward(sym, location, out_grads, expected, rtol=None, atol=None,
1335                            aux_states=None, grad_req='write', ctx=None, grad_stypes=None,
1336                            equal_nan=False, dtype=default_dtype()):
1337    """Compares a symbol's backward results with the expected ones.
1338    Prints error messages if the backward results are not the same as the expected results.
1339
1340    Parameters
1341    ---------
1342    sym : Symbol
1343        output symbol
1344    location : list of np.ndarray or dict of str to np.ndarray
1345        The evaluation point
1346
1347        - if type is list of np.ndarray
1348            Contains all the NumPy arrays corresponding to ``mx.sym.list_arguments``.
1349        - if type is dict of str to np.ndarray
1350            Contains the mapping between argument names and their values.
1351    out_grads : None or list of np.ndarray or dict of str to np.ndarray
1352        NumPys arrays corresponding to sym.outputs for incomming gradient.
1353
1354        - if type is list of np.ndarray
1355            Contains arrays corresponding to ``exe.outputs``.
1356        - if type is dict of str to np.ndarray
1357            contains mapping between mxnet.sym.list_output() and Executor.outputs
1358    expected : list of np.ndarray or dict of str to np.ndarray
1359        expected gradient values
1360
1361        - if type is list of np.ndarray
1362            Contains arrays corresponding to exe.grad_arrays
1363        - if type is dict of str to np.ndarray
1364            Contains mapping between ``sym.list_arguments()`` and exe.outputs.
1365    rtol : None or float
1366        The relative threshold. Default threshold will be used if set to ``None``.
1367    atol : None or float
1368        The absolute threshold. Default threshold will be used if set to ``None``.
1369    aux_states : list of np.ndarray or dict of str to np.ndarray
1370    grad_req : str or list of str or dict of str to str, optional
1371        Gradient requirements. 'write', 'add' or 'null'.
1372    ctx : Context, optional
1373        Running context.
1374    grad_stypes: dict of str->str
1375        dictionary of mapping argument name to stype for the gradient
1376    equal_nan: Boolean
1377        if True, `nan` is a valid value for checking equivalency (ie `nan` == `nan`)
1378    dtype: np.float16 or np.float32 or np.float64
1379        Datatype for mx.nd.array.
1380
1381    Example
1382    -------
1383    >>> lhs = mx.symbol.Variable('lhs')
1384    >>> rhs = mx.symbol.Variable('rhs')
1385    >>> sym_add = mx.symbol.elemwise_add(lhs, rhs)
1386    >>> mat1 = np.array([[1, 2], [3, 4]])
1387    >>> mat2 = np.array([[5, 6], [7, 8]])
1388    >>> grad1 = mx.nd.zeros(shape)
1389    >>> grad2 = mx.nd.zeros(shape)
1390    >>> exec_add = sym_add.bind(default_context(), args={'lhs': mat1, 'rhs': mat2},
1391    ... args_grad={'lhs': grad1, 'rhs': grad2}, grad_req={'lhs': 'write', 'rhs': 'write'})
1392    >>> exec_add.forward(is_train=True)
1393    >>> ograd = mx.nd.ones(shape)
1394    >>> grad_expected = ograd.copy().asnumpy()
1395    >>> check_symbolic_backward(sym_add, [mat1, mat2], [ograd], [grad_expected, grad_expected])
1396    """
1397    assert dtype == 'asnumpy' or dtype in (np.float16, np.float32, np.float64)
1398    if ctx is None:
1399        ctx = default_context()
1400
1401    location = _parse_location(sym=sym, location=location, ctx=ctx, dtype=dtype)
1402    aux_states = _parse_aux_states(sym=sym, aux_states=aux_states, ctx=ctx,
1403                                   dtype=dtype)
1404    if isinstance(expected, (list, tuple)):
1405        expected = {k:v for k, v in zip(sym.list_arguments(), expected)}
1406
1407    # Dirty the output buffer deterministically, for reproducibility.
1408    args_grad_npy = {k:np.random.normal(size=v.shape) for k, v in _sorted_items(expected)}
1409    args_grad_data = {}
1410    for k, v in args_grad_npy.items():
1411        nd = mx.nd.array(v, ctx=ctx, dtype=expected[k].dtype if dtype == "asnumpy" else dtype)
1412        if grad_stypes is not None and k in grad_stypes:
1413            stype = grad_stypes[k]
1414            if stype is not None and stype != 'default':
1415                out = create_sparse_array(v.shape, stype, density=0.0)
1416            else:
1417                out = nd
1418            args_grad_data[k] = out
1419        else:
1420            args_grad_data[k] = nd
1421
1422    if isinstance(grad_req, str):
1423        grad_req = {k:grad_req for k in sym.list_arguments()}
1424    elif isinstance(grad_req, (list, tuple)):
1425        grad_req = {k:v for k, v in zip(sym.list_arguments(), grad_req)}
1426
1427    executor = sym.bind(ctx=ctx, args=location, args_grad=args_grad_data,
1428                        aux_states=aux_states, grad_req=grad_req)
1429    executor.forward(is_train=True)
1430
1431    if isinstance(out_grads, (tuple, list)):
1432        outg = list()
1433        for arr in out_grads:
1434            if isinstance(arr, np.ndarray):
1435                outg.append(mx.nd.array(arr, ctx=ctx, dtype=arr.dtype if dtype == "asnumpy" else dtype))
1436            else:
1437                outg.append(arr)
1438        out_grads = outg
1439    elif isinstance(out_grads, dict):
1440        outg = dict()
1441        for k, v in out_grads.items():
1442            if isinstance(v, np.ndarray):
1443                outg[k] = mx.nd.array(v, ctx=ctx, dtype=v.dtype if dtype == "asnumpy" else dtype)
1444            else:
1445                outg[k] = v
1446        out_grads = outg
1447    else:
1448        assert out_grads is None
1449
1450    executor.backward(out_grads)
1451
1452    grads = args_grad_data
1453
1454    for name in expected:
1455        if grad_req[name] == 'write':
1456            assert_almost_equal(expected[name], grads[name], rtol, atol,
1457                                ("EXPECTED_%s"%name, "BACKWARD_%s"%name),
1458                                equal_nan=equal_nan)
1459        elif grad_req[name] == 'add':
1460            grad = grads[name].asnumpy() if isinstance(grads[name], mx.nd.NDArray) else grads[name]
1461            assert_almost_equal(expected[name], grad - args_grad_npy[name],
1462                                rtol, atol, ("EXPECTED_%s"%name, "BACKWARD_%s"%name),
1463                                equal_nan=equal_nan)
1464        elif grad_req[name] == 'null':
1465            assert_almost_equal(args_grad_npy[name], grads[name],
1466                                rtol, atol, ("EXPECTED_%s"%name, "BACKWARD_%s"%name),
1467                                equal_nan=equal_nan)
1468        else:
1469            raise ValueError("Invalid grad_req %s for argument %s"%(grad_req[name], name))
1470    return args_grad_data
1471
1472def check_speed(sym, location=None, ctx=None, N=20, grad_req=None, typ="whole",
1473                **kwargs):
1474    """Check the running speed of a symbol.
1475
1476    Parameters
1477    ----------
1478    sym : Symbol
1479        Symbol to run the speed test.
1480    location : none or dict of str to np.ndarray
1481        Location to evaluate the inner executor.
1482    ctx : Context
1483        Running context.
1484    N : int, optional
1485        Repeat times.
1486    grad_req : None or str or list of str or dict of str to str, optional
1487        Gradient requirements.
1488    typ : str, optional
1489        "whole" or "forward"
1490
1491        - "whole"
1492            Test the forward_backward speed.
1493        - "forward"
1494            Only test the forward speed.
1495    """
1496    if ctx is None:
1497        ctx = default_context()
1498
1499    if grad_req is None:
1500        grad_req = 'write'
1501    if location is None:
1502        exe = sym.simple_bind(grad_req=grad_req, ctx=ctx, **kwargs)
1503        location = {k: np.random.normal(size=arr.shape, scale=1.0) for k, arr in
1504                    exe.arg_dict.items()}
1505    else:
1506        assert isinstance(location, dict), "Expect dict, get \"location\"=%s" %str(location)
1507        exe = sym.simple_bind(grad_req=grad_req, ctx=ctx,
1508                              **{k: v.shape for k, v in location.items()})
1509
1510    for name, iarr in location.items():
1511        exe.arg_dict[name][:] = iarr.astype(exe.arg_dict[name].dtype)
1512
1513    if typ == "whole":
1514        # Warm up
1515        exe.forward(is_train=True)
1516        exe.backward(out_grads=exe.outputs)
1517        for output in exe.outputs:
1518            output.wait_to_read()
1519        # Test forward + backward
1520        tic = time.time()
1521        for _ in range(N):
1522            exe.forward(is_train=True)
1523            exe.backward(out_grads=exe.outputs)
1524        mx.nd.waitall()
1525        toc = time.time()
1526        forward_backward_time = (toc - tic) * 1.0 / N
1527        return forward_backward_time
1528    elif typ == "forward":
1529        # Warm up
1530        exe.forward(is_train=False)
1531        for output in exe.outputs:
1532            output.wait_to_read()
1533
1534        # Test forward only
1535        tic = time.time()
1536        for _ in range(N):
1537            exe.forward(is_train=False)
1538        mx.nd.waitall()
1539        toc = time.time()
1540        forward_time = (toc - tic) * 1.0 / N
1541        return forward_time
1542    else:
1543        raise ValueError('typ can only be "whole" or "forward".')
1544
1545
1546def check_consistency(sym, ctx_list, scale=1.0, grad_req='write',
1547                      arg_params=None, aux_params=None, rtol=None, atol=None,
1548                      raise_on_err=True, ground_truth=None, equal_nan=False,
1549                      use_uniform=False, rand_type=np.float64):
1550    """Check symbol gives the same output for different running context
1551
1552    Parameters
1553    ----------
1554    sym : Symbol or list of Symbols
1555        Symbol(s) to run the consistency test.
1556    ctx_list : list
1557        Running context. See example for more detail.
1558    scale : float, optional
1559        Standard deviation of the inner normal distribution. Used in initialization.
1560    grad_req : str or list of str or dict of str to str
1561        Gradient requirement.
1562    arg_params : dict of input name -> input data
1563        data to use for non-aux inputs
1564    aux_params : dict of input name -> input data
1565        data to use for aux inputs
1566    rtol : float or dictionary dtype->float, optional
1567        The relative error tolerance.
1568    atol : float or dictionary dtype->float, optional
1569        The absolute error tolerance.
1570    raise_on_err : bool, optional, defaults to True
1571        Should an error raise an exception (or just output exception message)
1572    ground_truth : dict of output name -> data, optional
1573        Provided ideal result to be compared against
1574    equal_nan : bool, optional, defaults to False
1575        Should nans be treated as equal in the comparison
1576    use_unifrom: bool
1577        Optional, When flag set to true,
1578        random input data generated follows uniform distribution,
1579        not normal distribution
1580    rand_type: np.dtype
1581        casts the randomly generated data to this type
1582        Optional, when input data is passed via arg_params,
1583        defaults to np.float64 (numpy float default)
1584
1585    Examples
1586    --------
1587    >>> # create the symbol
1588    >>> sym = mx.sym.Convolution(num_filter=3, kernel=(3,3), name='conv')
1589    >>> # initialize the running context
1590    >>> ctx_list =\
1591[{'ctx': mx.gpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float64}},\
1592 {'ctx': mx.gpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float32}},\
1593 {'ctx': mx.gpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float16}},\
1594 {'ctx': mx.cpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float64}},\
1595 {'ctx': mx.cpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float32}}]
1596    >>> check_consistency(sym, ctx_list)
1597    >>> sym = mx.sym.Concat(name='concat', num_args=2)
1598    >>> ctx_list = \
1599[{'ctx': mx.gpu(0), 'concat_arg1': (2, 10), 'concat_arg0': (2, 10),\
1600  'type_dict': {'concat_arg0': np.float64, 'concat_arg1': np.float64}},\
1601 {'ctx': mx.gpu(0), 'concat_arg1': (2, 10), 'concat_arg0': (2, 10),\
1602  'type_dict': {'concat_arg0': np.float32, 'concat_arg1': np.float32}},\
1603 {'ctx': mx.gpu(0), 'concat_arg1': (2, 10), 'concat_arg0': (2, 10),\
1604  'type_dict': {'concat_arg0': np.float16, 'concat_arg1': np.float16}},\
1605 {'ctx': mx.cpu(0), 'concat_arg1': (2, 10), 'concat_arg0': (2, 10),\
1606  'type_dict': {'concat_arg0': np.float64, 'concat_arg1': np.float64}},\
1607 {'ctx': mx.cpu(0), 'concat_arg1': (2, 10), 'concat_arg0': (2, 10),\
1608  'type_dict': {'concat_arg0': np.float32, 'concat_arg1': np.float32}}]
1609    >>> check_consistency(sym, ctx_list)
1610    """
1611
1612    assert len(ctx_list) > 1
1613    if isinstance(sym, Symbol):
1614        sym = [sym]*len(ctx_list)
1615    else:
1616        assert len(sym) == len(ctx_list)
1617
1618    output_names = sym[0].list_outputs()
1619    arg_names = sym[0].list_arguments()
1620    exe_list = []
1621    for s, ctx in zip(sym, ctx_list):
1622        assert s.list_arguments() == arg_names
1623        assert s.list_outputs() == output_names
1624        exe_list.append(s.simple_bind(grad_req=grad_req, **ctx))
1625
1626    arg_params = {} if arg_params is None else arg_params
1627    aux_params = {} if aux_params is None else aux_params
1628
1629    # returns the least precise of two dtypes
1630    def smaller_dtype(dt1, dt2):
1631        return dt1 if dt2 is None or np.dtype(dt1).itemsize < np.dtype(dt2).itemsize else dt2
1632
1633    # It's important to assign random inputs in a deterministic order, for reproducibility.
1634    for n, arr in _sorted_items(exe_list[0].arg_dict):
1635        if n not in arg_params:
1636            if use_uniform:
1637                arg_params[n] = np.random.uniform(low=-0.92 * scale, high=0.92 * scale,
1638                                                  size=arr.shape).astype(rand_type)
1639            else:
1640                arg_params[n] = np.random.normal(size=arr.shape,
1641                                                 scale=scale).astype(rand_type)
1642    for n, arr in exe_list[0].aux_dict.items():
1643        if n not in aux_params:
1644            aux_params[n] = 0
1645    for exe in exe_list:
1646        for name, arr in exe.arg_dict.items():
1647            arr[:] = arg_params[name]
1648        for name, arr in exe.aux_dict.items():
1649            arr[:] = aux_params[name]
1650        # We need to initialize the gradient arrays if it's add.
1651        if (grad_req == "add"):
1652            for arr in exe.grad_arrays:
1653                arr[:] = np.zeros(arr.shape, dtype=arr.dtype)
1654
1655    dtypes = [np.dtype(exe.outputs[0].dtype) for exe in exe_list]
1656    # Select the ground truth as the first model having the highest precision output[0]
1657    gt_idx = np.argmax(dtypes)
1658    gt = ground_truth
1659    if gt is None:
1660        gt = exe_list[gt_idx].output_dict.copy()
1661        if grad_req != 'null':
1662            gt.update(exe_list[gt_idx].grad_dict)
1663
1664    # test
1665    for exe in exe_list:
1666        exe.forward(is_train=False)
1667
1668    for i, exe in enumerate(exe_list):
1669        if i == gt_idx:
1670            continue
1671
1672        for name, arr in zip(output_names, exe.outputs):
1673            gtarr = gt[name]
1674            try:
1675                assert_almost_equal(arr, gtarr, rtol=rtol, atol=atol, equal_nan=equal_nan)
1676            except AssertionError as e:
1677                print('Predict Err: ctx %d vs ctx %d at %s'%(i, gt_idx, name))
1678                traceback.print_exc()
1679                if raise_on_err:
1680                    raise e
1681
1682                print(str(e))
1683
1684    # train
1685    if grad_req != 'null':
1686        # Perform forward()
1687        for exe in exe_list:
1688            exe.forward(is_train=True)
1689        # Use the first executor's output data, cast to the least precise dtype,
1690        # as the gradient data to pass to all executor's backward() call.
1691        least_precise_dtype = [out.dtype for out in exe_list[0].outputs]
1692        for exe in exe_list:
1693            least_precise_dtype = [smaller_dtype(out1.dtype, dt) \
1694                                    for (out1, dt) in zip(exe.outputs, least_precise_dtype)]
1695        golden_data_np = [out.astype(dt).asnumpy() \
1696                          for (out, dt) in zip(exe_list[0].outputs, least_precise_dtype)]
1697        # Perform backward()
1698        for exe in exe_list:
1699            out_grads = [mx.nd.array(golden_np, ctx=exe._ctx,
1700                                     dtype=out.dtype).tostype(out.stype)
1701                         for (golden_np, out) in zip(golden_data_np, exe.outputs)]
1702            exe.backward(out_grads)
1703
1704        for i, exe in enumerate(exe_list):
1705            if i == gt_idx:
1706                continue
1707
1708            curr = zip(output_names + arg_names, exe.outputs + exe.grad_arrays)
1709            for name, arr in curr:
1710                if gt[name] is None:
1711                    assert arr is None
1712                    continue
1713
1714                gtarr = gt[name]
1715                try:
1716                    rt, at = rtol, atol
1717                    # If the primary data i/o type is float16, then the tolerance used when
1718                    # comparing a float32 input gradient (e.g. batchnorm gamma) should be float16.
1719                    smaller_arr_dtype = smaller_dtype(arr.dtype, dtypes[i])
1720                    smaller_gt_dtype = smaller_dtype(gtarr.dtype, dtypes[gt_idx])
1721                    if smaller_arr_dtype != arr.dtype or \
1722                       smaller_gt_dtype != gtarr.dtype:
1723                        rt, at = get_tols(arr.astype(smaller_arr_dtype),
1724                                          gtarr.astype(smaller_gt_dtype), rtol, atol)
1725                    assert_almost_equal(arr, gtarr, rtol=rt, atol=at, equal_nan=equal_nan)
1726                except AssertionError as e:
1727                    print('Train Err: {} {} ctx {} vs {} {} ctx {} at {}'.format(
1728                        np.dtype(arr.dtype).name, arr.ctx, i,
1729                        np.dtype(gtarr.dtype).name, gtarr.ctx, gt_idx, name))
1730                    traceback.print_exc()
1731                    if raise_on_err:
1732                        raise e
1733
1734                    print(str(e))
1735
1736    return gt
1737
1738def list_gpus():
1739    """Return a list of GPUs
1740
1741    Returns
1742    -------
1743    list of int:
1744        If there are n GPUs, then return a list [0,1,...,n-1]. Otherwise returns
1745        [].
1746    """
1747    return range(mx.util.get_gpu_count())
1748
1749def download(url, fname=None, dirname=None, overwrite=False, retries=5):
1750    """Download an given URL
1751
1752    Parameters
1753    ----------
1754
1755    url : str
1756        URL to download
1757    fname : str, optional
1758        filename of the downloaded file. If None, then will guess a filename
1759        from url.
1760    dirname : str, optional
1761        output directory name. If None, then guess from fname or use the current
1762        directory
1763    overwrite : bool, optional
1764        Default is false, which means skipping download if the local file
1765        exists. If true, then download the url to overwrite the local file if
1766        exists.
1767    retries : integer, default 5
1768        The number of times to attempt the download in case of failure or non 200 return codes
1769
1770    Returns
1771    -------
1772    str
1773        The filename of the downloaded file
1774    """
1775
1776    assert retries >= 0, "Number of retries should be at least 0"
1777
1778    if fname is None:
1779        fname = url.split('/')[-1]
1780
1781    if dirname is None:
1782        dirname = os.path.dirname(fname)
1783    else:
1784        fname = os.path.join(dirname, fname)
1785    if dirname != "":
1786        if not os.path.exists(dirname):
1787            try:
1788                logging.info('create directory %s', dirname)
1789                os.makedirs(dirname)
1790            except OSError as exc:
1791                if exc.errno != errno.EEXIST:
1792                    raise OSError('failed to create ' + dirname)
1793
1794    if not overwrite and os.path.exists(fname):
1795        logging.info("%s exists, skipping download", fname)
1796        return fname
1797
1798    while retries+1 > 0:
1799        # Disable pyling too broad Exception
1800        # pylint: disable=W0703
1801        try:
1802            r = requests.get(url, stream=True)
1803            assert r.status_code == 200, "failed to open %s" % url
1804            with open(fname, 'wb') as f:
1805                for chunk in r.iter_content(chunk_size=1024):
1806                    if chunk: # filter out keep-alive new chunks
1807                        f.write(chunk)
1808                break
1809        except Exception as e:
1810            retries -= 1
1811            if retries <= 0:
1812                raise e
1813
1814            print("download failed, retrying, {} attempt{} left"
1815                  .format(retries, 's' if retries > 1 else ''))
1816    logging.info("downloaded %s into %s successfully", url, fname)
1817    return fname
1818
1819def download_model(model_name, dst_dir='./', meta_info=None):
1820    """Download a model from data.mxnet.io
1821
1822    Parameters
1823    ----------
1824    model_name : str
1825        Model name to download
1826    dst_dir : str
1827        Destination Directory to download the model
1828    meta_info : dict of dict
1829        Mapping from model_name to dict of the following structure:
1830        {'symbol': url, 'params': url}
1831
1832    Returns
1833    -------
1834    Two element tuple containing model_name and epoch for the params saved
1835    """
1836    _base_model_url = 'http://data.mxnet.io/models/'
1837    _default_model_info = {
1838        'imagenet1k-inception-bn': {'symbol':_base_model_url+'imagenet/inception-bn/Inception-BN-symbol.json',
1839                                    'params':_base_model_url+'imagenet/inception-bn/Inception-BN-0126.params'},
1840        'imagenet1k-resnet-18': {'symbol':_base_model_url+'imagenet/resnet/18-layers/resnet-18-symbol.json',
1841                                 'params':_base_model_url+'imagenet/resnet/18-layers/resnet-18-0000.params'},
1842        'imagenet1k-resnet-34': {'symbol':_base_model_url+'imagenet/resnet/34-layers/resnet-34-symbol.json',
1843                                 'params':_base_model_url+'imagenet/resnet/34-layers/resnet-34-0000.params'},
1844        'imagenet1k-resnet-50': {'symbol':_base_model_url+'imagenet/resnet/50-layers/resnet-50-symbol.json',
1845                                 'params':_base_model_url+'imagenet/resnet/50-layers/resnet-50-0000.params'},
1846        'imagenet1k-resnet-101': {'symbol':_base_model_url+'imagenet/resnet/101-layers/resnet-101-symbol.json',
1847                                  'params':_base_model_url+'imagenet/resnet/101-layers/resnet-101-0000.params'},
1848        'imagenet1k-resnet-152': {'symbol':_base_model_url+'imagenet/resnet/152-layers/resnet-152-symbol.json',
1849                                  'params':_base_model_url+'imagenet/resnet/152-layers/resnet-152-0000.params'},
1850        'imagenet1k-resnext-50': {'symbol':_base_model_url+'imagenet/resnext/50-layers/resnext-50-symbol.json',
1851                                  'params':_base_model_url+'imagenet/resnext/50-layers/resnext-50-0000.params'},
1852        'imagenet1k-resnext-101': {'symbol':_base_model_url+'imagenet/resnext/101-layers/resnext-101-symbol.json',
1853                                   'params':_base_model_url+'imagenet/resnext/101-layers/resnext-101-0000.params'},
1854        'imagenet1k-resnext-101-64x4d':
1855            {'symbol':_base_model_url+'imagenet/resnext/101-layers/resnext-101-64x4d-symbol.json',
1856             'params':_base_model_url+'imagenet/resnext/101-layers/resnext-101-64x4d-0000.params'},
1857        'imagenet11k-resnet-152':
1858            {'symbol':_base_model_url+'imagenet-11k/resnet-152/resnet-152-symbol.json',
1859             'params':_base_model_url+'imagenet-11k/resnet-152/resnet-152-0000.params'},
1860        'imagenet11k-place365ch-resnet-152':
1861            {'symbol':_base_model_url+'imagenet-11k-place365-ch/resnet-152-symbol.json',
1862             'params':_base_model_url+'imagenet-11k-place365-ch/resnet-152-0000.params'},
1863        'imagenet11k-place365ch-resnet-50':
1864            {'symbol':_base_model_url+'imagenet-11k-place365-ch/resnet-50-symbol.json',
1865             'params':_base_model_url+'imagenet-11k-place365-ch/resnet-50-0000.params'},
1866    }
1867
1868
1869    if meta_info is None:
1870        meta_info = _default_model_info
1871    meta_info = dict(meta_info)
1872    if model_name not in meta_info:
1873        return (None, 0)
1874    if not os.path.isdir(dst_dir):
1875        os.mkdir(dst_dir)
1876    meta = dict(meta_info[model_name])
1877    assert 'symbol' in meta, "missing symbol url"
1878    model_name = os.path.join(dst_dir, model_name)
1879    mx.test_utils.download(meta['symbol'], model_name+'-symbol.json')
1880    assert 'params' in meta, "mssing parameter file url"
1881    mx.test_utils.download(meta['params'], model_name+'-0000.params')
1882    return (model_name, 0)
1883
1884
1885def get_mnist():
1886    """Download and load the MNIST dataset
1887
1888    Returns
1889    -------
1890    dict
1891        A dict containing the data
1892    """
1893    def read_data(label_url, image_url):
1894        with gzip.open(mx.test_utils.download(label_url)) as flbl:
1895            struct.unpack(">II", flbl.read(8))
1896            label = np.frombuffer(flbl.read(), dtype=np.int8)
1897        with gzip.open(mx.test_utils.download(image_url), 'rb') as fimg:
1898            _, _, rows, cols = struct.unpack(">IIII", fimg.read(16))
1899            image = np.frombuffer(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)
1900            image = image.reshape(image.shape[0], 1, 28, 28).astype(np.float32)/255
1901        return (label, image)
1902
1903    # changed to mxnet.io for more stable hosting
1904    # path = 'http://yann.lecun.com/exdb/mnist/'
1905    path = 'http://data.mxnet.io/data/mnist/'
1906    (train_lbl, train_img) = read_data(
1907        path+'train-labels-idx1-ubyte.gz', path+'train-images-idx3-ubyte.gz')
1908    (test_lbl, test_img) = read_data(
1909        path+'t10k-labels-idx1-ubyte.gz', path+'t10k-images-idx3-ubyte.gz')
1910    return {'train_data':train_img, 'train_label':train_lbl,
1911            'test_data':test_img, 'test_label':test_lbl}
1912
1913def get_mnist_pkl():
1914    """Downloads MNIST dataset as a pkl.gz into a directory in the current directory
1915    with the name `data`
1916    """
1917    if not os.path.isdir("data"):
1918        os.makedirs('data')
1919    if not os.path.exists('data/mnist.pkl.gz'):
1920        download('http://deeplearning.net/data/mnist/mnist.pkl.gz',
1921                 dirname='data')
1922
1923def get_mnist_ubyte():
1924    """Downloads ubyte version of the MNIST dataset into a directory in the current directory
1925    with the name `data` and extracts all files in the zip archive to this directory.
1926    """
1927    if not os.path.isdir("data"):
1928        os.makedirs('data')
1929    if (not os.path.exists('data/train-images-idx3-ubyte')) or \
1930            (not os.path.exists('data/train-labels-idx1-ubyte')) or \
1931            (not os.path.exists('data/t10k-images-idx3-ubyte')) or \
1932            (not os.path.exists('data/t10k-labels-idx1-ubyte')):
1933        zip_file_path = download('http://data.mxnet.io/mxnet/data/mnist.zip',
1934                                 dirname='data')
1935        with zipfile.ZipFile(zip_file_path) as zf:
1936            zf.extractall('data')
1937
1938def get_cifar10():
1939    """Downloads CIFAR10 dataset into a directory in the current directory with the name `data`,
1940    and then extracts all files into the directory `data/cifar`.
1941    """
1942    if not os.path.isdir("data"):
1943        os.makedirs('data')
1944    if (not os.path.exists('data/cifar/train.rec')) or \
1945            (not os.path.exists('data/cifar/test.rec')) or \
1946            (not os.path.exists('data/cifar/train.lst')) or \
1947            (not os.path.exists('data/cifar/test.lst')):
1948        zip_file_path = download('http://data.mxnet.io/mxnet/data/cifar10.zip',
1949                                 dirname='data')
1950        with zipfile.ZipFile(zip_file_path) as zf:
1951            zf.extractall('data')
1952
1953def get_mnist_iterator(batch_size, input_shape, num_parts=1, part_index=0):
1954    """Returns training and validation iterators for MNIST dataset
1955    """
1956
1957    get_mnist_ubyte()
1958    flat = len(input_shape) != 3
1959
1960    train_dataiter = mx.io.MNISTIter(
1961        image="data/train-images-idx3-ubyte",
1962        label="data/train-labels-idx1-ubyte",
1963        input_shape=input_shape,
1964        batch_size=batch_size,
1965        shuffle=True,
1966        flat=flat,
1967        num_parts=num_parts,
1968        part_index=part_index)
1969
1970    val_dataiter = mx.io.MNISTIter(
1971        image="data/t10k-images-idx3-ubyte",
1972        label="data/t10k-labels-idx1-ubyte",
1973        input_shape=input_shape,
1974        batch_size=batch_size,
1975        flat=flat,
1976        num_parts=num_parts,
1977        part_index=part_index)
1978
1979    return (train_dataiter, val_dataiter)
1980
1981def get_zip_data(data_dir, url, data_origin_name):
1982    """Download and extract zip data.
1983
1984    Parameters
1985    ----------
1986
1987    data_dir : str
1988        Absolute or relative path of the directory name to store zip files
1989    url : str
1990        URL to download data from
1991    data_origin_name : str
1992        Name of the downloaded zip file
1993
1994    Examples
1995    --------
1996    >>> get_zip_data("data_dir",
1997                     "http://files.grouplens.org/datasets/movielens/ml-10m.zip",
1998                     "ml-10m.zip")
1999    """
2000    data_origin_name = os.path.join(data_dir, data_origin_name)
2001    if not os.path.exists(data_origin_name):
2002        download(url, dirname=data_dir, overwrite=False)
2003        zip_file = zipfile.ZipFile(data_origin_name)
2004        zip_file.extractall(path=data_dir)
2005
2006def get_bz2_data(data_dir, data_name, url, data_origin_name):
2007    """Download and extract bz2 data.
2008
2009    Parameters
2010    ----------
2011
2012    data_dir : str
2013        Absolute or relative path of the directory name to store bz2 files
2014    data_name : str
2015        Name of the output file in which bz2 contents will be extracted
2016    url : str
2017        URL to download data from
2018    data_origin_name : str
2019        Name of the downloaded b2 file
2020
2021    Examples
2022    --------
2023    >>> get_bz2_data("data_dir", "kdda.t",
2024                     "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.t.bz2",
2025                     "kdda.t.bz2")
2026    """
2027
2028    data_name = os.path.join(data_dir, data_name)
2029    data_origin_name = os.path.join(data_dir, data_origin_name)
2030    if not os.path.exists(data_name):
2031        download(url, fname=data_origin_name, dirname=data_dir, overwrite=False)
2032        bz_file = bz2.BZ2File(data_origin_name, 'rb')
2033        with open(data_name, 'wb') as fout:
2034            for line in bz_file:
2035                fout.write(line)
2036            bz_file.close()
2037        os.remove(data_origin_name)
2038
2039
2040def same_array(array1, array2):
2041    """Check whether two NDArrays sharing the same memory block
2042
2043    Parameters
2044    ----------
2045
2046    array1 : NDArray
2047        First NDArray to be checked
2048    array2 : NDArray
2049        Second NDArray to be checked
2050
2051    Returns
2052    -------
2053    bool
2054        Whether two NDArrays share the same memory
2055    """
2056    array1[:] += 1
2057    if not same(array1.asnumpy(), array2.asnumpy()):
2058        array1[:] -= 1
2059        return False
2060    array1[:] -= 1
2061    return same(array1.asnumpy(), array2.asnumpy())
2062
2063
2064@contextmanager
2065def discard_stderr():
2066    """
2067    Discards error output of a routine if invoked as:
2068
2069    with discard_stderr():
2070        ...
2071    """
2072    with open(os.devnull, 'w') as bit_bucket:
2073        try:
2074            stderr_fileno = sys.stderr.fileno()
2075            old_stderr = os.dup(stderr_fileno)
2076            try:
2077                os.dup2(bit_bucket.fileno(), stderr_fileno)
2078                yield
2079            finally:
2080                os.dup2(old_stderr, stderr_fileno)
2081        except AttributeError:
2082            # On some systems is stderr not a file descriptor but actually a virtual pipeline
2083            # that can not be copied
2084            yield
2085
2086
2087class DummyIter(mx.io.DataIter):
2088    """A dummy iterator that always returns the same batch of data
2089    (the first data batch of the real data iter). This is usually used for speed testing.
2090
2091    Parameters
2092    ----------
2093    real_iter: mx.io.DataIter
2094        The real data iterator where the first batch of data comes from
2095    """
2096    def __init__(self, real_iter):
2097        super(DummyIter, self).__init__()
2098        self.real_iter = real_iter
2099        self.provide_data = real_iter.provide_data
2100        self.provide_label = real_iter.provide_label
2101        self.batch_size = real_iter.batch_size
2102        self.the_batch = next(real_iter)
2103
2104    def __iter__(self):
2105        return self
2106
2107    def next(self):
2108        """Get a data batch from iterator. The first data batch of real iter is always returned.
2109        StopIteration will never be raised.
2110
2111        Returns
2112        -------
2113        DataBatch
2114            The data of next batch.
2115        """
2116        return self.the_batch
2117
2118def gen_buckets_probs_with_ppf(ppf, nbuckets):
2119    """Generate the buckets and probabilities for chi_square test when the ppf (Quantile function)
2120     is specified.
2121
2122    Parameters
2123    ----------
2124    ppf : function
2125        The Quantile function that takes a probability and maps it back to a value.
2126        It's the inverse of the cdf function
2127    nbuckets : int
2128        size of the buckets
2129
2130    Returns
2131    -------
2132    buckets : list of tuple
2133        The generated buckets
2134    probs : list
2135        The generate probabilities
2136    """
2137    assert nbuckets > 0
2138    probs = [1.0 / nbuckets for _ in range(nbuckets)]
2139    buckets = [(ppf(i / float(nbuckets)), ppf((i + 1) / float(nbuckets))) for i in range(nbuckets)]
2140    return buckets, probs
2141
2142def mean_check(generator, mu, sigma, nsamples=1000000):
2143    """Test the generator by matching the mean.
2144
2145    We test the sample mean by checking if it falls inside the range
2146        (mu - 3 * sigma / sqrt(n), mu + 3 * sigma / sqrt(n))
2147
2148    References::
2149
2150        @incollection{goucher2009beautiful,
2151              title={Beautiful Testing: Leading Professionals Reveal How They Improve Software},
2152              author={Goucher, Adam and Riley, Tim},
2153              year={2009},
2154              chapter=10
2155        }
2156
2157    Examples::
2158
2159        generator = lambda x: np.random.normal(0, 1.0, size=x)
2160        mean_check_ret = mean_check(generator, 0, 1.0)
2161
2162    Parameters
2163    ----------
2164    generator : function
2165        The generator function. It's expected to generate N i.i.d samples by calling generator(N).
2166    mu : float
2167    sigma : float
2168    nsamples : int
2169
2170    Returns
2171    -------
2172    ret : bool
2173        Whether the mean test succeeds
2174    """
2175    samples = np.array(generator(nsamples))
2176    sample_mean = samples.mean()
2177    ret = (sample_mean > mu - 3 * sigma / np.sqrt(nsamples)) and\
2178          (sample_mean < mu + 3 * sigma / np.sqrt(nsamples))
2179    return ret
2180
2181def get_im2rec_path(home_env="MXNET_HOME"):
2182    """Get path to the im2rec.py tool
2183
2184    Parameters
2185    ----------
2186
2187    home_env : str
2188        Env variable that holds the path to the MXNET folder
2189
2190    Returns
2191    -------
2192    str
2193        The path to im2rec.py
2194    """
2195    # Check first if the path to MXNET is passed as an env variable
2196    if home_env in os.environ:
2197        mxnet_path = os.environ[home_env]
2198    else:
2199        # Else use currently imported mxnet as reference
2200        mxnet_path = os.path.dirname(mx.__file__)
2201    # If MXNet was installed through pip, the location of im2rec.py
2202    im2rec_path = os.path.join(mxnet_path, 'tools', 'im2rec.py')
2203    if os.path.isfile(im2rec_path):
2204        return im2rec_path
2205    # If MXNet has been built locally
2206    im2rec_path = os.path.join(mxnet_path, '..', '..', 'tools', 'im2rec.py')
2207    if os.path.isfile(im2rec_path):
2208        return im2rec_path
2209    raise IOError('Could not find path to tools/im2rec.py')
2210
2211def var_check(generator, sigma, nsamples=1000000):
2212    """Test the generator by matching the variance.
2213    It will need a large number of samples and is not recommended to use
2214
2215    We test the sample variance by checking if it falls inside the range
2216        (sigma^2 - 3 * sqrt(2 * sigma^4 / (n-1)), sigma^2 + 3 * sqrt(2 * sigma^4 / (n-1)))
2217
2218    References::
2219
2220        @incollection{goucher2009beautiful,
2221              title={Beautiful Testing: Leading Professionals Reveal How They Improve Software},
2222              author={Goucher, Adam and Riley, Tim},
2223              year={2009},
2224              chapter=10
2225        }
2226
2227    Examples::
2228
2229        generator = lambda x: np.random.normal(0, 1.0, size=x)
2230        var_check_ret = var_check(generator, 0, 1.0)
2231
2232    Parameters
2233    ----------
2234    generator : function
2235        The generator function. It's expected to generate N i.i.d samples by calling generator(N).
2236    sigma : float
2237    nsamples : int
2238
2239    Returns
2240    -------
2241    ret : bool
2242        Whether the variance test succeeds
2243    """
2244    samples = np.array(generator(nsamples))
2245    sample_var = samples.var(ddof=1)
2246    ret = (sample_var > sigma ** 2 - 3 * np.sqrt(2 * sigma ** 4 / (nsamples - 1))) and\
2247          (sample_var < sigma ** 2 + 3 * np.sqrt(2 * sigma ** 4 / (nsamples - 1)))
2248    return ret
2249
2250def chi_square_check(generator, buckets, probs, nsamples=1000000):
2251    """Run the chi-square test for the generator. The generator can be both continuous and discrete.
2252
2253    If the generator is continuous, the buckets should contain tuples of (range_min, range_max) \
2254    and the probs should be the corresponding ideal probability within the specific ranges. \
2255    Otherwise, the buckets should contain all the possible values generated over the discrete distribution and the \
2256    probs should be groud-truth probability.
2257
2258    Usually the user is required to specify the probs parameter.
2259
2260    After obtaining the p value, we could further use the standard p > 0.05 (alpha) threshold to get \
2261    the final result.
2262
2263    Examples::
2264
2265      buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.norm.ppf(x, 0, 1), 5)
2266      generator = lambda x: np.random.normal(0, 1.0, size=x)
2267      p = chi_square_check(generator=generator, buckets=buckets, probs=probs)
2268      assert(p > 0.05)
2269
2270    Parameters
2271    ----------
2272    generator: function
2273        A function that is assumed to generate i.i.d samples from a specific distribution.
2274        generator(N) should generate N random samples.
2275    buckets: list of tuple or list of number
2276        The buckets to run the chi-square the test. Make sure that the buckets cover
2277        the whole range of the distribution. Also, the buckets must be in ascending order and have
2278        no intersection
2279    probs: list or tuple
2280        The ground-truth probability of the random value fall in a specific bucket.
2281    nsamples:int
2282        The number of samples to generate for the testing
2283
2284    Returns
2285    -------
2286    p : float
2287        p value that the generator has the expected distribution.
2288        A higher value indicates a larger confidence
2289    obs_freq : list
2290        Observed frequency of buckets
2291    expected_freq : list
2292        The expected (ground-truth) frequency of the buckets
2293    """
2294    if not ss:
2295        raise ImportError("scipy is not available."
2296                          " Please check if the scipy python bindings are installed.")
2297    assert isinstance(buckets, list)
2298    samples = generator(nsamples)
2299    assert len(probs) == len(buckets)
2300    if isinstance(buckets[0], (list, tuple)):
2301        # Check whether the buckets are valid and fill them into a npy array
2302        continuous_dist = True
2303        buckets_npy = np.zeros((len(buckets) * 2, ), dtype=np.float32)
2304        for i, _ in enumerate(buckets):
2305            assert(buckets[i][0] <= buckets[i][1])
2306            if i < len(buckets) - 1:
2307                assert(buckets[i][1] <= buckets[i + 1][0])
2308            buckets_npy[i * 2] = buckets[i][0]
2309            buckets_npy[i * 2 + 1] = buckets[i][1]
2310    else:
2311        continuous_dist = False
2312    expected_freq = (nsamples * np.array(probs, dtype=np.float32)).astype(np.int32)
2313    if continuous_dist:
2314        sample_bucket_ids = np.searchsorted(buckets_npy, samples, side='right')
2315    else:
2316        sample_bucket_ids = np.array(samples)
2317    if continuous_dist:
2318        sample_bucket_ids = sample_bucket_ids // 2
2319    obs_freq = np.zeros(shape=len(buckets), dtype=np.int)
2320    for i, _ in enumerate(buckets):
2321        if continuous_dist:
2322            obs_freq[i] = (sample_bucket_ids == i).sum()
2323        else:
2324            obs_freq[i] = (sample_bucket_ids == buckets[i]).sum()
2325    _, p = ss.chisquare(f_obs=obs_freq, f_exp=expected_freq)
2326    return p, obs_freq, expected_freq
2327
2328def verify_generator(generator, buckets, probs, nsamples=1000000, nrepeat=5, success_rate=0.2, alpha=0.05):
2329    """Verify whether the generator is correct using chi-square testing.
2330
2331    The test is repeated for "nrepeat" times and we check if the success rate is
2332     above the threshold (25% by default).
2333
2334    Parameters
2335    ----------
2336    generator: function
2337        A function that is assumed to generate i.i.d samples from a specific distribution.
2338            generator(N) should generate N random samples.
2339    buckets: list of tuple or list of number
2340        The buckets to run the chi-square the test. Make sure that the buckets cover
2341         the whole range of the distribution. Also, the buckets must be in ascending order and
2342         have no intersection
2343    probs: list or tuple
2344        The ground-truth probability of the random value fall in a specific bucket.
2345    nsamples: int
2346        The number of samples to generate for the testing
2347    nrepeat: int
2348        The times to repeat the test
2349    success_rate: float
2350        The desired success rate
2351    alpha: float
2352        The desired threshold for type-I error i.e. when a true null hypothesis is rejected
2353
2354    Returns
2355    -------
2356    cs_ret_l: list
2357        The p values of the chi-square test.
2358    """
2359    cs_ret_l = []
2360    obs_freq_l = []
2361    expected_freq_l = []
2362    for _ in range(nrepeat):
2363        cs_ret, obs_freq, expected_freq = chi_square_check(generator=generator, buckets=buckets,
2364                                                           probs=probs, nsamples=nsamples)
2365        cs_ret_l.append(cs_ret)
2366        obs_freq_l.append(obs_freq)
2367        expected_freq_l.append(expected_freq)
2368    success_num = (np.array(cs_ret_l) > alpha).sum()
2369    if success_num < nrepeat * success_rate:
2370        raise AssertionError("Generator test fails, Chi-square p=%s, obs_freq=%s, expected_freq=%s."
2371                             "\nbuckets=%s, probs=%s"
2372                             % (str(cs_ret_l), str(obs_freq_l), str(expected_freq_l),
2373                                str(buckets), str(probs)))
2374    return cs_ret_l
2375
2376def compare_ndarray_tuple(t1, t2, rtol=None, atol=None):
2377    """Compare ndarray tuple."""
2378    if t1 is None or t2 is None:
2379        return
2380
2381    if isinstance(t1, tuple):
2382        for s1, s2 in zip(t1, t2):
2383            compare_ndarray_tuple(s1, s2, rtol, atol)
2384    else:
2385        assert_almost_equal(t1, t2, rtol=rtol, atol=atol)
2386
2387
2388def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='default',
2389                      rtol=1e-4, atol=1e-5, compare_states=True, ntensors=1):
2390    """Compare opt1 and opt2."""
2391    if not isinstance(shape, list):
2392        assert(ntensors == 1)
2393        if w_stype == 'default':
2394            w2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
2395            w1 = w2.copyto(default_context())
2396        elif w_stype in ('row_sparse', 'csr'):
2397            w2 = rand_ndarray(shape, w_stype, density=1, dtype=dtype)
2398            w1 = w2.copyto(default_context()).tostype('default')
2399        else:
2400            raise Exception("type not supported yet")
2401        if g_stype == 'default':
2402            g2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype)
2403            g1 = g2.copyto(default_context())
2404        elif g_stype in ('row_sparse', 'csr'):
2405            g2 = rand_ndarray(shape, g_stype, dtype=dtype)
2406            g1 = g2.copyto(default_context()).tostype('default')
2407        else:
2408            raise Exception("type not supported yet")
2409
2410        state1 = opt1.create_state_multi_precision(0, w1)
2411        state2 = opt2.create_state_multi_precision(0, w2)
2412        if compare_states:
2413            compare_ndarray_tuple(state1, state2)
2414
2415        opt1.update_multi_precision(0, w1, g1, state1)
2416        opt2.update_multi_precision(0, w2, g2, state2)
2417        if compare_states:
2418            compare_ndarray_tuple(state1, state2, rtol=rtol, atol=atol)
2419        assert_almost_equal(w1, w2, rtol=rtol, atol=atol)
2420    else:
2421        # test multi-tensor: Opt1 single-tensor reference, Opt2 multi-tensor
2422        from copy import deepcopy
2423        w1, g1 = [], []
2424        for s in shape:
2425            w1.append(mx.random.uniform(shape=s, ctx=default_context(), dtype=dtype))
2426            g1.append(mx.random.uniform(shape=s, ctx=default_context(), dtype=dtype))
2427        w1 = tuple(w1)
2428        w2 = deepcopy(w1)
2429        g1 = tuple(g1)
2430        g2 = deepcopy(g1)
2431        state2 = [opt2.create_state_multi_precision(0, w2[i]) for i in range(ntensors)]
2432
2433        opt2.update_multi_precision(list(range(ntensors)), w2, g2, state2)
2434        for i in range(ntensors):
2435            state1 = opt1.create_state_multi_precision(i, w1[i])
2436            opt1.update_multi_precision(i, w1[i], g1[i], state1)
2437            if compare_states:
2438                compare_ndarray_tuple(state1, state2[i], rtol, atol)
2439            compare_ndarray_tuple(w1[i], w2[i], rtol, atol)
2440
2441def same_symbol_structure(sym1, sym2):
2442    """Compare two symbols to check if they have the same computation graph structure.
2443    Returns true if operator corresponding to a particular node id is same in both
2444    symbols for all nodes
2445    """
2446    conf = json.loads(sym1.tojson())
2447    nodes = conf["nodes"]
2448    conf2 = json.loads(sym2.tojson())
2449    nodes2 = conf2["nodes"]
2450    for node1, node2 in zip(nodes, nodes2):
2451        if node1["op"] != node2["op"]:
2452            return False
2453    return True
2454
2455
2456@contextmanager
2457def environment(*args):
2458    """
2459    Environment variable setter and unsetter via `with` idiom.
2460
2461    Takes a specification of env var names and desired values and adds those
2462    settings to the environment in advance of running the body of the `with`
2463    statement.  The original environment state is restored afterwards, even
2464    if exceptions are raised in the `with` body.
2465
2466    Parameters
2467    ----------
2468    args:
2469        if 2 args are passed:
2470            name, desired_value strings of the single env var to update, or
2471        if 1 arg is passed:
2472            a dict of name:desired_value for env var's to update
2473
2474    """
2475
2476    # On Linux, env var changes made through python's os.environ are seen
2477    # by the backend.  On Windows though, the C runtime gets a snapshot
2478    # of the environment that cannot be altered by os.environ.  Here we
2479    # check, using a wrapped version of the backend's getenv(), that
2480    # the desired env var value is seen by the backend, and otherwise use
2481    # a wrapped setenv() to establish that value in the backend.
2482
2483    # Also on Windows, a set env var can never have the value '', since
2484    # the command 'set FOO= ' is used to unset the variable.  Perhaps
2485    # as a result, the wrapped dmlc::GetEnv() routine returns the same
2486    # value for unset variables and those set to ''.  As a result, we
2487    # ignore discrepancy.
2488    def validate_backend_setting(name, value, can_use_setenv=True):
2489        backend_value = getenv(name)
2490        if value == backend_value or \
2491           value == '' and backend_value is None and platform.system() == 'Windows':
2492            return
2493        if not can_use_setenv:
2494            raise RuntimeError('Could not set env var {}={} within C Runtime'.format(name, value))
2495        setenv(name, value)
2496        validate_backend_setting(name, value, can_use_setenv=False)
2497
2498    # Core routine to alter environment from a dict of env_var_name, env_var_value pairs
2499    def set_environ(env_var_dict):
2500        for env_var_name, env_var_value in env_var_dict.items():
2501            if env_var_value is None:
2502                os.environ.pop(env_var_name, None)
2503            else:
2504                os.environ[env_var_name] = env_var_value
2505            validate_backend_setting(env_var_name, env_var_value)
2506
2507    # Create env_var name:value dict from the two calling methods of this routine
2508    if len(args) == 1 and isinstance(args[0], dict):
2509        env_vars = args[0]
2510    else:
2511        assert len(args) == 2, 'Expecting one dict arg or two args: env var name and value'
2512        env_vars = {args[0]: args[1]}
2513
2514    # Take a snapshot of the existing environment variable state
2515    # for those variables to be changed.  get() return None for unset keys.
2516    snapshot = {x: os.environ.get(x) for x in env_vars.keys()}
2517
2518    # Alter the environment per the env_vars dict
2519    set_environ(env_vars)
2520
2521    # Now run the wrapped code
2522    try:
2523        yield
2524    finally:
2525        # the backend engines may still be referencing the changed env var state
2526        mx.nd.waitall()
2527        # reinstate original env_var state per the snapshot taken earlier
2528        set_environ(snapshot)
2529
2530
2531def collapse_sum_like(a, shape):
2532    """Given `a` as a numpy ndarray, perform reduce_sum on `a` over the axes that do not
2533    exist in `shape`. Note that an ndarray with `shape` must be broadcastable to `a`.
2534    """
2535    assert len(a.shape) >= len(shape)
2536    if np.prod(shape) == 0 or a.size == 0:
2537        return np.zeros(shape, dtype=a.dtype)
2538    axes = []
2539    ndim_diff = len(a.shape) - len(shape)
2540    for i in range(ndim_diff):
2541        axes.append(i)
2542    for i, s in enumerate(shape):
2543        if s != a.shape[i+ndim_diff]:
2544            assert s == 1
2545            axes.append(i+ndim_diff)
2546    return np.sum(a, axis=tuple(axes)).reshape(shape)
2547
2548
2549def is_cd_run():
2550    """Checks if the test is running as part of a Continuous Delivery run"""
2551    return os.environ.get("CD_JOB", 0) == "1"
2552
2553
2554def is_aarch64_run():
2555    """Checks if the test is running on aarch64 instance"""
2556    return platform.machine() == "aarch64"
2557
2558
2559_features = Features()
2560
2561
2562def has_tvm_ops():
2563    """Returns True if MXNet is compiled with TVM generated operators. If current ctx
2564    is GPU, it only returns True for CUDA compute capability > 52 where FP16 is supported.
2565    """
2566    built_with_tvm_op = _features.is_enabled("TVM_OP")
2567    ctx = current_context()
2568    if ctx.device_type == 'gpu':
2569        try:
2570            cc = get_cuda_compute_capability(ctx)
2571        except:  # pylint: disable=bare-except
2572            print('Failed to get CUDA compute capability for context {}. The operators '
2573                  'built with USE_TVM_OP=1 will not be run in unit tests.'.format(ctx))
2574            return False
2575        print('Cuda arch compute capability: sm_{}'.format(str(cc)))
2576        return built_with_tvm_op and cc >= 53
2577    return built_with_tvm_op
2578
2579
2580def is_op_runnable():
2581    """Returns True for all CPU tests. Returns True for GPU tests that are either of the following.
2582    1. Built with USE_TVM_OP=0.
2583    2. Built with USE_TVM_OP=1, but with compute capability >= 53.
2584    """
2585    ctx = current_context()
2586    if ctx.device_type == 'gpu':
2587        if not _features.is_enabled("TVM_OP"):
2588            return True
2589        else:
2590            try:
2591                cc = get_cuda_compute_capability(ctx)
2592            except:  # pylint: disable=bare-except
2593                print('Failed to get CUDA compute capability for context {}. The operators '
2594                      'built with USE_TVM_OP=1 will not be run in unit tests.'.format(ctx))
2595                return False
2596            print('Cuda arch compute capability: sm_{}'.format(str(cc)))
2597            return cc >= 53
2598    return True
2599
2600
2601@use_np
2602def check_gluon_hybridize_consistency(net_builder, data_l, numpy_func=None, test_grad=True,
2603                                      rtol=1E-4, atol=1E-4):
2604    """Check whether a HybridBlock has consistent output between the hybridized
2605     v.s. non-hybridized versions
2606
2607    The network should not contain any random number generators.
2608
2609    Parameters
2610    ----------
2611    net_builder : function
2612        The builder of the HybridBlock that we are going to check the consistency.
2613        Inside the implementation, we will call net_builder() to construct the hybrid block.
2614        Also, the net_builder will need to support specifying the params
2615    data_l : list of mx.np.ndarray
2616        List of input ndarrays.
2617    numpy_func : function, optional
2618        The ground truth numpy function that has the same functionality as net_builder().
2619        Default None.
2620    test_grad : bool, optional
2621        Whether to test the consistency of the gradient. Default True.
2622    rtol : float, optional
2623        The relative error tolerance, default 1E-4. Default 1E-4.
2624    atol : float, optional
2625        The absolute error tolerance, default 1E-4. Default 1E-4.
2626    """
2627    class _NumpyParamDictInit(mx.init.Initializer):
2628        """Initializes parameters with the cached numpy ndarrays dictionary
2629        """
2630        def __init__(self, np_params):
2631            super(_NumpyParamDictInit, self).__init__()
2632            self._np_params = np_params
2633
2634        def _init_weight(self, name, arr):
2635            arr[()] = self._np_params[name]
2636    saved_out_np = None
2637    saved_grad_np_l = None
2638    params_init = None
2639    use_autograd_flags = [False, True] if test_grad else [False]
2640    for hybridize in [False, True]:
2641        for use_autograd in use_autograd_flags:
2642            net = net_builder(prefix='net_')
2643            if params_init is None:
2644                net.initialize()
2645            else:
2646                net.initialize(params_init)
2647            if hybridize:
2648                net.hybridize()
2649            in_data_l = [ele.copy() for ele in data_l]
2650            if use_autograd:
2651                for ele in in_data_l:
2652                    ele.attach_grad()
2653                with mx.autograd.record():
2654                    out = net(*in_data_l)
2655                out.backward(out)
2656            else:
2657                out = net(*in_data_l)
2658            if params_init is None:
2659                np_params = {k: v.data().asnumpy() for k, v in net.collect_params().items()}
2660                params_init = _NumpyParamDictInit(np_params)
2661            if saved_out_np is None:
2662                saved_out_np = out.asnumpy()
2663            else:
2664                # Check for correctness
2665                assert_almost_equal(out.asnumpy(), saved_out_np, rtol=rtol, atol=atol)
2666            if use_autograd:
2667                if saved_grad_np_l is None:
2668                    saved_grad_np_l = [ele.grad.asnumpy() for ele in in_data_l]
2669                else:
2670                    # Check for correctness
2671                    for data, saved_grad_np in zip(in_data_l, saved_grad_np_l):
2672                        assert_almost_equal(data.grad.asnumpy(), saved_grad_np,
2673                                            rtol=rtol, atol=atol)
2674    if numpy_func is not None:
2675        numpy_out = numpy_func(*[ele.asnumpy() for ele in data_l])
2676        assert_almost_equal(saved_out_np, numpy_out, rtol=rtol, atol=atol)
2677
2678
2679def new_matrix_with_real_eigvals_2d(n):
2680    """Generate a well-conditioned matrix with small real eigenvalues."""
2681    shape = (n, n)
2682    q = np.ones(shape)
2683    while 1:
2684        D = np.diag(np.random.uniform(-1.0, 1.0, shape[-1]))
2685        I = np.eye(shape[-1]).reshape(shape)
2686        v = np.random.uniform(-1., 1., shape[-1]).reshape(shape[:-1] + (1,))
2687        v = v / np.linalg.norm(v, axis=-2, keepdims=True)
2688        v_T = np.swapaxes(v, -1, -2)
2689        U = I - 2 * np.matmul(v, v_T)
2690        q = np.matmul(U, D)
2691        if (np.linalg.cond(q, 2) < 3):
2692            break
2693    D = np.diag(np.random.uniform(-10.0, 10.0, n))
2694    q_inv = np.linalg.inv(q)
2695    return np.matmul(np.matmul(q_inv, D), q)
2696
2697
2698def new_matrix_with_real_eigvals_nd(shape):
2699    """Generate well-conditioned matrices with small real eigenvalues."""
2700    n = int(np.prod(shape[:-2])) if len(shape) > 2 else 1
2701    return np.array([new_matrix_with_real_eigvals_2d(shape[-1]) for i in range(n)]).reshape(shape)
2702
2703
2704def new_orthonormal_matrix_2d(n):
2705    """Generate a orthonormal matrix."""
2706    x = np.random.randn(n, n)
2707    x_trans = x.T
2708    sym_mat = np.matmul(x_trans, x)
2709    return np.linalg.qr(sym_mat)[0]
2710
2711
2712def new_sym_matrix_with_real_eigvals_2d(n):
2713    """Generate a sym matrix with real eigenvalues."""
2714    q = new_orthonormal_matrix_2d(n)
2715    D = np.diag(np.random.uniform(-10.0, 10.0, n))
2716    return np.matmul(np.matmul(q.T, D), q)
2717
2718
2719def new_sym_matrix_with_real_eigvals_nd(shape):
2720    """Generate sym matrices with real eigenvalues."""
2721    n = int(np.prod(shape[:-2])) if len(shape) > 2 else 1
2722    return np.array([new_sym_matrix_with_real_eigvals_2d(shape[-1]) for i in range(n)]).reshape(shape)
2723