1package protoparse
2
3import (
4	"bufio"
5	"bytes"
6	"errors"
7	"fmt"
8	"io"
9	"strconv"
10	"strings"
11	"unicode/utf8"
12
13	"github.com/jhump/protoreflect/desc/protoparse/ast"
14)
15
16type runeReader struct {
17	rr     *bufio.Reader
18	marked []rune
19	unread []rune
20	err    error
21}
22
23func (rr *runeReader) readRune() (r rune, size int, err error) {
24	if rr.err != nil {
25		return 0, 0, rr.err
26	}
27	if len(rr.unread) > 0 {
28		r := rr.unread[len(rr.unread)-1]
29		rr.unread = rr.unread[:len(rr.unread)-1]
30		if rr.marked != nil {
31			rr.marked = append(rr.marked, r)
32		}
33		return r, utf8.RuneLen(r), nil
34	}
35	r, sz, err := rr.rr.ReadRune()
36	if err != nil {
37		rr.err = err
38	} else if rr.marked != nil {
39		rr.marked = append(rr.marked, r)
40	}
41	return r, sz, err
42}
43
44func (rr *runeReader) unreadRune(r rune) {
45	if rr.marked != nil {
46		if rr.marked[len(rr.marked)-1] != r {
47			panic("unread rune is not the same as last marked rune!")
48		}
49		rr.marked = rr.marked[:len(rr.marked)-1]
50	}
51	rr.unread = append(rr.unread, r)
52}
53
54func (rr *runeReader) startMark(initial rune) {
55	rr.marked = []rune{initial}
56}
57
58func (rr *runeReader) endMark() string {
59	m := string(rr.marked)
60	rr.marked = rr.marked[:0]
61	return m
62}
63
64type protoLex struct {
65	filename string
66	input    *runeReader
67	errs     *errorHandler
68	res      *ast.FileNode
69
70	lineNo int
71	colNo  int
72	offset int
73
74	prevSym ast.TerminalNode
75	eof     ast.TerminalNode
76
77	prevLineNo int
78	prevColNo  int
79	prevOffset int
80	comments   []ast.Comment
81	ws         []rune
82}
83
84var utf8Bom = []byte{0xEF, 0xBB, 0xBF}
85
86func newLexer(in io.Reader, filename string, errs *errorHandler) *protoLex {
87	br := bufio.NewReader(in)
88
89	// if file has UTF8 byte order marker preface, consume it
90	marker, err := br.Peek(3)
91	if err == nil && bytes.Equal(marker, utf8Bom) {
92		_, _ = br.Discard(3)
93	}
94
95	return &protoLex{
96		input:    &runeReader{rr: br},
97		filename: filename,
98		errs:     errs,
99	}
100}
101
102var keywords = map[string]int{
103	"syntax":     _SYNTAX,
104	"import":     _IMPORT,
105	"weak":       _WEAK,
106	"public":     _PUBLIC,
107	"package":    _PACKAGE,
108	"option":     _OPTION,
109	"true":       _TRUE,
110	"false":      _FALSE,
111	"inf":        _INF,
112	"nan":        _NAN,
113	"repeated":   _REPEATED,
114	"optional":   _OPTIONAL,
115	"required":   _REQUIRED,
116	"double":     _DOUBLE,
117	"float":      _FLOAT,
118	"int32":      _INT32,
119	"int64":      _INT64,
120	"uint32":     _UINT32,
121	"uint64":     _UINT64,
122	"sint32":     _SINT32,
123	"sint64":     _SINT64,
124	"fixed32":    _FIXED32,
125	"fixed64":    _FIXED64,
126	"sfixed32":   _SFIXED32,
127	"sfixed64":   _SFIXED64,
128	"bool":       _BOOL,
129	"string":     _STRING,
130	"bytes":      _BYTES,
131	"group":      _GROUP,
132	"oneof":      _ONEOF,
133	"map":        _MAP,
134	"extensions": _EXTENSIONS,
135	"to":         _TO,
136	"max":        _MAX,
137	"reserved":   _RESERVED,
138	"enum":       _ENUM,
139	"message":    _MESSAGE,
140	"extend":     _EXTEND,
141	"service":    _SERVICE,
142	"rpc":        _RPC,
143	"stream":     _STREAM,
144	"returns":    _RETURNS,
145}
146
147func (l *protoLex) cur() SourcePos {
148	return SourcePos{
149		Filename: l.filename,
150		Offset:   l.offset,
151		Line:     l.lineNo + 1,
152		Col:      l.colNo + 1,
153	}
154}
155
156func (l *protoLex) adjustPos(consumedChars ...rune) {
157	for _, c := range consumedChars {
158		switch c {
159		case '\n':
160			// new line, back to first column
161			l.colNo = 0
162			l.lineNo++
163		case '\r':
164			// no adjustment
165		case '\t':
166			// advance to next tab stop
167			mod := l.colNo % 8
168			l.colNo += 8 - mod
169		default:
170			l.colNo++
171		}
172	}
173}
174
175func (l *protoLex) prev() *SourcePos {
176	if l.prevSym == nil {
177		return &SourcePos{
178			Filename: l.filename,
179			Offset:   0,
180			Line:     1,
181			Col:      1,
182		}
183	}
184	return l.prevSym.Start()
185}
186
187func (l *protoLex) Lex(lval *protoSymType) int {
188	if l.errs.err != nil {
189		// if error reporter already returned non-nil error,
190		// we can skip the rest of the input
191		return 0
192	}
193
194	l.prevLineNo = l.lineNo
195	l.prevColNo = l.colNo
196	l.prevOffset = l.offset
197	l.comments = nil
198	l.ws = nil
199	l.input.endMark() // reset, just in case
200
201	for {
202		c, n, err := l.input.readRune()
203		if err == io.EOF {
204			// we're not actually returning a rune, but this will associate
205			// accumulated comments as a trailing comment on last symbol
206			// (if appropriate)
207			l.setRune(lval, 0)
208			l.eof = lval.b
209			return 0
210		} else if err != nil {
211			// we don't call setError because we don't want it wrapped
212			// with a source position because it's I/O, not syntax
213			lval.err = err
214			_ = l.errs.handleError(err)
215			return _ERROR
216		}
217
218		l.prevLineNo = l.lineNo
219		l.prevColNo = l.colNo
220		l.prevOffset = l.offset
221
222		l.offset += n
223		l.adjustPos(c)
224		if strings.ContainsRune("\n\r\t\f\v ", c) {
225			l.ws = append(l.ws, c)
226			continue
227		}
228
229		l.input.startMark(c)
230		if c == '.' {
231			// decimal literals could start with a dot
232			cn, _, err := l.input.readRune()
233			if err != nil {
234				l.setRune(lval, c)
235				return int(c)
236			}
237			if cn >= '0' && cn <= '9' {
238				l.adjustPos(cn)
239				token := l.readNumber(c, cn)
240				f, err := strconv.ParseFloat(token, 64)
241				if err != nil {
242					l.setError(lval, numError(err, "float", token))
243					return _ERROR
244				}
245				l.setFloat(lval, f)
246				return _FLOAT_LIT
247			}
248			l.input.unreadRune(cn)
249			l.setRune(lval, c)
250			return int(c)
251		}
252
253		if c == '_' || (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') {
254			// identifier
255			token := []rune{c}
256			token = l.readIdentifier(token)
257			str := string(token)
258			if t, ok := keywords[str]; ok {
259				l.setIdent(lval, str)
260				return t
261			}
262			l.setIdent(lval, str)
263			return _NAME
264		}
265
266		if c >= '0' && c <= '9' {
267			// integer or float literal
268			token := l.readNumber(c)
269			if strings.HasPrefix(token, "0x") || strings.HasPrefix(token, "0X") {
270				// hexadecimal
271				ui, err := strconv.ParseUint(token[2:], 16, 64)
272				if err != nil {
273					l.setError(lval, numError(err, "hexadecimal integer", token[2:]))
274					return _ERROR
275				}
276				l.setInt(lval, ui)
277				return _INT_LIT
278			}
279			if strings.Contains(token, ".") || strings.Contains(token, "e") || strings.Contains(token, "E") {
280				// floating point!
281				f, err := strconv.ParseFloat(token, 64)
282				if err != nil {
283					l.setError(lval, numError(err, "float", token))
284					return _ERROR
285				}
286				l.setFloat(lval, f)
287				return _FLOAT_LIT
288			}
289			// integer! (decimal or octal)
290			ui, err := strconv.ParseUint(token, 0, 64)
291			if err != nil {
292				kind := "integer"
293				if numErr, ok := err.(*strconv.NumError); ok && numErr.Err == strconv.ErrRange {
294					// if it's too big to be an int, parse it as a float
295					var f float64
296					kind = "float"
297					f, err = strconv.ParseFloat(token, 64)
298					if err == nil {
299						l.setFloat(lval, f)
300						return _FLOAT_LIT
301					}
302				}
303				l.setError(lval, numError(err, kind, token))
304				return _ERROR
305			}
306			l.setInt(lval, ui)
307			return _INT_LIT
308		}
309
310		if c == '\'' || c == '"' {
311			// string literal
312			str, err := l.readStringLiteral(c)
313			if err != nil {
314				l.setError(lval, err)
315				return _ERROR
316			}
317			l.setString(lval, str)
318			return _STRING_LIT
319		}
320
321		if c == '/' {
322			// comment
323			cn, _, err := l.input.readRune()
324			if err != nil {
325				l.setRune(lval, '/')
326				return int(c)
327			}
328			if cn == '/' {
329				l.adjustPos(cn)
330				hitNewline := l.skipToEndOfLineComment()
331				comment := l.newComment()
332				comment.PosRange.End.Col++
333				if hitNewline {
334					// we don't do this inside of skipToEndOfLineComment
335					// because we want to know the length of previous
336					// line for calculation above
337					l.adjustPos('\n')
338				}
339				l.comments = append(l.comments, comment)
340				continue
341			}
342			if cn == '*' {
343				l.adjustPos(cn)
344				if ok := l.skipToEndOfBlockComment(); !ok {
345					l.setError(lval, errors.New("block comment never terminates, unexpected EOF"))
346					return _ERROR
347				} else {
348					l.comments = append(l.comments, l.newComment())
349				}
350				continue
351			}
352			l.input.unreadRune(cn)
353		}
354
355		if c > 255 {
356			l.setError(lval, errors.New("invalid character"))
357			return _ERROR
358		}
359		l.setRune(lval, c)
360		return int(c)
361	}
362}
363
364func (l *protoLex) posRange() ast.PosRange {
365	return ast.PosRange{
366		Start: SourcePos{
367			Filename: l.filename,
368			Offset:   l.prevOffset,
369			Line:     l.prevLineNo + 1,
370			Col:      l.prevColNo + 1,
371		},
372		End: l.cur(),
373	}
374}
375
376func (l *protoLex) newComment() ast.Comment {
377	ws := string(l.ws)
378	l.ws = l.ws[:0]
379	return ast.Comment{
380		PosRange:          l.posRange(),
381		LeadingWhitespace: ws,
382		Text:              l.input.endMark(),
383	}
384}
385
386func (l *protoLex) newTokenInfo() ast.TokenInfo {
387	ws := string(l.ws)
388	l.ws = nil
389	return ast.TokenInfo{
390		PosRange:          l.posRange(),
391		LeadingComments:   l.comments,
392		LeadingWhitespace: ws,
393		RawText:           l.input.endMark(),
394	}
395}
396
397func (l *protoLex) setPrev(n ast.TerminalNode, isDot bool) {
398	nStart := n.Start().Line
399	if _, ok := n.(*ast.RuneNode); ok {
400		// This is really gross, but there are many cases where we don't want
401		// to attribute comments to punctuation (like commas, equals, semicolons)
402		// and would instead prefer to attribute comments to a more meaningful
403		// element in the AST.
404		//
405		// So if it's a simple node OTHER THAN PERIOD (since that is not just
406		// punctuation but typically part of a qualified identifier), don't
407		// attribute comments to it. We do that with this TOTAL HACK: adjusting
408		// the start line makes leading comments appear detached so logic below
409		// will naturally associated trailing comment to previous symbol
410		if !isDot {
411			nStart += 2
412		}
413	}
414	if l.prevSym != nil && len(n.LeadingComments()) > 0 && l.prevSym.End().Line < nStart {
415		// we may need to re-attribute the first comment to
416		// instead be previous node's trailing comment
417		prevEnd := l.prevSym.End().Line
418		comments := n.LeadingComments()
419		c := comments[0]
420		commentStart := c.Start.Line
421		if commentStart == prevEnd {
422			// comment is on same line as previous symbol
423			n.PopLeadingComment()
424			l.prevSym.PushTrailingComment(c)
425		} else if commentStart == prevEnd+1 {
426			// comment is right after previous symbol; see if it is detached
427			// and if so re-attribute
428			singleLineStyle := strings.HasPrefix(c.Text, "//")
429			line := c.End.Line
430			groupEnd := -1
431			for i := 1; i < len(comments); i++ {
432				c := comments[i]
433				newGroup := false
434				if !singleLineStyle || c.Start.Line > line+1 {
435					// we've found a gap between comments, which means the
436					// previous comments were detached
437					newGroup = true
438				} else {
439					line = c.End.Line
440					singleLineStyle = strings.HasPrefix(comments[i].Text, "//")
441					if !singleLineStyle {
442						// we've found a switch from // comments to /*
443						// consider that a new group which means the
444						// previous comments were detached
445						newGroup = true
446					}
447				}
448				if newGroup {
449					groupEnd = i
450					break
451				}
452			}
453
454			if groupEnd == -1 {
455				// just one group of comments; we'll mark it as a trailing
456				// comment if it immediately follows previous symbol and is
457				// detached from current symbol
458				c1 := comments[0]
459				c2 := comments[len(comments)-1]
460				if c1.Start.Line <= prevEnd+1 && c2.End.Line < nStart-1 {
461					groupEnd = len(comments)
462				}
463			}
464
465			for i := 0; i < groupEnd; i++ {
466				l.prevSym.PushTrailingComment(n.PopLeadingComment())
467			}
468		}
469	}
470
471	l.prevSym = n
472}
473
474func (l *protoLex) setString(lval *protoSymType, val string) {
475	lval.s = ast.NewStringLiteralNode(val, l.newTokenInfo())
476	l.setPrev(lval.s, false)
477}
478
479func (l *protoLex) setIdent(lval *protoSymType, val string) {
480	lval.id = ast.NewIdentNode(val, l.newTokenInfo())
481	l.setPrev(lval.id, false)
482}
483
484func (l *protoLex) setInt(lval *protoSymType, val uint64) {
485	lval.i = ast.NewUintLiteralNode(val, l.newTokenInfo())
486	l.setPrev(lval.i, false)
487}
488
489func (l *protoLex) setFloat(lval *protoSymType, val float64) {
490	lval.f = ast.NewFloatLiteralNode(val, l.newTokenInfo())
491	l.setPrev(lval.f, false)
492}
493
494func (l *protoLex) setRune(lval *protoSymType, val rune) {
495	lval.b = ast.NewRuneNode(val, l.newTokenInfo())
496	l.setPrev(lval.b, val == '.')
497}
498
499func (l *protoLex) setError(lval *protoSymType, err error) {
500	lval.err = l.addSourceError(err)
501}
502
503func (l *protoLex) readNumber(sofar ...rune) string {
504	token := sofar
505	allowExpSign := false
506	for {
507		c, _, err := l.input.readRune()
508		if err != nil {
509			break
510		}
511		if (c == '-' || c == '+') && !allowExpSign {
512			l.input.unreadRune(c)
513			break
514		}
515		allowExpSign = false
516		if c != '.' && c != '_' && (c < '0' || c > '9') &&
517			(c < 'a' || c > 'z') && (c < 'A' || c > 'Z') &&
518			c != '-' && c != '+' {
519			// no more chars in the number token
520			l.input.unreadRune(c)
521			break
522		}
523		if c == 'e' || c == 'E' {
524			// scientific notation char can be followed by
525			// an exponent sign
526			allowExpSign = true
527		}
528		l.adjustPos(c)
529		token = append(token, c)
530	}
531	return string(token)
532}
533
534func numError(err error, kind, s string) error {
535	ne, ok := err.(*strconv.NumError)
536	if !ok {
537		return err
538	}
539	if ne.Err == strconv.ErrRange {
540		return fmt.Errorf("value out of range for %s: %s", kind, s)
541	}
542	// syntax error
543	return fmt.Errorf("invalid syntax in %s value: %s", kind, s)
544}
545
546func (l *protoLex) readIdentifier(sofar []rune) []rune {
547	token := sofar
548	for {
549		c, _, err := l.input.readRune()
550		if err != nil {
551			break
552		}
553		if c != '_' && (c < 'a' || c > 'z') && (c < 'A' || c > 'Z') && (c < '0' || c > '9') {
554			l.input.unreadRune(c)
555			break
556		}
557		l.adjustPos(c)
558		token = append(token, c)
559	}
560	return token
561}
562
563func (l *protoLex) readStringLiteral(quote rune) (string, error) {
564	var buf bytes.Buffer
565	for {
566		c, _, err := l.input.readRune()
567		if err != nil {
568			if err == io.EOF {
569				err = io.ErrUnexpectedEOF
570			}
571			return "", err
572		}
573		if c == '\n' {
574			return "", errors.New("encountered end-of-line before end of string literal")
575		}
576		l.adjustPos(c)
577		if c == quote {
578			break
579		}
580		if c == 0 {
581			return "", errors.New("null character ('\\0') not allowed in string literal")
582		}
583		if c == '\\' {
584			// escape sequence
585			c, _, err = l.input.readRune()
586			if err != nil {
587				return "", err
588			}
589			l.adjustPos(c)
590			if c == 'x' || c == 'X' {
591				// hex escape
592				c, _, err := l.input.readRune()
593				if err != nil {
594					return "", err
595				}
596				l.adjustPos(c)
597				c2, _, err := l.input.readRune()
598				if err != nil {
599					return "", err
600				}
601				var hex string
602				if (c2 < '0' || c2 > '9') && (c2 < 'a' || c2 > 'f') && (c2 < 'A' || c2 > 'F') {
603					l.input.unreadRune(c2)
604					hex = string(c)
605				} else {
606					l.adjustPos(c2)
607					hex = string([]rune{c, c2})
608				}
609				i, err := strconv.ParseInt(hex, 16, 32)
610				if err != nil {
611					return "", fmt.Errorf("invalid hex escape: \\x%q", hex)
612				}
613				buf.WriteByte(byte(i))
614
615			} else if c >= '0' && c <= '7' {
616				// octal escape
617				c2, _, err := l.input.readRune()
618				if err != nil {
619					return "", err
620				}
621				var octal string
622				if c2 < '0' || c2 > '7' {
623					l.input.unreadRune(c2)
624					octal = string(c)
625				} else {
626					l.adjustPos(c2)
627					c3, _, err := l.input.readRune()
628					if err != nil {
629						return "", err
630					}
631					if c3 < '0' || c3 > '7' {
632						l.input.unreadRune(c3)
633						octal = string([]rune{c, c2})
634					} else {
635						l.adjustPos(c3)
636						octal = string([]rune{c, c2, c3})
637					}
638				}
639				i, err := strconv.ParseInt(octal, 8, 32)
640				if err != nil {
641					return "", fmt.Errorf("invalid octal escape: \\%q", octal)
642				}
643				if i > 0xff {
644					return "", fmt.Errorf("octal escape is out range, must be between 0 and 377: \\%q", octal)
645				}
646				buf.WriteByte(byte(i))
647
648			} else if c == 'u' {
649				// short unicode escape
650				u := make([]rune, 4)
651				for i := range u {
652					c, _, err := l.input.readRune()
653					if err != nil {
654						return "", err
655					}
656					l.adjustPos(c)
657					u[i] = c
658				}
659				i, err := strconv.ParseInt(string(u), 16, 32)
660				if err != nil {
661					return "", fmt.Errorf("invalid unicode escape: \\u%q", string(u))
662				}
663				buf.WriteRune(rune(i))
664
665			} else if c == 'U' {
666				// long unicode escape
667				u := make([]rune, 8)
668				for i := range u {
669					c, _, err := l.input.readRune()
670					if err != nil {
671						return "", err
672					}
673					l.adjustPos(c)
674					u[i] = c
675				}
676				i, err := strconv.ParseInt(string(u), 16, 32)
677				if err != nil {
678					return "", fmt.Errorf("invalid unicode escape: \\U%q", string(u))
679				}
680				if i > 0x10ffff || i < 0 {
681					return "", fmt.Errorf("unicode escape is out of range, must be between 0 and 0x10ffff: \\U%q", string(u))
682				}
683				buf.WriteRune(rune(i))
684
685			} else if c == 'a' {
686				buf.WriteByte('\a')
687			} else if c == 'b' {
688				buf.WriteByte('\b')
689			} else if c == 'f' {
690				buf.WriteByte('\f')
691			} else if c == 'n' {
692				buf.WriteByte('\n')
693			} else if c == 'r' {
694				buf.WriteByte('\r')
695			} else if c == 't' {
696				buf.WriteByte('\t')
697			} else if c == 'v' {
698				buf.WriteByte('\v')
699			} else if c == '\\' {
700				buf.WriteByte('\\')
701			} else if c == '\'' {
702				buf.WriteByte('\'')
703			} else if c == '"' {
704				buf.WriteByte('"')
705			} else if c == '?' {
706				buf.WriteByte('?')
707			} else {
708				return "", fmt.Errorf("invalid escape sequence: %q", "\\"+string(c))
709			}
710		} else {
711			buf.WriteRune(c)
712		}
713	}
714	return buf.String(), nil
715}
716
717func (l *protoLex) skipToEndOfLineComment() bool {
718	for {
719		c, _, err := l.input.readRune()
720		if err != nil {
721			return false
722		}
723		if c == '\n' {
724			return true
725		}
726		l.adjustPos(c)
727	}
728}
729
730func (l *protoLex) skipToEndOfBlockComment() bool {
731	for {
732		c, _, err := l.input.readRune()
733		if err != nil {
734			return false
735		}
736		l.adjustPos(c)
737		if c == '*' {
738			c, _, err := l.input.readRune()
739			if err != nil {
740				return false
741			}
742			if c == '/' {
743				l.adjustPos(c)
744				return true
745			}
746			l.input.unreadRune(c)
747		}
748	}
749}
750
751func (l *protoLex) addSourceError(err error) ErrorWithPos {
752	ewp, ok := err.(ErrorWithPos)
753	if !ok {
754		ewp = ErrorWithSourcePos{Pos: l.prev(), Underlying: err}
755	}
756	_ = l.errs.handleError(ewp)
757	return ewp
758}
759
760func (l *protoLex) Error(s string) {
761	_ = l.addSourceError(errors.New(s))
762}
763