1from sympy.core import S
2from sympy.core.sympify import _sympify
3from sympy.functions import KroneckerDelta
4
5from .matexpr import MatrixExpr
6from .special import ZeroMatrix, Identity, OneMatrix
7
8
9class PermutationMatrix(MatrixExpr):
10    """A Permutation Matrix
11
12    Parameters
13    ==========
14
15    perm : Permutation
16        The permutation the matrix uses.
17
18        The size of the permutation determines the matrix size.
19
20        See the documentation of
21        :class:`sympy.combinatorics.permutations.Permutation` for
22        the further information of how to create a permutation object.
23
24    Examples
25    ========
26
27    >>> from sympy.matrices import Matrix, PermutationMatrix
28    >>> from sympy.combinatorics import Permutation
29
30    Creating a permutation matrix:
31
32    >>> p = Permutation(1, 2, 0)
33    >>> P = PermutationMatrix(p)
34    >>> P = P.as_explicit()
35    >>> P
36    Matrix([
37    [0, 1, 0],
38    [0, 0, 1],
39    [1, 0, 0]])
40
41    Permuting a matrix row and column:
42
43    >>> M = Matrix([0, 1, 2])
44    >>> Matrix(P*M)
45    Matrix([
46    [1],
47    [2],
48    [0]])
49
50    >>> Matrix(M.T*P)
51    Matrix([[2, 0, 1]])
52
53    See Also
54    ========
55
56    sympy.combinatorics.permutations.Permutation
57    """
58
59    def __new__(cls, perm):
60        from sympy.combinatorics.permutations import Permutation
61
62        perm = _sympify(perm)
63        if not isinstance(perm, Permutation):
64            raise ValueError(
65                "{} must be a SymPy Permutation instance.".format(perm))
66
67        return super().__new__(cls, perm)
68
69    @property
70    def shape(self):
71        size = self.args[0].size
72        return (size, size)
73
74    @property
75    def is_Identity(self):
76        return self.args[0].is_Identity
77
78    def doit(self):
79        if self.is_Identity:
80            return Identity(self.rows)
81        return self
82
83    def _entry(self, i, j, **kwargs):
84        perm = self.args[0]
85        return KroneckerDelta(perm.apply(i), j)
86
87    def _eval_power(self, exp):
88        return PermutationMatrix(self.args[0] ** exp).doit()
89
90    def _eval_inverse(self):
91        return PermutationMatrix(self.args[0] ** -1)
92
93    _eval_transpose = _eval_adjoint = _eval_inverse
94
95    def _eval_determinant(self):
96        sign = self.args[0].signature()
97        if sign == 1:
98            return S.One
99        elif sign == -1:
100            return S.NegativeOne
101        raise NotImplementedError
102
103    def _eval_rewrite_as_BlockDiagMatrix(self, *args, **kwargs):
104        from sympy.combinatorics.permutations import Permutation
105        from .blockmatrix import BlockDiagMatrix
106
107        perm = self.args[0]
108        full_cyclic_form = perm.full_cyclic_form
109
110        cycles_picks = []
111
112        # Stage 1. Decompose the cycles into the blockable form.
113        a, b, c = 0, 0, 0
114        flag = False
115        for cycle in full_cyclic_form:
116            l = len(cycle)
117            m = max(cycle)
118
119            if not flag:
120                if m + 1 > a + l:
121                    flag = True
122                    temp = [cycle]
123                    b = m
124                    c = l
125                else:
126                    cycles_picks.append([cycle])
127                    a += l
128
129            else:
130                if m > b:
131                    if m + 1 == a + c + l:
132                        temp.append(cycle)
133                        cycles_picks.append(temp)
134                        flag = False
135                        a = m+1
136                    else:
137                        b = m
138                        temp.append(cycle)
139                        c += l
140                else:
141                    if b + 1 == a + c + l:
142                        temp.append(cycle)
143                        cycles_picks.append(temp)
144                        flag = False
145                        a = b+1
146                    else:
147                        temp.append(cycle)
148                        c += l
149
150        # Stage 2. Normalize each decomposed cycles and build matrix.
151        p = 0
152        args = []
153        for pick in cycles_picks:
154            new_cycles = []
155            l = 0
156            for cycle in pick:
157                new_cycle = [i - p for i in cycle]
158                new_cycles.append(new_cycle)
159                l += len(cycle)
160            p += l
161            perm = Permutation(new_cycles)
162            mat = PermutationMatrix(perm)
163            args.append(mat)
164
165        return BlockDiagMatrix(*args)
166
167
168class MatrixPermute(MatrixExpr):
169    r"""Symbolic representation for permuting matrix rows or columns.
170
171    Parameters
172    ==========
173
174    perm : Permutation, PermutationMatrix
175        The permutation to use for permuting the matrix.
176        The permutation can be resized to the suitable one,
177
178    axis : 0 or 1
179        The axis to permute alongside.
180        If `0`, it will permute the matrix rows.
181        If `1`, it will permute the matrix columns.
182
183    Notes
184    =====
185
186    This follows the same notation used in
187    :meth:`sympy.matrices.common.MatrixCommon.permute`.
188
189    Examples
190    ========
191
192    >>> from sympy.matrices import Matrix, MatrixPermute
193    >>> from sympy.combinatorics import Permutation
194
195    Permuting the matrix rows:
196
197    >>> p = Permutation(1, 2, 0)
198    >>> A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
199    >>> B = MatrixPermute(A, p, axis=0)
200    >>> B.as_explicit()
201    Matrix([
202    [4, 5, 6],
203    [7, 8, 9],
204    [1, 2, 3]])
205
206    Permuting the matrix columns:
207
208    >>> B = MatrixPermute(A, p, axis=1)
209    >>> B.as_explicit()
210    Matrix([
211    [2, 3, 1],
212    [5, 6, 4],
213    [8, 9, 7]])
214
215    See Also
216    ========
217
218    sympy.matrices.common.MatrixCommon.permute
219    """
220    def __new__(cls, mat, perm, axis=S.Zero):
221        from sympy.combinatorics.permutations import Permutation
222
223        mat = _sympify(mat)
224        if not mat.is_Matrix:
225            raise ValueError(
226                "{} must be a SymPy matrix instance.".format(perm))
227
228        perm = _sympify(perm)
229        if isinstance(perm, PermutationMatrix):
230            perm = perm.args[0]
231
232        if not isinstance(perm, Permutation):
233            raise ValueError(
234                "{} must be a SymPy Permutation or a PermutationMatrix " \
235                "instance".format(perm))
236
237        axis = _sympify(axis)
238        if axis not in (0, 1):
239            raise ValueError("The axis must be 0 or 1.")
240
241        mat_size = mat.shape[axis]
242        if mat_size != perm.size:
243            try:
244                perm = perm.resize(mat_size)
245            except ValueError:
246                raise ValueError(
247                    "Size does not match between the permutation {} "
248                    "and the matrix {} threaded over the axis {} "
249                    "and cannot be converted."
250                    .format(perm, mat, axis))
251
252        return super().__new__(cls, mat, perm, axis)
253
254    def doit(self, deep=True):
255        mat, perm, axis = self.args
256
257        if deep:
258            mat = mat.doit(deep=deep)
259            perm = perm.doit(deep=deep)
260
261        if perm.is_Identity:
262            return mat
263
264        if mat.is_Identity:
265            if axis is S.Zero:
266                return PermutationMatrix(perm)
267            elif axis is S.One:
268                return PermutationMatrix(perm**-1)
269
270        if isinstance(mat, (ZeroMatrix, OneMatrix)):
271            return mat
272
273        if isinstance(mat, MatrixPermute) and mat.args[2] == axis:
274            return MatrixPermute(mat.args[0], perm * mat.args[1], axis)
275
276        return self
277
278    @property
279    def shape(self):
280        return self.args[0].shape
281
282    def _entry(self, i, j, **kwargs):
283        mat, perm, axis = self.args
284
285        if axis == 0:
286            return mat[perm.apply(i), j]
287        elif axis == 1:
288            return mat[i, perm.apply(j)]
289
290    def _eval_rewrite_as_MatMul(self, *args, **kwargs):
291        from .matmul import MatMul
292
293        mat, perm, axis = self.args
294
295        deep = kwargs.get("deep", True)
296
297        if deep:
298            mat = mat.rewrite(MatMul)
299
300        if axis == 0:
301            return MatMul(PermutationMatrix(perm), mat)
302        elif axis == 1:
303            return MatMul(mat, PermutationMatrix(perm**-1))
304