1// Copyright 2020 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package cache
6
7import (
8	"context"
9	"fmt"
10	"go/ast"
11	"io/ioutil"
12	"os"
13	"path/filepath"
14	"sort"
15	"strconv"
16	"strings"
17
18	"golang.org/x/mod/modfile"
19	"golang.org/x/tools/internal/event"
20	"golang.org/x/tools/internal/lsp/debug/tag"
21	"golang.org/x/tools/internal/lsp/diff"
22	"golang.org/x/tools/internal/lsp/protocol"
23	"golang.org/x/tools/internal/lsp/source"
24	"golang.org/x/tools/internal/memoize"
25	"golang.org/x/tools/internal/span"
26)
27
28type modTidyKey struct {
29	sessionID       string
30	cfg             string
31	gomod           source.FileIdentity
32	imports         string
33	unsavedOverlays string
34	view            string
35}
36
37type modTidyHandle struct {
38	handle *memoize.Handle
39}
40
41type modTidyData struct {
42	tidied *source.TidiedModule
43	err    error
44}
45
46func (mth *modTidyHandle) tidy(ctx context.Context, snapshot *snapshot) (*source.TidiedModule, error) {
47	v, err := mth.handle.Get(ctx, snapshot.generation, snapshot)
48	if err != nil {
49		return nil, err
50	}
51	data := v.(*modTidyData)
52	return data.tidied, data.err
53}
54
55func (s *snapshot) ModTidy(ctx context.Context, fh source.FileHandle) (*source.TidiedModule, error) {
56	if fh.Kind() != source.Mod {
57		return nil, fmt.Errorf("%s is not a go.mod file", fh.URI())
58	}
59	if s.workspaceMode()&tempModfile == 0 {
60		return nil, source.ErrTmpModfileUnsupported
61	}
62	if handle := s.getModTidyHandle(fh.URI()); handle != nil {
63		return handle.tidy(ctx, s)
64	}
65	// If the file handle is an overlay, it may not be written to disk.
66	// The go.mod file has to be on disk for `go mod tidy` to work.
67	if _, ok := fh.(*overlay); ok {
68		if info, _ := os.Stat(fh.URI().Filename()); info == nil {
69			return nil, source.ErrNoModOnDisk
70		}
71	}
72	workspacePkgs, err := s.WorkspacePackages(ctx)
73	if err != nil {
74		return nil, err
75	}
76	importHash, err := hashImports(ctx, workspacePkgs)
77	if err != nil {
78		return nil, err
79	}
80
81	s.mu.Lock()
82	overlayHash := hashUnsavedOverlays(s.files)
83	s.mu.Unlock()
84
85	// Make sure to use the module root in the configuration.
86	cfg := s.config(ctx, filepath.Dir(fh.URI().Filename()))
87	key := modTidyKey{
88		sessionID:       s.view.session.id,
89		view:            s.view.folder.Filename(),
90		imports:         importHash,
91		unsavedOverlays: overlayHash,
92		gomod:           fh.FileIdentity(),
93		cfg:             hashConfig(cfg),
94	}
95	h := s.generation.Bind(key, func(ctx context.Context, arg memoize.Arg) interface{} {
96		ctx, done := event.Start(ctx, "cache.ModTidyHandle", tag.URI.Of(fh.URI()))
97		defer done()
98
99		snapshot := arg.(*snapshot)
100		pm, err := snapshot.ParseMod(ctx, fh)
101		if err != nil || len(pm.ParseErrors) > 0 {
102			if err == nil {
103				err = fmt.Errorf("could not parse module to tidy: %v", pm.ParseErrors)
104			}
105			var errors []source.Error
106			if pm != nil {
107				errors = pm.ParseErrors
108			}
109			return &modTidyData{
110				tidied: &source.TidiedModule{
111					Parsed: pm,
112					Errors: errors,
113				},
114				err: err,
115			}
116		}
117		// Get a new config to avoid races, since it may be modified by
118		// goCommandInvocation.
119		cfg := s.config(ctx, filepath.Dir(fh.URI().Filename()))
120		tmpURI, runner, inv, cleanup, err := snapshot.goCommandInvocation(ctx, cfg, true, "mod", []string{"tidy"})
121		if err != nil {
122			return &modTidyData{err: err}
123		}
124		// Keep the temporary go.mod file around long enough to parse it.
125		defer cleanup()
126
127		if _, err := runner.Run(ctx, *inv); err != nil {
128			return &modTidyData{err: err}
129		}
130		// Go directly to disk to get the temporary mod file, since it is
131		// always on disk.
132		tempContents, err := ioutil.ReadFile(tmpURI.Filename())
133		if err != nil {
134			return &modTidyData{err: err}
135		}
136		ideal, err := modfile.Parse(tmpURI.Filename(), tempContents, nil)
137		if err != nil {
138			// We do not need to worry about the temporary file's parse errors
139			// since it has been "tidied".
140			return &modTidyData{err: err}
141		}
142		// Compare the original and tidied go.mod files to compute errors and
143		// suggested fixes.
144		errors, err := modTidyErrors(ctx, snapshot, pm, ideal, workspacePkgs)
145		if err != nil {
146			return &modTidyData{err: err}
147		}
148		return &modTidyData{
149			tidied: &source.TidiedModule{
150				Errors:        errors,
151				Parsed:        pm,
152				TidiedContent: tempContents,
153			},
154		}
155	})
156
157	mth := &modTidyHandle{handle: h}
158	s.mu.Lock()
159	s.modTidyHandles[fh.URI()] = mth
160	s.mu.Unlock()
161
162	return mth.tidy(ctx, s)
163}
164
165func hashImports(ctx context.Context, wsPackages []source.Package) (string, error) {
166	results := make(map[string]bool)
167	var imports []string
168	for _, pkg := range wsPackages {
169		for _, path := range pkg.Imports() {
170			imp := path.PkgPath()
171			if _, ok := results[imp]; !ok {
172				results[imp] = true
173				imports = append(imports, imp)
174			}
175		}
176	}
177	sort.Strings(imports)
178	hashed := strings.Join(imports, ",")
179	return hashContents([]byte(hashed)), nil
180}
181
182// modTidyErrors computes the differences between the original and tidied
183// go.mod files to produce diagnostic and suggested fixes. Some diagnostics
184// may appear on the Go files that import packages from missing modules.
185func modTidyErrors(ctx context.Context, snapshot source.Snapshot, pm *source.ParsedModule, ideal *modfile.File, workspacePkgs []source.Package) (errors []source.Error, err error) {
186	// First, determine which modules are unused and which are missing from the
187	// original go.mod file.
188	var (
189		unused          = make(map[string]*modfile.Require, len(pm.File.Require))
190		missing         = make(map[string]*modfile.Require, len(ideal.Require))
191		wrongDirectness = make(map[string]*modfile.Require, len(pm.File.Require))
192	)
193	for _, req := range pm.File.Require {
194		unused[req.Mod.Path] = req
195	}
196	for _, req := range ideal.Require {
197		origReq := unused[req.Mod.Path]
198		if origReq == nil {
199			missing[req.Mod.Path] = req
200			continue
201		} else if origReq.Indirect != req.Indirect {
202			wrongDirectness[req.Mod.Path] = origReq
203		}
204		delete(unused, req.Mod.Path)
205	}
206	for _, req := range unused {
207		srcErr, err := unusedError(pm.Mapper, req, snapshot.View().Options().ComputeEdits)
208		if err != nil {
209			return nil, err
210		}
211		errors = append(errors, srcErr)
212	}
213	for _, req := range wrongDirectness {
214		// Handle dependencies that are incorrectly labeled indirect and
215		// vice versa.
216		srcErr, err := directnessError(pm.Mapper, req, snapshot.View().Options().ComputeEdits)
217		if err != nil {
218			return nil, err
219		}
220		errors = append(errors, srcErr)
221	}
222	// Next, compute any diagnostics for modules that are missing from the
223	// go.mod file. The fixes will be for the go.mod file, but the
224	// diagnostics should also appear in both the go.mod file and the import
225	// statements in the Go files in which the dependencies are used.
226	missingModuleFixes := map[*modfile.Require][]source.SuggestedFix{}
227	for _, req := range missing {
228		srcErr, err := missingModuleError(snapshot, pm, req)
229		if err != nil {
230			return nil, err
231		}
232		missingModuleFixes[req] = srcErr.SuggestedFixes
233		errors = append(errors, srcErr)
234	}
235	// Add diagnostics for missing modules anywhere they are imported in the
236	// workspace.
237	for _, pkg := range workspacePkgs {
238		missingImports := map[string]*modfile.Require{}
239		for _, imp := range pkg.Imports() {
240			if req, ok := missing[imp.PkgPath()]; ok {
241				missingImports[imp.PkgPath()] = req
242				break
243			}
244			// If the import is a package of the dependency, then add the
245			// package to the map, this will eliminate the need to do this
246			// prefix package search on each import for each file.
247			// Example:
248			//
249			// import (
250			//   "golang.org/x/tools/go/expect"
251			//   "golang.org/x/tools/go/packages"
252			// )
253			// They both are related to the same module: "golang.org/x/tools".
254			var match string
255			for _, req := range ideal.Require {
256				if strings.HasPrefix(imp.PkgPath(), req.Mod.Path) && len(req.Mod.Path) > len(match) {
257					match = req.Mod.Path
258				}
259			}
260			if req, ok := missing[match]; ok {
261				missingImports[imp.PkgPath()] = req
262			}
263		}
264		// None of this package's imports are from missing modules.
265		if len(missingImports) == 0 {
266			continue
267		}
268		for _, pgf := range pkg.CompiledGoFiles() {
269			file, m := pgf.File, pgf.Mapper
270			if file == nil || m == nil {
271				continue
272			}
273			imports := make(map[string]*ast.ImportSpec)
274			for _, imp := range file.Imports {
275				if imp.Path == nil {
276					continue
277				}
278				if target, err := strconv.Unquote(imp.Path.Value); err == nil {
279					imports[target] = imp
280				}
281			}
282			if len(imports) == 0 {
283				continue
284			}
285			for importPath, req := range missingImports {
286				imp, ok := imports[importPath]
287				if !ok {
288					continue
289				}
290				fixes, ok := missingModuleFixes[req]
291				if !ok {
292					return nil, fmt.Errorf("no missing module fix for %q (%q)", importPath, req.Mod.Path)
293				}
294				srcErr, err := missingModuleForImport(snapshot, m, imp, req, fixes)
295				if err != nil {
296					return nil, err
297				}
298				errors = append(errors, srcErr)
299			}
300		}
301	}
302	return errors, nil
303}
304
305// unusedError returns a source.Error for an unused require.
306func unusedError(m *protocol.ColumnMapper, req *modfile.Require, computeEdits diff.ComputeEdits) (source.Error, error) {
307	rng, err := rangeFromPositions(m, req.Syntax.Start, req.Syntax.End)
308	if err != nil {
309		return source.Error{}, err
310	}
311	edits, err := dropDependency(req, m, computeEdits)
312	if err != nil {
313		return source.Error{}, err
314	}
315	return source.Error{
316		Category: source.GoModTidy,
317		Message:  fmt.Sprintf("%s is not used in this module", req.Mod.Path),
318		Range:    rng,
319		URI:      m.URI,
320		SuggestedFixes: []source.SuggestedFix{{
321			Title: fmt.Sprintf("Remove dependency: %s", req.Mod.Path),
322			Edits: map[span.URI][]protocol.TextEdit{
323				m.URI: edits,
324			},
325		}},
326	}, nil
327}
328
329// directnessError extracts errors when a dependency is labeled indirect when
330// it should be direct and vice versa.
331func directnessError(m *protocol.ColumnMapper, req *modfile.Require, computeEdits diff.ComputeEdits) (source.Error, error) {
332	rng, err := rangeFromPositions(m, req.Syntax.Start, req.Syntax.End)
333	if err != nil {
334		return source.Error{}, err
335	}
336	direction := "indirect"
337	if req.Indirect {
338		direction = "direct"
339
340		// If the dependency should be direct, just highlight the // indirect.
341		if comments := req.Syntax.Comment(); comments != nil && len(comments.Suffix) > 0 {
342			end := comments.Suffix[0].Start
343			end.LineRune += len(comments.Suffix[0].Token)
344			end.Byte += len([]byte(comments.Suffix[0].Token))
345			rng, err = rangeFromPositions(m, comments.Suffix[0].Start, end)
346			if err != nil {
347				return source.Error{}, err
348			}
349		}
350	}
351	// If the dependency should be indirect, add the // indirect.
352	edits, err := switchDirectness(req, m, computeEdits)
353	if err != nil {
354		return source.Error{}, err
355	}
356	return source.Error{
357		Message:  fmt.Sprintf("%s should be %s", req.Mod.Path, direction),
358		Range:    rng,
359		URI:      m.URI,
360		Category: source.GoModTidy,
361		SuggestedFixes: []source.SuggestedFix{{
362			Title: fmt.Sprintf("Change %s to %s", req.Mod.Path, direction),
363			Edits: map[span.URI][]protocol.TextEdit{
364				m.URI: edits,
365			},
366		}},
367	}, nil
368}
369
370func missingModuleError(snapshot source.Snapshot, pm *source.ParsedModule, req *modfile.Require) (source.Error, error) {
371	start, end := pm.File.Module.Syntax.Span()
372	rng, err := rangeFromPositions(pm.Mapper, start, end)
373	if err != nil {
374		return source.Error{}, err
375	}
376	edits, err := addRequireFix(pm.Mapper, req, snapshot.View().Options().ComputeEdits)
377	if err != nil {
378		return source.Error{}, err
379	}
380	fix := &source.SuggestedFix{
381		Title: fmt.Sprintf("Add %s to your go.mod file", req.Mod.Path),
382		Edits: map[span.URI][]protocol.TextEdit{
383			pm.Mapper.URI: edits,
384		},
385	}
386	return source.Error{
387		URI:            pm.Mapper.URI,
388		Range:          rng,
389		Message:        fmt.Sprintf("%s is not in your go.mod file", req.Mod.Path),
390		Category:       source.GoModTidy,
391		Kind:           source.ModTidyError,
392		SuggestedFixes: []source.SuggestedFix{*fix},
393	}, nil
394}
395
396// dropDependency returns the edits to remove the given require from the go.mod
397// file.
398func dropDependency(req *modfile.Require, m *protocol.ColumnMapper, computeEdits diff.ComputeEdits) ([]protocol.TextEdit, error) {
399	// We need a private copy of the parsed go.mod file, since we're going to
400	// modify it.
401	copied, err := modfile.Parse("", m.Content, nil)
402	if err != nil {
403		return nil, err
404	}
405	if err := copied.DropRequire(req.Mod.Path); err != nil {
406		return nil, err
407	}
408	copied.Cleanup()
409	newContent, err := copied.Format()
410	if err != nil {
411		return nil, err
412	}
413	// Calculate the edits to be made due to the change.
414	diff := computeEdits(m.URI, string(m.Content), string(newContent))
415	return source.ToProtocolEdits(m, diff)
416}
417
418// switchDirectness gets the edits needed to change an indirect dependency to
419// direct and vice versa.
420func switchDirectness(req *modfile.Require, m *protocol.ColumnMapper, computeEdits diff.ComputeEdits) ([]protocol.TextEdit, error) {
421	// We need a private copy of the parsed go.mod file, since we're going to
422	// modify it.
423	copied, err := modfile.Parse("", m.Content, nil)
424	if err != nil {
425		return nil, err
426	}
427	// Change the directness in the matching require statement. To avoid
428	// reordering the require statements, rewrite all of them.
429	var requires []*modfile.Require
430	for _, r := range copied.Require {
431		if r.Mod.Path == req.Mod.Path {
432			requires = append(requires, &modfile.Require{
433				Mod:      r.Mod,
434				Syntax:   r.Syntax,
435				Indirect: !r.Indirect,
436			})
437			continue
438		}
439		requires = append(requires, r)
440	}
441	copied.SetRequire(requires)
442	newContent, err := copied.Format()
443	if err != nil {
444		return nil, err
445	}
446	// Calculate the edits to be made due to the change.
447	diff := computeEdits(m.URI, string(m.Content), string(newContent))
448	return source.ToProtocolEdits(m, diff)
449}
450
451// missingModuleForImport creates an error for a given import path that comes
452// from a missing module.
453func missingModuleForImport(snapshot source.Snapshot, m *protocol.ColumnMapper, imp *ast.ImportSpec, req *modfile.Require, fixes []source.SuggestedFix) (source.Error, error) {
454	if req.Syntax == nil {
455		return source.Error{}, fmt.Errorf("no syntax for %v", req)
456	}
457	spn, err := span.NewRange(snapshot.FileSet(), imp.Path.Pos(), imp.Path.End()).Span()
458	if err != nil {
459		return source.Error{}, err
460	}
461	rng, err := m.Range(spn)
462	if err != nil {
463		return source.Error{}, err
464	}
465	return source.Error{
466		Category:       source.GoModTidy,
467		URI:            m.URI,
468		Range:          rng,
469		Message:        fmt.Sprintf("%s is not in your go.mod file", req.Mod.Path),
470		Kind:           source.ModTidyError,
471		SuggestedFixes: fixes,
472	}, nil
473}
474
475// addRequireFix creates edits for adding a given require to a go.mod file.
476func addRequireFix(m *protocol.ColumnMapper, req *modfile.Require, computeEdits diff.ComputeEdits) ([]protocol.TextEdit, error) {
477	// We need a private copy of the parsed go.mod file, since we're going to
478	// modify it.
479	copied, err := modfile.Parse("", m.Content, nil)
480	if err != nil {
481		return nil, err
482	}
483	// Calculate the quick fix edits that need to be made to the go.mod file.
484	if err := copied.AddRequire(req.Mod.Path, req.Mod.Version); err != nil {
485		return nil, err
486	}
487	copied.SortBlocks()
488	newContents, err := copied.Format()
489	if err != nil {
490		return nil, err
491	}
492	// Calculate the edits to be made due to the change.
493	diff := computeEdits(m.URI, string(m.Content), string(newContents))
494	return source.ToProtocolEdits(m, diff)
495}
496
497func rangeFromPositions(m *protocol.ColumnMapper, s, e modfile.Position) (protocol.Range, error) {
498	toPoint := func(offset int) (span.Point, error) {
499		l, c, err := m.Converter.ToPosition(offset)
500		if err != nil {
501			return span.Point{}, err
502		}
503		return span.NewPoint(l, c, offset), nil
504	}
505	start, err := toPoint(s.Byte)
506	if err != nil {
507		return protocol.Range{}, err
508	}
509	end, err := toPoint(e.Byte)
510	if err != nil {
511		return protocol.Range{}, err
512	}
513	return m.Range(span.New(m.URI, start, end))
514}
515