1"""Indexing mixin for sparse matrix classes.
2"""
3import numpy as np
4from .sputils import isintlike
5
6try:
7    INT_TYPES = (int, long, np.integer)
8except NameError:
9    # long is not defined in Python3
10    INT_TYPES = (int, np.integer)
11
12
13def _broadcast_arrays(a, b):
14    """
15    Same as np.broadcast_arrays(a, b) but old writeability rules.
16
17    NumPy >= 1.17.0 transitions broadcast_arrays to return
18    read-only arrays. Set writeability explicitly to avoid warnings.
19    Retain the old writeability rules, as our Cython code assumes
20    the old behavior.
21    """
22    x, y = np.broadcast_arrays(a, b)
23    x.flags.writeable = a.flags.writeable
24    y.flags.writeable = b.flags.writeable
25    return x, y
26
27
28class IndexMixin:
29    """
30    This class provides common dispatching and validation logic for indexing.
31    """
32    def __getitem__(self, key):
33        row, col = self._validate_indices(key)
34        # Dispatch to specialized methods.
35        if isinstance(row, INT_TYPES):
36            if isinstance(col, INT_TYPES):
37                return self._get_intXint(row, col)
38            elif isinstance(col, slice):
39                return self._get_intXslice(row, col)
40            elif col.ndim == 1:
41                return self._get_intXarray(row, col)
42            raise IndexError('index results in >2 dimensions')
43        elif isinstance(row, slice):
44            if isinstance(col, INT_TYPES):
45                return self._get_sliceXint(row, col)
46            elif isinstance(col, slice):
47                if row == slice(None) and row == col:
48                    return self.copy()
49                return self._get_sliceXslice(row, col)
50            elif col.ndim == 1:
51                return self._get_sliceXarray(row, col)
52            raise IndexError('index results in >2 dimensions')
53        elif row.ndim == 1:
54            if isinstance(col, INT_TYPES):
55                return self._get_arrayXint(row, col)
56            elif isinstance(col, slice):
57                return self._get_arrayXslice(row, col)
58        else:  # row.ndim == 2
59            if isinstance(col, INT_TYPES):
60                return self._get_arrayXint(row, col)
61            elif isinstance(col, slice):
62                raise IndexError('index results in >2 dimensions')
63            elif row.shape[1] == 1 and (col.ndim == 1 or col.shape[0] == 1):
64                # special case for outer indexing
65                return self._get_columnXarray(row[:,0], col.ravel())
66
67        # The only remaining case is inner (fancy) indexing
68        row, col = _broadcast_arrays(row, col)
69        if row.shape != col.shape:
70            raise IndexError('number of row and column indices differ')
71        if row.size == 0:
72            return self.__class__(np.atleast_2d(row).shape, dtype=self.dtype)
73        return self._get_arrayXarray(row, col)
74
75    def __setitem__(self, key, x):
76        row, col = self._validate_indices(key)
77
78        if isinstance(row, INT_TYPES) and isinstance(col, INT_TYPES):
79            x = np.asarray(x, dtype=self.dtype)
80            if x.size != 1:
81                raise ValueError('Trying to assign a sequence to an item')
82            self._set_intXint(row, col, x.flat[0])
83            return
84
85        if isinstance(row, slice):
86            row = np.arange(*row.indices(self.shape[0]))[:, None]
87        else:
88            row = np.atleast_1d(row)
89
90        if isinstance(col, slice):
91            col = np.arange(*col.indices(self.shape[1]))[None, :]
92            if row.ndim == 1:
93                row = row[:, None]
94        else:
95            col = np.atleast_1d(col)
96
97        i, j = _broadcast_arrays(row, col)
98        if i.shape != j.shape:
99            raise IndexError('number of row and column indices differ')
100
101        from .base import isspmatrix
102        if isspmatrix(x):
103            if i.ndim == 1:
104                # Inner indexing, so treat them like row vectors.
105                i = i[None]
106                j = j[None]
107            broadcast_row = x.shape[0] == 1 and i.shape[0] != 1
108            broadcast_col = x.shape[1] == 1 and i.shape[1] != 1
109            if not ((broadcast_row or x.shape[0] == i.shape[0]) and
110                    (broadcast_col or x.shape[1] == i.shape[1])):
111                raise ValueError('shape mismatch in assignment')
112            if x.shape[0] == 0 or x.shape[1] == 0:
113                return
114            x = x.tocoo(copy=True)
115            x.sum_duplicates()
116            self._set_arrayXarray_sparse(i, j, x)
117        else:
118            # Make x and i into the same shape
119            x = np.asarray(x, dtype=self.dtype)
120            if x.squeeze().shape != i.squeeze().shape:
121                x = np.broadcast_to(x, i.shape)
122            if x.size == 0:
123                return
124            x = x.reshape(i.shape)
125            self._set_arrayXarray(i, j, x)
126
127    def _validate_indices(self, key):
128        M, N = self.shape
129        row, col = _unpack_index(key)
130
131        if isintlike(row):
132            row = int(row)
133            if row < -M or row >= M:
134                raise IndexError('row index (%d) out of range' % row)
135            if row < 0:
136                row += M
137        elif not isinstance(row, slice):
138            row = self._asindices(row, M)
139
140        if isintlike(col):
141            col = int(col)
142            if col < -N or col >= N:
143                raise IndexError('column index (%d) out of range' % col)
144            if col < 0:
145                col += N
146        elif not isinstance(col, slice):
147            col = self._asindices(col, N)
148
149        return row, col
150
151    def _asindices(self, idx, length):
152        """Convert `idx` to a valid index for an axis with a given length.
153
154        Subclasses that need special validation can override this method.
155        """
156        try:
157            x = np.asarray(idx)
158        except (ValueError, TypeError, MemoryError) as e:
159            raise IndexError('invalid index') from e
160
161        if x.ndim not in (1, 2):
162            raise IndexError('Index dimension must be <= 2')
163
164        if x.size == 0:
165            return x
166
167        # Check bounds
168        max_indx = x.max()
169        if max_indx >= length:
170            raise IndexError('index (%d) out of range' % max_indx)
171
172        min_indx = x.min()
173        if min_indx < 0:
174            if min_indx < -length:
175                raise IndexError('index (%d) out of range' % min_indx)
176            if x is idx or not x.flags.owndata:
177                x = x.copy()
178            x[x < 0] += length
179        return x
180
181    def getrow(self, i):
182        """Return a copy of row i of the matrix, as a (1 x n) row vector.
183        """
184        M, N = self.shape
185        i = int(i)
186        if i < -M or i >= M:
187            raise IndexError('index (%d) out of range' % i)
188        if i < 0:
189            i += M
190        return self._get_intXslice(i, slice(None))
191
192    def getcol(self, i):
193        """Return a copy of column i of the matrix, as a (m x 1) column vector.
194        """
195        M, N = self.shape
196        i = int(i)
197        if i < -N or i >= N:
198            raise IndexError('index (%d) out of range' % i)
199        if i < 0:
200            i += N
201        return self._get_sliceXint(slice(None), i)
202
203    def _get_intXint(self, row, col):
204        raise NotImplementedError()
205
206    def _get_intXarray(self, row, col):
207        raise NotImplementedError()
208
209    def _get_intXslice(self, row, col):
210        raise NotImplementedError()
211
212    def _get_sliceXint(self, row, col):
213        raise NotImplementedError()
214
215    def _get_sliceXslice(self, row, col):
216        raise NotImplementedError()
217
218    def _get_sliceXarray(self, row, col):
219        raise NotImplementedError()
220
221    def _get_arrayXint(self, row, col):
222        raise NotImplementedError()
223
224    def _get_arrayXslice(self, row, col):
225        raise NotImplementedError()
226
227    def _get_columnXarray(self, row, col):
228        raise NotImplementedError()
229
230    def _get_arrayXarray(self, row, col):
231        raise NotImplementedError()
232
233    def _set_intXint(self, row, col, x):
234        raise NotImplementedError()
235
236    def _set_arrayXarray(self, row, col, x):
237        raise NotImplementedError()
238
239    def _set_arrayXarray_sparse(self, row, col, x):
240        # Fall back to densifying x
241        x = np.asarray(x.toarray(), dtype=self.dtype)
242        x, _ = _broadcast_arrays(x, row)
243        self._set_arrayXarray(row, col, x)
244
245
246def _unpack_index(index):
247    """ Parse index. Always return a tuple of the form (row, col).
248    Valid type for row/col is integer, slice, or array of integers.
249    """
250    # First, check if indexing with single boolean matrix.
251    from .base import spmatrix, isspmatrix
252    if (isinstance(index, (spmatrix, np.ndarray)) and
253            index.ndim == 2 and index.dtype.kind == 'b'):
254        return index.nonzero()
255
256    # Parse any ellipses.
257    index = _check_ellipsis(index)
258
259    # Next, parse the tuple or object
260    if isinstance(index, tuple):
261        if len(index) == 2:
262            row, col = index
263        elif len(index) == 1:
264            row, col = index[0], slice(None)
265        else:
266            raise IndexError('invalid number of indices')
267    else:
268        idx = _compatible_boolean_index(index)
269        if idx is None:
270            row, col = index, slice(None)
271        elif idx.ndim < 2:
272            return _boolean_index_to_array(idx), slice(None)
273        elif idx.ndim == 2:
274            return idx.nonzero()
275    # Next, check for validity and transform the index as needed.
276    if isspmatrix(row) or isspmatrix(col):
277        # Supporting sparse boolean indexing with both row and col does
278        # not work because spmatrix.ndim is always 2.
279        raise IndexError(
280            'Indexing with sparse matrices is not supported '
281            'except boolean indexing where matrix and index '
282            'are equal shapes.')
283    bool_row = _compatible_boolean_index(row)
284    bool_col = _compatible_boolean_index(col)
285    if bool_row is not None:
286        row = _boolean_index_to_array(bool_row)
287    if bool_col is not None:
288        col = _boolean_index_to_array(bool_col)
289    return row, col
290
291
292def _check_ellipsis(index):
293    """Process indices with Ellipsis. Returns modified index."""
294    if index is Ellipsis:
295        return (slice(None), slice(None))
296
297    if not isinstance(index, tuple):
298        return index
299
300    # TODO: Deprecate this multiple-ellipsis handling,
301    #       as numpy no longer supports it.
302
303    # Find first ellipsis.
304    for j, v in enumerate(index):
305        if v is Ellipsis:
306            first_ellipsis = j
307            break
308    else:
309        return index
310
311    # Try to expand it using shortcuts for common cases
312    if len(index) == 1:
313        return (slice(None), slice(None))
314    if len(index) == 2:
315        if first_ellipsis == 0:
316            if index[1] is Ellipsis:
317                return (slice(None), slice(None))
318            return (slice(None), index[1])
319        return (index[0], slice(None))
320
321    # Expand it using a general-purpose algorithm
322    tail = []
323    for v in index[first_ellipsis+1:]:
324        if v is not Ellipsis:
325            tail.append(v)
326    nd = first_ellipsis + len(tail)
327    nslice = max(0, 2 - nd)
328    return index[:first_ellipsis] + (slice(None),)*nslice + tuple(tail)
329
330
331def _maybe_bool_ndarray(idx):
332    """Returns a compatible array if elements are boolean.
333    """
334    idx = np.asanyarray(idx)
335    if idx.dtype.kind == 'b':
336        return idx
337    return None
338
339
340def _first_element_bool(idx, max_dim=2):
341    """Returns True if first element of the incompatible
342    array type is boolean.
343    """
344    if max_dim < 1:
345        return None
346    try:
347        first = next(iter(idx), None)
348    except TypeError:
349        return None
350    if isinstance(first, bool):
351        return True
352    return _first_element_bool(first, max_dim-1)
353
354
355def _compatible_boolean_index(idx):
356    """Returns a boolean index array that can be converted to
357    integer array. Returns None if no such array exists.
358    """
359    # Presence of attribute `ndim` indicates a compatible array type.
360    if hasattr(idx, 'ndim') or _first_element_bool(idx):
361        return _maybe_bool_ndarray(idx)
362    return None
363
364
365def _boolean_index_to_array(idx):
366    if idx.ndim > 1:
367        raise IndexError('invalid index shape')
368    return np.where(idx)[0]
369