1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# coding: utf-8
19"""Utility functions for NDArray and BaseSparseNDArray."""
20import ctypes
21
22from ..base import _LIB, check_call, py_str, c_str, string_types, mx_uint, NDArrayHandle
23from ..base import c_array, c_handle_array, c_str_array
24from .ndarray import NDArray
25from .ndarray import array as _array
26from .ndarray import empty as _empty_ndarray
27from .ndarray import zeros as _zeros_ndarray
28from .sparse import zeros as _zeros_sparse_ndarray
29from .sparse import empty as _empty_sparse_ndarray
30from .sparse import array as _sparse_array
31from .sparse import _ndarray_cls
32try:
33    import scipy.sparse as spsp
34except ImportError:
35    spsp = None
36
37__all__ = ['zeros', 'empty', 'array', 'load', 'load_frombuffer', 'save']
38
39
40def zeros(shape, ctx=None, dtype=None, stype=None, **kwargs):
41    """Return a new array of given shape and type, filled with zeros.
42
43    Parameters
44    ----------
45    shape : int or tuple of int
46        The shape of the empty array
47    ctx : Context, optional
48        An optional device context (default is the current default context)
49    dtype : str or numpy.dtype, optional
50        An optional value type (default is `float32`)
51    stype: string, optional
52        The storage type of the empty array, such as 'row_sparse', 'csr', etc.
53
54    Returns
55    -------
56    NDArray, CSRNDArray or RowSparseNDArray
57        A created array
58    Examples
59    --------
60    >>> mx.nd.zeros((1,2), mx.cpu(), stype='csr')
61    <CSRNDArray 1x2 @cpu(0)>
62    >>> mx.nd.zeros((1,2), mx.cpu(), 'float16', stype='row_sparse').asnumpy()
63    array([[ 0.,  0.]], dtype=float16)
64    """
65
66    if stype is None or stype == 'default':
67        return _zeros_ndarray(shape, ctx, dtype, **kwargs)
68    else:
69        return _zeros_sparse_ndarray(stype, shape, ctx, dtype, **kwargs)
70
71
72def empty(shape, ctx=None, dtype=None, stype=None):
73    """Returns a new array of given shape and type, without initializing entries.
74
75    Parameters
76    ----------
77    shape : int or tuple of int
78        The shape of the empty array.
79    ctx : Context, optional
80        An optional device context (default is the current default context).
81    dtype : str or numpy.dtype, optional
82        An optional value type (default is `float32`).
83    stype : str, optional
84        An optional storage type (default is `default`).
85
86    Returns
87    -------
88    NDArray, CSRNDArray or RowSparseNDArray
89        A created array.
90
91    Examples
92    --------
93    >>> mx.nd.empty(1)
94    <NDArray 1 @cpu(0)>
95    >>> mx.nd.empty((1,2), mx.gpu(0))
96    <NDArray 1x2 @gpu(0)>
97    >>> mx.nd.empty((1,2), mx.gpu(0), 'float16')
98    <NDArray 1x2 @gpu(0)>
99    >>> mx.nd.empty((1,2), stype='csr')
100    <CSRNDArray 1x2 @cpu(0)>
101    """
102    if stype is None or stype == 'default':
103        return _empty_ndarray(shape, ctx, dtype)
104    else:
105        return _empty_sparse_ndarray(stype, shape, ctx, dtype)
106
107
108def array(source_array, ctx=None, dtype=None):
109    """Creates an array from any object exposing the array interface.
110
111    Parameters
112    ----------
113    source_array : array_like
114        An object exposing the array interface, an object whose `__array__`
115        method returns an array, or any (nested) sequence.
116    ctx : Context, optional
117        Device context (default is the current default context).
118    dtype : str or numpy.dtype, optional
119        The data type of the output array. The default dtype is ``source_array.dtype``
120        if `source_array` is an `NDArray`, `float32` otherwise.
121
122    Returns
123    -------
124    NDArray, RowSparseNDArray or CSRNDArray
125        An array with the same contents as the `source_array`.
126
127    Examples
128    --------
129    >>> import numpy as np
130    >>> mx.nd.array([1, 2, 3])
131    <NDArray 3 @cpu(0)>
132    >>> mx.nd.array([[1, 2], [3, 4]])
133    <NDArray 2x2 @cpu(0)>
134    >>> mx.nd.array(np.zeros((3, 2)))
135    <NDArray 3x2 @cpu(0)>
136    >>> mx.nd.array(np.zeros((3, 2)), mx.gpu(0))
137    <NDArray 3x2 @gpu(0)>
138    >>> mx.nd.array(mx.nd.zeros((3, 2), stype='row_sparse'))
139    <RowSparseNDArray 3x2 @cpu(0)>
140    """
141    if spsp is not None and isinstance(source_array, spsp.csr.csr_matrix):
142        return _sparse_array(source_array, ctx=ctx, dtype=dtype)
143    elif isinstance(source_array, NDArray) and source_array.stype != 'default':
144        return _sparse_array(source_array, ctx=ctx, dtype=dtype)
145    else:
146        return _array(source_array, ctx=ctx, dtype=dtype)
147
148
149def load(fname):
150    """Loads an array from file.
151
152    See more details in ``save``.
153
154    Parameters
155    ----------
156    fname : str
157        The filename.
158
159    Returns
160    -------
161    list of NDArray, RowSparseNDArray or CSRNDArray, or \
162    dict of str to NDArray, RowSparseNDArray or CSRNDArray
163        Loaded data.
164    """
165    if not isinstance(fname, string_types):
166        raise TypeError('fname required to be a string')
167    out_size = mx_uint()
168    out_name_size = mx_uint()
169    handles = ctypes.POINTER(NDArrayHandle)()
170    names = ctypes.POINTER(ctypes.c_char_p)()
171    check_call(_LIB.MXNDArrayLoad(c_str(fname),
172                                  ctypes.byref(out_size),
173                                  ctypes.byref(handles),
174                                  ctypes.byref(out_name_size),
175                                  ctypes.byref(names)))
176    if out_name_size.value == 0:
177        return [_ndarray_cls(NDArrayHandle(handles[i])) for i in range(out_size.value)]
178    else:
179        assert out_name_size.value == out_size.value
180        return dict(
181            (py_str(names[i]), _ndarray_cls(NDArrayHandle(handles[i])))
182            for i in range(out_size.value))
183
184
185def load_frombuffer(buf):
186    """Loads an array dictionary or list from a buffer
187
188    See more details in ``save``.
189
190    Parameters
191    ----------
192    buf : str
193        Buffer containing contents of a file as a string or bytes.
194
195    Returns
196    -------
197    list of NDArray, RowSparseNDArray or CSRNDArray, or \
198    dict of str to NDArray, RowSparseNDArray or CSRNDArray
199        Loaded data.
200    """
201    if not isinstance(buf, string_types + tuple([bytes])):
202        raise TypeError('buf required to be a string or bytes')
203    out_size = mx_uint()
204    out_name_size = mx_uint()
205    handles = ctypes.POINTER(NDArrayHandle)()
206    names = ctypes.POINTER(ctypes.c_char_p)()
207    check_call(_LIB.MXNDArrayLoadFromBuffer(buf,
208                                            mx_uint(len(buf)),
209                                            ctypes.byref(out_size),
210                                            ctypes.byref(handles),
211                                            ctypes.byref(out_name_size),
212                                            ctypes.byref(names)))
213    if out_name_size.value == 0:
214        return [_ndarray_cls(NDArrayHandle(handles[i])) for i in range(out_size.value)]
215    else:
216        assert out_name_size.value == out_size.value
217        return dict(
218            (py_str(names[i]), _ndarray_cls(NDArrayHandle(handles[i])))
219            for i in range(out_size.value))
220
221
222def save(fname, data):
223    """Saves a list of arrays or a dict of str->array to file.
224
225    Examples of filenames:
226
227    - ``/path/to/file``
228    - ``s3://my-bucket/path/to/file`` (if compiled with AWS S3 supports)
229    - ``hdfs://path/to/file`` (if compiled with HDFS supports)
230
231    Parameters
232    ----------
233    fname : str
234        The filename.
235    data : NDArray, RowSparseNDArray or CSRNDArray, \
236           or list of NDArray, RowSparseNDArray or CSRNDArray, \
237           or dict of str to NDArray, RowSparseNDArray or CSRNDArray
238        The data to save.
239
240    Examples
241    --------
242    >>> x = mx.nd.zeros((2,3))
243    >>> y = mx.nd.ones((1,4))
244    >>> mx.nd.save('my_list', [x,y])
245    >>> mx.nd.save('my_dict', {'x':x, 'y':y})
246    >>> mx.nd.load('my_list')
247    [<NDArray 2x3 @cpu(0)>, <NDArray 1x4 @cpu(0)>]
248    >>> mx.nd.load('my_dict')
249    {'y': <NDArray 1x4 @cpu(0)>, 'x': <NDArray 2x3 @cpu(0)>}
250    """
251    from ..numpy import ndarray as np_ndarray
252    if isinstance(data, NDArray):
253        data = [data]
254        handles = c_array(NDArrayHandle, [])
255    if isinstance(data, dict):
256        str_keys = data.keys()
257        nd_vals = data.values()
258        if any(not isinstance(k, string_types) for k in str_keys) or \
259           any(not isinstance(v, NDArray) for v in nd_vals):
260            raise TypeError('save only accept dict str->NDArray or list of NDArray')
261        if any(isinstance(v, np_ndarray) for v in nd_vals):
262            raise TypeError('cannot save mxnet.numpy.ndarray using mxnet.ndarray.save;'
263                            ' use mxnet.numpy.save instead.')
264        keys = c_str_array(str_keys)
265        handles = c_handle_array(nd_vals)
266    elif isinstance(data, list):
267        if any(not isinstance(v, NDArray) for v in data):
268            raise TypeError('save only accept dict str->NDArray or list of NDArray')
269        if any(isinstance(v, np_ndarray) for v in data):
270            raise TypeError('cannot save mxnet.numpy.ndarray using mxnet.ndarray.save;'
271                            ' use mxnet.numpy.save instead.')
272        keys = None
273        handles = c_handle_array(data)
274    else:
275        raise ValueError("data needs to either be a NDArray, dict of str, NDArray pairs "
276                         "or a list of NDarrays.")
277    check_call(_LIB.MXNDArraySave(c_str(fname),
278                                  mx_uint(len(handles)),
279                                  handles,
280                                  keys))
281