1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# coding: utf-8 19"""Utility functions for NDArray and BaseSparseNDArray.""" 20import ctypes 21 22from ..base import _LIB, check_call, py_str, c_str, string_types, mx_uint, NDArrayHandle 23from ..base import c_array, c_handle_array, c_str_array 24from .ndarray import NDArray 25from .ndarray import array as _array 26from .ndarray import empty as _empty_ndarray 27from .ndarray import zeros as _zeros_ndarray 28from .sparse import zeros as _zeros_sparse_ndarray 29from .sparse import empty as _empty_sparse_ndarray 30from .sparse import array as _sparse_array 31from .sparse import _ndarray_cls 32try: 33 import scipy.sparse as spsp 34except ImportError: 35 spsp = None 36 37__all__ = ['zeros', 'empty', 'array', 'load', 'load_frombuffer', 'save'] 38 39 40def zeros(shape, ctx=None, dtype=None, stype=None, **kwargs): 41 """Return a new array of given shape and type, filled with zeros. 42 43 Parameters 44 ---------- 45 shape : int or tuple of int 46 The shape of the empty array 47 ctx : Context, optional 48 An optional device context (default is the current default context) 49 dtype : str or numpy.dtype, optional 50 An optional value type (default is `float32`) 51 stype: string, optional 52 The storage type of the empty array, such as 'row_sparse', 'csr', etc. 53 54 Returns 55 ------- 56 NDArray, CSRNDArray or RowSparseNDArray 57 A created array 58 Examples 59 -------- 60 >>> mx.nd.zeros((1,2), mx.cpu(), stype='csr') 61 <CSRNDArray 1x2 @cpu(0)> 62 >>> mx.nd.zeros((1,2), mx.cpu(), 'float16', stype='row_sparse').asnumpy() 63 array([[ 0., 0.]], dtype=float16) 64 """ 65 66 if stype is None or stype == 'default': 67 return _zeros_ndarray(shape, ctx, dtype, **kwargs) 68 else: 69 return _zeros_sparse_ndarray(stype, shape, ctx, dtype, **kwargs) 70 71 72def empty(shape, ctx=None, dtype=None, stype=None): 73 """Returns a new array of given shape and type, without initializing entries. 74 75 Parameters 76 ---------- 77 shape : int or tuple of int 78 The shape of the empty array. 79 ctx : Context, optional 80 An optional device context (default is the current default context). 81 dtype : str or numpy.dtype, optional 82 An optional value type (default is `float32`). 83 stype : str, optional 84 An optional storage type (default is `default`). 85 86 Returns 87 ------- 88 NDArray, CSRNDArray or RowSparseNDArray 89 A created array. 90 91 Examples 92 -------- 93 >>> mx.nd.empty(1) 94 <NDArray 1 @cpu(0)> 95 >>> mx.nd.empty((1,2), mx.gpu(0)) 96 <NDArray 1x2 @gpu(0)> 97 >>> mx.nd.empty((1,2), mx.gpu(0), 'float16') 98 <NDArray 1x2 @gpu(0)> 99 >>> mx.nd.empty((1,2), stype='csr') 100 <CSRNDArray 1x2 @cpu(0)> 101 """ 102 if stype is None or stype == 'default': 103 return _empty_ndarray(shape, ctx, dtype) 104 else: 105 return _empty_sparse_ndarray(stype, shape, ctx, dtype) 106 107 108def array(source_array, ctx=None, dtype=None): 109 """Creates an array from any object exposing the array interface. 110 111 Parameters 112 ---------- 113 source_array : array_like 114 An object exposing the array interface, an object whose `__array__` 115 method returns an array, or any (nested) sequence. 116 ctx : Context, optional 117 Device context (default is the current default context). 118 dtype : str or numpy.dtype, optional 119 The data type of the output array. The default dtype is ``source_array.dtype`` 120 if `source_array` is an `NDArray`, `float32` otherwise. 121 122 Returns 123 ------- 124 NDArray, RowSparseNDArray or CSRNDArray 125 An array with the same contents as the `source_array`. 126 127 Examples 128 -------- 129 >>> import numpy as np 130 >>> mx.nd.array([1, 2, 3]) 131 <NDArray 3 @cpu(0)> 132 >>> mx.nd.array([[1, 2], [3, 4]]) 133 <NDArray 2x2 @cpu(0)> 134 >>> mx.nd.array(np.zeros((3, 2))) 135 <NDArray 3x2 @cpu(0)> 136 >>> mx.nd.array(np.zeros((3, 2)), mx.gpu(0)) 137 <NDArray 3x2 @gpu(0)> 138 >>> mx.nd.array(mx.nd.zeros((3, 2), stype='row_sparse')) 139 <RowSparseNDArray 3x2 @cpu(0)> 140 """ 141 if spsp is not None and isinstance(source_array, spsp.csr.csr_matrix): 142 return _sparse_array(source_array, ctx=ctx, dtype=dtype) 143 elif isinstance(source_array, NDArray) and source_array.stype != 'default': 144 return _sparse_array(source_array, ctx=ctx, dtype=dtype) 145 else: 146 return _array(source_array, ctx=ctx, dtype=dtype) 147 148 149def load(fname): 150 """Loads an array from file. 151 152 See more details in ``save``. 153 154 Parameters 155 ---------- 156 fname : str 157 The filename. 158 159 Returns 160 ------- 161 list of NDArray, RowSparseNDArray or CSRNDArray, or \ 162 dict of str to NDArray, RowSparseNDArray or CSRNDArray 163 Loaded data. 164 """ 165 if not isinstance(fname, string_types): 166 raise TypeError('fname required to be a string') 167 out_size = mx_uint() 168 out_name_size = mx_uint() 169 handles = ctypes.POINTER(NDArrayHandle)() 170 names = ctypes.POINTER(ctypes.c_char_p)() 171 check_call(_LIB.MXNDArrayLoad(c_str(fname), 172 ctypes.byref(out_size), 173 ctypes.byref(handles), 174 ctypes.byref(out_name_size), 175 ctypes.byref(names))) 176 if out_name_size.value == 0: 177 return [_ndarray_cls(NDArrayHandle(handles[i])) for i in range(out_size.value)] 178 else: 179 assert out_name_size.value == out_size.value 180 return dict( 181 (py_str(names[i]), _ndarray_cls(NDArrayHandle(handles[i]))) 182 for i in range(out_size.value)) 183 184 185def load_frombuffer(buf): 186 """Loads an array dictionary or list from a buffer 187 188 See more details in ``save``. 189 190 Parameters 191 ---------- 192 buf : str 193 Buffer containing contents of a file as a string or bytes. 194 195 Returns 196 ------- 197 list of NDArray, RowSparseNDArray or CSRNDArray, or \ 198 dict of str to NDArray, RowSparseNDArray or CSRNDArray 199 Loaded data. 200 """ 201 if not isinstance(buf, string_types + tuple([bytes])): 202 raise TypeError('buf required to be a string or bytes') 203 out_size = mx_uint() 204 out_name_size = mx_uint() 205 handles = ctypes.POINTER(NDArrayHandle)() 206 names = ctypes.POINTER(ctypes.c_char_p)() 207 check_call(_LIB.MXNDArrayLoadFromBuffer(buf, 208 mx_uint(len(buf)), 209 ctypes.byref(out_size), 210 ctypes.byref(handles), 211 ctypes.byref(out_name_size), 212 ctypes.byref(names))) 213 if out_name_size.value == 0: 214 return [_ndarray_cls(NDArrayHandle(handles[i])) for i in range(out_size.value)] 215 else: 216 assert out_name_size.value == out_size.value 217 return dict( 218 (py_str(names[i]), _ndarray_cls(NDArrayHandle(handles[i]))) 219 for i in range(out_size.value)) 220 221 222def save(fname, data): 223 """Saves a list of arrays or a dict of str->array to file. 224 225 Examples of filenames: 226 227 - ``/path/to/file`` 228 - ``s3://my-bucket/path/to/file`` (if compiled with AWS S3 supports) 229 - ``hdfs://path/to/file`` (if compiled with HDFS supports) 230 231 Parameters 232 ---------- 233 fname : str 234 The filename. 235 data : NDArray, RowSparseNDArray or CSRNDArray, \ 236 or list of NDArray, RowSparseNDArray or CSRNDArray, \ 237 or dict of str to NDArray, RowSparseNDArray or CSRNDArray 238 The data to save. 239 240 Examples 241 -------- 242 >>> x = mx.nd.zeros((2,3)) 243 >>> y = mx.nd.ones((1,4)) 244 >>> mx.nd.save('my_list', [x,y]) 245 >>> mx.nd.save('my_dict', {'x':x, 'y':y}) 246 >>> mx.nd.load('my_list') 247 [<NDArray 2x3 @cpu(0)>, <NDArray 1x4 @cpu(0)>] 248 >>> mx.nd.load('my_dict') 249 {'y': <NDArray 1x4 @cpu(0)>, 'x': <NDArray 2x3 @cpu(0)>} 250 """ 251 from ..numpy import ndarray as np_ndarray 252 if isinstance(data, NDArray): 253 data = [data] 254 handles = c_array(NDArrayHandle, []) 255 if isinstance(data, dict): 256 str_keys = data.keys() 257 nd_vals = data.values() 258 if any(not isinstance(k, string_types) for k in str_keys) or \ 259 any(not isinstance(v, NDArray) for v in nd_vals): 260 raise TypeError('save only accept dict str->NDArray or list of NDArray') 261 if any(isinstance(v, np_ndarray) for v in nd_vals): 262 raise TypeError('cannot save mxnet.numpy.ndarray using mxnet.ndarray.save;' 263 ' use mxnet.numpy.save instead.') 264 keys = c_str_array(str_keys) 265 handles = c_handle_array(nd_vals) 266 elif isinstance(data, list): 267 if any(not isinstance(v, NDArray) for v in data): 268 raise TypeError('save only accept dict str->NDArray or list of NDArray') 269 if any(isinstance(v, np_ndarray) for v in data): 270 raise TypeError('cannot save mxnet.numpy.ndarray using mxnet.ndarray.save;' 271 ' use mxnet.numpy.save instead.') 272 keys = None 273 handles = c_handle_array(data) 274 else: 275 raise ValueError("data needs to either be a NDArray, dict of str, NDArray pairs " 276 "or a list of NDarrays.") 277 check_call(_LIB.MXNDArraySave(c_str(fname), 278 mx_uint(len(handles)), 279 handles, 280 keys)) 281