1import string, re, sys, datetime
2from .core import TomlError
3
4if sys.version_info[0] == 2:
5    _chr = unichr
6else:
7    _chr = chr
8
9def load(fin, translate=lambda t, x, v: v):
10    return loads(fin.read(), translate=translate, filename=fin.name)
11
12def loads(s, filename='<string>', translate=lambda t, x, v: v):
13    if isinstance(s, bytes):
14        s = s.decode('utf-8')
15
16    s = s.replace('\r\n', '\n')
17
18    root = {}
19    tables = {}
20    scope = root
21
22    src = _Source(s, filename=filename)
23    ast = _p_toml(src)
24
25    def error(msg):
26        raise TomlError(msg, pos[0], pos[1], filename)
27
28    def process_value(v):
29        kind, text, value, pos = v
30        if kind == 'str' and value.startswith('\n'):
31            value = value[1:]
32        if kind == 'array':
33            if value and any(k != value[0][0] for k, t, v, p in value[1:]):
34                error('array-type-mismatch')
35            value = [process_value(item) for item in value]
36        elif kind == 'table':
37            value = dict([(k, process_value(value[k])) for k in value])
38        return translate(kind, text, value)
39
40    for kind, value, pos in ast:
41        if kind == 'kv':
42            k, v = value
43            if k in scope:
44                error('duplicate_keys. Key "{0}" was used more than once.'.format(k))
45            scope[k] = process_value(v)
46        else:
47            is_table_array = (kind == 'table_array')
48            cur = tables
49            for name in value[:-1]:
50                if isinstance(cur.get(name), list):
51                    d, cur = cur[name][-1]
52                else:
53                    d, cur = cur.setdefault(name, (None, {}))
54
55            scope = {}
56            name = value[-1]
57            if name not in cur:
58                if is_table_array:
59                    cur[name] = [(scope, {})]
60                else:
61                    cur[name] = (scope, {})
62            elif isinstance(cur[name], list):
63                if not is_table_array:
64                    error('table_type_mismatch')
65                cur[name].append((scope, {}))
66            else:
67                if is_table_array:
68                    error('table_type_mismatch')
69                old_scope, next_table = cur[name]
70                if old_scope is not None:
71                    error('duplicate_tables')
72                cur[name] = (scope, next_table)
73
74    def merge_tables(scope, tables):
75        if scope is None:
76            scope = {}
77        for k in tables:
78            if k in scope:
79                error('key_table_conflict')
80            v = tables[k]
81            if isinstance(v, list):
82                scope[k] = [merge_tables(sc, tbl) for sc, tbl in v]
83            else:
84                scope[k] = merge_tables(v[0], v[1])
85        return scope
86
87    return merge_tables(root, tables)
88
89class _Source:
90    def __init__(self, s, filename=None):
91        self.s = s
92        self._pos = (1, 1)
93        self._last = None
94        self._filename = filename
95        self.backtrack_stack = []
96
97    def last(self):
98        return self._last
99
100    def pos(self):
101        return self._pos
102
103    def fail(self):
104        return self._expect(None)
105
106    def consume_dot(self):
107        if self.s:
108            self._last = self.s[0]
109            self.s = self[1:]
110            self._advance(self._last)
111            return self._last
112        return None
113
114    def expect_dot(self):
115        return self._expect(self.consume_dot())
116
117    def consume_eof(self):
118        if not self.s:
119            self._last = ''
120            return True
121        return False
122
123    def expect_eof(self):
124        return self._expect(self.consume_eof())
125
126    def consume(self, s):
127        if self.s.startswith(s):
128            self.s = self.s[len(s):]
129            self._last = s
130            self._advance(s)
131            return True
132        return False
133
134    def expect(self, s):
135        return self._expect(self.consume(s))
136
137    def consume_re(self, re):
138        m = re.match(self.s)
139        if m:
140            self.s = self.s[len(m.group(0)):]
141            self._last = m
142            self._advance(m.group(0))
143            return m
144        return None
145
146    def expect_re(self, re):
147        return self._expect(self.consume_re(re))
148
149    def __enter__(self):
150        self.backtrack_stack.append((self.s, self._pos))
151
152    def __exit__(self, type, value, traceback):
153        if type is None:
154            self.backtrack_stack.pop()
155        else:
156            self.s, self._pos = self.backtrack_stack.pop()
157        return type == TomlError
158
159    def commit(self):
160        self.backtrack_stack[-1] = (self.s, self._pos)
161
162    def _expect(self, r):
163        if not r:
164            raise TomlError('msg', self._pos[0], self._pos[1], self._filename)
165        return r
166
167    def _advance(self, s):
168        suffix_pos = s.rfind('\n')
169        if suffix_pos == -1:
170            self._pos = (self._pos[0], self._pos[1] + len(s))
171        else:
172            self._pos = (self._pos[0] + s.count('\n'), len(s) - suffix_pos)
173
174_ews_re = re.compile(r'(?:[ \t]|#[^\n]*\n|#[^\n]*\Z|\n)*')
175def _p_ews(s):
176    s.expect_re(_ews_re)
177
178_ws_re = re.compile(r'[ \t]*')
179def _p_ws(s):
180    s.expect_re(_ws_re)
181
182_escapes = { 'b': '\b', 'n': '\n', 'r': '\r', 't': '\t', '"': '"', '\'': '\'',
183    '\\': '\\', '/': '/', 'f': '\f' }
184
185_basicstr_re = re.compile(r'[^"\\\000-\037]*')
186_short_uni_re = re.compile(r'u([0-9a-fA-F]{4})')
187_long_uni_re = re.compile(r'U([0-9a-fA-F]{8})')
188_escapes_re = re.compile('[bnrt"\'\\\\/f]')
189_newline_esc_re = re.compile('\n[ \t\n]*')
190def _p_basicstr_content(s, content=_basicstr_re):
191    res = []
192    while True:
193        res.append(s.expect_re(content).group(0))
194        if not s.consume('\\'):
195            break
196        if s.consume_re(_newline_esc_re):
197            pass
198        elif s.consume_re(_short_uni_re) or s.consume_re(_long_uni_re):
199            res.append(_chr(int(s.last().group(1), 16)))
200        else:
201            s.expect_re(_escapes_re)
202            res.append(_escapes[s.last().group(0)])
203    return ''.join(res)
204
205_key_re = re.compile(r'[0-9a-zA-Z-_]+')
206def _p_key(s):
207    with s:
208        s.expect('"')
209        r = _p_basicstr_content(s, _basicstr_re)
210        s.expect('"')
211        return r
212    return s.expect_re(_key_re).group(0)
213
214_float_re = re.compile(r'[+-]?(?:0|[1-9](?:_?\d)*)(?:\.\d(?:_?\d)*)?(?:[eE][+-]?(?:\d(?:_?\d)*))?')
215_datetime_re = re.compile(r'(\d{4})-(\d{2})-(\d{2})T(\d{2}):(\d{2}):(\d{2})(\.\d+)?(?:Z|([+-]\d{2}):(\d{2}))')
216
217_basicstr_ml_re = re.compile(r'(?:(?:|"|"")[^"\\\000-\011\013-\037])*')
218_litstr_re = re.compile(r"[^'\000-\037]*")
219_litstr_ml_re = re.compile(r"(?:(?:|'|'')(?:[^'\000-\011\013-\037]))*")
220def _p_value(s):
221    pos = s.pos()
222
223    if s.consume('true'):
224        return 'bool', s.last(), True, pos
225    if s.consume('false'):
226        return 'bool', s.last(), False, pos
227
228    if s.consume('"'):
229        if s.consume('""'):
230            r = _p_basicstr_content(s, _basicstr_ml_re)
231            s.expect('"""')
232        else:
233            r = _p_basicstr_content(s, _basicstr_re)
234            s.expect('"')
235        return 'str', r, r, pos
236
237    if s.consume('\''):
238        if s.consume('\'\''):
239            r = s.expect_re(_litstr_ml_re).group(0)
240            s.expect('\'\'\'')
241        else:
242            r = s.expect_re(_litstr_re).group(0)
243            s.expect('\'')
244        return 'str', r, r, pos
245
246    if s.consume_re(_datetime_re):
247        m = s.last()
248        s0 = m.group(0)
249        r = map(int, m.groups()[:6])
250        if m.group(7):
251            micro = float(m.group(7))
252        else:
253            micro = 0
254
255        if m.group(8):
256            g = int(m.group(8), 10) * 60 + int(m.group(9), 10)
257            tz = _TimeZone(datetime.timedelta(0, g * 60))
258        else:
259            tz = _TimeZone(datetime.timedelta(0, 0))
260
261        y, m, d, H, M, S = r
262        dt = datetime.datetime(y, m, d, H, M, S, int(micro * 1000000), tz)
263        return 'datetime', s0, dt, pos
264
265    if s.consume_re(_float_re):
266        m = s.last().group(0)
267        r = m.replace('_','')
268        if '.' in m or 'e' in m or 'E' in m:
269            return 'float', m, float(r), pos
270        else:
271            return 'int', m, int(r, 10), pos
272
273    if s.consume('['):
274        items = []
275        with s:
276            while True:
277                _p_ews(s)
278                items.append(_p_value(s))
279                s.commit()
280                _p_ews(s)
281                s.expect(',')
282                s.commit()
283        _p_ews(s)
284        s.expect(']')
285        return 'array', None, items, pos
286
287    if s.consume('{'):
288        _p_ws(s)
289        items = {}
290        if not s.consume('}'):
291            k = _p_key(s)
292            _p_ws(s)
293            s.expect('=')
294            _p_ws(s)
295            items[k] = _p_value(s)
296            _p_ws(s)
297            while s.consume(','):
298                _p_ws(s)
299                k = _p_key(s)
300                _p_ws(s)
301                s.expect('=')
302                _p_ws(s)
303                items[k] = _p_value(s)
304                _p_ws(s)
305            s.expect('}')
306        return 'table', None, items, pos
307
308    s.fail()
309
310def _p_stmt(s):
311    pos = s.pos()
312    if s.consume(   '['):
313        is_array = s.consume('[')
314        _p_ws(s)
315        keys = [_p_key(s)]
316        _p_ws(s)
317        while s.consume('.'):
318            _p_ws(s)
319            keys.append(_p_key(s))
320            _p_ws(s)
321        s.expect(']')
322        if is_array:
323            s.expect(']')
324        return 'table_array' if is_array else 'table', keys, pos
325
326    key = _p_key(s)
327    _p_ws(s)
328    s.expect('=')
329    _p_ws(s)
330    value = _p_value(s)
331    return 'kv', (key, value), pos
332
333_stmtsep_re = re.compile(r'(?:[ \t]*(?:#[^\n]*)?\n)+[ \t]*')
334def _p_toml(s):
335    stmts = []
336    _p_ews(s)
337    with s:
338        stmts.append(_p_stmt(s))
339        while True:
340            s.commit()
341            s.expect_re(_stmtsep_re)
342            stmts.append(_p_stmt(s))
343    _p_ews(s)
344    s.expect_eof()
345    return stmts
346
347class _TimeZone(datetime.tzinfo):
348    def __init__(self, offset):
349        self._offset = offset
350
351    def utcoffset(self, dt):
352        return self._offset
353
354    def dst(self, dt):
355        return None
356
357    def tzname(self, dt):
358        m = self._offset.total_seconds() // 60
359        if m < 0:
360            res = '-'
361            m = -m
362        else:
363            res = '+'
364        h = m // 60
365        m = m - h * 60
366        return '{}{:.02}{:.02}'.format(res, h, m)
367