1"""Implementation of the Kronecker product""" 2 3 4from sympy.core import Mul, prod, sympify 5from sympy.functions import adjoint 6from sympy.matrices.common import ShapeError 7from sympy.matrices.expressions.matexpr import MatrixExpr 8from sympy.matrices.expressions.transpose import transpose 9from sympy.matrices.expressions.special import Identity 10from sympy.matrices.matrices import MatrixBase 11from sympy.strategies import ( 12 canon, condition, distribute, do_one, exhaust, flatten, typed, unpack) 13from sympy.strategies.traverse import bottom_up 14from sympy.utilities import sift 15 16from .matadd import MatAdd 17from .matmul import MatMul 18from .matpow import MatPow 19 20 21def kronecker_product(*matrices): 22 """ 23 The Kronecker product of two or more arguments. 24 25 This computes the explicit Kronecker product for subclasses of 26 ``MatrixBase`` i.e. explicit matrices. Otherwise, a symbolic 27 ``KroneckerProduct`` object is returned. 28 29 30 Examples 31 ======== 32 33 For ``MatrixSymbol`` arguments a ``KroneckerProduct`` object is returned. 34 Elements of this matrix can be obtained by indexing, or for MatrixSymbols 35 with known dimension the explicit matrix can be obtained with 36 ``.as_explicit()`` 37 38 >>> from sympy.matrices import kronecker_product, MatrixSymbol 39 >>> A = MatrixSymbol('A', 2, 2) 40 >>> B = MatrixSymbol('B', 2, 2) 41 >>> kronecker_product(A) 42 A 43 >>> kronecker_product(A, B) 44 KroneckerProduct(A, B) 45 >>> kronecker_product(A, B)[0, 1] 46 A[0, 0]*B[0, 1] 47 >>> kronecker_product(A, B).as_explicit() 48 Matrix([ 49 [A[0, 0]*B[0, 0], A[0, 0]*B[0, 1], A[0, 1]*B[0, 0], A[0, 1]*B[0, 1]], 50 [A[0, 0]*B[1, 0], A[0, 0]*B[1, 1], A[0, 1]*B[1, 0], A[0, 1]*B[1, 1]], 51 [A[1, 0]*B[0, 0], A[1, 0]*B[0, 1], A[1, 1]*B[0, 0], A[1, 1]*B[0, 1]], 52 [A[1, 0]*B[1, 0], A[1, 0]*B[1, 1], A[1, 1]*B[1, 0], A[1, 1]*B[1, 1]]]) 53 54 For explicit matrices the Kronecker product is returned as a Matrix 55 56 >>> from sympy.matrices import Matrix, kronecker_product 57 >>> sigma_x = Matrix([ 58 ... [0, 1], 59 ... [1, 0]]) 60 ... 61 >>> Isigma_y = Matrix([ 62 ... [0, 1], 63 ... [-1, 0]]) 64 ... 65 >>> kronecker_product(sigma_x, Isigma_y) 66 Matrix([ 67 [ 0, 0, 0, 1], 68 [ 0, 0, -1, 0], 69 [ 0, 1, 0, 0], 70 [-1, 0, 0, 0]]) 71 72 See Also 73 ======== 74 KroneckerProduct 75 76 """ 77 if not matrices: 78 raise TypeError("Empty Kronecker product is undefined") 79 validate(*matrices) 80 if len(matrices) == 1: 81 return matrices[0] 82 else: 83 return KroneckerProduct(*matrices).doit() 84 85 86class KroneckerProduct(MatrixExpr): 87 """ 88 The Kronecker product of two or more arguments. 89 90 The Kronecker product is a non-commutative product of matrices. 91 Given two matrices of dimension (m, n) and (s, t) it produces a matrix 92 of dimension (m s, n t). 93 94 This is a symbolic object that simply stores its argument without 95 evaluating it. To actually compute the product, use the function 96 ``kronecker_product()`` or call the the ``.doit()`` or ``.as_explicit()`` 97 methods. 98 99 >>> from sympy.matrices import KroneckerProduct, MatrixSymbol 100 >>> A = MatrixSymbol('A', 5, 5) 101 >>> B = MatrixSymbol('B', 5, 5) 102 >>> isinstance(KroneckerProduct(A, B), KroneckerProduct) 103 True 104 """ 105 is_KroneckerProduct = True 106 107 def __new__(cls, *args, check=True): 108 args = list(map(sympify, args)) 109 if all(a.is_Identity for a in args): 110 ret = Identity(prod(a.rows for a in args)) 111 if all(isinstance(a, MatrixBase) for a in args): 112 return ret.as_explicit() 113 else: 114 return ret 115 116 if check: 117 validate(*args) 118 return super().__new__(cls, *args) 119 120 @property 121 def shape(self): 122 rows, cols = self.args[0].shape 123 for mat in self.args[1:]: 124 rows *= mat.rows 125 cols *= mat.cols 126 return (rows, cols) 127 128 def _entry(self, i, j, **kwargs): 129 result = 1 130 for mat in reversed(self.args): 131 i, m = divmod(i, mat.rows) 132 j, n = divmod(j, mat.cols) 133 result *= mat[m, n] 134 return result 135 136 def _eval_adjoint(self): 137 return KroneckerProduct(*list(map(adjoint, self.args))).doit() 138 139 def _eval_conjugate(self): 140 return KroneckerProduct(*[a.conjugate() for a in self.args]).doit() 141 142 def _eval_transpose(self): 143 return KroneckerProduct(*list(map(transpose, self.args))).doit() 144 145 def _eval_trace(self): 146 from .trace import trace 147 return prod(trace(a) for a in self.args) 148 149 def _eval_determinant(self): 150 from .determinant import det, Determinant 151 if not all(a.is_square for a in self.args): 152 return Determinant(self) 153 154 m = self.rows 155 return prod(det(a)**(m/a.rows) for a in self.args) 156 157 def _eval_inverse(self): 158 try: 159 return KroneckerProduct(*[a.inverse() for a in self.args]) 160 except ShapeError: 161 from sympy.matrices.expressions.inverse import Inverse 162 return Inverse(self) 163 164 def structurally_equal(self, other): 165 '''Determine whether two matrices have the same Kronecker product structure 166 167 Examples 168 ======== 169 170 >>> from sympy import KroneckerProduct, MatrixSymbol, symbols 171 >>> m, n = symbols(r'm, n', integer=True) 172 >>> A = MatrixSymbol('A', m, m) 173 >>> B = MatrixSymbol('B', n, n) 174 >>> C = MatrixSymbol('C', m, m) 175 >>> D = MatrixSymbol('D', n, n) 176 >>> KroneckerProduct(A, B).structurally_equal(KroneckerProduct(C, D)) 177 True 178 >>> KroneckerProduct(A, B).structurally_equal(KroneckerProduct(D, C)) 179 False 180 >>> KroneckerProduct(A, B).structurally_equal(C) 181 False 182 ''' 183 # Inspired by BlockMatrix 184 return (isinstance(other, KroneckerProduct) 185 and self.shape == other.shape 186 and len(self.args) == len(other.args) 187 and all(a.shape == b.shape for (a, b) in zip(self.args, other.args))) 188 189 def has_matching_shape(self, other): 190 '''Determine whether two matrices have the appropriate structure to bring matrix 191 multiplication inside the KroneckerProdut 192 193 Examples 194 ======== 195 >>> from sympy import KroneckerProduct, MatrixSymbol, symbols 196 >>> m, n = symbols(r'm, n', integer=True) 197 >>> A = MatrixSymbol('A', m, n) 198 >>> B = MatrixSymbol('B', n, m) 199 >>> KroneckerProduct(A, B).has_matching_shape(KroneckerProduct(B, A)) 200 True 201 >>> KroneckerProduct(A, B).has_matching_shape(KroneckerProduct(A, B)) 202 False 203 >>> KroneckerProduct(A, B).has_matching_shape(A) 204 False 205 ''' 206 return (isinstance(other, KroneckerProduct) 207 and self.cols == other.rows 208 and len(self.args) == len(other.args) 209 and all(a.cols == b.rows for (a, b) in zip(self.args, other.args))) 210 211 def _eval_expand_kroneckerproduct(self, **hints): 212 return flatten(canon(typed({KroneckerProduct: distribute(KroneckerProduct, MatAdd)}))(self)) 213 214 def _kronecker_add(self, other): 215 if self.structurally_equal(other): 216 return self.__class__(*[a + b for (a, b) in zip(self.args, other.args)]) 217 else: 218 return self + other 219 220 def _kronecker_mul(self, other): 221 if self.has_matching_shape(other): 222 return self.__class__(*[a*b for (a, b) in zip(self.args, other.args)]) 223 else: 224 return self * other 225 226 def doit(self, **kwargs): 227 deep = kwargs.get('deep', True) 228 if deep: 229 args = [arg.doit(**kwargs) for arg in self.args] 230 else: 231 args = self.args 232 return canonicalize(KroneckerProduct(*args)) 233 234 235def validate(*args): 236 if not all(arg.is_Matrix for arg in args): 237 raise TypeError("Mix of Matrix and Scalar symbols") 238 239 240# rules 241 242def extract_commutative(kron): 243 c_part = [] 244 nc_part = [] 245 for arg in kron.args: 246 c, nc = arg.args_cnc() 247 c_part.extend(c) 248 nc_part.append(Mul._from_args(nc)) 249 250 c_part = Mul(*c_part) 251 if c_part != 1: 252 return c_part*KroneckerProduct(*nc_part) 253 return kron 254 255 256def matrix_kronecker_product(*matrices): 257 """Compute the Kronecker product of a sequence of SymPy Matrices. 258 259 This is the standard Kronecker product of matrices [1]. 260 261 Parameters 262 ========== 263 264 matrices : tuple of MatrixBase instances 265 The matrices to take the Kronecker product of. 266 267 Returns 268 ======= 269 270 matrix : MatrixBase 271 The Kronecker product matrix. 272 273 Examples 274 ======== 275 276 >>> from sympy import Matrix 277 >>> from sympy.matrices.expressions.kronecker import ( 278 ... matrix_kronecker_product) 279 280 >>> m1 = Matrix([[1,2],[3,4]]) 281 >>> m2 = Matrix([[1,0],[0,1]]) 282 >>> matrix_kronecker_product(m1, m2) 283 Matrix([ 284 [1, 0, 2, 0], 285 [0, 1, 0, 2], 286 [3, 0, 4, 0], 287 [0, 3, 0, 4]]) 288 >>> matrix_kronecker_product(m2, m1) 289 Matrix([ 290 [1, 2, 0, 0], 291 [3, 4, 0, 0], 292 [0, 0, 1, 2], 293 [0, 0, 3, 4]]) 294 295 References 296 ========== 297 298 [1] https://en.wikipedia.org/wiki/Kronecker_product 299 """ 300 # Make sure we have a sequence of Matrices 301 if not all(isinstance(m, MatrixBase) for m in matrices): 302 raise TypeError( 303 'Sequence of Matrices expected, got: %s' % repr(matrices) 304 ) 305 306 # Pull out the first element in the product. 307 matrix_expansion = matrices[-1] 308 # Do the kronecker product working from right to left. 309 for mat in reversed(matrices[:-1]): 310 rows = mat.rows 311 cols = mat.cols 312 # Go through each row appending kronecker product to. 313 # running matrix_expansion. 314 for i in range(rows): 315 start = matrix_expansion*mat[i*cols] 316 # Go through each column joining each item 317 for j in range(cols - 1): 318 start = start.row_join( 319 matrix_expansion*mat[i*cols + j + 1] 320 ) 321 # If this is the first element, make it the start of the 322 # new row. 323 if i == 0: 324 next = start 325 else: 326 next = next.col_join(start) 327 matrix_expansion = next 328 329 MatrixClass = max(matrices, key=lambda M: M._class_priority).__class__ 330 if isinstance(matrix_expansion, MatrixClass): 331 return matrix_expansion 332 else: 333 return MatrixClass(matrix_expansion) 334 335 336def explicit_kronecker_product(kron): 337 # Make sure we have a sequence of Matrices 338 if not all(isinstance(m, MatrixBase) for m in kron.args): 339 return kron 340 341 return matrix_kronecker_product(*kron.args) 342 343 344rules = (unpack, 345 explicit_kronecker_product, 346 flatten, 347 extract_commutative) 348 349canonicalize = exhaust(condition(lambda x: isinstance(x, KroneckerProduct), 350 do_one(*rules))) 351 352 353def _kronecker_dims_key(expr): 354 if isinstance(expr, KroneckerProduct): 355 return tuple(a.shape for a in expr.args) 356 else: 357 return (0,) 358 359 360def kronecker_mat_add(expr): 361 from functools import reduce 362 args = sift(expr.args, _kronecker_dims_key) 363 nonkrons = args.pop((0,), None) 364 if not args: 365 return expr 366 367 krons = [reduce(lambda x, y: x._kronecker_add(y), group) 368 for group in args.values()] 369 370 if not nonkrons: 371 return MatAdd(*krons) 372 else: 373 return MatAdd(*krons) + nonkrons 374 375 376def kronecker_mat_mul(expr): 377 # modified from block matrix code 378 factor, matrices = expr.as_coeff_matrices() 379 380 i = 0 381 while i < len(matrices) - 1: 382 A, B = matrices[i:i+2] 383 if isinstance(A, KroneckerProduct) and isinstance(B, KroneckerProduct): 384 matrices[i] = A._kronecker_mul(B) 385 matrices.pop(i+1) 386 else: 387 i += 1 388 389 return factor*MatMul(*matrices) 390 391 392def kronecker_mat_pow(expr): 393 if isinstance(expr.base, KroneckerProduct) and all(a.is_square for a in expr.base.args): 394 return KroneckerProduct(*[MatPow(a, expr.exp) for a in expr.base.args]) 395 else: 396 return expr 397 398 399def combine_kronecker(expr): 400 """Combine KronekeckerProduct with expression. 401 402 If possible write operations on KroneckerProducts of compatible shapes 403 as a single KroneckerProduct. 404 405 Examples 406 ======== 407 408 >>> from sympy.matrices.expressions import MatrixSymbol, KroneckerProduct, combine_kronecker 409 >>> from sympy import symbols 410 >>> m, n = symbols(r'm, n', integer=True) 411 >>> A = MatrixSymbol('A', m, n) 412 >>> B = MatrixSymbol('B', n, m) 413 >>> combine_kronecker(KroneckerProduct(A, B)*KroneckerProduct(B, A)) 414 KroneckerProduct(A*B, B*A) 415 >>> combine_kronecker(KroneckerProduct(A, B)+KroneckerProduct(B.T, A.T)) 416 KroneckerProduct(A + B.T, B + A.T) 417 >>> C = MatrixSymbol('C', n, n) 418 >>> D = MatrixSymbol('D', m, m) 419 >>> combine_kronecker(KroneckerProduct(C, D)**m) 420 KroneckerProduct(C**m, D**m) 421 """ 422 def haskron(expr): 423 return isinstance(expr, MatrixExpr) and expr.has(KroneckerProduct) 424 425 rule = exhaust( 426 bottom_up(exhaust(condition(haskron, typed( 427 {MatAdd: kronecker_mat_add, 428 MatMul: kronecker_mat_mul, 429 MatPow: kronecker_mat_pow}))))) 430 result = rule(expr) 431 doit = getattr(result, 'doit', None) 432 if doit is not None: 433 return doit() 434 else: 435 return result 436