1from copy import copy 2from typing import Any, Optional, Tuple, Type 3 4from .exceptions import ParseError, UnexpectedCharError 5from .toml_char import TOMLChar 6 7 8class _State: 9 def __init__( 10 self, 11 source: "Source", 12 save_marker: Optional[str] = False, 13 restore: Optional[str] = False, 14 ) -> None: 15 self._source = source 16 self._save_marker = save_marker 17 self.restore = restore 18 19 def __enter__(self) -> None: 20 # Entering this context manager - save the state 21 self._chars = copy(self._source._chars) 22 self._idx = self._source._idx 23 self._current = self._source._current 24 self._marker = self._source._marker 25 26 return self 27 28 def __exit__(self, exception_type, exception_val, trace): 29 # Exiting this context manager - restore the prior state 30 if self.restore or exception_type: 31 self._source._chars = self._chars 32 self._source._idx = self._idx 33 self._source._current = self._current 34 if self._save_marker: 35 self._source._marker = self._marker 36 37 38class _StateHandler: 39 """ 40 State preserver for the Parser. 41 """ 42 43 def __init__(self, source: "Source") -> None: 44 self._source = source 45 self._states = [] 46 47 def __call__(self, *args, **kwargs): 48 return _State(self._source, *args, **kwargs) 49 50 def __enter__(self) -> None: 51 state = self() 52 self._states.append(state) 53 return state.__enter__() 54 55 def __exit__(self, exception_type, exception_val, trace): 56 state = self._states.pop() 57 return state.__exit__(exception_type, exception_val, trace) 58 59 60class Source(str): 61 EOF = TOMLChar("\0") 62 63 def __init__(self, _: str) -> None: 64 super().__init__() 65 66 # Collection of TOMLChars 67 self._chars = iter([(i, TOMLChar(c)) for i, c in enumerate(self)]) 68 69 self._idx = 0 70 self._marker = 0 71 self._current = TOMLChar("") 72 73 self._state = _StateHandler(self) 74 75 self.inc() 76 77 def reset(self): 78 # initialize both idx and current 79 self.inc() 80 81 # reset marker 82 self.mark() 83 84 @property 85 def state(self) -> _StateHandler: 86 return self._state 87 88 @property 89 def idx(self) -> int: 90 return self._idx 91 92 @property 93 def current(self) -> TOMLChar: 94 return self._current 95 96 @property 97 def marker(self) -> int: 98 return self._marker 99 100 def extract(self) -> str: 101 """ 102 Extracts the value between marker and index 103 """ 104 return self[self._marker : self._idx] 105 106 def inc(self, exception: Optional[Type[ParseError]] = None) -> bool: 107 """ 108 Increments the parser if the end of the input has not been reached. 109 Returns whether or not it was able to advance. 110 """ 111 try: 112 self._idx, self._current = next(self._chars) 113 114 return True 115 except StopIteration: 116 self._idx = len(self) 117 self._current = self.EOF 118 if exception: 119 raise self.parse_error(exception) 120 121 return False 122 123 def inc_n(self, n: int, exception: Exception = None) -> bool: 124 """ 125 Increments the parser by n characters 126 if the end of the input has not been reached. 127 """ 128 for _ in range(n): 129 if not self.inc(exception=exception): 130 return False 131 132 return True 133 134 def consume(self, chars, min=0, max=-1): 135 """ 136 Consume chars until min/max is satisfied is valid. 137 """ 138 while self.current in chars and max != 0: 139 min -= 1 140 max -= 1 141 if not self.inc(): 142 break 143 144 # failed to consume minimum number of characters 145 if min > 0: 146 self.parse_error(UnexpectedCharError) 147 148 def end(self) -> bool: 149 """ 150 Returns True if the parser has reached the end of the input. 151 """ 152 return self._current is self.EOF 153 154 def mark(self) -> None: 155 """ 156 Sets the marker to the index's current position 157 """ 158 self._marker = self._idx 159 160 def parse_error( 161 self, exception: Type[ParseError] = ParseError, *args: Any 162 ) -> ParseError: 163 """ 164 Creates a generic "parse error" at the current position. 165 """ 166 line, col = self._to_linecol() 167 168 return exception(line, col, *args) 169 170 def _to_linecol(self) -> Tuple[int, int]: 171 cur = 0 172 for i, line in enumerate(self.splitlines()): 173 if cur + len(line) + 1 > self.idx: 174 return (i + 1, self.idx - cur) 175 176 cur += len(line) + 1 177 178 return len(self.splitlines()), 0 179