1package parse
2
3import (
4	"bufio"
5	"bytes"
6	"fmt"
7	"github.com/yuin/gopher-lua/ast"
8	"io"
9	"reflect"
10	"strconv"
11	"strings"
12)
13
14const EOF = -1
15const whitespace1 = 1<<'\t' | 1<<' '
16const whitespace2 = 1<<'\t' | 1<<'\n' | 1<<'\r' | 1<<' '
17
18type Error struct {
19	Pos     ast.Position
20	Message string
21	Token   string
22}
23
24func (e *Error) Error() string {
25	pos := e.Pos
26	if pos.Line == EOF {
27		return fmt.Sprintf("%v at EOF:   %s\n", pos.Source, e.Message)
28	} else {
29		return fmt.Sprintf("%v line:%d(column:%d) near '%v':   %s\n", pos.Source, pos.Line, pos.Column, e.Token, e.Message)
30	}
31}
32
33func writeChar(buf *bytes.Buffer, c int) { buf.WriteByte(byte(c)) }
34
35func isDecimal(ch int) bool { return '0' <= ch && ch <= '9' }
36
37func isIdent(ch int, pos int) bool {
38	return ch == '_' || 'A' <= ch && ch <= 'Z' || 'a' <= ch && ch <= 'z' || isDecimal(ch) && pos > 0
39}
40
41func isDigit(ch int) bool {
42	return '0' <= ch && ch <= '9' || 'a' <= ch && ch <= 'f' || 'A' <= ch && ch <= 'F'
43}
44
45type Scanner struct {
46	Pos    ast.Position
47	reader *bufio.Reader
48}
49
50func NewScanner(reader io.Reader, source string) *Scanner {
51	return &Scanner{
52		Pos: ast.Position{
53			Source: source,
54			Line:   1,
55			Column: 0,
56		},
57		reader: bufio.NewReaderSize(reader, 4096),
58	}
59}
60
61func (sc *Scanner) Error(tok string, msg string) *Error { return &Error{sc.Pos, msg, tok} }
62
63func (sc *Scanner) TokenError(tok ast.Token, msg string) *Error { return &Error{tok.Pos, msg, tok.Str} }
64
65func (sc *Scanner) readNext() int {
66	ch, err := sc.reader.ReadByte()
67	if err == io.EOF {
68		return EOF
69	}
70	return int(ch)
71}
72
73func (sc *Scanner) Newline(ch int) {
74	if ch < 0 {
75		return
76	}
77	sc.Pos.Line += 1
78	sc.Pos.Column = 0
79	next := sc.Peek()
80	if ch == '\n' && next == '\r' || ch == '\r' && next == '\n' {
81		sc.reader.ReadByte()
82	}
83}
84
85func (sc *Scanner) Next() int {
86	ch := sc.readNext()
87	switch ch {
88	case '\n', '\r':
89		sc.Newline(ch)
90		ch = int('\n')
91	case EOF:
92		sc.Pos.Line = EOF
93		sc.Pos.Column = 0
94	default:
95		sc.Pos.Column++
96	}
97	return ch
98}
99
100func (sc *Scanner) Peek() int {
101	ch := sc.readNext()
102	if ch != EOF {
103		sc.reader.UnreadByte()
104	}
105	return ch
106}
107
108func (sc *Scanner) skipWhiteSpace(whitespace int64) int {
109	ch := sc.Next()
110	for ; whitespace&(1<<uint(ch)) != 0; ch = sc.Next() {
111	}
112	return ch
113}
114
115func (sc *Scanner) skipComments(ch int) error {
116	// multiline comment
117	if sc.Peek() == '[' {
118		ch = sc.Next()
119		if sc.Peek() == '[' || sc.Peek() == '=' {
120			var buf bytes.Buffer
121			if err := sc.scanMultilineString(sc.Next(), &buf); err != nil {
122				return sc.Error(buf.String(), "invalid multiline comment")
123			}
124			return nil
125		}
126	}
127	for {
128		if ch == '\n' || ch == '\r' || ch < 0 {
129			break
130		}
131		ch = sc.Next()
132	}
133	return nil
134}
135
136func (sc *Scanner) scanIdent(ch int, buf *bytes.Buffer) error {
137	writeChar(buf, ch)
138	for isIdent(sc.Peek(), 1) {
139		writeChar(buf, sc.Next())
140	}
141	return nil
142}
143
144func (sc *Scanner) scanDecimal(ch int, buf *bytes.Buffer) error {
145	writeChar(buf, ch)
146	for isDecimal(sc.Peek()) {
147		writeChar(buf, sc.Next())
148	}
149	return nil
150}
151
152func (sc *Scanner) scanNumber(ch int, buf *bytes.Buffer) error {
153	if ch == '0' { // octal
154		if sc.Peek() == 'x' || sc.Peek() == 'X' {
155			writeChar(buf, ch)
156			writeChar(buf, sc.Next())
157			hasvalue := false
158			for isDigit(sc.Peek()) {
159				writeChar(buf, sc.Next())
160				hasvalue = true
161			}
162			if !hasvalue {
163				return sc.Error(buf.String(), "illegal hexadecimal number")
164			}
165			return nil
166		} else if sc.Peek() != '.' && isDecimal(sc.Peek()) {
167			ch = sc.Next()
168		}
169	}
170	sc.scanDecimal(ch, buf)
171	if sc.Peek() == '.' {
172		sc.scanDecimal(sc.Next(), buf)
173	}
174	if ch = sc.Peek(); ch == 'e' || ch == 'E' {
175		writeChar(buf, sc.Next())
176		if ch = sc.Peek(); ch == '-' || ch == '+' {
177			writeChar(buf, sc.Next())
178		}
179		sc.scanDecimal(sc.Next(), buf)
180	}
181
182	return nil
183}
184
185func (sc *Scanner) scanString(quote int, buf *bytes.Buffer) error {
186	ch := sc.Next()
187	for ch != quote {
188		if ch == '\n' || ch == '\r' || ch < 0 {
189			return sc.Error(buf.String(), "unterminated string")
190		}
191		if ch == '\\' {
192			if err := sc.scanEscape(ch, buf); err != nil {
193				return err
194			}
195		} else {
196			writeChar(buf, ch)
197		}
198		ch = sc.Next()
199	}
200	return nil
201}
202
203func (sc *Scanner) scanEscape(ch int, buf *bytes.Buffer) error {
204	ch = sc.Next()
205	switch ch {
206	case 'a':
207		buf.WriteByte('\a')
208	case 'b':
209		buf.WriteByte('\b')
210	case 'f':
211		buf.WriteByte('\f')
212	case 'n':
213		buf.WriteByte('\n')
214	case 'r':
215		buf.WriteByte('\r')
216	case 't':
217		buf.WriteByte('\t')
218	case 'v':
219		buf.WriteByte('\v')
220	case '\\':
221		buf.WriteByte('\\')
222	case '"':
223		buf.WriteByte('"')
224	case '\'':
225		buf.WriteByte('\'')
226	case '\n':
227		buf.WriteByte('\n')
228	case '\r':
229		buf.WriteByte('\n')
230		sc.Newline('\r')
231	default:
232		if '0' <= ch && ch <= '9' {
233			bytes := []byte{byte(ch)}
234			for i := 0; i < 2 && isDecimal(sc.Peek()); i++ {
235				bytes = append(bytes, byte(sc.Next()))
236			}
237			val, _ := strconv.ParseInt(string(bytes), 10, 32)
238			writeChar(buf, int(val))
239		} else {
240			writeChar(buf, ch)
241		}
242	}
243	return nil
244}
245
246func (sc *Scanner) countSep(ch int) (int, int) {
247	count := 0
248	for ; ch == '='; count = count + 1 {
249		ch = sc.Next()
250	}
251	return count, ch
252}
253
254func (sc *Scanner) scanMultilineString(ch int, buf *bytes.Buffer) error {
255	var count1, count2 int
256	count1, ch = sc.countSep(ch)
257	if ch != '[' {
258		return sc.Error(string(ch), "invalid multiline string")
259	}
260	ch = sc.Next()
261	if ch == '\n' || ch == '\r' {
262		ch = sc.Next()
263	}
264	for {
265		if ch < 0 {
266			return sc.Error(buf.String(), "unterminated multiline string")
267		} else if ch == ']' {
268			count2, ch = sc.countSep(sc.Next())
269			if count1 == count2 && ch == ']' {
270				goto finally
271			}
272			buf.WriteByte(']')
273			buf.WriteString(strings.Repeat("=", count2))
274			continue
275		}
276		writeChar(buf, ch)
277		ch = sc.Next()
278	}
279
280finally:
281	return nil
282}
283
284var reservedWords = map[string]int{
285	"and": TAnd, "break": TBreak, "do": TDo, "else": TElse, "elseif": TElseIf,
286	"end": TEnd, "false": TFalse, "for": TFor, "function": TFunction,
287	"if": TIf, "in": TIn, "local": TLocal, "nil": TNil, "not": TNot, "or": TOr,
288	"return": TReturn, "repeat": TRepeat, "then": TThen, "true": TTrue,
289	"until": TUntil, "while": TWhile}
290
291func (sc *Scanner) Scan(lexer *Lexer) (ast.Token, error) {
292redo:
293	var err error
294	tok := ast.Token{}
295	newline := false
296
297	ch := sc.skipWhiteSpace(whitespace1)
298	if ch == '\n' || ch == '\r' {
299		newline = true
300		ch = sc.skipWhiteSpace(whitespace2)
301	}
302
303	if ch == '(' && lexer.PrevTokenType == ')' {
304		lexer.PNewLine = newline
305	} else {
306		lexer.PNewLine = false
307	}
308
309	var _buf bytes.Buffer
310	buf := &_buf
311	tok.Pos = sc.Pos
312
313	switch {
314	case isIdent(ch, 0):
315		tok.Type = TIdent
316		err = sc.scanIdent(ch, buf)
317		tok.Str = buf.String()
318		if err != nil {
319			goto finally
320		}
321		if typ, ok := reservedWords[tok.Str]; ok {
322			tok.Type = typ
323		}
324	case isDecimal(ch):
325		tok.Type = TNumber
326		err = sc.scanNumber(ch, buf)
327		tok.Str = buf.String()
328	default:
329		switch ch {
330		case EOF:
331			tok.Type = EOF
332		case '-':
333			if sc.Peek() == '-' {
334				err = sc.skipComments(sc.Next())
335				if err != nil {
336					goto finally
337				}
338				goto redo
339			} else {
340				tok.Type = ch
341				tok.Str = string(ch)
342			}
343		case '"', '\'':
344			tok.Type = TString
345			err = sc.scanString(ch, buf)
346			tok.Str = buf.String()
347		case '[':
348			if c := sc.Peek(); c == '[' || c == '=' {
349				tok.Type = TString
350				err = sc.scanMultilineString(sc.Next(), buf)
351				tok.Str = buf.String()
352			} else {
353				tok.Type = ch
354				tok.Str = string(ch)
355			}
356		case '=':
357			if sc.Peek() == '=' {
358				tok.Type = TEqeq
359				tok.Str = "=="
360				sc.Next()
361			} else {
362				tok.Type = ch
363				tok.Str = string(ch)
364			}
365		case '~':
366			if sc.Peek() == '=' {
367				tok.Type = TNeq
368				tok.Str = "~="
369				sc.Next()
370			} else {
371				err = sc.Error("~", "Invalid '~' token")
372			}
373		case '<':
374			if sc.Peek() == '=' {
375				tok.Type = TLte
376				tok.Str = "<="
377				sc.Next()
378			} else {
379				tok.Type = ch
380				tok.Str = string(ch)
381			}
382		case '>':
383			if sc.Peek() == '=' {
384				tok.Type = TGte
385				tok.Str = ">="
386				sc.Next()
387			} else {
388				tok.Type = ch
389				tok.Str = string(ch)
390			}
391		case '.':
392			ch2 := sc.Peek()
393			switch {
394			case isDecimal(ch2):
395				tok.Type = TNumber
396				err = sc.scanNumber(ch, buf)
397				tok.Str = buf.String()
398			case ch2 == '.':
399				writeChar(buf, ch)
400				writeChar(buf, sc.Next())
401				if sc.Peek() == '.' {
402					writeChar(buf, sc.Next())
403					tok.Type = T3Comma
404				} else {
405					tok.Type = T2Comma
406				}
407			default:
408				tok.Type = '.'
409			}
410			tok.Str = buf.String()
411		case '+', '*', '/', '%', '^', '#', '(', ')', '{', '}', ']', ';', ':', ',':
412			tok.Type = ch
413			tok.Str = string(ch)
414		default:
415			writeChar(buf, ch)
416			err = sc.Error(buf.String(), "Invalid token")
417			goto finally
418		}
419	}
420
421finally:
422	tok.Name = TokenName(int(tok.Type))
423	return tok, err
424}
425
426// yacc interface {{{
427
428type Lexer struct {
429	scanner       *Scanner
430	Stmts         []ast.Stmt
431	PNewLine      bool
432	Token         ast.Token
433	PrevTokenType int
434}
435
436func (lx *Lexer) Lex(lval *yySymType) int {
437	lx.PrevTokenType = lx.Token.Type
438	tok, err := lx.scanner.Scan(lx)
439	if err != nil {
440		panic(err)
441	}
442	if tok.Type < 0 {
443		return 0
444	}
445	lval.token = tok
446	lx.Token = tok
447	return int(tok.Type)
448}
449
450func (lx *Lexer) Error(message string) {
451	panic(lx.scanner.Error(lx.Token.Str, message))
452}
453
454func (lx *Lexer) TokenError(tok ast.Token, message string) {
455	panic(lx.scanner.TokenError(tok, message))
456}
457
458func Parse(reader io.Reader, name string) (chunk []ast.Stmt, err error) {
459	lexer := &Lexer{NewScanner(reader, name), nil, false, ast.Token{Str: ""}, TNil}
460	chunk = nil
461	defer func() {
462		if e := recover(); e != nil {
463			err, _ = e.(error)
464		}
465	}()
466	yyParse(lexer)
467	chunk = lexer.Stmts
468	return
469}
470
471// }}}
472
473// Dump {{{
474
475func isInlineDumpNode(rv reflect.Value) bool {
476	switch rv.Kind() {
477	case reflect.Struct, reflect.Slice, reflect.Interface, reflect.Ptr:
478		return false
479	default:
480		return true
481	}
482}
483
484func dump(node interface{}, level int, s string) string {
485	rt := reflect.TypeOf(node)
486	if fmt.Sprint(rt) == "<nil>" {
487		return strings.Repeat(s, level) + "<nil>"
488	}
489
490	rv := reflect.ValueOf(node)
491	buf := []string{}
492	switch rt.Kind() {
493	case reflect.Slice:
494		if rv.Len() == 0 {
495			return strings.Repeat(s, level) + "<empty>"
496		}
497		for i := 0; i < rv.Len(); i++ {
498			buf = append(buf, dump(rv.Index(i).Interface(), level, s))
499		}
500	case reflect.Ptr:
501		vt := rv.Elem()
502		tt := rt.Elem()
503		indicies := []int{}
504		for i := 0; i < tt.NumField(); i++ {
505			if strings.Index(tt.Field(i).Name, "Base") > -1 {
506				continue
507			}
508			indicies = append(indicies, i)
509		}
510		switch {
511		case len(indicies) == 0:
512			return strings.Repeat(s, level) + "<empty>"
513		case len(indicies) == 1 && isInlineDumpNode(vt.Field(indicies[0])):
514			for _, i := range indicies {
515				buf = append(buf, strings.Repeat(s, level)+"- Node$"+tt.Name()+": "+dump(vt.Field(i).Interface(), 0, s))
516			}
517		default:
518			buf = append(buf, strings.Repeat(s, level)+"- Node$"+tt.Name())
519			for _, i := range indicies {
520				if isInlineDumpNode(vt.Field(i)) {
521					inf := dump(vt.Field(i).Interface(), 0, s)
522					buf = append(buf, strings.Repeat(s, level+1)+tt.Field(i).Name+": "+inf)
523				} else {
524					buf = append(buf, strings.Repeat(s, level+1)+tt.Field(i).Name+": ")
525					buf = append(buf, dump(vt.Field(i).Interface(), level+2, s))
526				}
527			}
528		}
529	default:
530		buf = append(buf, strings.Repeat(s, level)+fmt.Sprint(node))
531	}
532	return strings.Join(buf, "\n")
533}
534
535func Dump(chunk []ast.Stmt) string {
536	return dump(chunk, 0, "   ")
537}
538
539// }}
540