1from ...core import Add, Expr, Integer
2from ...core.strategies import (bottom_up, condition, do_one, exhaust, typed,
3                                unpack)
4from ...core.sympify import sympify
5from ...logic import false
6from ...utilities import sift
7from .determinant import Determinant
8from .inverse import Inverse
9from .matadd import MatAdd
10from .matexpr import Identity, MatrixExpr, ZeroMatrix
11from .matmul import MatMul
12from .slice import MatrixSlice
13from .trace import Trace
14from .transpose import Transpose, transpose
15
16
17class BlockMatrix(MatrixExpr):
18    """A BlockMatrix is a Matrix composed of other smaller, submatrices
19
20    The submatrices are stored in a Diofant Matrix object but accessed as part of
21    a Matrix Expression
22
23    >>> X = MatrixSymbol('X', n, n)
24    >>> Y = MatrixSymbol('Y', m, m)
25    >>> Z = MatrixSymbol('Z', n, m)
26    >>> B = BlockMatrix([[X, Z], [ZeroMatrix(m, n), Y]])
27    >>> B
28    Matrix([
29    [X, Z],
30    [0, Y]])
31
32    >>> C = BlockMatrix([[Identity(n), Z]])
33    >>> C
34    Matrix([[I, Z]])
35
36    >>> block_collapse(C*B)
37    Matrix([[X, Z + Z*Y]])
38
39    """
40
41    def __new__(cls, *args):
42        from ..immutable import ImmutableMatrix
43        args = map(sympify, args)
44        mat = ImmutableMatrix(*args)
45
46        obj = Expr.__new__(cls, mat)
47        return obj
48
49    @property
50    def shape(self):
51        numrows = numcols = Integer(0)
52        M = self.blocks
53        for i in range(M.shape[0]):
54            numrows += M[i, 0].shape[0]
55        for i in range(M.shape[1]):
56            numcols += M[0, i].shape[1]
57        return numrows, numcols
58
59    @property
60    def blockshape(self):
61        return self.blocks.shape
62
63    @property
64    def blocks(self):
65        return self.args[0]
66
67    @property
68    def rowblocksizes(self):
69        return [self.blocks[i, 0].rows for i in range(self.blockshape[0])]
70
71    @property
72    def colblocksizes(self):
73        return [self.blocks[0, i].cols for i in range(self.blockshape[1])]
74
75    def structurally_equal(self, other):
76        return (isinstance(other, BlockMatrix)
77                and self.shape == other.shape
78                and self.blockshape == other.blockshape
79                and self.rowblocksizes == other.rowblocksizes
80                and self.colblocksizes == other.colblocksizes)
81
82    def _blockmul(self, other):
83        if (isinstance(other, BlockMatrix) and
84                self.colblocksizes == other.rowblocksizes):
85            return BlockMatrix(self.blocks*other.blocks)
86
87        return self * other
88
89    def _blockadd(self, other):
90        if (isinstance(other, BlockMatrix)
91                and self.structurally_equal(other)):
92            return BlockMatrix(self.blocks + other.blocks)
93
94        return self + other
95
96    def _eval_transpose(self):
97        from .. import Matrix
98
99        # Flip all the individual matrices
100        matrices = [transpose(matrix) for matrix in self.blocks]
101        # Make a copy
102        M = Matrix(self.blockshape[0], self.blockshape[1], matrices)
103        # Transpose the block structure
104        M = M.transpose()
105        return BlockMatrix(M)
106
107    def _eval_trace(self):
108        if self.rowblocksizes == self.colblocksizes:
109            return Add(*[Trace(self.blocks[i, i])
110                         for i in range(self.blockshape[0])])
111        raise NotImplementedError("Can't perform trace of irregular "
112                                  'blockshape')  # pragma: no cover
113
114    def _eval_determinant(self):
115        return Determinant(self)
116
117    def transpose(self):
118        """Return transpose of matrix.
119
120        Examples
121        ========
122
123        >>> X = MatrixSymbol('X', n, n)
124        >>> Y = MatrixSymbol('Y', m, m)
125        >>> Z = MatrixSymbol('Z', n, m)
126        >>> B = BlockMatrix([[X, Z], [ZeroMatrix(m, n), Y]])
127        >>> B.transpose()
128        Matrix([
129        [X.T,  0],
130        [Z.T, Y.T]])
131        >>> _.transpose()
132        Matrix([
133        [X, Z],
134        [0, Y]])
135
136        """
137        return self._eval_transpose()
138
139    def _entry(self, i, j):
140        # Find row entry
141        for row_block, numrows in enumerate(self.rowblocksizes):  # pragma: no branch
142            if (i < numrows) != false:
143                break
144            else:
145                i -= numrows
146        for col_block, numcols in enumerate(self.colblocksizes):  # pragma: no branch
147            if (j < numcols) != false:
148                break
149            else:
150                j -= numcols
151        return self.blocks[row_block, col_block][i, j]
152
153    @property
154    def is_Identity(self):
155        if self.blockshape[0] != self.blockshape[1]:
156            return False
157        for i in range(self.blockshape[0]):
158            for j in range(self.blockshape[1]):
159                if i == j and not self.blocks[i, j].is_Identity:
160                    return False
161                if i != j and not self.blocks[i, j].is_ZeroMatrix:
162                    return False
163        return True
164
165    @property
166    def is_structurally_symmetric(self):
167        return self.rowblocksizes == self.colblocksizes
168
169    def equals(self, other):
170        if self == other:
171            return True
172        if isinstance(other, BlockMatrix) and self.blocks == other.blocks:
173            return True
174        return super().equals(other)
175
176
177class BlockDiagMatrix(BlockMatrix):
178    """
179    A BlockDiagMatrix is a BlockMatrix with matrices only along the diagonal
180
181    >>> X = MatrixSymbol('X', n, n)
182    >>> Y = MatrixSymbol('Y', m, m)
183    >>> BlockDiagMatrix(X, Y)
184    Matrix([
185    [X, 0],
186    [0, Y]])
187
188    """
189
190    def __new__(cls, *mats):
191        return Expr.__new__(BlockDiagMatrix, *mats)
192
193    @property
194    def diag(self):
195        return self.args
196
197    @property
198    def blocks(self):
199        from ..immutable import ImmutableMatrix
200        mats = self.args
201        data = [[mats[i] if i == j else ZeroMatrix(mats[i].rows, mats[j].cols)
202                 for j in range(len(mats))]
203                for i in range(len(mats))]
204        return ImmutableMatrix(data)
205
206    @property
207    def shape(self):
208        return (sum(block.rows for block in self.args),
209                sum(block.cols for block in self.args))
210
211    @property
212    def blockshape(self):
213        n = len(self.args)
214        return n, n
215
216    @property
217    def rowblocksizes(self):
218        return [block.rows for block in self.args]
219
220    @property
221    def colblocksizes(self):
222        return [block.cols for block in self.args]
223
224    def _eval_inverse(self, expand='ignored'):
225        return BlockDiagMatrix(*[mat.inverse() for mat in self.args])
226
227    def _blockmul(self, other):
228        if (isinstance(other, BlockDiagMatrix) and
229                self.colblocksizes == other.rowblocksizes):
230            return BlockDiagMatrix(*[a*b for a, b in zip(self.args, other.args)])
231        else:
232            return BlockMatrix._blockmul(self, other)
233
234    def _blockadd(self, other):
235        if (isinstance(other, BlockDiagMatrix) and
236                self.blockshape == other.blockshape and
237                self.rowblocksizes == other.rowblocksizes and
238                self.colblocksizes == other.colblocksizes):
239            return BlockDiagMatrix(*[a + b for a, b in zip(self.args, other.args)])
240        else:
241            return BlockMatrix._blockadd(self, other)
242
243
244def block_collapse(expr):
245    """Evaluates a block matrix expression
246
247    >>> X = MatrixSymbol('X', n, n)
248    >>> Y = MatrixSymbol('Y', m, m)
249    >>> Z = MatrixSymbol('Z', n, m)
250    >>> B = BlockMatrix([[X, Z], [ZeroMatrix(m, n), Y]])
251    >>> B
252    Matrix([
253    [X, Z],
254    [0, Y]])
255
256    >>> C = BlockMatrix([[Identity(n), Z]])
257    >>> C
258    Matrix([[I, Z]])
259
260    >>> block_collapse(C*B)
261    Matrix([[X, Z + Z*Y]])
262
263    """
264    def hasbm(expr):
265        return isinstance(expr, MatrixExpr) and expr.has(BlockMatrix)
266    rule = exhaust(
267        bottom_up(exhaust(condition(hasbm, typed(
268            {MatAdd: do_one([bc_matadd, bc_block_plus_ident]),
269             MatMul: do_one([bc_matmul, bc_dist]),
270             Transpose: bc_transpose,
271             Inverse: bc_inverse,
272             BlockMatrix: do_one([bc_unpack, deblock])})))))
273    result = rule(expr)
274    return result.doit()
275
276
277def bc_unpack(expr):
278    if expr.blockshape == (1, 1):
279        return expr.blocks[0, 0]
280    return expr
281
282
283def bc_matadd(expr):
284    args = sift(expr.args, lambda M: isinstance(M, BlockMatrix))
285    blocks = args[True]
286    if not blocks:
287        return expr
288
289    nonblocks = args[False]
290    block = blocks[0]
291    for b in blocks[1:]:
292        block = block._blockadd(b)
293    if nonblocks:
294        return MatAdd(*nonblocks) + block
295    else:
296        return block
297
298
299def bc_block_plus_ident(expr):
300    idents = [arg for arg in expr.args if arg.is_Identity]
301    if not idents:
302        return expr
303
304    blocks = [arg for arg in expr.args if isinstance(arg, BlockMatrix)]
305    if (blocks and all(b.structurally_equal(blocks[0]) for b in blocks)
306            and blocks[0].is_structurally_symmetric):
307        block_id = BlockDiagMatrix(*[Identity(k)
308                                     for k in blocks[0].rowblocksizes])
309        return MatAdd(block_id * len(idents), *blocks).doit()
310
311    return expr
312
313
314def bc_dist(expr):
315    """Turn  a*[X, Y] into [a*X, a*Y]."""
316    factor, mat = expr.as_coeff_mmul()
317    if factor != 1 and isinstance(unpack(mat), BlockMatrix):
318        B = unpack(mat).blocks
319        return BlockMatrix([[factor * B[i, j] for j in range(B.cols)]
320                            for i in range(B.rows)])
321    return expr
322
323
324def bc_matmul(expr):
325    factor, matrices = expr.as_coeff_matrices()
326
327    i = 0
328    while i + 1 < len(matrices):
329        A, B = matrices[i:i+2]
330        if isinstance(A, BlockMatrix) and isinstance(B, BlockMatrix):
331            matrices[i] = A._blockmul(B)
332            matrices.pop(i+1)
333        elif isinstance(A, BlockMatrix):
334            matrices[i] = A._blockmul(BlockMatrix([[B]]))
335            matrices.pop(i+1)
336        elif isinstance(B, BlockMatrix):
337            matrices[i] = BlockMatrix([[A]])._blockmul(B)
338            matrices.pop(i+1)
339        else:
340            i += 1
341    return MatMul(factor, *matrices).doit()
342
343
344def bc_transpose(expr):
345    return BlockMatrix(block_collapse(expr.arg).blocks.applyfunc(transpose).T)
346
347
348def bc_inverse(expr):
349    return blockinverse_2x2(Inverse(reblock_2x2(expr.arg)))
350
351
352def blockinverse_2x2(expr):
353    # Cite: The Matrix Cookbook Section 9.1.3
354    [[A, B],
355     [C, D]] = expr.arg.blocks.tolist()
356
357    return BlockMatrix([[+(A - B*D.inverse()*C).inverse(),  (-A).inverse()*B*(D - C*A.inverse()*B).inverse()],
358                        [-(D - C*A.inverse()*B).inverse()*C*A.inverse(),     (D - C*A.inverse()*B).inverse()]])
359
360
361def deblock(B):
362    """Flatten a BlockMatrix of BlockMatrices."""
363    if not isinstance(B, BlockMatrix) or not B.blocks.has(BlockMatrix):
364        return B
365
366    def wrap(x):
367        return x if isinstance(x, BlockMatrix) else BlockMatrix([[x]])
368
369    bb = B.blocks.applyfunc(wrap)  # everything is a block
370
371    from .. import Matrix
372    MM = Matrix(0, sum(bb[0, i].blocks.shape[1] for i in range(bb.shape[1])), [])
373    for row in range(bb.shape[0]):
374        M = Matrix(bb[row, 0].blocks)
375        for col in range(1, bb.shape[1]):
376            M = M.row_join(bb[row, col].blocks)
377        MM = MM.col_join(M)
378
379    return BlockMatrix(MM)
380
381
382def reblock_2x2(B):
383    """Reblock a BlockMatrix so that it has 2x2 blocks of block matrices."""
384    if not isinstance(B, BlockMatrix) or not all(d > 2 for d in B.blocks.shape):
385        return B
386
387    BM = BlockMatrix  # for brevity's sake
388    return BM([[    B.blocks[0,  0], BM(B.blocks[0,  1:])],
389               [BM(B.blocks[1:, 0]), BM(B.blocks[1:, 1:])]])
390
391
392def bounds(sizes):
393    """Convert sequence of numbers into pairs of low-high pairs
394
395    >>> bounds((1, 10, 50))
396    [(0, 1), (1, 11), (11, 61)]
397
398    """
399    low = 0
400    rv = []
401    for size in sizes:
402        rv.append((low, low + size))
403        low += size
404    return rv
405
406
407def blockcut(expr, rowsizes, colsizes):
408    """Cut a matrix expression into Blocks
409
410    >>> M = ImmutableMatrix(4, 4, range(16))
411    >>> B = blockcut(M, (1, 3), (1, 3))
412    >>> type(B).__name__
413    'BlockMatrix'
414    >>> ImmutableMatrix(B.blocks[0, 1])
415    Matrix([[1, 2, 3]])
416
417    """
418    rowbounds = bounds(rowsizes)
419    colbounds = bounds(colsizes)
420    return BlockMatrix([[MatrixSlice(expr, rowbound, colbound)
421                         for colbound in colbounds]
422                        for rowbound in rowbounds])
423