1// Copyright 2020 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package completion
6
7import (
8	"go/ast"
9	"go/token"
10	"go/types"
11
12	"golang.org/x/tools/internal/lsp/source"
13)
14
15// exprAtPos returns the index of the expression containing pos.
16func exprAtPos(pos token.Pos, args []ast.Expr) int {
17	for i, expr := range args {
18		if expr.Pos() <= pos && pos <= expr.End() {
19			return i
20		}
21	}
22	return len(args)
23}
24
25// eachField invokes fn for each field that can be selected from a
26// value of type T.
27func eachField(T types.Type, fn func(*types.Var)) {
28	// TODO(adonovan): this algorithm doesn't exclude ambiguous
29	// selections that match more than one field/method.
30	// types.NewSelectionSet should do that for us.
31
32	// for termination on recursive types
33	var seen map[*types.Struct]bool
34
35	var visit func(T types.Type)
36	visit = func(T types.Type) {
37		if T, ok := source.Deref(T).Underlying().(*types.Struct); ok {
38			if seen[T] {
39				return
40			}
41
42			for i := 0; i < T.NumFields(); i++ {
43				f := T.Field(i)
44				fn(f)
45				if f.Anonymous() {
46					if seen == nil {
47						// Lazily create "seen" since it is only needed for
48						// embedded structs.
49						seen = make(map[*types.Struct]bool)
50					}
51					seen[T] = true
52					visit(f.Type())
53				}
54			}
55		}
56	}
57	visit(T)
58}
59
60// typeIsValid reports whether typ doesn't contain any Invalid types.
61func typeIsValid(typ types.Type) bool {
62	// Check named types separately, because we don't want
63	// to call Underlying() on them to avoid problems with recursive types.
64	if _, ok := typ.(*types.Named); ok {
65		return true
66	}
67
68	switch typ := typ.Underlying().(type) {
69	case *types.Basic:
70		return typ.Kind() != types.Invalid
71	case *types.Array:
72		return typeIsValid(typ.Elem())
73	case *types.Slice:
74		return typeIsValid(typ.Elem())
75	case *types.Pointer:
76		return typeIsValid(typ.Elem())
77	case *types.Map:
78		return typeIsValid(typ.Key()) && typeIsValid(typ.Elem())
79	case *types.Chan:
80		return typeIsValid(typ.Elem())
81	case *types.Signature:
82		return typeIsValid(typ.Params()) && typeIsValid(typ.Results())
83	case *types.Tuple:
84		for i := 0; i < typ.Len(); i++ {
85			if !typeIsValid(typ.At(i).Type()) {
86				return false
87			}
88		}
89		return true
90	case *types.Struct, *types.Interface:
91		// Don't bother checking structs, interfaces for validity.
92		return true
93	default:
94		return false
95	}
96}
97
98// resolveInvalid traverses the node of the AST that defines the scope
99// containing the declaration of obj, and attempts to find a user-friendly
100// name for its invalid type. The resulting Object and its Type are fake.
101func resolveInvalid(fset *token.FileSet, obj types.Object, node ast.Node, info *types.Info) types.Object {
102	var resultExpr ast.Expr
103	ast.Inspect(node, func(node ast.Node) bool {
104		switch n := node.(type) {
105		case *ast.ValueSpec:
106			for _, name := range n.Names {
107				if info.Defs[name] == obj {
108					resultExpr = n.Type
109				}
110			}
111			return false
112		case *ast.Field: // This case handles parameters and results of a FuncDecl or FuncLit.
113			for _, name := range n.Names {
114				if info.Defs[name] == obj {
115					resultExpr = n.Type
116				}
117			}
118			return false
119		default:
120			return true
121		}
122	})
123	// Construct a fake type for the object and return a fake object with this type.
124	typename := source.FormatNode(fset, resultExpr)
125	typ := types.NewNamed(types.NewTypeName(token.NoPos, obj.Pkg(), typename, nil), types.Typ[types.Invalid], nil)
126	return types.NewVar(obj.Pos(), obj.Pkg(), obj.Name(), typ)
127}
128
129func isPointer(T types.Type) bool {
130	_, ok := T.(*types.Pointer)
131	return ok
132}
133
134func isVar(obj types.Object) bool {
135	_, ok := obj.(*types.Var)
136	return ok
137}
138
139func isTypeName(obj types.Object) bool {
140	_, ok := obj.(*types.TypeName)
141	return ok
142}
143
144func isFunc(obj types.Object) bool {
145	_, ok := obj.(*types.Func)
146	return ok
147}
148
149func isEmptyInterface(T types.Type) bool {
150	intf, _ := T.(*types.Interface)
151	return intf != nil && intf.NumMethods() == 0
152}
153
154func isUntyped(T types.Type) bool {
155	if basic, ok := T.(*types.Basic); ok {
156		return basic.Info()&types.IsUntyped > 0
157	}
158	return false
159}
160
161func isPkgName(obj types.Object) bool {
162	_, ok := obj.(*types.PkgName)
163	return ok
164}
165
166func isASTFile(n ast.Node) bool {
167	_, ok := n.(*ast.File)
168	return ok
169}
170
171func deslice(T types.Type) types.Type {
172	if slice, ok := T.Underlying().(*types.Slice); ok {
173		return slice.Elem()
174	}
175	return nil
176}
177
178// isSelector returns the enclosing *ast.SelectorExpr when pos is in the
179// selector.
180func enclosingSelector(path []ast.Node, pos token.Pos) *ast.SelectorExpr {
181	if len(path) == 0 {
182		return nil
183	}
184
185	if sel, ok := path[0].(*ast.SelectorExpr); ok {
186		return sel
187	}
188
189	if _, ok := path[0].(*ast.Ident); ok && len(path) > 1 {
190		if sel, ok := path[1].(*ast.SelectorExpr); ok && pos >= sel.Sel.Pos() {
191			return sel
192		}
193	}
194
195	return nil
196}
197
198// enclosingDeclLHS returns LHS idents from containing value spec or
199// assign statement.
200func enclosingDeclLHS(path []ast.Node) []*ast.Ident {
201	for _, n := range path {
202		switch n := n.(type) {
203		case *ast.ValueSpec:
204			return n.Names
205		case *ast.AssignStmt:
206			ids := make([]*ast.Ident, 0, len(n.Lhs))
207			for _, e := range n.Lhs {
208				if id, ok := e.(*ast.Ident); ok {
209					ids = append(ids, id)
210				}
211			}
212			return ids
213		}
214	}
215
216	return nil
217}
218
219// exprObj returns the types.Object associated with the *ast.Ident or
220// *ast.SelectorExpr e.
221func exprObj(info *types.Info, e ast.Expr) types.Object {
222	var ident *ast.Ident
223	switch expr := e.(type) {
224	case *ast.Ident:
225		ident = expr
226	case *ast.SelectorExpr:
227		ident = expr.Sel
228	default:
229		return nil
230	}
231
232	return info.ObjectOf(ident)
233}
234
235// typeConversion returns the type being converted to if call is a type
236// conversion expression.
237func typeConversion(call *ast.CallExpr, info *types.Info) types.Type {
238	// Type conversion (e.g. "float64(foo)").
239	if fun, _ := exprObj(info, call.Fun).(*types.TypeName); fun != nil {
240		return fun.Type()
241	}
242
243	return nil
244}
245
246// fieldsAccessible returns whether s has at least one field accessible by p.
247func fieldsAccessible(s *types.Struct, p *types.Package) bool {
248	for i := 0; i < s.NumFields(); i++ {
249		f := s.Field(i)
250		if f.Exported() || f.Pkg() == p {
251			return true
252		}
253	}
254	return false
255}
256
257// prevStmt returns the statement that precedes the statement containing pos.
258// For example:
259//
260//     foo := 1
261//     bar(1 + 2<>)
262//
263// If "<>" is pos, prevStmt returns "foo := 1"
264func prevStmt(pos token.Pos, path []ast.Node) ast.Stmt {
265	var blockLines []ast.Stmt
266	for i := 0; i < len(path) && blockLines == nil; i++ {
267		switch n := path[i].(type) {
268		case *ast.BlockStmt:
269			blockLines = n.List
270		case *ast.CommClause:
271			blockLines = n.Body
272		case *ast.CaseClause:
273			blockLines = n.Body
274		}
275	}
276
277	for i := len(blockLines) - 1; i >= 0; i-- {
278		if blockLines[i].End() < pos {
279			return blockLines[i]
280		}
281	}
282
283	return nil
284}
285
286// formatZeroValue produces Go code representing the zero value of T. It
287// returns the empty string if T is invalid.
288func formatZeroValue(T types.Type, qf types.Qualifier) string {
289	switch u := T.Underlying().(type) {
290	case *types.Basic:
291		switch {
292		case u.Info()&types.IsNumeric > 0:
293			return "0"
294		case u.Info()&types.IsString > 0:
295			return `""`
296		case u.Info()&types.IsBoolean > 0:
297			return "false"
298		default:
299			return ""
300		}
301	case *types.Pointer, *types.Interface, *types.Chan, *types.Map, *types.Slice, *types.Signature:
302		return "nil"
303	default:
304		return types.TypeString(T, qf) + "{}"
305	}
306}
307
308// isBasicKind returns whether t is a basic type of kind k.
309func isBasicKind(t types.Type, k types.BasicInfo) bool {
310	b, _ := t.Underlying().(*types.Basic)
311	return b != nil && b.Info()&k > 0
312}
313