1from copy import copy
2from typing import Any, Optional, Tuple, Type
3
4from .exceptions import ParseError, UnexpectedCharError
5from .toml_char import TOMLChar
6
7
8class _State:
9    def __init__(
10        self,
11        source: "Source",
12        save_marker: Optional[str] = False,
13        restore: Optional[str] = False,
14    ) -> None:
15        self._source = source
16        self._save_marker = save_marker
17        self.restore = restore
18
19    def __enter__(self) -> None:
20        # Entering this context manager - save the state
21        self._chars = copy(self._source._chars)
22        self._idx = self._source._idx
23        self._current = self._source._current
24        self._marker = self._source._marker
25
26        return self
27
28    def __exit__(self, exception_type, exception_val, trace):
29        # Exiting this context manager - restore the prior state
30        if self.restore or exception_type:
31            self._source._chars = self._chars
32            self._source._idx = self._idx
33            self._source._current = self._current
34            if self._save_marker:
35                self._source._marker = self._marker
36
37
38class _StateHandler:
39    """
40    State preserver for the Parser.
41    """
42
43    def __init__(self, source: "Source") -> None:
44        self._source = source
45        self._states = []
46
47    def __call__(self, *args, **kwargs):
48        return _State(self._source, *args, **kwargs)
49
50    def __enter__(self) -> None:
51        state = self()
52        self._states.append(state)
53        return state.__enter__()
54
55    def __exit__(self, exception_type, exception_val, trace):
56        state = self._states.pop()
57        return state.__exit__(exception_type, exception_val, trace)
58
59
60class Source(str):
61    EOF = TOMLChar("\0")
62
63    def __init__(self, _: str) -> None:
64        super().__init__()
65
66        # Collection of TOMLChars
67        self._chars = iter([(i, TOMLChar(c)) for i, c in enumerate(self)])
68
69        self._idx = 0
70        self._marker = 0
71        self._current = TOMLChar("")
72
73        self._state = _StateHandler(self)
74
75        self.inc()
76
77    def reset(self):
78        # initialize both idx and current
79        self.inc()
80
81        # reset marker
82        self.mark()
83
84    @property
85    def state(self) -> _StateHandler:
86        return self._state
87
88    @property
89    def idx(self) -> int:
90        return self._idx
91
92    @property
93    def current(self) -> TOMLChar:
94        return self._current
95
96    @property
97    def marker(self) -> int:
98        return self._marker
99
100    def extract(self) -> str:
101        """
102        Extracts the value between marker and index
103        """
104        return self[self._marker : self._idx]
105
106    def inc(self, exception: Optional[Type[ParseError]] = None) -> bool:
107        """
108        Increments the parser if the end of the input has not been reached.
109        Returns whether or not it was able to advance.
110        """
111        try:
112            self._idx, self._current = next(self._chars)
113
114            return True
115        except StopIteration:
116            self._idx = len(self)
117            self._current = self.EOF
118            if exception:
119                raise self.parse_error(exception)
120
121            return False
122
123    def inc_n(self, n: int, exception: Exception = None) -> bool:
124        """
125        Increments the parser by n characters
126        if the end of the input has not been reached.
127        """
128        for _ in range(n):
129            if not self.inc(exception=exception):
130                return False
131
132        return True
133
134    def consume(self, chars, min=0, max=-1):
135        """
136        Consume chars until min/max is satisfied is valid.
137        """
138        while self.current in chars and max != 0:
139            min -= 1
140            max -= 1
141            if not self.inc():
142                break
143
144        # failed to consume minimum number of characters
145        if min > 0:
146            self.parse_error(UnexpectedCharError)
147
148    def end(self) -> bool:
149        """
150        Returns True if the parser has reached the end of the input.
151        """
152        return self._current is self.EOF
153
154    def mark(self) -> None:
155        """
156        Sets the marker to the index's current position
157        """
158        self._marker = self._idx
159
160    def parse_error(
161        self, exception: Type[ParseError] = ParseError, *args: Any
162    ) -> ParseError:
163        """
164        Creates a generic "parse error" at the current position.
165        """
166        line, col = self._to_linecol()
167
168        return exception(line, col, *args)
169
170    def _to_linecol(self) -> Tuple[int, int]:
171        cur = 0
172        for i, line in enumerate(self.splitlines()):
173            if cur + len(line) + 1 > self.idx:
174                return (i + 1, self.idx - cur)
175
176            cur += len(line) + 1
177
178        return len(self.splitlines()), 0
179