1package protoparse 2 3import ( 4 "bufio" 5 "bytes" 6 "errors" 7 "fmt" 8 "io" 9 "strconv" 10 "strings" 11 "unicode/utf8" 12 13 "github.com/jhump/protoreflect/desc/protoparse/ast" 14) 15 16type runeReader struct { 17 rr *bufio.Reader 18 marked []rune 19 unread []rune 20 err error 21} 22 23func (rr *runeReader) readRune() (r rune, size int, err error) { 24 if rr.err != nil { 25 return 0, 0, rr.err 26 } 27 if len(rr.unread) > 0 { 28 r := rr.unread[len(rr.unread)-1] 29 rr.unread = rr.unread[:len(rr.unread)-1] 30 if rr.marked != nil { 31 rr.marked = append(rr.marked, r) 32 } 33 return r, utf8.RuneLen(r), nil 34 } 35 r, sz, err := rr.rr.ReadRune() 36 if err != nil { 37 rr.err = err 38 } else if rr.marked != nil { 39 rr.marked = append(rr.marked, r) 40 } 41 return r, sz, err 42} 43 44func (rr *runeReader) unreadRune(r rune) { 45 if rr.marked != nil { 46 if rr.marked[len(rr.marked)-1] != r { 47 panic("unread rune is not the same as last marked rune!") 48 } 49 rr.marked = rr.marked[:len(rr.marked)-1] 50 } 51 rr.unread = append(rr.unread, r) 52} 53 54func (rr *runeReader) startMark(initial rune) { 55 rr.marked = []rune{initial} 56} 57 58func (rr *runeReader) endMark() string { 59 m := string(rr.marked) 60 rr.marked = rr.marked[:0] 61 return m 62} 63 64type protoLex struct { 65 filename string 66 input *runeReader 67 errs *errorHandler 68 res *ast.FileNode 69 70 lineNo int 71 colNo int 72 offset int 73 74 prevSym ast.TerminalNode 75 eof ast.TerminalNode 76 77 prevLineNo int 78 prevColNo int 79 prevOffset int 80 comments []ast.Comment 81 ws []rune 82} 83 84var utf8Bom = []byte{0xEF, 0xBB, 0xBF} 85 86func newLexer(in io.Reader, filename string, errs *errorHandler) *protoLex { 87 br := bufio.NewReader(in) 88 89 // if file has UTF8 byte order marker preface, consume it 90 marker, err := br.Peek(3) 91 if err == nil && bytes.Equal(marker, utf8Bom) { 92 _, _ = br.Discard(3) 93 } 94 95 return &protoLex{ 96 input: &runeReader{rr: br}, 97 filename: filename, 98 errs: errs, 99 } 100} 101 102var keywords = map[string]int{ 103 "syntax": _SYNTAX, 104 "import": _IMPORT, 105 "weak": _WEAK, 106 "public": _PUBLIC, 107 "package": _PACKAGE, 108 "option": _OPTION, 109 "true": _TRUE, 110 "false": _FALSE, 111 "inf": _INF, 112 "nan": _NAN, 113 "repeated": _REPEATED, 114 "optional": _OPTIONAL, 115 "required": _REQUIRED, 116 "double": _DOUBLE, 117 "float": _FLOAT, 118 "int32": _INT32, 119 "int64": _INT64, 120 "uint32": _UINT32, 121 "uint64": _UINT64, 122 "sint32": _SINT32, 123 "sint64": _SINT64, 124 "fixed32": _FIXED32, 125 "fixed64": _FIXED64, 126 "sfixed32": _SFIXED32, 127 "sfixed64": _SFIXED64, 128 "bool": _BOOL, 129 "string": _STRING, 130 "bytes": _BYTES, 131 "group": _GROUP, 132 "oneof": _ONEOF, 133 "map": _MAP, 134 "extensions": _EXTENSIONS, 135 "to": _TO, 136 "max": _MAX, 137 "reserved": _RESERVED, 138 "enum": _ENUM, 139 "message": _MESSAGE, 140 "extend": _EXTEND, 141 "service": _SERVICE, 142 "rpc": _RPC, 143 "stream": _STREAM, 144 "returns": _RETURNS, 145} 146 147func (l *protoLex) cur() SourcePos { 148 return SourcePos{ 149 Filename: l.filename, 150 Offset: l.offset, 151 Line: l.lineNo + 1, 152 Col: l.colNo + 1, 153 } 154} 155 156func (l *protoLex) adjustPos(consumedChars ...rune) { 157 for _, c := range consumedChars { 158 switch c { 159 case '\n': 160 // new line, back to first column 161 l.colNo = 0 162 l.lineNo++ 163 case '\r': 164 // no adjustment 165 case '\t': 166 // advance to next tab stop 167 mod := l.colNo % 8 168 l.colNo += 8 - mod 169 default: 170 l.colNo++ 171 } 172 } 173} 174 175func (l *protoLex) prev() *SourcePos { 176 if l.prevSym == nil { 177 return &SourcePos{ 178 Filename: l.filename, 179 Offset: 0, 180 Line: 1, 181 Col: 1, 182 } 183 } 184 return l.prevSym.Start() 185} 186 187func (l *protoLex) Lex(lval *protoSymType) int { 188 if l.errs.err != nil { 189 // if error reporter already returned non-nil error, 190 // we can skip the rest of the input 191 return 0 192 } 193 194 l.prevLineNo = l.lineNo 195 l.prevColNo = l.colNo 196 l.prevOffset = l.offset 197 l.comments = nil 198 l.ws = nil 199 l.input.endMark() // reset, just in case 200 201 for { 202 c, n, err := l.input.readRune() 203 if err == io.EOF { 204 // we're not actually returning a rune, but this will associate 205 // accumulated comments as a trailing comment on last symbol 206 // (if appropriate) 207 l.setRune(lval, 0) 208 l.eof = lval.b 209 return 0 210 } else if err != nil { 211 // we don't call setError because we don't want it wrapped 212 // with a source position because it's I/O, not syntax 213 lval.err = err 214 _ = l.errs.handleError(err) 215 return _ERROR 216 } 217 218 l.prevLineNo = l.lineNo 219 l.prevColNo = l.colNo 220 l.prevOffset = l.offset 221 222 l.offset += n 223 l.adjustPos(c) 224 if strings.ContainsRune("\n\r\t\f\v ", c) { 225 l.ws = append(l.ws, c) 226 continue 227 } 228 229 l.input.startMark(c) 230 if c == '.' { 231 // decimal literals could start with a dot 232 cn, _, err := l.input.readRune() 233 if err != nil { 234 l.setRune(lval, c) 235 return int(c) 236 } 237 if cn >= '0' && cn <= '9' { 238 l.adjustPos(cn) 239 token := l.readNumber(c, cn) 240 f, err := strconv.ParseFloat(token, 64) 241 if err != nil { 242 l.setError(lval, numError(err, "float", token)) 243 return _ERROR 244 } 245 l.setFloat(lval, f) 246 return _FLOAT_LIT 247 } 248 l.input.unreadRune(cn) 249 l.setRune(lval, c) 250 return int(c) 251 } 252 253 if c == '_' || (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') { 254 // identifier 255 token := []rune{c} 256 token = l.readIdentifier(token) 257 str := string(token) 258 if t, ok := keywords[str]; ok { 259 l.setIdent(lval, str) 260 return t 261 } 262 l.setIdent(lval, str) 263 return _NAME 264 } 265 266 if c >= '0' && c <= '9' { 267 // integer or float literal 268 token := l.readNumber(c) 269 if strings.HasPrefix(token, "0x") || strings.HasPrefix(token, "0X") { 270 // hexadecimal 271 ui, err := strconv.ParseUint(token[2:], 16, 64) 272 if err != nil { 273 l.setError(lval, numError(err, "hexadecimal integer", token[2:])) 274 return _ERROR 275 } 276 l.setInt(lval, ui) 277 return _INT_LIT 278 } 279 if strings.Contains(token, ".") || strings.Contains(token, "e") || strings.Contains(token, "E") { 280 // floating point! 281 f, err := strconv.ParseFloat(token, 64) 282 if err != nil { 283 l.setError(lval, numError(err, "float", token)) 284 return _ERROR 285 } 286 l.setFloat(lval, f) 287 return _FLOAT_LIT 288 } 289 // integer! (decimal or octal) 290 ui, err := strconv.ParseUint(token, 0, 64) 291 if err != nil { 292 kind := "integer" 293 if numErr, ok := err.(*strconv.NumError); ok && numErr.Err == strconv.ErrRange { 294 // if it's too big to be an int, parse it as a float 295 var f float64 296 kind = "float" 297 f, err = strconv.ParseFloat(token, 64) 298 if err == nil { 299 l.setFloat(lval, f) 300 return _FLOAT_LIT 301 } 302 } 303 l.setError(lval, numError(err, kind, token)) 304 return _ERROR 305 } 306 l.setInt(lval, ui) 307 return _INT_LIT 308 } 309 310 if c == '\'' || c == '"' { 311 // string literal 312 str, err := l.readStringLiteral(c) 313 if err != nil { 314 l.setError(lval, err) 315 return _ERROR 316 } 317 l.setString(lval, str) 318 return _STRING_LIT 319 } 320 321 if c == '/' { 322 // comment 323 cn, _, err := l.input.readRune() 324 if err != nil { 325 l.setRune(lval, '/') 326 return int(c) 327 } 328 if cn == '/' { 329 l.adjustPos(cn) 330 hitNewline := l.skipToEndOfLineComment() 331 comment := l.newComment() 332 comment.PosRange.End.Col++ 333 if hitNewline { 334 // we don't do this inside of skipToEndOfLineComment 335 // because we want to know the length of previous 336 // line for calculation above 337 l.adjustPos('\n') 338 } 339 l.comments = append(l.comments, comment) 340 continue 341 } 342 if cn == '*' { 343 l.adjustPos(cn) 344 if ok := l.skipToEndOfBlockComment(); !ok { 345 l.setError(lval, errors.New("block comment never terminates, unexpected EOF")) 346 return _ERROR 347 } else { 348 l.comments = append(l.comments, l.newComment()) 349 } 350 continue 351 } 352 l.input.unreadRune(cn) 353 } 354 355 if c > 255 { 356 l.setError(lval, errors.New("invalid character")) 357 return _ERROR 358 } 359 l.setRune(lval, c) 360 return int(c) 361 } 362} 363 364func (l *protoLex) posRange() ast.PosRange { 365 return ast.PosRange{ 366 Start: SourcePos{ 367 Filename: l.filename, 368 Offset: l.prevOffset, 369 Line: l.prevLineNo + 1, 370 Col: l.prevColNo + 1, 371 }, 372 End: l.cur(), 373 } 374} 375 376func (l *protoLex) newComment() ast.Comment { 377 ws := string(l.ws) 378 l.ws = l.ws[:0] 379 return ast.Comment{ 380 PosRange: l.posRange(), 381 LeadingWhitespace: ws, 382 Text: l.input.endMark(), 383 } 384} 385 386func (l *protoLex) newTokenInfo() ast.TokenInfo { 387 ws := string(l.ws) 388 l.ws = nil 389 return ast.TokenInfo{ 390 PosRange: l.posRange(), 391 LeadingComments: l.comments, 392 LeadingWhitespace: ws, 393 RawText: l.input.endMark(), 394 } 395} 396 397func (l *protoLex) setPrev(n ast.TerminalNode, isDot bool) { 398 nStart := n.Start().Line 399 if _, ok := n.(*ast.RuneNode); ok { 400 // This is really gross, but there are many cases where we don't want 401 // to attribute comments to punctuation (like commas, equals, semicolons) 402 // and would instead prefer to attribute comments to a more meaningful 403 // element in the AST. 404 // 405 // So if it's a simple node OTHER THAN PERIOD (since that is not just 406 // punctuation but typically part of a qualified identifier), don't 407 // attribute comments to it. We do that with this TOTAL HACK: adjusting 408 // the start line makes leading comments appear detached so logic below 409 // will naturally associated trailing comment to previous symbol 410 if !isDot { 411 nStart += 2 412 } 413 } 414 if l.prevSym != nil && len(n.LeadingComments()) > 0 && l.prevSym.End().Line < nStart { 415 // we may need to re-attribute the first comment to 416 // instead be previous node's trailing comment 417 prevEnd := l.prevSym.End().Line 418 comments := n.LeadingComments() 419 c := comments[0] 420 commentStart := c.Start.Line 421 if commentStart == prevEnd { 422 // comment is on same line as previous symbol 423 n.PopLeadingComment() 424 l.prevSym.PushTrailingComment(c) 425 } else if commentStart == prevEnd+1 { 426 // comment is right after previous symbol; see if it is detached 427 // and if so re-attribute 428 singleLineStyle := strings.HasPrefix(c.Text, "//") 429 line := c.End.Line 430 groupEnd := -1 431 for i := 1; i < len(comments); i++ { 432 c := comments[i] 433 newGroup := false 434 if !singleLineStyle || c.Start.Line > line+1 { 435 // we've found a gap between comments, which means the 436 // previous comments were detached 437 newGroup = true 438 } else { 439 line = c.End.Line 440 singleLineStyle = strings.HasPrefix(comments[i].Text, "//") 441 if !singleLineStyle { 442 // we've found a switch from // comments to /* 443 // consider that a new group which means the 444 // previous comments were detached 445 newGroup = true 446 } 447 } 448 if newGroup { 449 groupEnd = i 450 break 451 } 452 } 453 454 if groupEnd == -1 { 455 // just one group of comments; we'll mark it as a trailing 456 // comment if it immediately follows previous symbol and is 457 // detached from current symbol 458 c1 := comments[0] 459 c2 := comments[len(comments)-1] 460 if c1.Start.Line <= prevEnd+1 && c2.End.Line < nStart-1 { 461 groupEnd = len(comments) 462 } 463 } 464 465 for i := 0; i < groupEnd; i++ { 466 l.prevSym.PushTrailingComment(n.PopLeadingComment()) 467 } 468 } 469 } 470 471 l.prevSym = n 472} 473 474func (l *protoLex) setString(lval *protoSymType, val string) { 475 lval.s = ast.NewStringLiteralNode(val, l.newTokenInfo()) 476 l.setPrev(lval.s, false) 477} 478 479func (l *protoLex) setIdent(lval *protoSymType, val string) { 480 lval.id = ast.NewIdentNode(val, l.newTokenInfo()) 481 l.setPrev(lval.id, false) 482} 483 484func (l *protoLex) setInt(lval *protoSymType, val uint64) { 485 lval.i = ast.NewUintLiteralNode(val, l.newTokenInfo()) 486 l.setPrev(lval.i, false) 487} 488 489func (l *protoLex) setFloat(lval *protoSymType, val float64) { 490 lval.f = ast.NewFloatLiteralNode(val, l.newTokenInfo()) 491 l.setPrev(lval.f, false) 492} 493 494func (l *protoLex) setRune(lval *protoSymType, val rune) { 495 lval.b = ast.NewRuneNode(val, l.newTokenInfo()) 496 l.setPrev(lval.b, val == '.') 497} 498 499func (l *protoLex) setError(lval *protoSymType, err error) { 500 lval.err = l.addSourceError(err) 501} 502 503func (l *protoLex) readNumber(sofar ...rune) string { 504 token := sofar 505 allowExpSign := false 506 for { 507 c, _, err := l.input.readRune() 508 if err != nil { 509 break 510 } 511 if (c == '-' || c == '+') && !allowExpSign { 512 l.input.unreadRune(c) 513 break 514 } 515 allowExpSign = false 516 if c != '.' && c != '_' && (c < '0' || c > '9') && 517 (c < 'a' || c > 'z') && (c < 'A' || c > 'Z') && 518 c != '-' && c != '+' { 519 // no more chars in the number token 520 l.input.unreadRune(c) 521 break 522 } 523 if c == 'e' || c == 'E' { 524 // scientific notation char can be followed by 525 // an exponent sign 526 allowExpSign = true 527 } 528 l.adjustPos(c) 529 token = append(token, c) 530 } 531 return string(token) 532} 533 534func numError(err error, kind, s string) error { 535 ne, ok := err.(*strconv.NumError) 536 if !ok { 537 return err 538 } 539 if ne.Err == strconv.ErrRange { 540 return fmt.Errorf("value out of range for %s: %s", kind, s) 541 } 542 // syntax error 543 return fmt.Errorf("invalid syntax in %s value: %s", kind, s) 544} 545 546func (l *protoLex) readIdentifier(sofar []rune) []rune { 547 token := sofar 548 for { 549 c, _, err := l.input.readRune() 550 if err != nil { 551 break 552 } 553 if c != '_' && (c < 'a' || c > 'z') && (c < 'A' || c > 'Z') && (c < '0' || c > '9') { 554 l.input.unreadRune(c) 555 break 556 } 557 l.adjustPos(c) 558 token = append(token, c) 559 } 560 return token 561} 562 563func (l *protoLex) readStringLiteral(quote rune) (string, error) { 564 var buf bytes.Buffer 565 for { 566 c, _, err := l.input.readRune() 567 if err != nil { 568 if err == io.EOF { 569 err = io.ErrUnexpectedEOF 570 } 571 return "", err 572 } 573 if c == '\n' { 574 return "", errors.New("encountered end-of-line before end of string literal") 575 } 576 l.adjustPos(c) 577 if c == quote { 578 break 579 } 580 if c == 0 { 581 return "", errors.New("null character ('\\0') not allowed in string literal") 582 } 583 if c == '\\' { 584 // escape sequence 585 c, _, err = l.input.readRune() 586 if err != nil { 587 return "", err 588 } 589 l.adjustPos(c) 590 if c == 'x' || c == 'X' { 591 // hex escape 592 c, _, err := l.input.readRune() 593 if err != nil { 594 return "", err 595 } 596 l.adjustPos(c) 597 c2, _, err := l.input.readRune() 598 if err != nil { 599 return "", err 600 } 601 var hex string 602 if (c2 < '0' || c2 > '9') && (c2 < 'a' || c2 > 'f') && (c2 < 'A' || c2 > 'F') { 603 l.input.unreadRune(c2) 604 hex = string(c) 605 } else { 606 l.adjustPos(c2) 607 hex = string([]rune{c, c2}) 608 } 609 i, err := strconv.ParseInt(hex, 16, 32) 610 if err != nil { 611 return "", fmt.Errorf("invalid hex escape: \\x%q", hex) 612 } 613 buf.WriteByte(byte(i)) 614 615 } else if c >= '0' && c <= '7' { 616 // octal escape 617 c2, _, err := l.input.readRune() 618 if err != nil { 619 return "", err 620 } 621 var octal string 622 if c2 < '0' || c2 > '7' { 623 l.input.unreadRune(c2) 624 octal = string(c) 625 } else { 626 l.adjustPos(c2) 627 c3, _, err := l.input.readRune() 628 if err != nil { 629 return "", err 630 } 631 if c3 < '0' || c3 > '7' { 632 l.input.unreadRune(c3) 633 octal = string([]rune{c, c2}) 634 } else { 635 l.adjustPos(c3) 636 octal = string([]rune{c, c2, c3}) 637 } 638 } 639 i, err := strconv.ParseInt(octal, 8, 32) 640 if err != nil { 641 return "", fmt.Errorf("invalid octal escape: \\%q", octal) 642 } 643 if i > 0xff { 644 return "", fmt.Errorf("octal escape is out range, must be between 0 and 377: \\%q", octal) 645 } 646 buf.WriteByte(byte(i)) 647 648 } else if c == 'u' { 649 // short unicode escape 650 u := make([]rune, 4) 651 for i := range u { 652 c, _, err := l.input.readRune() 653 if err != nil { 654 return "", err 655 } 656 l.adjustPos(c) 657 u[i] = c 658 } 659 i, err := strconv.ParseInt(string(u), 16, 32) 660 if err != nil { 661 return "", fmt.Errorf("invalid unicode escape: \\u%q", string(u)) 662 } 663 buf.WriteRune(rune(i)) 664 665 } else if c == 'U' { 666 // long unicode escape 667 u := make([]rune, 8) 668 for i := range u { 669 c, _, err := l.input.readRune() 670 if err != nil { 671 return "", err 672 } 673 l.adjustPos(c) 674 u[i] = c 675 } 676 i, err := strconv.ParseInt(string(u), 16, 32) 677 if err != nil { 678 return "", fmt.Errorf("invalid unicode escape: \\U%q", string(u)) 679 } 680 if i > 0x10ffff || i < 0 { 681 return "", fmt.Errorf("unicode escape is out of range, must be between 0 and 0x10ffff: \\U%q", string(u)) 682 } 683 buf.WriteRune(rune(i)) 684 685 } else if c == 'a' { 686 buf.WriteByte('\a') 687 } else if c == 'b' { 688 buf.WriteByte('\b') 689 } else if c == 'f' { 690 buf.WriteByte('\f') 691 } else if c == 'n' { 692 buf.WriteByte('\n') 693 } else if c == 'r' { 694 buf.WriteByte('\r') 695 } else if c == 't' { 696 buf.WriteByte('\t') 697 } else if c == 'v' { 698 buf.WriteByte('\v') 699 } else if c == '\\' { 700 buf.WriteByte('\\') 701 } else if c == '\'' { 702 buf.WriteByte('\'') 703 } else if c == '"' { 704 buf.WriteByte('"') 705 } else if c == '?' { 706 buf.WriteByte('?') 707 } else { 708 return "", fmt.Errorf("invalid escape sequence: %q", "\\"+string(c)) 709 } 710 } else { 711 buf.WriteRune(c) 712 } 713 } 714 return buf.String(), nil 715} 716 717func (l *protoLex) skipToEndOfLineComment() bool { 718 for { 719 c, _, err := l.input.readRune() 720 if err != nil { 721 return false 722 } 723 if c == '\n' { 724 return true 725 } 726 l.adjustPos(c) 727 } 728} 729 730func (l *protoLex) skipToEndOfBlockComment() bool { 731 for { 732 c, _, err := l.input.readRune() 733 if err != nil { 734 return false 735 } 736 l.adjustPos(c) 737 if c == '*' { 738 c, _, err := l.input.readRune() 739 if err != nil { 740 return false 741 } 742 if c == '/' { 743 l.adjustPos(c) 744 return true 745 } 746 l.input.unreadRune(c) 747 } 748 } 749} 750 751func (l *protoLex) addSourceError(err error) ErrorWithPos { 752 ewp, ok := err.(ErrorWithPos) 753 if !ok { 754 ewp = ErrorWithSourcePos{Pos: l.prev(), Underlying: err} 755 } 756 _ = l.errs.handleError(ewp) 757 return ewp 758} 759 760func (l *protoLex) Error(s string) { 761 _ = l.addSourceError(errors.New(s)) 762} 763