1from . import idnadata
2import bisect
3import unicodedata
4import re
5import sys
6from .intranges import intranges_contain
7
8_virama_combining_class = 9
9_alabel_prefix = b'xn--'
10_unicode_dots_re = re.compile(u'[\u002e\u3002\uff0e\uff61]')
11
12if sys.version_info[0] >= 3:
13    unicode = str
14    unichr = chr
15
16class IDNAError(UnicodeError):
17    """ Base exception for all IDNA-encoding related problems """
18    pass
19
20
21class IDNABidiError(IDNAError):
22    """ Exception when bidirectional requirements are not satisfied """
23    pass
24
25
26class InvalidCodepoint(IDNAError):
27    """ Exception when a disallowed or unallocated codepoint is used """
28    pass
29
30
31class InvalidCodepointContext(IDNAError):
32    """ Exception when the codepoint is not valid in the context it is used """
33    pass
34
35
36def _combining_class(cp):
37    v = unicodedata.combining(unichr(cp))
38    if v == 0:
39        if not unicodedata.name(unichr(cp)):
40            raise ValueError("Unknown character in unicodedata")
41    return v
42
43def _is_script(cp, script):
44    return intranges_contain(ord(cp), idnadata.scripts[script])
45
46def _punycode(s):
47    return s.encode('punycode')
48
49def _unot(s):
50    return 'U+{0:04X}'.format(s)
51
52
53def valid_label_length(label):
54
55    if len(label) > 63:
56        return False
57    return True
58
59
60def valid_string_length(label, trailing_dot):
61
62    if len(label) > (254 if trailing_dot else 253):
63        return False
64    return True
65
66
67def check_bidi(label, check_ltr=False):
68
69    # Bidi rules should only be applied if string contains RTL characters
70    bidi_label = False
71    for (idx, cp) in enumerate(label, 1):
72        direction = unicodedata.bidirectional(cp)
73        if direction == '':
74            # String likely comes from a newer version of Unicode
75            raise IDNABidiError('Unknown directionality in label {0} at position {1}'.format(repr(label), idx))
76        if direction in ['R', 'AL', 'AN']:
77            bidi_label = True
78    if not bidi_label and not check_ltr:
79        return True
80
81    # Bidi rule 1
82    direction = unicodedata.bidirectional(label[0])
83    if direction in ['R', 'AL']:
84        rtl = True
85    elif direction == 'L':
86        rtl = False
87    else:
88        raise IDNABidiError('First codepoint in label {0} must be directionality L, R or AL'.format(repr(label)))
89
90    valid_ending = False
91    number_type = False
92    for (idx, cp) in enumerate(label, 1):
93        direction = unicodedata.bidirectional(cp)
94
95        if rtl:
96            # Bidi rule 2
97            if not direction in ['R', 'AL', 'AN', 'EN', 'ES', 'CS', 'ET', 'ON', 'BN', 'NSM']:
98                raise IDNABidiError('Invalid direction for codepoint at position {0} in a right-to-left label'.format(idx))
99            # Bidi rule 3
100            if direction in ['R', 'AL', 'EN', 'AN']:
101                valid_ending = True
102            elif direction != 'NSM':
103                valid_ending = False
104            # Bidi rule 4
105            if direction in ['AN', 'EN']:
106                if not number_type:
107                    number_type = direction
108                else:
109                    if number_type != direction:
110                        raise IDNABidiError('Can not mix numeral types in a right-to-left label')
111        else:
112            # Bidi rule 5
113            if not direction in ['L', 'EN', 'ES', 'CS', 'ET', 'ON', 'BN', 'NSM']:
114                raise IDNABidiError('Invalid direction for codepoint at position {0} in a left-to-right label'.format(idx))
115            # Bidi rule 6
116            if direction in ['L', 'EN']:
117                valid_ending = True
118            elif direction != 'NSM':
119                valid_ending = False
120
121    if not valid_ending:
122        raise IDNABidiError('Label ends with illegal codepoint directionality')
123
124    return True
125
126
127def check_initial_combiner(label):
128
129    if unicodedata.category(label[0])[0] == 'M':
130        raise IDNAError('Label begins with an illegal combining character')
131    return True
132
133
134def check_hyphen_ok(label):
135
136    if label[2:4] == '--':
137        raise IDNAError('Label has disallowed hyphens in 3rd and 4th position')
138    if label[0] == '-' or label[-1] == '-':
139        raise IDNAError('Label must not start or end with a hyphen')
140    return True
141
142
143def check_nfc(label):
144
145    if unicodedata.normalize('NFC', label) != label:
146        raise IDNAError('Label must be in Normalization Form C')
147
148
149def valid_contextj(label, pos):
150
151    cp_value = ord(label[pos])
152
153    if cp_value == 0x200c:
154
155        if pos > 0:
156            if _combining_class(ord(label[pos - 1])) == _virama_combining_class:
157                return True
158
159        ok = False
160        for i in range(pos-1, -1, -1):
161            joining_type = idnadata.joining_types.get(ord(label[i]))
162            if joining_type == ord('T'):
163                continue
164            if joining_type in [ord('L'), ord('D')]:
165                ok = True
166                break
167
168        if not ok:
169            return False
170
171        ok = False
172        for i in range(pos+1, len(label)):
173            joining_type = idnadata.joining_types.get(ord(label[i]))
174            if joining_type == ord('T'):
175                continue
176            if joining_type in [ord('R'), ord('D')]:
177                ok = True
178                break
179        return ok
180
181    if cp_value == 0x200d:
182
183        if pos > 0:
184            if _combining_class(ord(label[pos - 1])) == _virama_combining_class:
185                return True
186        return False
187
188    else:
189
190        return False
191
192
193def valid_contexto(label, pos, exception=False):
194
195    cp_value = ord(label[pos])
196
197    if cp_value == 0x00b7:
198        if 0 < pos < len(label)-1:
199            if ord(label[pos - 1]) == 0x006c and ord(label[pos + 1]) == 0x006c:
200                return True
201        return False
202
203    elif cp_value == 0x0375:
204        if pos < len(label)-1 and len(label) > 1:
205            return _is_script(label[pos + 1], 'Greek')
206        return False
207
208    elif cp_value == 0x05f3 or cp_value == 0x05f4:
209        if pos > 0:
210            return _is_script(label[pos - 1], 'Hebrew')
211        return False
212
213    elif cp_value == 0x30fb:
214        for cp in label:
215            if cp == u'\u30fb':
216                continue
217            if _is_script(cp, 'Hiragana') or _is_script(cp, 'Katakana') or _is_script(cp, 'Han'):
218                return True
219        return False
220
221    elif 0x660 <= cp_value <= 0x669:
222        for cp in label:
223            if 0x6f0 <= ord(cp) <= 0x06f9:
224                return False
225        return True
226
227    elif 0x6f0 <= cp_value <= 0x6f9:
228        for cp in label:
229            if 0x660 <= ord(cp) <= 0x0669:
230                return False
231        return True
232
233
234def check_label(label):
235
236    if isinstance(label, (bytes, bytearray)):
237        label = label.decode('utf-8')
238    if len(label) == 0:
239        raise IDNAError('Empty Label')
240
241    check_nfc(label)
242    check_hyphen_ok(label)
243    check_initial_combiner(label)
244
245    for (pos, cp) in enumerate(label):
246        cp_value = ord(cp)
247        if intranges_contain(cp_value, idnadata.codepoint_classes['PVALID']):
248            continue
249        elif intranges_contain(cp_value, idnadata.codepoint_classes['CONTEXTJ']):
250            try:
251                if not valid_contextj(label, pos):
252                    raise InvalidCodepointContext('Joiner {0} not allowed at position {1} in {2}'.format(
253                        _unot(cp_value), pos+1, repr(label)))
254            except ValueError:
255                raise IDNAError('Unknown codepoint adjacent to joiner {0} at position {1} in {2}'.format(
256                    _unot(cp_value), pos+1, repr(label)))
257        elif intranges_contain(cp_value, idnadata.codepoint_classes['CONTEXTO']):
258            if not valid_contexto(label, pos):
259                raise InvalidCodepointContext('Codepoint {0} not allowed at position {1} in {2}'.format(_unot(cp_value), pos+1, repr(label)))
260        else:
261            raise InvalidCodepoint('Codepoint {0} at position {1} of {2} not allowed'.format(_unot(cp_value), pos+1, repr(label)))
262
263    check_bidi(label)
264
265
266def alabel(label):
267
268    try:
269        label = label.encode('ascii')
270        ulabel(label)
271        if not valid_label_length(label):
272            raise IDNAError('Label too long')
273        return label
274    except UnicodeEncodeError:
275        pass
276
277    if not label:
278        raise IDNAError('No Input')
279
280    label = unicode(label)
281    check_label(label)
282    label = _punycode(label)
283    label = _alabel_prefix + label
284
285    if not valid_label_length(label):
286        raise IDNAError('Label too long')
287
288    return label
289
290
291def ulabel(label):
292
293    if not isinstance(label, (bytes, bytearray)):
294        try:
295            label = label.encode('ascii')
296        except UnicodeEncodeError:
297            check_label(label)
298            return label
299
300    label = label.lower()
301    if label.startswith(_alabel_prefix):
302        label = label[len(_alabel_prefix):]
303        if not label:
304            raise IDNAError('Malformed A-label, no Punycode eligible content found')
305        if label.decode('ascii')[-1] == '-':
306            raise IDNAError('A-label must not end with a hyphen')
307    else:
308        check_label(label)
309        return label.decode('ascii')
310
311    label = label.decode('punycode')
312    check_label(label)
313    return label
314
315
316def uts46_remap(domain, std3_rules=True, transitional=False):
317    """Re-map the characters in the string according to UTS46 processing."""
318    from .uts46data import uts46data
319    output = u""
320    try:
321        for pos, char in enumerate(domain):
322            code_point = ord(char)
323            uts46row = uts46data[code_point if code_point < 256 else
324                bisect.bisect_left(uts46data, (code_point, "Z")) - 1]
325            status = uts46row[1]
326            replacement = uts46row[2] if len(uts46row) == 3 else None
327            if (status == "V" or
328                    (status == "D" and not transitional) or
329                    (status == "3" and not std3_rules and replacement is None)):
330                output += char
331            elif replacement is not None and (status == "M" or
332                    (status == "3" and not std3_rules) or
333                    (status == "D" and transitional)):
334                output += replacement
335            elif status != "I":
336                raise IndexError()
337        return unicodedata.normalize("NFC", output)
338    except IndexError:
339        raise InvalidCodepoint(
340            "Codepoint {0} not allowed at position {1} in {2}".format(
341            _unot(code_point), pos + 1, repr(domain)))
342
343
344def encode(s, strict=False, uts46=False, std3_rules=False, transitional=False):
345
346    if isinstance(s, (bytes, bytearray)):
347        s = s.decode("ascii")
348    if uts46:
349        s = uts46_remap(s, std3_rules, transitional)
350    trailing_dot = False
351    result = []
352    if strict:
353        labels = s.split('.')
354    else:
355        labels = _unicode_dots_re.split(s)
356    if not labels or labels == ['']:
357        raise IDNAError('Empty domain')
358    if labels[-1] == '':
359        del labels[-1]
360        trailing_dot = True
361    for label in labels:
362        s = alabel(label)
363        if s:
364            result.append(s)
365        else:
366            raise IDNAError('Empty label')
367    if trailing_dot:
368        result.append(b'')
369    s = b'.'.join(result)
370    if not valid_string_length(s, trailing_dot):
371        raise IDNAError('Domain too long')
372    return s
373
374
375def decode(s, strict=False, uts46=False, std3_rules=False):
376
377    if isinstance(s, (bytes, bytearray)):
378        s = s.decode("ascii")
379    if uts46:
380        s = uts46_remap(s, std3_rules, False)
381    trailing_dot = False
382    result = []
383    if not strict:
384        labels = _unicode_dots_re.split(s)
385    else:
386        labels = s.split(u'.')
387    if not labels or labels == ['']:
388        raise IDNAError('Empty domain')
389    if not labels[-1]:
390        del labels[-1]
391        trailing_dot = True
392    for label in labels:
393        s = ulabel(label)
394        if s:
395            result.append(s)
396        else:
397            raise IDNAError('Empty label')
398    if trailing_dot:
399        result.append(u'')
400    return u'.'.join(result)
401