1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package source
6
7import (
8	"context"
9	"fmt"
10	"go/ast"
11	"go/token"
12	"go/types"
13	"strings"
14
15	"golang.org/x/tools/go/ast/astutil"
16	"golang.org/x/tools/internal/event"
17	"golang.org/x/tools/internal/lsp/protocol"
18	errors "golang.org/x/xerrors"
19)
20
21func Highlight(ctx context.Context, snapshot Snapshot, fh FileHandle, pos protocol.Position) ([]protocol.Range, error) {
22	ctx, done := event.Start(ctx, "source.Highlight")
23	defer done()
24
25	pkg, pgf, err := GetParsedFile(ctx, snapshot, fh, WidestPackage)
26	if err != nil {
27		return nil, errors.Errorf("getting file for Highlight: %w", err)
28	}
29	spn, err := pgf.Mapper.PointSpan(pos)
30	if err != nil {
31		return nil, err
32	}
33	rng, err := spn.Range(pgf.Mapper.Converter)
34	if err != nil {
35		return nil, err
36	}
37	path, _ := astutil.PathEnclosingInterval(pgf.File, rng.Start, rng.Start)
38	if len(path) == 0 {
39		return nil, fmt.Errorf("no enclosing position found for %v:%v", int(pos.Line), int(pos.Character))
40	}
41	// If start == end for astutil.PathEnclosingInterval, the 1-char interval
42	// following start is used instead. As a result, we might not get an exact
43	// match so we should check the 1-char interval to the left of the passed
44	// in position to see if that is an exact match.
45	if _, ok := path[0].(*ast.Ident); !ok {
46		if p, _ := astutil.PathEnclosingInterval(pgf.File, rng.Start-1, rng.Start-1); p != nil {
47			switch p[0].(type) {
48			case *ast.Ident, *ast.SelectorExpr:
49				path = p // use preceding ident/selector
50			}
51		}
52	}
53	result, err := highlightPath(pkg, path)
54	if err != nil {
55		return nil, err
56	}
57	var ranges []protocol.Range
58	for rng := range result {
59		mRng, err := posToMappedRange(snapshot, pkg, rng.start, rng.end)
60		if err != nil {
61			return nil, err
62		}
63		pRng, err := mRng.Range()
64		if err != nil {
65			return nil, err
66		}
67		ranges = append(ranges, pRng)
68	}
69	return ranges, nil
70}
71
72func highlightPath(pkg Package, path []ast.Node) (map[posRange]struct{}, error) {
73	result := make(map[posRange]struct{})
74	switch node := path[0].(type) {
75	case *ast.BasicLit:
76		if len(path) > 1 {
77			if _, ok := path[1].(*ast.ImportSpec); ok {
78				err := highlightImportUses(pkg, path, result)
79				return result, err
80			}
81		}
82		highlightFuncControlFlow(path, result)
83	case *ast.ReturnStmt, *ast.FuncDecl, *ast.FuncType:
84		highlightFuncControlFlow(path, result)
85	case *ast.Ident:
86		highlightIdentifiers(pkg, path, result)
87	case *ast.ForStmt, *ast.RangeStmt:
88		highlightLoopControlFlow(path, result)
89	case *ast.SwitchStmt:
90		highlightSwitchFlow(path, result)
91	case *ast.BranchStmt:
92		// BREAK can exit a loop, switch or select, while CONTINUE exit a loop so
93		// these need to be handled separately. They can also be embedded in any
94		// other loop/switch/select if they have a label. TODO: add support for
95		// GOTO and FALLTHROUGH as well.
96		if node.Label != nil {
97			highlightLabeledFlow(node, result)
98		} else {
99			switch node.Tok {
100			case token.BREAK:
101				highlightUnlabeledBreakFlow(path, result)
102			case token.CONTINUE:
103				highlightLoopControlFlow(path, result)
104			}
105		}
106	default:
107		// If the cursor is in an unidentified area, return empty results.
108		return nil, nil
109	}
110	return result, nil
111}
112
113type posRange struct {
114	start, end token.Pos
115}
116
117func highlightFuncControlFlow(path []ast.Node, result map[posRange]struct{}) {
118	var enclosingFunc ast.Node
119	var returnStmt *ast.ReturnStmt
120	var resultsList *ast.FieldList
121	inReturnList := false
122
123Outer:
124	// Reverse walk the path till we get to the func block.
125	for i, n := range path {
126		switch node := n.(type) {
127		case *ast.KeyValueExpr:
128			// If cursor is in a key: value expr, we don't want control flow highlighting
129			return
130		case *ast.CallExpr:
131			// If cusor is an arg in a callExpr, we don't want control flow highlighting.
132			if i > 0 {
133				for _, arg := range node.Args {
134					if arg == path[i-1] {
135						return
136					}
137				}
138			}
139		case *ast.Field:
140			inReturnList = true
141		case *ast.FuncLit:
142			enclosingFunc = n
143			resultsList = node.Type.Results
144			break Outer
145		case *ast.FuncDecl:
146			enclosingFunc = n
147			resultsList = node.Type.Results
148			break Outer
149		case *ast.ReturnStmt:
150			returnStmt = node
151			// If the cursor is not directly in a *ast.ReturnStmt, then
152			// we need to know if it is within one of the values that is being returned.
153			inReturnList = inReturnList || path[0] != returnStmt
154		}
155	}
156	// Cursor is not in a function.
157	if enclosingFunc == nil {
158		return
159	}
160	// If the cursor is on a "return" or "func" keyword, we should highlight all of the exit
161	// points of the function, including the "return" and "func" keywords.
162	highlightAllReturnsAndFunc := path[0] == returnStmt || path[0] == enclosingFunc
163	switch path[0].(type) {
164	case *ast.Ident, *ast.BasicLit:
165		// Cursor is in an identifier and not in a return statement or in the results list.
166		if returnStmt == nil && !inReturnList {
167			return
168		}
169	case *ast.FuncType:
170		highlightAllReturnsAndFunc = true
171	}
172	// The user's cursor may be within the return statement of a function,
173	// or within the result section of a function's signature.
174	// index := -1
175	var nodes []ast.Node
176	if returnStmt != nil {
177		for _, n := range returnStmt.Results {
178			nodes = append(nodes, n)
179		}
180	} else if resultsList != nil {
181		for _, n := range resultsList.List {
182			nodes = append(nodes, n)
183		}
184	}
185	_, index := nodeAtPos(nodes, path[0].Pos())
186
187	// Highlight the correct argument in the function declaration return types.
188	if resultsList != nil && -1 < index && index < len(resultsList.List) {
189		rng := posRange{
190			start: resultsList.List[index].Pos(),
191			end:   resultsList.List[index].End(),
192		}
193		result[rng] = struct{}{}
194	}
195	// Add the "func" part of the func declaration.
196	if highlightAllReturnsAndFunc {
197		r := posRange{
198			start: enclosingFunc.Pos(),
199			end:   enclosingFunc.Pos() + token.Pos(len("func")),
200		}
201		result[r] = struct{}{}
202	}
203	ast.Inspect(enclosingFunc, func(n ast.Node) bool {
204		// Don't traverse any other functions.
205		switch n.(type) {
206		case *ast.FuncDecl, *ast.FuncLit:
207			return enclosingFunc == n
208		}
209		ret, ok := n.(*ast.ReturnStmt)
210		if !ok {
211			return true
212		}
213		var toAdd ast.Node
214		// Add the entire return statement, applies when highlight the word "return" or "func".
215		if highlightAllReturnsAndFunc {
216			toAdd = n
217		}
218		// Add the relevant field within the entire return statement.
219		if -1 < index && index < len(ret.Results) {
220			toAdd = ret.Results[index]
221		}
222		if toAdd != nil {
223			result[posRange{start: toAdd.Pos(), end: toAdd.End()}] = struct{}{}
224		}
225		return false
226	})
227}
228
229func highlightUnlabeledBreakFlow(path []ast.Node, result map[posRange]struct{}) {
230	// Reverse walk the path until we find closest loop, select, or switch.
231	for _, n := range path {
232		switch n.(type) {
233		case *ast.ForStmt, *ast.RangeStmt:
234			highlightLoopControlFlow(path, result)
235			return // only highlight the innermost statement
236		case *ast.SwitchStmt:
237			highlightSwitchFlow(path, result)
238			return
239		case *ast.SelectStmt:
240			// TODO: add highlight when breaking a select.
241			return
242		}
243	}
244}
245
246func highlightLabeledFlow(node *ast.BranchStmt, result map[posRange]struct{}) {
247	obj := node.Label.Obj
248	if obj == nil || obj.Decl == nil {
249		return
250	}
251	label, ok := obj.Decl.(*ast.LabeledStmt)
252	if !ok {
253		return
254	}
255	switch label.Stmt.(type) {
256	case *ast.ForStmt, *ast.RangeStmt:
257		highlightLoopControlFlow([]ast.Node{label.Stmt, label}, result)
258	case *ast.SwitchStmt:
259		highlightSwitchFlow([]ast.Node{label.Stmt, label}, result)
260	}
261}
262
263func labelFor(path []ast.Node) *ast.Ident {
264	if len(path) > 1 {
265		if n, ok := path[1].(*ast.LabeledStmt); ok {
266			return n.Label
267		}
268	}
269	return nil
270}
271
272func highlightLoopControlFlow(path []ast.Node, result map[posRange]struct{}) {
273	var loop ast.Node
274	var loopLabel *ast.Ident
275	stmtLabel := labelFor(path)
276Outer:
277	// Reverse walk the path till we get to the for loop.
278	for i := range path {
279		switch n := path[i].(type) {
280		case *ast.ForStmt, *ast.RangeStmt:
281			loopLabel = labelFor(path[i:])
282
283			if stmtLabel == nil || loopLabel == stmtLabel {
284				loop = n
285				break Outer
286			}
287		}
288	}
289	if loop == nil {
290		return
291	}
292
293	// Add the for statement.
294	rng := posRange{
295		start: loop.Pos(),
296		end:   loop.Pos() + token.Pos(len("for")),
297	}
298	result[rng] = struct{}{}
299
300	// Traverse AST to find branch statements within the same for-loop.
301	ast.Inspect(loop, func(n ast.Node) bool {
302		switch n.(type) {
303		case *ast.ForStmt, *ast.RangeStmt:
304			return loop == n
305		case *ast.SwitchStmt, *ast.SelectStmt:
306			return false
307		}
308		b, ok := n.(*ast.BranchStmt)
309		if !ok {
310			return true
311		}
312		if b.Label == nil || labelDecl(b.Label) == loopLabel {
313			result[posRange{start: b.Pos(), end: b.End()}] = struct{}{}
314		}
315		return true
316	})
317
318	// Find continue statements in the same loop or switches/selects.
319	ast.Inspect(loop, func(n ast.Node) bool {
320		switch n.(type) {
321		case *ast.ForStmt, *ast.RangeStmt:
322			return loop == n
323		}
324
325		if n, ok := n.(*ast.BranchStmt); ok && n.Tok == token.CONTINUE {
326			result[posRange{start: n.Pos(), end: n.End()}] = struct{}{}
327		}
328		return true
329	})
330
331	// We don't need to check other for loops if we aren't looking for labeled statements.
332	if loopLabel == nil {
333		return
334	}
335
336	// Find labeled branch statements in any loop.
337	ast.Inspect(loop, func(n ast.Node) bool {
338		b, ok := n.(*ast.BranchStmt)
339		if !ok {
340			return true
341		}
342		// statement with labels that matches the loop
343		if b.Label != nil && labelDecl(b.Label) == loopLabel {
344			result[posRange{start: b.Pos(), end: b.End()}] = struct{}{}
345		}
346		return true
347	})
348}
349
350func highlightSwitchFlow(path []ast.Node, result map[posRange]struct{}) {
351	var switchNode ast.Node
352	var switchNodeLabel *ast.Ident
353	stmtLabel := labelFor(path)
354Outer:
355	// Reverse walk the path till we get to the switch statement.
356	for i := range path {
357		switch n := path[i].(type) {
358		case *ast.SwitchStmt:
359			switchNodeLabel = labelFor(path[i:])
360			if stmtLabel == nil || switchNodeLabel == stmtLabel {
361				switchNode = n
362				break Outer
363			}
364		}
365	}
366	// Cursor is not in a switch statement
367	if switchNode == nil {
368		return
369	}
370
371	// Add the switch statement.
372	rng := posRange{
373		start: switchNode.Pos(),
374		end:   switchNode.Pos() + token.Pos(len("switch")),
375	}
376	result[rng] = struct{}{}
377
378	// Traverse AST to find break statements within the same switch.
379	ast.Inspect(switchNode, func(n ast.Node) bool {
380		switch n.(type) {
381		case *ast.SwitchStmt:
382			return switchNode == n
383		case *ast.ForStmt, *ast.RangeStmt, *ast.SelectStmt:
384			return false
385		}
386
387		b, ok := n.(*ast.BranchStmt)
388		if !ok || b.Tok != token.BREAK {
389			return true
390		}
391
392		if b.Label == nil || labelDecl(b.Label) == switchNodeLabel {
393			result[posRange{start: b.Pos(), end: b.End()}] = struct{}{}
394		}
395		return true
396	})
397
398	// We don't need to check other switches if we aren't looking for labeled statements.
399	if switchNodeLabel == nil {
400		return
401	}
402
403	// Find labeled break statements in any switch
404	ast.Inspect(switchNode, func(n ast.Node) bool {
405		b, ok := n.(*ast.BranchStmt)
406		if !ok || b.Tok != token.BREAK {
407			return true
408		}
409
410		if b.Label != nil && labelDecl(b.Label) == switchNodeLabel {
411			result[posRange{start: b.Pos(), end: b.End()}] = struct{}{}
412		}
413
414		return true
415	})
416}
417
418func labelDecl(n *ast.Ident) *ast.Ident {
419	if n == nil {
420		return nil
421	}
422	if n.Obj == nil {
423		return nil
424	}
425	if n.Obj.Decl == nil {
426		return nil
427	}
428	stmt, ok := n.Obj.Decl.(*ast.LabeledStmt)
429	if !ok {
430		return nil
431	}
432	return stmt.Label
433}
434
435func highlightImportUses(pkg Package, path []ast.Node, result map[posRange]struct{}) error {
436	basicLit, ok := path[0].(*ast.BasicLit)
437	if !ok {
438		return errors.Errorf("highlightImportUses called with an ast.Node of type %T", basicLit)
439	}
440	ast.Inspect(path[len(path)-1], func(node ast.Node) bool {
441		if imp, ok := node.(*ast.ImportSpec); ok && imp.Path == basicLit {
442			result[posRange{start: node.Pos(), end: node.End()}] = struct{}{}
443			return false
444		}
445		n, ok := node.(*ast.Ident)
446		if !ok {
447			return true
448		}
449		obj, ok := pkg.GetTypesInfo().ObjectOf(n).(*types.PkgName)
450		if !ok {
451			return true
452		}
453		if !strings.Contains(basicLit.Value, obj.Name()) {
454			return true
455		}
456		result[posRange{start: n.Pos(), end: n.End()}] = struct{}{}
457		return false
458	})
459	return nil
460}
461
462func highlightIdentifiers(pkg Package, path []ast.Node, result map[posRange]struct{}) error {
463	id, ok := path[0].(*ast.Ident)
464	if !ok {
465		return errors.Errorf("highlightIdentifiers called with an ast.Node of type %T", id)
466	}
467	// Check if ident is inside return or func decl.
468	highlightFuncControlFlow(path, result)
469
470	// TODO: maybe check if ident is a reserved word, if true then don't continue and return results.
471
472	idObj := pkg.GetTypesInfo().ObjectOf(id)
473	pkgObj, isImported := idObj.(*types.PkgName)
474	ast.Inspect(path[len(path)-1], func(node ast.Node) bool {
475		if imp, ok := node.(*ast.ImportSpec); ok && isImported {
476			highlightImport(pkgObj, imp, result)
477		}
478		n, ok := node.(*ast.Ident)
479		if !ok {
480			return true
481		}
482		if n.Name != id.Name {
483			return false
484		}
485		if nObj := pkg.GetTypesInfo().ObjectOf(n); nObj == idObj {
486			result[posRange{start: n.Pos(), end: n.End()}] = struct{}{}
487		}
488		return false
489	})
490	return nil
491}
492
493func highlightImport(obj *types.PkgName, imp *ast.ImportSpec, result map[posRange]struct{}) {
494	if imp.Name != nil || imp.Path == nil {
495		return
496	}
497	if !strings.Contains(imp.Path.Value, obj.Name()) {
498		return
499	}
500	result[posRange{start: imp.Path.Pos(), end: imp.Path.End()}] = struct{}{}
501}
502