1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# coding: utf-8
19# pylint: disable=wildcard-import, unused-wildcard-import, too-many-lines
20"""Sparse NDArray API of MXNet."""
21
22try:
23    from __builtin__ import slice as py_slice
24    from __builtin__ import sum as py_sum
25except ImportError:
26    from builtins import slice as py_slice
27    from builtins import sum as py_sum
28
29import ctypes
30import warnings
31import operator
32from array import array as native_array
33
34__all__ = ["_ndarray_cls", "csr_matrix", "row_sparse_array",
35           "BaseSparseNDArray", "CSRNDArray", "RowSparseNDArray",
36           "add", "subtract", "multiply", "divide"]
37
38import numpy as np
39from ..base import NotSupportedForSparseNDArray
40from ..base import _LIB, numeric_types
41from ..base import c_array_buf, mx_real_t, integer_types
42from ..base import NDArrayHandle, check_call
43from ..context import Context, current_context
44from . import _internal
45from . import op
46try:
47    from .gen_sparse import retain as gs_retain # pylint: disable=redefined-builtin
48except ImportError:
49    gs_retain = None
50from ._internal import _set_ndarray_class
51from .ndarray import NDArray, _storage_type, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP
52from .ndarray import _STORAGE_TYPE_STR_TO_ID, _STORAGE_TYPE_ROW_SPARSE, _STORAGE_TYPE_CSR, _int64_enabled
53from .ndarray import _STORAGE_TYPE_UNDEFINED, _STORAGE_TYPE_DEFAULT
54from .ndarray import zeros as _zeros_ndarray
55from .ndarray import array as _array
56from .ndarray import _ufunc_helper
57
58
59try:
60    import scipy.sparse as spsp
61except ImportError:
62    spsp = None
63
64_STORAGE_AUX_TYPES = {
65    'row_sparse': [np.int64],
66    'csr': [np.int64, np.int64]
67}
68
69
70def _new_alloc_handle(stype, shape, ctx, delay_alloc, dtype, aux_types, aux_shapes=None):
71    """Return a new handle with specified storage type, shape, dtype and context.
72
73    Empty handle is only used to hold results
74
75    Returns
76    -------
77    handle
78        A new empty ndarray handle
79    """
80    hdl = NDArrayHandle()
81    for aux_t in aux_types:
82        if np.dtype(aux_t) != np.dtype("int64"):
83            raise NotImplementedError("only int64 is supported for aux types")
84    aux_type_ids = [int(_DTYPE_NP_TO_MX[np.dtype(aux_t).type]) for aux_t in aux_types]
85    aux_shapes = [(0,) for aux_t in aux_types] if aux_shapes is None else aux_shapes
86    aux_shape_lens = [len(aux_shape) for aux_shape in aux_shapes]
87    aux_shapes = py_sum(aux_shapes, ())
88    num_aux = ctypes.c_uint(len(aux_types))
89    if _int64_enabled():
90        check_call(_LIB.MXNDArrayCreateSparseEx64(
91            ctypes.c_int(int(_STORAGE_TYPE_STR_TO_ID[stype])),
92            c_array_buf(ctypes.c_int64, native_array('q', shape)),
93            ctypes.c_int(len(shape)),
94            ctypes.c_int(ctx.device_typeid),
95            ctypes.c_int(ctx.device_id),
96            ctypes.c_int(int(delay_alloc)),
97            ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
98            num_aux,
99            c_array_buf(ctypes.c_int, native_array('i', aux_type_ids)),
100            c_array_buf(ctypes.c_int, native_array('i', aux_shape_lens)),
101            c_array_buf(ctypes.c_int64, native_array('q', aux_shapes)),
102            ctypes.byref(hdl)))
103    else:
104        check_call(_LIB.MXNDArrayCreateSparseEx(
105            ctypes.c_int(int(_STORAGE_TYPE_STR_TO_ID[stype])),
106            c_array_buf(ctypes.c_uint, native_array('I', shape)),
107            ctypes.c_uint(len(shape)),
108            ctypes.c_int(ctx.device_typeid),
109            ctypes.c_int(ctx.device_id),
110            ctypes.c_int(int(delay_alloc)),
111            ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])),
112            num_aux,
113            c_array_buf(ctypes.c_int, native_array('i', aux_type_ids)),
114            c_array_buf(ctypes.c_uint, native_array('I', aux_shape_lens)),
115            c_array_buf(ctypes.c_uint, native_array('I', aux_shapes)),
116            ctypes.byref(hdl)))
117    return hdl
118
119
120class BaseSparseNDArray(NDArray):
121    """The base class of an NDArray stored in a sparse storage format.
122
123    See CSRNDArray and RowSparseNDArray for more details.
124    """
125
126    def __repr__(self):
127        """Returns a string representation of the sparse array."""
128        shape_info = 'x'.join(['%d' % x for x in self.shape])
129        # The data content is not displayed since the array usually has big shape
130        return '\n<%s %s @%s>' % (self.__class__.__name__,
131                                  shape_info, self.context)
132
133    def __add__(self, other):
134        return add(self, other)
135
136    def __sub__(self, other):
137        return subtract(self, other)
138
139    def __mul__(self, other):
140        return multiply(self, other)
141
142    def __div__(self, other):
143        return divide(self, other)
144
145    def __iadd__(self, other):
146        raise NotImplementedError()
147
148    def __isub__(self, other):
149        raise NotImplementedError()
150
151    def __imul__(self, other):
152        raise NotImplementedError()
153
154    def __idiv__(self, other):
155        raise NotImplementedError()
156
157    def __itruediv__(self, other):
158        raise NotImplementedError()
159
160    def _sync_copyfrom(self, source_array):
161        raise NotImplementedError()
162
163    def _at(self, idx):
164        raise NotSupportedForSparseNDArray(self._at, '[idx]', idx)
165
166    def _slice(self, start, stop):
167        raise NotSupportedForSparseNDArray(self._slice, None, start, stop)
168
169    def reshape(self, *shape, **kwargs):
170        raise NotSupportedForSparseNDArray(self.reshape, None, shape)
171
172    @property
173    def size(self):
174        # the `size` for a sparse ndarray is ambiguous, hence disabled.
175        raise NotImplementedError()
176
177    def _aux_type(self, i):
178        """Data-type of the array's ith aux data.
179
180        Returns
181        -------
182        numpy.dtype
183            This BaseSparseNDArray's aux data type.
184        """
185        aux_type = ctypes.c_int()
186        check_call(_LIB.MXNDArrayGetAuxType(self.handle, i, ctypes.byref(aux_type)))
187        return _DTYPE_MX_TO_NP[aux_type.value]
188
189    @property
190    def _num_aux(self):
191        """The number of aux data used to help store the sparse ndarray.
192        """
193        return len(_STORAGE_AUX_TYPES[self.stype])
194
195    @property
196    def _aux_types(self):
197        """The data types of the aux data for the BaseSparseNDArray.
198        """
199        aux_types = []
200        num_aux = self._num_aux
201        for i in range(num_aux):
202            aux_types.append(self._aux_type(i))
203        return aux_types
204
205    def asnumpy(self):
206        """Return a dense ``numpy.ndarray`` object with value copied from this array
207        """
208        return self.tostype('default').asnumpy()
209
210    def astype(self, dtype, copy=True):
211        """Return a copy of the array after casting to a specified type.
212
213        Parameters
214        ----------
215        dtype : numpy.dtype or str
216            The type of the returned array.
217        copy : bool
218            Default `True`. By default, astype always returns a newly
219            allocated ndarray on the same context. If this is set to
220            `False`, and the dtype requested is the same as the ndarray's
221            dtype, the ndarray is returned instead of a copy.
222
223        Examples
224        --------
225        >>> x = mx.nd.sparse.zeros('row_sparse', (2,3), dtype='float32')
226        >>> y = x.astype('int32')
227        >>> y.dtype
228        <type 'numpy.int32'>
229        """
230        if not copy and np.dtype(dtype) == self.dtype:
231            return self
232
233        res = zeros(shape=self.shape, ctx=self.context,
234                    dtype=dtype, stype=self.stype)
235        self.copyto(res)
236        return res
237
238    def copyto(self, other):
239        """Copies the value of this array to another array.
240
241        Parameters
242        ----------
243        other : NDArray or CSRNDArray or RowSparseNDArray or Context
244            The destination array or context.
245
246        Returns
247        -------
248        NDArray or CSRNDArray or RowSparseNDArray
249            The copied array.
250        """
251        # pylint: disable= no-member, protected-access
252        if isinstance(other, NDArray):
253            if other.handle is self.handle:
254                warnings.warn('You are attempting to copy an array to itself', RuntimeWarning)
255                return False
256            return _internal._copyto(self, out=other)
257        elif isinstance(other, Context):
258            hret = _ndarray_cls(_new_alloc_handle(self.stype, self.shape, other,
259                                                  True, self.dtype, self._aux_types))
260            return _internal._copyto(self, out=hret)
261        else:
262            raise TypeError('copyto does not support type ' + str(type(other)))
263        # pylint: enable= no-member, protected-access
264
265    def check_format(self, full_check=True):
266        """Check whether the NDArray format is valid.
267
268        Parameters
269        ----------
270        full_check : bool, optional
271            If `True`, rigorous check, O(N) operations. Otherwise
272            basic check, O(1) operations (default True).
273        """
274        check_call(_LIB.MXNDArraySyncCheckFormat(self.handle, ctypes.c_bool(full_check)))
275
276    def _data(self):
277        """A deep copy NDArray of the data array associated with the BaseSparseNDArray.
278
279        This function blocks. Do not use it in performance critical code.
280        """
281        self.wait_to_read()
282        hdl = NDArrayHandle()
283        check_call(_LIB.MXNDArrayGetDataNDArray(self.handle, ctypes.byref(hdl)))
284        return NDArray(hdl)
285
286
287    def _aux_data(self, i):
288        """ Get a deep copy NDArray of the i-th aux data array associated with the
289        BaseSparseNDArray.
290
291        This function blocks. Do not use it in performance critical code.
292        """
293        self.wait_to_read()
294        hdl = NDArrayHandle()
295        check_call(_LIB.MXNDArrayGetAuxNDArray(self.handle, i, ctypes.byref(hdl)))
296        return NDArray(hdl)
297
298
299# pylint: disable=abstract-method
300class CSRNDArray(BaseSparseNDArray):
301    """A sparse representation of 2D NDArray in the Compressed Sparse Row format.
302
303    A CSRNDArray represents an NDArray as three separate arrays: `data`,
304    `indptr` and `indices`. It uses the CSR representation where the column indices for
305    row i are stored in ``indices[indptr[i]:indptr[i+1]]`` and their corresponding values are stored
306    in ``data[indptr[i]:indptr[i+1]]``.
307
308    The column indices for a given row are expected to be sorted in ascending order.
309    Duplicate column entries for the same row are not allowed.
310
311    Example
312    -------
313    >>> a = mx.nd.array([[0, 1, 0], [2, 0, 0], [0, 0, 0], [0, 0, 3]])
314    >>> a = a.tostype('csr')
315    >>> a.data.asnumpy()
316    array([ 1.,  2.,  3.], dtype=float32)
317    >>> a.indices.asnumpy()
318    array([1, 0, 2])
319    >>> a.indptr.asnumpy()
320    array([0, 1, 2, 2, 3])
321
322    See Also
323    --------
324    csr_matrix: Several ways to construct a CSRNDArray
325    """
326
327    def __reduce__(self):
328        return CSRNDArray, (None,), super(CSRNDArray, self).__getstate__()
329
330    def __iadd__(self, other):
331        (self + other).copyto(self)
332        return self
333
334    def __isub__(self, other):
335        (self - other).copyto(self)
336        return self
337
338    def __imul__(self, other):
339        (self * other).copyto(self)
340        return self
341
342    def __idiv__(self, other):
343        (self / other).copyto(self)
344        return self
345
346    def __itruediv__(self, other):
347        (self / other).copyto(self)
348        return self
349
350    def __getitem__(self, key):
351        """x.__getitem__(i) <=> x[i]
352
353        Returns a newly created NDArray based on the indexing key.
354
355        Parameters
356        ----------
357        key : int or mxnet.ndarray.NDArray.slice
358            Indexing key.
359
360        Examples
361        --------
362        >>> indptr = np.array([0, 2, 3, 6])
363        >>> indices = np.array([0, 2, 2, 0, 1, 2])
364        >>> data = np.array([1, 2, 3, 4, 5, 6])
365        >>> a = mx.nd.sparse.csr_matrix((data, indices, indptr), shape=(3, 3))
366        >>> a.asnumpy()
367        array([[ 1.,  0.,  2.],
368               [ 0.,  0.,  3.],
369               [ 4.,  5.,  6.]], dtype=float32)
370        >>> a[1:2].asnumpy()
371        array([[ 0.,  0.,  3.]], dtype=float32)
372        >>> a[1].asnumpy()
373        array([[ 0.,  0.,  3.]], dtype=float32)
374        >>> a[-1].asnumpy()
375        array([[ 4.,  5.,  6.]], dtype=float32)
376        """
377        # pylint: disable= no-member, protected-access
378        if isinstance(key, int):
379            if key == -1:
380                begin = self.shape[0] - 1
381            else:
382                begin = key
383            return op.slice(self, begin=begin, end=begin+1)
384        if isinstance(key, py_slice):
385            if key.step is not None:
386                raise ValueError('CSRNDArray only supports continuous slicing on axis 0')
387            if key.start is not None or key.stop is not None:
388                begin = key.start if key.start else 0
389                end = key.stop if key.stop else self.shape[0]
390                return op.slice(self, begin=begin, end=end)
391            else:
392                return self
393        if isinstance(key, tuple):
394            raise ValueError('Multi-dimension indexing is not supported')
395        raise ValueError('Undefined behaviour for {}'.format(key))
396        # pylint: enable= no-member, protected-access
397
398    def __setitem__(self, key, value):
399        """x.__setitem__(i, y) <=> x[i]=y
400
401        Set self[key] to value. Only slice key [:] is supported.
402
403        Parameters
404        ----------
405        key : mxnet.ndarray.NDArray.slice
406            The indexing key.
407        value : NDArray or CSRNDArray or numpy.ndarray
408            The value to set.
409
410        Examples
411        --------
412        >>> src = mx.nd.sparse.zeros('csr', (3,3))
413        >>> src.asnumpy()
414        array([[ 0.,  0.,  0.],
415               [ 0.,  0.,  0.],
416               [ 0.,  0.,  0.]], dtype=float32)
417        >>> # assign CSRNDArray with same storage type
418        >>> x = mx.nd.ones((3,3)).tostype('csr')
419        >>> x[:] = src
420        >>> x.asnumpy()
421        array([[ 1.,  1.,  1.],
422               [ 1.,  1.,  1.],
423               [ 1.,  1.,  1.]], dtype=float32)
424        >>> # assign NDArray to CSRNDArray
425        >>> x[:] = mx.nd.ones((3,3)) * 2
426        >>> x.asnumpy()
427        array([[ 2.,  2.,  2.],
428               [ 2.,  2.,  2.],
429               [ 2.,  2.,  2.]], dtype=float32)
430        """
431        if not self.writable:
432            raise ValueError('Failed to assign to a readonly CSRNDArray')
433        if isinstance(key, py_slice):
434            if key.step is not None or key.start is not None or key.stop is not None:
435                raise ValueError('Assignment with slice for CSRNDArray is not ' \
436                                 'implemented yet.')
437            if isinstance(value, NDArray):
438                # avoid copying to itself
439                if value.handle is not self.handle:
440                    value.copyto(self)
441            elif isinstance(value, numeric_types):
442                raise ValueError("Assigning numeric types to CSRNDArray is " \
443                                 "not implemented yet.")
444            elif isinstance(value, (np.ndarray, np.generic)):
445                # TODO(haibin/anisub) check scipy.sparse and use _sync_copy_from to
446                # avoid the temporary copy
447                warnings.warn('Assigning non-NDArray object to CSRNDArray is not efficient',
448                              RuntimeWarning)
449                tmp = _array(value)
450                tmp.copyto(self)
451            else:
452                raise TypeError('type %s not supported' % str(type(value)))
453        else:
454            assert(isinstance(key, (int, tuple)))
455            raise Exception('CSRNDArray only supports [:] for assignment')
456
457    @property
458    def indices(self):
459        """A deep copy NDArray of the indices array of the CSRNDArray.
460        This generates a deep copy of the column indices of the current `csr` matrix.
461
462        Returns
463        -------
464        NDArray
465            This CSRNDArray's indices array.
466        """
467        return self._aux_data(1)
468
469    @property
470    def indptr(self):
471        """A deep copy NDArray of the indptr array of the CSRNDArray.
472        This generates a deep copy of the `indptr` of the current `csr` matrix.
473
474        Returns
475        -------
476        NDArray
477            This CSRNDArray's indptr array.
478        """
479        return self._aux_data(0)
480
481    @property
482    def data(self):
483        """A deep copy NDArray of the data array of the CSRNDArray.
484        This generates a deep copy of the `data` of the current `csr` matrix.
485
486        Returns
487        -------
488        NDArray
489            This CSRNDArray's data array.
490        """
491        return self._data()
492
493    @indices.setter
494    def indices(self, indices):
495        raise NotImplementedError()
496
497    @indptr.setter
498    def indptr(self, indptr):
499        raise NotImplementedError()
500
501    @data.setter
502    def data(self, data):
503        raise NotImplementedError()
504
505
506    def tostype(self, stype):
507        """Return a copy of the array with chosen storage type.
508
509        Returns
510        -------
511        NDArray or CSRNDArray
512            A copy of the array with the chosen storage stype
513        """
514        # pylint: disable= no-member, protected-access
515        if stype == 'row_sparse':
516            raise ValueError("cast_storage from csr to row_sparse is not supported")
517        return op.cast_storage(self, stype=stype)
518        # pylint: enable= no-member, protected-access
519
520    def copyto(self, other):
521        """Copies the value of this array to another array.
522
523        If ``other`` is a ``NDArray`` or ``CSRNDArray`` object, then ``other.shape`` and
524        ``self.shape`` should be the same. This function copies the value from
525        ``self`` to ``other``.
526
527        If ``other`` is a context, a new ``CSRNDArray`` will be first created on
528        the target context, and the value of ``self`` is copied.
529
530        Parameters
531        ----------
532        other : NDArray or CSRNDArray or Context
533            The destination array or context.
534
535        Returns
536        -------
537        NDArray or CSRNDArray
538            The copied array. If ``other`` is an ``NDArray`` or ``CSRNDArray``, then the return
539            value and ``other`` will point to the same ``NDArray`` or ``CSRNDArray``.
540        """
541        if isinstance(other, Context):
542            return super(CSRNDArray, self).copyto(other)
543        elif isinstance(other, NDArray):
544            stype = other.stype
545            if stype in ('default', 'csr'):
546                return super(CSRNDArray, self).copyto(other)
547            else:
548                raise TypeError('copyto does not support destination NDArray stype ' + str(stype))
549        else:
550            raise TypeError('copyto does not support type ' + str(type(other)))
551
552    def asscipy(self):
553        """Returns a ``scipy.sparse.csr.csr_matrix`` object with value copied from this array
554
555        Examples
556        --------
557        >>> x = mx.nd.sparse.zeros('csr', (2,3))
558        >>> y = x.asscipy()
559        >>> type(y)
560        <type 'scipy.sparse.csr.csr_matrix'>
561        >>> y
562        <2x3 sparse matrix of type '<type 'numpy.float32'>'
563        with 0 stored elements in Compressed Sparse Row format>
564        """
565        data = self.data.asnumpy()
566        indices = self.indices.asnumpy()
567        indptr = self.indptr.asnumpy()
568        if not spsp:
569            raise ImportError("scipy could not be imported. "
570                              "Please make sure that the scipy is installed.")
571        return spsp.csr_matrix((data, indices, indptr), shape=self.shape, dtype=self.dtype)
572
573# pylint: disable=abstract-method
574class RowSparseNDArray(BaseSparseNDArray):
575    """A sparse representation of a set of NDArray row slices at given indices.
576
577    A RowSparseNDArray represents a multidimensional NDArray using two separate arrays: `data` and
578    `indices`. The number of dimensions has to be at least 2.
579
580    - data: an NDArray of any dtype with shape [D0, D1, ..., Dn].
581    - indices: a 1-D int64 NDArray with shape [D0] with values sorted in ascending order.
582
583    The `indices` stores the indices of the row slices with non-zeros,
584    while the values are stored in `data`. The corresponding NDArray ``dense``
585    represented by RowSparseNDArray ``rsp`` has
586
587    ``dense[rsp.indices[i], :, :, :, ...] = rsp.data[i, :, :, :, ...]``
588
589        >>> dense.asnumpy()
590        array([[ 1.,  2., 3.],
591               [ 0.,  0., 0.],
592               [ 4.,  0., 5.],
593               [ 0.,  0., 0.],
594               [ 0.,  0., 0.]], dtype=float32)
595        >>> rsp = dense.tostype('row_sparse')
596        >>> rsp.indices.asnumpy()
597        array([0, 2], dtype=int64)
598        >>> rsp.data.asnumpy()
599        array([[ 1.,  2., 3.],
600               [ 4.,  0., 5.]], dtype=float32)
601
602    A RowSparseNDArray is typically used to represent non-zero row slices of a large NDArray
603    of shape [LARGE0, D1, .. , Dn] where LARGE0 >> D0 and most row slices are zeros.
604
605    RowSparseNDArray is used principally in the definition of gradients for operations
606    that have sparse gradients (e.g. sparse dot and sparse embedding).
607
608    See Also
609    --------
610    row_sparse_array: Several ways to construct a RowSparseNDArray
611    """
612    def __reduce__(self):
613        return RowSparseNDArray, (None,), super(RowSparseNDArray, self).__getstate__()
614
615    def __iadd__(self, other):
616        (self + other).copyto(self)
617        return self
618
619    def __isub__(self, other):
620        (self - other).copyto(self)
621        return self
622
623    def __imul__(self, other):
624        (self * other).copyto(self)
625        return self
626
627    def __idiv__(self, other):
628        (self / other).copyto(self)
629        return self
630
631    def __itruediv__(self, other):
632        (self / other).copyto(self)
633        return self
634
635    def __getitem__(self, key):
636        """x.__getitem__(i) <=> x[i]
637
638        Returns a sliced view of this array.
639
640        Parameters
641        ----------
642        key : mxnet.ndarray.NDArray.slice
643            Indexing key.
644
645        Examples
646        --------
647        >>> x = mx.nd.sparse.zeros('row_sparse', (2, 3))
648        >>> x[:].asnumpy()
649        array([[ 0.,  0.,  0.],
650               [ 0.,  0.,  0.]], dtype=float32)
651        """
652        if isinstance(key, int):
653            raise Exception("__getitem__ with int key is not implemented for RowSparseNDArray yet")
654        if isinstance(key, py_slice):
655            if key.step is not None or key.start is not None or key.stop is not None:
656                raise Exception('RowSparseNDArray only supports [:] for __getitem__')
657
658            return self
659        if isinstance(key, tuple):
660            raise ValueError('Multi-dimension indexing is not supported')
661        raise ValueError('Undefined behaviour for {}'.format(key))
662
663    def __setitem__(self, key, value):
664        """x.__setitem__(i, y) <=> x[i]=y
665
666        Set self[key] to value. Only slice key [:] is supported.
667
668        Parameters
669        ----------
670        key : mxnet.ndarray.NDArray.slice
671            The indexing key.
672        value : NDArray or numpy.ndarray
673            The value to set.
674
675        Examples
676        --------
677        >>> src = mx.nd.row_sparse([[1, 0, 2], [4, 5, 6]], [0, 2], (3,3))
678        >>> src.asnumpy()
679        array([[ 1.,  0.,  2.],
680               [ 0.,  0.,  0.],
681               [ 4.,  5.,  6.]], dtype=float32)
682        >>> # assign RowSparseNDArray with same storage type
683        >>> x = mx.nd.sparse.zeros('row_sparse', (3,3))
684        >>> x[:] = src
685        >>> x.asnumpy()
686        array([[ 1.,  0.,  2.],
687               [ 0.,  0.,  0.],
688               [ 4.,  5.,  6.]], dtype=float32)
689        >>> # assign NDArray to RowSparseNDArray
690        >>> x[:] = mx.nd.ones((3,3))
691        >>> x.asnumpy()
692        array([[ 1.,  1.,  1.],
693               [ 1.,  1.,  1.],
694               [ 1.,  1.,  1.]], dtype=float32)
695        """
696        # pylint: disable= no-member, protected-access
697        if not self.writable:
698            raise ValueError('Failed to assign to a readonly RowSparseNDArray')
699        if isinstance(key, py_slice):
700            if key.step is not None or key.start is not None or key.stop is not None:
701                raise ValueError('Assignment with slice for RowSparseNDArray ' \
702                                 'is not implmented yet.')
703            if isinstance(value, NDArray):
704                # avoid copying to itself
705                if value.handle is not self.handle:
706                    value.copyto(self)
707            elif isinstance(value, numeric_types):
708                _internal._set_value(float(value), out=self)
709            elif isinstance(value, (np.ndarray, np.generic)):
710                warnings.warn('Assigning non-NDArray object to RowSparseNDArray is not efficient',
711                              RuntimeWarning)
712                tmp = _array(value)
713                tmp.copyto(self)
714            else:
715                raise TypeError('type %s not supported' % str(type(value)))
716        else:
717            assert(isinstance(key, (int, tuple)))
718            raise TypeError('RowSparseNDArray only supports [:] for assignment')
719        # pylint: enable= no-member, protected-access
720
721    @property
722    def indices(self):
723        """A deep copy NDArray of the indices array of the RowSparseNDArray.
724        This generates a deep copy of the row indices of the current `row_sparse` matrix.
725
726        Returns
727        -------
728        NDArray
729            This RowSparseNDArray's indices array.
730        """
731        return self._aux_data(0)
732
733    @property
734    def data(self):
735        """A deep copy NDArray of the data array of the RowSparseNDArray.
736        This generates a deep copy of the `data` of the current `row_sparse` matrix.
737
738        Returns
739        -------
740        NDArray
741            This RowSparseNDArray's data array.
742        """
743        return self._data()
744
745    @indices.setter
746    def indices(self, indices):
747        raise NotImplementedError()
748
749    @data.setter
750    def data(self, data):
751        raise NotImplementedError()
752
753    def tostype(self, stype):
754        """Return a copy of the array with chosen storage type.
755
756        Returns
757        -------
758        NDArray or RowSparseNDArray
759            A copy of the array with the chosen storage stype
760        """
761        # pylint: disable= no-member, protected-access
762        if stype == 'csr':
763            raise ValueError("cast_storage from row_sparse to csr is not supported")
764        return op.cast_storage(self, stype=stype)
765        # pylint: enable= no-member, protected-access
766
767    def copyto(self, other):
768        """Copies the value of this array to another array.
769
770        If ``other`` is a ``NDArray`` or ``RowSparseNDArray`` object, then ``other.shape``
771        and ``self.shape`` should be the same. This function copies the value from
772        ``self`` to ``other``.
773
774        If ``other`` is a context, a new ``RowSparseNDArray`` will be first created on
775        the target context, and the value of ``self`` is copied.
776
777        Parameters
778        ----------
779        other : NDArray or RowSparseNDArray or Context
780            The destination array or context.
781
782        Returns
783        -------
784        NDArray or RowSparseNDArray
785            The copied array. If ``other`` is an ``NDArray`` or ``RowSparseNDArray``, then the
786            return value and ``other`` will point to the same ``NDArray`` or ``RowSparseNDArray``.
787        """
788        if isinstance(other, Context):
789            return super(RowSparseNDArray, self).copyto(other)
790        elif isinstance(other, NDArray):
791            stype = other.stype
792            if stype in ('default', 'row_sparse'):
793                return super(RowSparseNDArray, self).copyto(other)
794            else:
795                raise TypeError('copyto does not support destination NDArray stype ' + str(stype))
796        else:
797            raise TypeError('copyto does not support type ' + str(type(other)))
798
799    def retain(self, *args, **kwargs):
800        """Convenience fluent method for :py:func:`retain`.
801
802        The arguments are the same as for :py:func:`retain`, with
803        this array as data.
804        """
805        if not gs_retain:
806            raise ImportError("gen_sparse could not be imported")
807        return gs_retain(*args, **kwargs)
808
809def _prepare_src_array(source_array, dtype):
810    """Prepare `source_array` so that it can be used to construct NDArray.
811    `source_array` is converted to a `np.ndarray` if it's neither an `NDArray` \
812    nor an `np.ndarray`.
813    """
814    if not isinstance(source_array, NDArray) and not isinstance(source_array, np.ndarray):
815        try:
816            source_array = np.array(source_array, dtype=dtype)
817        except:
818            raise TypeError('values must be array like object')
819    return source_array
820
821def _prepare_default_dtype(src_array, dtype):
822    """Prepare the value of dtype if `dtype` is None. If `src_array` is an NDArray, numpy.ndarray
823    or scipy.sparse.csr.csr_matrix, return src_array.dtype. float32 is returned otherwise."""
824    if dtype is None:
825        if isinstance(src_array, (NDArray, np.ndarray)):
826            dtype = src_array.dtype
827        elif spsp and isinstance(src_array, spsp.csr.csr_matrix):
828            dtype = src_array.dtype
829        else:
830            dtype = mx_real_t
831    return dtype
832
833def _check_shape(s1, s2):
834    """check s1 == s2 if both are not None"""
835    if s1 and s2 and s1 != s2:
836        raise ValueError("Shape mismatch detected. " + str(s1) + " v.s. " + str(s2))
837
838def csr_matrix(arg1, shape=None, ctx=None, dtype=None):
839    """Creates a `CSRNDArray`, an 2D array with compressed sparse row (CSR) format.
840
841    The CSRNDArray can be instantiated in several ways:
842
843    - csr_matrix(D):
844        to construct a CSRNDArray with a dense 2D array ``D``
845            -  **D** (*array_like*) - An object exposing the array interface, an object whose \
846            `__array__` method returns an array, or any (nested) sequence.
847            - **ctx** (*Context, optional*) - Device context \
848            (default is the current default context).
849            - **dtype** (*str or numpy.dtype, optional*) - The data type of the output array. \
850            The default dtype is ``D.dtype`` if ``D`` is an NDArray or numpy.ndarray, \
851            float32 otherwise.
852
853    - csr_matrix(S)
854        to construct a CSRNDArray with a sparse 2D array ``S``
855            -  **S** (*CSRNDArray or scipy.sparse.csr.csr_matrix*) - A sparse matrix.
856            - **ctx** (*Context, optional*) - Device context \
857            (default is the current default context).
858            - **dtype** (*str or numpy.dtype, optional*) - The data type of the output array. \
859            The default dtype is ``S.dtype``.
860
861    - csr_matrix((M, N))
862        to construct an empty CSRNDArray with shape ``(M, N)``
863            -  **M** (*int*) - Number of rows in the matrix
864            -  **N** (*int*) - Number of columns in the matrix
865            - **ctx** (*Context, optional*) - Device context \
866            (default is the current default context).
867            - **dtype** (*str or numpy.dtype, optional*) - The data type of the output array. \
868            The default dtype is float32.
869
870    - csr_matrix((data, indices, indptr))
871        to construct a CSRNDArray based on the definition of compressed sparse row format \
872        using three separate arrays, \
873        where the column indices for row i are stored in ``indices[indptr[i]:indptr[i+1]]`` \
874        and their corresponding values are stored in ``data[indptr[i]:indptr[i+1]]``. \
875        The column indices for a given row are expected to be **sorted in ascending order.** \
876        Duplicate column entries for the same row are not allowed.
877            - **data** (*array_like*) - An object exposing the array interface, which \
878            holds all the non-zero entries of the matrix in row-major order.
879            - **indices** (*array_like*) - An object exposing the array interface, which \
880            stores the column index for each non-zero element in ``data``.
881            - **indptr** (*array_like*) - An object exposing the array interface, which \
882            stores the offset into ``data`` of the first non-zero element number of each \
883            row of the matrix.
884            - **shape** (*tuple of int, optional*) - The shape of the array. The default \
885            shape is inferred from the indices and indptr arrays.
886            - **ctx** (*Context, optional*) - Device context \
887            (default is the current default context).
888            - **dtype** (*str or numpy.dtype, optional*) - The data type of the output array. \
889            The default dtype is ``data.dtype`` if ``data`` is an NDArray or numpy.ndarray, \
890            float32 otherwise.
891
892    - csr_matrix((data, (row, col)))
893        to construct a CSRNDArray based on the COOrdinate format \
894        using three seperate arrays, \
895        where ``row[i]`` is the row index of the element, \
896        ``col[i]`` is the column index of the element \
897        and ``data[i]`` is the data corresponding to the element. All the missing \
898        elements in the input are taken to be zeroes.
899            - **data** (*array_like*) - An object exposing the array interface, which \
900            holds all the non-zero entries of the matrix in COO format.
901            - **row** (*array_like*) - An object exposing the array interface, which \
902            stores the row index for each non zero element in ``data``.
903            - **col** (*array_like*) - An object exposing the array interface, which \
904            stores the col index for each non zero element in ``data``.
905            - **shape** (*tuple of int, optional*) - The shape of the array. The default \
906            shape is inferred from the ``row`` and ``col`` arrays.
907            - **ctx** (*Context, optional*) - Device context \
908            (default is the current default context).
909            - **dtype** (*str or numpy.dtype, optional*) - The data type of the output array. \
910            The default dtype is float32.
911
912    Parameters
913    ----------
914    arg1: tuple of int, tuple of array_like, array_like, CSRNDArray, scipy.sparse.csr_matrix, \
915    scipy.sparse.coo_matrix, tuple of int or tuple of array_like
916        The argument to help instantiate the csr matrix. See above for further details.
917    shape : tuple of int, optional
918        The shape of the csr matrix.
919    ctx: Context, optional
920        Device context (default is the current default context).
921    dtype: str or numpy.dtype, optional
922        The data type of the output array.
923
924    Returns
925    -------
926    CSRNDArray
927        A `CSRNDArray` with the `csr` storage representation.
928
929    Example
930    -------
931    >>> a = mx.nd.sparse.csr_matrix(([1, 2, 3], [1, 0, 2], [0, 1, 2, 2, 3]), shape=(4, 3))
932    >>> a.asnumpy()
933    array([[ 0.,  1.,  0.],
934           [ 2.,  0.,  0.],
935           [ 0.,  0.,  0.],
936           [ 0.,  0.,  3.]], dtype=float32)
937
938    See Also
939    --------
940    CSRNDArray : MXNet NDArray in compressed sparse row format.
941    """
942    # construct a csr matrix from (M, N) or (data, indices, indptr)
943    if isinstance(arg1, tuple):
944        arg_len = len(arg1)
945        if arg_len == 2:
946            # construct a sparse csr matrix from
947            # scipy coo matrix if input format is coo
948            if isinstance(arg1[1], tuple) and len(arg1[1]) == 2:
949                data, (row, col) = arg1
950                if isinstance(data, NDArray):
951                    data = data.asnumpy()
952                if isinstance(row, NDArray):
953                    row = row.asnumpy()
954                if isinstance(col, NDArray):
955                    col = col.asnumpy()
956                if not spsp:
957                    raise ImportError("scipy could not be imported. "
958                                      "Please make sure that the scipy is installed.")
959                coo = spsp.coo_matrix((data, (row, col)), shape=shape)
960                _check_shape(coo.shape, shape)
961                csr = coo.tocsr()
962                return array(csr, ctx=ctx, dtype=dtype)
963            else:
964                # empty matrix with shape
965                _check_shape(arg1, shape)
966                return empty('csr', arg1, ctx=ctx, dtype=dtype)
967        elif arg_len == 3:
968            # data, indices, indptr
969            return _csr_matrix_from_definition(arg1[0], arg1[1], arg1[2], shape=shape,
970                                               ctx=ctx, dtype=dtype)
971        else:
972            raise ValueError("Unexpected length of input tuple: " + str(arg_len))
973    else:
974        # construct a csr matrix from a sparse / dense one
975        if isinstance(arg1, CSRNDArray) or (spsp and isinstance(arg1, spsp.csr.csr_matrix)):
976            # construct a csr matrix from scipy or CSRNDArray
977            _check_shape(arg1.shape, shape)
978            return array(arg1, ctx=ctx, dtype=dtype)
979        elif isinstance(arg1, RowSparseNDArray):
980            raise ValueError("Unexpected input type: RowSparseNDArray")
981        else:
982            # construct a csr matrix from a dense one
983            # prepare default ctx and dtype since mx.nd.array doesn't use default values
984            # based on source_array
985            dtype = _prepare_default_dtype(arg1, dtype)
986            # create dns array with provided dtype. ctx is not passed since copy across
987            # ctx requires dtype to be the same
988            dns = _array(arg1, dtype=dtype)
989            if ctx is not None and dns.context != ctx:
990                dns = dns.as_in_context(ctx)
991            _check_shape(dns.shape, shape)
992            return dns.tostype('csr')
993
994def _csr_matrix_from_definition(data, indices, indptr, shape=None, ctx=None,
995                                dtype=None, indices_type=None, indptr_type=None):
996    """Create a `CSRNDArray` based on data, indices and indptr"""
997    # pylint: disable= no-member, protected-access
998    storage_type = 'csr'
999    # context
1000    ctx = current_context() if ctx is None else ctx
1001    # types
1002    dtype = _prepare_default_dtype(data, dtype)
1003    indptr_type = _STORAGE_AUX_TYPES[storage_type][0] if indptr_type is None else indptr_type
1004    indices_type = _STORAGE_AUX_TYPES[storage_type][1] if indices_type is None else indices_type
1005    # prepare src array and types
1006    data = _prepare_src_array(data, dtype)
1007    indptr = _prepare_src_array(indptr, indptr_type)
1008    indices = _prepare_src_array(indices, indices_type)
1009
1010    # TODO(junwu): Convert data, indptr, and indices to mxnet NDArrays
1011    # if they are not for now. In the future, we should provide a c-api
1012    # to accept np.ndarray types to copy from to result.data and aux_data
1013    if not isinstance(data, NDArray):
1014        data = _array(data, ctx, dtype)
1015    if not isinstance(indptr, NDArray):
1016        indptr = _array(indptr, ctx, indptr_type)
1017    if not isinstance(indices, NDArray):
1018        indices = _array(indices, ctx, indices_type)
1019    if shape is None:
1020        if indices.shape[0] == 0:
1021            raise ValueError('invalid shape')
1022        shape = (len(indptr) - 1, op.max(indices).asscalar() + 1)
1023    # verify shapes
1024    aux_shapes = [indptr.shape, indices.shape]
1025    if data.ndim != 1 or indptr.ndim != 1 or indices.ndim != 1 or \
1026        indptr.shape[0] == 0 or len(shape) != 2:
1027        raise ValueError('invalid shape')
1028    result = CSRNDArray(_new_alloc_handle(storage_type, shape, ctx, False, dtype,
1029                                          [indptr_type, indices_type], aux_shapes))
1030    check_call(_LIB.MXNDArraySyncCopyFromNDArray(result.handle, data.handle, ctypes.c_int(-1)))
1031    check_call(_LIB.MXNDArraySyncCopyFromNDArray(result.handle, indptr.handle, ctypes.c_int(0)))
1032    check_call(_LIB.MXNDArraySyncCopyFromNDArray(result.handle, indices.handle, ctypes.c_int(1)))
1033    return result
1034    # pylint: enable= no-member, protected-access
1035
1036def row_sparse_array(arg1, shape=None, ctx=None, dtype=None):
1037    """Creates a `RowSparseNDArray`, a multidimensional row sparse array with a set of \
1038    tensor slices at given indices.
1039
1040    The RowSparseNDArray can be instantiated in several ways:
1041
1042    - row_sparse_array(D):
1043        to construct a RowSparseNDArray with a dense ndarray ``D``
1044        -  **D** (*array_like*) - An object exposing the array interface, an object whose \
1045        `__array__` method returns an array, or any (nested) sequence.
1046        - **ctx** (*Context, optional*) - Device context \
1047        (default is the current default context).
1048        - **dtype** (*str or numpy.dtype, optional*) - The data type of the output array. \
1049        The default dtype is ``D.dtype`` if ``D`` is an NDArray or numpy.ndarray, \
1050        float32 otherwise.
1051
1052    - row_sparse_array(S)
1053        to construct a RowSparseNDArray with a sparse ndarray ``S``
1054        -  **S** (*RowSparseNDArray*) - A sparse ndarray.
1055        - **ctx** (*Context, optional*) - Device context \
1056        (default is the current default context).
1057        - **dtype** (*str or numpy.dtype, optional*) - The data type of the output array. \
1058        The default dtype is ``S.dtype``.
1059
1060    - row_sparse_array((D0, D1 .. Dn))
1061        to construct an empty RowSparseNDArray with shape ``(D0, D1, ... Dn)``
1062        -  **D0, D1 .. Dn** (*int*) - The shape of the ndarray
1063        - **ctx** (*Context, optional*) - Device context \
1064        (default is the current default context).
1065        - **dtype** (*str or numpy.dtype, optional*) - The data type of the output array. \
1066            The default dtype is float32.
1067
1068    - row_sparse_array((data, indices))
1069        to construct a RowSparseNDArray based on the definition of row sparse format \
1070        using two separate arrays, \
1071        where the `indices` stores the indices of the row slices with non-zeros,
1072        while the values are stored in `data`. The corresponding NDArray ``dense``
1073        represented by RowSparseNDArray ``rsp`` has \
1074        ``dense[rsp.indices[i], :, :, :, ...] = rsp.data[i, :, :, :, ...]``
1075        The row indices for are expected to be **sorted in ascending order.** \
1076        - **data** (*array_like*) - An object exposing the array interface, which \
1077        holds all the non-zero row slices of the array.
1078        - **indices** (*array_like*) - An object exposing the array interface, which \
1079        stores the row index for each row slice with non-zero elements.
1080        - **shape** (*tuple of int, optional*) - The shape of the array. The default \
1081        shape is inferred from the indices and indptr arrays.
1082        - **ctx** (*Context, optional*) - Device context \
1083        (default is the current default context).
1084        - **dtype** (*str or numpy.dtype, optional*) - The data type of the output array. \
1085        The default dtype is float32.
1086
1087    Parameters
1088    ----------
1089    arg1 : NDArray, numpy.ndarray, RowSparseNDArray, tuple of int or tuple of array_like
1090        The argument to help instantiate the row sparse ndarray. See above for further details.
1091    shape : tuple of int, optional
1092        The shape of the row sparse ndarray. (Default value = None)
1093    ctx : Context, optional
1094        Device context (default is the current default context).
1095    dtype : str or numpy.dtype, optional
1096        The data type of the output array. (Default value = None)
1097
1098    Returns
1099    -------
1100    RowSparseNDArray
1101        An `RowSparseNDArray` with the `row_sparse` storage representation.
1102
1103    Examples
1104    --------
1105    >>> a = mx.nd.sparse.row_sparse_array(([[1, 2], [3, 4]], [1, 4]), shape=(6, 2))
1106    >>> a.asnumpy()
1107    array([[ 0.,  0.],
1108           [ 1.,  2.],
1109           [ 0.,  0.],
1110           [ 0.,  0.],
1111           [ 3.,  4.],
1112           [ 0.,  0.]], dtype=float32)
1113
1114    See Also
1115    --------
1116    RowSparseNDArray : MXNet NDArray in row sparse format.
1117    """
1118    # construct a row sparse array from (D0, D1 ..) or (data, indices)
1119    if isinstance(arg1, tuple):
1120        arg_len = len(arg1)
1121        if arg_len < 2:
1122            raise ValueError("Unexpected length of input tuple: " + str(arg_len))
1123        if arg_len > 2:
1124            # empty ndarray with shape
1125            _check_shape(arg1, shape)
1126            return empty('row_sparse', arg1, ctx=ctx, dtype=dtype)
1127        else:
1128            # len(arg1) = 2, is either shape or (data, indices)
1129            if isinstance(arg1[0], integer_types) and isinstance(arg1[1], integer_types):
1130                # empty ndarray with shape
1131                _check_shape(arg1, shape)
1132                return empty('row_sparse', arg1, ctx=ctx, dtype=dtype)
1133            else:
1134                # data, indices, indptr
1135                return _row_sparse_ndarray_from_definition(arg1[0], arg1[1], shape=shape,
1136                                                           ctx=ctx, dtype=dtype)
1137    else:
1138        # construct a row sparse ndarray from a dense / sparse array
1139        if isinstance(arg1, RowSparseNDArray):
1140            # construct a row sparse ndarray from RowSparseNDArray
1141            _check_shape(arg1.shape, shape)
1142            return array(arg1, ctx=ctx, dtype=dtype)
1143        elif isinstance(arg1, CSRNDArray):
1144            raise ValueError("Unexpected input type: CSRNDArray")
1145        else:
1146            # construct a csr matrix from a dense one
1147            # prepare default dtype since mx.nd.array doesn't use default values
1148            # based on source_array
1149            dtype = _prepare_default_dtype(arg1, dtype)
1150            # create dns array with provided dtype. ctx is not passed since copy across
1151            # ctx requires dtype to be the same
1152            dns = _array(arg1, dtype=dtype)
1153            if ctx is not None and dns.context != ctx:
1154                dns = dns.as_in_context(ctx)
1155            _check_shape(dns.shape, shape)
1156            return dns.tostype('row_sparse')
1157
1158def _row_sparse_ndarray_from_definition(data, indices, shape=None, ctx=None,
1159                                        dtype=None, indices_type=None):
1160    """Create a `RowSparseNDArray` based on data and indices"""
1161    storage_type = 'row_sparse'
1162    # context
1163    ctx = current_context() if ctx is None else ctx
1164    # types
1165    dtype = _prepare_default_dtype(data, dtype)
1166    indices_type = _STORAGE_AUX_TYPES[storage_type][0] if indices_type is None else indices_type
1167    # prepare src array and types
1168    data = _prepare_src_array(data, dtype)
1169    indices = _prepare_src_array(indices, indices_type)
1170
1171    # TODO(junwu): Convert data, indptr, and indices to mxnet NDArrays
1172    # if they are not for now. In the future, we should provide a c-api
1173    # to accept np.ndarray types to copy from to result.data and aux_data
1174    if not isinstance(data, NDArray):
1175        data = _array(data, ctx, dtype)
1176    if not isinstance(indices, NDArray):
1177        indices = _array(indices, ctx, indices_type)
1178    if shape is None:
1179        num_indices = indices.shape[0]
1180        if num_indices == 0:
1181            raise ValueError('invalid shape')
1182        dim0 = indices[num_indices - 1].asscalar() + 1
1183        shape = (dim0, ) + data.shape[1:]
1184    # verify shapes
1185    if data.ndim != len(shape) or indices.ndim != 1 or np.prod(shape[1:]) == 0:
1186        raise ValueError("invalid shape")
1187    result = RowSparseNDArray(_new_alloc_handle(storage_type, shape, ctx, False, dtype,
1188                                                [indices_type], [indices.shape]))
1189    check_call(_LIB.MXNDArraySyncCopyFromNDArray(result.handle, data.handle, ctypes.c_int(-1)))
1190    check_call(_LIB.MXNDArraySyncCopyFromNDArray(result.handle, indices.handle, ctypes.c_int(0)))
1191    return result
1192
1193def _ndarray_cls(handle, writable=True, stype=_STORAGE_TYPE_UNDEFINED):
1194    if stype == _STORAGE_TYPE_UNDEFINED:
1195        stype = _storage_type(handle)
1196    if stype == _STORAGE_TYPE_DEFAULT:
1197        return NDArray(handle, writable=writable)
1198    elif stype == _STORAGE_TYPE_CSR:
1199        return CSRNDArray(handle, writable=writable)
1200    elif stype == _STORAGE_TYPE_ROW_SPARSE:
1201        return RowSparseNDArray(handle, writable=writable)
1202    else:
1203        raise Exception("unknown storage type: %s"%stype)
1204
1205
1206_set_ndarray_class(_ndarray_cls)
1207
1208
1209def add(lhs, rhs):
1210    """Returns element-wise sum of the input arrays with broadcasting.
1211
1212    Equivalent to ``lhs + rhs``, ``mx.nd.broadcast_add(lhs, rhs)`` and
1213    ``mx.nd.broadcast_plus(lhs, rhs)`` when shapes of lhs and rhs do not
1214    match. If lhs.shape == rhs.shape, this is equivalent to
1215    ``mx.nd.elemwise_add(lhs, rhs)``
1216
1217    .. note::
1218
1219        If the corresponding dimensions of two arrays have the same size or one of them has size 1,
1220        then the arrays are broadcastable to a common shape.abs
1221
1222    Parameters
1223    ----------
1224    lhs : scalar or mxnet.ndarray.sparse.array
1225        First array to be added.
1226    rhs : scalar or mxnet.ndarray.sparse.array
1227         Second array to be added.
1228        If ``lhs.shape != rhs.shape``, they must be
1229        broadcastable to a common shape.
1230
1231    Returns
1232    -------
1233    NDArray
1234        The element-wise sum of the input arrays.
1235
1236    Examples
1237    --------
1238    >>> a = mx.nd.ones((2,3)).tostype('csr')
1239    >>> b = mx.nd.ones((2,3)).tostype('csr')
1240    >>> a.asnumpy()
1241    array([[ 1.,  1.,  1.],
1242           [ 1.,  1.,  1.]], dtype=float32)
1243    >>> b.asnumpy()
1244    array([[ 1.,  1.,  1.],
1245           [ 1.,  1.,  1.]], dtype=float32)
1246    >>> (a+b).asnumpy()
1247    array([[ 2.,  2.,  2.],
1248           [ 2.,  2.,  2.]], dtype=float32)
1249    >>> c = mx.nd.ones((2,3)).tostype('row_sparse')
1250    >>> d = mx.nd.ones((2,3)).tostype('row_sparse')
1251    >>> c.asnumpy()
1252    array([[ 1.,  1.,  1.],
1253           [ 1.,  1.,  1.]], dtype=float32)
1254    >>> d.asnumpy()
1255    array([[ 1.,  1.,  1.],
1256           [ 1.,  1.,  1.]], dtype=float32)
1257    >>> (c+d).asnumpy()
1258    array([[ 2.,  2.,  2.],
1259           [ 2.,  2.,  2.]], dtype=float32)
1260    """
1261    # pylint: disable= no-member, protected-access
1262    if isinstance(lhs, NDArray) and isinstance(rhs, NDArray) and lhs.shape == rhs.shape:
1263        return _ufunc_helper(
1264            lhs,
1265            rhs,
1266            op.elemwise_add,
1267            operator.add,
1268            _internal._plus_scalar,
1269            None)
1270
1271    return _ufunc_helper(
1272        lhs,
1273        rhs,
1274        op.broadcast_add,
1275        operator.add,
1276        _internal._plus_scalar,
1277        None)
1278    # pylint: enable= no-member, protected-access
1279
1280
1281def subtract(lhs, rhs):
1282    """Returns element-wise difference of the input arrays with broadcasting.
1283
1284    Equivalent to ``lhs - rhs``, ``mx.nd.broadcast_sub(lhs, rhs)`` and
1285    ``mx.nd.broadcast_minus(lhs, rhs)`` when shapes of lhs and rhs do not
1286    match. If lhs.shape == rhs.shape, this is equivalent to
1287    ``mx.nd.elemwise_sub(lhs, rhs)``
1288
1289    .. note::
1290
1291        If the corresponding dimensions of two arrays have the same size or one of them has size 1,
1292        then the arrays are broadcastable to a common shape.
1293
1294    Parameters
1295    ----------
1296    lhs : scalar or mxnet.ndarray.sparse.array
1297        First array to be subtracted.
1298    rhs : scalar or mxnet.ndarray.sparse.array
1299         Second array to be subtracted.
1300        If ``lhs.shape != rhs.shape``, they must be
1301        broadcastable to a common shape.__spec__
1302
1303    Returns
1304    -------
1305    NDArray
1306        The element-wise difference of the input arrays.
1307
1308    Examples
1309    --------
1310    >>> a = mx.nd.ones((2,3)).tostype('csr')
1311    >>> b = mx.nd.ones((2,3)).tostype('csr')
1312    >>> a.asnumpy()
1313    array([[ 1.,  1.,  1.],
1314           [ 1.,  1.,  1.]], dtype=float32)
1315    >>> b.asnumpy()
1316    array([[ 1.,  1.,  1.],
1317           [ 1.,  1.,  1.]], dtype=float32)
1318    >>> (a-b).asnumpy()
1319    array([[ 0.,  0.,  0.],
1320           [ 0.,  0.,  0.]], dtype=float32)
1321    >>> c = mx.nd.ones((2,3)).tostype('row_sparse')
1322    >>> d = mx.nd.ones((2,3)).tostype('row_sparse')
1323    >>> c.asnumpy()
1324    array([[ 1.,  1.,  1.],
1325           [ 1.,  1.,  1.]], dtype=float32)
1326    >>> d.asnumpy()
1327    array([[ 1.,  1.,  1.],
1328           [ 1.,  1.,  1.]], dtype=float32)
1329    >>> (c-d).asnumpy()
1330    array([[ 0.,  0.,  0.],
1331           [ 0.,  0.,  0.]], dtype=float32)
1332    """
1333    # pylint: disable= no-member, protected-access
1334    if isinstance(lhs, NDArray) and isinstance(rhs, NDArray) and lhs.shape == rhs.shape:
1335        return _ufunc_helper(
1336            lhs,
1337            rhs,
1338            op.elemwise_sub,
1339            operator.sub,
1340            _internal._minus_scalar,
1341            None)
1342
1343    return _ufunc_helper(
1344        lhs,
1345        rhs,
1346        op.broadcast_sub,
1347        operator.sub,
1348        _internal._minus_scalar,
1349        None)
1350    # pylint: enable= no-member, protected-access
1351
1352
1353def multiply(lhs, rhs):
1354    """Returns element-wise product of the input arrays with broadcasting.
1355
1356        Equivalent to ``lhs * rhs`` and ``mx.nd.broadcast_mul(lhs, rhs)``
1357        when shapes of lhs and rhs do not match. If lhs.shape == rhs.shape,
1358        this is equivalent to ``mx.nd.elemwise_mul(lhs, rhs)``
1359
1360    .. note::
1361
1362        If the corresponding dimensions of two arrays have the same size or one of them has size 1,
1363        then the arrays are broadcastable to a common shape.
1364
1365    Parameters
1366    ----------
1367    lhs : scalar or mxnet.ndarray.sparse.array
1368        First array to be multiplied.
1369    rhs : scalar or mxnet.ndarray.sparse.array
1370         Second array to be multiplied.
1371        If ``lhs.shape != rhs.shape``, they must be
1372        broadcastable to a common shape.
1373
1374    Returns
1375    -------
1376    NDArray
1377        The element-wise multiplication of the input arrays.
1378
1379    Examples
1380    --------
1381    >>> x = mx.nd.ones((2,3)).tostype('csr')
1382    >>> y = mx.nd.arange(2).reshape((2,1))
1383    >>> z = mx.nd.arange(3)
1384    >>> x.asnumpy()
1385    array([[ 1.,  1.,  1.],
1386           [ 1.,  1.,  1.]], dtype=float32)
1387    >>> y.asnumpy()
1388    array([[ 0.],
1389           [ 1.]], dtype=float32)
1390    >>> z.asnumpy()
1391    array([ 0.,  1.,  2.], dtype=float32)
1392    >>> (x*2).asnumpy()
1393    array([[ 2.,  2.,  2.],
1394           [ 2.,  2.,  2.]], dtype=float32)
1395    >>> (x*y).asnumpy()
1396    array([[ 0.,  0.,  0.],
1397           [ 1.,  1.,  1.]], dtype=float32)
1398    >>> mx.nd.sparse.multiply(x, y).asnumpy()
1399    array([[ 0.,  0.,  0.],
1400           [ 1.,  1.,  1.]], dtype=float32)
1401    >>> (x*z).asnumpy()
1402    array([[ 0.,  1.,  2.],
1403           [ 0.,  1.,  2.]], dtype=float32)
1404    >>> mx.nd.sparse.multiply(x, z).asnumpy()
1405    array([[ 0.,  1.,  2.],
1406           [ 0.,  1.,  2.]], dtype=float32)
1407    >>> z = z.reshape((1, 3))
1408    >>> z.asnumpy()
1409    array([[ 0.,  1.,  2.]], dtype=float32)
1410    >>> (x*z).asnumpy()
1411    array([[ 0.,  1.,  2.],
1412           [ 0.,  1.,  2.]], dtype=float32)
1413    >>> mx.nd.sparse.multiply(x, z).asnumpy()
1414    array([[ 0.,  1.,  2.],
1415           [ 0.,  1.,  2.]], dtype=float32)
1416    """
1417    # pylint: disable= no-member, protected-access
1418    if isinstance(lhs, NDArray) and isinstance(rhs, NDArray) and lhs.shape == rhs.shape:
1419        return _ufunc_helper(
1420            lhs,
1421            rhs,
1422            op.elemwise_mul,
1423            operator.mul,
1424            _internal._mul_scalar,
1425            None)
1426
1427    return _ufunc_helper(
1428        lhs,
1429        rhs,
1430        op.broadcast_mul,
1431        operator.mul,
1432        _internal._mul_scalar,
1433        None)
1434    # pylint: enable= no-member, protected-access
1435
1436
1437def divide(lhs, rhs):
1438    """Returns element-wise division of the input arrays with broadcasting.
1439
1440    Equivalent to ``lhs / rhs`` and ``mx.nd.broadcast_div(lhs, rhs)``
1441    when shapes of lhs and rhs do not match. If lhs.shape == rhs.shape,
1442    this is equivalent to ``mx.nd.elemwise_div(lhs, rhs)``
1443
1444    .. note::
1445
1446        If the corresponding dimensions of two arrays have the same size or one of them has size 1,
1447        then the arrays are broadcastable to a common shape.
1448
1449    Parameters
1450    ----------
1451    lhs : scalar or mxnet.ndarray.sparse.array
1452        First array in division.
1453    rhs : scalar or mxnet.ndarray.sparse.array
1454         Second array in division.
1455        The arrays to be divided. If ``lhs.shape != rhs.shape``, they must be
1456        broadcastable to a common shape.
1457
1458    Returns
1459    -------
1460    NDArray
1461        The element-wise division of the input arrays.
1462
1463    Examples
1464    --------
1465    >>> x = (mx.nd.ones((2,3))*6).tostype('csr')
1466    >>> y = mx.nd.arange(2).reshape((2,1)) + 1
1467    >>> z = mx.nd.arange(3) + 1
1468    >>> x.asnumpy()
1469    array([[ 6.,  6.,  6.],
1470           [ 6.,  6.,  6.]], dtype=float32)
1471    >>> y.asnumpy()
1472    array([[ 1.],
1473           [ 2.]], dtype=float32)
1474    >>> z.asnumpy()
1475    array([ 1.,  2.,  3.], dtype=float32)
1476    >>> x/2
1477    <NDArray 2x3 @cpu(0)>
1478    >>> (x/3).asnumpy()
1479    array([[ 2.,  2.,  2.],
1480           [ 2.,  2.,  2.]], dtype=float32)
1481    >>> (x/y).asnumpy()
1482    array([[ 6.,  6.,  6.],
1483           [ 3.,  3.,  3.]], dtype=float32)
1484    >>> mx.nd.sparse.divide(x,y).asnumpy()
1485    array([[ 6.,  6.,  6.],
1486           [ 3.,  3.,  3.]], dtype=float32)
1487    >>> (x/z).asnumpy()
1488    array([[ 6.,  3.,  2.],
1489           [ 6.,  3.,  2.]], dtype=float32)
1490    >>> mx.nd.sprase.divide(x,z).asnumpy()
1491    array([[ 6.,  3.,  2.],
1492           [ 6.,  3.,  2.]], dtype=float32)
1493    >>> z = z.reshape((1,3))
1494    >>> z.asnumpy()
1495    array([[ 1.,  2.,  3.]], dtype=float32)
1496    >>> (x/z).asnumpy()
1497    array([[ 6.,  3.,  2.],
1498           [ 6.,  3.,  2.]], dtype=float32)
1499    >>> mx.nd.sparse.divide(x,z).asnumpy()
1500    array([[ 6.,  3.,  2.],
1501           [ 6.,  3.,  2.]], dtype=float32)
1502    """
1503    # pylint: disable= no-member, protected-access
1504    if isinstance(lhs, NDArray) and isinstance(rhs, NDArray) and lhs.shape == rhs.shape:
1505        return _ufunc_helper(
1506            lhs,
1507            rhs,
1508            op.elemwise_div,
1509            operator.truediv,
1510            _internal._div_scalar,
1511            None)
1512
1513    return _ufunc_helper(
1514        lhs,
1515        rhs,
1516        op.broadcast_div,
1517        operator.truediv,
1518        _internal._div_scalar,
1519        None)
1520    # pylint: enable= no-member, protected-access
1521
1522
1523def zeros(stype, shape, ctx=None, dtype=None, **kwargs):
1524    """Return a new array of given shape and type, filled with zeros.
1525
1526    Parameters
1527    ----------
1528    stype: string
1529        The storage type of the empty array, such as 'row_sparse', 'csr', etc
1530    shape : int or tuple of int
1531        The shape of the empty array
1532    ctx : Context, optional
1533        An optional device context (default is the current default context)
1534    dtype : str or numpy.dtype, optional
1535        An optional value type (default is `float32`)
1536
1537    Returns
1538    -------
1539    RowSparseNDArray or CSRNDArray
1540        A created array
1541    Examples
1542    --------
1543    >>> mx.nd.sparse.zeros('csr', (1,2))
1544    <CSRNDArray 1x2 @cpu(0)>
1545    >>> mx.nd.sparse.zeros('row_sparse', (1,2), ctx=mx.cpu(), dtype='float16').asnumpy()
1546    array([[ 0.,  0.]], dtype=float16)
1547    """
1548    # pylint: disable= no-member, protected-access
1549    if stype == 'default':
1550        return _zeros_ndarray(shape, ctx=ctx, dtype=dtype, **kwargs)
1551    if ctx is None:
1552        ctx = current_context()
1553    dtype = mx_real_t if dtype is None else dtype
1554    if stype in ('row_sparse', 'csr'):
1555        aux_types = _STORAGE_AUX_TYPES[stype]
1556    else:
1557        raise ValueError("unknown storage type" + stype)
1558    out = _ndarray_cls(_new_alloc_handle(stype, shape, ctx, True, dtype, aux_types))
1559    return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, out=out, **kwargs)
1560    # pylint: enable= no-member, protected-access
1561
1562
1563def empty(stype, shape, ctx=None, dtype=None):
1564    """Returns a new array of given shape and type, without initializing entries.
1565
1566    Parameters
1567    ----------
1568    stype: string
1569        The storage type of the empty array, such as 'row_sparse', 'csr', etc
1570    shape : int or tuple of int
1571        The shape of the empty array.
1572    ctx : Context, optional
1573        An optional device context (default is the current default context).
1574    dtype : str or numpy.dtype, optional
1575        An optional value type (default is `float32`).
1576
1577    Returns
1578    -------
1579    CSRNDArray or RowSparseNDArray
1580        A created array.
1581    """
1582    if isinstance(shape, int):
1583        shape = (shape, )
1584    if ctx is None:
1585        ctx = current_context()
1586    if dtype is None:
1587        dtype = mx_real_t
1588    assert(stype is not None)
1589    if stype in ('csr', 'row_sparse'):
1590        return zeros(stype, shape, ctx=ctx, dtype=dtype)
1591    else:
1592        raise Exception("unknown stype : " + str(stype))
1593
1594
1595def array(source_array, ctx=None, dtype=None):
1596    """Creates a sparse array from any object exposing the array interface.
1597
1598    Parameters
1599    ----------
1600    source_array : RowSparseNDArray, CSRNDArray or scipy.sparse.csr.csr_matrix
1601        The source sparse array
1602    ctx : Context, optional
1603        The default context is ``source_array.context`` if ``source_array`` is an NDArray. \
1604        The current default context otherwise.
1605    dtype : str or numpy.dtype, optional
1606        The data type of the output array. The default dtype is ``source_array.dtype``
1607        if `source_array` is an `NDArray`, `numpy.ndarray` or `scipy.sparse.csr.csr_matrix`, \
1608        `float32` otherwise.
1609
1610    Returns
1611    -------
1612    RowSparseNDArray or CSRNDArray
1613        An array with the same contents as the `source_array`.
1614
1615    Examples
1616    --------
1617    >>> import scipy.sparse as spsp
1618    >>> csr = spsp.csr_matrix((2, 100))
1619    >>> mx.nd.sparse.array(csr)
1620    <CSRNDArray 2x100 @cpu(0)>
1621    >>> mx.nd.sparse.array(mx.nd.sparse.zeros('csr', (3, 2)))
1622    <CSRNDArray 3x2 @cpu(0)>
1623    >>> mx.nd.sparse.array(mx.nd.sparse.zeros('row_sparse', (3, 2)))
1624    <RowSparseNDArray 3x2 @cpu(0)>
1625    """
1626    ctx = current_context() if ctx is None else ctx
1627    if isinstance(source_array, NDArray):
1628        assert(source_array.stype != 'default'), \
1629               "Please use `tostype` to create RowSparseNDArray or CSRNDArray from an NDArray"
1630        # prepare dtype and ctx based on source_array, if not provided
1631        dtype = _prepare_default_dtype(source_array, dtype)
1632        # if both dtype and ctx are different from source_array, we cannot copy directly
1633        if source_array.dtype != dtype and source_array.context != ctx:
1634            arr = empty(source_array.stype, source_array.shape, dtype=dtype)
1635            arr[:] = source_array
1636            arr = arr.as_in_context(ctx)
1637        else:
1638            arr = empty(source_array.stype, source_array.shape, dtype=dtype, ctx=ctx)
1639            arr[:] = source_array
1640        return arr
1641    elif spsp and isinstance(source_array, spsp.csr.csr_matrix):
1642        # TODO(haibin) implement `_sync_copy_from` with scipy csr object to reduce a copy
1643        # preprocess scipy csr to canonical form
1644        csr = source_array.sorted_indices()
1645        csr.sum_duplicates()
1646        dtype = _prepare_default_dtype(source_array, dtype)
1647        return csr_matrix((csr.data, csr.indices, csr.indptr), shape=csr.shape, \
1648                          dtype=dtype, ctx=ctx)
1649    elif isinstance(source_array, (np.ndarray, np.generic)):
1650        raise ValueError("Please use mx.nd.array to create an NDArray with source_array of type ",
1651                         type(source_array))
1652    else:
1653        raise ValueError("Unexpected source_array type: ", type(source_array))
1654