1from ...core import Add, Expr, Integer 2from ...core.strategies import (bottom_up, condition, do_one, exhaust, typed, 3 unpack) 4from ...core.sympify import sympify 5from ...logic import false 6from ...utilities import sift 7from .determinant import Determinant 8from .inverse import Inverse 9from .matadd import MatAdd 10from .matexpr import Identity, MatrixExpr, ZeroMatrix 11from .matmul import MatMul 12from .slice import MatrixSlice 13from .trace import Trace 14from .transpose import Transpose, transpose 15 16 17class BlockMatrix(MatrixExpr): 18 """A BlockMatrix is a Matrix composed of other smaller, submatrices 19 20 The submatrices are stored in a Diofant Matrix object but accessed as part of 21 a Matrix Expression 22 23 >>> X = MatrixSymbol('X', n, n) 24 >>> Y = MatrixSymbol('Y', m, m) 25 >>> Z = MatrixSymbol('Z', n, m) 26 >>> B = BlockMatrix([[X, Z], [ZeroMatrix(m, n), Y]]) 27 >>> B 28 Matrix([ 29 [X, Z], 30 [0, Y]]) 31 32 >>> C = BlockMatrix([[Identity(n), Z]]) 33 >>> C 34 Matrix([[I, Z]]) 35 36 >>> block_collapse(C*B) 37 Matrix([[X, Z + Z*Y]]) 38 39 """ 40 41 def __new__(cls, *args): 42 from ..immutable import ImmutableMatrix 43 args = map(sympify, args) 44 mat = ImmutableMatrix(*args) 45 46 obj = Expr.__new__(cls, mat) 47 return obj 48 49 @property 50 def shape(self): 51 numrows = numcols = Integer(0) 52 M = self.blocks 53 for i in range(M.shape[0]): 54 numrows += M[i, 0].shape[0] 55 for i in range(M.shape[1]): 56 numcols += M[0, i].shape[1] 57 return numrows, numcols 58 59 @property 60 def blockshape(self): 61 return self.blocks.shape 62 63 @property 64 def blocks(self): 65 return self.args[0] 66 67 @property 68 def rowblocksizes(self): 69 return [self.blocks[i, 0].rows for i in range(self.blockshape[0])] 70 71 @property 72 def colblocksizes(self): 73 return [self.blocks[0, i].cols for i in range(self.blockshape[1])] 74 75 def structurally_equal(self, other): 76 return (isinstance(other, BlockMatrix) 77 and self.shape == other.shape 78 and self.blockshape == other.blockshape 79 and self.rowblocksizes == other.rowblocksizes 80 and self.colblocksizes == other.colblocksizes) 81 82 def _blockmul(self, other): 83 if (isinstance(other, BlockMatrix) and 84 self.colblocksizes == other.rowblocksizes): 85 return BlockMatrix(self.blocks*other.blocks) 86 87 return self * other 88 89 def _blockadd(self, other): 90 if (isinstance(other, BlockMatrix) 91 and self.structurally_equal(other)): 92 return BlockMatrix(self.blocks + other.blocks) 93 94 return self + other 95 96 def _eval_transpose(self): 97 from .. import Matrix 98 99 # Flip all the individual matrices 100 matrices = [transpose(matrix) for matrix in self.blocks] 101 # Make a copy 102 M = Matrix(self.blockshape[0], self.blockshape[1], matrices) 103 # Transpose the block structure 104 M = M.transpose() 105 return BlockMatrix(M) 106 107 def _eval_trace(self): 108 if self.rowblocksizes == self.colblocksizes: 109 return Add(*[Trace(self.blocks[i, i]) 110 for i in range(self.blockshape[0])]) 111 raise NotImplementedError("Can't perform trace of irregular " 112 'blockshape') # pragma: no cover 113 114 def _eval_determinant(self): 115 return Determinant(self) 116 117 def transpose(self): 118 """Return transpose of matrix. 119 120 Examples 121 ======== 122 123 >>> X = MatrixSymbol('X', n, n) 124 >>> Y = MatrixSymbol('Y', m, m) 125 >>> Z = MatrixSymbol('Z', n, m) 126 >>> B = BlockMatrix([[X, Z], [ZeroMatrix(m, n), Y]]) 127 >>> B.transpose() 128 Matrix([ 129 [X.T, 0], 130 [Z.T, Y.T]]) 131 >>> _.transpose() 132 Matrix([ 133 [X, Z], 134 [0, Y]]) 135 136 """ 137 return self._eval_transpose() 138 139 def _entry(self, i, j): 140 # Find row entry 141 for row_block, numrows in enumerate(self.rowblocksizes): # pragma: no branch 142 if (i < numrows) != false: 143 break 144 else: 145 i -= numrows 146 for col_block, numcols in enumerate(self.colblocksizes): # pragma: no branch 147 if (j < numcols) != false: 148 break 149 else: 150 j -= numcols 151 return self.blocks[row_block, col_block][i, j] 152 153 @property 154 def is_Identity(self): 155 if self.blockshape[0] != self.blockshape[1]: 156 return False 157 for i in range(self.blockshape[0]): 158 for j in range(self.blockshape[1]): 159 if i == j and not self.blocks[i, j].is_Identity: 160 return False 161 if i != j and not self.blocks[i, j].is_ZeroMatrix: 162 return False 163 return True 164 165 @property 166 def is_structurally_symmetric(self): 167 return self.rowblocksizes == self.colblocksizes 168 169 def equals(self, other): 170 if self == other: 171 return True 172 if isinstance(other, BlockMatrix) and self.blocks == other.blocks: 173 return True 174 return super().equals(other) 175 176 177class BlockDiagMatrix(BlockMatrix): 178 """ 179 A BlockDiagMatrix is a BlockMatrix with matrices only along the diagonal 180 181 >>> X = MatrixSymbol('X', n, n) 182 >>> Y = MatrixSymbol('Y', m, m) 183 >>> BlockDiagMatrix(X, Y) 184 Matrix([ 185 [X, 0], 186 [0, Y]]) 187 188 """ 189 190 def __new__(cls, *mats): 191 return Expr.__new__(BlockDiagMatrix, *mats) 192 193 @property 194 def diag(self): 195 return self.args 196 197 @property 198 def blocks(self): 199 from ..immutable import ImmutableMatrix 200 mats = self.args 201 data = [[mats[i] if i == j else ZeroMatrix(mats[i].rows, mats[j].cols) 202 for j in range(len(mats))] 203 for i in range(len(mats))] 204 return ImmutableMatrix(data) 205 206 @property 207 def shape(self): 208 return (sum(block.rows for block in self.args), 209 sum(block.cols for block in self.args)) 210 211 @property 212 def blockshape(self): 213 n = len(self.args) 214 return n, n 215 216 @property 217 def rowblocksizes(self): 218 return [block.rows for block in self.args] 219 220 @property 221 def colblocksizes(self): 222 return [block.cols for block in self.args] 223 224 def _eval_inverse(self, expand='ignored'): 225 return BlockDiagMatrix(*[mat.inverse() for mat in self.args]) 226 227 def _blockmul(self, other): 228 if (isinstance(other, BlockDiagMatrix) and 229 self.colblocksizes == other.rowblocksizes): 230 return BlockDiagMatrix(*[a*b for a, b in zip(self.args, other.args)]) 231 else: 232 return BlockMatrix._blockmul(self, other) 233 234 def _blockadd(self, other): 235 if (isinstance(other, BlockDiagMatrix) and 236 self.blockshape == other.blockshape and 237 self.rowblocksizes == other.rowblocksizes and 238 self.colblocksizes == other.colblocksizes): 239 return BlockDiagMatrix(*[a + b for a, b in zip(self.args, other.args)]) 240 else: 241 return BlockMatrix._blockadd(self, other) 242 243 244def block_collapse(expr): 245 """Evaluates a block matrix expression 246 247 >>> X = MatrixSymbol('X', n, n) 248 >>> Y = MatrixSymbol('Y', m, m) 249 >>> Z = MatrixSymbol('Z', n, m) 250 >>> B = BlockMatrix([[X, Z], [ZeroMatrix(m, n), Y]]) 251 >>> B 252 Matrix([ 253 [X, Z], 254 [0, Y]]) 255 256 >>> C = BlockMatrix([[Identity(n), Z]]) 257 >>> C 258 Matrix([[I, Z]]) 259 260 >>> block_collapse(C*B) 261 Matrix([[X, Z + Z*Y]]) 262 263 """ 264 def hasbm(expr): 265 return isinstance(expr, MatrixExpr) and expr.has(BlockMatrix) 266 rule = exhaust( 267 bottom_up(exhaust(condition(hasbm, typed( 268 {MatAdd: do_one([bc_matadd, bc_block_plus_ident]), 269 MatMul: do_one([bc_matmul, bc_dist]), 270 Transpose: bc_transpose, 271 Inverse: bc_inverse, 272 BlockMatrix: do_one([bc_unpack, deblock])}))))) 273 result = rule(expr) 274 return result.doit() 275 276 277def bc_unpack(expr): 278 if expr.blockshape == (1, 1): 279 return expr.blocks[0, 0] 280 return expr 281 282 283def bc_matadd(expr): 284 args = sift(expr.args, lambda M: isinstance(M, BlockMatrix)) 285 blocks = args[True] 286 if not blocks: 287 return expr 288 289 nonblocks = args[False] 290 block = blocks[0] 291 for b in blocks[1:]: 292 block = block._blockadd(b) 293 if nonblocks: 294 return MatAdd(*nonblocks) + block 295 else: 296 return block 297 298 299def bc_block_plus_ident(expr): 300 idents = [arg for arg in expr.args if arg.is_Identity] 301 if not idents: 302 return expr 303 304 blocks = [arg for arg in expr.args if isinstance(arg, BlockMatrix)] 305 if (blocks and all(b.structurally_equal(blocks[0]) for b in blocks) 306 and blocks[0].is_structurally_symmetric): 307 block_id = BlockDiagMatrix(*[Identity(k) 308 for k in blocks[0].rowblocksizes]) 309 return MatAdd(block_id * len(idents), *blocks).doit() 310 311 return expr 312 313 314def bc_dist(expr): 315 """Turn a*[X, Y] into [a*X, a*Y].""" 316 factor, mat = expr.as_coeff_mmul() 317 if factor != 1 and isinstance(unpack(mat), BlockMatrix): 318 B = unpack(mat).blocks 319 return BlockMatrix([[factor * B[i, j] for j in range(B.cols)] 320 for i in range(B.rows)]) 321 return expr 322 323 324def bc_matmul(expr): 325 factor, matrices = expr.as_coeff_matrices() 326 327 i = 0 328 while i + 1 < len(matrices): 329 A, B = matrices[i:i+2] 330 if isinstance(A, BlockMatrix) and isinstance(B, BlockMatrix): 331 matrices[i] = A._blockmul(B) 332 matrices.pop(i+1) 333 elif isinstance(A, BlockMatrix): 334 matrices[i] = A._blockmul(BlockMatrix([[B]])) 335 matrices.pop(i+1) 336 elif isinstance(B, BlockMatrix): 337 matrices[i] = BlockMatrix([[A]])._blockmul(B) 338 matrices.pop(i+1) 339 else: 340 i += 1 341 return MatMul(factor, *matrices).doit() 342 343 344def bc_transpose(expr): 345 return BlockMatrix(block_collapse(expr.arg).blocks.applyfunc(transpose).T) 346 347 348def bc_inverse(expr): 349 return blockinverse_2x2(Inverse(reblock_2x2(expr.arg))) 350 351 352def blockinverse_2x2(expr): 353 # Cite: The Matrix Cookbook Section 9.1.3 354 [[A, B], 355 [C, D]] = expr.arg.blocks.tolist() 356 357 return BlockMatrix([[+(A - B*D.inverse()*C).inverse(), (-A).inverse()*B*(D - C*A.inverse()*B).inverse()], 358 [-(D - C*A.inverse()*B).inverse()*C*A.inverse(), (D - C*A.inverse()*B).inverse()]]) 359 360 361def deblock(B): 362 """Flatten a BlockMatrix of BlockMatrices.""" 363 if not isinstance(B, BlockMatrix) or not B.blocks.has(BlockMatrix): 364 return B 365 366 def wrap(x): 367 return x if isinstance(x, BlockMatrix) else BlockMatrix([[x]]) 368 369 bb = B.blocks.applyfunc(wrap) # everything is a block 370 371 from .. import Matrix 372 MM = Matrix(0, sum(bb[0, i].blocks.shape[1] for i in range(bb.shape[1])), []) 373 for row in range(bb.shape[0]): 374 M = Matrix(bb[row, 0].blocks) 375 for col in range(1, bb.shape[1]): 376 M = M.row_join(bb[row, col].blocks) 377 MM = MM.col_join(M) 378 379 return BlockMatrix(MM) 380 381 382def reblock_2x2(B): 383 """Reblock a BlockMatrix so that it has 2x2 blocks of block matrices.""" 384 if not isinstance(B, BlockMatrix) or not all(d > 2 for d in B.blocks.shape): 385 return B 386 387 BM = BlockMatrix # for brevity's sake 388 return BM([[ B.blocks[0, 0], BM(B.blocks[0, 1:])], 389 [BM(B.blocks[1:, 0]), BM(B.blocks[1:, 1:])]]) 390 391 392def bounds(sizes): 393 """Convert sequence of numbers into pairs of low-high pairs 394 395 >>> bounds((1, 10, 50)) 396 [(0, 1), (1, 11), (11, 61)] 397 398 """ 399 low = 0 400 rv = [] 401 for size in sizes: 402 rv.append((low, low + size)) 403 low += size 404 return rv 405 406 407def blockcut(expr, rowsizes, colsizes): 408 """Cut a matrix expression into Blocks 409 410 >>> M = ImmutableMatrix(4, 4, range(16)) 411 >>> B = blockcut(M, (1, 3), (1, 3)) 412 >>> type(B).__name__ 413 'BlockMatrix' 414 >>> ImmutableMatrix(B.blocks[0, 1]) 415 Matrix([[1, 2, 3]]) 416 417 """ 418 rowbounds = bounds(rowsizes) 419 colbounds = bounds(colsizes) 420 return BlockMatrix([[MatrixSlice(expr, rowbound, colbound) 421 for colbound in colbounds] 422 for rowbound in rowbounds]) 423