1"""Base class for sparse matrix formats using compressed storage."""
2__all__ = []
3
4from warnings import warn
5import operator
6
7import numpy as np
8from scipy._lib._util import _prune_array
9
10from .base import spmatrix, isspmatrix, SparseEfficiencyWarning
11from .data import _data_matrix, _minmax_mixin
12from .dia import dia_matrix
13from . import _sparsetools
14from ._sparsetools import (get_csr_submatrix, csr_sample_offsets, csr_todense,
15                           csr_sample_values, csr_row_index, csr_row_slice,
16                           csr_column_index1, csr_column_index2)
17from ._index import IndexMixin
18from .sputils import (upcast, upcast_char, to_native, isdense, isshape,
19                      getdtype, isscalarlike, isintlike, get_index_dtype,
20                      downcast_intp_index, get_sum_dtype, check_shape,
21                      matrix, asmatrix, is_pydata_spmatrix)
22
23
24class _cs_matrix(_data_matrix, _minmax_mixin, IndexMixin):
25    """base matrix class for compressed row- and column-oriented matrices"""
26
27    def __init__(self, arg1, shape=None, dtype=None, copy=False):
28        _data_matrix.__init__(self)
29
30        if isspmatrix(arg1):
31            if arg1.format == self.format and copy:
32                arg1 = arg1.copy()
33            else:
34                arg1 = arg1.asformat(self.format)
35            self._set_self(arg1)
36
37        elif isinstance(arg1, tuple):
38            if isshape(arg1):
39                # It's a tuple of matrix dimensions (M, N)
40                # create empty matrix
41                self._shape = check_shape(arg1)
42                M, N = self.shape
43                # Select index dtype large enough to pass array and
44                # scalar parameters to sparsetools
45                idx_dtype = get_index_dtype(maxval=max(M, N))
46                self.data = np.zeros(0, getdtype(dtype, default=float))
47                self.indices = np.zeros(0, idx_dtype)
48                self.indptr = np.zeros(self._swap((M, N))[0] + 1,
49                                       dtype=idx_dtype)
50            else:
51                if len(arg1) == 2:
52                    # (data, ij) format
53                    from .coo import coo_matrix
54                    other = self.__class__(coo_matrix(arg1, shape=shape,
55                                                      dtype=dtype))
56                    self._set_self(other)
57                elif len(arg1) == 3:
58                    # (data, indices, indptr) format
59                    (data, indices, indptr) = arg1
60
61                    # Select index dtype large enough to pass array and
62                    # scalar parameters to sparsetools
63                    maxval = None
64                    if shape is not None:
65                        maxval = max(shape)
66                    idx_dtype = get_index_dtype((indices, indptr),
67                                                maxval=maxval,
68                                                check_contents=True)
69
70                    self.indices = np.array(indices, copy=copy,
71                                            dtype=idx_dtype)
72                    self.indptr = np.array(indptr, copy=copy, dtype=idx_dtype)
73                    self.data = np.array(data, copy=copy, dtype=dtype)
74                else:
75                    raise ValueError("unrecognized {}_matrix "
76                                     "constructor usage".format(self.format))
77
78        else:
79            # must be dense
80            try:
81                arg1 = np.asarray(arg1)
82            except Exception as e:
83                raise ValueError("unrecognized {}_matrix constructor usage"
84                                 "".format(self.format)) from e
85            from .coo import coo_matrix
86            self._set_self(self.__class__(coo_matrix(arg1, dtype=dtype)))
87
88        # Read matrix dimensions given, if any
89        if shape is not None:
90            self._shape = check_shape(shape)
91        else:
92            if self.shape is None:
93                # shape not already set, try to infer dimensions
94                try:
95                    major_dim = len(self.indptr) - 1
96                    minor_dim = self.indices.max() + 1
97                except Exception as e:
98                    raise ValueError('unable to infer matrix dimensions') from e
99                else:
100                    self._shape = check_shape(self._swap((major_dim,
101                                                          minor_dim)))
102
103        if dtype is not None:
104            self.data = self.data.astype(dtype, copy=False)
105
106        self.check_format(full_check=False)
107
108    def getnnz(self, axis=None):
109        if axis is None:
110            return int(self.indptr[-1])
111        else:
112            if axis < 0:
113                axis += 2
114            axis, _ = self._swap((axis, 1 - axis))
115            _, N = self._swap(self.shape)
116            if axis == 0:
117                return np.bincount(downcast_intp_index(self.indices),
118                                   minlength=N)
119            elif axis == 1:
120                return np.diff(self.indptr)
121            raise ValueError('axis out of bounds')
122
123    getnnz.__doc__ = spmatrix.getnnz.__doc__
124
125    def _set_self(self, other, copy=False):
126        """take the member variables of other and assign them to self"""
127
128        if copy:
129            other = other.copy()
130
131        self.data = other.data
132        self.indices = other.indices
133        self.indptr = other.indptr
134        self._shape = check_shape(other.shape)
135
136    def check_format(self, full_check=True):
137        """check whether the matrix format is valid
138
139        Parameters
140        ----------
141        full_check : bool, optional
142            If `True`, rigorous check, O(N) operations. Otherwise
143            basic check, O(1) operations (default True).
144        """
145        # use _swap to determine proper bounds
146        major_name, minor_name = self._swap(('row', 'column'))
147        major_dim, minor_dim = self._swap(self.shape)
148
149        # index arrays should have integer data types
150        if self.indptr.dtype.kind != 'i':
151            warn("indptr array has non-integer dtype ({})"
152                 "".format(self.indptr.dtype.name), stacklevel=3)
153        if self.indices.dtype.kind != 'i':
154            warn("indices array has non-integer dtype ({})"
155                 "".format(self.indices.dtype.name), stacklevel=3)
156
157        idx_dtype = get_index_dtype((self.indptr, self.indices))
158        self.indptr = np.asarray(self.indptr, dtype=idx_dtype)
159        self.indices = np.asarray(self.indices, dtype=idx_dtype)
160        self.data = to_native(self.data)
161
162        # check array shapes
163        for x in [self.data.ndim, self.indices.ndim, self.indptr.ndim]:
164            if x != 1:
165                raise ValueError('data, indices, and indptr should be 1-D')
166
167        # check index pointer
168        if (len(self.indptr) != major_dim + 1):
169            raise ValueError("index pointer size ({}) should be ({})"
170                             "".format(len(self.indptr), major_dim + 1))
171        if (self.indptr[0] != 0):
172            raise ValueError("index pointer should start with 0")
173
174        # check index and data arrays
175        if (len(self.indices) != len(self.data)):
176            raise ValueError("indices and data should have the same size")
177        if (self.indptr[-1] > len(self.indices)):
178            raise ValueError("Last value of index pointer should be less than "
179                             "the size of index and data arrays")
180
181        self.prune()
182
183        if full_check:
184            # check format validity (more expensive)
185            if self.nnz > 0:
186                if self.indices.max() >= minor_dim:
187                    raise ValueError("{} index values must be < {}"
188                                     "".format(minor_name, minor_dim))
189                if self.indices.min() < 0:
190                    raise ValueError("{} index values must be >= 0"
191                                     "".format(minor_name))
192                if np.diff(self.indptr).min() < 0:
193                    raise ValueError("index pointer values must form a "
194                                     "non-decreasing sequence")
195
196        # if not self.has_sorted_indices():
197        #    warn('Indices were not in sorted order.  Sorting indices.')
198        #    self.sort_indices()
199        #    assert(self.has_sorted_indices())
200        # TODO check for duplicates?
201
202    #######################
203    # Boolean comparisons #
204    #######################
205
206    def _scalar_binopt(self, other, op):
207        """Scalar version of self._binopt, for cases in which no new nonzeros
208        are added. Produces a new spmatrix in canonical form.
209        """
210        self.sum_duplicates()
211        res = self._with_data(op(self.data, other), copy=True)
212        res.eliminate_zeros()
213        return res
214
215    def __eq__(self, other):
216        # Scalar other.
217        if isscalarlike(other):
218            if np.isnan(other):
219                return self.__class__(self.shape, dtype=np.bool_)
220
221            if other == 0:
222                warn("Comparing a sparse matrix with 0 using == is inefficient"
223                     ", try using != instead.", SparseEfficiencyWarning,
224                     stacklevel=3)
225                all_true = self.__class__(np.ones(self.shape, dtype=np.bool_))
226                inv = self._scalar_binopt(other, operator.ne)
227                return all_true - inv
228            else:
229                return self._scalar_binopt(other, operator.eq)
230        # Dense other.
231        elif isdense(other):
232            return self.todense() == other
233        # Pydata sparse other.
234        elif is_pydata_spmatrix(other):
235            return NotImplemented
236        # Sparse other.
237        elif isspmatrix(other):
238            warn("Comparing sparse matrices using == is inefficient, try using"
239                 " != instead.", SparseEfficiencyWarning, stacklevel=3)
240            # TODO sparse broadcasting
241            if self.shape != other.shape:
242                return False
243            elif self.format != other.format:
244                other = other.asformat(self.format)
245            res = self._binopt(other, '_ne_')
246            all_true = self.__class__(np.ones(self.shape, dtype=np.bool_))
247            return all_true - res
248        else:
249            return False
250
251    def __ne__(self, other):
252        # Scalar other.
253        if isscalarlike(other):
254            if np.isnan(other):
255                warn("Comparing a sparse matrix with nan using != is"
256                     " inefficient", SparseEfficiencyWarning, stacklevel=3)
257                all_true = self.__class__(np.ones(self.shape, dtype=np.bool_))
258                return all_true
259            elif other != 0:
260                warn("Comparing a sparse matrix with a nonzero scalar using !="
261                     " is inefficient, try using == instead.",
262                     SparseEfficiencyWarning, stacklevel=3)
263                all_true = self.__class__(np.ones(self.shape), dtype=np.bool_)
264                inv = self._scalar_binopt(other, operator.eq)
265                return all_true - inv
266            else:
267                return self._scalar_binopt(other, operator.ne)
268        # Dense other.
269        elif isdense(other):
270            return self.todense() != other
271        # Pydata sparse other.
272        elif is_pydata_spmatrix(other):
273            return NotImplemented
274        # Sparse other.
275        elif isspmatrix(other):
276            # TODO sparse broadcasting
277            if self.shape != other.shape:
278                return True
279            elif self.format != other.format:
280                other = other.asformat(self.format)
281            return self._binopt(other, '_ne_')
282        else:
283            return True
284
285    def _inequality(self, other, op, op_name, bad_scalar_msg):
286        # Scalar other.
287        if isscalarlike(other):
288            if 0 == other and op_name in ('_le_', '_ge_'):
289                raise NotImplementedError(" >= and <= don't work with 0.")
290            elif op(0, other):
291                warn(bad_scalar_msg, SparseEfficiencyWarning)
292                other_arr = np.empty(self.shape, dtype=np.result_type(other))
293                other_arr.fill(other)
294                other_arr = self.__class__(other_arr)
295                return self._binopt(other_arr, op_name)
296            else:
297                return self._scalar_binopt(other, op)
298        # Dense other.
299        elif isdense(other):
300            return op(self.todense(), other)
301        # Sparse other.
302        elif isspmatrix(other):
303            # TODO sparse broadcasting
304            if self.shape != other.shape:
305                raise ValueError("inconsistent shapes")
306            elif self.format != other.format:
307                other = other.asformat(self.format)
308            if op_name not in ('_ge_', '_le_'):
309                return self._binopt(other, op_name)
310
311            warn("Comparing sparse matrices using >= and <= is inefficient, "
312                 "using <, >, or !=, instead.", SparseEfficiencyWarning)
313            all_true = self.__class__(np.ones(self.shape, dtype=np.bool_))
314            res = self._binopt(other, '_gt_' if op_name == '_le_' else '_lt_')
315            return all_true - res
316        else:
317            raise ValueError("Operands could not be compared.")
318
319    def __lt__(self, other):
320        return self._inequality(other, operator.lt, '_lt_',
321                                "Comparing a sparse matrix with a scalar "
322                                "greater than zero using < is inefficient, "
323                                "try using >= instead.")
324
325    def __gt__(self, other):
326        return self._inequality(other, operator.gt, '_gt_',
327                                "Comparing a sparse matrix with a scalar "
328                                "less than zero using > is inefficient, "
329                                "try using <= instead.")
330
331    def __le__(self, other):
332        return self._inequality(other, operator.le, '_le_',
333                                "Comparing a sparse matrix with a scalar "
334                                "greater than zero using <= is inefficient, "
335                                "try using > instead.")
336
337    def __ge__(self, other):
338        return self._inequality(other, operator.ge, '_ge_',
339                                "Comparing a sparse matrix with a scalar "
340                                "less than zero using >= is inefficient, "
341                                "try using < instead.")
342
343    #################################
344    # Arithmetic operator overrides #
345    #################################
346
347    def _add_dense(self, other):
348        if other.shape != self.shape:
349            raise ValueError('Incompatible shapes ({} and {})'
350                             .format(self.shape, other.shape))
351        dtype = upcast_char(self.dtype.char, other.dtype.char)
352        order = self._swap('CF')[0]
353        result = np.array(other, dtype=dtype, order=order, copy=True)
354        M, N = self._swap(self.shape)
355        y = result if result.flags.c_contiguous else result.T
356        csr_todense(M, N, self.indptr, self.indices, self.data, y)
357        return matrix(result, copy=False)
358
359    def _add_sparse(self, other):
360        return self._binopt(other, '_plus_')
361
362    def _sub_sparse(self, other):
363        return self._binopt(other, '_minus_')
364
365    def multiply(self, other):
366        """Point-wise multiplication by another matrix, vector, or
367        scalar.
368        """
369        # Scalar multiplication.
370        if isscalarlike(other):
371            return self._mul_scalar(other)
372        # Sparse matrix or vector.
373        if isspmatrix(other):
374            if self.shape == other.shape:
375                other = self.__class__(other)
376                return self._binopt(other, '_elmul_')
377            # Single element.
378            elif other.shape == (1, 1):
379                return self._mul_scalar(other.toarray()[0, 0])
380            elif self.shape == (1, 1):
381                return other._mul_scalar(self.toarray()[0, 0])
382            # A row times a column.
383            elif self.shape[1] == 1 and other.shape[0] == 1:
384                return self._mul_sparse_matrix(other.tocsc())
385            elif self.shape[0] == 1 and other.shape[1] == 1:
386                return other._mul_sparse_matrix(self.tocsc())
387            # Row vector times matrix. other is a row.
388            elif other.shape[0] == 1 and self.shape[1] == other.shape[1]:
389                other = dia_matrix((other.toarray().ravel(), [0]),
390                                   shape=(other.shape[1], other.shape[1]))
391                return self._mul_sparse_matrix(other)
392            # self is a row.
393            elif self.shape[0] == 1 and self.shape[1] == other.shape[1]:
394                copy = dia_matrix((self.toarray().ravel(), [0]),
395                                  shape=(self.shape[1], self.shape[1]))
396                return other._mul_sparse_matrix(copy)
397            # Column vector times matrix. other is a column.
398            elif other.shape[1] == 1 and self.shape[0] == other.shape[0]:
399                other = dia_matrix((other.toarray().ravel(), [0]),
400                                   shape=(other.shape[0], other.shape[0]))
401                return other._mul_sparse_matrix(self)
402            # self is a column.
403            elif self.shape[1] == 1 and self.shape[0] == other.shape[0]:
404                copy = dia_matrix((self.toarray().ravel(), [0]),
405                                  shape=(self.shape[0], self.shape[0]))
406                return copy._mul_sparse_matrix(other)
407            else:
408                raise ValueError("inconsistent shapes")
409
410        # Assume other is a dense matrix/array, which produces a single-item
411        # object array if other isn't convertible to ndarray.
412        other = np.atleast_2d(other)
413
414        if other.ndim != 2:
415            return np.multiply(self.toarray(), other)
416        # Single element / wrapped object.
417        if other.size == 1:
418            return self._mul_scalar(other.flat[0])
419        # Fast case for trivial sparse matrix.
420        elif self.shape == (1, 1):
421            return np.multiply(self.toarray()[0, 0], other)
422
423        from .coo import coo_matrix
424        ret = self.tocoo()
425        # Matching shapes.
426        if self.shape == other.shape:
427            data = np.multiply(ret.data, other[ret.row, ret.col])
428        # Sparse row vector times...
429        elif self.shape[0] == 1:
430            if other.shape[1] == 1:  # Dense column vector.
431                data = np.multiply(ret.data, other)
432            elif other.shape[1] == self.shape[1]:  # Dense matrix.
433                data = np.multiply(ret.data, other[:, ret.col])
434            else:
435                raise ValueError("inconsistent shapes")
436            row = np.repeat(np.arange(other.shape[0]), len(ret.row))
437            col = np.tile(ret.col, other.shape[0])
438            return coo_matrix((data.view(np.ndarray).ravel(), (row, col)),
439                              shape=(other.shape[0], self.shape[1]),
440                              copy=False)
441        # Sparse column vector times...
442        elif self.shape[1] == 1:
443            if other.shape[0] == 1:  # Dense row vector.
444                data = np.multiply(ret.data[:, None], other)
445            elif other.shape[0] == self.shape[0]:  # Dense matrix.
446                data = np.multiply(ret.data[:, None], other[ret.row])
447            else:
448                raise ValueError("inconsistent shapes")
449            row = np.repeat(ret.row, other.shape[1])
450            col = np.tile(np.arange(other.shape[1]), len(ret.col))
451            return coo_matrix((data.view(np.ndarray).ravel(), (row, col)),
452                              shape=(self.shape[0], other.shape[1]),
453                              copy=False)
454        # Sparse matrix times dense row vector.
455        elif other.shape[0] == 1 and self.shape[1] == other.shape[1]:
456            data = np.multiply(ret.data, other[:, ret.col].ravel())
457        # Sparse matrix times dense column vector.
458        elif other.shape[1] == 1 and self.shape[0] == other.shape[0]:
459            data = np.multiply(ret.data, other[ret.row].ravel())
460        else:
461            raise ValueError("inconsistent shapes")
462        ret.data = data.view(np.ndarray).ravel()
463        return ret
464
465    ###########################
466    # Multiplication handlers #
467    ###########################
468
469    def _mul_vector(self, other):
470        M, N = self.shape
471
472        # output array
473        result = np.zeros(M, dtype=upcast_char(self.dtype.char,
474                                               other.dtype.char))
475
476        # csr_matvec or csc_matvec
477        fn = getattr(_sparsetools, self.format + '_matvec')
478        fn(M, N, self.indptr, self.indices, self.data, other, result)
479
480        return result
481
482    def _mul_multivector(self, other):
483        M, N = self.shape
484        n_vecs = other.shape[1]  # number of column vectors
485
486        result = np.zeros((M, n_vecs),
487                          dtype=upcast_char(self.dtype.char, other.dtype.char))
488
489        # csr_matvecs or csc_matvecs
490        fn = getattr(_sparsetools, self.format + '_matvecs')
491        fn(M, N, n_vecs, self.indptr, self.indices, self.data,
492           other.ravel(), result.ravel())
493
494        return result
495
496    def _mul_sparse_matrix(self, other):
497        M, K1 = self.shape
498        K2, N = other.shape
499
500        major_axis = self._swap((M, N))[0]
501        other = self.__class__(other)  # convert to this format
502
503        idx_dtype = get_index_dtype((self.indptr, self.indices,
504                                     other.indptr, other.indices))
505
506        fn = getattr(_sparsetools, self.format + '_matmat_maxnnz')
507        nnz = fn(M, N,
508                 np.asarray(self.indptr, dtype=idx_dtype),
509                 np.asarray(self.indices, dtype=idx_dtype),
510                 np.asarray(other.indptr, dtype=idx_dtype),
511                 np.asarray(other.indices, dtype=idx_dtype))
512
513        idx_dtype = get_index_dtype((self.indptr, self.indices,
514                                     other.indptr, other.indices),
515                                    maxval=nnz)
516
517        indptr = np.empty(major_axis + 1, dtype=idx_dtype)
518        indices = np.empty(nnz, dtype=idx_dtype)
519        data = np.empty(nnz, dtype=upcast(self.dtype, other.dtype))
520
521        fn = getattr(_sparsetools, self.format + '_matmat')
522        fn(M, N, np.asarray(self.indptr, dtype=idx_dtype),
523           np.asarray(self.indices, dtype=idx_dtype),
524           self.data,
525           np.asarray(other.indptr, dtype=idx_dtype),
526           np.asarray(other.indices, dtype=idx_dtype),
527           other.data,
528           indptr, indices, data)
529
530        return self.__class__((data, indices, indptr), shape=(M, N))
531
532    def diagonal(self, k=0):
533        rows, cols = self.shape
534        if k <= -rows or k >= cols:
535            return np.empty(0, dtype=self.data.dtype)
536        fn = getattr(_sparsetools, self.format + "_diagonal")
537        y = np.empty(min(rows + min(k, 0), cols - max(k, 0)),
538                     dtype=upcast(self.dtype))
539        fn(k, self.shape[0], self.shape[1], self.indptr, self.indices,
540           self.data, y)
541        return y
542
543    diagonal.__doc__ = spmatrix.diagonal.__doc__
544
545    #####################
546    # Other binary ops  #
547    #####################
548
549    def _maximum_minimum(self, other, npop, op_name, dense_check):
550        if isscalarlike(other):
551            if dense_check(other):
552                warn("Taking maximum (minimum) with > 0 (< 0) number results"
553                     " to a dense matrix.", SparseEfficiencyWarning,
554                     stacklevel=3)
555                other_arr = np.empty(self.shape, dtype=np.asarray(other).dtype)
556                other_arr.fill(other)
557                other_arr = self.__class__(other_arr)
558                return self._binopt(other_arr, op_name)
559            else:
560                self.sum_duplicates()
561                new_data = npop(self.data, np.asarray(other))
562                mat = self.__class__((new_data, self.indices, self.indptr),
563                                     dtype=new_data.dtype, shape=self.shape)
564                return mat
565        elif isdense(other):
566            return npop(self.todense(), other)
567        elif isspmatrix(other):
568            return self._binopt(other, op_name)
569        else:
570            raise ValueError("Operands not compatible.")
571
572    def maximum(self, other):
573        return self._maximum_minimum(other, np.maximum,
574                                     '_maximum_', lambda x: np.asarray(x) > 0)
575
576    maximum.__doc__ = spmatrix.maximum.__doc__
577
578    def minimum(self, other):
579        return self._maximum_minimum(other, np.minimum,
580                                     '_minimum_', lambda x: np.asarray(x) < 0)
581
582    minimum.__doc__ = spmatrix.minimum.__doc__
583
584    #####################
585    # Reduce operations #
586    #####################
587
588    def sum(self, axis=None, dtype=None, out=None):
589        """Sum the matrix over the given axis.  If the axis is None, sum
590        over both rows and columns, returning a scalar.
591        """
592        # The spmatrix base class already does axis=0 and axis=1 efficiently
593        # so we only do the case axis=None here
594        if (not hasattr(self, 'blocksize') and
595                axis in self._swap(((1, -1), (0, 2)))[0]):
596            # faster than multiplication for large minor axis in CSC/CSR
597            res_dtype = get_sum_dtype(self.dtype)
598            ret = np.zeros(len(self.indptr) - 1, dtype=res_dtype)
599
600            major_index, value = self._minor_reduce(np.add)
601            ret[major_index] = value
602            ret = asmatrix(ret)
603            if axis % 2 == 1:
604                ret = ret.T
605
606            if out is not None and out.shape != ret.shape:
607                raise ValueError('dimensions do not match')
608
609            return ret.sum(axis=(), dtype=dtype, out=out)
610        # spmatrix will handle the remaining situations when axis
611        # is in {None, -1, 0, 1}
612        else:
613            return spmatrix.sum(self, axis=axis, dtype=dtype, out=out)
614
615    sum.__doc__ = spmatrix.sum.__doc__
616
617    def _minor_reduce(self, ufunc, data=None):
618        """Reduce nonzeros with a ufunc over the minor axis when non-empty
619
620        Can be applied to a function of self.data by supplying data parameter.
621
622        Warning: this does not call sum_duplicates()
623
624        Returns
625        -------
626        major_index : array of ints
627            Major indices where nonzero
628
629        value : array of self.dtype
630            Reduce result for nonzeros in each major_index
631        """
632        if data is None:
633            data = self.data
634        major_index = np.flatnonzero(np.diff(self.indptr))
635        value = ufunc.reduceat(data,
636                               downcast_intp_index(self.indptr[major_index]))
637        return major_index, value
638
639    #######################
640    # Getting and Setting #
641    #######################
642
643    def _get_intXint(self, row, col):
644        M, N = self._swap(self.shape)
645        major, minor = self._swap((row, col))
646        indptr, indices, data = get_csr_submatrix(
647            M, N, self.indptr, self.indices, self.data,
648            major, major + 1, minor, minor + 1)
649        return data.sum(dtype=self.dtype)
650
651    def _get_sliceXslice(self, row, col):
652        major, minor = self._swap((row, col))
653        if major.step in (1, None) and minor.step in (1, None):
654            return self._get_submatrix(major, minor, copy=True)
655        return self._major_slice(major)._minor_slice(minor)
656
657    def _get_arrayXarray(self, row, col):
658        # inner indexing
659        idx_dtype = self.indices.dtype
660        M, N = self._swap(self.shape)
661        major, minor = self._swap((row, col))
662        major = np.asarray(major, dtype=idx_dtype)
663        minor = np.asarray(minor, dtype=idx_dtype)
664
665        val = np.empty(major.size, dtype=self.dtype)
666        csr_sample_values(M, N, self.indptr, self.indices, self.data,
667                          major.size, major.ravel(), minor.ravel(), val)
668        if major.ndim == 1:
669            return asmatrix(val)
670        return self.__class__(val.reshape(major.shape))
671
672    def _get_columnXarray(self, row, col):
673        # outer indexing
674        major, minor = self._swap((row, col))
675        return self._major_index_fancy(major)._minor_index_fancy(minor)
676
677    def _major_index_fancy(self, idx):
678        """Index along the major axis where idx is an array of ints.
679        """
680        idx_dtype = self.indices.dtype
681        indices = np.asarray(idx, dtype=idx_dtype).ravel()
682
683        _, N = self._swap(self.shape)
684        M = len(indices)
685        new_shape = self._swap((M, N))
686        if M == 0:
687            return self.__class__(new_shape)
688
689        row_nnz = np.diff(self.indptr)
690        idx_dtype = self.indices.dtype
691        res_indptr = np.zeros(M+1, dtype=idx_dtype)
692        np.cumsum(row_nnz[idx], out=res_indptr[1:])
693
694        nnz = res_indptr[-1]
695        res_indices = np.empty(nnz, dtype=idx_dtype)
696        res_data = np.empty(nnz, dtype=self.dtype)
697        csr_row_index(M, indices, self.indptr, self.indices, self.data,
698                      res_indices, res_data)
699
700        return self.__class__((res_data, res_indices, res_indptr),
701                              shape=new_shape, copy=False)
702
703    def _major_slice(self, idx, copy=False):
704        """Index along the major axis where idx is a slice object.
705        """
706        if idx == slice(None):
707            return self.copy() if copy else self
708
709        M, N = self._swap(self.shape)
710        start, stop, step = idx.indices(M)
711        M = len(range(start, stop, step))
712        new_shape = self._swap((M, N))
713        if M == 0:
714            return self.__class__(new_shape)
715
716        row_nnz = np.diff(self.indptr)
717        idx_dtype = self.indices.dtype
718        res_indptr = np.zeros(M+1, dtype=idx_dtype)
719        np.cumsum(row_nnz[idx], out=res_indptr[1:])
720
721        if step == 1:
722            all_idx = slice(self.indptr[start], self.indptr[stop])
723            res_indices = np.array(self.indices[all_idx], copy=copy)
724            res_data = np.array(self.data[all_idx], copy=copy)
725        else:
726            nnz = res_indptr[-1]
727            res_indices = np.empty(nnz, dtype=idx_dtype)
728            res_data = np.empty(nnz, dtype=self.dtype)
729            csr_row_slice(start, stop, step, self.indptr, self.indices,
730                          self.data, res_indices, res_data)
731
732        return self.__class__((res_data, res_indices, res_indptr),
733                              shape=new_shape, copy=False)
734
735    def _minor_index_fancy(self, idx):
736        """Index along the minor axis where idx is an array of ints.
737        """
738        idx_dtype = self.indices.dtype
739        idx = np.asarray(idx, dtype=idx_dtype).ravel()
740
741        M, N = self._swap(self.shape)
742        k = len(idx)
743        new_shape = self._swap((M, k))
744        if k == 0:
745            return self.__class__(new_shape)
746
747        # pass 1: count idx entries and compute new indptr
748        col_offsets = np.zeros(N, dtype=idx_dtype)
749        res_indptr = np.empty_like(self.indptr)
750        csr_column_index1(k, idx, M, N, self.indptr, self.indices,
751                          col_offsets, res_indptr)
752
753        # pass 2: copy indices/data for selected idxs
754        col_order = np.argsort(idx).astype(idx_dtype, copy=False)
755        nnz = res_indptr[-1]
756        res_indices = np.empty(nnz, dtype=idx_dtype)
757        res_data = np.empty(nnz, dtype=self.dtype)
758        csr_column_index2(col_order, col_offsets, len(self.indices),
759                          self.indices, self.data, res_indices, res_data)
760        return self.__class__((res_data, res_indices, res_indptr),
761                              shape=new_shape, copy=False)
762
763    def _minor_slice(self, idx, copy=False):
764        """Index along the minor axis where idx is a slice object.
765        """
766        if idx == slice(None):
767            return self.copy() if copy else self
768
769        M, N = self._swap(self.shape)
770        start, stop, step = idx.indices(N)
771        N = len(range(start, stop, step))
772        if N == 0:
773            return self.__class__(self._swap((M, N)))
774        if step == 1:
775            return self._get_submatrix(minor=idx, copy=copy)
776        # TODO: don't fall back to fancy indexing here
777        return self._minor_index_fancy(np.arange(start, stop, step))
778
779    def _get_submatrix(self, major=None, minor=None, copy=False):
780        """Return a submatrix of this matrix.
781
782        major, minor: None, int, or slice with step 1
783        """
784        M, N = self._swap(self.shape)
785        i0, i1 = _process_slice(major, M)
786        j0, j1 = _process_slice(minor, N)
787
788        if i0 == 0 and j0 == 0 and i1 == M and j1 == N:
789            return self.copy() if copy else self
790
791        indptr, indices, data = get_csr_submatrix(
792            M, N, self.indptr, self.indices, self.data, i0, i1, j0, j1)
793
794        shape = self._swap((i1 - i0, j1 - j0))
795        return self.__class__((data, indices, indptr), shape=shape,
796                              dtype=self.dtype, copy=False)
797
798    def _set_intXint(self, row, col, x):
799        i, j = self._swap((row, col))
800        self._set_many(i, j, x)
801
802    def _set_arrayXarray(self, row, col, x):
803        i, j = self._swap((row, col))
804        self._set_many(i, j, x)
805
806    def _set_arrayXarray_sparse(self, row, col, x):
807        # clear entries that will be overwritten
808        self._zero_many(*self._swap((row, col)))
809
810        M, N = row.shape  # matches col.shape
811        broadcast_row = M != 1 and x.shape[0] == 1
812        broadcast_col = N != 1 and x.shape[1] == 1
813        r, c = x.row, x.col
814
815        x = np.asarray(x.data, dtype=self.dtype)
816        if x.size == 0:
817            return
818
819        if broadcast_row:
820            r = np.repeat(np.arange(M), len(r))
821            c = np.tile(c, M)
822            x = np.tile(x, M)
823        if broadcast_col:
824            r = np.repeat(r, N)
825            c = np.tile(np.arange(N), len(c))
826            x = np.repeat(x, N)
827        # only assign entries in the new sparsity structure
828        i, j = self._swap((row[r, c], col[r, c]))
829        self._set_many(i, j, x)
830
831    def _setdiag(self, values, k):
832        if 0 in self.shape:
833            return
834
835        M, N = self.shape
836        broadcast = (values.ndim == 0)
837
838        if k < 0:
839            if broadcast:
840                max_index = min(M + k, N)
841            else:
842                max_index = min(M + k, N, len(values))
843            i = np.arange(max_index, dtype=self.indices.dtype)
844            j = np.arange(max_index, dtype=self.indices.dtype)
845            i -= k
846
847        else:
848            if broadcast:
849                max_index = min(M, N - k)
850            else:
851                max_index = min(M, N - k, len(values))
852            i = np.arange(max_index, dtype=self.indices.dtype)
853            j = np.arange(max_index, dtype=self.indices.dtype)
854            j += k
855
856        if not broadcast:
857            values = values[:len(i)]
858
859        self[i, j] = values
860
861    def _prepare_indices(self, i, j):
862        M, N = self._swap(self.shape)
863
864        def check_bounds(indices, bound):
865            idx = indices.max()
866            if idx >= bound:
867                raise IndexError('index (%d) out of range (>= %d)' %
868                                 (idx, bound))
869            idx = indices.min()
870            if idx < -bound:
871                raise IndexError('index (%d) out of range (< -%d)' %
872                                 (idx, bound))
873
874        i = np.array(i, dtype=self.indices.dtype, copy=False, ndmin=1).ravel()
875        j = np.array(j, dtype=self.indices.dtype, copy=False, ndmin=1).ravel()
876        check_bounds(i, M)
877        check_bounds(j, N)
878        return i, j, M, N
879
880    def _set_many(self, i, j, x):
881        """Sets value at each (i, j) to x
882
883        Here (i,j) index major and minor respectively, and must not contain
884        duplicate entries.
885        """
886        i, j, M, N = self._prepare_indices(i, j)
887        x = np.array(x, dtype=self.dtype, copy=False, ndmin=1).ravel()
888
889        n_samples = x.size
890        offsets = np.empty(n_samples, dtype=self.indices.dtype)
891        ret = csr_sample_offsets(M, N, self.indptr, self.indices, n_samples,
892                                 i, j, offsets)
893        if ret == 1:
894            # rinse and repeat
895            self.sum_duplicates()
896            csr_sample_offsets(M, N, self.indptr, self.indices, n_samples,
897                               i, j, offsets)
898
899        if -1 not in offsets:
900            # only affects existing non-zero cells
901            self.data[offsets] = x
902            return
903
904        else:
905            warn("Changing the sparsity structure of a {}_matrix is expensive."
906                 " lil_matrix is more efficient.".format(self.format),
907                 SparseEfficiencyWarning, stacklevel=3)
908            # replace where possible
909            mask = offsets > -1
910            self.data[offsets[mask]] = x[mask]
911            # only insertions remain
912            mask = ~mask
913            i = i[mask]
914            i[i < 0] += M
915            j = j[mask]
916            j[j < 0] += N
917            self._insert_many(i, j, x[mask])
918
919    def _zero_many(self, i, j):
920        """Sets value at each (i, j) to zero, preserving sparsity structure.
921
922        Here (i,j) index major and minor respectively.
923        """
924        i, j, M, N = self._prepare_indices(i, j)
925
926        n_samples = len(i)
927        offsets = np.empty(n_samples, dtype=self.indices.dtype)
928        ret = csr_sample_offsets(M, N, self.indptr, self.indices, n_samples,
929                                 i, j, offsets)
930        if ret == 1:
931            # rinse and repeat
932            self.sum_duplicates()
933            csr_sample_offsets(M, N, self.indptr, self.indices, n_samples,
934                               i, j, offsets)
935
936        # only assign zeros to the existing sparsity structure
937        self.data[offsets[offsets > -1]] = 0
938
939    def _insert_many(self, i, j, x):
940        """Inserts new nonzero at each (i, j) with value x
941
942        Here (i,j) index major and minor respectively.
943        i, j and x must be non-empty, 1d arrays.
944        Inserts each major group (e.g. all entries per row) at a time.
945        Maintains has_sorted_indices property.
946        Modifies i, j, x in place.
947        """
948        order = np.argsort(i, kind='mergesort')  # stable for duplicates
949        i = i.take(order, mode='clip')
950        j = j.take(order, mode='clip')
951        x = x.take(order, mode='clip')
952
953        do_sort = self.has_sorted_indices
954
955        # Update index data type
956        idx_dtype = get_index_dtype((self.indices, self.indptr),
957                                    maxval=(self.indptr[-1] + x.size))
958        self.indptr = np.asarray(self.indptr, dtype=idx_dtype)
959        self.indices = np.asarray(self.indices, dtype=idx_dtype)
960        i = np.asarray(i, dtype=idx_dtype)
961        j = np.asarray(j, dtype=idx_dtype)
962
963        # Collate old and new in chunks by major index
964        indices_parts = []
965        data_parts = []
966        ui, ui_indptr = np.unique(i, return_index=True)
967        ui_indptr = np.append(ui_indptr, len(j))
968        new_nnzs = np.diff(ui_indptr)
969        prev = 0
970        for c, (ii, js, je) in enumerate(zip(ui, ui_indptr, ui_indptr[1:])):
971            # old entries
972            start = self.indptr[prev]
973            stop = self.indptr[ii]
974            indices_parts.append(self.indices[start:stop])
975            data_parts.append(self.data[start:stop])
976
977            # handle duplicate j: keep last setting
978            uj, uj_indptr = np.unique(j[js:je][::-1], return_index=True)
979            if len(uj) == je - js:
980                indices_parts.append(j[js:je])
981                data_parts.append(x[js:je])
982            else:
983                indices_parts.append(j[js:je][::-1][uj_indptr])
984                data_parts.append(x[js:je][::-1][uj_indptr])
985                new_nnzs[c] = len(uj)
986
987            prev = ii
988
989        # remaining old entries
990        start = self.indptr[ii]
991        indices_parts.append(self.indices[start:])
992        data_parts.append(self.data[start:])
993
994        # update attributes
995        self.indices = np.concatenate(indices_parts)
996        self.data = np.concatenate(data_parts)
997        nnzs = np.empty(self.indptr.shape, dtype=idx_dtype)
998        nnzs[0] = idx_dtype(0)
999        indptr_diff = np.diff(self.indptr)
1000        indptr_diff[ui] += new_nnzs
1001        nnzs[1:] = indptr_diff
1002        self.indptr = np.cumsum(nnzs, out=nnzs)
1003
1004        if do_sort:
1005            # TODO: only sort where necessary
1006            self.has_sorted_indices = False
1007            self.sort_indices()
1008
1009        self.check_format(full_check=False)
1010
1011    ######################
1012    # Conversion methods #
1013    ######################
1014
1015    def tocoo(self, copy=True):
1016        major_dim, minor_dim = self._swap(self.shape)
1017        minor_indices = self.indices
1018        major_indices = np.empty(len(minor_indices), dtype=self.indices.dtype)
1019        _sparsetools.expandptr(major_dim, self.indptr, major_indices)
1020        row, col = self._swap((major_indices, minor_indices))
1021
1022        from .coo import coo_matrix
1023        return coo_matrix((self.data, (row, col)), self.shape, copy=copy,
1024                          dtype=self.dtype)
1025
1026    tocoo.__doc__ = spmatrix.tocoo.__doc__
1027
1028    def toarray(self, order=None, out=None):
1029        if out is None and order is None:
1030            order = self._swap('cf')[0]
1031        out = self._process_toarray_args(order, out)
1032        if not (out.flags.c_contiguous or out.flags.f_contiguous):
1033            raise ValueError('Output array must be C or F contiguous')
1034        # align ideal order with output array order
1035        if out.flags.c_contiguous:
1036            x = self.tocsr()
1037            y = out
1038        else:
1039            x = self.tocsc()
1040            y = out.T
1041        M, N = x._swap(x.shape)
1042        csr_todense(M, N, x.indptr, x.indices, x.data, y)
1043        return out
1044
1045    toarray.__doc__ = spmatrix.toarray.__doc__
1046
1047    ##############################################################
1048    # methods that examine or modify the internal data structure #
1049    ##############################################################
1050
1051    def eliminate_zeros(self):
1052        """Remove zero entries from the matrix
1053
1054        This is an *in place* operation.
1055        """
1056        M, N = self._swap(self.shape)
1057        _sparsetools.csr_eliminate_zeros(M, N, self.indptr, self.indices,
1058                                         self.data)
1059        self.prune()  # nnz may have changed
1060
1061    def __get_has_canonical_format(self):
1062        """Determine whether the matrix has sorted indices and no duplicates
1063
1064        Returns
1065            - True: if the above applies
1066            - False: otherwise
1067
1068        has_canonical_format implies has_sorted_indices, so if the latter flag
1069        is False, so will the former be; if the former is found True, the
1070        latter flag is also set.
1071        """
1072
1073        # first check to see if result was cached
1074        if not getattr(self, '_has_sorted_indices', True):
1075            # not sorted => not canonical
1076            self._has_canonical_format = False
1077        elif not hasattr(self, '_has_canonical_format'):
1078            self.has_canonical_format = bool(
1079                _sparsetools.csr_has_canonical_format(
1080                    len(self.indptr) - 1, self.indptr, self.indices))
1081        return self._has_canonical_format
1082
1083    def __set_has_canonical_format(self, val):
1084        self._has_canonical_format = bool(val)
1085        if val:
1086            self.has_sorted_indices = True
1087
1088    has_canonical_format = property(fget=__get_has_canonical_format,
1089                                    fset=__set_has_canonical_format)
1090
1091    def sum_duplicates(self):
1092        """Eliminate duplicate matrix entries by adding them together
1093
1094        This is an *in place* operation.
1095        """
1096        if self.has_canonical_format:
1097            return
1098        self.sort_indices()
1099
1100        M, N = self._swap(self.shape)
1101        _sparsetools.csr_sum_duplicates(M, N, self.indptr, self.indices,
1102                                        self.data)
1103
1104        self.prune()  # nnz may have changed
1105        self.has_canonical_format = True
1106
1107    def __get_sorted(self):
1108        """Determine whether the matrix has sorted indices
1109
1110        Returns
1111            - True: if the indices of the matrix are in sorted order
1112            - False: otherwise
1113
1114        """
1115
1116        # first check to see if result was cached
1117        if not hasattr(self, '_has_sorted_indices'):
1118            self._has_sorted_indices = bool(
1119                _sparsetools.csr_has_sorted_indices(
1120                    len(self.indptr) - 1, self.indptr, self.indices))
1121        return self._has_sorted_indices
1122
1123    def __set_sorted(self, val):
1124        self._has_sorted_indices = bool(val)
1125
1126    has_sorted_indices = property(fget=__get_sorted, fset=__set_sorted)
1127
1128    def sorted_indices(self):
1129        """Return a copy of this matrix with sorted indices
1130        """
1131        A = self.copy()
1132        A.sort_indices()
1133        return A
1134
1135        # an alternative that has linear complexity is the following
1136        # although the previous option is typically faster
1137        # return self.toother().toother()
1138
1139    def sort_indices(self):
1140        """Sort the indices of this matrix *in place*
1141        """
1142
1143        if not self.has_sorted_indices:
1144            _sparsetools.csr_sort_indices(len(self.indptr) - 1, self.indptr,
1145                                          self.indices, self.data)
1146            self.has_sorted_indices = True
1147
1148    def prune(self):
1149        """Remove empty space after all non-zero elements.
1150        """
1151        major_dim = self._swap(self.shape)[0]
1152
1153        if len(self.indptr) != major_dim + 1:
1154            raise ValueError('index pointer has invalid length')
1155        if len(self.indices) < self.nnz:
1156            raise ValueError('indices array has fewer than nnz elements')
1157        if len(self.data) < self.nnz:
1158            raise ValueError('data array has fewer than nnz elements')
1159
1160        self.indices = _prune_array(self.indices[:self.nnz])
1161        self.data = _prune_array(self.data[:self.nnz])
1162
1163    def resize(self, *shape):
1164        shape = check_shape(shape)
1165        if hasattr(self, 'blocksize'):
1166            bm, bn = self.blocksize
1167            new_M, rm = divmod(shape[0], bm)
1168            new_N, rn = divmod(shape[1], bn)
1169            if rm or rn:
1170                raise ValueError("shape must be divisible into %s blocks. "
1171                                 "Got %s" % (self.blocksize, shape))
1172            M, N = self.shape[0] // bm, self.shape[1] // bn
1173        else:
1174            new_M, new_N = self._swap(shape)
1175            M, N = self._swap(self.shape)
1176
1177        if new_M < M:
1178            self.indices = self.indices[:self.indptr[new_M]]
1179            self.data = self.data[:self.indptr[new_M]]
1180            self.indptr = self.indptr[:new_M + 1]
1181        elif new_M > M:
1182            self.indptr = np.resize(self.indptr, new_M + 1)
1183            self.indptr[M + 1:].fill(self.indptr[M])
1184
1185        if new_N < N:
1186            mask = self.indices < new_N
1187            if not np.all(mask):
1188                self.indices = self.indices[mask]
1189                self.data = self.data[mask]
1190                major_index, val = self._minor_reduce(np.add, mask)
1191                self.indptr.fill(0)
1192                self.indptr[1:][major_index] = val
1193                np.cumsum(self.indptr, out=self.indptr)
1194
1195        self._shape = shape
1196
1197    resize.__doc__ = spmatrix.resize.__doc__
1198
1199    ###################
1200    # utility methods #
1201    ###################
1202
1203    # needed by _data_matrix
1204    def _with_data(self, data, copy=True):
1205        """Returns a matrix with the same sparsity structure as self,
1206        but with different data.  By default the structure arrays
1207        (i.e. .indptr and .indices) are copied.
1208        """
1209        if copy:
1210            return self.__class__((data, self.indices.copy(),
1211                                   self.indptr.copy()),
1212                                  shape=self.shape,
1213                                  dtype=data.dtype)
1214        else:
1215            return self.__class__((data, self.indices, self.indptr),
1216                                  shape=self.shape, dtype=data.dtype)
1217
1218    def _binopt(self, other, op):
1219        """apply the binary operation fn to two sparse matrices."""
1220        other = self.__class__(other)
1221
1222        # e.g. csr_plus_csr, csr_minus_csr, etc.
1223        fn = getattr(_sparsetools, self.format + op + self.format)
1224
1225        maxnnz = self.nnz + other.nnz
1226        idx_dtype = get_index_dtype((self.indptr, self.indices,
1227                                     other.indptr, other.indices),
1228                                    maxval=maxnnz)
1229        indptr = np.empty(self.indptr.shape, dtype=idx_dtype)
1230        indices = np.empty(maxnnz, dtype=idx_dtype)
1231
1232        bool_ops = ['_ne_', '_lt_', '_gt_', '_le_', '_ge_']
1233        if op in bool_ops:
1234            data = np.empty(maxnnz, dtype=np.bool_)
1235        else:
1236            data = np.empty(maxnnz, dtype=upcast(self.dtype, other.dtype))
1237
1238        fn(self.shape[0], self.shape[1],
1239           np.asarray(self.indptr, dtype=idx_dtype),
1240           np.asarray(self.indices, dtype=idx_dtype),
1241           self.data,
1242           np.asarray(other.indptr, dtype=idx_dtype),
1243           np.asarray(other.indices, dtype=idx_dtype),
1244           other.data,
1245           indptr, indices, data)
1246
1247        A = self.__class__((data, indices, indptr), shape=self.shape)
1248        A.prune()
1249
1250        return A
1251
1252    def _divide_sparse(self, other):
1253        """
1254        Divide this matrix by a second sparse matrix.
1255        """
1256        if other.shape != self.shape:
1257            raise ValueError('inconsistent shapes')
1258
1259        r = self._binopt(other, '_eldiv_')
1260
1261        if np.issubdtype(r.dtype, np.inexact):
1262            # Eldiv leaves entries outside the combined sparsity
1263            # pattern empty, so they must be filled manually.
1264            # Everything outside of other's sparsity is NaN, and everything
1265            # inside it is either zero or defined by eldiv.
1266            out = np.empty(self.shape, dtype=self.dtype)
1267            out.fill(np.nan)
1268            row, col = other.nonzero()
1269            out[row, col] = 0
1270            r = r.tocoo()
1271            out[r.row, r.col] = r.data
1272            out = matrix(out)
1273        else:
1274            # integers types go with nan <-> 0
1275            out = r
1276
1277        return out
1278
1279
1280def _process_slice(sl, num):
1281    if sl is None:
1282        i0, i1 = 0, num
1283    elif isinstance(sl, slice):
1284        i0, i1, stride = sl.indices(num)
1285        if stride != 1:
1286            raise ValueError('slicing with step != 1 not supported')
1287        i0 = min(i0, i1)  # give an empty slice when i0 > i1
1288    elif isintlike(sl):
1289        if sl < 0:
1290            sl += num
1291        i0, i1 = sl, sl + 1
1292        if i0 < 0 or i1 > num:
1293            raise IndexError('index out of bounds: 0 <= %d < %d <= %d' %
1294                             (i0, i1, num))
1295    else:
1296        raise TypeError('expected slice or scalar')
1297
1298    return i0, i1
1299