1// Copyright 2019 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package cache 6 7import ( 8 "bytes" 9 "context" 10 "fmt" 11 "go/ast" 12 "go/parser" 13 "go/scanner" 14 "go/token" 15 "go/types" 16 "reflect" 17 "strconv" 18 "strings" 19 20 "golang.org/x/tools/internal/event" 21 "golang.org/x/tools/internal/lsp/debug/tag" 22 "golang.org/x/tools/internal/lsp/diff" 23 "golang.org/x/tools/internal/lsp/diff/myers" 24 "golang.org/x/tools/internal/lsp/protocol" 25 "golang.org/x/tools/internal/lsp/source" 26 "golang.org/x/tools/internal/memoize" 27 "golang.org/x/tools/internal/span" 28 errors "golang.org/x/xerrors" 29) 30 31// parseKey uniquely identifies a parsed Go file. 32type parseKey struct { 33 file source.FileIdentity 34 mode source.ParseMode 35} 36 37type parseGoHandle struct { 38 handle *memoize.Handle 39 file source.FileHandle 40 mode source.ParseMode 41} 42 43type parseGoData struct { 44 parsed *source.ParsedGoFile 45 46 // If true, we adjusted the AST to make it type check better, and 47 // it may not match the source code. 48 fixed bool 49 err error // any other errors 50} 51 52func (s *snapshot) parseGoHandle(ctx context.Context, fh source.FileHandle, mode source.ParseMode) *parseGoHandle { 53 key := parseKey{ 54 file: fh.FileIdentity(), 55 mode: mode, 56 } 57 if pgh := s.getGoFile(key); pgh != nil { 58 return pgh 59 } 60 parseHandle := s.generation.Bind(key, func(ctx context.Context, arg memoize.Arg) interface{} { 61 snapshot := arg.(*snapshot) 62 return parseGo(ctx, snapshot.view.session.cache.fset, fh, mode) 63 }, nil) 64 65 pgh := &parseGoHandle{ 66 handle: parseHandle, 67 file: fh, 68 mode: mode, 69 } 70 return s.addGoFile(key, pgh) 71} 72 73func (pgh *parseGoHandle) String() string { 74 return pgh.File().URI().Filename() 75} 76 77func (pgh *parseGoHandle) File() source.FileHandle { 78 return pgh.file 79} 80 81func (pgh *parseGoHandle) Mode() source.ParseMode { 82 return pgh.mode 83} 84 85func (s *snapshot) ParseGo(ctx context.Context, fh source.FileHandle, mode source.ParseMode) (*source.ParsedGoFile, error) { 86 pgh := s.parseGoHandle(ctx, fh, mode) 87 pgf, _, err := s.parseGo(ctx, pgh) 88 return pgf, err 89} 90 91func (s *snapshot) parseGo(ctx context.Context, pgh *parseGoHandle) (*source.ParsedGoFile, bool, error) { 92 if pgh.mode == source.ParseExported { 93 panic("only type checking should use Exported") 94 } 95 d, err := pgh.handle.Get(ctx, s.generation, s) 96 if err != nil { 97 return nil, false, err 98 } 99 data := d.(*parseGoData) 100 return data.parsed, data.fixed, data.err 101} 102 103type astCacheKey struct { 104 pkg packageHandleKey 105 uri span.URI 106} 107 108func (s *snapshot) astCacheData(ctx context.Context, spkg source.Package, pos token.Pos) (*astCacheData, error) { 109 pkg := spkg.(*pkg) 110 pkgHandle := s.getPackage(pkg.m.id, pkg.mode) 111 if pkgHandle == nil { 112 return nil, fmt.Errorf("could not reconstruct package handle for %v", pkg.m.id) 113 } 114 tok := s.FileSet().File(pos) 115 if tok == nil { 116 return nil, fmt.Errorf("no file for pos %v", pos) 117 } 118 pgf, err := pkg.File(span.URIFromPath(tok.Name())) 119 if err != nil { 120 return nil, err 121 } 122 astHandle := s.generation.Bind(astCacheKey{pkgHandle.key, pgf.URI}, func(ctx context.Context, arg memoize.Arg) interface{} { 123 snapshot := arg.(*snapshot) 124 return buildASTCache(ctx, snapshot, pgf) 125 }, nil) 126 127 d, err := astHandle.Get(ctx, s.generation, s) 128 if err != nil { 129 return nil, err 130 } 131 data := d.(*astCacheData) 132 if data.err != nil { 133 return nil, data.err 134 } 135 return data, nil 136} 137 138func (s *snapshot) PosToDecl(ctx context.Context, spkg source.Package, pos token.Pos) (ast.Decl, error) { 139 data, err := s.astCacheData(ctx, spkg, pos) 140 if err != nil { 141 return nil, err 142 } 143 return data.posToDecl[pos], nil 144} 145 146func (s *snapshot) PosToField(ctx context.Context, spkg source.Package, pos token.Pos) (*ast.Field, error) { 147 data, err := s.astCacheData(ctx, spkg, pos) 148 if err != nil { 149 return nil, err 150 } 151 return data.posToField[pos], nil 152} 153 154type astCacheData struct { 155 err error 156 157 posToDecl map[token.Pos]ast.Decl 158 posToField map[token.Pos]*ast.Field 159} 160 161// buildASTCache builds caches to aid in quickly going from the typed 162// world to the syntactic world. 163func buildASTCache(ctx context.Context, snapshot *snapshot, pgf *source.ParsedGoFile) *astCacheData { 164 var ( 165 // path contains all ancestors, including n. 166 path []ast.Node 167 // decls contains all ancestors that are decls. 168 decls []ast.Decl 169 ) 170 171 data := &astCacheData{ 172 posToDecl: make(map[token.Pos]ast.Decl), 173 posToField: make(map[token.Pos]*ast.Field), 174 } 175 176 ast.Inspect(pgf.File, func(n ast.Node) bool { 177 if n == nil { 178 lastP := path[len(path)-1] 179 path = path[:len(path)-1] 180 if len(decls) > 0 && decls[len(decls)-1] == lastP { 181 decls = decls[:len(decls)-1] 182 } 183 return false 184 } 185 186 path = append(path, n) 187 188 switch n := n.(type) { 189 case *ast.Field: 190 addField := func(f ast.Node) { 191 if f.Pos().IsValid() { 192 data.posToField[f.Pos()] = n 193 if len(decls) > 0 { 194 data.posToDecl[f.Pos()] = decls[len(decls)-1] 195 } 196 } 197 } 198 199 // Add mapping for *ast.Field itself. This handles embedded 200 // fields which have no associated *ast.Ident name. 201 addField(n) 202 203 // Add mapping for each field name since you can have 204 // multiple names for the same type expression. 205 for _, name := range n.Names { 206 addField(name) 207 } 208 209 // Also map "X" in "...X" to the containing *ast.Field. This 210 // makes it easy to format variadic signature params 211 // properly. 212 if elips, ok := n.Type.(*ast.Ellipsis); ok && elips.Elt != nil { 213 addField(elips.Elt) 214 } 215 case *ast.FuncDecl: 216 decls = append(decls, n) 217 218 if n.Name != nil && n.Name.Pos().IsValid() { 219 data.posToDecl[n.Name.Pos()] = n 220 } 221 case *ast.GenDecl: 222 decls = append(decls, n) 223 224 for _, spec := range n.Specs { 225 switch spec := spec.(type) { 226 case *ast.TypeSpec: 227 if spec.Name != nil && spec.Name.Pos().IsValid() { 228 data.posToDecl[spec.Name.Pos()] = n 229 } 230 case *ast.ValueSpec: 231 for _, id := range spec.Names { 232 if id != nil && id.Pos().IsValid() { 233 data.posToDecl[id.Pos()] = n 234 } 235 } 236 } 237 } 238 } 239 240 return true 241 }) 242 243 return data 244} 245 246func parseGo(ctx context.Context, fset *token.FileSet, fh source.FileHandle, mode source.ParseMode) *parseGoData { 247 ctx, done := event.Start(ctx, "cache.parseGo", tag.File.Of(fh.URI().Filename())) 248 defer done() 249 250 if fh.Kind() != source.Go { 251 return &parseGoData{err: errors.Errorf("cannot parse non-Go file %s", fh.URI())} 252 } 253 src, err := fh.Read() 254 if err != nil { 255 return &parseGoData{err: err} 256 } 257 258 parserMode := parser.AllErrors | parser.ParseComments 259 if mode == source.ParseHeader { 260 parserMode = parser.ImportsOnly | parser.ParseComments 261 } 262 263 file, err := parser.ParseFile(fset, fh.URI().Filename(), src, parserMode) 264 var parseErr scanner.ErrorList 265 if err != nil { 266 // We passed a byte slice, so the only possible error is a parse error. 267 parseErr = err.(scanner.ErrorList) 268 } 269 270 tok := fset.File(file.Pos()) 271 if tok == nil { 272 // file.Pos is the location of the package declaration. If there was 273 // none, we can't find the token.File that ParseFile created, and we 274 // have no choice but to recreate it. 275 tok = fset.AddFile(fh.URI().Filename(), -1, len(src)) 276 tok.SetLinesForContent(src) 277 } 278 279 fixed := false 280 // If there were parse errors, attempt to fix them up. 281 if parseErr != nil { 282 // Fix any badly parsed parts of the AST. 283 fixed = fixAST(ctx, file, tok, src) 284 285 for i := 0; i < 10; i++ { 286 // Fix certain syntax errors that render the file unparseable. 287 newSrc := fixSrc(file, tok, src) 288 if newSrc == nil { 289 break 290 } 291 292 // If we thought there was something to fix 10 times in a row, 293 // it is likely we got stuck in a loop somehow. Log out a diff 294 // of the last changes we made to aid in debugging. 295 if i == 9 { 296 edits, err := myers.ComputeEdits(fh.URI(), string(src), string(newSrc)) 297 if err != nil { 298 event.Error(ctx, "error generating fixSrc diff", err, tag.File.Of(tok.Name())) 299 } else { 300 unified := diff.ToUnified("before", "after", string(src), edits) 301 event.Log(ctx, fmt.Sprintf("fixSrc loop - last diff:\n%v", unified), tag.File.Of(tok.Name())) 302 } 303 } 304 305 newFile, _ := parser.ParseFile(fset, fh.URI().Filename(), newSrc, parserMode) 306 if newFile != nil { 307 // Maintain the original parseError so we don't try formatting the doctored file. 308 file = newFile 309 src = newSrc 310 tok = fset.File(file.Pos()) 311 312 fixed = fixAST(ctx, file, tok, src) 313 } 314 } 315 } 316 317 return &parseGoData{ 318 parsed: &source.ParsedGoFile{ 319 URI: fh.URI(), 320 Mode: mode, 321 Src: src, 322 File: file, 323 Tok: tok, 324 Mapper: &protocol.ColumnMapper{ 325 URI: fh.URI(), 326 Converter: span.NewTokenConverter(fset, tok), 327 Content: src, 328 }, 329 ParseErr: parseErr, 330 }, 331 fixed: fixed, 332 } 333} 334 335// An unexportedFilter removes as much unexported AST from a set of Files as possible. 336type unexportedFilter struct { 337 uses map[string]bool 338} 339 340// Filter records uses of unexported identifiers and filters out all other 341// unexported declarations. 342func (f *unexportedFilter) Filter(files []*ast.File) { 343 // Iterate to fixed point -- unexported types can include other unexported types. 344 oldLen := len(f.uses) 345 for { 346 for _, file := range files { 347 f.recordUses(file) 348 } 349 if len(f.uses) == oldLen { 350 break 351 } 352 oldLen = len(f.uses) 353 } 354 355 for _, file := range files { 356 var newDecls []ast.Decl 357 for _, decl := range file.Decls { 358 if f.filterDecl(decl) { 359 newDecls = append(newDecls, decl) 360 } 361 } 362 file.Decls = newDecls 363 file.Scope = nil 364 file.Unresolved = nil 365 file.Comments = nil 366 trimAST(file) 367 } 368} 369 370func (f *unexportedFilter) keep(ident *ast.Ident) bool { 371 return ast.IsExported(ident.Name) || f.uses[ident.Name] 372} 373 374func (f *unexportedFilter) filterDecl(decl ast.Decl) bool { 375 switch decl := decl.(type) { 376 case *ast.FuncDecl: 377 if ident := recvIdent(decl); ident != nil && !f.keep(ident) { 378 return false 379 } 380 return f.keep(decl.Name) 381 case *ast.GenDecl: 382 if decl.Tok == token.CONST { 383 // Constants can involve iota, and iota is hard to deal with. 384 return true 385 } 386 var newSpecs []ast.Spec 387 for _, spec := range decl.Specs { 388 if f.filterSpec(spec) { 389 newSpecs = append(newSpecs, spec) 390 } 391 } 392 decl.Specs = newSpecs 393 return len(newSpecs) != 0 394 case *ast.BadDecl: 395 return false 396 } 397 panic(fmt.Sprintf("unknown ast.Decl %T", decl)) 398} 399 400func (f *unexportedFilter) filterSpec(spec ast.Spec) bool { 401 switch spec := spec.(type) { 402 case *ast.ImportSpec: 403 return true 404 case *ast.ValueSpec: 405 var newNames []*ast.Ident 406 for _, name := range spec.Names { 407 if f.keep(name) { 408 newNames = append(newNames, name) 409 } 410 } 411 spec.Names = newNames 412 return len(spec.Names) != 0 413 case *ast.TypeSpec: 414 if !f.keep(spec.Name) { 415 return false 416 } 417 switch typ := spec.Type.(type) { 418 case *ast.StructType: 419 f.filterFieldList(typ.Fields) 420 case *ast.InterfaceType: 421 f.filterFieldList(typ.Methods) 422 } 423 return true 424 } 425 panic(fmt.Sprintf("unknown ast.Spec %T", spec)) 426} 427 428func (f *unexportedFilter) filterFieldList(fields *ast.FieldList) { 429 var newFields []*ast.Field 430 for _, field := range fields.List { 431 if len(field.Names) == 0 { 432 // Keep embedded fields: they can export methods and fields. 433 newFields = append(newFields, field) 434 } 435 for _, name := range field.Names { 436 if f.keep(name) { 437 newFields = append(newFields, field) 438 break 439 } 440 } 441 } 442 fields.List = newFields 443} 444 445func (f *unexportedFilter) recordUses(file *ast.File) { 446 for _, decl := range file.Decls { 447 switch decl := decl.(type) { 448 case *ast.FuncDecl: 449 // Ignore methods on dropped types. 450 if ident := recvIdent(decl); ident != nil && !f.keep(ident) { 451 break 452 } 453 // Ignore functions with dropped names. 454 if !f.keep(decl.Name) { 455 break 456 } 457 f.recordFuncType(decl.Type) 458 case *ast.GenDecl: 459 for _, spec := range decl.Specs { 460 switch spec := spec.(type) { 461 case *ast.ValueSpec: 462 for i, name := range spec.Names { 463 // Don't mess with constants -- iota is hard. 464 if f.keep(name) || decl.Tok == token.CONST { 465 f.recordIdents(spec.Type) 466 if len(spec.Values) > i { 467 f.recordIdents(spec.Values[i]) 468 } 469 } 470 } 471 case *ast.TypeSpec: 472 switch typ := spec.Type.(type) { 473 case *ast.StructType: 474 f.recordFieldUses(false, typ.Fields) 475 case *ast.InterfaceType: 476 f.recordFieldUses(false, typ.Methods) 477 } 478 } 479 } 480 } 481 } 482} 483 484// recvIdent returns the identifier of a method receiver, e.g. *int. 485func recvIdent(decl *ast.FuncDecl) *ast.Ident { 486 if decl.Recv == nil || len(decl.Recv.List) == 0 { 487 return nil 488 } 489 x := decl.Recv.List[0].Type 490 if star, ok := x.(*ast.StarExpr); ok { 491 x = star.X 492 } 493 if ident, ok := x.(*ast.Ident); ok { 494 return ident 495 } 496 return nil 497} 498 499// recordIdents records unexported identifiers in an Expr in uses. 500// These may be types, e.g. in map[key]value, function names, e.g. in foo(), 501// or simple variable references. References that will be discarded, such 502// as those in function literal bodies, are ignored. 503func (f *unexportedFilter) recordIdents(x ast.Expr) { 504 ast.Inspect(x, func(n ast.Node) bool { 505 if n == nil { 506 return false 507 } 508 if complit, ok := n.(*ast.CompositeLit); ok { 509 // We clear out composite literal contents; just record their type. 510 f.recordIdents(complit.Type) 511 return false 512 } 513 if flit, ok := n.(*ast.FuncLit); ok { 514 f.recordFuncType(flit.Type) 515 return false 516 } 517 if ident, ok := n.(*ast.Ident); ok && !ast.IsExported(ident.Name) { 518 f.uses[ident.Name] = true 519 } 520 return true 521 }) 522} 523 524// recordFuncType records the types mentioned by a function type. 525func (f *unexportedFilter) recordFuncType(x *ast.FuncType) { 526 f.recordFieldUses(true, x.Params) 527 f.recordFieldUses(true, x.Results) 528} 529 530// recordFieldUses records unexported identifiers used in fields, which may be 531// struct members, interface members, or function parameter/results. 532func (f *unexportedFilter) recordFieldUses(isParams bool, fields *ast.FieldList) { 533 if fields == nil { 534 return 535 } 536 for _, field := range fields.List { 537 if isParams { 538 // Parameter types of retained functions need to be retained. 539 f.recordIdents(field.Type) 540 continue 541 } 542 if ft, ok := field.Type.(*ast.FuncType); ok { 543 // Function declarations in interfaces need all their types retained. 544 f.recordFuncType(ft) 545 continue 546 } 547 if len(field.Names) == 0 { 548 // Embedded fields might contribute exported names. 549 f.recordIdents(field.Type) 550 } 551 for _, name := range field.Names { 552 // We only need normal fields if they're exported. 553 if ast.IsExported(name.Name) { 554 f.recordIdents(field.Type) 555 break 556 } 557 } 558 } 559} 560 561// ProcessErrors records additional uses from errors, returning the new uses 562// and any unexpected errors. 563func (f *unexportedFilter) ProcessErrors(errors []types.Error) (map[string]bool, []types.Error) { 564 var unexpected []types.Error 565 missing := map[string]bool{} 566 for _, err := range errors { 567 if strings.Contains(err.Msg, "missing return") { 568 continue 569 } 570 const undeclared = "undeclared name: " 571 if strings.HasPrefix(err.Msg, undeclared) { 572 missing[strings.TrimPrefix(err.Msg, undeclared)] = true 573 f.uses[strings.TrimPrefix(err.Msg, undeclared)] = true 574 continue 575 } 576 unexpected = append(unexpected, err) 577 } 578 return missing, unexpected 579} 580 581// trimAST clears any part of the AST not relevant to type checking 582// expressions at pos. 583func trimAST(file *ast.File) { 584 ast.Inspect(file, func(n ast.Node) bool { 585 if n == nil { 586 return false 587 } 588 switch n := n.(type) { 589 case *ast.FuncDecl: 590 n.Body = nil 591 case *ast.BlockStmt: 592 n.List = nil 593 case *ast.CaseClause: 594 n.Body = nil 595 case *ast.CommClause: 596 n.Body = nil 597 case *ast.CompositeLit: 598 // types.Info.Types for long slice/array literals are particularly 599 // expensive. Try to clear them out. 600 at, ok := n.Type.(*ast.ArrayType) 601 if !ok { 602 // Composite literal. No harm removing all its fields. 603 n.Elts = nil 604 break 605 } 606 // Removing the elements from an ellipsis array changes its type. 607 // Try to set the length explicitly so we can continue. 608 if _, ok := at.Len.(*ast.Ellipsis); ok { 609 length, ok := arrayLength(n) 610 if !ok { 611 break 612 } 613 at.Len = &ast.BasicLit{ 614 Kind: token.INT, 615 Value: fmt.Sprint(length), 616 ValuePos: at.Len.Pos(), 617 } 618 } 619 n.Elts = nil 620 } 621 return true 622 }) 623} 624 625// arrayLength returns the length of some simple forms of ellipsis array literal. 626// Notably, it handles the tables in golang.org/x/text. 627func arrayLength(array *ast.CompositeLit) (int, bool) { 628 litVal := func(expr ast.Expr) (int, bool) { 629 lit, ok := expr.(*ast.BasicLit) 630 if !ok { 631 return 0, false 632 } 633 val, err := strconv.ParseInt(lit.Value, 10, 64) 634 if err != nil { 635 return 0, false 636 } 637 return int(val), true 638 } 639 largestKey := -1 640 for _, elt := range array.Elts { 641 kve, ok := elt.(*ast.KeyValueExpr) 642 if !ok { 643 continue 644 } 645 switch key := kve.Key.(type) { 646 case *ast.BasicLit: 647 if val, ok := litVal(key); ok && largestKey < val { 648 largestKey = val 649 } 650 case *ast.BinaryExpr: 651 // golang.org/x/text uses subtraction (and only subtraction) in its indices. 652 if key.Op != token.SUB { 653 break 654 } 655 x, ok := litVal(key.X) 656 if !ok { 657 break 658 } 659 y, ok := litVal(key.Y) 660 if !ok { 661 break 662 } 663 if val := x - y; largestKey < val { 664 largestKey = val 665 } 666 } 667 } 668 if largestKey != -1 { 669 return largestKey + 1, true 670 } 671 return len(array.Elts), true 672} 673 674// fixAST inspects the AST and potentially modifies any *ast.BadStmts so that it can be 675// type-checked more effectively. 676func fixAST(ctx context.Context, n ast.Node, tok *token.File, src []byte) (fixed bool) { 677 var err error 678 walkASTWithParent(n, func(n, parent ast.Node) bool { 679 switch n := n.(type) { 680 case *ast.BadStmt: 681 if fixed = fixDeferOrGoStmt(n, parent, tok, src); fixed { 682 // Recursively fix in our fixed node. 683 _ = fixAST(ctx, parent, tok, src) 684 } else { 685 err = errors.Errorf("unable to parse defer or go from *ast.BadStmt: %v", err) 686 } 687 return false 688 case *ast.BadExpr: 689 if fixed = fixArrayType(n, parent, tok, src); fixed { 690 // Recursively fix in our fixed node. 691 _ = fixAST(ctx, parent, tok, src) 692 return false 693 } 694 695 // Fix cases where parser interprets if/for/switch "init" 696 // statement as "cond" expression, e.g.: 697 // 698 // // "i := foo" is init statement, not condition. 699 // for i := foo 700 // 701 fixInitStmt(n, parent, tok, src) 702 703 return false 704 case *ast.SelectorExpr: 705 // Fix cases where a keyword prefix results in a phantom "_" selector, e.g.: 706 // 707 // foo.var<> // want to complete to "foo.variance" 708 // 709 fixPhantomSelector(n, tok, src) 710 return true 711 712 case *ast.BlockStmt: 713 switch parent.(type) { 714 case *ast.SwitchStmt, *ast.TypeSwitchStmt, *ast.SelectStmt: 715 // Adjust closing curly brace of empty switch/select 716 // statements so we can complete inside them. 717 fixEmptySwitch(n, tok, src) 718 } 719 720 return true 721 default: 722 return true 723 } 724 }) 725 return fixed 726} 727 728// walkASTWithParent walks the AST rooted at n. The semantics are 729// similar to ast.Inspect except it does not call f(nil). 730func walkASTWithParent(n ast.Node, f func(n ast.Node, parent ast.Node) bool) { 731 var ancestors []ast.Node 732 ast.Inspect(n, func(n ast.Node) (recurse bool) { 733 defer func() { 734 if recurse { 735 ancestors = append(ancestors, n) 736 } 737 }() 738 739 if n == nil { 740 ancestors = ancestors[:len(ancestors)-1] 741 return false 742 } 743 744 var parent ast.Node 745 if len(ancestors) > 0 { 746 parent = ancestors[len(ancestors)-1] 747 } 748 749 return f(n, parent) 750 }) 751} 752 753// fixSrc attempts to modify the file's source code to fix certain 754// syntax errors that leave the rest of the file unparsed. 755func fixSrc(f *ast.File, tok *token.File, src []byte) (newSrc []byte) { 756 walkASTWithParent(f, func(n, parent ast.Node) bool { 757 if newSrc != nil { 758 return false 759 } 760 761 switch n := n.(type) { 762 case *ast.BlockStmt: 763 newSrc = fixMissingCurlies(f, n, parent, tok, src) 764 case *ast.SelectorExpr: 765 newSrc = fixDanglingSelector(n, tok, src) 766 } 767 768 return newSrc == nil 769 }) 770 771 return newSrc 772} 773 774// fixMissingCurlies adds in curly braces for block statements that 775// are missing curly braces. For example: 776// 777// if foo 778// 779// becomes 780// 781// if foo {} 782func fixMissingCurlies(f *ast.File, b *ast.BlockStmt, parent ast.Node, tok *token.File, src []byte) []byte { 783 // If the "{" is already in the source code, there isn't anything to 784 // fix since we aren't missing curlies. 785 if b.Lbrace.IsValid() { 786 braceOffset := tok.Offset(b.Lbrace) 787 if braceOffset < len(src) && src[braceOffset] == '{' { 788 return nil 789 } 790 } 791 792 parentLine := tok.Line(parent.Pos()) 793 794 if parentLine >= tok.LineCount() { 795 // If we are the last line in the file, no need to fix anything. 796 return nil 797 } 798 799 // Insert curlies at the end of parent's starting line. The parent 800 // is the statement that contains the block, e.g. *ast.IfStmt. The 801 // block's Pos()/End() can't be relied upon because they are based 802 // on the (missing) curly braces. We assume the statement is a 803 // single line for now and try sticking the curly braces at the end. 804 insertPos := tok.LineStart(parentLine+1) - 1 805 806 // Scootch position backwards until it's not in a comment. For example: 807 // 808 // if foo<> // some amazing comment | 809 // someOtherCode() 810 // 811 // insertPos will be located at "|", so we back it out of the comment. 812 didSomething := true 813 for didSomething { 814 didSomething = false 815 for _, c := range f.Comments { 816 if c.Pos() < insertPos && insertPos <= c.End() { 817 insertPos = c.Pos() 818 didSomething = true 819 } 820 } 821 } 822 823 // Bail out if line doesn't end in an ident or ".". This is to avoid 824 // cases like below where we end up making things worse by adding 825 // curlies: 826 // 827 // if foo && 828 // bar<> 829 switch precedingToken(insertPos, tok, src) { 830 case token.IDENT, token.PERIOD: 831 // ok 832 default: 833 return nil 834 } 835 836 var buf bytes.Buffer 837 buf.Grow(len(src) + 3) 838 buf.Write(src[:tok.Offset(insertPos)]) 839 840 // Detect if we need to insert a semicolon to fix "for" loop situations like: 841 // 842 // for i := foo(); foo<> 843 // 844 // Just adding curlies is not sufficient to make things parse well. 845 if fs, ok := parent.(*ast.ForStmt); ok { 846 if _, ok := fs.Cond.(*ast.BadExpr); !ok { 847 if xs, ok := fs.Post.(*ast.ExprStmt); ok { 848 if _, ok := xs.X.(*ast.BadExpr); ok { 849 buf.WriteByte(';') 850 } 851 } 852 } 853 } 854 855 // Insert "{}" at insertPos. 856 buf.WriteByte('{') 857 buf.WriteByte('}') 858 buf.Write(src[tok.Offset(insertPos):]) 859 return buf.Bytes() 860} 861 862// fixEmptySwitch moves empty switch/select statements' closing curly 863// brace down one line. This allows us to properly detect incomplete 864// "case" and "default" keywords as inside the switch statement. For 865// example: 866// 867// switch { 868// def<> 869// } 870// 871// gets parsed like: 872// 873// switch { 874// } 875// 876// Later we manually pull out the "def" token, but we need to detect 877// that our "<>" position is inside the switch block. To do that we 878// move the curly brace so it looks like: 879// 880// switch { 881// 882// } 883// 884func fixEmptySwitch(body *ast.BlockStmt, tok *token.File, src []byte) { 885 // We only care about empty switch statements. 886 if len(body.List) > 0 || !body.Rbrace.IsValid() { 887 return 888 } 889 890 // If the right brace is actually in the source code at the 891 // specified position, don't mess with it. 892 braceOffset := tok.Offset(body.Rbrace) 893 if braceOffset < len(src) && src[braceOffset] == '}' { 894 return 895 } 896 897 braceLine := tok.Line(body.Rbrace) 898 if braceLine >= tok.LineCount() { 899 // If we are the last line in the file, no need to fix anything. 900 return 901 } 902 903 // Move the right brace down one line. 904 body.Rbrace = tok.LineStart(braceLine + 1) 905} 906 907// fixDanglingSelector inserts real "_" selector expressions in place 908// of phantom "_" selectors. For example: 909// 910// func _() { 911// x.<> 912// } 913// var x struct { i int } 914// 915// To fix completion at "<>", we insert a real "_" after the "." so the 916// following declaration of "x" can be parsed and type checked 917// normally. 918func fixDanglingSelector(s *ast.SelectorExpr, tok *token.File, src []byte) []byte { 919 if !isPhantomUnderscore(s.Sel, tok, src) { 920 return nil 921 } 922 923 if !s.X.End().IsValid() { 924 return nil 925 } 926 927 // Insert directly after the selector's ".". 928 insertOffset := tok.Offset(s.X.End()) + 1 929 if src[insertOffset-1] != '.' { 930 return nil 931 } 932 933 var buf bytes.Buffer 934 buf.Grow(len(src) + 1) 935 buf.Write(src[:insertOffset]) 936 buf.WriteByte('_') 937 buf.Write(src[insertOffset:]) 938 return buf.Bytes() 939} 940 941// fixPhantomSelector tries to fix selector expressions with phantom 942// "_" selectors. In particular, we check if the selector is a 943// keyword, and if so we swap in an *ast.Ident with the keyword text. For example: 944// 945// foo.var 946// 947// yields a "_" selector instead of "var" since "var" is a keyword. 948func fixPhantomSelector(sel *ast.SelectorExpr, tok *token.File, src []byte) { 949 if !isPhantomUnderscore(sel.Sel, tok, src) { 950 return 951 } 952 953 // Only consider selectors directly abutting the selector ".". This 954 // avoids false positives in cases like: 955 // 956 // foo. // don't think "var" is our selector 957 // var bar = 123 958 // 959 if sel.Sel.Pos() != sel.X.End()+1 { 960 return 961 } 962 963 maybeKeyword := readKeyword(sel.Sel.Pos(), tok, src) 964 if maybeKeyword == "" { 965 return 966 } 967 968 replaceNode(sel, sel.Sel, &ast.Ident{ 969 Name: maybeKeyword, 970 NamePos: sel.Sel.Pos(), 971 }) 972} 973 974// isPhantomUnderscore reports whether the given ident is a phantom 975// underscore. The parser sometimes inserts phantom underscores when 976// it encounters otherwise unparseable situations. 977func isPhantomUnderscore(id *ast.Ident, tok *token.File, src []byte) bool { 978 if id == nil || id.Name != "_" { 979 return false 980 } 981 982 // Phantom underscore means the underscore is not actually in the 983 // program text. 984 offset := tok.Offset(id.Pos()) 985 return len(src) <= offset || src[offset] != '_' 986} 987 988// fixInitStmt fixes cases where the parser misinterprets an 989// if/for/switch "init" statement as the "cond" conditional. In cases 990// like "if i := 0" the user hasn't typed the semicolon yet so the 991// parser is looking for the conditional expression. However, "i := 0" 992// are not valid expressions, so we get a BadExpr. 993func fixInitStmt(bad *ast.BadExpr, parent ast.Node, tok *token.File, src []byte) { 994 if !bad.Pos().IsValid() || !bad.End().IsValid() { 995 return 996 } 997 998 // Try to extract a statement from the BadExpr. 999 stmtBytes := src[tok.Offset(bad.Pos()) : tok.Offset(bad.End()-1)+1] 1000 stmt, err := parseStmt(bad.Pos(), stmtBytes) 1001 if err != nil { 1002 return 1003 } 1004 1005 // If the parent statement doesn't already have an "init" statement, 1006 // move the extracted statement into the "init" field and insert a 1007 // dummy expression into the required "cond" field. 1008 switch p := parent.(type) { 1009 case *ast.IfStmt: 1010 if p.Init != nil { 1011 return 1012 } 1013 p.Init = stmt 1014 p.Cond = &ast.Ident{ 1015 Name: "_", 1016 NamePos: stmt.End(), 1017 } 1018 case *ast.ForStmt: 1019 if p.Init != nil { 1020 return 1021 } 1022 p.Init = stmt 1023 p.Cond = &ast.Ident{ 1024 Name: "_", 1025 NamePos: stmt.End(), 1026 } 1027 case *ast.SwitchStmt: 1028 if p.Init != nil { 1029 return 1030 } 1031 p.Init = stmt 1032 p.Tag = nil 1033 } 1034} 1035 1036// readKeyword reads the keyword starting at pos, if any. 1037func readKeyword(pos token.Pos, tok *token.File, src []byte) string { 1038 var kwBytes []byte 1039 for i := tok.Offset(pos); i < len(src); i++ { 1040 // Use a simplified identifier check since keywords are always lowercase ASCII. 1041 if src[i] < 'a' || src[i] > 'z' { 1042 break 1043 } 1044 kwBytes = append(kwBytes, src[i]) 1045 1046 // Stop search at arbitrarily chosen too-long-for-a-keyword length. 1047 if len(kwBytes) > 15 { 1048 return "" 1049 } 1050 } 1051 1052 if kw := string(kwBytes); token.Lookup(kw).IsKeyword() { 1053 return kw 1054 } 1055 1056 return "" 1057} 1058 1059// fixArrayType tries to parse an *ast.BadExpr into an *ast.ArrayType. 1060// go/parser often turns lone array types like "[]int" into BadExprs 1061// if it isn't expecting a type. 1062func fixArrayType(bad *ast.BadExpr, parent ast.Node, tok *token.File, src []byte) bool { 1063 // Our expected input is a bad expression that looks like "[]someExpr". 1064 1065 from := bad.Pos() 1066 to := bad.End() 1067 1068 if !from.IsValid() || !to.IsValid() { 1069 return false 1070 } 1071 1072 exprBytes := make([]byte, 0, int(to-from)+3) 1073 // Avoid doing tok.Offset(to) since that panics if badExpr ends at EOF. 1074 exprBytes = append(exprBytes, src[tok.Offset(from):tok.Offset(to-1)+1]...) 1075 exprBytes = bytes.TrimSpace(exprBytes) 1076 1077 // If our expression ends in "]" (e.g. "[]"), add a phantom selector 1078 // so we can complete directly after the "[]". 1079 if len(exprBytes) > 0 && exprBytes[len(exprBytes)-1] == ']' { 1080 exprBytes = append(exprBytes, '_') 1081 } 1082 1083 // Add "{}" to turn our ArrayType into a CompositeLit. This is to 1084 // handle the case of "[...]int" where we must make it a composite 1085 // literal to be parseable. 1086 exprBytes = append(exprBytes, '{', '}') 1087 1088 expr, err := parseExpr(from, exprBytes) 1089 if err != nil { 1090 return false 1091 } 1092 1093 cl, _ := expr.(*ast.CompositeLit) 1094 if cl == nil { 1095 return false 1096 } 1097 1098 at, _ := cl.Type.(*ast.ArrayType) 1099 if at == nil { 1100 return false 1101 } 1102 1103 return replaceNode(parent, bad, at) 1104} 1105 1106// precedingToken scans src to find the token preceding pos. 1107func precedingToken(pos token.Pos, tok *token.File, src []byte) token.Token { 1108 s := &scanner.Scanner{} 1109 s.Init(tok, src, nil, 0) 1110 1111 var lastTok token.Token 1112 for { 1113 p, t, _ := s.Scan() 1114 if t == token.EOF || p >= pos { 1115 break 1116 } 1117 1118 lastTok = t 1119 } 1120 return lastTok 1121} 1122 1123// fixDeferOrGoStmt tries to parse an *ast.BadStmt into a defer or a go statement. 1124// 1125// go/parser packages a statement of the form "defer x." as an *ast.BadStmt because 1126// it does not include a call expression. This means that go/types skips type-checking 1127// this statement entirely, and we can't use the type information when completing. 1128// Here, we try to generate a fake *ast.DeferStmt or *ast.GoStmt to put into the AST, 1129// instead of the *ast.BadStmt. 1130func fixDeferOrGoStmt(bad *ast.BadStmt, parent ast.Node, tok *token.File, src []byte) bool { 1131 // Check if we have a bad statement containing either a "go" or "defer". 1132 s := &scanner.Scanner{} 1133 s.Init(tok, src, nil, 0) 1134 1135 var ( 1136 pos token.Pos 1137 tkn token.Token 1138 ) 1139 for { 1140 if tkn == token.EOF { 1141 return false 1142 } 1143 if pos >= bad.From { 1144 break 1145 } 1146 pos, tkn, _ = s.Scan() 1147 } 1148 1149 var stmt ast.Stmt 1150 switch tkn { 1151 case token.DEFER: 1152 stmt = &ast.DeferStmt{ 1153 Defer: pos, 1154 } 1155 case token.GO: 1156 stmt = &ast.GoStmt{ 1157 Go: pos, 1158 } 1159 default: 1160 return false 1161 } 1162 1163 var ( 1164 from, to, last token.Pos 1165 lastToken token.Token 1166 braceDepth int 1167 phantomSelectors []token.Pos 1168 ) 1169FindTo: 1170 for { 1171 to, tkn, _ = s.Scan() 1172 1173 if from == token.NoPos { 1174 from = to 1175 } 1176 1177 switch tkn { 1178 case token.EOF: 1179 break FindTo 1180 case token.SEMICOLON: 1181 // If we aren't in nested braces, end of statement means 1182 // end of expression. 1183 if braceDepth == 0 { 1184 break FindTo 1185 } 1186 case token.LBRACE: 1187 braceDepth++ 1188 } 1189 1190 // This handles the common dangling selector case. For example in 1191 // 1192 // defer fmt. 1193 // y := 1 1194 // 1195 // we notice the dangling period and end our expression. 1196 // 1197 // If the previous token was a "." and we are looking at a "}", 1198 // the period is likely a dangling selector and needs a phantom 1199 // "_". Likewise if the current token is on a different line than 1200 // the period, the period is likely a dangling selector. 1201 if lastToken == token.PERIOD && (tkn == token.RBRACE || tok.Line(to) > tok.Line(last)) { 1202 // Insert phantom "_" selector after the dangling ".". 1203 phantomSelectors = append(phantomSelectors, last+1) 1204 // If we aren't in a block then end the expression after the ".". 1205 if braceDepth == 0 { 1206 to = last + 1 1207 break 1208 } 1209 } 1210 1211 lastToken = tkn 1212 last = to 1213 1214 switch tkn { 1215 case token.RBRACE: 1216 braceDepth-- 1217 if braceDepth <= 0 { 1218 if braceDepth == 0 { 1219 // +1 to include the "}" itself. 1220 to += 1 1221 } 1222 break FindTo 1223 } 1224 } 1225 } 1226 1227 if !from.IsValid() || tok.Offset(from) >= len(src) { 1228 return false 1229 } 1230 1231 if !to.IsValid() || tok.Offset(to) >= len(src) { 1232 return false 1233 } 1234 1235 // Insert any phantom selectors needed to prevent dangling "." from messing 1236 // up the AST. 1237 exprBytes := make([]byte, 0, int(to-from)+len(phantomSelectors)) 1238 for i, b := range src[tok.Offset(from):tok.Offset(to)] { 1239 if len(phantomSelectors) > 0 && from+token.Pos(i) == phantomSelectors[0] { 1240 exprBytes = append(exprBytes, '_') 1241 phantomSelectors = phantomSelectors[1:] 1242 } 1243 exprBytes = append(exprBytes, b) 1244 } 1245 1246 if len(phantomSelectors) > 0 { 1247 exprBytes = append(exprBytes, '_') 1248 } 1249 1250 expr, err := parseExpr(from, exprBytes) 1251 if err != nil { 1252 return false 1253 } 1254 1255 // Package the expression into a fake *ast.CallExpr and re-insert 1256 // into the function. 1257 call := &ast.CallExpr{ 1258 Fun: expr, 1259 Lparen: to, 1260 Rparen: to, 1261 } 1262 1263 switch stmt := stmt.(type) { 1264 case *ast.DeferStmt: 1265 stmt.Call = call 1266 case *ast.GoStmt: 1267 stmt.Call = call 1268 } 1269 1270 return replaceNode(parent, bad, stmt) 1271} 1272 1273// parseStmt parses the statement in src and updates its position to 1274// start at pos. 1275func parseStmt(pos token.Pos, src []byte) (ast.Stmt, error) { 1276 // Wrap our expression to make it a valid Go file we can pass to ParseFile. 1277 fileSrc := bytes.Join([][]byte{ 1278 []byte("package fake;func _(){"), 1279 src, 1280 []byte("}"), 1281 }, nil) 1282 1283 // Use ParseFile instead of ParseExpr because ParseFile has 1284 // best-effort behavior, whereas ParseExpr fails hard on any error. 1285 fakeFile, err := parser.ParseFile(token.NewFileSet(), "", fileSrc, 0) 1286 if fakeFile == nil { 1287 return nil, errors.Errorf("error reading fake file source: %v", err) 1288 } 1289 1290 // Extract our expression node from inside the fake file. 1291 if len(fakeFile.Decls) == 0 { 1292 return nil, errors.Errorf("error parsing fake file: %v", err) 1293 } 1294 1295 fakeDecl, _ := fakeFile.Decls[0].(*ast.FuncDecl) 1296 if fakeDecl == nil || len(fakeDecl.Body.List) == 0 { 1297 return nil, errors.Errorf("no statement in %s: %v", src, err) 1298 } 1299 1300 stmt := fakeDecl.Body.List[0] 1301 1302 // parser.ParseFile returns undefined positions. 1303 // Adjust them for the current file. 1304 offsetPositions(stmt, pos-1-(stmt.Pos()-1)) 1305 1306 return stmt, nil 1307} 1308 1309// parseExpr parses the expression in src and updates its position to 1310// start at pos. 1311func parseExpr(pos token.Pos, src []byte) (ast.Expr, error) { 1312 stmt, err := parseStmt(pos, src) 1313 if err != nil { 1314 return nil, err 1315 } 1316 1317 exprStmt, ok := stmt.(*ast.ExprStmt) 1318 if !ok { 1319 return nil, errors.Errorf("no expr in %s: %v", src, err) 1320 } 1321 1322 return exprStmt.X, nil 1323} 1324 1325var tokenPosType = reflect.TypeOf(token.NoPos) 1326 1327// offsetPositions applies an offset to the positions in an ast.Node. 1328func offsetPositions(n ast.Node, offset token.Pos) { 1329 ast.Inspect(n, func(n ast.Node) bool { 1330 if n == nil { 1331 return false 1332 } 1333 1334 v := reflect.ValueOf(n).Elem() 1335 1336 switch v.Kind() { 1337 case reflect.Struct: 1338 for i := 0; i < v.NumField(); i++ { 1339 f := v.Field(i) 1340 if f.Type() != tokenPosType { 1341 continue 1342 } 1343 1344 if !f.CanSet() { 1345 continue 1346 } 1347 1348 f.SetInt(f.Int() + int64(offset)) 1349 } 1350 } 1351 1352 return true 1353 }) 1354} 1355 1356// replaceNode updates parent's child oldChild to be newChild. It 1357// returns whether it replaced successfully. 1358func replaceNode(parent, oldChild, newChild ast.Node) bool { 1359 if parent == nil || oldChild == nil || newChild == nil { 1360 return false 1361 } 1362 1363 parentVal := reflect.ValueOf(parent).Elem() 1364 if parentVal.Kind() != reflect.Struct { 1365 return false 1366 } 1367 1368 newChildVal := reflect.ValueOf(newChild) 1369 1370 tryReplace := func(v reflect.Value) bool { 1371 if !v.CanSet() || !v.CanInterface() { 1372 return false 1373 } 1374 1375 // If the existing value is oldChild, we found our child. Make 1376 // sure our newChild is assignable and then make the swap. 1377 if v.Interface() == oldChild && newChildVal.Type().AssignableTo(v.Type()) { 1378 v.Set(newChildVal) 1379 return true 1380 } 1381 1382 return false 1383 } 1384 1385 // Loop over parent's struct fields. 1386 for i := 0; i < parentVal.NumField(); i++ { 1387 f := parentVal.Field(i) 1388 1389 switch f.Kind() { 1390 // Check interface and pointer fields. 1391 case reflect.Interface, reflect.Ptr: 1392 if tryReplace(f) { 1393 return true 1394 } 1395 1396 // Search through any slice fields. 1397 case reflect.Slice: 1398 for i := 0; i < f.Len(); i++ { 1399 if tryReplace(f.Index(i)) { 1400 return true 1401 } 1402 } 1403 } 1404 } 1405 1406 return false 1407} 1408