1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package source
6
7import (
8	"context"
9	"go/ast"
10	"go/token"
11	"go/types"
12	"sort"
13
14	"golang.org/x/tools/internal/event"
15	"golang.org/x/tools/internal/lsp/protocol"
16	"golang.org/x/tools/internal/span"
17	errors "golang.org/x/xerrors"
18)
19
20// ReferenceInfo holds information about reference to an identifier in Go source.
21type ReferenceInfo struct {
22	Name string
23	MappedRange
24	ident         *ast.Ident
25	obj           types.Object
26	pkg           Package
27	isDeclaration bool
28}
29
30// References returns a list of references for a given identifier within the packages
31// containing i.File. Declarations appear first in the result.
32func References(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position, includeDeclaration bool) ([]*ReferenceInfo, error) {
33	ctx, done := event.Start(ctx, "source.References")
34	defer done()
35
36	qualifiedObjs, err := qualifiedObjsAtProtocolPos(ctx, s, f, pp)
37	// Don't return references for builtin types.
38	if errors.Is(err, errBuiltin) {
39		return nil, nil
40	}
41	if err != nil {
42		return nil, err
43	}
44
45	refs, err := references(ctx, s, qualifiedObjs, includeDeclaration, true, false)
46	if err != nil {
47		return nil, err
48	}
49
50	toSort := refs
51	if includeDeclaration {
52		toSort = refs[1:]
53	}
54	sort.Slice(toSort, func(i, j int) bool {
55		x := CompareURI(toSort[i].URI(), toSort[j].URI())
56		if x == 0 {
57			return toSort[i].ident.Pos() < toSort[j].ident.Pos()
58		}
59		return x < 0
60	})
61	return refs, nil
62}
63
64// references is a helper function to avoid recomputing qualifiedObjsAtProtocolPos.
65func references(ctx context.Context, snapshot Snapshot, qos []qualifiedObject, includeDeclaration, includeInterfaceRefs, includeEmbeddedRefs bool) ([]*ReferenceInfo, error) {
66	var (
67		references []*ReferenceInfo
68		seen       = make(map[token.Pos]bool)
69	)
70
71	filename := snapshot.FileSet().Position(qos[0].obj.Pos()).Filename
72	pgf, err := qos[0].pkg.File(span.URIFromPath(filename))
73	if err != nil {
74		return nil, err
75	}
76	declIdent, err := findIdentifier(ctx, snapshot, qos[0].pkg, pgf.File, qos[0].obj.Pos())
77	if err != nil {
78		return nil, err
79	}
80	// Make sure declaration is the first item in the response.
81	if includeDeclaration {
82		references = append(references, &ReferenceInfo{
83			MappedRange:   declIdent.MappedRange,
84			Name:          qos[0].obj.Name(),
85			ident:         declIdent.ident,
86			obj:           qos[0].obj,
87			pkg:           declIdent.pkg,
88			isDeclaration: true,
89		})
90	}
91
92	for _, qo := range qos {
93		var searchPkgs []Package
94
95		// Only search dependents if the object is exported.
96		if qo.obj.Exported() {
97			reverseDeps, err := snapshot.GetReverseDependencies(ctx, qo.pkg.ID())
98			if err != nil {
99				return nil, err
100			}
101			searchPkgs = append(searchPkgs, reverseDeps...)
102		}
103		// Add the package in which the identifier is declared.
104		searchPkgs = append(searchPkgs, qo.pkg)
105		for _, pkg := range searchPkgs {
106			for ident, obj := range pkg.GetTypesInfo().Uses {
107				if obj != qo.obj {
108					// If ident is not a use of qo.obj, skip it, with one exception: uses
109					// of an embedded field can be considered references of the embedded
110					// type name.
111					if !includeEmbeddedRefs {
112						continue
113					}
114					v, ok := obj.(*types.Var)
115					if !ok || !v.Embedded() {
116						continue
117					}
118					named, ok := v.Type().(*types.Named)
119					if !ok || named.Obj() != qo.obj {
120						continue
121					}
122				}
123				if seen[ident.Pos()] {
124					continue
125				}
126				seen[ident.Pos()] = true
127				rng, err := posToMappedRange(snapshot, pkg, ident.Pos(), ident.End())
128				if err != nil {
129					return nil, err
130				}
131				references = append(references, &ReferenceInfo{
132					Name:        ident.Name,
133					ident:       ident,
134					pkg:         pkg,
135					obj:         obj,
136					MappedRange: rng,
137				})
138			}
139		}
140	}
141
142	// When searching on type name, don't include interface references -- they
143	// would be things like all references to Stringer for any type that
144	// happened to have a String method.
145	_, isType := declIdent.Declaration.obj.(*types.TypeName)
146	if includeInterfaceRefs && !isType {
147		declRange, err := declIdent.Range()
148		if err != nil {
149			return nil, err
150		}
151		fh, err := snapshot.GetFile(ctx, declIdent.URI())
152		if err != nil {
153			return nil, err
154		}
155		interfaceRefs, err := interfaceReferences(ctx, snapshot, fh, declRange.Start)
156		if err != nil {
157			return nil, err
158		}
159		references = append(references, interfaceRefs...)
160	}
161
162	return references, nil
163}
164
165// interfaceReferences returns the references to the interfaces implemented by
166// the type or method at the given position.
167func interfaceReferences(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position) ([]*ReferenceInfo, error) {
168	implementations, err := implementations(ctx, s, f, pp)
169	if err != nil {
170		if errors.Is(err, ErrNotAType) {
171			return nil, nil
172		}
173		return nil, err
174	}
175
176	var refs []*ReferenceInfo
177	for _, impl := range implementations {
178		implRefs, err := references(ctx, s, []qualifiedObject{impl}, false, false, false)
179		if err != nil {
180			return nil, err
181		}
182		refs = append(refs, implRefs...)
183	}
184	return refs, nil
185}
186