1import re
2
3from django import forms
4from django.conf import settings
5from django.forms.models import fields_for_model
6
7from utilities.choices import unpack_grouped_choices
8from utilities.querysets import RestrictedQuerySet
9from .constants import *
10
11__all__ = (
12    'add_blank_choice',
13    'expand_alphanumeric_pattern',
14    'expand_ipaddress_pattern',
15    'form_from_model',
16    'get_selected_values',
17    'parse_alphanumeric_range',
18    'parse_numeric_range',
19    'restrict_form_fields',
20    'parse_csv',
21    'validate_csv',
22)
23
24
25def parse_numeric_range(string, base=10):
26    """
27    Expand a numeric range (continuous or not) into a decimal or
28    hexadecimal list, as specified by the base parameter
29      '0-3,5' => [0, 1, 2, 3, 5]
30      '2,8-b,d,f' => [2, 8, 9, a, b, d, f]
31    """
32    values = list()
33    for dash_range in string.split(','):
34        try:
35            begin, end = dash_range.split('-')
36        except ValueError:
37            begin, end = dash_range, dash_range
38        try:
39            begin, end = int(begin.strip(), base=base), int(end.strip(), base=base) + 1
40        except ValueError:
41            raise forms.ValidationError(f'Range "{dash_range}" is invalid.')
42        values.extend(range(begin, end))
43    return list(set(values))
44
45
46def parse_alphanumeric_range(string):
47    """
48    Expand an alphanumeric range (continuous or not) into a list.
49    'a-d,f' => [a, b, c, d, f]
50    '0-3,a-d' => [0, 1, 2, 3, a, b, c, d]
51    """
52    values = []
53    for dash_range in string.split(','):
54        try:
55            begin, end = dash_range.split('-')
56            vals = begin + end
57            # Break out of loop if there's an invalid pattern to return an error
58            if (not (vals.isdigit() or vals.isalpha())) or (vals.isalpha() and not (vals.isupper() or vals.islower())):
59                return []
60        except ValueError:
61            begin, end = dash_range, dash_range
62        if begin.isdigit() and end.isdigit():
63            for n in list(range(int(begin), int(end) + 1)):
64                values.append(n)
65        else:
66            # Value-based
67            if begin == end:
68                values.append(begin)
69            # Range-based
70            else:
71                # Not a valid range (more than a single character)
72                if not len(begin) == len(end) == 1:
73                    raise forms.ValidationError(f'Range "{dash_range}" is invalid.')
74                for n in list(range(ord(begin), ord(end) + 1)):
75                    values.append(chr(n))
76    return values
77
78
79def expand_alphanumeric_pattern(string):
80    """
81    Expand an alphabetic pattern into a list of strings.
82    """
83    lead, pattern, remnant = re.split(ALPHANUMERIC_EXPANSION_PATTERN, string, maxsplit=1)
84    parsed_range = parse_alphanumeric_range(pattern)
85    for i in parsed_range:
86        if re.search(ALPHANUMERIC_EXPANSION_PATTERN, remnant):
87            for string in expand_alphanumeric_pattern(remnant):
88                yield "{}{}{}".format(lead, i, string)
89        else:
90            yield "{}{}{}".format(lead, i, remnant)
91
92
93def expand_ipaddress_pattern(string, family):
94    """
95    Expand an IP address pattern into a list of strings. Examples:
96      '192.0.2.[1,2,100-250]/24' => ['192.0.2.1/24', '192.0.2.2/24', '192.0.2.100/24' ... '192.0.2.250/24']
97      '2001:db8:0:[0,fd-ff]::/64' => ['2001:db8:0:0::/64', '2001:db8:0:fd::/64', ... '2001:db8:0:ff::/64']
98    """
99    if family not in [4, 6]:
100        raise Exception("Invalid IP address family: {}".format(family))
101    if family == 4:
102        regex = IP4_EXPANSION_PATTERN
103        base = 10
104    else:
105        regex = IP6_EXPANSION_PATTERN
106        base = 16
107    lead, pattern, remnant = re.split(regex, string, maxsplit=1)
108    parsed_range = parse_numeric_range(pattern, base)
109    for i in parsed_range:
110        if re.search(regex, remnant):
111            for string in expand_ipaddress_pattern(remnant, family):
112                yield ''.join([lead, format(i, 'x' if family == 6 else 'd'), string])
113        else:
114            yield ''.join([lead, format(i, 'x' if family == 6 else 'd'), remnant])
115
116
117def get_selected_values(form, field_name):
118    """
119    Return the list of selected human-friendly values for a form field
120    """
121    if not hasattr(form, 'cleaned_data'):
122        form.is_valid()
123    filter_data = form.cleaned_data.get(field_name)
124    field = form.fields[field_name]
125
126    # Non-selection field
127    if not hasattr(field, 'choices'):
128        return [str(filter_data)]
129
130    # Get choice labels
131    if type(field.choices) is forms.models.ModelChoiceIterator:
132        # Field uses dynamic choices: show all that have been populated on the widget
133        values = [
134            subwidget.choice_label for subwidget in form[field_name].subwidgets
135        ]
136
137    else:
138        # Static selection field
139        choices = unpack_grouped_choices(field.choices)
140        if type(filter_data) not in (list, tuple):
141            filter_data = [filter_data]  # Ensure filter data is iterable
142        values = [
143            label for value, label in choices if str(value) in filter_data or None in filter_data
144        ]
145
146    if hasattr(field, 'null_option'):
147        # If the field has a `null_option` attribute set and it is selected,
148        # add it to the field's grouped choices.
149        if field.null_option is not None and None in filter_data:
150            values.append(field.null_option)
151
152    return values
153
154
155def add_blank_choice(choices):
156    """
157    Add a blank choice to the beginning of a choices list.
158    """
159    return ((None, '---------'),) + tuple(choices)
160
161
162def form_from_model(model, fields):
163    """
164    Return a Form class with the specified fields derived from a model. This is useful when we need a form to be used
165    for creating objects, but want to avoid the model's validation (e.g. for bulk create/edit functions). All fields
166    are marked as not required.
167    """
168    form_fields = fields_for_model(model, fields=fields)
169    for field in form_fields.values():
170        field.required = False
171
172    return type('FormFromModel', (forms.Form,), form_fields)
173
174
175def restrict_form_fields(form, user, action='view'):
176    """
177    Restrict all form fields which reference a RestrictedQuerySet. This ensures that users see only permitted objects
178    as available choices.
179    """
180    for field in form.fields.values():
181        if hasattr(field, 'queryset') and issubclass(field.queryset.__class__, RestrictedQuerySet):
182            field.queryset = field.queryset.restrict(user, action)
183
184
185def parse_csv(reader):
186    """
187    Parse a csv_reader object into a headers dictionary and a list of records dictionaries. Raise an error
188    if the records are formatted incorrectly. Return headers and records as a tuple.
189    """
190    records = []
191    headers = {}
192
193    # Consume the first line of CSV data as column headers. Create a dictionary mapping each header to an optional
194    # "to" field specifying how the related object is being referenced. For example, importing a Device might use a
195    # `site.slug` header, to indicate the related site is being referenced by its slug.
196
197    for header in next(reader):
198        if '.' in header:
199            field, to_field = header.split('.', 1)
200            headers[field] = to_field
201        else:
202            headers[header] = None
203
204    # Parse CSV rows into a list of dictionaries mapped from the column headers.
205    for i, row in enumerate(reader, start=1):
206        if len(row) != len(headers):
207            raise forms.ValidationError(
208                f"Row {i}: Expected {len(headers)} columns but found {len(row)}"
209            )
210        row = [col.strip() for col in row]
211        record = dict(zip(headers.keys(), row))
212        records.append(record)
213
214    return headers, records
215
216
217def validate_csv(headers, fields, required_fields):
218    """
219    Validate that parsed csv data conforms to the object's available fields. Raise validation errors
220    if parsed csv data contains invalid headers or does not contain required headers.
221    """
222    # Validate provided column headers
223    for field, to_field in headers.items():
224        if field not in fields:
225            raise forms.ValidationError(f'Unexpected column header "{field}" found.')
226        if to_field and not hasattr(fields[field], 'to_field_name'):
227            raise forms.ValidationError(f'Column "{field}" is not a related object; cannot use dots')
228        if to_field and not hasattr(fields[field].queryset.model, to_field):
229            raise forms.ValidationError(f'Invalid related object attribute for column "{field}": {to_field}')
230
231    # Validate required fields
232    for f in required_fields:
233        if f not in headers:
234            raise forms.ValidationError(f'Required column header "{f}" not found.')
235