1from ...core import Expr, Integer, Tuple
2from ...functions import floor
3from ...logic import true
4from .matexpr import MatrixExpr
5
6
7def normalize(i, parentsize):
8    if isinstance(i, slice):
9        i = (i.start, i.stop, i.step)
10    if not isinstance(i, (tuple, list, Tuple)):
11        if (i < Integer(0)) == true:
12            i += parentsize
13        i = (i, i + 1, 1)
14    i = list(i)
15    if len(i) == 2:
16        i.append(1)
17    start, stop, step = i
18    start = start or 0
19    if stop is None:
20        stop = parentsize
21    if (start < Integer(0)) == true:
22        start += parentsize
23    if (stop < Integer(0)) == true:
24        stop += parentsize
25    step = step or 1
26
27    if ((stop - start) * step < Integer(1)) == true:
28        raise IndexError()
29
30    return start, stop, step
31
32
33class MatrixSlice(MatrixExpr):
34    """A MatrixSlice of a Matrix Expression
35
36    Examples
37    ========
38
39    >>> M = ImmutableMatrix(4, 4, range(16))
40    >>> print(M)
41    Matrix([
42    [ 0,  1,  2,  3],
43    [ 4,  5,  6,  7],
44    [ 8,  9, 10, 11],
45    [12, 13, 14, 15]])
46
47    >>> B = MatrixSlice(M, (0, 2), (2, 4))
48    >>> print(ImmutableMatrix(B))
49    Matrix([
50    [2, 3],
51    [6, 7]])
52
53    """
54
55    parent = property(lambda self: self.args[0])
56    rowslice = property(lambda self: self.args[1])
57    colslice = property(lambda self: self.args[2])
58
59    def __new__(cls, parent, rowslice, colslice):
60        rowslice = normalize(rowslice, parent.shape[0])
61        colslice = normalize(colslice, parent.shape[1])
62        if true in (0 > rowslice[0], parent.shape[0] < rowslice[1],
63                    0 > colslice[0], parent.shape[1] < colslice[1]):
64            raise IndexError()
65        if isinstance(parent, MatrixSlice):
66            return mat_slice_of_slice(parent, rowslice, colslice)
67        return Expr.__new__(cls, parent, Tuple(*rowslice), Tuple(*colslice))
68
69    @property
70    def shape(self):
71        rows = self.rowslice[1] - self.rowslice[0]
72        rows = rows if self.rowslice[2] == 1 else floor(rows/self.rowslice[2])
73        cols = self.colslice[1] - self.colslice[0]
74        cols = cols if self.colslice[2] == 1 else floor(cols/self.colslice[2])
75        return rows, cols
76
77    def _entry(self, i, j):
78        return self.parent._entry(i*self.rowslice[2] + self.rowslice[0],
79                                  j*self.colslice[2] + self.colslice[0])
80
81    @property
82    def on_diag(self):
83        return self.rowslice == self.colslice
84
85
86def slice_of_slice(s, t):
87    start1, stop1, step1 = s
88    start2, stop2, step2 = t
89
90    start = start1 + start2*step1
91    step = step1 * step2
92    stop = start1 + step1*stop2
93
94    assert stop <= stop1
95
96    return start, stop, step
97
98
99def mat_slice_of_slice(parent, rowslice, colslice):
100    """Collapse nested matrix slices
101
102    >>> X = MatrixSymbol('X', 10, 10)
103    >>> X[:, 1:5][5:8, :]
104    X[5:8, 1:5]
105    >>> X[1:9:2, 2:6][1:3, 2]
106    X[3:7:2, 4]
107
108    """
109    row = slice_of_slice(parent.rowslice, rowslice)
110    col = slice_of_slice(parent.colslice, colslice)
111    return MatrixSlice(parent.parent, row, col)
112