1// Copyright 2020 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package cache
6
7import (
8	"context"
9	"fmt"
10	"go/ast"
11	"io/ioutil"
12	"os"
13	"path/filepath"
14	"sort"
15	"strconv"
16	"strings"
17
18	"golang.org/x/mod/modfile"
19	"golang.org/x/tools/internal/event"
20	"golang.org/x/tools/internal/gocommand"
21	"golang.org/x/tools/internal/lsp/command"
22	"golang.org/x/tools/internal/lsp/debug/tag"
23	"golang.org/x/tools/internal/lsp/diff"
24	"golang.org/x/tools/internal/lsp/protocol"
25	"golang.org/x/tools/internal/lsp/source"
26	"golang.org/x/tools/internal/memoize"
27	"golang.org/x/tools/internal/span"
28)
29
30type modTidyKey struct {
31	sessionID       string
32	env             string
33	gomod           source.FileIdentity
34	imports         string
35	unsavedOverlays string
36	view            string
37}
38
39type modTidyHandle struct {
40	handle *memoize.Handle
41}
42
43type modTidyData struct {
44	tidied *source.TidiedModule
45	err    error
46}
47
48func (mth *modTidyHandle) tidy(ctx context.Context, snapshot *snapshot) (*source.TidiedModule, error) {
49	v, err := mth.handle.Get(ctx, snapshot.generation, snapshot)
50	if err != nil {
51		return nil, err
52	}
53	data := v.(*modTidyData)
54	return data.tidied, data.err
55}
56
57func (s *snapshot) ModTidy(ctx context.Context, pm *source.ParsedModule) (*source.TidiedModule, error) {
58	if pm.File == nil {
59		return nil, fmt.Errorf("cannot tidy unparseable go.mod file: %v", pm.URI)
60	}
61	if handle := s.getModTidyHandle(pm.URI); handle != nil {
62		return handle.tidy(ctx, s)
63	}
64	fh, err := s.GetFile(ctx, pm.URI)
65	if err != nil {
66		return nil, err
67	}
68	// If the file handle is an overlay, it may not be written to disk.
69	// The go.mod file has to be on disk for `go mod tidy` to work.
70	if _, ok := fh.(*overlay); ok {
71		if info, _ := os.Stat(fh.URI().Filename()); info == nil {
72			return nil, source.ErrNoModOnDisk
73		}
74	}
75	if criticalErr := s.GetCriticalError(ctx); criticalErr != nil {
76		return &source.TidiedModule{
77			Diagnostics: criticalErr.DiagList,
78		}, nil
79	}
80	workspacePkgs, err := s.workspacePackageHandles(ctx)
81	if err != nil {
82		return nil, err
83	}
84	importHash, err := s.hashImports(ctx, workspacePkgs)
85	if err != nil {
86		return nil, err
87	}
88
89	s.mu.Lock()
90	overlayHash := hashUnsavedOverlays(s.files)
91	s.mu.Unlock()
92
93	key := modTidyKey{
94		sessionID:       s.view.session.id,
95		view:            s.view.folder.Filename(),
96		imports:         importHash,
97		unsavedOverlays: overlayHash,
98		gomod:           fh.FileIdentity(),
99		env:             hashEnv(s),
100	}
101	h := s.generation.Bind(key, func(ctx context.Context, arg memoize.Arg) interface{} {
102		ctx, done := event.Start(ctx, "cache.ModTidyHandle", tag.URI.Of(fh.URI()))
103		defer done()
104
105		snapshot := arg.(*snapshot)
106		inv := &gocommand.Invocation{
107			Verb:       "mod",
108			Args:       []string{"tidy"},
109			WorkingDir: filepath.Dir(fh.URI().Filename()),
110		}
111		tmpURI, inv, cleanup, err := snapshot.goCommandInvocation(ctx, source.WriteTemporaryModFile, inv)
112		if err != nil {
113			return &modTidyData{err: err}
114		}
115		// Keep the temporary go.mod file around long enough to parse it.
116		defer cleanup()
117
118		if _, err := s.view.session.gocmdRunner.Run(ctx, *inv); err != nil {
119			return &modTidyData{err: err}
120		}
121		// Go directly to disk to get the temporary mod file, since it is
122		// always on disk.
123		tempContents, err := ioutil.ReadFile(tmpURI.Filename())
124		if err != nil {
125			return &modTidyData{err: err}
126		}
127		ideal, err := modfile.Parse(tmpURI.Filename(), tempContents, nil)
128		if err != nil {
129			// We do not need to worry about the temporary file's parse errors
130			// since it has been "tidied".
131			return &modTidyData{err: err}
132		}
133		// Compare the original and tidied go.mod files to compute errors and
134		// suggested fixes.
135		diagnostics, err := modTidyDiagnostics(ctx, snapshot, pm, ideal, workspacePkgs)
136		if err != nil {
137			return &modTidyData{err: err}
138		}
139		return &modTidyData{
140			tidied: &source.TidiedModule{
141				Diagnostics:   diagnostics,
142				TidiedContent: tempContents,
143			},
144		}
145	}, nil)
146
147	mth := &modTidyHandle{handle: h}
148	s.mu.Lock()
149	s.modTidyHandles[fh.URI()] = mth
150	s.mu.Unlock()
151
152	return mth.tidy(ctx, s)
153}
154
155func (s *snapshot) uriToModDecl(ctx context.Context, uri span.URI) (protocol.Range, error) {
156	fh, err := s.GetFile(ctx, uri)
157	if err != nil {
158		return protocol.Range{}, nil
159	}
160	pmf, err := s.ParseMod(ctx, fh)
161	if err != nil {
162		return protocol.Range{}, nil
163	}
164	if pmf.File.Module == nil || pmf.File.Module.Syntax == nil {
165		return protocol.Range{}, nil
166	}
167	return rangeFromPositions(pmf.Mapper, pmf.File.Module.Syntax.Start, pmf.File.Module.Syntax.End)
168}
169
170func (s *snapshot) hashImports(ctx context.Context, wsPackages []*packageHandle) (string, error) {
171	seen := map[string]struct{}{}
172	var imports []string
173	for _, ph := range wsPackages {
174		for _, imp := range ph.imports(ctx, s) {
175			if _, ok := seen[imp]; !ok {
176				imports = append(imports, imp)
177				seen[imp] = struct{}{}
178			}
179		}
180	}
181	sort.Strings(imports)
182	hashed := strings.Join(imports, ",")
183	return hashContents([]byte(hashed)), nil
184}
185
186// modTidyDiagnostics computes the differences between the original and tidied
187// go.mod files to produce diagnostic and suggested fixes. Some diagnostics
188// may appear on the Go files that import packages from missing modules.
189func modTidyDiagnostics(ctx context.Context, snapshot source.Snapshot, pm *source.ParsedModule, ideal *modfile.File, workspacePkgs []*packageHandle) (diagnostics []*source.Diagnostic, err error) {
190	// First, determine which modules are unused and which are missing from the
191	// original go.mod file.
192	var (
193		unused          = make(map[string]*modfile.Require, len(pm.File.Require))
194		missing         = make(map[string]*modfile.Require, len(ideal.Require))
195		wrongDirectness = make(map[string]*modfile.Require, len(pm.File.Require))
196	)
197	for _, req := range pm.File.Require {
198		unused[req.Mod.Path] = req
199	}
200	for _, req := range ideal.Require {
201		origReq := unused[req.Mod.Path]
202		if origReq == nil {
203			missing[req.Mod.Path] = req
204			continue
205		} else if origReq.Indirect != req.Indirect {
206			wrongDirectness[req.Mod.Path] = origReq
207		}
208		delete(unused, req.Mod.Path)
209	}
210	for _, req := range wrongDirectness {
211		// Handle dependencies that are incorrectly labeled indirect and
212		// vice versa.
213		srcDiag, err := directnessDiagnostic(pm.Mapper, req, snapshot.View().Options().ComputeEdits)
214		if err != nil {
215			return nil, err
216		}
217		diagnostics = append(diagnostics, srcDiag)
218	}
219	// Next, compute any diagnostics for modules that are missing from the
220	// go.mod file. The fixes will be for the go.mod file, but the
221	// diagnostics should also appear in both the go.mod file and the import
222	// statements in the Go files in which the dependencies are used.
223	missingModuleFixes := map[*modfile.Require][]source.SuggestedFix{}
224	for _, req := range missing {
225		srcDiag, err := missingModuleDiagnostic(pm, req)
226		if err != nil {
227			return nil, err
228		}
229		missingModuleFixes[req] = srcDiag.SuggestedFixes
230		diagnostics = append(diagnostics, srcDiag)
231	}
232	// Add diagnostics for missing modules anywhere they are imported in the
233	// workspace.
234	for _, ph := range workspacePkgs {
235		missingImports := map[string]*modfile.Require{}
236
237		// If -mod=readonly is not set we may have successfully imported
238		// packages from missing modules. Otherwise they'll be in
239		// MissingDependencies. Combine both.
240		importedPkgs := ph.imports(ctx, snapshot)
241
242		for _, imp := range importedPkgs {
243			if req, ok := missing[imp]; ok {
244				missingImports[imp] = req
245				break
246			}
247			// If the import is a package of the dependency, then add the
248			// package to the map, this will eliminate the need to do this
249			// prefix package search on each import for each file.
250			// Example:
251			//
252			// import (
253			//   "golang.org/x/tools/go/expect"
254			//   "golang.org/x/tools/go/packages"
255			// )
256			// They both are related to the same module: "golang.org/x/tools".
257			var match string
258			for _, req := range ideal.Require {
259				if strings.HasPrefix(imp, req.Mod.Path) && len(req.Mod.Path) > len(match) {
260					match = req.Mod.Path
261				}
262			}
263			if req, ok := missing[match]; ok {
264				missingImports[imp] = req
265			}
266		}
267		// None of this package's imports are from missing modules.
268		if len(missingImports) == 0 {
269			continue
270		}
271		for _, pgh := range ph.compiledGoFiles {
272			pgf, err := snapshot.ParseGo(ctx, pgh.file, source.ParseHeader)
273			if err != nil {
274				continue
275			}
276			file, m := pgf.File, pgf.Mapper
277			if file == nil || m == nil {
278				continue
279			}
280			imports := make(map[string]*ast.ImportSpec)
281			for _, imp := range file.Imports {
282				if imp.Path == nil {
283					continue
284				}
285				if target, err := strconv.Unquote(imp.Path.Value); err == nil {
286					imports[target] = imp
287				}
288			}
289			if len(imports) == 0 {
290				continue
291			}
292			for importPath, req := range missingImports {
293				imp, ok := imports[importPath]
294				if !ok {
295					continue
296				}
297				fixes, ok := missingModuleFixes[req]
298				if !ok {
299					return nil, fmt.Errorf("no missing module fix for %q (%q)", importPath, req.Mod.Path)
300				}
301				srcErr, err := missingModuleForImport(snapshot, m, imp, req, fixes)
302				if err != nil {
303					return nil, err
304				}
305				diagnostics = append(diagnostics, srcErr)
306			}
307		}
308	}
309	// Finally, add errors for any unused dependencies.
310	onlyDiagnostic := len(diagnostics) == 0 && len(unused) == 1
311	for _, req := range unused {
312		srcErr, err := unusedDiagnostic(pm.Mapper, req, onlyDiagnostic)
313		if err != nil {
314			return nil, err
315		}
316		diagnostics = append(diagnostics, srcErr)
317	}
318	return diagnostics, nil
319}
320
321// unusedDiagnostic returns a source.Diagnostic for an unused require.
322func unusedDiagnostic(m *protocol.ColumnMapper, req *modfile.Require, onlyDiagnostic bool) (*source.Diagnostic, error) {
323	rng, err := rangeFromPositions(m, req.Syntax.Start, req.Syntax.End)
324	if err != nil {
325		return nil, err
326	}
327	title := fmt.Sprintf("Remove dependency: %s", req.Mod.Path)
328	cmd, err := command.NewRemoveDependencyCommand(title, command.RemoveDependencyArgs{
329		URI:            protocol.URIFromSpanURI(m.URI),
330		OnlyDiagnostic: onlyDiagnostic,
331		ModulePath:     req.Mod.Path,
332	})
333	if err != nil {
334		return nil, err
335	}
336	return &source.Diagnostic{
337		URI:            m.URI,
338		Range:          rng,
339		Severity:       protocol.SeverityWarning,
340		Source:         source.ModTidyError,
341		Message:        fmt.Sprintf("%s is not used in this module", req.Mod.Path),
342		SuggestedFixes: []source.SuggestedFix{source.SuggestedFixFromCommand(cmd, protocol.QuickFix)},
343	}, nil
344}
345
346// directnessDiagnostic extracts errors when a dependency is labeled indirect when
347// it should be direct and vice versa.
348func directnessDiagnostic(m *protocol.ColumnMapper, req *modfile.Require, computeEdits diff.ComputeEdits) (*source.Diagnostic, error) {
349	rng, err := rangeFromPositions(m, req.Syntax.Start, req.Syntax.End)
350	if err != nil {
351		return nil, err
352	}
353	direction := "indirect"
354	if req.Indirect {
355		direction = "direct"
356
357		// If the dependency should be direct, just highlight the // indirect.
358		if comments := req.Syntax.Comment(); comments != nil && len(comments.Suffix) > 0 {
359			end := comments.Suffix[0].Start
360			end.LineRune += len(comments.Suffix[0].Token)
361			end.Byte += len([]byte(comments.Suffix[0].Token))
362			rng, err = rangeFromPositions(m, comments.Suffix[0].Start, end)
363			if err != nil {
364				return nil, err
365			}
366		}
367	}
368	// If the dependency should be indirect, add the // indirect.
369	edits, err := switchDirectness(req, m, computeEdits)
370	if err != nil {
371		return nil, err
372	}
373	return &source.Diagnostic{
374		URI:      m.URI,
375		Range:    rng,
376		Severity: protocol.SeverityWarning,
377		Source:   source.ModTidyError,
378		Message:  fmt.Sprintf("%s should be %s", req.Mod.Path, direction),
379		SuggestedFixes: []source.SuggestedFix{{
380			Title: fmt.Sprintf("Change %s to %s", req.Mod.Path, direction),
381			Edits: map[span.URI][]protocol.TextEdit{
382				m.URI: edits,
383			},
384			ActionKind: protocol.QuickFix,
385		}},
386	}, nil
387}
388
389func missingModuleDiagnostic(pm *source.ParsedModule, req *modfile.Require) (*source.Diagnostic, error) {
390	var rng protocol.Range
391	// Default to the start of the file if there is no module declaration.
392	if pm.File != nil && pm.File.Module != nil && pm.File.Module.Syntax != nil {
393		start, end := pm.File.Module.Syntax.Span()
394		var err error
395		rng, err = rangeFromPositions(pm.Mapper, start, end)
396		if err != nil {
397			return nil, err
398		}
399	}
400	title := fmt.Sprintf("Add %s to your go.mod file", req.Mod.Path)
401	cmd, err := command.NewAddDependencyCommand(title, command.DependencyArgs{
402		URI:        protocol.URIFromSpanURI(pm.Mapper.URI),
403		AddRequire: !req.Indirect,
404		GoCmdArgs:  []string{req.Mod.Path + "@" + req.Mod.Version},
405	})
406	if err != nil {
407		return nil, err
408	}
409	return &source.Diagnostic{
410		URI:            pm.Mapper.URI,
411		Range:          rng,
412		Severity:       protocol.SeverityError,
413		Source:         source.ModTidyError,
414		Message:        fmt.Sprintf("%s is not in your go.mod file", req.Mod.Path),
415		SuggestedFixes: []source.SuggestedFix{source.SuggestedFixFromCommand(cmd, protocol.QuickFix)},
416	}, nil
417}
418
419// switchDirectness gets the edits needed to change an indirect dependency to
420// direct and vice versa.
421func switchDirectness(req *modfile.Require, m *protocol.ColumnMapper, computeEdits diff.ComputeEdits) ([]protocol.TextEdit, error) {
422	// We need a private copy of the parsed go.mod file, since we're going to
423	// modify it.
424	copied, err := modfile.Parse("", m.Content, nil)
425	if err != nil {
426		return nil, err
427	}
428	// Change the directness in the matching require statement. To avoid
429	// reordering the require statements, rewrite all of them.
430	var requires []*modfile.Require
431	for _, r := range copied.Require {
432		if r.Mod.Path == req.Mod.Path {
433			requires = append(requires, &modfile.Require{
434				Mod:      r.Mod,
435				Syntax:   r.Syntax,
436				Indirect: !r.Indirect,
437			})
438			continue
439		}
440		requires = append(requires, r)
441	}
442	copied.SetRequire(requires)
443	newContent, err := copied.Format()
444	if err != nil {
445		return nil, err
446	}
447	// Calculate the edits to be made due to the change.
448	diff, err := computeEdits(m.URI, string(m.Content), string(newContent))
449	if err != nil {
450		return nil, err
451	}
452	return source.ToProtocolEdits(m, diff)
453}
454
455// missingModuleForImport creates an error for a given import path that comes
456// from a missing module.
457func missingModuleForImport(snapshot source.Snapshot, m *protocol.ColumnMapper, imp *ast.ImportSpec, req *modfile.Require, fixes []source.SuggestedFix) (*source.Diagnostic, error) {
458	if req.Syntax == nil {
459		return nil, fmt.Errorf("no syntax for %v", req)
460	}
461	spn, err := span.NewRange(snapshot.FileSet(), imp.Path.Pos(), imp.Path.End()).Span()
462	if err != nil {
463		return nil, err
464	}
465	rng, err := m.Range(spn)
466	if err != nil {
467		return nil, err
468	}
469	return &source.Diagnostic{
470		URI:            m.URI,
471		Range:          rng,
472		Severity:       protocol.SeverityError,
473		Source:         source.ModTidyError,
474		Message:        fmt.Sprintf("%s is not in your go.mod file", req.Mod.Path),
475		SuggestedFixes: fixes,
476	}, nil
477}
478
479func rangeFromPositions(m *protocol.ColumnMapper, s, e modfile.Position) (protocol.Range, error) {
480	spn, err := spanFromPositions(m, s, e)
481	if err != nil {
482		return protocol.Range{}, err
483	}
484	return m.Range(spn)
485}
486
487func spanFromPositions(m *protocol.ColumnMapper, s, e modfile.Position) (span.Span, error) {
488	toPoint := func(offset int) (span.Point, error) {
489		l, c, err := m.Converter.ToPosition(offset)
490		if err != nil {
491			return span.Point{}, err
492		}
493		return span.NewPoint(l, c, offset), nil
494	}
495	start, err := toPoint(s.Byte)
496	if err != nil {
497		return span.Span{}, err
498	}
499	end, err := toPoint(e.Byte)
500	if err != nil {
501		return span.Span{}, err
502	}
503	return span.New(m.URI, start, end), nil
504}
505