1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package cache
6
7import (
8	"context"
9	"go/ast"
10	"go/token"
11	"sync"
12
13	"golang.org/x/tools/internal/lsp/source"
14	"golang.org/x/tools/internal/span"
15)
16
17// goFile holds all of the information we know about a Go file.
18type goFile struct {
19	fileBase
20
21	// mu protects all mutable state of the Go file,
22	// which can be modified during type-checking.
23	mu sync.Mutex
24
25	// missingImports is the set of unresolved imports for this package.
26	// It contains any packages with `go list` errors.
27	missingImports map[packagePath]struct{}
28
29	// justOpened indicates that the file has just been opened.
30	// We re-run go/packages.Load on just opened files to make sure
31	// that we know about all of their packages.
32	justOpened bool
33
34	imports []*ast.ImportSpec
35
36	ast  *astFile
37	pkgs map[packageID]*pkg
38	meta map[packageID]*metadata
39}
40
41type astFile struct {
42	uri       span.URI
43	file      *ast.File
44	err       error // parse errors
45	ph        source.ParseGoHandle
46	isTrimmed bool
47}
48
49func (f *goFile) GetToken(ctx context.Context) *token.File {
50	f.view.mu.Lock()
51	defer f.view.mu.Unlock()
52
53	if f.isDirty() || f.astIsTrimmed() {
54		if _, err := f.view.loadParseTypecheck(ctx, f); err != nil {
55			f.View().Session().Logger().Errorf(ctx, "unable to check package for %s: %v", f.URI(), err)
56			return nil
57		}
58	}
59	if unexpectedAST(ctx, f) {
60		return nil
61	}
62	return f.token
63}
64
65func (f *goFile) GetAnyAST(ctx context.Context) *ast.File {
66	f.view.mu.Lock()
67	defer f.view.mu.Unlock()
68
69	if f.isDirty() {
70		if _, err := f.view.loadParseTypecheck(ctx, f); err != nil {
71			f.View().Session().Logger().Errorf(ctx, "unable to check package for %s: %v", f.URI(), err)
72			return nil
73		}
74	}
75	if f.ast == nil {
76		return nil
77	}
78	return f.ast.file
79}
80
81func (f *goFile) GetAST(ctx context.Context) *ast.File {
82	f.view.mu.Lock()
83	defer f.view.mu.Unlock()
84
85	if f.isDirty() || f.astIsTrimmed() {
86		if _, err := f.view.loadParseTypecheck(ctx, f); err != nil {
87			f.View().Session().Logger().Errorf(ctx, "unable to check package for %s: %v", f.URI(), err)
88			return nil
89		}
90	}
91	if unexpectedAST(ctx, f) {
92		return nil
93	}
94	return f.ast.file
95}
96
97func (f *goFile) GetPackages(ctx context.Context) []source.Package {
98	f.view.mu.Lock()
99	defer f.view.mu.Unlock()
100
101	if f.isDirty() || f.astIsTrimmed() {
102		if errs, err := f.view.loadParseTypecheck(ctx, f); err != nil {
103			f.View().Session().Logger().Errorf(ctx, "unable to check package for %s: %v", f.URI(), err)
104
105			// Create diagnostics for errors if we are able to.
106			if len(errs) > 0 {
107				return []source.Package{&pkg{errors: errs}}
108			}
109			return nil
110		}
111	}
112	if unexpectedAST(ctx, f) {
113		return nil
114	}
115	var pkgs []source.Package
116	for _, pkg := range f.pkgs {
117		pkgs = append(pkgs, pkg)
118	}
119	return pkgs
120}
121
122func (f *goFile) GetPackage(ctx context.Context) source.Package {
123	pkgs := f.GetPackages(ctx)
124	var result source.Package
125
126	// Pick the "narrowest" package, i.e. the package with the fewest number of files.
127	// This solves the problem of test variants,
128	// as the test will have more files than the non-test package.
129	for _, pkg := range pkgs {
130		if result == nil || len(pkg.GetFilenames()) < len(result.GetFilenames()) {
131			result = pkg
132		}
133	}
134	return result
135}
136
137func unexpectedAST(ctx context.Context, f *goFile) bool {
138	f.mu.Lock()
139	defer f.mu.Unlock()
140
141	// If the AST comes back nil, something has gone wrong.
142	if f.ast == nil {
143		f.View().Session().Logger().Errorf(ctx, "expected full AST for %s, returned nil", f.URI())
144		return true
145	}
146	// If the AST comes back trimmed, something has gone wrong.
147	if f.ast.isTrimmed {
148		f.View().Session().Logger().Errorf(ctx, "expected full AST for %s, returned trimmed", f.URI())
149		return true
150	}
151	return false
152}
153
154// isDirty is true if the file needs to be type-checked.
155// It assumes that the file's view's mutex is held by the caller.
156func (f *goFile) isDirty() bool {
157	f.mu.Lock()
158	defer f.mu.Unlock()
159
160	// If the the file has just been opened,
161	// it may be part of more packages than we are aware of.
162	//
163	// Note: This must be the first case, otherwise we may not reset the value of f.justOpened.
164	if f.justOpened {
165		f.meta = make(map[packageID]*metadata)
166		f.pkgs = make(map[packageID]*pkg)
167		f.justOpened = false
168		return true
169	}
170	if len(f.meta) == 0 || len(f.pkgs) == 0 {
171		return true
172	}
173	if len(f.missingImports) > 0 {
174		return true
175	}
176	return f.token == nil || f.ast == nil
177}
178
179func (f *goFile) astIsTrimmed() bool {
180	f.mu.Lock()
181	defer f.mu.Unlock()
182
183	return f.ast != nil && f.ast.isTrimmed
184}
185
186func (f *goFile) GetActiveReverseDeps(ctx context.Context) []source.GoFile {
187	pkg := f.GetPackage(ctx)
188	if pkg == nil {
189		return nil
190	}
191
192	f.view.mu.Lock()
193	defer f.view.mu.Unlock()
194
195	f.view.mcache.mu.Lock()
196	defer f.view.mcache.mu.Unlock()
197
198	id := packageID(pkg.ID())
199
200	seen := make(map[packageID]struct{}) // visited packages
201	results := make(map[*goFile]struct{})
202	f.view.reverseDeps(ctx, seen, results, id)
203
204	var files []source.GoFile
205	for rd := range results {
206		if rd == nil {
207			continue
208		}
209		// Don't return any of the active files in this package.
210		if _, ok := rd.pkgs[id]; ok {
211			continue
212		}
213		files = append(files, rd)
214	}
215	return files
216}
217
218func (v *view) reverseDeps(ctx context.Context, seen map[packageID]struct{}, results map[*goFile]struct{}, id packageID) {
219	if _, ok := seen[id]; ok {
220		return
221	}
222	seen[id] = struct{}{}
223	m, ok := v.mcache.packages[id]
224	if !ok {
225		return
226	}
227	for _, filename := range m.files {
228		uri := span.FileURI(filename)
229		if f, err := v.getFile(uri); err == nil && v.session.IsOpen(uri) {
230			results[f.(*goFile)] = struct{}{}
231		}
232	}
233	for parentID := range m.parents {
234		v.reverseDeps(ctx, seen, results, parentID)
235	}
236}
237