1""" Utility functions for sparse matrix module 2""" 3 4import sys 5import operator 6import warnings 7import numpy as np 8from scipy._lib._util import prod 9 10__all__ = ['upcast', 'getdtype', 'getdata', 'isscalarlike', 'isintlike', 11 'isshape', 'issequence', 'isdense', 'ismatrix', 'get_sum_dtype'] 12 13supported_dtypes = [np.bool_, np.byte, np.ubyte, np.short, np.ushort, np.intc, 14 np.uintc, np.int_, np.uint, np.longlong, np.ulonglong, np.single, np.double, 15 np.longdouble, np.csingle, np.cdouble, np.clongdouble] 16 17_upcast_memo = {} 18 19 20def upcast(*args): 21 """Returns the nearest supported sparse dtype for the 22 combination of one or more types. 23 24 upcast(t0, t1, ..., tn) -> T where T is a supported dtype 25 26 Examples 27 -------- 28 29 >>> upcast('int32') 30 <type 'numpy.int32'> 31 >>> upcast('bool') 32 <type 'numpy.bool_'> 33 >>> upcast('int32','float32') 34 <type 'numpy.float64'> 35 >>> upcast('bool',complex,float) 36 <type 'numpy.complex128'> 37 38 """ 39 40 t = _upcast_memo.get(hash(args)) 41 if t is not None: 42 return t 43 44 upcast = np.find_common_type(args, []) 45 46 for t in supported_dtypes: 47 if np.can_cast(upcast, t): 48 _upcast_memo[hash(args)] = t 49 return t 50 51 raise TypeError('no supported conversion for types: %r' % (args,)) 52 53 54def upcast_char(*args): 55 """Same as `upcast` but taking dtype.char as input (faster).""" 56 t = _upcast_memo.get(args) 57 if t is not None: 58 return t 59 t = upcast(*map(np.dtype, args)) 60 _upcast_memo[args] = t 61 return t 62 63 64def upcast_scalar(dtype, scalar): 65 """Determine data type for binary operation between an array of 66 type `dtype` and a scalar. 67 """ 68 return (np.array([0], dtype=dtype) * scalar).dtype 69 70 71def downcast_intp_index(arr): 72 """ 73 Down-cast index array to np.intp dtype if it is of a larger dtype. 74 75 Raise an error if the array contains a value that is too large for 76 intp. 77 """ 78 if arr.dtype.itemsize > np.dtype(np.intp).itemsize: 79 if arr.size == 0: 80 return arr.astype(np.intp) 81 maxval = arr.max() 82 minval = arr.min() 83 if maxval > np.iinfo(np.intp).max or minval < np.iinfo(np.intp).min: 84 raise ValueError("Cannot deal with arrays with indices larger " 85 "than the machine maximum address size " 86 "(e.g. 64-bit indices on 32-bit machine).") 87 return arr.astype(np.intp) 88 return arr 89 90 91def to_native(A): 92 return np.asarray(A, dtype=A.dtype.newbyteorder('native')) 93 94 95def getdtype(dtype, a=None, default=None): 96 """Function used to simplify argument processing. If 'dtype' is not 97 specified (is None), returns a.dtype; otherwise returns a np.dtype 98 object created from the specified dtype argument. If 'dtype' and 'a' 99 are both None, construct a data type out of the 'default' parameter. 100 Furthermore, 'dtype' must be in 'allowed' set. 101 """ 102 # TODO is this really what we want? 103 if dtype is None: 104 try: 105 newdtype = a.dtype 106 except AttributeError as e: 107 if default is not None: 108 newdtype = np.dtype(default) 109 else: 110 raise TypeError("could not interpret data type") from e 111 else: 112 newdtype = np.dtype(dtype) 113 if newdtype == np.object_: 114 warnings.warn("object dtype is not supported by sparse matrices") 115 116 return newdtype 117 118 119def getdata(obj, dtype=None, copy=False): 120 """ 121 This is a wrapper of `np.array(obj, dtype=dtype, copy=copy)` 122 that will generate a warning if the result is an object array. 123 """ 124 data = np.array(obj, dtype=dtype, copy=copy) 125 # Defer to getdtype for checking that the dtype is OK. 126 # This is called for the validation only; we don't need the return value. 127 getdtype(data.dtype) 128 return data 129 130 131def get_index_dtype(arrays=(), maxval=None, check_contents=False): 132 """ 133 Based on input (integer) arrays `a`, determine a suitable index data 134 type that can hold the data in the arrays. 135 136 Parameters 137 ---------- 138 arrays : tuple of array_like 139 Input arrays whose types/contents to check 140 maxval : float, optional 141 Maximum value needed 142 check_contents : bool, optional 143 Whether to check the values in the arrays and not just their types. 144 Default: False (check only the types) 145 146 Returns 147 ------- 148 dtype : dtype 149 Suitable index data type (int32 or int64) 150 151 """ 152 153 int32min = np.iinfo(np.int32).min 154 int32max = np.iinfo(np.int32).max 155 156 dtype = np.intc 157 if maxval is not None: 158 if maxval > int32max: 159 dtype = np.int64 160 161 if isinstance(arrays, np.ndarray): 162 arrays = (arrays,) 163 164 for arr in arrays: 165 arr = np.asarray(arr) 166 if not np.can_cast(arr.dtype, np.int32): 167 if check_contents: 168 if arr.size == 0: 169 # a bigger type not needed 170 continue 171 elif np.issubdtype(arr.dtype, np.integer): 172 maxval = arr.max() 173 minval = arr.min() 174 if minval >= int32min and maxval <= int32max: 175 # a bigger type not needed 176 continue 177 178 dtype = np.int64 179 break 180 181 return dtype 182 183 184def get_sum_dtype(dtype): 185 """Mimic numpy's casting for np.sum""" 186 if dtype.kind == 'u' and np.can_cast(dtype, np.uint): 187 return np.uint 188 if np.can_cast(dtype, np.int_): 189 return np.int_ 190 return dtype 191 192 193def isscalarlike(x): 194 """Is x either a scalar, an array scalar, or a 0-dim array?""" 195 return np.isscalar(x) or (isdense(x) and x.ndim == 0) 196 197 198def isintlike(x): 199 """Is x appropriate as an index into a sparse matrix? Returns True 200 if it can be cast safely to a machine int. 201 """ 202 # Fast-path check to eliminate non-scalar values. operator.index would 203 # catch this case too, but the exception catching is slow. 204 if np.ndim(x) != 0: 205 return False 206 try: 207 operator.index(x) 208 except (TypeError, ValueError): 209 try: 210 loose_int = bool(int(x) == x) 211 except (TypeError, ValueError): 212 return False 213 if loose_int: 214 warnings.warn("Inexact indices into sparse matrices are deprecated", 215 DeprecationWarning) 216 return loose_int 217 return True 218 219 220def isshape(x, nonneg=False): 221 """Is x a valid 2-tuple of dimensions? 222 223 If nonneg, also checks that the dimensions are non-negative. 224 """ 225 try: 226 # Assume it's a tuple of matrix dimensions (M, N) 227 (M, N) = x 228 except Exception: 229 return False 230 else: 231 if isintlike(M) and isintlike(N): 232 if np.ndim(M) == 0 and np.ndim(N) == 0: 233 if not nonneg or (M >= 0 and N >= 0): 234 return True 235 return False 236 237 238def issequence(t): 239 return ((isinstance(t, (list, tuple)) and 240 (len(t) == 0 or np.isscalar(t[0]))) or 241 (isinstance(t, np.ndarray) and (t.ndim == 1))) 242 243 244def ismatrix(t): 245 return ((isinstance(t, (list, tuple)) and 246 len(t) > 0 and issequence(t[0])) or 247 (isinstance(t, np.ndarray) and t.ndim == 2)) 248 249 250def isdense(x): 251 return isinstance(x, np.ndarray) 252 253 254def validateaxis(axis): 255 if axis is not None: 256 axis_type = type(axis) 257 258 # In NumPy, you can pass in tuples for 'axis', but they are 259 # not very useful for sparse matrices given their limited 260 # dimensions, so let's make it explicit that they are not 261 # allowed to be passed in 262 if axis_type == tuple: 263 raise TypeError(("Tuples are not accepted for the 'axis' " 264 "parameter. Please pass in one of the " 265 "following: {-2, -1, 0, 1, None}.")) 266 267 # If not a tuple, check that the provided axis is actually 268 # an integer and raise a TypeError similar to NumPy's 269 if not np.issubdtype(np.dtype(axis_type), np.integer): 270 raise TypeError("axis must be an integer, not {name}" 271 .format(name=axis_type.__name__)) 272 273 if not (-2 <= axis <= 1): 274 raise ValueError("axis out of range") 275 276 277def check_shape(args, current_shape=None): 278 """Imitate numpy.matrix handling of shape arguments""" 279 if len(args) == 0: 280 raise TypeError("function missing 1 required positional argument: " 281 "'shape'") 282 elif len(args) == 1: 283 try: 284 shape_iter = iter(args[0]) 285 except TypeError: 286 new_shape = (operator.index(args[0]), ) 287 else: 288 new_shape = tuple(operator.index(arg) for arg in shape_iter) 289 else: 290 new_shape = tuple(operator.index(arg) for arg in args) 291 292 if current_shape is None: 293 if len(new_shape) != 2: 294 raise ValueError('shape must be a 2-tuple of positive integers') 295 elif new_shape[0] < 0 or new_shape[1] < 0: 296 raise ValueError("'shape' elements cannot be negative") 297 298 else: 299 # Check the current size only if needed 300 current_size = prod(current_shape) 301 302 # Check for negatives 303 negative_indexes = [i for i, x in enumerate(new_shape) if x < 0] 304 if len(negative_indexes) == 0: 305 new_size = prod(new_shape) 306 if new_size != current_size: 307 raise ValueError('cannot reshape array of size {} into shape {}' 308 .format(current_size, new_shape)) 309 elif len(negative_indexes) == 1: 310 skip = negative_indexes[0] 311 specified = prod(new_shape[0:skip] + new_shape[skip+1:]) 312 unspecified, remainder = divmod(current_size, specified) 313 if remainder != 0: 314 err_shape = tuple('newshape' if x < 0 else x for x in new_shape) 315 raise ValueError('cannot reshape array of size {} into shape {}' 316 ''.format(current_size, err_shape)) 317 new_shape = new_shape[0:skip] + (unspecified,) + new_shape[skip+1:] 318 else: 319 raise ValueError('can only specify one unknown dimension') 320 321 if len(new_shape) != 2: 322 raise ValueError('matrix shape must be two-dimensional') 323 324 return new_shape 325 326 327def check_reshape_kwargs(kwargs): 328 """Unpack keyword arguments for reshape function. 329 330 This is useful because keyword arguments after star arguments are not 331 allowed in Python 2, but star keyword arguments are. This function unpacks 332 'order' and 'copy' from the star keyword arguments (with defaults) and 333 throws an error for any remaining. 334 """ 335 336 order = kwargs.pop('order', 'C') 337 copy = kwargs.pop('copy', False) 338 if kwargs: # Some unused kwargs remain 339 raise TypeError('reshape() got unexpected keywords arguments: {}' 340 .format(', '.join(kwargs.keys()))) 341 return order, copy 342 343 344def is_pydata_spmatrix(m): 345 """ 346 Check whether object is pydata/sparse matrix, avoiding importing the module. 347 """ 348 base_cls = getattr(sys.modules.get('sparse'), 'SparseArray', None) 349 return base_cls is not None and isinstance(m, base_cls) 350 351 352############################################################################### 353# Wrappers for NumPy types that are deprecated 354 355# Numpy versions of these functions raise deprecation warnings, the 356# ones below do not. 357 358 359def matrix(*args, **kwargs): 360 return np.array(*args, **kwargs).view(np.matrix) 361 362 363def asmatrix(data, dtype=None): 364 if isinstance(data, np.matrix) and (dtype is None or data.dtype == dtype): 365 return data 366 return np.asarray(data, dtype=dtype).view(np.matrix) 367