1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package source
6
7import (
8	"context"
9	"encoding/json"
10	"fmt"
11	"go/ast"
12	"go/printer"
13	"go/token"
14	"go/types"
15	"path/filepath"
16	"regexp"
17	"sort"
18	"strconv"
19	"strings"
20
21	"golang.org/x/tools/internal/lsp/protocol"
22	"golang.org/x/tools/internal/span"
23	errors "golang.org/x/xerrors"
24)
25
26// MappedRange provides mapped protocol.Range for a span.Range, accounting for
27// UTF-16 code points.
28type MappedRange struct {
29	spanRange span.Range
30	m         *protocol.ColumnMapper
31
32	// protocolRange is the result of converting the spanRange using the mapper.
33	// It is computed on-demand.
34	protocolRange *protocol.Range
35}
36
37// NewMappedRange returns a MappedRange for the given start and end token.Pos.
38func NewMappedRange(fset *token.FileSet, m *protocol.ColumnMapper, start, end token.Pos) MappedRange {
39	return MappedRange{
40		spanRange: span.Range{
41			FileSet:   fset,
42			Start:     start,
43			End:       end,
44			Converter: m.Converter,
45		},
46		m: m,
47	}
48}
49
50func (s MappedRange) Range() (protocol.Range, error) {
51	if s.protocolRange == nil {
52		spn, err := s.spanRange.Span()
53		if err != nil {
54			return protocol.Range{}, err
55		}
56		prng, err := s.m.Range(spn)
57		if err != nil {
58			return protocol.Range{}, err
59		}
60		s.protocolRange = &prng
61	}
62	return *s.protocolRange, nil
63}
64
65func (s MappedRange) Span() (span.Span, error) {
66	return s.spanRange.Span()
67}
68
69func (s MappedRange) SpanRange() span.Range {
70	return s.spanRange
71}
72
73func (s MappedRange) URI() span.URI {
74	return s.m.URI
75}
76
77// GetParsedFile is a convenience function that extracts the Package and
78// ParsedGoFile for a file in a Snapshot. pkgPolicy is one of NarrowestPackage/
79// WidestPackage.
80func GetParsedFile(ctx context.Context, snapshot Snapshot, fh FileHandle, pkgPolicy PackageFilter) (Package, *ParsedGoFile, error) {
81	pkg, err := snapshot.PackageForFile(ctx, fh.URI(), TypecheckWorkspace, pkgPolicy)
82	if err != nil {
83		return nil, nil, err
84	}
85	pgh, err := pkg.File(fh.URI())
86	return pkg, pgh, err
87}
88
89func IsGenerated(ctx context.Context, snapshot Snapshot, uri span.URI) bool {
90	fh, err := snapshot.GetFile(ctx, uri)
91	if err != nil {
92		return false
93	}
94	pgf, err := snapshot.ParseGo(ctx, fh, ParseHeader)
95	if err != nil {
96		return false
97	}
98	tok := snapshot.FileSet().File(pgf.File.Pos())
99	if tok == nil {
100		return false
101	}
102	for _, commentGroup := range pgf.File.Comments {
103		for _, comment := range commentGroup.List {
104			if matched := generatedRx.MatchString(comment.Text); matched {
105				// Check if comment is at the beginning of the line in source.
106				if pos := tok.Position(comment.Slash); pos.Column == 1 {
107					return true
108				}
109			}
110		}
111	}
112	return false
113}
114
115func nodeToProtocolRange(snapshot Snapshot, pkg Package, n ast.Node) (protocol.Range, error) {
116	mrng, err := posToMappedRange(snapshot, pkg, n.Pos(), n.End())
117	if err != nil {
118		return protocol.Range{}, err
119	}
120	return mrng.Range()
121}
122
123func objToMappedRange(snapshot Snapshot, pkg Package, obj types.Object) (MappedRange, error) {
124	if pkgName, ok := obj.(*types.PkgName); ok {
125		// An imported Go package has a package-local, unqualified name.
126		// When the name matches the imported package name, there is no
127		// identifier in the import spec with the local package name.
128		//
129		// For example:
130		// 		import "go/ast" 	// name "ast" matches package name
131		// 		import a "go/ast"  	// name "a" does not match package name
132		//
133		// When the identifier does not appear in the source, have the range
134		// of the object be the import path, including quotes.
135		if pkgName.Imported().Name() == pkgName.Name() {
136			return posToMappedRange(snapshot, pkg, obj.Pos(), obj.Pos()+token.Pos(len(pkgName.Imported().Path())+2))
137		}
138	}
139	return nameToMappedRange(snapshot, pkg, obj.Pos(), obj.Name())
140}
141
142func nameToMappedRange(snapshot Snapshot, pkg Package, pos token.Pos, name string) (MappedRange, error) {
143	return posToMappedRange(snapshot, pkg, pos, pos+token.Pos(len(name)))
144}
145
146func posToMappedRange(snapshot Snapshot, pkg Package, pos, end token.Pos) (MappedRange, error) {
147	logicalFilename := snapshot.FileSet().File(pos).Position(pos).Filename
148	pgf, _, err := findFileInDeps(pkg, span.URIFromPath(logicalFilename))
149	if err != nil {
150		return MappedRange{}, err
151	}
152	if !pos.IsValid() {
153		return MappedRange{}, errors.Errorf("invalid position for %v", pos)
154	}
155	if !end.IsValid() {
156		return MappedRange{}, errors.Errorf("invalid position for %v", end)
157	}
158	return NewMappedRange(snapshot.FileSet(), pgf.Mapper, pos, end), nil
159}
160
161// Matches cgo generated comment as well as the proposed standard:
162//	https://golang.org/s/generatedcode
163var generatedRx = regexp.MustCompile(`// .*DO NOT EDIT\.?`)
164
165func DetectLanguage(langID, filename string) FileKind {
166	switch langID {
167	case "go":
168		return Go
169	case "go.mod":
170		return Mod
171	case "go.sum":
172		return Sum
173	}
174	// Fallback to detecting the language based on the file extension.
175	switch filepath.Ext(filename) {
176	case ".mod":
177		return Mod
178	case ".sum":
179		return Sum
180	default: // fallback to Go
181		return Go
182	}
183}
184
185func (k FileKind) String() string {
186	switch k {
187	case Mod:
188		return "go.mod"
189	case Sum:
190		return "go.sum"
191	default:
192		return "go"
193	}
194}
195
196// nodeAtPos returns the index and the node whose position is contained inside
197// the node list.
198func nodeAtPos(nodes []ast.Node, pos token.Pos) (ast.Node, int) {
199	if nodes == nil {
200		return nil, -1
201	}
202	for i, node := range nodes {
203		if node.Pos() <= pos && pos <= node.End() {
204			return node, i
205		}
206	}
207	return nil, -1
208}
209
210// IsInterface returns if a types.Type is an interface
211func IsInterface(T types.Type) bool {
212	return T != nil && types.IsInterface(T)
213}
214
215// FormatNode returns the "pretty-print" output for an ast node.
216func FormatNode(fset *token.FileSet, n ast.Node) string {
217	var buf strings.Builder
218	if err := printer.Fprint(&buf, fset, n); err != nil {
219		return ""
220	}
221	return buf.String()
222}
223
224// Deref returns a pointer's element type, traversing as many levels as needed.
225// Otherwise it returns typ.
226func Deref(typ types.Type) types.Type {
227	for {
228		p, ok := typ.Underlying().(*types.Pointer)
229		if !ok {
230			return typ
231		}
232		typ = p.Elem()
233	}
234}
235
236func SortDiagnostics(d []*Diagnostic) {
237	sort.Slice(d, func(i int, j int) bool {
238		return CompareDiagnostic(d[i], d[j]) < 0
239	})
240}
241
242func CompareDiagnostic(a, b *Diagnostic) int {
243	if r := protocol.CompareRange(a.Range, b.Range); r != 0 {
244		return r
245	}
246	if a.Source < b.Source {
247		return -1
248	}
249	if a.Message < b.Message {
250		return -1
251	}
252	if a.Message == b.Message {
253		return 0
254	}
255	return 1
256}
257
258// FindPosInPackage finds the parsed file for a position in a given search
259// package.
260func FindPosInPackage(snapshot Snapshot, searchpkg Package, pos token.Pos) (*ParsedGoFile, Package, error) {
261	tok := snapshot.FileSet().File(pos)
262	if tok == nil {
263		return nil, nil, errors.Errorf("no file for pos in package %s", searchpkg.ID())
264	}
265	uri := span.URIFromPath(tok.Name())
266
267	pgf, pkg, err := findFileInDeps(searchpkg, uri)
268	if err != nil {
269		return nil, nil, err
270	}
271	return pgf, pkg, nil
272}
273
274// findFileInDeps finds uri in pkg or its dependencies.
275func findFileInDeps(pkg Package, uri span.URI) (*ParsedGoFile, Package, error) {
276	queue := []Package{pkg}
277	seen := make(map[string]bool)
278
279	for len(queue) > 0 {
280		pkg := queue[0]
281		queue = queue[1:]
282		seen[pkg.ID()] = true
283
284		if pgf, err := pkg.File(uri); err == nil {
285			return pgf, pkg, nil
286		}
287		for _, dep := range pkg.Imports() {
288			if !seen[dep.ID()] {
289				queue = append(queue, dep)
290			}
291		}
292	}
293	return nil, nil, errors.Errorf("no file for %s in package %s", uri, pkg.ID())
294}
295
296// MarshalArgs encodes the given arguments to json.RawMessages. This function
297// is used to construct arguments to a protocol.Command.
298//
299// Example usage:
300//
301//   jsonArgs, err := EncodeArgs(1, "hello", true, StructuredArg{42, 12.6})
302//
303func MarshalArgs(args ...interface{}) ([]json.RawMessage, error) {
304	var out []json.RawMessage
305	for _, arg := range args {
306		argJSON, err := json.Marshal(arg)
307		if err != nil {
308			return nil, err
309		}
310		out = append(out, argJSON)
311	}
312	return out, nil
313}
314
315// UnmarshalArgs decodes the given json.RawMessages to the variables provided
316// by args. Each element of args should be a pointer.
317//
318// Example usage:
319//
320//   var (
321//       num int
322//       str string
323//       bul bool
324//       structured StructuredArg
325//   )
326//   err := UnmarshalArgs(args, &num, &str, &bul, &structured)
327//
328func UnmarshalArgs(jsonArgs []json.RawMessage, args ...interface{}) error {
329	if len(args) != len(jsonArgs) {
330		return fmt.Errorf("DecodeArgs: expected %d input arguments, got %d JSON arguments", len(args), len(jsonArgs))
331	}
332	for i, arg := range args {
333		if err := json.Unmarshal(jsonArgs[i], arg); err != nil {
334			return err
335		}
336	}
337	return nil
338}
339
340// ImportPath returns the unquoted import path of s,
341// or "" if the path is not properly quoted.
342func ImportPath(s *ast.ImportSpec) string {
343	t, err := strconv.Unquote(s.Path.Value)
344	if err != nil {
345		return ""
346	}
347	return t
348}
349
350// NodeContains returns true if a node encloses a given position pos.
351func NodeContains(n ast.Node, pos token.Pos) bool {
352	return n != nil && n.Pos() <= pos && pos <= n.End()
353}
354
355// CollectScopes returns all scopes in an ast path, ordered as innermost scope
356// first.
357func CollectScopes(info *types.Info, path []ast.Node, pos token.Pos) []*types.Scope {
358	// scopes[i], where i<len(path), is the possibly nil Scope of path[i].
359	var scopes []*types.Scope
360	for _, n := range path {
361		// Include *FuncType scope if pos is inside the function body.
362		switch node := n.(type) {
363		case *ast.FuncDecl:
364			if node.Body != nil && NodeContains(node.Body, pos) {
365				n = node.Type
366			}
367		case *ast.FuncLit:
368			if node.Body != nil && NodeContains(node.Body, pos) {
369				n = node.Type
370			}
371		}
372		scopes = append(scopes, info.Scopes[n])
373	}
374	return scopes
375}
376
377// Qualifier returns a function that appropriately formats a types.PkgName
378// appearing in a *ast.File.
379func Qualifier(f *ast.File, pkg *types.Package, info *types.Info) types.Qualifier {
380	// Construct mapping of import paths to their defined or implicit names.
381	imports := make(map[*types.Package]string)
382	for _, imp := range f.Imports {
383		var obj types.Object
384		if imp.Name != nil {
385			obj = info.Defs[imp.Name]
386		} else {
387			obj = info.Implicits[imp]
388		}
389		if pkgname, ok := obj.(*types.PkgName); ok {
390			imports[pkgname.Imported()] = pkgname.Name()
391		}
392	}
393	// Define qualifier to replace full package paths with names of the imports.
394	return func(p *types.Package) string {
395		if p == pkg {
396			return ""
397		}
398		if name, ok := imports[p]; ok {
399			return name
400		}
401		return p.Name()
402	}
403}
404
405// isDirective reports whether c is a comment directive.
406//
407// Copied and adapted from go/src/go/ast/ast.go.
408func isDirective(c string) bool {
409	if len(c) < 3 {
410		return false
411	}
412	if c[1] != '/' {
413		return false
414	}
415	//-style comment (no newline at the end)
416	c = c[2:]
417	if len(c) == 0 {
418		// empty line
419		return false
420	}
421	// "//line " is a line directive.
422	// (The // has been removed.)
423	if strings.HasPrefix(c, "line ") {
424		return true
425	}
426
427	// "//[a-z0-9]+:[a-z0-9]"
428	// (The // has been removed.)
429	colon := strings.Index(c, ":")
430	if colon <= 0 || colon+1 >= len(c) {
431		return false
432	}
433	for i := 0; i <= colon+1; i++ {
434		if i == colon {
435			continue
436		}
437		b := c[i]
438		if !('a' <= b && b <= 'z' || '0' <= b && b <= '9') {
439			return false
440		}
441	}
442	return true
443}
444
445// honorSymlinks toggles whether or not we consider symlinks when comparing
446// file or directory URIs.
447const honorSymlinks = false
448
449func CompareURI(left, right span.URI) int {
450	if honorSymlinks {
451		return span.CompareURI(left, right)
452	}
453	if left == right {
454		return 0
455	}
456	if left < right {
457		return -1
458	}
459	return 1
460}
461
462// InDir checks whether path is in the file tree rooted at dir.
463// InDir makes some effort to succeed even in the presence of symbolic links.
464//
465// Copied and slightly adjusted from go/src/cmd/go/internal/search/search.go.
466func InDir(dir, path string) bool {
467	if inDirLex(dir, path) {
468		return true
469	}
470	if !honorSymlinks {
471		return false
472	}
473	xpath, err := filepath.EvalSymlinks(path)
474	if err != nil || xpath == path {
475		xpath = ""
476	} else {
477		if inDirLex(dir, xpath) {
478			return true
479		}
480	}
481
482	xdir, err := filepath.EvalSymlinks(dir)
483	if err == nil && xdir != dir {
484		if inDirLex(xdir, path) {
485			return true
486		}
487		if xpath != "" {
488			if inDirLex(xdir, xpath) {
489				return true
490			}
491		}
492	}
493	return false
494}
495
496// inDirLex is like inDir but only checks the lexical form of the file names.
497// It does not consider symbolic links.
498//
499// Copied from go/src/cmd/go/internal/search/search.go.
500func inDirLex(dir, path string) bool {
501	pv := strings.ToUpper(filepath.VolumeName(path))
502	dv := strings.ToUpper(filepath.VolumeName(dir))
503	path = path[len(pv):]
504	dir = dir[len(dv):]
505	switch {
506	default:
507		return false
508	case pv != dv:
509		return false
510	case len(path) == len(dir):
511		if path == dir {
512			return true
513		}
514		return false
515	case dir == "":
516		return path != ""
517	case len(path) > len(dir):
518		if dir[len(dir)-1] == filepath.Separator {
519			if path[:len(dir)] == dir {
520				return path[len(dir):] != ""
521			}
522			return false
523		}
524		if path[len(dir)] == filepath.Separator && path[:len(dir)] == dir {
525			if len(path) == len(dir)+1 {
526				return true
527			}
528			return path[len(dir)+1:] != ""
529		}
530		return false
531	}
532}
533