1""" Utility functions for sparse matrix module
2"""
3
4import sys
5import operator
6import warnings
7import numpy as np
8from scipy._lib._util import prod
9
10__all__ = ['upcast', 'getdtype', 'getdata', 'isscalarlike', 'isintlike',
11           'isshape', 'issequence', 'isdense', 'ismatrix', 'get_sum_dtype']
12
13supported_dtypes = [np.bool_, np.byte, np.ubyte, np.short, np.ushort, np.intc,
14                    np.uintc, np.int_, np.uint, np.longlong, np.ulonglong, np.single, np.double,
15                    np.longdouble, np.csingle, np.cdouble, np.clongdouble]
16
17_upcast_memo = {}
18
19
20def upcast(*args):
21    """Returns the nearest supported sparse dtype for the
22    combination of one or more types.
23
24    upcast(t0, t1, ..., tn) -> T  where T is a supported dtype
25
26    Examples
27    --------
28
29    >>> upcast('int32')
30    <type 'numpy.int32'>
31    >>> upcast('bool')
32    <type 'numpy.bool_'>
33    >>> upcast('int32','float32')
34    <type 'numpy.float64'>
35    >>> upcast('bool',complex,float)
36    <type 'numpy.complex128'>
37
38    """
39
40    t = _upcast_memo.get(hash(args))
41    if t is not None:
42        return t
43
44    upcast = np.find_common_type(args, [])
45
46    for t in supported_dtypes:
47        if np.can_cast(upcast, t):
48            _upcast_memo[hash(args)] = t
49            return t
50
51    raise TypeError('no supported conversion for types: %r' % (args,))
52
53
54def upcast_char(*args):
55    """Same as `upcast` but taking dtype.char as input (faster)."""
56    t = _upcast_memo.get(args)
57    if t is not None:
58        return t
59    t = upcast(*map(np.dtype, args))
60    _upcast_memo[args] = t
61    return t
62
63
64def upcast_scalar(dtype, scalar):
65    """Determine data type for binary operation between an array of
66    type `dtype` and a scalar.
67    """
68    return (np.array([0], dtype=dtype) * scalar).dtype
69
70
71def downcast_intp_index(arr):
72    """
73    Down-cast index array to np.intp dtype if it is of a larger dtype.
74
75    Raise an error if the array contains a value that is too large for
76    intp.
77    """
78    if arr.dtype.itemsize > np.dtype(np.intp).itemsize:
79        if arr.size == 0:
80            return arr.astype(np.intp)
81        maxval = arr.max()
82        minval = arr.min()
83        if maxval > np.iinfo(np.intp).max or minval < np.iinfo(np.intp).min:
84            raise ValueError("Cannot deal with arrays with indices larger "
85                             "than the machine maximum address size "
86                             "(e.g. 64-bit indices on 32-bit machine).")
87        return arr.astype(np.intp)
88    return arr
89
90
91def to_native(A):
92    return np.asarray(A, dtype=A.dtype.newbyteorder('native'))
93
94
95def getdtype(dtype, a=None, default=None):
96    """Function used to simplify argument processing. If 'dtype' is not
97    specified (is None), returns a.dtype; otherwise returns a np.dtype
98    object created from the specified dtype argument. If 'dtype' and 'a'
99    are both None, construct a data type out of the 'default' parameter.
100    Furthermore, 'dtype' must be in 'allowed' set.
101    """
102    # TODO is this really what we want?
103    if dtype is None:
104        try:
105            newdtype = a.dtype
106        except AttributeError as e:
107            if default is not None:
108                newdtype = np.dtype(default)
109            else:
110                raise TypeError("could not interpret data type") from e
111    else:
112        newdtype = np.dtype(dtype)
113        if newdtype == np.object_:
114            warnings.warn("object dtype is not supported by sparse matrices")
115
116    return newdtype
117
118
119def getdata(obj, dtype=None, copy=False):
120    """
121    This is a wrapper of `np.array(obj, dtype=dtype, copy=copy)`
122    that will generate a warning if the result is an object array.
123    """
124    data = np.array(obj, dtype=dtype, copy=copy)
125    # Defer to getdtype for checking that the dtype is OK.
126    # This is called for the validation only; we don't need the return value.
127    getdtype(data.dtype)
128    return data
129
130
131def get_index_dtype(arrays=(), maxval=None, check_contents=False):
132    """
133    Based on input (integer) arrays `a`, determine a suitable index data
134    type that can hold the data in the arrays.
135
136    Parameters
137    ----------
138    arrays : tuple of array_like
139        Input arrays whose types/contents to check
140    maxval : float, optional
141        Maximum value needed
142    check_contents : bool, optional
143        Whether to check the values in the arrays and not just their types.
144        Default: False (check only the types)
145
146    Returns
147    -------
148    dtype : dtype
149        Suitable index data type (int32 or int64)
150
151    """
152
153    int32min = np.iinfo(np.int32).min
154    int32max = np.iinfo(np.int32).max
155
156    dtype = np.intc
157    if maxval is not None:
158        if maxval > int32max:
159            dtype = np.int64
160
161    if isinstance(arrays, np.ndarray):
162        arrays = (arrays,)
163
164    for arr in arrays:
165        arr = np.asarray(arr)
166        if not np.can_cast(arr.dtype, np.int32):
167            if check_contents:
168                if arr.size == 0:
169                    # a bigger type not needed
170                    continue
171                elif np.issubdtype(arr.dtype, np.integer):
172                    maxval = arr.max()
173                    minval = arr.min()
174                    if minval >= int32min and maxval <= int32max:
175                        # a bigger type not needed
176                        continue
177
178            dtype = np.int64
179            break
180
181    return dtype
182
183
184def get_sum_dtype(dtype):
185    """Mimic numpy's casting for np.sum"""
186    if dtype.kind == 'u' and np.can_cast(dtype, np.uint):
187        return np.uint
188    if np.can_cast(dtype, np.int_):
189        return np.int_
190    return dtype
191
192
193def isscalarlike(x):
194    """Is x either a scalar, an array scalar, or a 0-dim array?"""
195    return np.isscalar(x) or (isdense(x) and x.ndim == 0)
196
197
198def isintlike(x):
199    """Is x appropriate as an index into a sparse matrix? Returns True
200    if it can be cast safely to a machine int.
201    """
202    # Fast-path check to eliminate non-scalar values. operator.index would
203    # catch this case too, but the exception catching is slow.
204    if np.ndim(x) != 0:
205        return False
206    try:
207        operator.index(x)
208    except (TypeError, ValueError):
209        try:
210            loose_int = bool(int(x) == x)
211        except (TypeError, ValueError):
212            return False
213        if loose_int:
214            warnings.warn("Inexact indices into sparse matrices are deprecated",
215                          DeprecationWarning)
216        return loose_int
217    return True
218
219
220def isshape(x, nonneg=False):
221    """Is x a valid 2-tuple of dimensions?
222
223    If nonneg, also checks that the dimensions are non-negative.
224    """
225    try:
226        # Assume it's a tuple of matrix dimensions (M, N)
227        (M, N) = x
228    except Exception:
229        return False
230    else:
231        if isintlike(M) and isintlike(N):
232            if np.ndim(M) == 0 and np.ndim(N) == 0:
233                if not nonneg or (M >= 0 and N >= 0):
234                    return True
235        return False
236
237
238def issequence(t):
239    return ((isinstance(t, (list, tuple)) and
240            (len(t) == 0 or np.isscalar(t[0]))) or
241            (isinstance(t, np.ndarray) and (t.ndim == 1)))
242
243
244def ismatrix(t):
245    return ((isinstance(t, (list, tuple)) and
246             len(t) > 0 and issequence(t[0])) or
247            (isinstance(t, np.ndarray) and t.ndim == 2))
248
249
250def isdense(x):
251    return isinstance(x, np.ndarray)
252
253
254def validateaxis(axis):
255    if axis is not None:
256        axis_type = type(axis)
257
258        # In NumPy, you can pass in tuples for 'axis', but they are
259        # not very useful for sparse matrices given their limited
260        # dimensions, so let's make it explicit that they are not
261        # allowed to be passed in
262        if axis_type == tuple:
263            raise TypeError(("Tuples are not accepted for the 'axis' "
264                             "parameter. Please pass in one of the "
265                             "following: {-2, -1, 0, 1, None}."))
266
267        # If not a tuple, check that the provided axis is actually
268        # an integer and raise a TypeError similar to NumPy's
269        if not np.issubdtype(np.dtype(axis_type), np.integer):
270            raise TypeError("axis must be an integer, not {name}"
271                            .format(name=axis_type.__name__))
272
273        if not (-2 <= axis <= 1):
274            raise ValueError("axis out of range")
275
276
277def check_shape(args, current_shape=None):
278    """Imitate numpy.matrix handling of shape arguments"""
279    if len(args) == 0:
280        raise TypeError("function missing 1 required positional argument: "
281                        "'shape'")
282    elif len(args) == 1:
283        try:
284            shape_iter = iter(args[0])
285        except TypeError:
286            new_shape = (operator.index(args[0]), )
287        else:
288            new_shape = tuple(operator.index(arg) for arg in shape_iter)
289    else:
290        new_shape = tuple(operator.index(arg) for arg in args)
291
292    if current_shape is None:
293        if len(new_shape) != 2:
294            raise ValueError('shape must be a 2-tuple of positive integers')
295        elif new_shape[0] < 0 or new_shape[1] < 0:
296            raise ValueError("'shape' elements cannot be negative")
297
298    else:
299        # Check the current size only if needed
300        current_size = prod(current_shape)
301
302        # Check for negatives
303        negative_indexes = [i for i, x in enumerate(new_shape) if x < 0]
304        if len(negative_indexes) == 0:
305            new_size = prod(new_shape)
306            if new_size != current_size:
307                raise ValueError('cannot reshape array of size {} into shape {}'
308                                 .format(current_size, new_shape))
309        elif len(negative_indexes) == 1:
310            skip = negative_indexes[0]
311            specified = prod(new_shape[0:skip] + new_shape[skip+1:])
312            unspecified, remainder = divmod(current_size, specified)
313            if remainder != 0:
314                err_shape = tuple('newshape' if x < 0 else x for x in new_shape)
315                raise ValueError('cannot reshape array of size {} into shape {}'
316                                 ''.format(current_size, err_shape))
317            new_shape = new_shape[0:skip] + (unspecified,) + new_shape[skip+1:]
318        else:
319            raise ValueError('can only specify one unknown dimension')
320
321    if len(new_shape) != 2:
322        raise ValueError('matrix shape must be two-dimensional')
323
324    return new_shape
325
326
327def check_reshape_kwargs(kwargs):
328    """Unpack keyword arguments for reshape function.
329
330    This is useful because keyword arguments after star arguments are not
331    allowed in Python 2, but star keyword arguments are. This function unpacks
332    'order' and 'copy' from the star keyword arguments (with defaults) and
333    throws an error for any remaining.
334    """
335
336    order = kwargs.pop('order', 'C')
337    copy = kwargs.pop('copy', False)
338    if kwargs:  # Some unused kwargs remain
339        raise TypeError('reshape() got unexpected keywords arguments: {}'
340                        .format(', '.join(kwargs.keys())))
341    return order, copy
342
343
344def is_pydata_spmatrix(m):
345    """
346    Check whether object is pydata/sparse matrix, avoiding importing the module.
347    """
348    base_cls = getattr(sys.modules.get('sparse'), 'SparseArray', None)
349    return base_cls is not None and isinstance(m, base_cls)
350
351
352###############################################################################
353# Wrappers for NumPy types that are deprecated
354
355# Numpy versions of these functions raise deprecation warnings, the
356# ones below do not.
357
358
359def matrix(*args, **kwargs):
360    return np.array(*args, **kwargs).view(np.matrix)
361
362
363def asmatrix(data, dtype=None):
364    if isinstance(data, np.matrix) and (dtype is None or data.dtype == dtype):
365        return data
366    return np.asarray(data, dtype=dtype).view(np.matrix)
367