1from __future__ import absolute_import
2# Copyright (c) 2010-2019 openpyxl
3
4from collections import defaultdict
5from itertools import chain
6from operator import itemgetter
7
8from openpyxl.descriptors.serialisable import Serialisable
9from openpyxl.descriptors import (
10    Bool,
11    NoneSet,
12    String,
13    Sequence,
14    Alias,
15    Integer,
16    Convertible,
17)
18from openpyxl.descriptors.nested import NestedText
19from openpyxl.compat import unicode
20from openpyxl.utils import (
21    rows_from_range,
22    coordinate_to_tuple,
23    get_column_letter,
24)
25
26
27def collapse_cell_addresses(cells, input_ranges=()):
28    """ Collapse a collection of cell co-ordinates down into an optimal
29        range or collection of ranges.
30
31        E.g. Cells A1, A2, A3, B1, B2 and B3 should have the data-validation
32        object applied, attempt to collapse down to a single range, A1:B3.
33
34        Currently only collapsing contiguous vertical ranges (i.e. above
35        example results in A1:A3 B1:B3).
36    """
37
38    ranges = list(input_ranges)
39
40    # convert cell into row, col tuple
41    raw_coords = (coordinate_to_tuple(cell) for cell in cells)
42
43    # group by column in order
44    grouped_coords = defaultdict(list)
45    for row, col in sorted(raw_coords, key=itemgetter(1)):
46        grouped_coords[col].append(row)
47
48    # create range string from first and last row in column
49    for col, cells in grouped_coords.items():
50        col = get_column_letter(col)
51        fmt = "{0}{1}:{2}{3}"
52        if len(cells) == 1:
53            fmt = "{0}{1}"
54        r = fmt.format(col, min(cells), col, max(cells))
55        ranges.append(r)
56
57    return " ".join(ranges)
58
59
60def expand_cell_ranges(range_string):
61    """
62    Expand cell ranges to a sequence of addresses.
63    Reverse of collapse_cell_addresses
64    Eg. converts "A1:A2 B1:B2" to (A1, A2, B1, B2)
65    """
66    cells = []
67    for rs in range_string.split():
68        cells.extend(rows_from_range(rs))
69    return set(chain.from_iterable(cells))
70
71
72from .cell_range import MultiCellRange
73
74
75class DataValidation(Serialisable):
76
77    tagname = "dataValidation"
78
79    sqref = Convertible(expected_type=MultiCellRange)
80    cells = Alias("sqref")
81    ranges = Alias("sqref")
82
83    showErrorMessage = Bool()
84    showDropDown = Bool(allow_none=True)
85    hide_drop_down = Alias('showDropDown')
86    showInputMessage = Bool()
87    showErrorMessage = Bool()
88    allowBlank = Bool()
89    allow_blank = Alias('allowBlank')
90
91    errorTitle = String(allow_none = True)
92    error = String(allow_none = True)
93    promptTitle = String(allow_none = True)
94    prompt = String(allow_none = True)
95    formula1 = NestedText(allow_none=True, expected_type=unicode)
96    formula2 = NestedText(allow_none=True, expected_type=unicode)
97
98    type = NoneSet(values=("whole", "decimal", "list", "date", "time",
99                           "textLength", "custom"))
100    errorStyle = NoneSet(values=("stop", "warning", "information"))
101    imeMode = NoneSet(values=("noControl", "off", "on", "disabled",
102                              "hiragana", "fullKatakana", "halfKatakana", "fullAlpha","halfAlpha",
103                              "fullHangul", "halfHangul"))
104    operator = NoneSet(values=("between", "notBetween", "equal", "notEqual",
105                               "lessThan", "lessThanOrEqual", "greaterThan", "greaterThanOrEqual"))
106    validation_type = Alias('type')
107
108    def __init__(self,
109                 type=None,
110                 formula1=None,
111                 formula2=None,
112                 allow_blank=False,
113                 showErrorMessage=True,
114                 showInputMessage=True,
115                 showDropDown=None,
116                 allowBlank=None,
117                 sqref=(),
118                 promptTitle=None,
119                 errorStyle=None,
120                 error=None,
121                 prompt=None,
122                 errorTitle=None,
123                 imeMode=None,
124                 operator=None,
125                 ):
126        self.sqref = sqref
127        self.showDropDown = showDropDown
128        self.imeMode = imeMode
129        self.operator = operator
130        self.formula1 = formula1
131        self.formula2 = formula2
132        if allow_blank is not None:
133            allowBlank = allow_blank
134        self.allowBlank = allowBlank
135        self.showErrorMessage = showErrorMessage
136        self.showInputMessage = showInputMessage
137        self.type = type
138        self.promptTitle = promptTitle
139        self.errorStyle = errorStyle
140        self.error = error
141        self.prompt = prompt
142        self.errorTitle = errorTitle
143
144
145    def add(self, cell):
146        """Adds a cell or cell coordinate to this validator"""
147        if hasattr(cell, "coordinate"):
148            cell = cell.coordinate
149        self.sqref += cell
150
151
152    def __contains__(self, cell):
153        if hasattr(cell, "coordinate"):
154            cell = cell.coordinate
155        return cell in self.sqref
156
157
158class DataValidationList(Serialisable):
159
160    tagname = "dataValidations"
161
162    disablePrompts = Bool(allow_none=True)
163    xWindow = Integer(allow_none=True)
164    yWindow = Integer(allow_none=True)
165    dataValidation = Sequence(expected_type=DataValidation)
166
167    __elements__ = ('dataValidation',)
168    __attrs__ = ('disablePrompts', 'xWindow', 'yWindow', 'count')
169
170    def __init__(self,
171                 disablePrompts=None,
172                 xWindow=None,
173                 yWindow=None,
174                 count=None,
175                 dataValidation=(),
176                ):
177        self.disablePrompts = disablePrompts
178        self.xWindow = xWindow
179        self.yWindow = yWindow
180        self.dataValidation = dataValidation
181
182
183    @property
184    def count(self):
185        return len(self)
186
187
188    def __len__(self):
189        return len(self.dataValidation)
190
191
192    def append(self, dv):
193        self.dataValidation.append(dv)
194
195
196    def to_tree(self, tagname=None):
197        """
198        Need to skip validations that have no cell ranges
199        """
200        ranges = self.dataValidation # copy
201        self.dataValidation = [r for r in self.dataValidation if bool(r.sqref)]
202        xml = super(DataValidationList, self).to_tree(tagname)
203        self.dataValidation = ranges
204        return xml
205