1local utils = require "luacheck.utils"
2
3-- Lexer should support syntax of Lua 5.1, Lua 5.2, Lua 5.3, Lua 5.4 and LuaJIT(64bit and complex cdata literals).
4local lexer = {}
5
6local sbyte = string.byte
7local schar = string.char
8local sreverse = string.reverse
9local tconcat = table.concat
10local mfloor = math.floor
11
12-- No point in inlining these, fetching a constant ~= fetching a local.
13local BYTE_0, BYTE_9, BYTE_f, BYTE_F = sbyte("0"), sbyte("9"), sbyte("f"), sbyte("F")
14local BYTE_x, BYTE_X, BYTE_i, BYTE_I = sbyte("x"), sbyte("X"), sbyte("i"), sbyte("I")
15local BYTE_l, BYTE_L, BYTE_u, BYTE_U = sbyte("l"), sbyte("L"), sbyte("u"), sbyte("U")
16local BYTE_e, BYTE_E, BYTE_p, BYTE_P = sbyte("e"), sbyte("E"), sbyte("p"), sbyte("P")
17local BYTE_a, BYTE_z, BYTE_A, BYTE_Z = sbyte("a"), sbyte("z"), sbyte("A"), sbyte("Z")
18local BYTE_DOT, BYTE_COLON = sbyte("."), sbyte(":")
19local BYTE_OBRACK, BYTE_CBRACK = sbyte("["), sbyte("]")
20local BYTE_OBRACE, BYTE_CBRACE = sbyte("{"), sbyte("}")
21local BYTE_QUOTE, BYTE_DQUOTE = sbyte("'"), sbyte('"')
22local BYTE_PLUS, BYTE_DASH, BYTE_LDASH = sbyte("+"), sbyte("-"), sbyte("_")
23local BYTE_SLASH, BYTE_BSLASH = sbyte("/"), sbyte("\\")
24local BYTE_EQ, BYTE_NE = sbyte("="), sbyte("~")
25local BYTE_LT, BYTE_GT = sbyte("<"), sbyte(">")
26local BYTE_LF, BYTE_CR = sbyte("\n"), sbyte("\r")
27local BYTE_SPACE, BYTE_FF, BYTE_TAB, BYTE_VTAB = sbyte(" "), sbyte("\f"), sbyte("\t"), sbyte("\v")
28
29local function to_hex(b)
30   if BYTE_0 <= b and b <= BYTE_9 then
31      return b-BYTE_0
32   elseif BYTE_a <= b and b <= BYTE_f then
33      return 10+b-BYTE_a
34   elseif BYTE_A <= b and b <= BYTE_F then
35      return 10+b-BYTE_A
36   else
37      return nil
38   end
39end
40
41local function to_dec(b)
42   if BYTE_0 <= b and b <= BYTE_9 then
43      return b-BYTE_0
44   else
45      return nil
46   end
47end
48
49local function to_utf(codepoint)
50   if codepoint < 0x80 then  -- ASCII?
51      return schar(codepoint)
52   end
53
54   local buf = {}
55   local mfb = 0x3F
56
57   repeat
58      buf[#buf+1] = schar(codepoint % 0x40 + 0x80)
59      codepoint = mfloor(codepoint / 0x40)
60      mfb = mfloor(mfb / 2)
61   until codepoint <= mfb
62
63   buf[#buf+1] = schar(0xFE - mfb*2 + codepoint)
64   return sreverse(tconcat(buf))
65end
66
67local function is_alpha(b)
68   return (BYTE_a <= b and b <= BYTE_z) or
69      (BYTE_A <= b and b <= BYTE_Z) or b == BYTE_LDASH
70end
71
72local function is_newline(b)
73   return (b == BYTE_LF) or (b == BYTE_CR)
74end
75
76local function is_space(b)
77   return (b == BYTE_SPACE) or (b == BYTE_FF) or
78      (b == BYTE_TAB) or (b == BYTE_VTAB)
79end
80
81local keywords = utils.array_to_set({
82   "and", "break", "do", "else", "elseif", "end", "false", "for", "function", "goto", "if", "in",
83   "local", "nil", "not", "or", "repeat", "return", "then", "true", "until", "while"})
84
85local simple_escapes = {
86   [sbyte("a")] = sbyte("\a"),
87   [sbyte("b")] = sbyte("\b"),
88   [sbyte("f")] = sbyte("\f"),
89   [sbyte("n")] = sbyte("\n"),
90   [sbyte("r")] = sbyte("\r"),
91   [sbyte("t")] = sbyte("\t"),
92   [sbyte("v")] = sbyte("\v"),
93   [BYTE_BSLASH] = BYTE_BSLASH,
94   [BYTE_QUOTE] = BYTE_QUOTE,
95   [BYTE_DQUOTE] = BYTE_DQUOTE
96}
97
98local function next_byte(state)
99   local offset = state.offset + 1
100   state.offset = offset
101   return state.src:get_codepoint(offset)
102end
103
104-- Skipping helpers.
105-- Take the current character, skip something, return next character.
106
107local function skip_newline(state, newline)
108   local first_newline_offset = state.offset
109   local b = next_byte(state)
110
111   if b ~= newline and is_newline(b) then
112      b = next_byte(state)
113   end
114
115   local line = state.line
116   local line_offsets = state.line_offsets
117   state.line_lengths[line] = first_newline_offset - line_offsets[line]
118   line = line + 1
119   state.line = line
120   line_offsets[line] = state.offset
121   return b
122end
123
124local function skip_to_newline(state, b)
125   while not is_newline(b) and b do
126      b = next_byte(state)
127   end
128
129   return b
130end
131
132local function skip_space(state, b)
133   while is_space(b) or is_newline(b) do
134      if is_newline(b) then
135         b = skip_newline(state, b)
136      else
137         b = next_byte(state)
138      end
139   end
140
141   return b
142end
143
144-- Skips "[=*" or "]=*". Returns next character and number of "="s.
145local function skip_long_bracket(state)
146   local start = state.offset
147   local b = next_byte(state)
148
149   while b == BYTE_EQ do
150      b = next_byte(state)
151   end
152
153   return b, state.offset-start-1
154end
155
156-- Token handlers.
157
158-- Called after the opening "[=*" has been skipped.
159-- Takes number of "=" in the opening bracket and token type(comment or string).
160local function lex_long_string(state, opening_long_bracket, token)
161   local b = next_byte(state)
162
163   if is_newline(b) then
164      b = skip_newline(state, b)
165   end
166
167   local lines = {}
168   local line_start = state.offset
169
170   while true do
171      if is_newline(b) then
172         -- Add the finished line.
173         lines[#lines+1] = state.src:get_substring(line_start, state.offset-1)
174
175         b = skip_newline(state, b)
176         line_start = state.offset
177      elseif b == BYTE_CBRACK then
178         local long_bracket
179         b, long_bracket = skip_long_bracket(state)
180
181         if b == BYTE_CBRACK and long_bracket == opening_long_bracket then
182            break
183         end
184      elseif b == nil then
185         return nil, token == "string" and "unfinished long string" or "unfinished long comment"
186      else
187         b = next_byte(state)
188      end
189   end
190
191   -- Add last line.
192   lines[#lines+1] = state.src:get_substring(line_start, state.offset-opening_long_bracket-2)
193   state.offset = state.offset + 1
194   return token, tconcat(lines, "\n")
195end
196
197local function lex_short_string(state, quote)
198   local b = next_byte(state)
199   local chunks  -- Buffer is only required when there are escape sequences.
200   local chunk_start = state.offset
201
202   while b ~= quote do
203      if b == BYTE_BSLASH then
204         -- Escape sequence.
205
206         if not chunks then
207            -- This is the first escape sequence, init buffer.
208            chunks = {}
209         end
210
211         -- Put previous chunk into buffer.
212         if chunk_start ~= state.offset then
213            chunks[#chunks+1] = state.src:get_substring(chunk_start, state.offset-1)
214         end
215
216         b = next_byte(state)
217
218         -- The final string escape sequence evaluates to.
219         local s
220
221         local escape_byte = simple_escapes[b]
222
223         if escape_byte then  -- Is it a simple escape sequence?
224            b = next_byte(state)
225            s = schar(escape_byte)
226         elseif is_newline(b) then
227            b = skip_newline(state, b)
228            s = "\n"
229         elseif b == BYTE_x then
230            -- Hexadecimal escape.
231            b = next_byte(state)  -- Skip "x".
232            -- Exactly two hexadecimal digits.
233            local c1, c2
234
235            if b then
236               c1 = to_hex(b)
237            end
238
239            if not c1 then
240               return nil, "invalid hexadecimal escape sequence", -2
241            end
242
243            b = next_byte(state)
244
245            if b then
246               c2 = to_hex(b)
247            end
248
249            if not c2 then
250               return nil, "invalid hexadecimal escape sequence", -3
251            end
252
253            b = next_byte(state)
254            s = schar(c1*16 + c2)
255         elseif b == BYTE_u then
256            b = next_byte(state)  -- Skip "u".
257
258            if b ~= BYTE_OBRACE then
259               return nil, "invalid UTF-8 escape sequence", -2
260            end
261
262            b = next_byte(state)  -- Skip "{".
263
264            local codepoint  -- There should be at least one digit.
265
266            if b then
267               codepoint = to_hex(b)
268            end
269
270            if not codepoint then
271               return nil, "invalid UTF-8 escape sequence", -3
272            end
273
274            local hexdigits = 0
275
276            while true do
277               b = next_byte(state)
278               local hex
279
280               if b then
281                  hex = to_hex(b)
282               end
283
284               if hex then
285                  hexdigits = hexdigits + 1
286                  codepoint = codepoint*16 + hex
287
288                  if codepoint > 0x7FFFFFFF then
289                     -- UTF-8 value too large.
290                     return nil, "invalid UTF-8 escape sequence", -hexdigits-3
291                  end
292               else
293                  break
294               end
295            end
296
297            if b ~= BYTE_CBRACE then
298               return nil, "invalid UTF-8 escape sequence", -hexdigits-4
299            end
300
301            b = next_byte(state)  -- Skip "}".
302            s = to_utf(codepoint)
303         elseif b == BYTE_z then
304            -- Zap following span of spaces.
305            b = skip_space(state, next_byte(state))
306         else
307            -- Must be a decimal escape.
308            local cb
309
310            if b then
311               cb = to_dec(b)
312            end
313
314            if not cb then
315               return nil, "invalid escape sequence", -1
316            end
317
318            -- Up to three decimal digits.
319            b = next_byte(state)
320
321            if b then
322               local c2 = to_dec(b)
323
324               if c2 then
325                  cb = 10*cb + c2
326                  b = next_byte(state)
327
328                  if b then
329                     local c3 = to_dec(b)
330
331                     if c3 then
332                        cb = 10*cb + c3
333
334                        if cb > 255 then
335                           return nil, "invalid decimal escape sequence", -3
336                        end
337
338                        b = next_byte(state)
339                     end
340                  end
341               end
342            end
343
344            s = schar(cb)
345         end
346
347         if s then
348            chunks[#chunks+1] = s
349         end
350
351         -- Next chunk starts after escape sequence.
352         chunk_start = state.offset
353      elseif b == nil or is_newline(b) then
354         return nil, "unfinished string"
355      else
356         b = next_byte(state)
357      end
358   end
359
360   -- Offset now points at the closing quote.
361   local string_value
362
363   if chunks then
364      -- Put last chunk into buffer.
365      if chunk_start ~= state.offset then
366         chunks[#chunks+1] = state.src:get_substring(chunk_start, state.offset-1)
367      end
368
369      string_value = tconcat(chunks)
370   else
371      -- There were no escape sequences.
372      string_value = state.src:get_substring(chunk_start, state.offset-1)
373   end
374
375   -- Skip the closing quote.
376   state.offset = state.offset + 1
377   return "string", string_value
378end
379
380-- Payload for a number is simply a substring.
381-- Luacheck is supposed to be forward-compatible with Lua 5.3 and LuaJIT syntax, so
382--    parsing it into actual number may be problematic.
383-- It is not needed currently anyway as Luacheck does not do static evaluation yet.
384local function lex_number(state, b)
385   local start = state.offset
386
387   local exp_lower, exp_upper = BYTE_e, BYTE_E
388   local is_digit = to_dec
389   local has_digits = false
390   local is_float = false
391
392   if b == BYTE_0 then
393      b = next_byte(state)
394
395      if b == BYTE_x or b == BYTE_X then
396         exp_lower, exp_upper = BYTE_p, BYTE_P
397         is_digit = to_hex
398         b = next_byte(state)
399      else
400         has_digits = true
401      end
402   end
403
404   while b ~= nil and is_digit(b) do
405      b = next_byte(state)
406      has_digits = true
407   end
408
409   if b == BYTE_DOT then
410      -- Fractional part.
411      is_float = true
412      b = next_byte(state)  -- Skip dot.
413
414      while b ~= nil and is_digit(b) do
415         b = next_byte(state)
416         has_digits = true
417      end
418   end
419
420   if b == exp_lower or b == exp_upper then
421      -- Exponent part.
422      is_float = true
423      b = next_byte(state)
424
425      -- Skip optional sign.
426      if b == BYTE_PLUS or b == BYTE_DASH then
427         b = next_byte(state)
428      end
429
430      -- Exponent consists of one or more decimal digits.
431      if b == nil or not to_dec(b) then
432         return nil, "malformed number"
433      end
434
435      repeat
436         b = next_byte(state)
437      until b == nil or not to_dec(b)
438   end
439
440   if not has_digits then
441      return nil, "malformed number"
442   end
443
444   -- Is it cdata literal?
445   if b == BYTE_i or b == BYTE_I then
446      -- It is complex literal. Skip "i" or "I".
447      state.offset = state.offset + 1
448   else
449      -- uint64_t and int64_t literals can not be fractional.
450      if not is_float then
451         if b == BYTE_u or b == BYTE_U then
452            -- It may be uint64_t literal.
453            local b1 = state.src:get_codepoint(state.offset+1)
454
455            if b1 == BYTE_l or b1 == BYTE_L then
456               local b2 = state.src:get_codepoint(state.offset+2)
457
458               if b2 == BYTE_l or b2 == BYTE_L then
459                  -- It is uint64_t literal.
460                  state.offset = state.offset + 3
461               end
462            end
463         elseif b == BYTE_l or b == BYTE_L then
464            -- It may be uint64_t or int64_t literal.
465            local b1 = state.src:get_codepoint(state.offset+1)
466
467            if b1 == BYTE_l or b1 == BYTE_L then
468               local b2 = state.src:get_codepoint(state.offset+2)
469
470               if b2 == BYTE_u or b2 == BYTE_U then
471                  -- It is uint64_t literal.
472                  state.offset = state.offset + 3
473               else
474                  -- It is int64_t literal.
475                  state.offset = state.offset + 2
476               end
477            end
478         end
479      end
480   end
481
482   return "number", state.src:get_substring(start, state.offset-1)
483end
484
485local function lex_ident(state)
486   local start = state.offset
487   local b = next_byte(state)
488
489   while (b ~= nil) and (is_alpha(b) or to_dec(b)) do
490      b = next_byte(state)
491   end
492
493   local ident = state.src:get_substring(start, state.offset-1)
494
495   if keywords[ident] then
496      return ident
497   else
498      return "name", ident
499   end
500end
501
502local function lex_dash(state)
503   local b = next_byte(state)
504
505   -- Is it "-" or comment?
506   if b ~= BYTE_DASH then
507      return "-"
508   end
509
510   -- It is a comment.
511   b = next_byte(state)
512   local start = state.offset
513
514   -- Is it a long comment?
515   if b == BYTE_OBRACK then
516      local long_bracket
517      b, long_bracket = skip_long_bracket(state)
518
519      if b == BYTE_OBRACK then
520         return lex_long_string(state, long_bracket, "long_comment")
521      end
522   end
523
524   -- Short comment.
525   skip_to_newline(state, b)
526   local comment_value = state.src:get_substring(start, state.offset - 1)
527   return "short_comment", comment_value
528end
529
530local function lex_bracket(state)
531   -- Is it "[" or long string?
532   local b, long_bracket = skip_long_bracket(state)
533
534   if b == BYTE_OBRACK then
535      return lex_long_string(state, long_bracket, "string")
536   elseif long_bracket == 0 then
537      return "["
538   else
539      return nil, "invalid long string delimiter"
540   end
541end
542
543local function lex_eq(state)
544   local b = next_byte(state)
545
546   if b == BYTE_EQ then
547      state.offset = state.offset + 1
548      return "=="
549   else
550      return "="
551   end
552end
553
554local function lex_lt(state)
555   local b = next_byte(state)
556
557   if b == BYTE_EQ then
558      state.offset = state.offset + 1
559      return "<="
560   elseif b == BYTE_LT then
561      state.offset = state.offset + 1
562      return "<<"
563   else
564      return "<"
565   end
566end
567
568local function lex_gt(state)
569   local b = next_byte(state)
570
571   if b == BYTE_EQ then
572      state.offset = state.offset + 1
573      return ">="
574   elseif b == BYTE_GT then
575      state.offset = state.offset + 1
576      return ">>"
577   else
578      return ">"
579   end
580end
581
582local function lex_div(state)
583   local b = next_byte(state)
584
585   if b == BYTE_SLASH then
586      state.offset = state.offset + 1
587      return "//"
588   else
589      return "/"
590   end
591end
592
593local function lex_ne(state)
594   local b = next_byte(state)
595
596   if b == BYTE_EQ then
597      state.offset = state.offset + 1
598      return "~="
599   else
600      return "~"
601   end
602end
603
604local function lex_colon(state)
605   local b = next_byte(state)
606
607   if b == BYTE_COLON then
608      state.offset = state.offset + 1
609      return "::"
610   else
611      return ":"
612   end
613end
614
615local function lex_dot(state)
616   local b = next_byte(state)
617
618   if b == BYTE_DOT then
619      b = next_byte(state)
620
621      if b == BYTE_DOT then
622         state.offset = state.offset + 1
623         return "...", "..."
624      else
625         return ".."
626      end
627   elseif b and to_dec(b) then
628      -- Backtrack to dot.
629      state.offset = state.offset - 2
630      return lex_number(state, next_byte(state))
631   else
632      return "."
633   end
634end
635
636local function lex_any(state, b)
637   state.offset = state.offset + 1
638
639   if b > 255 then
640      b = 255
641   end
642
643   return schar(b)
644end
645
646-- Maps first bytes of tokens to functions that handle them.
647-- Each handler takes the first byte as an argument.
648-- Each handler stops at the character after the token and returns the token and,
649--    optionally, a value associated with the token.
650-- On error handler returns nil, error message and, optionally, start of reported location as negative offset.
651local byte_handlers = {
652   [BYTE_DOT] = lex_dot,
653   [BYTE_COLON] = lex_colon,
654   [BYTE_OBRACK] = lex_bracket,
655   [BYTE_QUOTE] = lex_short_string,
656   [BYTE_DQUOTE] = lex_short_string,
657   [BYTE_DASH] = lex_dash,
658   [BYTE_SLASH] = lex_div,
659   [BYTE_EQ] = lex_eq,
660   [BYTE_NE] = lex_ne,
661   [BYTE_LT] = lex_lt,
662   [BYTE_GT] = lex_gt,
663   [BYTE_LDASH] = lex_ident
664}
665
666for b=BYTE_0, BYTE_9 do
667   byte_handlers[b] = lex_number
668end
669
670for b=BYTE_a, BYTE_z do
671   byte_handlers[b] = lex_ident
672end
673
674for b=BYTE_A, BYTE_Z do
675   byte_handlers[b] = lex_ident
676end
677
678-- Creates and returns lexer state for source.
679function lexer.new_state(src, line_offsets, line_lengths)
680   local state = {
681      src = src,
682      line = 1,
683      line_offsets = line_offsets or {},
684      line_lengths = line_lengths or {},
685      offset = 1
686   }
687
688   state.line_offsets[1] = 1
689
690   if src:get_length() >= 2 and src:get_substring(1, 2) == "#!" then
691      -- Skip shebang line.
692      state.offset = 2
693      skip_to_newline(state, next_byte(state))
694   end
695
696   return state
697end
698
699function lexer.get_quoted_substring_or_line(state, line, offset, end_offset)
700   local line_length = state.line_lengths[line]
701
702   if line_length then
703      local line_end_offset = state.line_offsets[line] + line_length - 1
704
705      if line_end_offset < end_offset then
706         end_offset = line_end_offset
707      end
708   end
709
710   return "'" .. state.src:get_printable_substring(offset, end_offset) .. "'"
711end
712
713-- Looks for next token starting from state.line, state.offset.
714-- Returns next token, its value and its location (line, offset).
715-- Sets state.line, state.offset to token end location + 1.
716-- Fills state.line_offsets and state.line_lengths.
717-- On error returns nil, error message, error location (line, offset), error end offset.
718function lexer.next_token(state)
719   local line_offsets = state.line_offsets
720   local b = skip_space(state, state.src:get_codepoint(state.offset))
721
722   -- Save location of token start.
723   local token_line = state.line
724   local line_offset = line_offsets[token_line]
725   local token_offset = state.offset
726
727   if not b then
728      -- EOF token has length 1.
729      state.offset = state.offset + 1
730      state.line_lengths[token_line] = token_offset - line_offset
731      return "eof", nil, token_line, token_offset
732   end
733
734   local token, token_value, relative_error_offset = (byte_handlers[b] or lex_any)(state, b)
735
736   if relative_error_offset then
737      -- Error relative to current offset.
738      local error_offset = state.offset + relative_error_offset
739      local error_end_offset = math.min(state.offset, state.src:get_length())
740      local error_message = token_value .. " " .. lexer.get_quoted_substring_or_line(state,
741         state.line, error_offset, error_end_offset)
742      return nil, error_message, state.line, error_offset, error_end_offset
743   end
744
745   -- Single character errors fall through here.
746   return token, token_value, token_line, token_offset, not token and token_offset
747end
748
749return lexer
750