1"""
2
3Module for the SDM class.
4
5"""
6
7from operator import add, neg, pos, sub, mul
8from collections import defaultdict
9
10from sympy.utilities.iterables import _strongly_connected_components
11
12from .exceptions import DDMBadInputError, DDMDomainError, DDMShapeError
13
14from .ddm import DDM
15
16
17class SDM(dict):
18    r"""Sparse matrix based on polys domain elements
19
20    This is a dict subclass and is a wrapper for a dict of dicts that supports
21    basic matrix arithmetic +, -, *, **.
22
23
24    In order to create a new :py:class:`~.SDM`, a dict
25    of dicts mapping non-zero elements to their
26    corresponding row and column in the matrix is needed.
27
28    We also need to specify the shape and :py:class:`~.Domain`
29    of our :py:class:`~.SDM` object.
30
31    We declare a 2x2 :py:class:`~.SDM` matrix belonging
32    to QQ domain as shown below.
33    The 2x2 Matrix in the example is
34
35    .. math::
36           A = \left[\begin{array}{ccc}
37                0 & \frac{1}{2} \\
38                0 & 0 \end{array} \right]
39
40
41    >>> from sympy.polys.matrices.sdm import SDM
42    >>> from sympy import QQ
43    >>> elemsdict = {0:{1:QQ(1, 2)}}
44    >>> A = SDM(elemsdict, (2, 2), QQ)
45    >>> A
46    {0: {1: 1/2}}
47
48    We can manipulate :py:class:`~.SDM` the same way
49    as a Matrix class
50
51    >>> from sympy import ZZ
52    >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ)
53    >>> B  = SDM({0:{0: ZZ(3)}, 1:{1:ZZ(4)}}, (2, 2), ZZ)
54    >>> A + B
55    {0: {0: 3, 1: 2}, 1: {0: 1, 1: 4}}
56
57    Multiplication
58
59    >>> A*B
60    {0: {1: 8}, 1: {0: 3}}
61    >>> A*ZZ(2)
62    {0: {1: 4}, 1: {0: 2}}
63
64    """
65
66    fmt = 'sparse'
67
68    def __init__(self, elemsdict, shape, domain):
69        super().__init__(elemsdict)
70        self.shape = self.rows, self.cols = m, n = shape
71        self.domain = domain
72
73        if not all(0 <= r < m for r in self):
74            raise DDMBadInputError("Row out of range")
75        if not all(0 <= c < n for row in self.values() for c in row):
76            raise DDMBadInputError("Column out of range")
77
78    def getitem(self, i, j):
79        try:
80            return self[i][j]
81        except KeyError:
82            m, n = self.shape
83            if -m <= i < m and -n <= j < n:
84                try:
85                    return self[i % m][j % n]
86                except KeyError:
87                    return self.domain.zero
88            else:
89                raise IndexError("index out of range")
90
91    def setitem(self, i, j, value):
92        m, n = self.shape
93        if not (-m <= i < m and -n <= j < n):
94            raise IndexError("index out of range")
95        i, j = i % m, j % n
96        if value:
97            try:
98                self[i][j] = value
99            except KeyError:
100                self[i] = {j: value}
101        else:
102            rowi = self.get(i, None)
103            if rowi is not None:
104                try:
105                    del rowi[j]
106                except KeyError:
107                    pass
108                else:
109                    if not rowi:
110                        del self[i]
111
112    def extract_slice(self, slice1, slice2):
113        m, n = self.shape
114        ri = range(m)[slice1]
115        ci = range(n)[slice2]
116
117        sdm = {}
118        for i, row in self.items():
119            if i in ri:
120                row = {ci.index(j): e for j, e in row.items() if j in ci}
121                if row:
122                    sdm[ri.index(i)] = row
123
124        return self.new(sdm, (len(ri), len(ci)), self.domain)
125
126    def extract(self, rows, cols):
127        if not (self and rows and cols):
128            return self.zeros((len(rows), len(cols)), self.domain)
129
130        m, n = self.shape
131        if not (-m <= min(rows) <= max(rows) < m):
132            raise IndexError('Row index out of range')
133        if not (-n <= min(cols) <= max(cols) < n):
134            raise IndexError('Column index out of range')
135
136        # rows and cols can contain duplicates e.g. M[[1, 2, 2], [0, 1]]
137        # Build a map from row/col in self to list of rows/cols in output
138        rowmap = defaultdict(list)
139        colmap = defaultdict(list)
140        for i2, i1 in enumerate(rows):
141            rowmap[i1 % m].append(i2)
142        for j2, j1 in enumerate(cols):
143            colmap[j1 % n].append(j2)
144
145        # Used to efficiently skip zero rows/cols
146        rowset = set(rowmap)
147        colset = set(colmap)
148
149        sdm1 = self
150        sdm2 = {}
151        for i1 in rowset & set(sdm1):
152            row1 = sdm1[i1]
153            row2 = {}
154            for j1 in colset & set(row1):
155                row1_j1 = row1[j1]
156                for j2 in colmap[j1]:
157                    row2[j2] = row1_j1
158            if row2:
159                for i2 in rowmap[i1]:
160                    sdm2[i2] = row2.copy()
161
162        return self.new(sdm2, (len(rows), len(cols)), self.domain)
163
164    def __str__(self):
165        rowsstr = []
166        for i, row in self.items():
167            elemsstr = ', '.join('%s: %s' % (j, elem) for j, elem in row.items())
168            rowsstr.append('%s: {%s}' % (i, elemsstr))
169        return '{%s}' % ', '.join(rowsstr)
170
171    def __repr__(self):
172        cls = type(self).__name__
173        rows = dict.__repr__(self)
174        return '%s(%s, %s, %s)' % (cls, rows, self.shape, self.domain)
175
176    @classmethod
177    def new(cls, sdm, shape, domain):
178        """
179
180        Parameters
181        ==========
182
183        sdm: A dict of dicts for non-zero elements in SDM
184        shape: tuple representing dimension of SDM
185        domain: Represents :py:class:`~.Domain` of SDM
186
187        Returns
188        =======
189
190        An :py:class:`~.SDM` object
191
192        Examples
193        ========
194
195        >>> from sympy.polys.matrices.sdm import SDM
196        >>> from sympy import QQ
197        >>> elemsdict = {0:{1: QQ(2)}}
198        >>> A = SDM.new(elemsdict, (2, 2), QQ)
199        >>> A
200        {0: {1: 2}}
201
202        """
203        return cls(sdm, shape, domain)
204
205    def copy(A):
206        """
207        Returns the copy of a :py:class:`~.SDM` object
208
209        Examples
210        ========
211
212        >>> from sympy.polys.matrices.sdm import SDM
213        >>> from sympy import QQ
214        >>> elemsdict = {0:{1:QQ(2)}, 1:{}}
215        >>> A = SDM(elemsdict, (2, 2), QQ)
216        >>> B = A.copy()
217        >>> B
218        {0: {1: 2}, 1: {}}
219
220        """
221        Ac = {i: Ai.copy() for i, Ai in A.items()}
222        return A.new(Ac, A.shape, A.domain)
223
224    @classmethod
225    def from_list(cls, ddm, shape, domain):
226        """
227
228        Parameters
229        ==========
230
231        ddm:
232            list of lists containing domain elements
233        shape:
234            Dimensions of :py:class:`~.SDM` matrix
235        domain:
236            Represents :py:class:`~.Domain` of :py:class:`~.SDM` object
237
238        Returns
239        =======
240
241        :py:class:`~.SDM` containing elements of ddm
242
243        Examples
244        ========
245
246        >>> from sympy.polys.matrices.sdm import SDM
247        >>> from sympy import QQ
248        >>> ddm = [[QQ(1, 2), QQ(0)], [QQ(0), QQ(3, 4)]]
249        >>> A = SDM.from_list(ddm, (2, 2), QQ)
250        >>> A
251        {0: {0: 1/2}, 1: {1: 3/4}}
252
253        """
254
255        m, n = shape
256        if not (len(ddm) == m and all(len(row) == n for row in ddm)):
257            raise DDMBadInputError("Inconsistent row-list/shape")
258        getrow = lambda i: {j:ddm[i][j] for j in range(n) if ddm[i][j]}
259        irows = ((i, getrow(i)) for i in range(m))
260        sdm = {i: row for i, row in irows if row}
261        return cls(sdm, shape, domain)
262
263    @classmethod
264    def from_ddm(cls, ddm):
265        """
266        converts object of :py:class:`~.DDM` to
267        :py:class:`~.SDM`
268
269        Examples
270        ========
271
272        >>> from sympy.polys.matrices.ddm import DDM
273        >>> from sympy.polys.matrices.sdm import SDM
274        >>> from sympy import QQ
275        >>> ddm = DDM( [[QQ(1, 2), 0], [0, QQ(3, 4)]], (2, 2), QQ)
276        >>> A = SDM.from_ddm(ddm)
277        >>> A
278        {0: {0: 1/2}, 1: {1: 3/4}}
279
280        """
281        return cls.from_list(ddm, ddm.shape, ddm.domain)
282
283    def to_list(M):
284        """
285
286        Converts a :py:class:`~.SDM` object to a list
287
288        Examples
289        ========
290
291        >>> from sympy.polys.matrices.sdm import SDM
292        >>> from sympy import QQ
293        >>> elemsdict = {0:{1:QQ(2)}, 1:{}}
294        >>> A = SDM(elemsdict, (2, 2), QQ)
295        >>> A.to_list()
296        [[0, 2], [0, 0]]
297
298        """
299        m, n = M.shape
300        zero = M.domain.zero
301        ddm = [[zero] * n for _ in range(m)]
302        for i, row in M.items():
303            for j, e in row.items():
304                ddm[i][j] = e
305        return ddm
306
307    def to_list_flat(M):
308        m, n = M.shape
309        zero = M.domain.zero
310        flat = [zero] * (m * n)
311        for i, row in M.items():
312            for j, e in row.items():
313                flat[i*n + j] = e
314        return flat
315
316    def to_dok(M):
317        return {(i, j): e for i, row in M.items() for j, e in row.items()}
318
319    def to_ddm(M):
320        """
321        Convert a :py:class:`~.SDM` object to a :py:class:`~.DDM` object
322
323        Examples
324        ========
325
326        >>> from sympy.polys.matrices.sdm import SDM
327        >>> from sympy import QQ
328        >>> A = SDM({0:{1:QQ(2)}, 1:{}}, (2, 2), QQ)
329        >>> A.to_ddm()
330        [[0, 2], [0, 0]]
331
332        """
333        return DDM(M.to_list(), M.shape, M.domain)
334
335    def to_sdm(M):
336        return M
337
338    @classmethod
339    def zeros(cls, shape, domain):
340        r"""
341
342        Returns a :py:class:`~.SDM` of size shape,
343        belonging to the specified domain
344
345        In the example below we declare a matrix A where,
346
347        .. math::
348            A := \left[\begin{array}{ccc}
349            0 & 0 & 0 \\
350            0 & 0 & 0 \end{array} \right]
351
352        >>> from sympy.polys.matrices.sdm import SDM
353        >>> from sympy import QQ
354        >>> A = SDM.zeros((2, 3), QQ)
355        >>> A
356        {}
357
358        """
359        return cls({}, shape, domain)
360
361    @classmethod
362    def ones(cls, shape, domain):
363        one = domain.one
364        m, n = shape
365        row = dict(zip(range(n), [one]*n))
366        sdm = {i: row.copy() for i in range(m)}
367        return cls(sdm, shape, domain)
368
369    @classmethod
370    def eye(cls, shape, domain):
371        """
372
373        Returns a identity :py:class:`~.SDM` matrix of dimensions
374        size x size, belonging to the specified domain
375
376        Examples
377        ========
378
379        >>> from sympy.polys.matrices.sdm import SDM
380        >>> from sympy import QQ
381        >>> I = SDM.eye((2, 2), QQ)
382        >>> I
383        {0: {0: 1}, 1: {1: 1}}
384
385        """
386        rows, cols = shape
387        one = domain.one
388        sdm = {i: {i: one} for i in range(min(rows, cols))}
389        return cls(sdm, shape, domain)
390
391    @classmethod
392    def diag(cls, diagonal, domain, shape):
393        sdm = {i: {i: v} for i, v in enumerate(diagonal) if v}
394        return cls(sdm, shape, domain)
395
396    def transpose(M):
397        """
398
399        Returns the transpose of a :py:class:`~.SDM` matrix
400
401        Examples
402        ========
403
404        >>> from sympy.polys.matrices.sdm import SDM
405        >>> from sympy import QQ
406        >>> A = SDM({0:{1:QQ(2)}, 1:{}}, (2, 2), QQ)
407        >>> A.transpose()
408        {1: {0: 2}}
409
410        """
411        MT = sdm_transpose(M)
412        return M.new(MT, M.shape[::-1], M.domain)
413
414    def __add__(A, B):
415        if not isinstance(B, SDM):
416            return NotImplemented
417        return A.add(B)
418
419    def __sub__(A, B):
420        if not isinstance(B, SDM):
421            return NotImplemented
422        return A.sub(B)
423
424    def __neg__(A):
425        return A.neg()
426
427    def __mul__(A, B):
428        """A * B"""
429        if isinstance(B, SDM):
430            return A.matmul(B)
431        elif B in A.domain:
432            return A.mul(B)
433        else:
434            return NotImplemented
435
436    def __rmul__(a, b):
437        if b in a.domain:
438            return a.rmul(b)
439        else:
440            return NotImplemented
441
442    def matmul(A, B):
443        """
444        Performs matrix multiplication of two SDM matrices
445
446        Parameters
447        ==========
448
449        A, B: SDM to multiply
450
451        Returns
452        =======
453
454        SDM
455            SDM after multiplication
456
457        Raises
458        ======
459
460        DomainError
461            If domain of A does not match
462            with that of B
463
464        Examples
465        ========
466
467        >>> from sympy import ZZ
468        >>> from sympy.polys.matrices.sdm import SDM
469        >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ)
470        >>> B = SDM({0:{0:ZZ(2), 1:ZZ(3)}, 1:{0:ZZ(4)}}, (2, 2), ZZ)
471        >>> A.matmul(B)
472        {0: {0: 8}, 1: {0: 2, 1: 3}}
473
474        """
475        if A.domain != B.domain:
476            raise DDMDomainError
477        m, n = A.shape
478        n2, o = B.shape
479        if n != n2:
480            raise DDMShapeError
481        C = sdm_matmul(A, B, A.domain, m, o)
482        return A.new(C, (m, o), A.domain)
483
484    def mul(A, b):
485        """
486        Multiplies each element of A with a scalar b
487
488        Examples
489        ========
490
491        >>> from sympy import ZZ
492        >>> from sympy.polys.matrices.sdm import SDM
493        >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ)
494        >>> A.mul(ZZ(3))
495        {0: {1: 6}, 1: {0: 3}}
496
497        """
498        Csdm = unop_dict(A, lambda aij: aij*b)
499        return A.new(Csdm, A.shape, A.domain)
500
501    def rmul(A, b):
502        Csdm = unop_dict(A, lambda aij: b*aij)
503        return A.new(Csdm, A.shape, A.domain)
504
505    def mul_elementwise(A, B):
506        if A.domain != B.domain:
507            raise DDMDomainError
508        if A.shape != B.shape:
509            raise DDMShapeError
510        zero = A.domain.zero
511        fzero = lambda e: zero
512        Csdm = binop_dict(A, B, mul, fzero, fzero)
513        return A.new(Csdm, A.shape, A.domain)
514
515    def add(A, B):
516        """
517
518        Adds two :py:class:`~.SDM` matrices
519
520        Examples
521        ========
522
523        >>> from sympy import ZZ
524        >>> from sympy.polys.matrices.sdm import SDM
525        >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ)
526        >>> B = SDM({0:{0: ZZ(3)}, 1:{1:ZZ(4)}}, (2, 2), ZZ)
527        >>> A.add(B)
528        {0: {0: 3, 1: 2}, 1: {0: 1, 1: 4}}
529
530        """
531
532        Csdm = binop_dict(A, B, add, pos, pos)
533        return A.new(Csdm, A.shape, A.domain)
534
535    def sub(A, B):
536        """
537
538        Subtracts two :py:class:`~.SDM` matrices
539
540        Examples
541        ========
542
543        >>> from sympy import ZZ
544        >>> from sympy.polys.matrices.sdm import SDM
545        >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ)
546        >>> B  = SDM({0:{0: ZZ(3)}, 1:{1:ZZ(4)}}, (2, 2), ZZ)
547        >>> A.sub(B)
548        {0: {0: -3, 1: 2}, 1: {0: 1, 1: -4}}
549
550        """
551        Csdm = binop_dict(A, B, sub, pos, neg)
552        return A.new(Csdm, A.shape, A.domain)
553
554    def neg(A):
555        """
556
557        Returns the negative of a :py:class:`~.SDM` matrix
558
559        Examples
560        ========
561
562        >>> from sympy import ZZ
563        >>> from sympy.polys.matrices.sdm import SDM
564        >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ)
565        >>> A.neg()
566        {0: {1: -2}, 1: {0: -1}}
567
568        """
569        Csdm = unop_dict(A, neg)
570        return A.new(Csdm, A.shape, A.domain)
571
572    def convert_to(A, K):
573        """
574
575        Converts the :py:class:`~.Domain` of a :py:class:`~.SDM` matrix to K
576
577        Examples
578        ========
579
580        >>> from sympy import ZZ, QQ
581        >>> from sympy.polys.matrices.sdm import SDM
582        >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ)
583        >>> A.convert_to(QQ)
584        {0: {1: 2}, 1: {0: 1}}
585
586        """
587        Kold = A.domain
588        if K == Kold:
589            return A.copy()
590        Ak = unop_dict(A, lambda e: K.convert_from(e, Kold))
591        return A.new(Ak, A.shape, K)
592
593    def scc(A):
594        """Strongly connected components of a square matrix *A*.
595
596        Examples
597        ========
598
599        >>> from sympy import ZZ
600        >>> from sympy.polys.matrices.sdm import SDM
601        >>> A = SDM({0:{0: ZZ(2)}, 1:{1:ZZ(1)}}, (2, 2), ZZ)
602        >>> A.scc()
603        [[0], [1]]
604
605        See also
606        ========
607
608        sympy.polys.matrices.domainmatrix.DomainMatrix.scc
609        """
610        rows, cols = A.shape
611        assert rows == cols
612        V = range(rows)
613        Emap = {v: list(A.get(v, [])) for v in V}
614        return _strongly_connected_components(V, Emap)
615
616    def rref(A):
617        """
618
619        Returns reduced-row echelon form and list of pivots for the :py:class:`~.SDM`
620
621        Examples
622        ========
623
624        >>> from sympy import QQ
625        >>> from sympy.polys.matrices.sdm import SDM
626        >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(2), 1:QQ(4)}}, (2, 2), QQ)
627        >>> A.rref()
628        ({0: {0: 1, 1: 2}}, [0])
629
630        """
631        B, pivots, _ = sdm_irref(A)
632        return A.new(B, A.shape, A.domain), pivots
633
634    def inv(A):
635        """
636
637        Returns inverse of a matrix A
638
639        Examples
640        ========
641
642        >>> from sympy import QQ
643        >>> from sympy.polys.matrices.sdm import SDM
644        >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ)
645        >>> A.inv()
646        {0: {0: -2, 1: 1}, 1: {0: 3/2, 1: -1/2}}
647
648        """
649        return A.from_ddm(A.to_ddm().inv())
650
651    def det(A):
652        """
653        Returns determinant of A
654
655        Examples
656        ========
657
658        >>> from sympy import QQ
659        >>> from sympy.polys.matrices.sdm import SDM
660        >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ)
661        >>> A.det()
662        -2
663
664        """
665        return A.to_ddm().det()
666
667    def lu(A):
668        """
669
670        Returns LU decomposition for a matrix A
671
672        Examples
673        ========
674
675        >>> from sympy import QQ
676        >>> from sympy.polys.matrices.sdm import SDM
677        >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ)
678        >>> A.lu()
679        ({0: {0: 1}, 1: {0: 3, 1: 1}}, {0: {0: 1, 1: 2}, 1: {1: -2}}, [])
680
681        """
682        L, U, swaps = A.to_ddm().lu()
683        return A.from_ddm(L), A.from_ddm(U), swaps
684
685    def lu_solve(A, b):
686        """
687
688        Uses LU decomposition to solve Ax = b,
689
690        Examples
691        ========
692
693        >>> from sympy import QQ
694        >>> from sympy.polys.matrices.sdm import SDM
695        >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ)
696        >>> b = SDM({0:{0:QQ(1)}, 1:{0:QQ(2)}}, (2, 1), QQ)
697        >>> A.lu_solve(b)
698        {1: {0: 1/2}}
699
700        """
701        return A.from_ddm(A.to_ddm().lu_solve(b.to_ddm()))
702
703    def nullspace(A):
704        """
705
706        Returns nullspace for a :py:class:`~.SDM` matrix A
707
708        Examples
709        ========
710
711        >>> from sympy import QQ
712        >>> from sympy.polys.matrices.sdm import SDM
713        >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0: QQ(2), 1: QQ(4)}}, (2, 2), QQ)
714        >>> A.nullspace()
715        ({0: {0: -2, 1: 1}}, [1])
716
717        """
718        ncols = A.shape[1]
719        one = A.domain.one
720        B, pivots, nzcols = sdm_irref(A)
721        K, nonpivots = sdm_nullspace_from_rref(B, one, ncols, pivots, nzcols)
722        K = dict(enumerate(K))
723        shape = (len(K), ncols)
724        return A.new(K, shape, A.domain), nonpivots
725
726    def particular(A):
727        ncols = A.shape[1]
728        B, pivots, nzcols = sdm_irref(A)
729        P = sdm_particular_from_rref(B, ncols, pivots)
730        rep = {0:P} if P else {}
731        return A.new(rep, (1, ncols-1), A.domain)
732
733    def hstack(A, *B):
734        """Horizontally stacks :py:class:`~.SDM` matrices.
735
736        Examples
737        ========
738
739        >>> from sympy import ZZ
740        >>> from sympy.polys.matrices.sdm import SDM
741
742        >>> A = SDM({0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3), 1: ZZ(4)}}, (2, 2), ZZ)
743        >>> B = SDM({0: {0: ZZ(5), 1: ZZ(6)}, 1: {0: ZZ(7), 1: ZZ(8)}}, (2, 2), ZZ)
744        >>> A.hstack(B)
745        {0: {0: 1, 1: 2, 2: 5, 3: 6}, 1: {0: 3, 1: 4, 2: 7, 3: 8}}
746
747        >>> C = SDM({0: {0: ZZ(9), 1: ZZ(10)}, 1: {0: ZZ(11), 1: ZZ(12)}}, (2, 2), ZZ)
748        >>> A.hstack(B, C)
749        {0: {0: 1, 1: 2, 2: 5, 3: 6, 4: 9, 5: 10}, 1: {0: 3, 1: 4, 2: 7, 3: 8, 4: 11, 5: 12}}
750        """
751        Anew = dict(A.copy())
752        rows, cols = A.shape
753        domain = A.domain
754
755        for Bk in B:
756            Bkrows, Bkcols = Bk.shape
757            assert Bkrows == rows
758            assert Bk.domain == domain
759
760            for i, Bki in Bk.items():
761                Ai = Anew.get(i, None)
762                if Ai is None:
763                    Anew[i] = Ai = {}
764                for j, Bkij in Bki.items():
765                    Ai[j + cols] = Bkij
766            cols += Bkcols
767
768        return A.new(Anew, (rows, cols), A.domain)
769
770    def vstack(A, *B):
771        """Vertically stacks :py:class:`~.SDM` matrices.
772
773        Examples
774        ========
775
776        >>> from sympy import ZZ
777        >>> from sympy.polys.matrices.sdm import SDM
778
779        >>> A = SDM({0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3), 1: ZZ(4)}}, (2, 2), ZZ)
780        >>> B = SDM({0: {0: ZZ(5), 1: ZZ(6)}, 1: {0: ZZ(7), 1: ZZ(8)}}, (2, 2), ZZ)
781        >>> A.vstack(B)
782        {0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}, 2: {0: 5, 1: 6}, 3: {0: 7, 1: 8}}
783
784        >>> C = SDM({0: {0: ZZ(9), 1: ZZ(10)}, 1: {0: ZZ(11), 1: ZZ(12)}}, (2, 2), ZZ)
785        >>> A.vstack(B, C)
786        {0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}, 2: {0: 5, 1: 6}, 3: {0: 7, 1: 8}, 4: {0: 9, 1: 10}, 5: {0: 11, 1: 12}}
787        """
788        Anew = dict(A.copy())
789        rows, cols = A.shape
790        domain = A.domain
791
792        for Bk in B:
793            Bkrows, Bkcols = Bk.shape
794            assert Bkcols == cols
795            assert Bk.domain == domain
796
797            for i, Bki in Bk.items():
798                Anew[i + rows] = Bki
799            rows += Bkrows
800
801        return A.new(Anew, (rows, cols), A.domain)
802
803    def applyfunc(self, func, domain):
804        sdm = {i: {j: func(e) for j, e in row.items()} for i, row in self.items()}
805        return self.new(sdm, self.shape, domain)
806
807    def charpoly(A):
808        """
809        Returns the coefficients of the characteristic polynomial
810        of the :py:class:`~.SDM` matrix. These elements will be domain elements.
811        The domain of the elements will be same as domain of the :py:class:`~.SDM`.
812
813        Examples
814        ========
815
816        >>> from sympy import QQ, Symbol
817        >>> from sympy.polys.matrices.sdm import SDM
818        >>> from sympy.polys import Poly
819        >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ)
820        >>> A.charpoly()
821        [1, -5, -2]
822
823        We can create a polynomial using the
824        coefficients using :py:class:`~.Poly`
825
826        >>> x = Symbol('x')
827        >>> p = Poly(A.charpoly(), x, domain=A.domain)
828        >>> p
829        Poly(x**2 - 5*x - 2, x, domain='QQ')
830
831        """
832        return A.to_ddm().charpoly()
833
834
835def binop_dict(A, B, fab, fa, fb):
836    Anz, Bnz = set(A), set(B)
837    C = {}
838
839    for i in Anz & Bnz:
840        Ai, Bi = A[i], B[i]
841        Ci = {}
842        Anzi, Bnzi = set(Ai), set(Bi)
843        for j in Anzi & Bnzi:
844            Cij = fab(Ai[j], Bi[j])
845            if Cij:
846                Ci[j] = Cij
847        for j in Anzi - Bnzi:
848            Cij = fa(Ai[j])
849            if Cij:
850                Ci[j] = Cij
851        for j in Bnzi - Anzi:
852            Cij = fb(Bi[j])
853            if Cij:
854                Ci[j] = Cij
855        if Ci:
856            C[i] = Ci
857
858    for i in Anz - Bnz:
859        Ai = A[i]
860        Ci = {}
861        for j, Aij in Ai.items():
862            Cij = fa(Aij)
863            if Cij:
864                Ci[j] = Cij
865        if Ci:
866            C[i] = Ci
867
868    for i in Bnz - Anz:
869        Bi = B[i]
870        Ci = {}
871        for j, Bij in Bi.items():
872            Cij = fb(Bij)
873            if Cij:
874                Ci[j] = Cij
875        if Ci:
876            C[i] = Ci
877
878    return C
879
880
881def unop_dict(A, f):
882    B = {}
883    for i, Ai in A.items():
884        Bi = {}
885        for j, Aij in Ai.items():
886            Bij = f(Aij)
887            if Bij:
888                Bi[j] = Bij
889        if Bi:
890            B[i] = Bi
891    return B
892
893
894def sdm_transpose(M):
895    MT = {}
896    for i, Mi in M.items():
897        for j, Mij in Mi.items():
898            try:
899                MT[j][i] = Mij
900            except KeyError:
901                MT[j] = {i: Mij}
902    return MT
903
904
905def sdm_matmul(A, B, K, m, o):
906    #
907    # Should be fast if A and B are very sparse.
908    # Consider e.g. A = B = eye(1000).
909    #
910    # The idea here is that we compute C = A*B in terms of the rows of C and
911    # B since the dict of dicts representation naturally stores the matrix as
912    # rows. The ith row of C (Ci) is equal to the sum of Aik * Bk where Bk is
913    # the kth row of B. The algorithm below loops over each nonzero element
914    # Aik of A and if the corresponding row Bj is nonzero then we do
915    #    Ci += Aik * Bk.
916    # To make this more efficient we don't need to loop over all elements Aik.
917    # Instead for each row Ai we compute the intersection of the nonzero
918    # columns in Ai with the nonzero rows in B. That gives the k such that
919    # Aik and Bk are both nonzero. In Python the intersection of two sets
920    # of int can be computed very efficiently.
921    #
922    if K.is_EXRAW:
923        return sdm_matmul_exraw(A, B, K, m, o)
924
925    C = {}
926    B_knz = set(B)
927    for i, Ai in A.items():
928        Ci = {}
929        Ai_knz = set(Ai)
930        for k in Ai_knz & B_knz:
931            Aik = Ai[k]
932            for j, Bkj in B[k].items():
933                Cij = Ci.get(j, None)
934                if Cij is not None:
935                    Cij = Cij + Aik * Bkj
936                    if Cij:
937                        Ci[j] = Cij
938                    else:
939                        Ci.pop(j)
940                else:
941                    Cij = Aik * Bkj
942                    if Cij:
943                        Ci[j] = Cij
944        if Ci:
945            C[i] = Ci
946    return C
947
948
949def sdm_matmul_exraw(A, B, K, m, o):
950    #
951    # Like sdm_matmul above except that:
952    #
953    # - Handles cases like 0*oo -> nan (sdm_matmul skips multipication by zero)
954    # - Uses K.sum (Add(*items)) for efficient addition of Expr
955    #
956    zero = K.zero
957    C = {}
958    B_knz = set(B)
959    for i, Ai in A.items():
960        Ci_list = defaultdict(list)
961        Ai_knz = set(Ai)
962
963        # Nonzero row/column pair
964        for k in Ai_knz & B_knz:
965            Aik = Ai[k]
966            if zero * Aik == zero:
967                # This is the main inner loop:
968                for j, Bkj in B[k].items():
969                    Ci_list[j].append(Aik * Bkj)
970            else:
971                for j in range(o):
972                    Ci_list[j].append(Aik * B[k].get(j, zero))
973
974        # Zero row in B, check for infinities in A
975        for k in Ai_knz - B_knz:
976            zAik = zero * Ai[k]
977            if zAik != zero:
978                for j in range(o):
979                    Ci_list[j].append(zAik)
980
981        # Add terms using K.sum (Add(*terms)) for efficiency
982        Ci = {}
983        for j, Cij_list in Ci_list.items():
984            Cij = K.sum(Cij_list)
985            if Cij:
986                Ci[j] = Cij
987        if Ci:
988            C[i] = Ci
989
990    # Find all infinities in B
991    for k, Bk in B.items():
992        for j, Bkj in Bk.items():
993            if zero * Bkj != zero:
994                for i in range(m):
995                    Aik = A.get(i, {}).get(k, zero)
996                    # If Aik is not zero then this was handled above
997                    if Aik == zero:
998                        Ci = C.get(i, {})
999                        Cij = Ci.get(j, zero) + Aik * Bkj
1000                        if Cij != zero:
1001                            Ci[j] = Cij
1002                        else:  # pragma: no cover
1003                            # Not sure how we could get here but let's raise an
1004                            # exception just in case.
1005                            raise RuntimeError
1006                        C[i] = Ci
1007
1008    return C
1009
1010
1011def sdm_irref(A):
1012    """RREF and pivots of a sparse matrix *A*.
1013
1014    Compute the reduced row echelon form (RREF) of the matrix *A* and return a
1015    list of the pivot columns. This routine does not work in place and leaves
1016    the original matrix *A* unmodified.
1017
1018    Examples
1019    ========
1020
1021    This routine works with a dict of dicts sparse representation of a matrix:
1022
1023    >>> from sympy import QQ
1024    >>> from sympy.polys.matrices.sdm import sdm_irref
1025    >>> A = {0: {0: QQ(1), 1: QQ(2)}, 1: {0: QQ(3), 1: QQ(4)}}
1026    >>> Arref, pivots, _ = sdm_irref(A)
1027    >>> Arref
1028    {0: {0: 1}, 1: {1: 1}}
1029    >>> pivots
1030    [0, 1]
1031
1032   The analogous calculation with :py:class:`~.Matrix` would be
1033
1034    >>> from sympy import Matrix
1035    >>> M = Matrix([[1, 2], [3, 4]])
1036    >>> Mrref, pivots = M.rref()
1037    >>> Mrref
1038    Matrix([
1039    [1, 0],
1040    [0, 1]])
1041    >>> pivots
1042    (0, 1)
1043
1044    Notes
1045    =====
1046
1047    The cost of this algorithm is determined purely by the nonzero elements of
1048    the matrix. No part of the cost of any step in this algorithm depends on
1049    the number of rows or columns in the matrix. No step depends even on the
1050    number of nonzero rows apart from the primary loop over those rows. The
1051    implementation is much faster than ddm_rref for sparse matrices. In fact
1052    at the time of writing it is also (slightly) faster than the dense
1053    implementation even if the input is a fully dense matrix so it seems to be
1054    faster in all cases.
1055
1056    The elements of the matrix should support exact division with ``/``. For
1057    example elements of any domain that is a field (e.g. ``QQ``) should be
1058    fine. No attempt is made to handle inexact arithmetic.
1059
1060    """
1061    #
1062    # Any zeros in the matrix are not stored at all so an element is zero if
1063    # its row dict has no index at that key. A row is entirely zero if its
1064    # row index is not in the outer dict. Since rref reorders the rows and
1065    # removes zero rows we can completely discard the row indices. The first
1066    # step then copies the row dicts into a list sorted by the index of the
1067    # first nonzero column in each row.
1068    #
1069    # The algorithm then processes each row Ai one at a time. Previously seen
1070    # rows are used to cancel their pivot columns from Ai. Then a pivot from
1071    # Ai is chosen and is cancelled from all previously seen rows. At this
1072    # point Ai joins the previously seen rows. Once all rows are seen all
1073    # elimination has occurred and the rows are sorted by pivot column index.
1074    #
1075    # The previously seen rows are stored in two separate groups. The reduced
1076    # group consists of all rows that have been reduced to a single nonzero
1077    # element (the pivot). There is no need to attempt any further reduction
1078    # with these. Rows that still have other nonzeros need to be considered
1079    # when Ai is cancelled from the previously seen rows.
1080    #
1081    # A dict nonzerocolumns is used to map from a column index to a set of
1082    # previously seen rows that still have a nonzero element in that column.
1083    # This means that we can cancel the pivot from Ai into the previously seen
1084    # rows without needing to loop over each row that might have a zero in
1085    # that column.
1086    #
1087
1088    # Row dicts sorted by index of first nonzero column
1089    # (Maybe sorting is not needed/useful.)
1090    Arows = sorted((Ai.copy() for Ai in A.values()), key=min)
1091
1092    # Each processed row has an associated pivot column.
1093    # pivot_row_map maps from the pivot column index to the row dict.
1094    # This means that we can represent a set of rows purely as a set of their
1095    # pivot indices.
1096    pivot_row_map = {}
1097
1098    # Set of pivot indices for rows that are fully reduced to a single nonzero.
1099    reduced_pivots = set()
1100
1101    # Set of pivot indices for rows not fully reduced
1102    nonreduced_pivots = set()
1103
1104    # Map from column index to a set of pivot indices representing the rows
1105    # that have a nonzero at that column.
1106    nonzero_columns = defaultdict(set)
1107
1108    while Arows:
1109        # Select pivot element and row
1110        Ai = Arows.pop()
1111
1112        # Nonzero columns from fully reduced pivot rows can be removed
1113        Ai = {j: Aij for j, Aij in Ai.items() if j not in reduced_pivots}
1114
1115        # Others require full row cancellation
1116        for j in nonreduced_pivots & set(Ai):
1117            Aj = pivot_row_map[j]
1118            Aij = Ai[j]
1119            Ainz = set(Ai)
1120            Ajnz = set(Aj)
1121            for k in Ajnz - Ainz:
1122                Ai[k] = - Aij * Aj[k]
1123            Ai.pop(j)
1124            Ainz.remove(j)
1125            for k in Ajnz & Ainz:
1126                Aik = Ai[k] - Aij * Aj[k]
1127                if Aik:
1128                    Ai[k] = Aik
1129                else:
1130                    Ai.pop(k)
1131
1132        # We have now cancelled previously seen pivots from Ai.
1133        # If it is zero then discard it.
1134        if not Ai:
1135            continue
1136
1137        # Choose a pivot from Ai:
1138        j = min(Ai)
1139        Aij = Ai[j]
1140        pivot_row_map[j] = Ai
1141        Ainz = set(Ai)
1142
1143        # Normalise the pivot row to make the pivot 1.
1144        #
1145        # This approach is slow for some domains. Cross cancellation might be
1146        # better for e.g. QQ(x) with division delayed to the final steps.
1147        Aijinv = Aij**-1
1148        for l in Ai:
1149            Ai[l] *= Aijinv
1150
1151        # Use Aij to cancel column j from all previously seen rows
1152        for k in nonzero_columns.pop(j, ()):
1153            Ak = pivot_row_map[k]
1154            Akj = Ak[j]
1155            Aknz = set(Ak)
1156            for l in Ainz - Aknz:
1157                Ak[l] = - Akj * Ai[l]
1158                nonzero_columns[l].add(k)
1159            Ak.pop(j)
1160            Aknz.remove(j)
1161            for l in Ainz & Aknz:
1162                Akl = Ak[l] - Akj * Ai[l]
1163                if Akl:
1164                    Ak[l] = Akl
1165                else:
1166                    # Drop nonzero elements
1167                    Ak.pop(l)
1168                    if l != j:
1169                        nonzero_columns[l].remove(k)
1170            if len(Ak) == 1:
1171                reduced_pivots.add(k)
1172                nonreduced_pivots.remove(k)
1173
1174        if len(Ai) == 1:
1175            reduced_pivots.add(j)
1176        else:
1177            nonreduced_pivots.add(j)
1178            for l in Ai:
1179                if l != j:
1180                    nonzero_columns[l].add(j)
1181
1182    # All done!
1183    pivots = sorted(reduced_pivots | nonreduced_pivots)
1184    pivot2row = {p: n for n, p in enumerate(pivots)}
1185    nonzero_columns = {c: set(pivot2row[p] for p in s) for c, s in nonzero_columns.items()}
1186    rows = [pivot_row_map[i] for i in pivots]
1187    rref = dict(enumerate(rows))
1188    return rref, pivots, nonzero_columns
1189
1190
1191def sdm_nullspace_from_rref(A, one, ncols, pivots, nonzero_cols):
1192    """Get nullspace from A which is in RREF"""
1193    nonpivots = sorted(set(range(ncols)) - set(pivots))
1194
1195    K = []
1196    for j in nonpivots:
1197        Kj = {j:one}
1198        for i in nonzero_cols.get(j, ()):
1199            Kj[pivots[i]] = -A[i][j]
1200        K.append(Kj)
1201
1202    return K, nonpivots
1203
1204
1205def sdm_particular_from_rref(A, ncols, pivots):
1206    """Get a particular solution from A which is in RREF"""
1207    P = {}
1208    for i, j in enumerate(pivots):
1209        Ain = A[i].get(ncols-1, None)
1210        if Ain is not None:
1211            P[j] = Ain / A[i][j]
1212    return P
1213