1from . import idnadata
2import bisect
3import unicodedata
4import re
5import sys
6from .intranges import intranges_contain
7
8_virama_combining_class = 9
9_alabel_prefix = b'xn--'
10_unicode_dots_re = re.compile(u'[\u002e\u3002\uff0e\uff61]')
11
12if sys.version_info[0] == 3:
13    unicode = str
14    unichr = chr
15
16class IDNAError(UnicodeError):
17    """ Base exception for all IDNA-encoding related problems """
18    pass
19
20
21class IDNABidiError(IDNAError):
22    """ Exception when bidirectional requirements are not satisfied """
23    pass
24
25
26class InvalidCodepoint(IDNAError):
27    """ Exception when a disallowed or unallocated codepoint is used """
28    pass
29
30
31class InvalidCodepointContext(IDNAError):
32    """ Exception when the codepoint is not valid in the context it is used """
33    pass
34
35
36def _combining_class(cp):
37    v = unicodedata.combining(unichr(cp))
38    if v == 0:
39        if not unicodedata.name(unichr(cp)):
40            raise ValueError("Unknown character in unicodedata")
41    return v
42
43def _is_script(cp, script):
44    return intranges_contain(ord(cp), idnadata.scripts[script])
45
46def _punycode(s):
47    return s.encode('punycode')
48
49def _unot(s):
50    return 'U+{0:04X}'.format(s)
51
52
53def valid_label_length(label):
54
55    if len(label) > 63:
56        return False
57    return True
58
59
60def valid_string_length(label, trailing_dot):
61
62    if len(label) > (254 if trailing_dot else 253):
63        return False
64    return True
65
66
67def check_bidi(label, check_ltr=False):
68
69    # Bidi rules should only be applied if string contains RTL characters
70    bidi_label = False
71    for (idx, cp) in enumerate(label, 1):
72        direction = unicodedata.bidirectional(cp)
73        if direction == '':
74            # String likely comes from a newer version of Unicode
75            raise IDNABidiError('Unknown directionality in label {0} at position {1}'.format(repr(label), idx))
76        if direction in ['R', 'AL', 'AN']:
77            bidi_label = True
78    if not bidi_label and not check_ltr:
79        return True
80
81    # Bidi rule 1
82    direction = unicodedata.bidirectional(label[0])
83    if direction in ['R', 'AL']:
84        rtl = True
85    elif direction == 'L':
86        rtl = False
87    else:
88        raise IDNABidiError('First codepoint in label {0} must be directionality L, R or AL'.format(repr(label)))
89
90    valid_ending = False
91    number_type = False
92    for (idx, cp) in enumerate(label, 1):
93        direction = unicodedata.bidirectional(cp)
94
95        if rtl:
96            # Bidi rule 2
97            if not direction in ['R', 'AL', 'AN', 'EN', 'ES', 'CS', 'ET', 'ON', 'BN', 'NSM']:
98                raise IDNABidiError('Invalid direction for codepoint at position {0} in a right-to-left label'.format(idx))
99            # Bidi rule 3
100            if direction in ['R', 'AL', 'EN', 'AN']:
101                valid_ending = True
102            elif direction != 'NSM':
103                valid_ending = False
104            # Bidi rule 4
105            if direction in ['AN', 'EN']:
106                if not number_type:
107                    number_type = direction
108                else:
109                    if number_type != direction:
110                        raise IDNABidiError('Can not mix numeral types in a right-to-left label')
111        else:
112            # Bidi rule 5
113            if not direction in ['L', 'EN', 'ES', 'CS', 'ET', 'ON', 'BN', 'NSM']:
114                raise IDNABidiError('Invalid direction for codepoint at position {0} in a left-to-right label'.format(idx))
115            # Bidi rule 6
116            if direction in ['L', 'EN']:
117                valid_ending = True
118            elif direction != 'NSM':
119                valid_ending = False
120
121    if not valid_ending:
122        raise IDNABidiError('Label ends with illegal codepoint directionality')
123
124    return True
125
126
127def check_initial_combiner(label):
128
129    if unicodedata.category(label[0])[0] == 'M':
130        raise IDNAError('Label begins with an illegal combining character')
131    return True
132
133
134def check_hyphen_ok(label):
135
136    if label[2:4] == '--':
137        raise IDNAError('Label has disallowed hyphens in 3rd and 4th position')
138    if label[0] == '-' or label[-1] == '-':
139        raise IDNAError('Label must not start or end with a hyphen')
140    return True
141
142
143def check_nfc(label):
144
145    if unicodedata.normalize('NFC', label) != label:
146        raise IDNAError('Label must be in Normalization Form C')
147
148
149def valid_contextj(label, pos):
150
151    cp_value = ord(label[pos])
152
153    if cp_value == 0x200c:
154
155        if pos > 0:
156            if _combining_class(ord(label[pos - 1])) == _virama_combining_class:
157                return True
158
159        ok = False
160        for i in range(pos-1, -1, -1):
161            joining_type = idnadata.joining_types.get(ord(label[i]))
162            if joining_type == ord('T'):
163                continue
164            if joining_type in [ord('L'), ord('D')]:
165                ok = True
166                break
167
168        if not ok:
169            return False
170
171        ok = False
172        for i in range(pos+1, len(label)):
173            joining_type = idnadata.joining_types.get(ord(label[i]))
174            if joining_type == ord('T'):
175                continue
176            if joining_type in [ord('R'), ord('D')]:
177                ok = True
178                break
179        return ok
180
181    if cp_value == 0x200d:
182
183        if pos > 0:
184            if _combining_class(ord(label[pos - 1])) == _virama_combining_class:
185                return True
186        return False
187
188    else:
189
190        return False
191
192
193def valid_contexto(label, pos, exception=False):
194
195    cp_value = ord(label[pos])
196
197    if cp_value == 0x00b7:
198        if 0 < pos < len(label)-1:
199            if ord(label[pos - 1]) == 0x006c and ord(label[pos + 1]) == 0x006c:
200                return True
201        return False
202
203    elif cp_value == 0x0375:
204        if pos < len(label)-1 and len(label) > 1:
205            return _is_script(label[pos + 1], 'Greek')
206        return False
207
208    elif cp_value == 0x05f3 or cp_value == 0x05f4:
209        if pos > 0:
210            return _is_script(label[pos - 1], 'Hebrew')
211        return False
212
213    elif cp_value == 0x30fb:
214        for cp in label:
215            if cp == u'\u30fb':
216                continue
217            if _is_script(cp, 'Hiragana') or _is_script(cp, 'Katakana') or _is_script(cp, 'Han'):
218                return True
219        return False
220
221    elif 0x660 <= cp_value <= 0x669:
222        for cp in label:
223            if 0x6f0 <= ord(cp) <= 0x06f9:
224                return False
225        return True
226
227    elif 0x6f0 <= cp_value <= 0x6f9:
228        for cp in label:
229            if 0x660 <= ord(cp) <= 0x0669:
230                return False
231        return True
232
233
234def check_label(label):
235
236    if isinstance(label, (bytes, bytearray)):
237        label = label.decode('utf-8')
238    if len(label) == 0:
239        raise IDNAError('Empty Label')
240
241    check_nfc(label)
242    check_hyphen_ok(label)
243    check_initial_combiner(label)
244
245    for (pos, cp) in enumerate(label):
246        cp_value = ord(cp)
247        if intranges_contain(cp_value, idnadata.codepoint_classes['PVALID']):
248            continue
249        elif intranges_contain(cp_value, idnadata.codepoint_classes['CONTEXTJ']):
250            try:
251                if not valid_contextj(label, pos):
252                    raise InvalidCodepointContext('Joiner {0} not allowed at position {1} in {2}'.format(
253                        _unot(cp_value), pos+1, repr(label)))
254            except ValueError:
255                raise IDNAError('Unknown codepoint adjacent to joiner {0} at position {1} in {2}'.format(
256                    _unot(cp_value), pos+1, repr(label)))
257        elif intranges_contain(cp_value, idnadata.codepoint_classes['CONTEXTO']):
258            if not valid_contexto(label, pos):
259                raise InvalidCodepointContext('Codepoint {0} not allowed at position {1} in {2}'.format(_unot(cp_value), pos+1, repr(label)))
260        else:
261            raise InvalidCodepoint('Codepoint {0} at position {1} of {2} not allowed'.format(_unot(cp_value), pos+1, repr(label)))
262
263    check_bidi(label)
264
265
266def alabel(label):
267
268    try:
269        label = label.encode('ascii')
270        ulabel(label)
271        if not valid_label_length(label):
272            raise IDNAError('Label too long')
273        return label
274    except UnicodeEncodeError:
275        pass
276
277    if not label:
278        raise IDNAError('No Input')
279
280    label = unicode(label)
281    check_label(label)
282    label = _punycode(label)
283    label = _alabel_prefix + label
284
285    if not valid_label_length(label):
286        raise IDNAError('Label too long')
287
288    return label
289
290
291def ulabel(label):
292
293    if not isinstance(label, (bytes, bytearray)):
294        try:
295            label = label.encode('ascii')
296        except UnicodeEncodeError:
297            check_label(label)
298            return label
299
300    label = label.lower()
301    if label.startswith(_alabel_prefix):
302        label = label[len(_alabel_prefix):]
303    else:
304        check_label(label)
305        return label.decode('ascii')
306
307    label = label.decode('punycode')
308    check_label(label)
309    return label
310
311
312def uts46_remap(domain, std3_rules=True, transitional=False):
313    """Re-map the characters in the string according to UTS46 processing."""
314    from .uts46data import uts46data
315    output = u""
316    try:
317        for pos, char in enumerate(domain):
318            code_point = ord(char)
319            uts46row = uts46data[code_point if code_point < 256 else
320                bisect.bisect_left(uts46data, (code_point, "Z")) - 1]
321            status = uts46row[1]
322            replacement = uts46row[2] if len(uts46row) == 3 else None
323            if (status == "V" or
324                    (status == "D" and not transitional) or
325                    (status == "3" and not std3_rules and replacement is None)):
326                output += char
327            elif replacement is not None and (status == "M" or
328                    (status == "3" and not std3_rules) or
329                    (status == "D" and transitional)):
330                output += replacement
331            elif status != "I":
332                raise IndexError()
333        return unicodedata.normalize("NFC", output)
334    except IndexError:
335        raise InvalidCodepoint(
336            "Codepoint {0} not allowed at position {1} in {2}".format(
337            _unot(code_point), pos + 1, repr(domain)))
338
339
340def encode(s, strict=False, uts46=False, std3_rules=False, transitional=False):
341
342    if isinstance(s, (bytes, bytearray)):
343        s = s.decode("ascii")
344    if uts46:
345        s = uts46_remap(s, std3_rules, transitional)
346    trailing_dot = False
347    result = []
348    if strict:
349        labels = s.split('.')
350    else:
351        labels = _unicode_dots_re.split(s)
352    if not labels or labels == ['']:
353        raise IDNAError('Empty domain')
354    if labels[-1] == '':
355        del labels[-1]
356        trailing_dot = True
357    for label in labels:
358        s = alabel(label)
359        if s:
360            result.append(s)
361        else:
362            raise IDNAError('Empty label')
363    if trailing_dot:
364        result.append(b'')
365    s = b'.'.join(result)
366    if not valid_string_length(s, trailing_dot):
367        raise IDNAError('Domain too long')
368    return s
369
370
371def decode(s, strict=False, uts46=False, std3_rules=False):
372
373    if isinstance(s, (bytes, bytearray)):
374        s = s.decode("ascii")
375    if uts46:
376        s = uts46_remap(s, std3_rules, False)
377    trailing_dot = False
378    result = []
379    if not strict:
380        labels = _unicode_dots_re.split(s)
381    else:
382        labels = s.split(u'.')
383    if not labels or labels == ['']:
384        raise IDNAError('Empty domain')
385    if not labels[-1]:
386        del labels[-1]
387        trailing_dot = True
388    for label in labels:
389        s = ulabel(label)
390        if s:
391            result.append(s)
392        else:
393            raise IDNAError('Empty label')
394    if trailing_dot:
395        result.append(u'')
396    return u'.'.join(result)
397