1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package cache
6
7import (
8	"bytes"
9	"context"
10	"fmt"
11	"go/ast"
12	"go/parser"
13	"go/scanner"
14	"go/token"
15	"go/types"
16	"reflect"
17	"strconv"
18	"strings"
19
20	"golang.org/x/tools/internal/event"
21	"golang.org/x/tools/internal/lsp/debug/tag"
22	"golang.org/x/tools/internal/lsp/diff"
23	"golang.org/x/tools/internal/lsp/diff/myers"
24	"golang.org/x/tools/internal/lsp/protocol"
25	"golang.org/x/tools/internal/lsp/source"
26	"golang.org/x/tools/internal/memoize"
27	"golang.org/x/tools/internal/span"
28	errors "golang.org/x/xerrors"
29)
30
31// parseKey uniquely identifies a parsed Go file.
32type parseKey struct {
33	file source.FileIdentity
34	mode source.ParseMode
35}
36
37type parseGoHandle struct {
38	handle *memoize.Handle
39	file   source.FileHandle
40	mode   source.ParseMode
41}
42
43type parseGoData struct {
44	parsed *source.ParsedGoFile
45
46	// If true, we adjusted the AST to make it type check better, and
47	// it may not match the source code.
48	fixed bool
49	err   error // any other errors
50}
51
52func (s *snapshot) parseGoHandle(ctx context.Context, fh source.FileHandle, mode source.ParseMode) *parseGoHandle {
53	key := parseKey{
54		file: fh.FileIdentity(),
55		mode: mode,
56	}
57	if pgh := s.getGoFile(key); pgh != nil {
58		return pgh
59	}
60	parseHandle := s.generation.Bind(key, func(ctx context.Context, arg memoize.Arg) interface{} {
61		snapshot := arg.(*snapshot)
62		return parseGo(ctx, snapshot.view.session.cache.fset, fh, mode)
63	}, nil)
64
65	pgh := &parseGoHandle{
66		handle: parseHandle,
67		file:   fh,
68		mode:   mode,
69	}
70	return s.addGoFile(key, pgh)
71}
72
73func (pgh *parseGoHandle) String() string {
74	return pgh.File().URI().Filename()
75}
76
77func (pgh *parseGoHandle) File() source.FileHandle {
78	return pgh.file
79}
80
81func (pgh *parseGoHandle) Mode() source.ParseMode {
82	return pgh.mode
83}
84
85func (s *snapshot) ParseGo(ctx context.Context, fh source.FileHandle, mode source.ParseMode) (*source.ParsedGoFile, error) {
86	pgh := s.parseGoHandle(ctx, fh, mode)
87	pgf, _, err := s.parseGo(ctx, pgh)
88	return pgf, err
89}
90
91func (s *snapshot) parseGo(ctx context.Context, pgh *parseGoHandle) (*source.ParsedGoFile, bool, error) {
92	if pgh.mode == source.ParseExported {
93		panic("only type checking should use Exported")
94	}
95	d, err := pgh.handle.Get(ctx, s.generation, s)
96	if err != nil {
97		return nil, false, err
98	}
99	data := d.(*parseGoData)
100	return data.parsed, data.fixed, data.err
101}
102
103type astCacheKey struct {
104	pkg packageHandleKey
105	uri span.URI
106}
107
108func (s *snapshot) astCacheData(ctx context.Context, spkg source.Package, pos token.Pos) (*astCacheData, error) {
109	pkg := spkg.(*pkg)
110	pkgHandle := s.getPackage(pkg.m.id, pkg.mode)
111	if pkgHandle == nil {
112		return nil, fmt.Errorf("could not reconstruct package handle for %v", pkg.m.id)
113	}
114	tok := s.FileSet().File(pos)
115	if tok == nil {
116		return nil, fmt.Errorf("no file for pos %v", pos)
117	}
118	pgf, err := pkg.File(span.URIFromPath(tok.Name()))
119	if err != nil {
120		return nil, err
121	}
122	astHandle := s.generation.Bind(astCacheKey{pkgHandle.key, pgf.URI}, func(ctx context.Context, arg memoize.Arg) interface{} {
123		snapshot := arg.(*snapshot)
124		return buildASTCache(ctx, snapshot, pgf)
125	}, nil)
126
127	d, err := astHandle.Get(ctx, s.generation, s)
128	if err != nil {
129		return nil, err
130	}
131	data := d.(*astCacheData)
132	if data.err != nil {
133		return nil, data.err
134	}
135	return data, nil
136}
137
138func (s *snapshot) PosToDecl(ctx context.Context, spkg source.Package, pos token.Pos) (ast.Decl, error) {
139	data, err := s.astCacheData(ctx, spkg, pos)
140	if err != nil {
141		return nil, err
142	}
143	return data.posToDecl[pos], nil
144}
145
146func (s *snapshot) PosToField(ctx context.Context, spkg source.Package, pos token.Pos) (*ast.Field, error) {
147	data, err := s.astCacheData(ctx, spkg, pos)
148	if err != nil {
149		return nil, err
150	}
151	return data.posToField[pos], nil
152}
153
154type astCacheData struct {
155	err error
156
157	posToDecl  map[token.Pos]ast.Decl
158	posToField map[token.Pos]*ast.Field
159}
160
161// buildASTCache builds caches to aid in quickly going from the typed
162// world to the syntactic world.
163func buildASTCache(ctx context.Context, snapshot *snapshot, pgf *source.ParsedGoFile) *astCacheData {
164	var (
165		// path contains all ancestors, including n.
166		path []ast.Node
167		// decls contains all ancestors that are decls.
168		decls []ast.Decl
169	)
170
171	data := &astCacheData{
172		posToDecl:  make(map[token.Pos]ast.Decl),
173		posToField: make(map[token.Pos]*ast.Field),
174	}
175
176	ast.Inspect(pgf.File, func(n ast.Node) bool {
177		if n == nil {
178			lastP := path[len(path)-1]
179			path = path[:len(path)-1]
180			if len(decls) > 0 && decls[len(decls)-1] == lastP {
181				decls = decls[:len(decls)-1]
182			}
183			return false
184		}
185
186		path = append(path, n)
187
188		switch n := n.(type) {
189		case *ast.Field:
190			addField := func(f ast.Node) {
191				if f.Pos().IsValid() {
192					data.posToField[f.Pos()] = n
193					if len(decls) > 0 {
194						data.posToDecl[f.Pos()] = decls[len(decls)-1]
195					}
196				}
197			}
198
199			// Add mapping for *ast.Field itself. This handles embedded
200			// fields which have no associated *ast.Ident name.
201			addField(n)
202
203			// Add mapping for each field name since you can have
204			// multiple names for the same type expression.
205			for _, name := range n.Names {
206				addField(name)
207			}
208
209			// Also map "X" in "...X" to the containing *ast.Field. This
210			// makes it easy to format variadic signature params
211			// properly.
212			if elips, ok := n.Type.(*ast.Ellipsis); ok && elips.Elt != nil {
213				addField(elips.Elt)
214			}
215		case *ast.FuncDecl:
216			decls = append(decls, n)
217
218			if n.Name != nil && n.Name.Pos().IsValid() {
219				data.posToDecl[n.Name.Pos()] = n
220			}
221		case *ast.GenDecl:
222			decls = append(decls, n)
223
224			for _, spec := range n.Specs {
225				switch spec := spec.(type) {
226				case *ast.TypeSpec:
227					if spec.Name != nil && spec.Name.Pos().IsValid() {
228						data.posToDecl[spec.Name.Pos()] = n
229					}
230				case *ast.ValueSpec:
231					for _, id := range spec.Names {
232						if id != nil && id.Pos().IsValid() {
233							data.posToDecl[id.Pos()] = n
234						}
235					}
236				}
237			}
238		}
239
240		return true
241	})
242
243	return data
244}
245
246func parseGo(ctx context.Context, fset *token.FileSet, fh source.FileHandle, mode source.ParseMode) *parseGoData {
247	ctx, done := event.Start(ctx, "cache.parseGo", tag.File.Of(fh.URI().Filename()))
248	defer done()
249
250	if fh.Kind() != source.Go {
251		return &parseGoData{err: errors.Errorf("cannot parse non-Go file %s", fh.URI())}
252	}
253	src, err := fh.Read()
254	if err != nil {
255		return &parseGoData{err: err}
256	}
257
258	parserMode := parser.AllErrors | parser.ParseComments
259	if mode == source.ParseHeader {
260		parserMode = parser.ImportsOnly | parser.ParseComments
261	}
262
263	file, err := parser.ParseFile(fset, fh.URI().Filename(), src, parserMode)
264	var parseErr scanner.ErrorList
265	if err != nil {
266		// We passed a byte slice, so the only possible error is a parse error.
267		parseErr = err.(scanner.ErrorList)
268	}
269
270	tok := fset.File(file.Pos())
271	if tok == nil {
272		// file.Pos is the location of the package declaration. If there was
273		// none, we can't find the token.File that ParseFile created, and we
274		// have no choice but to recreate it.
275		tok = fset.AddFile(fh.URI().Filename(), -1, len(src))
276		tok.SetLinesForContent(src)
277	}
278
279	fixed := false
280	// If there were parse errors, attempt to fix them up.
281	if parseErr != nil {
282		// Fix any badly parsed parts of the AST.
283		fixed = fixAST(ctx, file, tok, src)
284
285		for i := 0; i < 10; i++ {
286			// Fix certain syntax errors that render the file unparseable.
287			newSrc := fixSrc(file, tok, src)
288			if newSrc == nil {
289				break
290			}
291
292			// If we thought there was something to fix 10 times in a row,
293			// it is likely we got stuck in a loop somehow. Log out a diff
294			// of the last changes we made to aid in debugging.
295			if i == 9 {
296				edits, err := myers.ComputeEdits(fh.URI(), string(src), string(newSrc))
297				if err != nil {
298					event.Error(ctx, "error generating fixSrc diff", err, tag.File.Of(tok.Name()))
299				} else {
300					unified := diff.ToUnified("before", "after", string(src), edits)
301					event.Log(ctx, fmt.Sprintf("fixSrc loop - last diff:\n%v", unified), tag.File.Of(tok.Name()))
302				}
303			}
304
305			newFile, _ := parser.ParseFile(fset, fh.URI().Filename(), newSrc, parserMode)
306			if newFile != nil {
307				// Maintain the original parseError so we don't try formatting the doctored file.
308				file = newFile
309				src = newSrc
310				tok = fset.File(file.Pos())
311
312				fixed = fixAST(ctx, file, tok, src)
313			}
314		}
315	}
316
317	return &parseGoData{
318		parsed: &source.ParsedGoFile{
319			URI:  fh.URI(),
320			Mode: mode,
321			Src:  src,
322			File: file,
323			Tok:  tok,
324			Mapper: &protocol.ColumnMapper{
325				URI:       fh.URI(),
326				Converter: span.NewTokenConverter(fset, tok),
327				Content:   src,
328			},
329			ParseErr: parseErr,
330		},
331		fixed: fixed,
332	}
333}
334
335// An unexportedFilter removes as much unexported AST from a set of Files as possible.
336type unexportedFilter struct {
337	uses map[string]bool
338}
339
340// Filter records uses of unexported identifiers and filters out all other
341// unexported declarations.
342func (f *unexportedFilter) Filter(files []*ast.File) {
343	// Iterate to fixed point -- unexported types can include other unexported types.
344	oldLen := len(f.uses)
345	for {
346		for _, file := range files {
347			f.recordUses(file)
348		}
349		if len(f.uses) == oldLen {
350			break
351		}
352		oldLen = len(f.uses)
353	}
354
355	for _, file := range files {
356		var newDecls []ast.Decl
357		for _, decl := range file.Decls {
358			if f.filterDecl(decl) {
359				newDecls = append(newDecls, decl)
360			}
361		}
362		file.Decls = newDecls
363		file.Scope = nil
364		file.Unresolved = nil
365		file.Comments = nil
366		trimAST(file)
367	}
368}
369
370func (f *unexportedFilter) keep(ident *ast.Ident) bool {
371	return ast.IsExported(ident.Name) || f.uses[ident.Name]
372}
373
374func (f *unexportedFilter) filterDecl(decl ast.Decl) bool {
375	switch decl := decl.(type) {
376	case *ast.FuncDecl:
377		if ident := recvIdent(decl); ident != nil && !f.keep(ident) {
378			return false
379		}
380		return f.keep(decl.Name)
381	case *ast.GenDecl:
382		if decl.Tok == token.CONST {
383			// Constants can involve iota, and iota is hard to deal with.
384			return true
385		}
386		var newSpecs []ast.Spec
387		for _, spec := range decl.Specs {
388			if f.filterSpec(spec) {
389				newSpecs = append(newSpecs, spec)
390			}
391		}
392		decl.Specs = newSpecs
393		return len(newSpecs) != 0
394	case *ast.BadDecl:
395		return false
396	}
397	panic(fmt.Sprintf("unknown ast.Decl %T", decl))
398}
399
400func (f *unexportedFilter) filterSpec(spec ast.Spec) bool {
401	switch spec := spec.(type) {
402	case *ast.ImportSpec:
403		return true
404	case *ast.ValueSpec:
405		var newNames []*ast.Ident
406		for _, name := range spec.Names {
407			if f.keep(name) {
408				newNames = append(newNames, name)
409			}
410		}
411		spec.Names = newNames
412		return len(spec.Names) != 0
413	case *ast.TypeSpec:
414		if !f.keep(spec.Name) {
415			return false
416		}
417		switch typ := spec.Type.(type) {
418		case *ast.StructType:
419			f.filterFieldList(typ.Fields)
420		case *ast.InterfaceType:
421			f.filterFieldList(typ.Methods)
422		}
423		return true
424	}
425	panic(fmt.Sprintf("unknown ast.Spec %T", spec))
426}
427
428func (f *unexportedFilter) filterFieldList(fields *ast.FieldList) {
429	var newFields []*ast.Field
430	for _, field := range fields.List {
431		if len(field.Names) == 0 {
432			// Keep embedded fields: they can export methods and fields.
433			newFields = append(newFields, field)
434		}
435		for _, name := range field.Names {
436			if f.keep(name) {
437				newFields = append(newFields, field)
438				break
439			}
440		}
441	}
442	fields.List = newFields
443}
444
445func (f *unexportedFilter) recordUses(file *ast.File) {
446	for _, decl := range file.Decls {
447		switch decl := decl.(type) {
448		case *ast.FuncDecl:
449			// Ignore methods on dropped types.
450			if ident := recvIdent(decl); ident != nil && !f.keep(ident) {
451				break
452			}
453			// Ignore functions with dropped names.
454			if !f.keep(decl.Name) {
455				break
456			}
457			f.recordFuncType(decl.Type)
458		case *ast.GenDecl:
459			for _, spec := range decl.Specs {
460				switch spec := spec.(type) {
461				case *ast.ValueSpec:
462					for i, name := range spec.Names {
463						// Don't mess with constants -- iota is hard.
464						if f.keep(name) || decl.Tok == token.CONST {
465							f.recordIdents(spec.Type)
466							if len(spec.Values) > i {
467								f.recordIdents(spec.Values[i])
468							}
469						}
470					}
471				case *ast.TypeSpec:
472					switch typ := spec.Type.(type) {
473					case *ast.StructType:
474						f.recordFieldUses(false, typ.Fields)
475					case *ast.InterfaceType:
476						f.recordFieldUses(false, typ.Methods)
477					}
478				}
479			}
480		}
481	}
482}
483
484// recvIdent returns the identifier of a method receiver, e.g. *int.
485func recvIdent(decl *ast.FuncDecl) *ast.Ident {
486	if decl.Recv == nil || len(decl.Recv.List) == 0 {
487		return nil
488	}
489	x := decl.Recv.List[0].Type
490	if star, ok := x.(*ast.StarExpr); ok {
491		x = star.X
492	}
493	if ident, ok := x.(*ast.Ident); ok {
494		return ident
495	}
496	return nil
497}
498
499// recordIdents records unexported identifiers in an Expr in uses.
500// These may be types, e.g. in map[key]value, function names, e.g. in foo(),
501// or simple variable references. References that will be discarded, such
502// as those in function literal bodies, are ignored.
503func (f *unexportedFilter) recordIdents(x ast.Expr) {
504	ast.Inspect(x, func(n ast.Node) bool {
505		if n == nil {
506			return false
507		}
508		if complit, ok := n.(*ast.CompositeLit); ok {
509			// We clear out composite literal contents; just record their type.
510			f.recordIdents(complit.Type)
511			return false
512		}
513		if flit, ok := n.(*ast.FuncLit); ok {
514			f.recordFuncType(flit.Type)
515			return false
516		}
517		if ident, ok := n.(*ast.Ident); ok && !ast.IsExported(ident.Name) {
518			f.uses[ident.Name] = true
519		}
520		return true
521	})
522}
523
524// recordFuncType records the types mentioned by a function type.
525func (f *unexportedFilter) recordFuncType(x *ast.FuncType) {
526	f.recordFieldUses(true, x.Params)
527	f.recordFieldUses(true, x.Results)
528}
529
530// recordFieldUses records unexported identifiers used in fields, which may be
531// struct members, interface members, or function parameter/results.
532func (f *unexportedFilter) recordFieldUses(isParams bool, fields *ast.FieldList) {
533	if fields == nil {
534		return
535	}
536	for _, field := range fields.List {
537		if isParams {
538			// Parameter types of retained functions need to be retained.
539			f.recordIdents(field.Type)
540			continue
541		}
542		if ft, ok := field.Type.(*ast.FuncType); ok {
543			// Function declarations in interfaces need all their types retained.
544			f.recordFuncType(ft)
545			continue
546		}
547		if len(field.Names) == 0 {
548			// Embedded fields might contribute exported names.
549			f.recordIdents(field.Type)
550		}
551		for _, name := range field.Names {
552			// We only need normal fields if they're exported.
553			if ast.IsExported(name.Name) {
554				f.recordIdents(field.Type)
555				break
556			}
557		}
558	}
559}
560
561// ProcessErrors records additional uses from errors, returning the new uses
562// and any unexpected errors.
563func (f *unexportedFilter) ProcessErrors(errors []types.Error) (map[string]bool, []types.Error) {
564	var unexpected []types.Error
565	missing := map[string]bool{}
566	for _, err := range errors {
567		if strings.Contains(err.Msg, "missing return") {
568			continue
569		}
570		const undeclared = "undeclared name: "
571		if strings.HasPrefix(err.Msg, undeclared) {
572			missing[strings.TrimPrefix(err.Msg, undeclared)] = true
573			f.uses[strings.TrimPrefix(err.Msg, undeclared)] = true
574			continue
575		}
576		unexpected = append(unexpected, err)
577	}
578	return missing, unexpected
579}
580
581// trimAST clears any part of the AST not relevant to type checking
582// expressions at pos.
583func trimAST(file *ast.File) {
584	ast.Inspect(file, func(n ast.Node) bool {
585		if n == nil {
586			return false
587		}
588		switch n := n.(type) {
589		case *ast.FuncDecl:
590			n.Body = nil
591		case *ast.BlockStmt:
592			n.List = nil
593		case *ast.CaseClause:
594			n.Body = nil
595		case *ast.CommClause:
596			n.Body = nil
597		case *ast.CompositeLit:
598			// types.Info.Types for long slice/array literals are particularly
599			// expensive. Try to clear them out.
600			at, ok := n.Type.(*ast.ArrayType)
601			if !ok {
602				// Composite literal. No harm removing all its fields.
603				n.Elts = nil
604				break
605			}
606			// Removing the elements from an ellipsis array changes its type.
607			// Try to set the length explicitly so we can continue.
608			if _, ok := at.Len.(*ast.Ellipsis); ok {
609				length, ok := arrayLength(n)
610				if !ok {
611					break
612				}
613				at.Len = &ast.BasicLit{
614					Kind:     token.INT,
615					Value:    fmt.Sprint(length),
616					ValuePos: at.Len.Pos(),
617				}
618			}
619			n.Elts = nil
620		}
621		return true
622	})
623}
624
625// arrayLength returns the length of some simple forms of ellipsis array literal.
626// Notably, it handles the tables in golang.org/x/text.
627func arrayLength(array *ast.CompositeLit) (int, bool) {
628	litVal := func(expr ast.Expr) (int, bool) {
629		lit, ok := expr.(*ast.BasicLit)
630		if !ok {
631			return 0, false
632		}
633		val, err := strconv.ParseInt(lit.Value, 10, 64)
634		if err != nil {
635			return 0, false
636		}
637		return int(val), true
638	}
639	largestKey := -1
640	for _, elt := range array.Elts {
641		kve, ok := elt.(*ast.KeyValueExpr)
642		if !ok {
643			continue
644		}
645		switch key := kve.Key.(type) {
646		case *ast.BasicLit:
647			if val, ok := litVal(key); ok && largestKey < val {
648				largestKey = val
649			}
650		case *ast.BinaryExpr:
651			// golang.org/x/text uses subtraction (and only subtraction) in its indices.
652			if key.Op != token.SUB {
653				break
654			}
655			x, ok := litVal(key.X)
656			if !ok {
657				break
658			}
659			y, ok := litVal(key.Y)
660			if !ok {
661				break
662			}
663			if val := x - y; largestKey < val {
664				largestKey = val
665			}
666		}
667	}
668	if largestKey != -1 {
669		return largestKey + 1, true
670	}
671	return len(array.Elts), true
672}
673
674// fixAST inspects the AST and potentially modifies any *ast.BadStmts so that it can be
675// type-checked more effectively.
676func fixAST(ctx context.Context, n ast.Node, tok *token.File, src []byte) (fixed bool) {
677	var err error
678	walkASTWithParent(n, func(n, parent ast.Node) bool {
679		switch n := n.(type) {
680		case *ast.BadStmt:
681			if fixed = fixDeferOrGoStmt(n, parent, tok, src); fixed {
682				// Recursively fix in our fixed node.
683				_ = fixAST(ctx, parent, tok, src)
684			} else {
685				err = errors.Errorf("unable to parse defer or go from *ast.BadStmt: %v", err)
686			}
687			return false
688		case *ast.BadExpr:
689			if fixed = fixArrayType(n, parent, tok, src); fixed {
690				// Recursively fix in our fixed node.
691				_ = fixAST(ctx, parent, tok, src)
692				return false
693			}
694
695			// Fix cases where parser interprets if/for/switch "init"
696			// statement as "cond" expression, e.g.:
697			//
698			//   // "i := foo" is init statement, not condition.
699			//   for i := foo
700			//
701			fixInitStmt(n, parent, tok, src)
702
703			return false
704		case *ast.SelectorExpr:
705			// Fix cases where a keyword prefix results in a phantom "_" selector, e.g.:
706			//
707			//   foo.var<> // want to complete to "foo.variance"
708			//
709			fixPhantomSelector(n, tok, src)
710			return true
711
712		case *ast.BlockStmt:
713			switch parent.(type) {
714			case *ast.SwitchStmt, *ast.TypeSwitchStmt, *ast.SelectStmt:
715				// Adjust closing curly brace of empty switch/select
716				// statements so we can complete inside them.
717				fixEmptySwitch(n, tok, src)
718			}
719
720			return true
721		default:
722			return true
723		}
724	})
725	return fixed
726}
727
728// walkASTWithParent walks the AST rooted at n. The semantics are
729// similar to ast.Inspect except it does not call f(nil).
730func walkASTWithParent(n ast.Node, f func(n ast.Node, parent ast.Node) bool) {
731	var ancestors []ast.Node
732	ast.Inspect(n, func(n ast.Node) (recurse bool) {
733		defer func() {
734			if recurse {
735				ancestors = append(ancestors, n)
736			}
737		}()
738
739		if n == nil {
740			ancestors = ancestors[:len(ancestors)-1]
741			return false
742		}
743
744		var parent ast.Node
745		if len(ancestors) > 0 {
746			parent = ancestors[len(ancestors)-1]
747		}
748
749		return f(n, parent)
750	})
751}
752
753// fixSrc attempts to modify the file's source code to fix certain
754// syntax errors that leave the rest of the file unparsed.
755func fixSrc(f *ast.File, tok *token.File, src []byte) (newSrc []byte) {
756	walkASTWithParent(f, func(n, parent ast.Node) bool {
757		if newSrc != nil {
758			return false
759		}
760
761		switch n := n.(type) {
762		case *ast.BlockStmt:
763			newSrc = fixMissingCurlies(f, n, parent, tok, src)
764		case *ast.SelectorExpr:
765			newSrc = fixDanglingSelector(n, tok, src)
766		}
767
768		return newSrc == nil
769	})
770
771	return newSrc
772}
773
774// fixMissingCurlies adds in curly braces for block statements that
775// are missing curly braces. For example:
776//
777//   if foo
778//
779// becomes
780//
781//   if foo {}
782func fixMissingCurlies(f *ast.File, b *ast.BlockStmt, parent ast.Node, tok *token.File, src []byte) []byte {
783	// If the "{" is already in the source code, there isn't anything to
784	// fix since we aren't missing curlies.
785	if b.Lbrace.IsValid() {
786		braceOffset := tok.Offset(b.Lbrace)
787		if braceOffset < len(src) && src[braceOffset] == '{' {
788			return nil
789		}
790	}
791
792	parentLine := tok.Line(parent.Pos())
793
794	if parentLine >= tok.LineCount() {
795		// If we are the last line in the file, no need to fix anything.
796		return nil
797	}
798
799	// Insert curlies at the end of parent's starting line. The parent
800	// is the statement that contains the block, e.g. *ast.IfStmt. The
801	// block's Pos()/End() can't be relied upon because they are based
802	// on the (missing) curly braces. We assume the statement is a
803	// single line for now and try sticking the curly braces at the end.
804	insertPos := tok.LineStart(parentLine+1) - 1
805
806	// Scootch position backwards until it's not in a comment. For example:
807	//
808	// if foo<> // some amazing comment |
809	// someOtherCode()
810	//
811	// insertPos will be located at "|", so we back it out of the comment.
812	didSomething := true
813	for didSomething {
814		didSomething = false
815		for _, c := range f.Comments {
816			if c.Pos() < insertPos && insertPos <= c.End() {
817				insertPos = c.Pos()
818				didSomething = true
819			}
820		}
821	}
822
823	// Bail out if line doesn't end in an ident or ".". This is to avoid
824	// cases like below where we end up making things worse by adding
825	// curlies:
826	//
827	//   if foo &&
828	//     bar<>
829	switch precedingToken(insertPos, tok, src) {
830	case token.IDENT, token.PERIOD:
831		// ok
832	default:
833		return nil
834	}
835
836	var buf bytes.Buffer
837	buf.Grow(len(src) + 3)
838	buf.Write(src[:tok.Offset(insertPos)])
839
840	// Detect if we need to insert a semicolon to fix "for" loop situations like:
841	//
842	//   for i := foo(); foo<>
843	//
844	// Just adding curlies is not sufficient to make things parse well.
845	if fs, ok := parent.(*ast.ForStmt); ok {
846		if _, ok := fs.Cond.(*ast.BadExpr); !ok {
847			if xs, ok := fs.Post.(*ast.ExprStmt); ok {
848				if _, ok := xs.X.(*ast.BadExpr); ok {
849					buf.WriteByte(';')
850				}
851			}
852		}
853	}
854
855	// Insert "{}" at insertPos.
856	buf.WriteByte('{')
857	buf.WriteByte('}')
858	buf.Write(src[tok.Offset(insertPos):])
859	return buf.Bytes()
860}
861
862// fixEmptySwitch moves empty switch/select statements' closing curly
863// brace down one line. This allows us to properly detect incomplete
864// "case" and "default" keywords as inside the switch statement. For
865// example:
866//
867//   switch {
868//   def<>
869//   }
870//
871// gets parsed like:
872//
873//   switch {
874//   }
875//
876// Later we manually pull out the "def" token, but we need to detect
877// that our "<>" position is inside the switch block. To do that we
878// move the curly brace so it looks like:
879//
880//   switch {
881//
882//   }
883//
884func fixEmptySwitch(body *ast.BlockStmt, tok *token.File, src []byte) {
885	// We only care about empty switch statements.
886	if len(body.List) > 0 || !body.Rbrace.IsValid() {
887		return
888	}
889
890	// If the right brace is actually in the source code at the
891	// specified position, don't mess with it.
892	braceOffset := tok.Offset(body.Rbrace)
893	if braceOffset < len(src) && src[braceOffset] == '}' {
894		return
895	}
896
897	braceLine := tok.Line(body.Rbrace)
898	if braceLine >= tok.LineCount() {
899		// If we are the last line in the file, no need to fix anything.
900		return
901	}
902
903	// Move the right brace down one line.
904	body.Rbrace = tok.LineStart(braceLine + 1)
905}
906
907// fixDanglingSelector inserts real "_" selector expressions in place
908// of phantom "_" selectors. For example:
909//
910// func _() {
911//   x.<>
912// }
913// var x struct { i int }
914//
915// To fix completion at "<>", we insert a real "_" after the "." so the
916// following declaration of "x" can be parsed and type checked
917// normally.
918func fixDanglingSelector(s *ast.SelectorExpr, tok *token.File, src []byte) []byte {
919	if !isPhantomUnderscore(s.Sel, tok, src) {
920		return nil
921	}
922
923	if !s.X.End().IsValid() {
924		return nil
925	}
926
927	// Insert directly after the selector's ".".
928	insertOffset := tok.Offset(s.X.End()) + 1
929	if src[insertOffset-1] != '.' {
930		return nil
931	}
932
933	var buf bytes.Buffer
934	buf.Grow(len(src) + 1)
935	buf.Write(src[:insertOffset])
936	buf.WriteByte('_')
937	buf.Write(src[insertOffset:])
938	return buf.Bytes()
939}
940
941// fixPhantomSelector tries to fix selector expressions with phantom
942// "_" selectors. In particular, we check if the selector is a
943// keyword, and if so we swap in an *ast.Ident with the keyword text. For example:
944//
945// foo.var
946//
947// yields a "_" selector instead of "var" since "var" is a keyword.
948func fixPhantomSelector(sel *ast.SelectorExpr, tok *token.File, src []byte) {
949	if !isPhantomUnderscore(sel.Sel, tok, src) {
950		return
951	}
952
953	// Only consider selectors directly abutting the selector ".". This
954	// avoids false positives in cases like:
955	//
956	//   foo. // don't think "var" is our selector
957	//   var bar = 123
958	//
959	if sel.Sel.Pos() != sel.X.End()+1 {
960		return
961	}
962
963	maybeKeyword := readKeyword(sel.Sel.Pos(), tok, src)
964	if maybeKeyword == "" {
965		return
966	}
967
968	replaceNode(sel, sel.Sel, &ast.Ident{
969		Name:    maybeKeyword,
970		NamePos: sel.Sel.Pos(),
971	})
972}
973
974// isPhantomUnderscore reports whether the given ident is a phantom
975// underscore. The parser sometimes inserts phantom underscores when
976// it encounters otherwise unparseable situations.
977func isPhantomUnderscore(id *ast.Ident, tok *token.File, src []byte) bool {
978	if id == nil || id.Name != "_" {
979		return false
980	}
981
982	// Phantom underscore means the underscore is not actually in the
983	// program text.
984	offset := tok.Offset(id.Pos())
985	return len(src) <= offset || src[offset] != '_'
986}
987
988// fixInitStmt fixes cases where the parser misinterprets an
989// if/for/switch "init" statement as the "cond" conditional. In cases
990// like "if i := 0" the user hasn't typed the semicolon yet so the
991// parser is looking for the conditional expression. However, "i := 0"
992// are not valid expressions, so we get a BadExpr.
993func fixInitStmt(bad *ast.BadExpr, parent ast.Node, tok *token.File, src []byte) {
994	if !bad.Pos().IsValid() || !bad.End().IsValid() {
995		return
996	}
997
998	// Try to extract a statement from the BadExpr.
999	stmtBytes := src[tok.Offset(bad.Pos()) : tok.Offset(bad.End()-1)+1]
1000	stmt, err := parseStmt(bad.Pos(), stmtBytes)
1001	if err != nil {
1002		return
1003	}
1004
1005	// If the parent statement doesn't already have an "init" statement,
1006	// move the extracted statement into the "init" field and insert a
1007	// dummy expression into the required "cond" field.
1008	switch p := parent.(type) {
1009	case *ast.IfStmt:
1010		if p.Init != nil {
1011			return
1012		}
1013		p.Init = stmt
1014		p.Cond = &ast.Ident{
1015			Name:    "_",
1016			NamePos: stmt.End(),
1017		}
1018	case *ast.ForStmt:
1019		if p.Init != nil {
1020			return
1021		}
1022		p.Init = stmt
1023		p.Cond = &ast.Ident{
1024			Name:    "_",
1025			NamePos: stmt.End(),
1026		}
1027	case *ast.SwitchStmt:
1028		if p.Init != nil {
1029			return
1030		}
1031		p.Init = stmt
1032		p.Tag = nil
1033	}
1034}
1035
1036// readKeyword reads the keyword starting at pos, if any.
1037func readKeyword(pos token.Pos, tok *token.File, src []byte) string {
1038	var kwBytes []byte
1039	for i := tok.Offset(pos); i < len(src); i++ {
1040		// Use a simplified identifier check since keywords are always lowercase ASCII.
1041		if src[i] < 'a' || src[i] > 'z' {
1042			break
1043		}
1044		kwBytes = append(kwBytes, src[i])
1045
1046		// Stop search at arbitrarily chosen too-long-for-a-keyword length.
1047		if len(kwBytes) > 15 {
1048			return ""
1049		}
1050	}
1051
1052	if kw := string(kwBytes); token.Lookup(kw).IsKeyword() {
1053		return kw
1054	}
1055
1056	return ""
1057}
1058
1059// fixArrayType tries to parse an *ast.BadExpr into an *ast.ArrayType.
1060// go/parser often turns lone array types like "[]int" into BadExprs
1061// if it isn't expecting a type.
1062func fixArrayType(bad *ast.BadExpr, parent ast.Node, tok *token.File, src []byte) bool {
1063	// Our expected input is a bad expression that looks like "[]someExpr".
1064
1065	from := bad.Pos()
1066	to := bad.End()
1067
1068	if !from.IsValid() || !to.IsValid() {
1069		return false
1070	}
1071
1072	exprBytes := make([]byte, 0, int(to-from)+3)
1073	// Avoid doing tok.Offset(to) since that panics if badExpr ends at EOF.
1074	exprBytes = append(exprBytes, src[tok.Offset(from):tok.Offset(to-1)+1]...)
1075	exprBytes = bytes.TrimSpace(exprBytes)
1076
1077	// If our expression ends in "]" (e.g. "[]"), add a phantom selector
1078	// so we can complete directly after the "[]".
1079	if len(exprBytes) > 0 && exprBytes[len(exprBytes)-1] == ']' {
1080		exprBytes = append(exprBytes, '_')
1081	}
1082
1083	// Add "{}" to turn our ArrayType into a CompositeLit. This is to
1084	// handle the case of "[...]int" where we must make it a composite
1085	// literal to be parseable.
1086	exprBytes = append(exprBytes, '{', '}')
1087
1088	expr, err := parseExpr(from, exprBytes)
1089	if err != nil {
1090		return false
1091	}
1092
1093	cl, _ := expr.(*ast.CompositeLit)
1094	if cl == nil {
1095		return false
1096	}
1097
1098	at, _ := cl.Type.(*ast.ArrayType)
1099	if at == nil {
1100		return false
1101	}
1102
1103	return replaceNode(parent, bad, at)
1104}
1105
1106// precedingToken scans src to find the token preceding pos.
1107func precedingToken(pos token.Pos, tok *token.File, src []byte) token.Token {
1108	s := &scanner.Scanner{}
1109	s.Init(tok, src, nil, 0)
1110
1111	var lastTok token.Token
1112	for {
1113		p, t, _ := s.Scan()
1114		if t == token.EOF || p >= pos {
1115			break
1116		}
1117
1118		lastTok = t
1119	}
1120	return lastTok
1121}
1122
1123// fixDeferOrGoStmt tries to parse an *ast.BadStmt into a defer or a go statement.
1124//
1125// go/parser packages a statement of the form "defer x." as an *ast.BadStmt because
1126// it does not include a call expression. This means that go/types skips type-checking
1127// this statement entirely, and we can't use the type information when completing.
1128// Here, we try to generate a fake *ast.DeferStmt or *ast.GoStmt to put into the AST,
1129// instead of the *ast.BadStmt.
1130func fixDeferOrGoStmt(bad *ast.BadStmt, parent ast.Node, tok *token.File, src []byte) bool {
1131	// Check if we have a bad statement containing either a "go" or "defer".
1132	s := &scanner.Scanner{}
1133	s.Init(tok, src, nil, 0)
1134
1135	var (
1136		pos token.Pos
1137		tkn token.Token
1138	)
1139	for {
1140		if tkn == token.EOF {
1141			return false
1142		}
1143		if pos >= bad.From {
1144			break
1145		}
1146		pos, tkn, _ = s.Scan()
1147	}
1148
1149	var stmt ast.Stmt
1150	switch tkn {
1151	case token.DEFER:
1152		stmt = &ast.DeferStmt{
1153			Defer: pos,
1154		}
1155	case token.GO:
1156		stmt = &ast.GoStmt{
1157			Go: pos,
1158		}
1159	default:
1160		return false
1161	}
1162
1163	var (
1164		from, to, last   token.Pos
1165		lastToken        token.Token
1166		braceDepth       int
1167		phantomSelectors []token.Pos
1168	)
1169FindTo:
1170	for {
1171		to, tkn, _ = s.Scan()
1172
1173		if from == token.NoPos {
1174			from = to
1175		}
1176
1177		switch tkn {
1178		case token.EOF:
1179			break FindTo
1180		case token.SEMICOLON:
1181			// If we aren't in nested braces, end of statement means
1182			// end of expression.
1183			if braceDepth == 0 {
1184				break FindTo
1185			}
1186		case token.LBRACE:
1187			braceDepth++
1188		}
1189
1190		// This handles the common dangling selector case. For example in
1191		//
1192		// defer fmt.
1193		// y := 1
1194		//
1195		// we notice the dangling period and end our expression.
1196		//
1197		// If the previous token was a "." and we are looking at a "}",
1198		// the period is likely a dangling selector and needs a phantom
1199		// "_". Likewise if the current token is on a different line than
1200		// the period, the period is likely a dangling selector.
1201		if lastToken == token.PERIOD && (tkn == token.RBRACE || tok.Line(to) > tok.Line(last)) {
1202			// Insert phantom "_" selector after the dangling ".".
1203			phantomSelectors = append(phantomSelectors, last+1)
1204			// If we aren't in a block then end the expression after the ".".
1205			if braceDepth == 0 {
1206				to = last + 1
1207				break
1208			}
1209		}
1210
1211		lastToken = tkn
1212		last = to
1213
1214		switch tkn {
1215		case token.RBRACE:
1216			braceDepth--
1217			if braceDepth <= 0 {
1218				if braceDepth == 0 {
1219					// +1 to include the "}" itself.
1220					to += 1
1221				}
1222				break FindTo
1223			}
1224		}
1225	}
1226
1227	if !from.IsValid() || tok.Offset(from) >= len(src) {
1228		return false
1229	}
1230
1231	if !to.IsValid() || tok.Offset(to) >= len(src) {
1232		return false
1233	}
1234
1235	// Insert any phantom selectors needed to prevent dangling "." from messing
1236	// up the AST.
1237	exprBytes := make([]byte, 0, int(to-from)+len(phantomSelectors))
1238	for i, b := range src[tok.Offset(from):tok.Offset(to)] {
1239		if len(phantomSelectors) > 0 && from+token.Pos(i) == phantomSelectors[0] {
1240			exprBytes = append(exprBytes, '_')
1241			phantomSelectors = phantomSelectors[1:]
1242		}
1243		exprBytes = append(exprBytes, b)
1244	}
1245
1246	if len(phantomSelectors) > 0 {
1247		exprBytes = append(exprBytes, '_')
1248	}
1249
1250	expr, err := parseExpr(from, exprBytes)
1251	if err != nil {
1252		return false
1253	}
1254
1255	// Package the expression into a fake *ast.CallExpr and re-insert
1256	// into the function.
1257	call := &ast.CallExpr{
1258		Fun:    expr,
1259		Lparen: to,
1260		Rparen: to,
1261	}
1262
1263	switch stmt := stmt.(type) {
1264	case *ast.DeferStmt:
1265		stmt.Call = call
1266	case *ast.GoStmt:
1267		stmt.Call = call
1268	}
1269
1270	return replaceNode(parent, bad, stmt)
1271}
1272
1273// parseStmt parses the statement in src and updates its position to
1274// start at pos.
1275func parseStmt(pos token.Pos, src []byte) (ast.Stmt, error) {
1276	// Wrap our expression to make it a valid Go file we can pass to ParseFile.
1277	fileSrc := bytes.Join([][]byte{
1278		[]byte("package fake;func _(){"),
1279		src,
1280		[]byte("}"),
1281	}, nil)
1282
1283	// Use ParseFile instead of ParseExpr because ParseFile has
1284	// best-effort behavior, whereas ParseExpr fails hard on any error.
1285	fakeFile, err := parser.ParseFile(token.NewFileSet(), "", fileSrc, 0)
1286	if fakeFile == nil {
1287		return nil, errors.Errorf("error reading fake file source: %v", err)
1288	}
1289
1290	// Extract our expression node from inside the fake file.
1291	if len(fakeFile.Decls) == 0 {
1292		return nil, errors.Errorf("error parsing fake file: %v", err)
1293	}
1294
1295	fakeDecl, _ := fakeFile.Decls[0].(*ast.FuncDecl)
1296	if fakeDecl == nil || len(fakeDecl.Body.List) == 0 {
1297		return nil, errors.Errorf("no statement in %s: %v", src, err)
1298	}
1299
1300	stmt := fakeDecl.Body.List[0]
1301
1302	// parser.ParseFile returns undefined positions.
1303	// Adjust them for the current file.
1304	offsetPositions(stmt, pos-1-(stmt.Pos()-1))
1305
1306	return stmt, nil
1307}
1308
1309// parseExpr parses the expression in src and updates its position to
1310// start at pos.
1311func parseExpr(pos token.Pos, src []byte) (ast.Expr, error) {
1312	stmt, err := parseStmt(pos, src)
1313	if err != nil {
1314		return nil, err
1315	}
1316
1317	exprStmt, ok := stmt.(*ast.ExprStmt)
1318	if !ok {
1319		return nil, errors.Errorf("no expr in %s: %v", src, err)
1320	}
1321
1322	return exprStmt.X, nil
1323}
1324
1325var tokenPosType = reflect.TypeOf(token.NoPos)
1326
1327// offsetPositions applies an offset to the positions in an ast.Node.
1328func offsetPositions(n ast.Node, offset token.Pos) {
1329	ast.Inspect(n, func(n ast.Node) bool {
1330		if n == nil {
1331			return false
1332		}
1333
1334		v := reflect.ValueOf(n).Elem()
1335
1336		switch v.Kind() {
1337		case reflect.Struct:
1338			for i := 0; i < v.NumField(); i++ {
1339				f := v.Field(i)
1340				if f.Type() != tokenPosType {
1341					continue
1342				}
1343
1344				if !f.CanSet() {
1345					continue
1346				}
1347
1348				f.SetInt(f.Int() + int64(offset))
1349			}
1350		}
1351
1352		return true
1353	})
1354}
1355
1356// replaceNode updates parent's child oldChild to be newChild. It
1357// returns whether it replaced successfully.
1358func replaceNode(parent, oldChild, newChild ast.Node) bool {
1359	if parent == nil || oldChild == nil || newChild == nil {
1360		return false
1361	}
1362
1363	parentVal := reflect.ValueOf(parent).Elem()
1364	if parentVal.Kind() != reflect.Struct {
1365		return false
1366	}
1367
1368	newChildVal := reflect.ValueOf(newChild)
1369
1370	tryReplace := func(v reflect.Value) bool {
1371		if !v.CanSet() || !v.CanInterface() {
1372			return false
1373		}
1374
1375		// If the existing value is oldChild, we found our child. Make
1376		// sure our newChild is assignable and then make the swap.
1377		if v.Interface() == oldChild && newChildVal.Type().AssignableTo(v.Type()) {
1378			v.Set(newChildVal)
1379			return true
1380		}
1381
1382		return false
1383	}
1384
1385	// Loop over parent's struct fields.
1386	for i := 0; i < parentVal.NumField(); i++ {
1387		f := parentVal.Field(i)
1388
1389		switch f.Kind() {
1390		// Check interface and pointer fields.
1391		case reflect.Interface, reflect.Ptr:
1392			if tryReplace(f) {
1393				return true
1394			}
1395
1396		// Search through any slice fields.
1397		case reflect.Slice:
1398			for i := 0; i < f.Len(); i++ {
1399				if tryReplace(f.Index(i)) {
1400					return true
1401				}
1402			}
1403		}
1404	}
1405
1406	return false
1407}
1408