1"""Indexing mixin for sparse matrix classes. 2""" 3import numpy as np 4from .sputils import isintlike 5 6try: 7 INT_TYPES = (int, long, np.integer) 8except NameError: 9 # long is not defined in Python3 10 INT_TYPES = (int, np.integer) 11 12 13def _broadcast_arrays(a, b): 14 """ 15 Same as np.broadcast_arrays(a, b) but old writeability rules. 16 17 NumPy >= 1.17.0 transitions broadcast_arrays to return 18 read-only arrays. Set writeability explicitly to avoid warnings. 19 Retain the old writeability rules, as our Cython code assumes 20 the old behavior. 21 """ 22 x, y = np.broadcast_arrays(a, b) 23 x.flags.writeable = a.flags.writeable 24 y.flags.writeable = b.flags.writeable 25 return x, y 26 27 28class IndexMixin: 29 """ 30 This class provides common dispatching and validation logic for indexing. 31 """ 32 def __getitem__(self, key): 33 row, col = self._validate_indices(key) 34 # Dispatch to specialized methods. 35 if isinstance(row, INT_TYPES): 36 if isinstance(col, INT_TYPES): 37 return self._get_intXint(row, col) 38 elif isinstance(col, slice): 39 return self._get_intXslice(row, col) 40 elif col.ndim == 1: 41 return self._get_intXarray(row, col) 42 raise IndexError('index results in >2 dimensions') 43 elif isinstance(row, slice): 44 if isinstance(col, INT_TYPES): 45 return self._get_sliceXint(row, col) 46 elif isinstance(col, slice): 47 if row == slice(None) and row == col: 48 return self.copy() 49 return self._get_sliceXslice(row, col) 50 elif col.ndim == 1: 51 return self._get_sliceXarray(row, col) 52 raise IndexError('index results in >2 dimensions') 53 elif row.ndim == 1: 54 if isinstance(col, INT_TYPES): 55 return self._get_arrayXint(row, col) 56 elif isinstance(col, slice): 57 return self._get_arrayXslice(row, col) 58 else: # row.ndim == 2 59 if isinstance(col, INT_TYPES): 60 return self._get_arrayXint(row, col) 61 elif isinstance(col, slice): 62 raise IndexError('index results in >2 dimensions') 63 elif row.shape[1] == 1 and (col.ndim == 1 or col.shape[0] == 1): 64 # special case for outer indexing 65 return self._get_columnXarray(row[:,0], col.ravel()) 66 67 # The only remaining case is inner (fancy) indexing 68 row, col = _broadcast_arrays(row, col) 69 if row.shape != col.shape: 70 raise IndexError('number of row and column indices differ') 71 if row.size == 0: 72 return self.__class__(np.atleast_2d(row).shape, dtype=self.dtype) 73 return self._get_arrayXarray(row, col) 74 75 def __setitem__(self, key, x): 76 row, col = self._validate_indices(key) 77 78 if isinstance(row, INT_TYPES) and isinstance(col, INT_TYPES): 79 x = np.asarray(x, dtype=self.dtype) 80 if x.size != 1: 81 raise ValueError('Trying to assign a sequence to an item') 82 self._set_intXint(row, col, x.flat[0]) 83 return 84 85 if isinstance(row, slice): 86 row = np.arange(*row.indices(self.shape[0]))[:, None] 87 else: 88 row = np.atleast_1d(row) 89 90 if isinstance(col, slice): 91 col = np.arange(*col.indices(self.shape[1]))[None, :] 92 if row.ndim == 1: 93 row = row[:, None] 94 else: 95 col = np.atleast_1d(col) 96 97 i, j = _broadcast_arrays(row, col) 98 if i.shape != j.shape: 99 raise IndexError('number of row and column indices differ') 100 101 from .base import isspmatrix 102 if isspmatrix(x): 103 if i.ndim == 1: 104 # Inner indexing, so treat them like row vectors. 105 i = i[None] 106 j = j[None] 107 broadcast_row = x.shape[0] == 1 and i.shape[0] != 1 108 broadcast_col = x.shape[1] == 1 and i.shape[1] != 1 109 if not ((broadcast_row or x.shape[0] == i.shape[0]) and 110 (broadcast_col or x.shape[1] == i.shape[1])): 111 raise ValueError('shape mismatch in assignment') 112 if x.shape[0] == 0 or x.shape[1] == 0: 113 return 114 x = x.tocoo(copy=True) 115 x.sum_duplicates() 116 self._set_arrayXarray_sparse(i, j, x) 117 else: 118 # Make x and i into the same shape 119 x = np.asarray(x, dtype=self.dtype) 120 if x.squeeze().shape != i.squeeze().shape: 121 x = np.broadcast_to(x, i.shape) 122 if x.size == 0: 123 return 124 x = x.reshape(i.shape) 125 self._set_arrayXarray(i, j, x) 126 127 def _validate_indices(self, key): 128 M, N = self.shape 129 row, col = _unpack_index(key) 130 131 if isintlike(row): 132 row = int(row) 133 if row < -M or row >= M: 134 raise IndexError('row index (%d) out of range' % row) 135 if row < 0: 136 row += M 137 elif not isinstance(row, slice): 138 row = self._asindices(row, M) 139 140 if isintlike(col): 141 col = int(col) 142 if col < -N or col >= N: 143 raise IndexError('column index (%d) out of range' % col) 144 if col < 0: 145 col += N 146 elif not isinstance(col, slice): 147 col = self._asindices(col, N) 148 149 return row, col 150 151 def _asindices(self, idx, length): 152 """Convert `idx` to a valid index for an axis with a given length. 153 154 Subclasses that need special validation can override this method. 155 """ 156 try: 157 x = np.asarray(idx) 158 except (ValueError, TypeError, MemoryError) as e: 159 raise IndexError('invalid index') from e 160 161 if x.ndim not in (1, 2): 162 raise IndexError('Index dimension must be <= 2') 163 164 if x.size == 0: 165 return x 166 167 # Check bounds 168 max_indx = x.max() 169 if max_indx >= length: 170 raise IndexError('index (%d) out of range' % max_indx) 171 172 min_indx = x.min() 173 if min_indx < 0: 174 if min_indx < -length: 175 raise IndexError('index (%d) out of range' % min_indx) 176 if x is idx or not x.flags.owndata: 177 x = x.copy() 178 x[x < 0] += length 179 return x 180 181 def getrow(self, i): 182 """Return a copy of row i of the matrix, as a (1 x n) row vector. 183 """ 184 M, N = self.shape 185 i = int(i) 186 if i < -M or i >= M: 187 raise IndexError('index (%d) out of range' % i) 188 if i < 0: 189 i += M 190 return self._get_intXslice(i, slice(None)) 191 192 def getcol(self, i): 193 """Return a copy of column i of the matrix, as a (m x 1) column vector. 194 """ 195 M, N = self.shape 196 i = int(i) 197 if i < -N or i >= N: 198 raise IndexError('index (%d) out of range' % i) 199 if i < 0: 200 i += N 201 return self._get_sliceXint(slice(None), i) 202 203 def _get_intXint(self, row, col): 204 raise NotImplementedError() 205 206 def _get_intXarray(self, row, col): 207 raise NotImplementedError() 208 209 def _get_intXslice(self, row, col): 210 raise NotImplementedError() 211 212 def _get_sliceXint(self, row, col): 213 raise NotImplementedError() 214 215 def _get_sliceXslice(self, row, col): 216 raise NotImplementedError() 217 218 def _get_sliceXarray(self, row, col): 219 raise NotImplementedError() 220 221 def _get_arrayXint(self, row, col): 222 raise NotImplementedError() 223 224 def _get_arrayXslice(self, row, col): 225 raise NotImplementedError() 226 227 def _get_columnXarray(self, row, col): 228 raise NotImplementedError() 229 230 def _get_arrayXarray(self, row, col): 231 raise NotImplementedError() 232 233 def _set_intXint(self, row, col, x): 234 raise NotImplementedError() 235 236 def _set_arrayXarray(self, row, col, x): 237 raise NotImplementedError() 238 239 def _set_arrayXarray_sparse(self, row, col, x): 240 # Fall back to densifying x 241 x = np.asarray(x.toarray(), dtype=self.dtype) 242 x, _ = _broadcast_arrays(x, row) 243 self._set_arrayXarray(row, col, x) 244 245 246def _unpack_index(index): 247 """ Parse index. Always return a tuple of the form (row, col). 248 Valid type for row/col is integer, slice, or array of integers. 249 """ 250 # First, check if indexing with single boolean matrix. 251 from .base import spmatrix, isspmatrix 252 if (isinstance(index, (spmatrix, np.ndarray)) and 253 index.ndim == 2 and index.dtype.kind == 'b'): 254 return index.nonzero() 255 256 # Parse any ellipses. 257 index = _check_ellipsis(index) 258 259 # Next, parse the tuple or object 260 if isinstance(index, tuple): 261 if len(index) == 2: 262 row, col = index 263 elif len(index) == 1: 264 row, col = index[0], slice(None) 265 else: 266 raise IndexError('invalid number of indices') 267 else: 268 idx = _compatible_boolean_index(index) 269 if idx is None: 270 row, col = index, slice(None) 271 elif idx.ndim < 2: 272 return _boolean_index_to_array(idx), slice(None) 273 elif idx.ndim == 2: 274 return idx.nonzero() 275 # Next, check for validity and transform the index as needed. 276 if isspmatrix(row) or isspmatrix(col): 277 # Supporting sparse boolean indexing with both row and col does 278 # not work because spmatrix.ndim is always 2. 279 raise IndexError( 280 'Indexing with sparse matrices is not supported ' 281 'except boolean indexing where matrix and index ' 282 'are equal shapes.') 283 bool_row = _compatible_boolean_index(row) 284 bool_col = _compatible_boolean_index(col) 285 if bool_row is not None: 286 row = _boolean_index_to_array(bool_row) 287 if bool_col is not None: 288 col = _boolean_index_to_array(bool_col) 289 return row, col 290 291 292def _check_ellipsis(index): 293 """Process indices with Ellipsis. Returns modified index.""" 294 if index is Ellipsis: 295 return (slice(None), slice(None)) 296 297 if not isinstance(index, tuple): 298 return index 299 300 # TODO: Deprecate this multiple-ellipsis handling, 301 # as numpy no longer supports it. 302 303 # Find first ellipsis. 304 for j, v in enumerate(index): 305 if v is Ellipsis: 306 first_ellipsis = j 307 break 308 else: 309 return index 310 311 # Try to expand it using shortcuts for common cases 312 if len(index) == 1: 313 return (slice(None), slice(None)) 314 if len(index) == 2: 315 if first_ellipsis == 0: 316 if index[1] is Ellipsis: 317 return (slice(None), slice(None)) 318 return (slice(None), index[1]) 319 return (index[0], slice(None)) 320 321 # Expand it using a general-purpose algorithm 322 tail = [] 323 for v in index[first_ellipsis+1:]: 324 if v is not Ellipsis: 325 tail.append(v) 326 nd = first_ellipsis + len(tail) 327 nslice = max(0, 2 - nd) 328 return index[:first_ellipsis] + (slice(None),)*nslice + tuple(tail) 329 330 331def _maybe_bool_ndarray(idx): 332 """Returns a compatible array if elements are boolean. 333 """ 334 idx = np.asanyarray(idx) 335 if idx.dtype.kind == 'b': 336 return idx 337 return None 338 339 340def _first_element_bool(idx, max_dim=2): 341 """Returns True if first element of the incompatible 342 array type is boolean. 343 """ 344 if max_dim < 1: 345 return None 346 try: 347 first = next(iter(idx), None) 348 except TypeError: 349 return None 350 if isinstance(first, bool): 351 return True 352 return _first_element_bool(first, max_dim-1) 353 354 355def _compatible_boolean_index(idx): 356 """Returns a boolean index array that can be converted to 357 integer array. Returns None if no such array exists. 358 """ 359 # Presence of attribute `ndim` indicates a compatible array type. 360 if hasattr(idx, 'ndim') or _first_element_bool(idx): 361 return _maybe_bool_ndarray(idx) 362 return None 363 364 365def _boolean_index_to_array(idx): 366 if idx.ndim > 1: 367 raise IndexError('invalid index shape') 368 return np.where(idx)[0] 369