1import string
2from types import MappingProxyType
3from typing import Any, BinaryIO, Dict, FrozenSet, Iterable, NamedTuple, Optional, Tuple
4import warnings
5
6from tomli._re import (
7    RE_DATETIME,
8    RE_LOCALTIME,
9    RE_NUMBER,
10    match_to_datetime,
11    match_to_localtime,
12    match_to_number,
13)
14from tomli._types import Key, ParseFloat, Pos
15
16ASCII_CTRL = frozenset(chr(i) for i in range(32)) | frozenset(chr(127))
17
18# Neither of these sets include quotation mark or backslash. They are
19# currently handled as separate cases in the parser functions.
20ILLEGAL_BASIC_STR_CHARS = ASCII_CTRL - frozenset("\t")
21ILLEGAL_MULTILINE_BASIC_STR_CHARS = ASCII_CTRL - frozenset("\t\n")
22
23ILLEGAL_LITERAL_STR_CHARS = ILLEGAL_BASIC_STR_CHARS
24ILLEGAL_MULTILINE_LITERAL_STR_CHARS = ILLEGAL_MULTILINE_BASIC_STR_CHARS
25
26ILLEGAL_COMMENT_CHARS = ILLEGAL_BASIC_STR_CHARS
27
28TOML_WS = frozenset(" \t")
29TOML_WS_AND_NEWLINE = TOML_WS | frozenset("\n")
30BARE_KEY_CHARS = frozenset(string.ascii_letters + string.digits + "-_")
31KEY_INITIAL_CHARS = BARE_KEY_CHARS | frozenset("\"'")
32HEXDIGIT_CHARS = frozenset(string.hexdigits)
33
34BASIC_STR_ESCAPE_REPLACEMENTS = MappingProxyType(
35    {
36        "\\b": "\u0008",  # backspace
37        "\\t": "\u0009",  # tab
38        "\\n": "\u000A",  # linefeed
39        "\\f": "\u000C",  # form feed
40        "\\r": "\u000D",  # carriage return
41        '\\"': "\u0022",  # quote
42        "\\\\": "\u005C",  # backslash
43    }
44)
45
46
47class TOMLDecodeError(ValueError):
48    """An error raised if a document is not valid TOML."""
49
50
51def load(fp: BinaryIO, *, parse_float: ParseFloat = float) -> Dict[str, Any]:
52    """Parse TOML from a binary file object."""
53    s_bytes = fp.read()
54    try:
55        s = s_bytes.decode()
56    except AttributeError:
57        warnings.warn(
58            "Text file object support is deprecated in favor of binary file objects."
59            ' Use `open("foo.toml", "rb")` to open the file in binary mode.',
60            DeprecationWarning,
61            stacklevel=2,
62        )
63        s = s_bytes  # type: ignore[assignment]
64    return loads(s, parse_float=parse_float)
65
66
67def loads(s: str, *, parse_float: ParseFloat = float) -> Dict[str, Any]:  # noqa: C901
68    """Parse TOML from a string."""
69
70    # The spec allows converting "\r\n" to "\n", even in string
71    # literals. Let's do so to simplify parsing.
72    src = s.replace("\r\n", "\n")
73    pos = 0
74    out = Output(NestedDict(), Flags())
75    header: Key = ()
76
77    # Parse one statement at a time
78    # (typically means one line in TOML source)
79    while True:
80        # 1. Skip line leading whitespace
81        pos = skip_chars(src, pos, TOML_WS)
82
83        # 2. Parse rules. Expect one of the following:
84        #    - end of file
85        #    - end of line
86        #    - comment
87        #    - key/value pair
88        #    - append dict to list (and move to its namespace)
89        #    - create dict (and move to its namespace)
90        # Skip trailing whitespace when applicable.
91        try:
92            char = src[pos]
93        except IndexError:
94            break
95        if char == "\n":
96            pos += 1
97            continue
98        if char in KEY_INITIAL_CHARS:
99            pos = key_value_rule(src, pos, out, header, parse_float)
100            pos = skip_chars(src, pos, TOML_WS)
101        elif char == "[":
102            try:
103                second_char: Optional[str] = src[pos + 1]
104            except IndexError:
105                second_char = None
106            if second_char == "[":
107                pos, header = create_list_rule(src, pos, out)
108            else:
109                pos, header = create_dict_rule(src, pos, out)
110            pos = skip_chars(src, pos, TOML_WS)
111        elif char != "#":
112            raise suffixed_err(src, pos, "Invalid statement")
113
114        # 3. Skip comment
115        pos = skip_comment(src, pos)
116
117        # 4. Expect end of line or end of file
118        try:
119            char = src[pos]
120        except IndexError:
121            break
122        if char != "\n":
123            raise suffixed_err(
124                src, pos, "Expected newline or end of document after a statement"
125            )
126        pos += 1
127
128    return out.data.dict
129
130
131class Flags:
132    """Flags that map to parsed keys/namespaces."""
133
134    # Marks an immutable namespace (inline array or inline table).
135    FROZEN = 0
136    # Marks a nest that has been explicitly created and can no longer
137    # be opened using the "[table]" syntax.
138    EXPLICIT_NEST = 1
139
140    def __init__(self) -> None:
141        self._flags: Dict[str, dict] = {}
142
143    def unset_all(self, key: Key) -> None:
144        cont = self._flags
145        for k in key[:-1]:
146            if k not in cont:
147                return
148            cont = cont[k]["nested"]
149        cont.pop(key[-1], None)
150
151    def set_for_relative_key(self, head_key: Key, rel_key: Key, flag: int) -> None:
152        cont = self._flags
153        for k in head_key:
154            if k not in cont:
155                cont[k] = {"flags": set(), "recursive_flags": set(), "nested": {}}
156            cont = cont[k]["nested"]
157        for k in rel_key:
158            if k in cont:
159                cont[k]["flags"].add(flag)
160            else:
161                cont[k] = {"flags": {flag}, "recursive_flags": set(), "nested": {}}
162            cont = cont[k]["nested"]
163
164    def set(self, key: Key, flag: int, *, recursive: bool) -> None:  # noqa: A003
165        cont = self._flags
166        key_parent, key_stem = key[:-1], key[-1]
167        for k in key_parent:
168            if k not in cont:
169                cont[k] = {"flags": set(), "recursive_flags": set(), "nested": {}}
170            cont = cont[k]["nested"]
171        if key_stem not in cont:
172            cont[key_stem] = {"flags": set(), "recursive_flags": set(), "nested": {}}
173        cont[key_stem]["recursive_flags" if recursive else "flags"].add(flag)
174
175    def is_(self, key: Key, flag: int) -> bool:
176        if not key:
177            return False  # document root has no flags
178        cont = self._flags
179        for k in key[:-1]:
180            if k not in cont:
181                return False
182            inner_cont = cont[k]
183            if flag in inner_cont["recursive_flags"]:
184                return True
185            cont = inner_cont["nested"]
186        key_stem = key[-1]
187        if key_stem in cont:
188            cont = cont[key_stem]
189            return flag in cont["flags"] or flag in cont["recursive_flags"]
190        return False
191
192
193class NestedDict:
194    def __init__(self) -> None:
195        # The parsed content of the TOML document
196        self.dict: Dict[str, Any] = {}
197
198    def get_or_create_nest(
199        self,
200        key: Key,
201        *,
202        access_lists: bool = True,
203    ) -> dict:
204        cont: Any = self.dict
205        for k in key:
206            if k not in cont:
207                cont[k] = {}
208            cont = cont[k]
209            if access_lists and isinstance(cont, list):
210                cont = cont[-1]
211            if not isinstance(cont, dict):
212                raise KeyError("There is no nest behind this key")
213        return cont
214
215    def append_nest_to_list(self, key: Key) -> None:
216        cont = self.get_or_create_nest(key[:-1])
217        last_key = key[-1]
218        if last_key in cont:
219            list_ = cont[last_key]
220            try:
221                list_.append({})
222            except AttributeError:
223                raise KeyError("An object other than list found behind this key")
224        else:
225            cont[last_key] = [{}]
226
227
228class Output(NamedTuple):
229    data: NestedDict
230    flags: Flags
231
232
233def skip_chars(src: str, pos: Pos, chars: Iterable[str]) -> Pos:
234    try:
235        while src[pos] in chars:
236            pos += 1
237    except IndexError:
238        pass
239    return pos
240
241
242def skip_until(
243    src: str,
244    pos: Pos,
245    expect: str,
246    *,
247    error_on: FrozenSet[str],
248    error_on_eof: bool,
249) -> Pos:
250    try:
251        new_pos = src.index(expect, pos)
252    except ValueError:
253        new_pos = len(src)
254        if error_on_eof:
255            raise suffixed_err(src, new_pos, f"Expected {expect!r}") from None
256
257    if not error_on.isdisjoint(src[pos:new_pos]):
258        while src[pos] not in error_on:
259            pos += 1
260        raise suffixed_err(src, pos, f"Found invalid character {src[pos]!r}")
261    return new_pos
262
263
264def skip_comment(src: str, pos: Pos) -> Pos:
265    try:
266        char: Optional[str] = src[pos]
267    except IndexError:
268        char = None
269    if char == "#":
270        return skip_until(
271            src, pos + 1, "\n", error_on=ILLEGAL_COMMENT_CHARS, error_on_eof=False
272        )
273    return pos
274
275
276def skip_comments_and_array_ws(src: str, pos: Pos) -> Pos:
277    while True:
278        pos_before_skip = pos
279        pos = skip_chars(src, pos, TOML_WS_AND_NEWLINE)
280        pos = skip_comment(src, pos)
281        if pos == pos_before_skip:
282            return pos
283
284
285def create_dict_rule(src: str, pos: Pos, out: Output) -> Tuple[Pos, Key]:
286    pos += 1  # Skip "["
287    pos = skip_chars(src, pos, TOML_WS)
288    pos, key = parse_key(src, pos)
289
290    if out.flags.is_(key, Flags.EXPLICIT_NEST) or out.flags.is_(key, Flags.FROZEN):
291        raise suffixed_err(src, pos, f"Can not declare {key} twice")
292    out.flags.set(key, Flags.EXPLICIT_NEST, recursive=False)
293    try:
294        out.data.get_or_create_nest(key)
295    except KeyError:
296        raise suffixed_err(src, pos, "Can not overwrite a value") from None
297
298    if not src.startswith("]", pos):
299        raise suffixed_err(src, pos, 'Expected "]" at the end of a table declaration')
300    return pos + 1, key
301
302
303def create_list_rule(src: str, pos: Pos, out: Output) -> Tuple[Pos, Key]:
304    pos += 2  # Skip "[["
305    pos = skip_chars(src, pos, TOML_WS)
306    pos, key = parse_key(src, pos)
307
308    if out.flags.is_(key, Flags.FROZEN):
309        raise suffixed_err(src, pos, f"Can not mutate immutable namespace {key}")
310    # Free the namespace now that it points to another empty list item...
311    out.flags.unset_all(key)
312    # ...but this key precisely is still prohibited from table declaration
313    out.flags.set(key, Flags.EXPLICIT_NEST, recursive=False)
314    try:
315        out.data.append_nest_to_list(key)
316    except KeyError:
317        raise suffixed_err(src, pos, "Can not overwrite a value") from None
318
319    if not src.startswith("]]", pos):
320        raise suffixed_err(src, pos, 'Expected "]]" at the end of an array declaration')
321    return pos + 2, key
322
323
324def key_value_rule(
325    src: str, pos: Pos, out: Output, header: Key, parse_float: ParseFloat
326) -> Pos:
327    pos, key, value = parse_key_value_pair(src, pos, parse_float)
328    key_parent, key_stem = key[:-1], key[-1]
329    abs_key_parent = header + key_parent
330
331    if out.flags.is_(abs_key_parent, Flags.FROZEN):
332        raise suffixed_err(
333            src, pos, f"Can not mutate immutable namespace {abs_key_parent}"
334        )
335    # Containers in the relative path can't be opened with the table syntax after this
336    out.flags.set_for_relative_key(header, key, Flags.EXPLICIT_NEST)
337    try:
338        nest = out.data.get_or_create_nest(abs_key_parent)
339    except KeyError:
340        raise suffixed_err(src, pos, "Can not overwrite a value") from None
341    if key_stem in nest:
342        raise suffixed_err(src, pos, "Can not overwrite a value")
343    # Mark inline table and array namespaces recursively immutable
344    if isinstance(value, (dict, list)):
345        out.flags.set(header + key, Flags.FROZEN, recursive=True)
346    nest[key_stem] = value
347    return pos
348
349
350def parse_key_value_pair(
351    src: str, pos: Pos, parse_float: ParseFloat
352) -> Tuple[Pos, Key, Any]:
353    pos, key = parse_key(src, pos)
354    try:
355        char: Optional[str] = src[pos]
356    except IndexError:
357        char = None
358    if char != "=":
359        raise suffixed_err(src, pos, 'Expected "=" after a key in a key/value pair')
360    pos += 1
361    pos = skip_chars(src, pos, TOML_WS)
362    pos, value = parse_value(src, pos, parse_float)
363    return pos, key, value
364
365
366def parse_key(src: str, pos: Pos) -> Tuple[Pos, Key]:
367    pos, key_part = parse_key_part(src, pos)
368    key: Key = (key_part,)
369    pos = skip_chars(src, pos, TOML_WS)
370    while True:
371        try:
372            char: Optional[str] = src[pos]
373        except IndexError:
374            char = None
375        if char != ".":
376            return pos, key
377        pos += 1
378        pos = skip_chars(src, pos, TOML_WS)
379        pos, key_part = parse_key_part(src, pos)
380        key += (key_part,)
381        pos = skip_chars(src, pos, TOML_WS)
382
383
384def parse_key_part(src: str, pos: Pos) -> Tuple[Pos, str]:
385    try:
386        char: Optional[str] = src[pos]
387    except IndexError:
388        char = None
389    if char in BARE_KEY_CHARS:
390        start_pos = pos
391        pos = skip_chars(src, pos, BARE_KEY_CHARS)
392        return pos, src[start_pos:pos]
393    if char == "'":
394        return parse_literal_str(src, pos)
395    if char == '"':
396        return parse_one_line_basic_str(src, pos)
397    raise suffixed_err(src, pos, "Invalid initial character for a key part")
398
399
400def parse_one_line_basic_str(src: str, pos: Pos) -> Tuple[Pos, str]:
401    pos += 1
402    return parse_basic_str(src, pos, multiline=False)
403
404
405def parse_array(src: str, pos: Pos, parse_float: ParseFloat) -> Tuple[Pos, list]:
406    pos += 1
407    array: list = []
408
409    pos = skip_comments_and_array_ws(src, pos)
410    if src.startswith("]", pos):
411        return pos + 1, array
412    while True:
413        pos, val = parse_value(src, pos, parse_float)
414        array.append(val)
415        pos = skip_comments_and_array_ws(src, pos)
416
417        c = src[pos : pos + 1]
418        if c == "]":
419            return pos + 1, array
420        if c != ",":
421            raise suffixed_err(src, pos, "Unclosed array")
422        pos += 1
423
424        pos = skip_comments_and_array_ws(src, pos)
425        if src.startswith("]", pos):
426            return pos + 1, array
427
428
429def parse_inline_table(src: str, pos: Pos, parse_float: ParseFloat) -> Tuple[Pos, dict]:
430    pos += 1
431    nested_dict = NestedDict()
432    flags = Flags()
433
434    pos = skip_chars(src, pos, TOML_WS)
435    if src.startswith("}", pos):
436        return pos + 1, nested_dict.dict
437    while True:
438        pos, key, value = parse_key_value_pair(src, pos, parse_float)
439        key_parent, key_stem = key[:-1], key[-1]
440        if flags.is_(key, Flags.FROZEN):
441            raise suffixed_err(src, pos, f"Can not mutate immutable namespace {key}")
442        try:
443            nest = nested_dict.get_or_create_nest(key_parent, access_lists=False)
444        except KeyError:
445            raise suffixed_err(src, pos, "Can not overwrite a value") from None
446        if key_stem in nest:
447            raise suffixed_err(src, pos, f"Duplicate inline table key {key_stem!r}")
448        nest[key_stem] = value
449        pos = skip_chars(src, pos, TOML_WS)
450        c = src[pos : pos + 1]
451        if c == "}":
452            return pos + 1, nested_dict.dict
453        if c != ",":
454            raise suffixed_err(src, pos, "Unclosed inline table")
455        if isinstance(value, (dict, list)):
456            flags.set(key, Flags.FROZEN, recursive=True)
457        pos += 1
458        pos = skip_chars(src, pos, TOML_WS)
459
460
461def parse_basic_str_escape(  # noqa: C901
462    src: str, pos: Pos, *, multiline: bool = False
463) -> Tuple[Pos, str]:
464    escape_id = src[pos : pos + 2]
465    pos += 2
466    if multiline and escape_id in {"\\ ", "\\\t", "\\\n"}:
467        # Skip whitespace until next non-whitespace character or end of
468        # the doc. Error if non-whitespace is found before newline.
469        if escape_id != "\\\n":
470            pos = skip_chars(src, pos, TOML_WS)
471            try:
472                char = src[pos]
473            except IndexError:
474                return pos, ""
475            if char != "\n":
476                raise suffixed_err(src, pos, 'Unescaped "\\" in a string')
477            pos += 1
478        pos = skip_chars(src, pos, TOML_WS_AND_NEWLINE)
479        return pos, ""
480    if escape_id == "\\u":
481        return parse_hex_char(src, pos, 4)
482    if escape_id == "\\U":
483        return parse_hex_char(src, pos, 8)
484    try:
485        return pos, BASIC_STR_ESCAPE_REPLACEMENTS[escape_id]
486    except KeyError:
487        if len(escape_id) != 2:
488            raise suffixed_err(src, pos, "Unterminated string") from None
489        raise suffixed_err(src, pos, 'Unescaped "\\" in a string') from None
490
491
492def parse_basic_str_escape_multiline(src: str, pos: Pos) -> Tuple[Pos, str]:
493    return parse_basic_str_escape(src, pos, multiline=True)
494
495
496def parse_hex_char(src: str, pos: Pos, hex_len: int) -> Tuple[Pos, str]:
497    hex_str = src[pos : pos + hex_len]
498    if len(hex_str) != hex_len or not HEXDIGIT_CHARS.issuperset(hex_str):
499        raise suffixed_err(src, pos, "Invalid hex value")
500    pos += hex_len
501    hex_int = int(hex_str, 16)
502    if not is_unicode_scalar_value(hex_int):
503        raise suffixed_err(src, pos, "Escaped character is not a Unicode scalar value")
504    return pos, chr(hex_int)
505
506
507def parse_literal_str(src: str, pos: Pos) -> Tuple[Pos, str]:
508    pos += 1  # Skip starting apostrophe
509    start_pos = pos
510    pos = skip_until(
511        src, pos, "'", error_on=ILLEGAL_LITERAL_STR_CHARS, error_on_eof=True
512    )
513    return pos + 1, src[start_pos:pos]  # Skip ending apostrophe
514
515
516def parse_multiline_str(src: str, pos: Pos, *, literal: bool) -> Tuple[Pos, str]:
517    pos += 3
518    if src.startswith("\n", pos):
519        pos += 1
520
521    if literal:
522        delim = "'"
523        end_pos = skip_until(
524            src,
525            pos,
526            "'''",
527            error_on=ILLEGAL_MULTILINE_LITERAL_STR_CHARS,
528            error_on_eof=True,
529        )
530        result = src[pos:end_pos]
531        pos = end_pos + 3
532    else:
533        delim = '"'
534        pos, result = parse_basic_str(src, pos, multiline=True)
535
536    # Add at maximum two extra apostrophes/quotes if the end sequence
537    # is 4 or 5 chars long instead of just 3.
538    if not src.startswith(delim, pos):
539        return pos, result
540    pos += 1
541    if not src.startswith(delim, pos):
542        return pos, result + delim
543    pos += 1
544    return pos, result + (delim * 2)
545
546
547def parse_basic_str(src: str, pos: Pos, *, multiline: bool) -> Tuple[Pos, str]:
548    if multiline:
549        error_on = ILLEGAL_MULTILINE_BASIC_STR_CHARS
550        parse_escapes = parse_basic_str_escape_multiline
551    else:
552        error_on = ILLEGAL_BASIC_STR_CHARS
553        parse_escapes = parse_basic_str_escape
554    result = ""
555    start_pos = pos
556    while True:
557        try:
558            char = src[pos]
559        except IndexError:
560            raise suffixed_err(src, pos, "Unterminated string") from None
561        if char == '"':
562            if not multiline:
563                return pos + 1, result + src[start_pos:pos]
564            if src.startswith('"""', pos):
565                return pos + 3, result + src[start_pos:pos]
566            pos += 1
567            continue
568        if char == "\\":
569            result += src[start_pos:pos]
570            pos, parsed_escape = parse_escapes(src, pos)
571            result += parsed_escape
572            start_pos = pos
573            continue
574        if char in error_on:
575            raise suffixed_err(src, pos, f"Illegal character {char!r}")
576        pos += 1
577
578
579def parse_value(  # noqa: C901
580    src: str, pos: Pos, parse_float: ParseFloat
581) -> Tuple[Pos, Any]:
582    try:
583        char: Optional[str] = src[pos]
584    except IndexError:
585        char = None
586
587    # Basic strings
588    if char == '"':
589        if src.startswith('"""', pos):
590            return parse_multiline_str(src, pos, literal=False)
591        return parse_one_line_basic_str(src, pos)
592
593    # Literal strings
594    if char == "'":
595        if src.startswith("'''", pos):
596            return parse_multiline_str(src, pos, literal=True)
597        return parse_literal_str(src, pos)
598
599    # Booleans
600    if char == "t":
601        if src.startswith("true", pos):
602            return pos + 4, True
603    if char == "f":
604        if src.startswith("false", pos):
605            return pos + 5, False
606
607    # Dates and times
608    datetime_match = RE_DATETIME.match(src, pos)
609    if datetime_match:
610        try:
611            datetime_obj = match_to_datetime(datetime_match)
612        except ValueError as e:
613            raise suffixed_err(src, pos, "Invalid date or datetime") from e
614        return datetime_match.end(), datetime_obj
615    localtime_match = RE_LOCALTIME.match(src, pos)
616    if localtime_match:
617        return localtime_match.end(), match_to_localtime(localtime_match)
618
619    # Integers and "normal" floats.
620    # The regex will greedily match any type starting with a decimal
621    # char, so needs to be located after handling of dates and times.
622    number_match = RE_NUMBER.match(src, pos)
623    if number_match:
624        return number_match.end(), match_to_number(number_match, parse_float)
625
626    # Arrays
627    if char == "[":
628        return parse_array(src, pos, parse_float)
629
630    # Inline tables
631    if char == "{":
632        return parse_inline_table(src, pos, parse_float)
633
634    # Special floats
635    first_three = src[pos : pos + 3]
636    if first_three in {"inf", "nan"}:
637        return pos + 3, parse_float(first_three)
638    first_four = src[pos : pos + 4]
639    if first_four in {"-inf", "+inf", "-nan", "+nan"}:
640        return pos + 4, parse_float(first_four)
641
642    raise suffixed_err(src, pos, "Invalid value")
643
644
645def suffixed_err(src: str, pos: Pos, msg: str) -> TOMLDecodeError:
646    """Return a `TOMLDecodeError` where error message is suffixed with
647    coordinates in source."""
648
649    def coord_repr(src: str, pos: Pos) -> str:
650        if pos >= len(src):
651            return "end of document"
652        line = src.count("\n", 0, pos) + 1
653        if line == 1:
654            column = pos + 1
655        else:
656            column = pos - src.rindex("\n", 0, pos)
657        return f"line {line}, column {column}"
658
659    return TOMLDecodeError(f"{msg} (at {coord_repr(src, pos)})")
660
661
662def is_unicode_scalar_value(codepoint: int) -> bool:
663    return (0 <= codepoint <= 55295) or (57344 <= codepoint <= 1114111)
664