1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package cache
6
7import (
8	"context"
9	"fmt"
10	"go/ast"
11	"go/parser"
12	"go/scanner"
13	"go/token"
14
15	"golang.org/x/tools/internal/lsp/source"
16	"golang.org/x/tools/internal/lsp/telemetry/trace"
17	"golang.org/x/tools/internal/memoize"
18)
19
20// Limits the number of parallel parser calls per process.
21var parseLimit = make(chan bool, 20)
22
23// parseKey uniquely identifies a parsed Go file.
24type parseKey struct {
25	file source.FileIdentity
26	mode source.ParseMode
27}
28
29type parseGoHandle struct {
30	handle *memoize.Handle
31	file   source.FileHandle
32	mode   source.ParseMode
33}
34
35type parseGoData struct {
36	memoize.NoCopy
37
38	ast *ast.File
39	err error
40}
41
42func (c *cache) ParseGoHandle(fh source.FileHandle, mode source.ParseMode) source.ParseGoHandle {
43	key := parseKey{
44		file: fh.Identity(),
45		mode: mode,
46	}
47	h := c.store.Bind(key, func(ctx context.Context) interface{} {
48		data := &parseGoData{}
49		data.ast, data.err = parseGo(ctx, c, fh, mode)
50		return data
51	})
52	return &parseGoHandle{
53		handle: h,
54		file:   fh,
55		mode:   mode,
56	}
57}
58
59func (h *parseGoHandle) File() source.FileHandle {
60	return h.file
61}
62
63func (h *parseGoHandle) Mode() source.ParseMode {
64	return h.mode
65}
66
67func (h *parseGoHandle) Parse(ctx context.Context) (*ast.File, error) {
68	v := h.handle.Get(ctx)
69	if v == nil {
70		return nil, ctx.Err()
71	}
72	data := v.(*parseGoData)
73	return data.ast, data.err
74}
75
76func parseGo(ctx context.Context, c *cache, fh source.FileHandle, mode source.ParseMode) (*ast.File, error) {
77	ctx, ts := trace.StartSpan(ctx, "cache.parseGo")
78	defer ts.End()
79	buf, _, err := fh.Read(ctx)
80	if err != nil {
81		return nil, err
82	}
83	parseLimit <- true
84	defer func() { <-parseLimit }()
85	parserMode := parser.AllErrors | parser.ParseComments
86	if mode == source.ParseHeader {
87		parserMode = parser.ImportsOnly
88	}
89	ast, err := parser.ParseFile(c.fset, fh.Identity().URI.Filename(), buf, parserMode)
90	if ast != nil {
91		if mode == source.ParseExported {
92			trimAST(ast)
93		}
94		// Fix any badly parsed parts of the AST.
95		tok := c.fset.File(ast.Pos())
96		if err := fix(ctx, ast, tok, buf); err != nil {
97			// TODO: Do something with the error (need access to a logger in here).
98		}
99	}
100	if ast == nil {
101		return nil, err
102	}
103	return ast, err
104}
105
106// trimAST clears any part of the AST not relevant to type checking
107// expressions at pos.
108func trimAST(file *ast.File) {
109	ast.Inspect(file, func(n ast.Node) bool {
110		if n == nil {
111			return false
112		}
113		switch n := n.(type) {
114		case *ast.FuncDecl:
115			n.Body = nil
116		case *ast.BlockStmt:
117			n.List = nil
118		case *ast.CaseClause:
119			n.Body = nil
120		case *ast.CommClause:
121			n.Body = nil
122		case *ast.CompositeLit:
123			// Leave elts in place for [...]T
124			// array literals, because they can
125			// affect the expression's type.
126			if !isEllipsisArray(n.Type) {
127				n.Elts = nil
128			}
129		}
130		return true
131	})
132}
133
134func isEllipsisArray(n ast.Expr) bool {
135	at, ok := n.(*ast.ArrayType)
136	if !ok {
137		return false
138	}
139	_, ok = at.Len.(*ast.Ellipsis)
140	return ok
141}
142
143// fix inspects and potentially modifies any *ast.BadStmts or *ast.BadExprs in the AST.
144// We attempt to modify the AST such that we can type-check it more effectively.
145func fix(ctx context.Context, file *ast.File, tok *token.File, src []byte) error {
146	var parent ast.Node
147	var err error
148	ast.Inspect(file, func(n ast.Node) bool {
149		if n == nil {
150			return false
151		}
152		switch n := n.(type) {
153		case *ast.BadStmt:
154			err = parseDeferOrGoStmt(n, parent, tok, src) // don't shadow err
155			if err != nil {
156				err = fmt.Errorf("unable to parse defer or go from *ast.BadStmt: %v", err)
157			}
158			return false
159		default:
160			parent = n
161			return true
162		}
163	})
164	return err
165}
166
167// parseDeferOrGoStmt tries to parse an *ast.BadStmt into a defer or a go statement.
168//
169// go/parser packages a statement of the form "defer x." as an *ast.BadStmt because
170// it does not include a call expression. This means that go/types skips type-checking
171// this statement entirely, and we can't use the type information when completing.
172// Here, we try to generate a fake *ast.DeferStmt or *ast.GoStmt to put into the AST,
173// instead of the *ast.BadStmt.
174func parseDeferOrGoStmt(bad *ast.BadStmt, parent ast.Node, tok *token.File, src []byte) error {
175	// Check if we have a bad statement containing either a "go" or "defer".
176	s := &scanner.Scanner{}
177	s.Init(tok, src, nil, 0)
178
179	var pos token.Pos
180	var tkn token.Token
181	var lit string
182	for {
183		if tkn == token.EOF {
184			return fmt.Errorf("reached the end of the file")
185		}
186		if pos >= bad.From {
187			break
188		}
189		pos, tkn, lit = s.Scan()
190	}
191	var stmt ast.Stmt
192	switch lit {
193	case "defer":
194		stmt = &ast.DeferStmt{
195			Defer: pos,
196		}
197	case "go":
198		stmt = &ast.GoStmt{
199			Go: pos,
200		}
201	default:
202		return fmt.Errorf("no defer or go statement found")
203	}
204
205	// The expression after the "defer" or "go" starts at this position.
206	from, _, _ := s.Scan()
207	var to, curr token.Pos
208FindTo:
209	for {
210		curr, tkn, lit = s.Scan()
211		// TODO(rstambler): This still needs more handling to work correctly.
212		// We encounter a specific issue with code that looks like this:
213		//
214		//      defer x.<>
215		//      y := 1
216		//
217		// In this scenario, we parse it as "defer x.y", which then fails to
218		// type-check, and we don't get completions as expected.
219		switch tkn {
220		case token.COMMENT, token.EOF, token.SEMICOLON, token.DEFINE:
221			break FindTo
222		}
223		// to is the end of expression that should become the Fun part of the call.
224		to = curr
225	}
226	if !from.IsValid() || tok.Offset(from) >= len(src) {
227		return fmt.Errorf("invalid from position")
228	}
229	if !to.IsValid() || tok.Offset(to)+1 >= len(src) {
230		return fmt.Errorf("invalid to position")
231	}
232	exprstr := string(src[tok.Offset(from) : tok.Offset(to)+1])
233	expr, err := parser.ParseExpr(exprstr)
234	if expr == nil {
235		return fmt.Errorf("no expr in %s: %v", exprstr, err)
236	}
237	// parser.ParseExpr returns undefined positions.
238	// Adjust them for the current file.
239	offsetPositions(expr, from-1)
240
241	// Package the expression into a fake *ast.CallExpr and re-insert into the function.
242	call := &ast.CallExpr{
243		Fun:    expr,
244		Lparen: to,
245		Rparen: to,
246	}
247	switch stmt := stmt.(type) {
248	case *ast.DeferStmt:
249		stmt.Call = call
250	case *ast.GoStmt:
251		stmt.Call = call
252	}
253	switch parent := parent.(type) {
254	case *ast.BlockStmt:
255		for i, s := range parent.List {
256			if s == bad {
257				parent.List[i] = stmt
258				break
259			}
260		}
261	}
262	return nil
263}
264
265// offsetPositions applies an offset to the positions in an ast.Node.
266// TODO(rstambler): Add more cases here as they become necessary.
267func offsetPositions(expr ast.Expr, offset token.Pos) {
268	ast.Inspect(expr, func(n ast.Node) bool {
269		switch n := n.(type) {
270		case *ast.Ident:
271			n.NamePos += offset
272			return false
273		default:
274			return true
275		}
276	})
277}
278