1from itertools import chain, starmap, product, groupby, islice
2from functools import reduce
3from operator import itemgetter
4from typing import List, Tuple, Iterable, Sequence, Optional, Union
5
6from AnyQt.QtCore import (
7    QModelIndex, QAbstractItemModel, QItemSelectionModel, QItemSelection,
8    QObject
9)
10
11
12class BlockSelectionModel(QItemSelectionModel):
13    """
14    Item selection model ensuring the selection maintains a simple block
15    like structure.
16
17    e.g.
18
19        [a b] c [d e]
20        [f g] h [i j]
21
22    is allowed but this is not
23
24        [a] b  c  d e
25        [f  g] h [i j]
26
27    I.e. select the Cartesian product of row and column indices.
28
29    """
30    def __init__(
31            self, model: QAbstractItemModel, parent: Optional[QObject] = None,
32            selectBlocks=True, **kwargs
33    ) -> None:
34        super().__init__(model, parent, **kwargs)
35        self.__selectBlocks = selectBlocks
36
37    def select(self, selection: Union[QItemSelection, QModelIndex],
38               flags: QItemSelectionModel.SelectionFlags) -> None:
39        """Reimplemented."""
40        if isinstance(selection, QModelIndex):
41            selection = QItemSelection(selection, selection)
42
43        if not self.__selectBlocks:
44            super().select(selection, flags)
45            return
46
47        model = self.model()
48
49        if flags & QItemSelectionModel.Current:  # no current selection support
50            flags &= ~QItemSelectionModel.Current
51        if flags & QItemSelectionModel.Toggle:  # no toggle support either
52            flags &= ~QItemSelectionModel.Toggle
53            flags |= QItemSelectionModel.Select
54
55        if flags == QItemSelectionModel.ClearAndSelect:
56            # extend selection ranges in `selection` to span all row/columns
57            sel_rows = selection_rows(selection)
58            sel_cols = selection_columns(selection)
59            selection = QItemSelection()
60            for row_range, col_range in \
61                    product(to_ranges(sel_rows), to_ranges(sel_cols)):
62                qitemselection_select_range(
63                    selection, model, row_range, col_range
64                )
65        elif flags & (QItemSelectionModel.Select |
66                      QItemSelectionModel.Deselect):
67            # extend all selection ranges in `selection` with the full current
68            # row/col spans
69            rows, cols = selection_blocks(self.selection())
70            sel_rows = selection_rows(selection)
71            sel_cols = selection_columns(selection)
72            ext_selection = QItemSelection()
73            for row_range, col_range in \
74                    product(to_ranges(rows), to_ranges(sel_cols)):
75                qitemselection_select_range(
76                    ext_selection, model, row_range, col_range
77                )
78            for row_range, col_range in \
79                    product(to_ranges(sel_rows), to_ranges(cols)):
80                qitemselection_select_range(
81                    ext_selection, model, row_range, col_range
82                )
83            selection.merge(ext_selection, QItemSelectionModel.Select)
84        super().select(selection, flags)
85
86    def selectBlocks(self):
87        """Is the block selection in effect."""
88        return self.__selectBlocks
89
90    def setSelectBlocks(self, state):
91        """Set the block selection state.
92
93        If set to False, the selection model behaves as the base
94        QItemSelectionModel
95
96        """
97        self.__selectBlocks = state
98
99
100def selection_rows(selection):
101    # type: (QItemSelection) -> List[Tuple[int, int]]
102    """
103    Return a list of ranges for all referenced rows contained in selection
104
105    Parameters
106    ----------
107    selection : QItemSelection
108
109    Returns
110    -------
111    rows : List[Tuple[int, int]]
112    """
113    spans = set(range(s.top(), s.bottom() + 1) for s in selection)
114    indices = sorted(set(chain.from_iterable(spans)))
115    return list(ranges(indices))
116
117
118def selection_columns(selection):
119    # type: (QItemSelection) -> List[Tuple[int, int]]
120    """
121    Return a list of ranges for all referenced columns contained in selection
122
123    Parameters
124    ----------
125    selection : QItemSelection
126
127    Returns
128    -------
129    rows : List[Tuple[int, int]]
130    """
131    spans = {range(s.left(), s.right() + 1) for s in selection}
132    indices = sorted(set(chain.from_iterable(spans)))
133    return list(ranges(indices))
134
135
136def selection_blocks(selection):
137    # type: (QItemSelection) -> Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]
138    if selection.count() > 0:
139        rowranges = {range(span.top(), span.bottom() + 1)
140                     for span in selection}
141        colranges = {range(span.left(), span.right() + 1)
142                     for span in selection}
143    else:
144        return [], []
145
146    rows = sorted(set(chain.from_iterable(rowranges)))
147    cols = sorted(set(chain.from_iterable(colranges)))
148    return list(ranges(rows)), list(ranges(cols))
149
150
151def ranges(indices):
152    # type: (Iterable[int]) -> Iterable[Tuple[int, int]]
153    """
154    Group consecutive indices into `(start, stop)` tuple 'ranges'.
155
156    >>> list(ranges([1, 2, 3, 5, 3, 4]))
157    >>> [(1, 4), (5, 6), (3, 5)]
158
159    """
160    g = groupby(enumerate(indices), key=lambda t: t[1] - t[0])
161    for _, range_ind in g:
162        range_ind = list(range_ind)
163        _, start = range_ind[0]
164        _, end = range_ind[-1]
165        yield start, end + 1
166
167
168def merge_ranges(
169        ranges: Iterable[Tuple[int, int]]
170) -> Sequence[Tuple[int, int]]:
171    def merge_range_seq_accum(
172            accum: List[Tuple[int, int]], r: Tuple[int, int]
173    ) -> List[Tuple[int, int]]:
174        last_start, last_stop = accum[-1]
175        r_start, r_stop = r
176        assert last_start <= r_start
177        if r_start <= last_stop:
178            # merge into last
179            accum[-1] = last_start, max(last_stop, r_stop)
180        else:
181            # push a new (disconnected) range interval
182            accum.append(r)
183        return accum
184
185    ranges = sorted(ranges, key=itemgetter(0))
186    if ranges:
187        return reduce(merge_range_seq_accum, islice(ranges, 1, None),
188                      [ranges[0]])
189    else:
190        return []
191
192
193def qitemselection_select_range(
194        selection: QItemSelection,
195        model: QAbstractItemModel,
196        rows: range,
197        columns: range
198) -> None:
199    assert rows.step == 1 and columns.step == 1
200    selection.select(
201        model.index(rows.start, columns.start),
202        model.index(rows.stop - 1, columns.stop - 1)
203    )
204
205
206def to_ranges(spans: Iterable[Tuple[int, int]]) -> Sequence[range]:
207    return list(starmap(range, spans))
208
209
210class SymmetricSelectionModel(QItemSelectionModel):
211    """
212    Item selection model ensuring the selection is symmetric
213
214    """
215    def select(self, selection: Union[QItemSelection, QModelIndex],
216               flags: QItemSelectionModel.SelectionFlags) -> None:
217        if isinstance(selection, QModelIndex):
218            selection = QItemSelection(selection, selection)
219
220        if flags & QItemSelectionModel.Current:  # no current selection support
221            flags &= ~QItemSelectionModel.Current
222        if flags & QItemSelectionModel.Toggle:  # no toggle support either
223            flags &= ~QItemSelectionModel.Toggle
224            flags |= QItemSelectionModel.Select
225
226        model = self.model()
227        rows, cols = selection_blocks(selection)
228        sym_ranges = to_ranges(merge_ranges(chain(rows, cols)))
229        if flags == QItemSelectionModel.ClearAndSelect:
230            # extend ranges in `selection` to symmetric selection
231            # row/columns.
232            selection = QItemSelection()
233            for rows, cols in product(sym_ranges, sym_ranges):
234                qitemselection_select_range(selection, model, rows, cols)
235        elif flags & (QItemSelectionModel.Select |
236                      QItemSelectionModel.Deselect):
237            # extend ranges in sym_ranges to span all current rows/columns
238            rows_current, cols_current = selection_blocks(self.selection())
239            ext_selection = QItemSelection()
240            for rrange, crange in product(sym_ranges, sym_ranges):
241                qitemselection_select_range(selection, model, rrange, crange)
242            for rrange, crange in product(sym_ranges, to_ranges(cols_current)):
243                qitemselection_select_range(selection, model, rrange, crange)
244            for rrange, crange in product(to_ranges(rows_current), sym_ranges):
245                qitemselection_select_range(selection, model, rrange, crange)
246            selection.merge(ext_selection, QItemSelectionModel.Select)
247        super().select(selection, flags)
248
249    def selectedItems(self) -> Sequence[int]:
250        """Return the indices of the the symmetric selection."""
251        ranges_ = starmap(range, selection_rows(self.selection()))
252        return sorted(chain.from_iterable(ranges_))
253
254    def setSelectedItems(self, inds: Iterable[int]):
255        """Set and select the `inds` indices"""
256        model = self.model()
257        selection = QItemSelection()
258        sym_ranges = to_ranges(ranges(inds))
259        for rows, cols in product(sym_ranges, sym_ranges):
260            qitemselection_select_range(selection, model, rows, cols)
261        self.select(selection, QItemSelectionModel.ClearAndSelect)
262