1// Copyright 2020 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package mod
6
7import (
8	"context"
9	"fmt"
10	"os"
11	"path/filepath"
12
13	"golang.org/x/mod/modfile"
14	"golang.org/x/tools/internal/lsp/protocol"
15	"golang.org/x/tools/internal/lsp/source"
16	"golang.org/x/tools/internal/span"
17)
18
19// LensFuncs returns the supported lensFuncs for go.mod files.
20func LensFuncs() map[string]source.LensFunc {
21	return map[string]source.LensFunc{
22		source.CommandUpgradeDependency.Name: upgradeLenses,
23		source.CommandTidy.Name:              tidyLens,
24		source.CommandVendor.Name:            vendorLens,
25	}
26}
27
28func upgradeLenses(ctx context.Context, snapshot source.Snapshot, fh source.FileHandle) ([]protocol.CodeLens, error) {
29	pm, err := snapshot.ParseMod(ctx, fh)
30	if err != nil || pm.File == nil {
31		return nil, err
32	}
33	if len(pm.File.Require) == 0 {
34		// Nothing to upgrade.
35		return nil, nil
36	}
37	upgradeTransitiveArgs, err := source.MarshalArgs(fh.URI(), false, []string{"-u", "all"})
38	if err != nil {
39		return nil, err
40	}
41	var requires []string
42	for _, req := range pm.File.Require {
43		requires = append(requires, req.Mod.Path)
44	}
45	upgradeDirectArgs, err := source.MarshalArgs(fh.URI(), false, requires)
46	if err != nil {
47		return nil, err
48	}
49	// Put the upgrade code lenses above the first require block or statement.
50	rng, err := firstRequireRange(fh, pm)
51	if err != nil {
52		return nil, err
53	}
54	return []protocol.CodeLens{
55		{
56			Range: rng,
57			Command: protocol.Command{
58				Title:     "Upgrade transitive dependencies",
59				Command:   source.CommandUpgradeDependency.ID(),
60				Arguments: upgradeTransitiveArgs,
61			},
62		},
63		{
64			Range: rng,
65			Command: protocol.Command{
66				Title:     "Upgrade direct dependencies",
67				Command:   source.CommandUpgradeDependency.ID(),
68				Arguments: upgradeDirectArgs,
69			},
70		},
71	}, nil
72
73}
74
75func tidyLens(ctx context.Context, snapshot source.Snapshot, fh source.FileHandle) ([]protocol.CodeLens, error) {
76	pm, err := snapshot.ParseMod(ctx, fh)
77	if err != nil || pm.File == nil {
78		return nil, err
79	}
80	if len(pm.File.Require) == 0 {
81		// Nothing to vendor.
82		return nil, nil
83	}
84	goModArgs, err := source.MarshalArgs(fh.URI())
85	if err != nil {
86		return nil, err
87	}
88	rng, err := moduleStmtRange(fh, pm)
89	if err != nil {
90		return nil, err
91	}
92	return []protocol.CodeLens{{
93		Range: rng,
94		Command: protocol.Command{
95			Title:     source.CommandTidy.Title,
96			Command:   source.CommandTidy.ID(),
97			Arguments: goModArgs,
98		},
99	}}, nil
100}
101
102func vendorLens(ctx context.Context, snapshot source.Snapshot, fh source.FileHandle) ([]protocol.CodeLens, error) {
103	pm, err := snapshot.ParseMod(ctx, fh)
104	if err != nil || pm.File == nil {
105		return nil, err
106	}
107	rng, err := moduleStmtRange(fh, pm)
108	if err != nil {
109		return nil, err
110	}
111	goModArgs, err := source.MarshalArgs(fh.URI())
112	if err != nil {
113		return nil, err
114	}
115	// Change the message depending on whether or not the module already has a
116	// vendor directory.
117	title := "Create vendor directory"
118	vendorDir := filepath.Join(filepath.Dir(fh.URI().Filename()), "vendor")
119	if info, _ := os.Stat(vendorDir); info != nil && info.IsDir() {
120		title = "Sync vendor directory"
121	}
122	return []protocol.CodeLens{{
123		Range: rng,
124		Command: protocol.Command{
125			Title:     title,
126			Command:   source.CommandVendor.ID(),
127			Arguments: goModArgs,
128		},
129	}}, nil
130}
131
132func moduleStmtRange(fh source.FileHandle, pm *source.ParsedModule) (protocol.Range, error) {
133	if pm.File == nil || pm.File.Module == nil || pm.File.Module.Syntax == nil {
134		return protocol.Range{}, fmt.Errorf("no module statement in %s", fh.URI())
135	}
136	syntax := pm.File.Module.Syntax
137	return lineToRange(pm.Mapper, fh.URI(), syntax.Start, syntax.End)
138}
139
140// firstRequireRange returns the range for the first "require" in the given
141// go.mod file. This is either a require block or an individual require line.
142func firstRequireRange(fh source.FileHandle, pm *source.ParsedModule) (protocol.Range, error) {
143	if len(pm.File.Require) == 0 {
144		return protocol.Range{}, fmt.Errorf("no requires in the file %s", fh.URI())
145	}
146	var start, end modfile.Position
147	for _, stmt := range pm.File.Syntax.Stmt {
148		if b, ok := stmt.(*modfile.LineBlock); ok && len(b.Token) == 1 && b.Token[0] == "require" {
149			start, end = b.Span()
150			break
151		}
152	}
153
154	firstRequire := pm.File.Require[0].Syntax
155	if start.Byte == 0 || firstRequire.Start.Byte < start.Byte {
156		start, end = firstRequire.Start, firstRequire.End
157	}
158	return lineToRange(pm.Mapper, fh.URI(), start, end)
159}
160
161func lineToRange(m *protocol.ColumnMapper, uri span.URI, start, end modfile.Position) (protocol.Range, error) {
162	line, col, err := m.Converter.ToPosition(start.Byte)
163	if err != nil {
164		return protocol.Range{}, err
165	}
166	s := span.NewPoint(line, col, start.Byte)
167	line, col, err = m.Converter.ToPosition(end.Byte)
168	if err != nil {
169		return protocol.Range{}, err
170	}
171	e := span.NewPoint(line, col, end.Byte)
172	return m.Range(span.New(uri, s, e))
173}
174