1package mod
2
3import (
4	"context"
5	"fmt"
6	"os"
7	"path/filepath"
8
9	"golang.org/x/mod/modfile"
10	"golang.org/x/tools/internal/lsp/protocol"
11	"golang.org/x/tools/internal/lsp/source"
12	"golang.org/x/tools/internal/span"
13)
14
15// LensFuncs returns the supported lensFuncs for go.mod files.
16func LensFuncs() map[string]source.LensFunc {
17	return map[string]source.LensFunc{
18		source.CommandUpgradeDependency.Name: upgradeLenses,
19		source.CommandTidy.Name:              tidyLens,
20		source.CommandVendor.Name:            vendorLens,
21	}
22}
23
24func upgradeLenses(ctx context.Context, snapshot source.Snapshot, fh source.FileHandle) ([]protocol.CodeLens, error) {
25	pm, err := snapshot.ParseMod(ctx, fh)
26	if err != nil || pm.File == nil {
27		return nil, err
28	}
29	if len(pm.File.Require) == 0 {
30		// Nothing to upgrade.
31		return nil, nil
32	}
33	upgradeTransitiveArgs, err := source.MarshalArgs(fh.URI(), false, []string{"-u", "all"})
34	if err != nil {
35		return nil, err
36	}
37	var requires []string
38	for _, req := range pm.File.Require {
39		requires = append(requires, req.Mod.Path)
40	}
41	upgradeDirectArgs, err := source.MarshalArgs(fh.URI(), false, requires)
42	if err != nil {
43		return nil, err
44	}
45	// Put the upgrade code lenses above the first require block or statement.
46	rng, err := firstRequireRange(fh, pm)
47	if err != nil {
48		return nil, err
49	}
50	return []protocol.CodeLens{
51		{
52			Range: rng,
53			Command: protocol.Command{
54				Title:     "Upgrade transitive dependencies",
55				Command:   source.CommandUpgradeDependency.ID(),
56				Arguments: upgradeTransitiveArgs,
57			},
58		},
59		{
60			Range: rng,
61			Command: protocol.Command{
62				Title:     "Upgrade direct dependencies",
63				Command:   source.CommandUpgradeDependency.ID(),
64				Arguments: upgradeDirectArgs,
65			},
66		},
67	}, nil
68
69}
70
71func tidyLens(ctx context.Context, snapshot source.Snapshot, fh source.FileHandle) ([]protocol.CodeLens, error) {
72	pm, err := snapshot.ParseMod(ctx, fh)
73	if err != nil || pm.File == nil {
74		return nil, err
75	}
76	if len(pm.File.Require) == 0 {
77		// Nothing to vendor.
78		return nil, nil
79	}
80	goModArgs, err := source.MarshalArgs(fh.URI())
81	if err != nil {
82		return nil, err
83	}
84	rng, err := moduleStmtRange(fh, pm)
85	if err != nil {
86		return nil, err
87	}
88	return []protocol.CodeLens{{
89		Range: rng,
90		Command: protocol.Command{
91			Title:     source.CommandTidy.Title,
92			Command:   source.CommandTidy.ID(),
93			Arguments: goModArgs,
94		},
95	}}, nil
96}
97
98func vendorLens(ctx context.Context, snapshot source.Snapshot, fh source.FileHandle) ([]protocol.CodeLens, error) {
99	pm, err := snapshot.ParseMod(ctx, fh)
100	if err != nil || pm.File == nil {
101		return nil, err
102	}
103	rng, err := moduleStmtRange(fh, pm)
104	if err != nil {
105		return nil, err
106	}
107	goModArgs, err := source.MarshalArgs(fh.URI())
108	if err != nil {
109		return nil, err
110	}
111	// Change the message depending on whether or not the module already has a
112	// vendor directory.
113	title := "Create vendor directory"
114	vendorDir := filepath.Join(filepath.Dir(fh.URI().Filename()), "vendor")
115	if info, _ := os.Stat(vendorDir); info != nil && info.IsDir() {
116		title = "Sync vendor directory"
117	}
118	return []protocol.CodeLens{{
119		Range: rng,
120		Command: protocol.Command{
121			Title:     title,
122			Command:   source.CommandVendor.ID(),
123			Arguments: goModArgs,
124		},
125	}}, nil
126}
127
128func moduleStmtRange(fh source.FileHandle, pm *source.ParsedModule) (protocol.Range, error) {
129	if pm.File == nil || pm.File.Module == nil || pm.File.Module.Syntax == nil {
130		return protocol.Range{}, fmt.Errorf("no module statement in %s", fh.URI())
131	}
132	syntax := pm.File.Module.Syntax
133	return lineToRange(pm.Mapper, fh.URI(), syntax.Start, syntax.End)
134}
135
136// firstRequireRange returns the range for the first "require" in the given
137// go.mod file. This is either a require block or an individual require line.
138func firstRequireRange(fh source.FileHandle, pm *source.ParsedModule) (protocol.Range, error) {
139	if len(pm.File.Require) == 0 {
140		return protocol.Range{}, fmt.Errorf("no requires in the file %s", fh.URI())
141	}
142	var start, end modfile.Position
143	for _, stmt := range pm.File.Syntax.Stmt {
144		if b, ok := stmt.(*modfile.LineBlock); ok && len(b.Token) == 1 && b.Token[0] == "require" {
145			start, end = b.Span()
146			break
147		}
148	}
149
150	firstRequire := pm.File.Require[0].Syntax
151	if start.Byte == 0 || firstRequire.Start.Byte < start.Byte {
152		start, end = firstRequire.Start, firstRequire.End
153	}
154	return lineToRange(pm.Mapper, fh.URI(), start, end)
155}
156
157func lineToRange(m *protocol.ColumnMapper, uri span.URI, start, end modfile.Position) (protocol.Range, error) {
158	line, col, err := m.Converter.ToPosition(start.Byte)
159	if err != nil {
160		return protocol.Range{}, err
161	}
162	s := span.NewPoint(line, col, start.Byte)
163	line, col, err = m.Converter.ToPosition(end.Byte)
164	if err != nil {
165		return protocol.Range{}, err
166	}
167	e := span.NewPoint(line, col, end.Byte)
168	return m.Range(span.New(uri, s, e))
169}
170