1"""Implementation of the Kronecker product"""
2
3
4from sympy.core import Mul, prod, sympify
5from sympy.functions import adjoint
6from sympy.matrices.common import ShapeError
7from sympy.matrices.expressions.matexpr import MatrixExpr
8from sympy.matrices.expressions.transpose import transpose
9from sympy.matrices.expressions.special import Identity
10from sympy.matrices.matrices import MatrixBase
11from sympy.strategies import (
12    canon, condition, distribute, do_one, exhaust, flatten, typed, unpack)
13from sympy.strategies.traverse import bottom_up
14from sympy.utilities import sift
15
16from .matadd import MatAdd
17from .matmul import MatMul
18from .matpow import MatPow
19
20
21def kronecker_product(*matrices):
22    """
23    The Kronecker product of two or more arguments.
24
25    This computes the explicit Kronecker product for subclasses of
26    ``MatrixBase`` i.e. explicit matrices. Otherwise, a symbolic
27    ``KroneckerProduct`` object is returned.
28
29
30    Examples
31    ========
32
33    For ``MatrixSymbol`` arguments a ``KroneckerProduct`` object is returned.
34    Elements of this matrix can be obtained by indexing, or for MatrixSymbols
35    with known dimension the explicit matrix can be obtained with
36    ``.as_explicit()``
37
38    >>> from sympy.matrices import kronecker_product, MatrixSymbol
39    >>> A = MatrixSymbol('A', 2, 2)
40    >>> B = MatrixSymbol('B', 2, 2)
41    >>> kronecker_product(A)
42    A
43    >>> kronecker_product(A, B)
44    KroneckerProduct(A, B)
45    >>> kronecker_product(A, B)[0, 1]
46    A[0, 0]*B[0, 1]
47    >>> kronecker_product(A, B).as_explicit()
48    Matrix([
49        [A[0, 0]*B[0, 0], A[0, 0]*B[0, 1], A[0, 1]*B[0, 0], A[0, 1]*B[0, 1]],
50        [A[0, 0]*B[1, 0], A[0, 0]*B[1, 1], A[0, 1]*B[1, 0], A[0, 1]*B[1, 1]],
51        [A[1, 0]*B[0, 0], A[1, 0]*B[0, 1], A[1, 1]*B[0, 0], A[1, 1]*B[0, 1]],
52        [A[1, 0]*B[1, 0], A[1, 0]*B[1, 1], A[1, 1]*B[1, 0], A[1, 1]*B[1, 1]]])
53
54    For explicit matrices the Kronecker product is returned as a Matrix
55
56    >>> from sympy.matrices import Matrix, kronecker_product
57    >>> sigma_x = Matrix([
58    ... [0, 1],
59    ... [1, 0]])
60    ...
61    >>> Isigma_y = Matrix([
62    ... [0, 1],
63    ... [-1, 0]])
64    ...
65    >>> kronecker_product(sigma_x, Isigma_y)
66    Matrix([
67    [ 0, 0,  0, 1],
68    [ 0, 0, -1, 0],
69    [ 0, 1,  0, 0],
70    [-1, 0,  0, 0]])
71
72    See Also
73    ========
74        KroneckerProduct
75
76    """
77    if not matrices:
78        raise TypeError("Empty Kronecker product is undefined")
79    validate(*matrices)
80    if len(matrices) == 1:
81        return matrices[0]
82    else:
83        return KroneckerProduct(*matrices).doit()
84
85
86class KroneckerProduct(MatrixExpr):
87    """
88    The Kronecker product of two or more arguments.
89
90    The Kronecker product is a non-commutative product of matrices.
91    Given two matrices of dimension (m, n) and (s, t) it produces a matrix
92    of dimension (m s, n t).
93
94    This is a symbolic object that simply stores its argument without
95    evaluating it. To actually compute the product, use the function
96    ``kronecker_product()`` or call the the ``.doit()`` or  ``.as_explicit()``
97    methods.
98
99    >>> from sympy.matrices import KroneckerProduct, MatrixSymbol
100    >>> A = MatrixSymbol('A', 5, 5)
101    >>> B = MatrixSymbol('B', 5, 5)
102    >>> isinstance(KroneckerProduct(A, B), KroneckerProduct)
103    True
104    """
105    is_KroneckerProduct = True
106
107    def __new__(cls, *args, check=True):
108        args = list(map(sympify, args))
109        if all(a.is_Identity for a in args):
110            ret = Identity(prod(a.rows for a in args))
111            if all(isinstance(a, MatrixBase) for a in args):
112                return ret.as_explicit()
113            else:
114                return ret
115
116        if check:
117            validate(*args)
118        return super().__new__(cls, *args)
119
120    @property
121    def shape(self):
122        rows, cols = self.args[0].shape
123        for mat in self.args[1:]:
124            rows *= mat.rows
125            cols *= mat.cols
126        return (rows, cols)
127
128    def _entry(self, i, j, **kwargs):
129        result = 1
130        for mat in reversed(self.args):
131            i, m = divmod(i, mat.rows)
132            j, n = divmod(j, mat.cols)
133            result *= mat[m, n]
134        return result
135
136    def _eval_adjoint(self):
137        return KroneckerProduct(*list(map(adjoint, self.args))).doit()
138
139    def _eval_conjugate(self):
140        return KroneckerProduct(*[a.conjugate() for a in self.args]).doit()
141
142    def _eval_transpose(self):
143        return KroneckerProduct(*list(map(transpose, self.args))).doit()
144
145    def _eval_trace(self):
146        from .trace import trace
147        return prod(trace(a) for a in self.args)
148
149    def _eval_determinant(self):
150        from .determinant import det, Determinant
151        if not all(a.is_square for a in self.args):
152            return Determinant(self)
153
154        m = self.rows
155        return prod(det(a)**(m/a.rows) for a in self.args)
156
157    def _eval_inverse(self):
158        try:
159            return KroneckerProduct(*[a.inverse() for a in self.args])
160        except ShapeError:
161            from sympy.matrices.expressions.inverse import Inverse
162            return Inverse(self)
163
164    def structurally_equal(self, other):
165        '''Determine whether two matrices have the same Kronecker product structure
166
167        Examples
168        ========
169
170        >>> from sympy import KroneckerProduct, MatrixSymbol, symbols
171        >>> m, n = symbols(r'm, n', integer=True)
172        >>> A = MatrixSymbol('A', m, m)
173        >>> B = MatrixSymbol('B', n, n)
174        >>> C = MatrixSymbol('C', m, m)
175        >>> D = MatrixSymbol('D', n, n)
176        >>> KroneckerProduct(A, B).structurally_equal(KroneckerProduct(C, D))
177        True
178        >>> KroneckerProduct(A, B).structurally_equal(KroneckerProduct(D, C))
179        False
180        >>> KroneckerProduct(A, B).structurally_equal(C)
181        False
182        '''
183        # Inspired by BlockMatrix
184        return (isinstance(other, KroneckerProduct)
185                and self.shape == other.shape
186                and len(self.args) == len(other.args)
187                and all(a.shape == b.shape for (a, b) in zip(self.args, other.args)))
188
189    def has_matching_shape(self, other):
190        '''Determine whether two matrices have the appropriate structure to bring matrix
191        multiplication inside the KroneckerProdut
192
193        Examples
194        ========
195        >>> from sympy import KroneckerProduct, MatrixSymbol, symbols
196        >>> m, n = symbols(r'm, n', integer=True)
197        >>> A = MatrixSymbol('A', m, n)
198        >>> B = MatrixSymbol('B', n, m)
199        >>> KroneckerProduct(A, B).has_matching_shape(KroneckerProduct(B, A))
200        True
201        >>> KroneckerProduct(A, B).has_matching_shape(KroneckerProduct(A, B))
202        False
203        >>> KroneckerProduct(A, B).has_matching_shape(A)
204        False
205        '''
206        return (isinstance(other, KroneckerProduct)
207                and self.cols == other.rows
208                and len(self.args) == len(other.args)
209                and all(a.cols == b.rows for (a, b) in zip(self.args, other.args)))
210
211    def _eval_expand_kroneckerproduct(self, **hints):
212        return flatten(canon(typed({KroneckerProduct: distribute(KroneckerProduct, MatAdd)}))(self))
213
214    def _kronecker_add(self, other):
215        if self.structurally_equal(other):
216            return self.__class__(*[a + b for (a, b) in zip(self.args, other.args)])
217        else:
218            return self + other
219
220    def _kronecker_mul(self, other):
221        if self.has_matching_shape(other):
222            return self.__class__(*[a*b for (a, b) in zip(self.args, other.args)])
223        else:
224            return self * other
225
226    def doit(self, **kwargs):
227        deep = kwargs.get('deep', True)
228        if deep:
229            args = [arg.doit(**kwargs) for arg in self.args]
230        else:
231            args = self.args
232        return canonicalize(KroneckerProduct(*args))
233
234
235def validate(*args):
236    if not all(arg.is_Matrix for arg in args):
237        raise TypeError("Mix of Matrix and Scalar symbols")
238
239
240# rules
241
242def extract_commutative(kron):
243    c_part = []
244    nc_part = []
245    for arg in kron.args:
246        c, nc = arg.args_cnc()
247        c_part.extend(c)
248        nc_part.append(Mul._from_args(nc))
249
250    c_part = Mul(*c_part)
251    if c_part != 1:
252        return c_part*KroneckerProduct(*nc_part)
253    return kron
254
255
256def matrix_kronecker_product(*matrices):
257    """Compute the Kronecker product of a sequence of SymPy Matrices.
258
259    This is the standard Kronecker product of matrices [1].
260
261    Parameters
262    ==========
263
264    matrices : tuple of MatrixBase instances
265        The matrices to take the Kronecker product of.
266
267    Returns
268    =======
269
270    matrix : MatrixBase
271        The Kronecker product matrix.
272
273    Examples
274    ========
275
276    >>> from sympy import Matrix
277    >>> from sympy.matrices.expressions.kronecker import (
278    ... matrix_kronecker_product)
279
280    >>> m1 = Matrix([[1,2],[3,4]])
281    >>> m2 = Matrix([[1,0],[0,1]])
282    >>> matrix_kronecker_product(m1, m2)
283    Matrix([
284    [1, 0, 2, 0],
285    [0, 1, 0, 2],
286    [3, 0, 4, 0],
287    [0, 3, 0, 4]])
288    >>> matrix_kronecker_product(m2, m1)
289    Matrix([
290    [1, 2, 0, 0],
291    [3, 4, 0, 0],
292    [0, 0, 1, 2],
293    [0, 0, 3, 4]])
294
295    References
296    ==========
297
298    [1] https://en.wikipedia.org/wiki/Kronecker_product
299    """
300    # Make sure we have a sequence of Matrices
301    if not all(isinstance(m, MatrixBase) for m in matrices):
302        raise TypeError(
303            'Sequence of Matrices expected, got: %s' % repr(matrices)
304        )
305
306    # Pull out the first element in the product.
307    matrix_expansion = matrices[-1]
308    # Do the kronecker product working from right to left.
309    for mat in reversed(matrices[:-1]):
310        rows = mat.rows
311        cols = mat.cols
312        # Go through each row appending kronecker product to.
313        # running matrix_expansion.
314        for i in range(rows):
315            start = matrix_expansion*mat[i*cols]
316            # Go through each column joining each item
317            for j in range(cols - 1):
318                start = start.row_join(
319                    matrix_expansion*mat[i*cols + j + 1]
320                )
321            # If this is the first element, make it the start of the
322            # new row.
323            if i == 0:
324                next = start
325            else:
326                next = next.col_join(start)
327        matrix_expansion = next
328
329    MatrixClass = max(matrices, key=lambda M: M._class_priority).__class__
330    if isinstance(matrix_expansion, MatrixClass):
331        return matrix_expansion
332    else:
333        return MatrixClass(matrix_expansion)
334
335
336def explicit_kronecker_product(kron):
337    # Make sure we have a sequence of Matrices
338    if not all(isinstance(m, MatrixBase) for m in kron.args):
339        return kron
340
341    return matrix_kronecker_product(*kron.args)
342
343
344rules = (unpack,
345         explicit_kronecker_product,
346         flatten,
347         extract_commutative)
348
349canonicalize = exhaust(condition(lambda x: isinstance(x, KroneckerProduct),
350                                 do_one(*rules)))
351
352
353def _kronecker_dims_key(expr):
354    if isinstance(expr, KroneckerProduct):
355        return tuple(a.shape for a in expr.args)
356    else:
357        return (0,)
358
359
360def kronecker_mat_add(expr):
361    from functools import reduce
362    args = sift(expr.args, _kronecker_dims_key)
363    nonkrons = args.pop((0,), None)
364    if not args:
365        return expr
366
367    krons = [reduce(lambda x, y: x._kronecker_add(y), group)
368             for group in args.values()]
369
370    if not nonkrons:
371        return MatAdd(*krons)
372    else:
373        return MatAdd(*krons) + nonkrons
374
375
376def kronecker_mat_mul(expr):
377    # modified from block matrix code
378    factor, matrices = expr.as_coeff_matrices()
379
380    i = 0
381    while i < len(matrices) - 1:
382        A, B = matrices[i:i+2]
383        if isinstance(A, KroneckerProduct) and isinstance(B, KroneckerProduct):
384            matrices[i] = A._kronecker_mul(B)
385            matrices.pop(i+1)
386        else:
387            i += 1
388
389    return factor*MatMul(*matrices)
390
391
392def kronecker_mat_pow(expr):
393    if isinstance(expr.base, KroneckerProduct) and all(a.is_square for a in expr.base.args):
394        return KroneckerProduct(*[MatPow(a, expr.exp) for a in expr.base.args])
395    else:
396        return expr
397
398
399def combine_kronecker(expr):
400    """Combine KronekeckerProduct with expression.
401
402    If possible write operations on KroneckerProducts of compatible shapes
403    as a single KroneckerProduct.
404
405    Examples
406    ========
407
408    >>> from sympy.matrices.expressions import MatrixSymbol, KroneckerProduct, combine_kronecker
409    >>> from sympy import symbols
410    >>> m, n = symbols(r'm, n', integer=True)
411    >>> A = MatrixSymbol('A', m, n)
412    >>> B = MatrixSymbol('B', n, m)
413    >>> combine_kronecker(KroneckerProduct(A, B)*KroneckerProduct(B, A))
414    KroneckerProduct(A*B, B*A)
415    >>> combine_kronecker(KroneckerProduct(A, B)+KroneckerProduct(B.T, A.T))
416    KroneckerProduct(A + B.T, B + A.T)
417    >>> C = MatrixSymbol('C', n, n)
418    >>> D = MatrixSymbol('D', m, m)
419    >>> combine_kronecker(KroneckerProduct(C, D)**m)
420    KroneckerProduct(C**m, D**m)
421    """
422    def haskron(expr):
423        return isinstance(expr, MatrixExpr) and expr.has(KroneckerProduct)
424
425    rule = exhaust(
426        bottom_up(exhaust(condition(haskron, typed(
427            {MatAdd: kronecker_mat_add,
428             MatMul: kronecker_mat_mul,
429             MatPow: kronecker_mat_pow})))))
430    result = rule(expr)
431    doit = getattr(result, 'doit', None)
432    if doit is not None:
433        return doit()
434    else:
435        return result
436