1# Lexer Implementation
2
3import re
4
5from .utils import Str, classify, get_regexp_width, Py36, Serialize, suppress
6from .exceptions import UnexpectedCharacters, LexError, UnexpectedToken
7
8###{standalone
9from warnings import warn
10from copy import copy
11
12
13class Pattern(Serialize):
14    raw = None
15    type = None
16
17    def __init__(self, value, flags=(), raw=None):
18        self.value = value
19        self.flags = frozenset(flags)
20        self.raw = raw
21
22    def __repr__(self):
23        return repr(self.to_regexp())
24
25    # Pattern Hashing assumes all subclasses have a different priority!
26    def __hash__(self):
27        return hash((type(self), self.value, self.flags))
28
29    def __eq__(self, other):
30        return type(self) == type(other) and self.value == other.value and self.flags == other.flags
31
32    def to_regexp(self):
33        raise NotImplementedError()
34
35    def min_width(self):
36        raise NotImplementedError()
37
38    def max_width(self):
39        raise NotImplementedError()
40
41    if Py36:
42        # Python 3.6 changed syntax for flags in regular expression
43        def _get_flags(self, value):
44            for f in self.flags:
45                value = ('(?%s:%s)' % (f, value))
46            return value
47
48    else:
49        def _get_flags(self, value):
50            for f in self.flags:
51                value = ('(?%s)' % f) + value
52            return value
53
54
55
56class PatternStr(Pattern):
57    __serialize_fields__ = 'value', 'flags'
58
59    type = "str"
60
61    def to_regexp(self):
62        return self._get_flags(re.escape(self.value))
63
64    @property
65    def min_width(self):
66        return len(self.value)
67    max_width = min_width
68
69
70class PatternRE(Pattern):
71    __serialize_fields__ = 'value', 'flags', '_width'
72
73    type = "re"
74
75    def to_regexp(self):
76        return self._get_flags(self.value)
77
78    _width = None
79    def _get_width(self):
80        if self._width is None:
81            self._width = get_regexp_width(self.to_regexp())
82        return self._width
83
84    @property
85    def min_width(self):
86        return self._get_width()[0]
87
88    @property
89    def max_width(self):
90        return self._get_width()[1]
91
92
93class TerminalDef(Serialize):
94    __serialize_fields__ = 'name', 'pattern', 'priority'
95    __serialize_namespace__ = PatternStr, PatternRE
96
97    def __init__(self, name, pattern, priority=1):
98        assert isinstance(pattern, Pattern), pattern
99        self.name = name
100        self.pattern = pattern
101        self.priority = priority
102
103    def __repr__(self):
104        return '%s(%r, %r)' % (type(self).__name__, self.name, self.pattern)
105
106    def user_repr(self):
107        if self.name.startswith('__'): # We represent a generated terminal
108            return self.pattern.raw or self.name
109        else:
110            return self.name
111
112
113class Token(Str):
114    """A string with meta-information, that is produced by the lexer.
115
116    When parsing text, the resulting chunks of the input that haven't been discarded,
117    will end up in the tree as Token instances. The Token class inherits from Python's ``str``,
118    so normal string comparisons and operations will work as expected.
119
120    Attributes:
121        type: Name of the token (as specified in grammar)
122        value: Value of the token (redundant, as ``token.value == token`` will always be true)
123        start_pos: The index of the token in the text
124        line: The line of the token in the text (starting with 1)
125        column: The column of the token in the text (starting with 1)
126        end_line: The line where the token ends
127        end_column: The next column after the end of the token. For example,
128            if the token is a single character with a column value of 4,
129            end_column will be 5.
130        end_pos: the index where the token ends (basically ``start_pos + len(token)``)
131    """
132    __slots__ = ('type', 'start_pos', 'value', 'line', 'column', 'end_line', 'end_column', 'end_pos')
133
134    def __new__(cls, type_, value, start_pos=None, line=None, column=None, end_line=None, end_column=None, end_pos=None, pos_in_stream=None):
135        try:
136            inst = super(Token, cls).__new__(cls, value)
137        except UnicodeDecodeError:
138            value = value.decode('latin1')
139            inst = super(Token, cls).__new__(cls, value)
140
141        inst.type = type_
142        inst.start_pos = start_pos if start_pos is not None else pos_in_stream
143        inst.value = value
144        inst.line = line
145        inst.column = column
146        inst.end_line = end_line
147        inst.end_column = end_column
148        inst.end_pos = end_pos
149        return inst
150
151    @property
152    def pos_in_stream(self):
153        warn("Attribute Token.pos_in_stream was renamed to Token.start_pos", DeprecationWarning, 2)
154        return self.start_pos
155
156    def update(self, type_=None, value=None):
157        return Token.new_borrow_pos(
158            type_ if type_ is not None else self.type,
159            value if value is not None else self.value,
160            self
161        )
162
163    @classmethod
164    def new_borrow_pos(cls, type_, value, borrow_t):
165        return cls(type_, value, borrow_t.start_pos, borrow_t.line, borrow_t.column, borrow_t.end_line, borrow_t.end_column, borrow_t.end_pos)
166
167    def __reduce__(self):
168        return (self.__class__, (self.type, self.value, self.start_pos, self.line, self.column))
169
170    def __repr__(self):
171        return 'Token(%r, %r)' % (self.type, self.value)
172
173    def __deepcopy__(self, memo):
174        return Token(self.type, self.value, self.start_pos, self.line, self.column)
175
176    def __eq__(self, other):
177        if isinstance(other, Token) and self.type != other.type:
178            return False
179
180        return Str.__eq__(self, other)
181
182    __hash__ = Str.__hash__
183
184
185class LineCounter:
186    __slots__ = 'char_pos', 'line', 'column', 'line_start_pos', 'newline_char'
187
188    def __init__(self, newline_char):
189        self.newline_char = newline_char
190        self.char_pos = 0
191        self.line = 1
192        self.column = 1
193        self.line_start_pos = 0
194
195    def __eq__(self, other):
196        if not isinstance(other, LineCounter):
197            return NotImplemented
198
199        return self.char_pos == other.char_pos and self.newline_char == other.newline_char
200
201    def feed(self, token, test_newline=True):
202        """Consume a token and calculate the new line & column.
203
204        As an optional optimization, set test_newline=False if token doesn't contain a newline.
205        """
206        if test_newline:
207            newlines = token.count(self.newline_char)
208            if newlines:
209                self.line += newlines
210                self.line_start_pos = self.char_pos + token.rindex(self.newline_char) + 1
211
212        self.char_pos += len(token)
213        self.column = self.char_pos - self.line_start_pos + 1
214
215
216class UnlessCallback:
217    def __init__(self, scanner):
218        self.scanner = scanner
219
220    def __call__(self, t):
221        res = self.scanner.match(t.value, 0)
222        if res:
223            _value, t.type = res
224        return t
225
226
227class CallChain:
228    def __init__(self, callback1, callback2, cond):
229        self.callback1 = callback1
230        self.callback2 = callback2
231        self.cond = cond
232
233    def __call__(self, t):
234        t2 = self.callback1(t)
235        return self.callback2(t) if self.cond(t2) else t2
236
237
238def _get_match(re_, regexp, s, flags):
239    m = re_.match(regexp, s, flags)
240    if m:
241        return m.group(0)
242
243def _create_unless(terminals, g_regex_flags, re_, use_bytes):
244    tokens_by_type = classify(terminals, lambda t: type(t.pattern))
245    assert len(tokens_by_type) <= 2, tokens_by_type.keys()
246    embedded_strs = set()
247    callback = {}
248    for retok in tokens_by_type.get(PatternRE, []):
249        unless = []
250        for strtok in tokens_by_type.get(PatternStr, []):
251            if strtok.priority > retok.priority:
252                continue
253            s = strtok.pattern.value
254            if s == _get_match(re_, retok.pattern.to_regexp(), s, g_regex_flags):
255                unless.append(strtok)
256                if strtok.pattern.flags <= retok.pattern.flags:
257                    embedded_strs.add(strtok)
258        if unless:
259            callback[retok.name] = UnlessCallback(Scanner(unless, g_regex_flags, re_, match_whole=True, use_bytes=use_bytes))
260
261    new_terminals = [t for t in terminals if t not in embedded_strs]
262    return new_terminals, callback
263
264
265
266class Scanner:
267    def __init__(self, terminals, g_regex_flags, re_, use_bytes, match_whole=False):
268        self.terminals = terminals
269        self.g_regex_flags = g_regex_flags
270        self.re_ = re_
271        self.use_bytes = use_bytes
272        self.match_whole = match_whole
273
274        self.allowed_types = {t.name for t in self.terminals}
275
276        self._mres = self._build_mres(terminals, len(terminals))
277
278    def _build_mres(self, terminals, max_size):
279        # Python sets an unreasonable group limit (currently 100) in its re module
280        # Worse, the only way to know we reached it is by catching an AssertionError!
281        # This function recursively tries less and less groups until it's successful.
282        postfix = '$' if self.match_whole else ''
283        mres = []
284        while terminals:
285            pattern = u'|'.join(u'(?P<%s>%s)' % (t.name, t.pattern.to_regexp() + postfix) for t in terminals[:max_size])
286            if self.use_bytes:
287                pattern = pattern.encode('latin-1')
288            try:
289                mre = self.re_.compile(pattern, self.g_regex_flags)
290            except AssertionError:  # Yes, this is what Python provides us.. :/
291                return self._build_mres(terminals, max_size//2)
292
293            mres.append((mre, {i: n for n, i in mre.groupindex.items()}))
294            terminals = terminals[max_size:]
295        return mres
296
297    def match(self, text, pos):
298        for mre, type_from_index in self._mres:
299            m = mre.match(text, pos)
300            if m:
301                return m.group(0), type_from_index[m.lastindex]
302
303
304def _regexp_has_newline(r):
305    r"""Expressions that may indicate newlines in a regexp:
306        - newlines (\n)
307        - escaped newline (\\n)
308        - anything but ([^...])
309        - any-char (.) when the flag (?s) exists
310        - spaces (\s)
311    """
312    return '\n' in r or '\\n' in r or '\\s' in r or '[^' in r or ('(?s' in r and '.' in r)
313
314
315class Lexer(object):
316    """Lexer interface
317
318    Method Signatures:
319        lex(self, text) -> Iterator[Token]
320    """
321    lex = NotImplemented
322
323    def make_lexer_state(self, text):
324        line_ctr = LineCounter(b'\n' if isinstance(text, bytes) else '\n')
325        return LexerState(text, line_ctr)
326
327
328class TraditionalLexer(Lexer):
329
330    def __init__(self, conf):
331        terminals = list(conf.terminals)
332        assert all(isinstance(t, TerminalDef) for t in terminals), terminals
333
334        self.re = conf.re_module
335
336        if not conf.skip_validation:
337            # Sanitization
338            for t in terminals:
339                try:
340                    self.re.compile(t.pattern.to_regexp(), conf.g_regex_flags)
341                except self.re.error:
342                    raise LexError("Cannot compile token %s: %s" % (t.name, t.pattern))
343
344                if t.pattern.min_width == 0:
345                    raise LexError("Lexer does not allow zero-width terminals. (%s: %s)" % (t.name, t.pattern))
346
347            if not (set(conf.ignore) <= {t.name for t in terminals}):
348                raise LexError("Ignore terminals are not defined: %s" % (set(conf.ignore) - {t.name for t in terminals}))
349
350        # Init
351        self.newline_types = frozenset(t.name for t in terminals if _regexp_has_newline(t.pattern.to_regexp()))
352        self.ignore_types = frozenset(conf.ignore)
353
354        terminals.sort(key=lambda x: (-x.priority, -x.pattern.max_width, -len(x.pattern.value), x.name))
355        self.terminals = terminals
356        self.user_callbacks = conf.callbacks
357        self.g_regex_flags = conf.g_regex_flags
358        self.use_bytes = conf.use_bytes
359        self.terminals_by_name = conf.terminals_by_name
360
361        self._scanner = None
362
363    def _build_scanner(self):
364        terminals, self.callback = _create_unless(self.terminals, self.g_regex_flags, self.re, self.use_bytes)
365        assert all(self.callback.values())
366
367        for type_, f in self.user_callbacks.items():
368            if type_ in self.callback:
369                # Already a callback there, probably UnlessCallback
370                self.callback[type_] = CallChain(self.callback[type_], f, lambda t: t.type == type_)
371            else:
372                self.callback[type_] = f
373
374        self._scanner = Scanner(terminals, self.g_regex_flags, self.re, self.use_bytes)
375
376    @property
377    def scanner(self):
378        if self._scanner is None:
379            self._build_scanner()
380        return self._scanner
381
382    def match(self, text, pos):
383        return self.scanner.match(text, pos)
384
385    def lex(self, state, parser_state):
386        with suppress(EOFError):
387            while True:
388                yield self.next_token(state, parser_state)
389
390    def next_token(self, lex_state, parser_state=None):
391        line_ctr = lex_state.line_ctr
392        while line_ctr.char_pos < len(lex_state.text):
393            res = self.match(lex_state.text, line_ctr.char_pos)
394            if not res:
395                allowed = self.scanner.allowed_types - self.ignore_types
396                if not allowed:
397                    allowed = {"<END-OF-FILE>"}
398                raise UnexpectedCharacters(lex_state.text, line_ctr.char_pos, line_ctr.line, line_ctr.column,
399                                           allowed=allowed, token_history=lex_state.last_token and [lex_state.last_token],
400                                           state=parser_state, terminals_by_name=self.terminals_by_name)
401
402            value, type_ = res
403
404            if type_ not in self.ignore_types:
405                t = Token(type_, value, line_ctr.char_pos, line_ctr.line, line_ctr.column)
406                line_ctr.feed(value, type_ in self.newline_types)
407                t.end_line = line_ctr.line
408                t.end_column = line_ctr.column
409                t.end_pos = line_ctr.char_pos
410                if t.type in self.callback:
411                    t = self.callback[t.type](t)
412                    if not isinstance(t, Token):
413                        raise LexError("Callbacks must return a token (returned %r)" % t)
414                lex_state.last_token = t
415                return t
416            else:
417                if type_ in self.callback:
418                    t2 = Token(type_, value, line_ctr.char_pos, line_ctr.line, line_ctr.column)
419                    self.callback[type_](t2)
420                line_ctr.feed(value, type_ in self.newline_types)
421
422        # EOF
423        raise EOFError(self)
424
425
426class LexerState(object):
427    __slots__ = 'text', 'line_ctr', 'last_token'
428
429    def __init__(self, text, line_ctr, last_token=None):
430        self.text = text
431        self.line_ctr = line_ctr
432        self.last_token = last_token
433
434    def __eq__(self, other):
435        if not isinstance(other, LexerState):
436            return NotImplemented
437
438        return self.text is other.text and self.line_ctr == other.line_ctr and self.last_token == other.last_token
439
440    def __copy__(self):
441        return type(self)(self.text, copy(self.line_ctr), self.last_token)
442
443
444class ContextualLexer(Lexer):
445
446    def __init__(self, conf, states, always_accept=()):
447        terminals = list(conf.terminals)
448        terminals_by_name = conf.terminals_by_name
449
450        trad_conf = copy(conf)
451        trad_conf.terminals = terminals
452
453        lexer_by_tokens = {}
454        self.lexers = {}
455        for state, accepts in states.items():
456            key = frozenset(accepts)
457            try:
458                lexer = lexer_by_tokens[key]
459            except KeyError:
460                accepts = set(accepts) | set(conf.ignore) | set(always_accept)
461                lexer_conf = copy(trad_conf)
462                lexer_conf.terminals = [terminals_by_name[n] for n in accepts if n in terminals_by_name]
463                lexer = TraditionalLexer(lexer_conf)
464                lexer_by_tokens[key] = lexer
465
466            self.lexers[state] = lexer
467
468        assert trad_conf.terminals is terminals
469        self.root_lexer = TraditionalLexer(trad_conf)
470
471    def make_lexer_state(self, text):
472        return self.root_lexer.make_lexer_state(text)
473
474    def lex(self, lexer_state, parser_state):
475        try:
476            while True:
477                lexer = self.lexers[parser_state.position]
478                yield lexer.next_token(lexer_state, parser_state)
479        except EOFError:
480            pass
481        except UnexpectedCharacters as e:
482            # In the contextual lexer, UnexpectedCharacters can mean that the terminal is defined, but not in the current context.
483            # This tests the input against the global context, to provide a nicer error.
484            try:
485                last_token = lexer_state.last_token  # Save last_token. Calling root_lexer.next_token will change this to the wrong token
486                token = self.root_lexer.next_token(lexer_state, parser_state)
487                raise UnexpectedToken(token, e.allowed, state=parser_state, token_history=[last_token], terminals_by_name=self.root_lexer.terminals_by_name)
488            except UnexpectedCharacters:
489                raise e  # Raise the original UnexpectedCharacters. The root lexer raises it with the wrong expected set.
490
491class LexerThread(object):
492    """A thread that ties a lexer instance and a lexer state, to be used by the parser"""
493
494    def __init__(self, lexer, text):
495        self.lexer = lexer
496        self.state = lexer.make_lexer_state(text)
497
498    def lex(self, parser_state):
499        return self.lexer.lex(self.state, parser_state)
500
501    def __copy__(self):
502        copied = object.__new__(LexerThread)
503        copied.lexer = self.lexer
504        copied.state = copy(self.state)
505        return copied
506###}
507