1from sympy.matrices.expressions.matexpr  import MatrixExpr
2from sympy import Tuple, Basic
3from sympy.functions.elementary.integers import floor
4
5def normalize(i, parentsize):
6    if isinstance(i, slice):
7        i = (i.start, i.stop, i.step)
8    if not isinstance(i, (tuple, list, Tuple)):
9        if (i < 0) == True:
10            i += parentsize
11        i = (i, i+1, 1)
12    i = list(i)
13    if len(i) == 2:
14        i.append(1)
15    start, stop, step = i
16    start = start or 0
17    if stop is None:
18        stop = parentsize
19    if (start < 0) == True:
20        start += parentsize
21    if (stop < 0) == True:
22        stop += parentsize
23    step = step or 1
24
25    if ((stop - start) * step < 1) == True:
26        raise IndexError()
27
28    return (start, stop, step)
29
30class MatrixSlice(MatrixExpr):
31    """ A MatrixSlice of a Matrix Expression
32
33    Examples
34    ========
35
36    >>> from sympy import MatrixSlice, ImmutableMatrix
37    >>> M = ImmutableMatrix(4, 4, range(16))
38    >>> M
39    Matrix([
40    [ 0,  1,  2,  3],
41    [ 4,  5,  6,  7],
42    [ 8,  9, 10, 11],
43    [12, 13, 14, 15]])
44
45    >>> B = MatrixSlice(M, (0, 2), (2, 4))
46    >>> ImmutableMatrix(B)
47    Matrix([
48    [2, 3],
49    [6, 7]])
50    """
51    parent = property(lambda self: self.args[0])
52    rowslice = property(lambda self: self.args[1])
53    colslice = property(lambda self: self.args[2])
54
55    def __new__(cls, parent, rowslice, colslice):
56        rowslice = normalize(rowslice, parent.shape[0])
57        colslice = normalize(colslice, parent.shape[1])
58        if not (len(rowslice) == len(colslice) == 3):
59            raise IndexError()
60        if ((0 > rowslice[0]) == True or
61            (parent.shape[0] < rowslice[1]) == True or
62            (0 > colslice[0]) == True or
63            (parent.shape[1] < colslice[1]) == True):
64            raise IndexError()
65        if isinstance(parent, MatrixSlice):
66            return mat_slice_of_slice(parent, rowslice, colslice)
67        return Basic.__new__(cls, parent, Tuple(*rowslice), Tuple(*colslice))
68
69    @property
70    def shape(self):
71        rows = self.rowslice[1] - self.rowslice[0]
72        rows = rows if self.rowslice[2] == 1 else floor(rows/self.rowslice[2])
73        cols = self.colslice[1] - self.colslice[0]
74        cols = cols if self.colslice[2] == 1 else floor(cols/self.colslice[2])
75        return rows, cols
76
77    def _entry(self, i, j, **kwargs):
78        return self.parent._entry(i*self.rowslice[2] + self.rowslice[0],
79                                  j*self.colslice[2] + self.colslice[0],
80                                  **kwargs)
81
82    @property
83    def on_diag(self):
84        return self.rowslice == self.colslice
85
86
87def slice_of_slice(s, t):
88    start1, stop1, step1 = s
89    start2, stop2, step2 = t
90
91    start = start1 + start2*step1
92    step = step1 * step2
93    stop = start1 + step1*stop2
94
95    if stop > stop1:
96        raise IndexError()
97
98    return start, stop, step
99
100
101def mat_slice_of_slice(parent, rowslice, colslice):
102    """ Collapse nested matrix slices
103
104    >>> from sympy import MatrixSymbol
105    >>> X = MatrixSymbol('X', 10, 10)
106    >>> X[:, 1:5][5:8, :]
107    X[5:8, 1:5]
108    >>> X[1:9:2, 2:6][1:3, 2]
109    X[3:7:2, 4:5]
110    """
111    row = slice_of_slice(parent.rowslice, rowslice)
112    col = slice_of_slice(parent.colslice, colslice)
113    return MatrixSlice(parent.parent, row, col)
114