1// Copyright 2014 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package satisfy inspects the type-checked ASTs of Go packages and
6// reports the set of discovered type constraints of the form (lhs, rhs
7// Type) where lhs is a non-trivial interface, rhs satisfies this
8// interface, and this fact is necessary for the package to be
9// well-typed.
10//
11// THIS PACKAGE IS EXPERIMENTAL AND MAY CHANGE AT ANY TIME.
12//
13// It is provided only for the gorename tool.  Ideally this
14// functionality will become part of the type-checker in due course,
15// since it is computing it anyway, and it is robust for ill-typed
16// inputs, which this package is not.
17//
18package satisfy // import "golang.org/x/tools/refactor/satisfy"
19
20// NOTES:
21//
22// We don't care about numeric conversions, so we don't descend into
23// types or constant expressions.  This is unsound because
24// constant expressions can contain arbitrary statements, e.g.
25//   const x = len([1]func(){func() {
26//     ...
27//   }})
28//
29// TODO(adonovan): make this robust against ill-typed input.
30// Or move it into the type-checker.
31//
32// Assignability conversions are possible in the following places:
33// - in assignments y = x, y := x, var y = x.
34// - from call argument types to formal parameter types
35// - in append and delete calls
36// - from return operands to result parameter types
37// - in composite literal T{k:v}, from k and v to T's field/element/key type
38// - in map[key] from key to the map's key type
39// - in comparisons x==y and switch x { case y: }.
40// - in explicit conversions T(x)
41// - in sends ch <- x, from x to the channel element type
42// - in type assertions x.(T) and switch x.(type) { case T: }
43//
44// The results of this pass provide information equivalent to the
45// ssa.MakeInterface and ssa.ChangeInterface instructions.
46
47import (
48	"fmt"
49	"go/ast"
50	"go/token"
51	"go/types"
52
53	"golang.org/x/tools/go/ast/astutil"
54	"golang.org/x/tools/go/types/typeutil"
55)
56
57// A Constraint records the fact that the RHS type does and must
58// satisify the LHS type, which is an interface.
59// The names are suggestive of an assignment statement LHS = RHS.
60type Constraint struct {
61	LHS, RHS types.Type
62}
63
64// A Finder inspects the type-checked ASTs of Go packages and
65// accumulates the set of type constraints (x, y) such that x is
66// assignable to y, y is an interface, and both x and y have methods.
67//
68// In other words, it returns the subset of the "implements" relation
69// that is checked during compilation of a package.  Refactoring tools
70// will need to preserve at least this part of the relation to ensure
71// continued compilation.
72//
73type Finder struct {
74	Result    map[Constraint]bool
75	msetcache typeutil.MethodSetCache
76
77	// per-Find state
78	info *types.Info
79	sig  *types.Signature
80}
81
82// Find inspects a single package, populating Result with its pairs of
83// constrained types.
84//
85// The result is non-canonical and thus may contain duplicates (but this
86// tends to preserves names of interface types better).
87//
88// The package must be free of type errors, and
89// info.{Defs,Uses,Selections,Types} must have been populated by the
90// type-checker.
91//
92func (f *Finder) Find(info *types.Info, files []*ast.File) {
93	if f.Result == nil {
94		f.Result = make(map[Constraint]bool)
95	}
96
97	f.info = info
98	for _, file := range files {
99		for _, d := range file.Decls {
100			switch d := d.(type) {
101			case *ast.GenDecl:
102				if d.Tok == token.VAR { // ignore consts
103					for _, spec := range d.Specs {
104						f.valueSpec(spec.(*ast.ValueSpec))
105					}
106				}
107
108			case *ast.FuncDecl:
109				if d.Body != nil {
110					f.sig = f.info.Defs[d.Name].Type().(*types.Signature)
111					f.stmt(d.Body)
112					f.sig = nil
113				}
114			}
115		}
116	}
117	f.info = nil
118}
119
120var (
121	tInvalid     = types.Typ[types.Invalid]
122	tUntypedBool = types.Typ[types.UntypedBool]
123	tUntypedNil  = types.Typ[types.UntypedNil]
124)
125
126// exprN visits an expression in a multi-value context.
127func (f *Finder) exprN(e ast.Expr) types.Type {
128	typ := f.info.Types[e].Type.(*types.Tuple)
129	switch e := e.(type) {
130	case *ast.ParenExpr:
131		return f.exprN(e.X)
132
133	case *ast.CallExpr:
134		// x, err := f(args)
135		sig := f.expr(e.Fun).Underlying().(*types.Signature)
136		f.call(sig, e.Args)
137
138	case *ast.IndexExpr:
139		// y, ok := x[i]
140		x := f.expr(e.X)
141		f.assign(f.expr(e.Index), x.Underlying().(*types.Map).Key())
142
143	case *ast.TypeAssertExpr:
144		// y, ok := x.(T)
145		f.typeAssert(f.expr(e.X), typ.At(0).Type())
146
147	case *ast.UnaryExpr: // must be receive <-
148		// y, ok := <-x
149		f.expr(e.X)
150
151	default:
152		panic(e)
153	}
154	return typ
155}
156
157func (f *Finder) call(sig *types.Signature, args []ast.Expr) {
158	if len(args) == 0 {
159		return
160	}
161
162	// Ellipsis call?  e.g. f(x, y, z...)
163	if _, ok := args[len(args)-1].(*ast.Ellipsis); ok {
164		for i, arg := range args {
165			// The final arg is a slice, and so is the final param.
166			f.assign(sig.Params().At(i).Type(), f.expr(arg))
167		}
168		return
169	}
170
171	var argtypes []types.Type
172
173	// Gather the effective actual parameter types.
174	if tuple, ok := f.info.Types[args[0]].Type.(*types.Tuple); ok {
175		// f(g()) call where g has multiple results?
176		f.expr(args[0])
177		// unpack the tuple
178		for i := 0; i < tuple.Len(); i++ {
179			argtypes = append(argtypes, tuple.At(i).Type())
180		}
181	} else {
182		for _, arg := range args {
183			argtypes = append(argtypes, f.expr(arg))
184		}
185	}
186
187	// Assign the actuals to the formals.
188	if !sig.Variadic() {
189		for i, argtype := range argtypes {
190			f.assign(sig.Params().At(i).Type(), argtype)
191		}
192	} else {
193		// The first n-1 parameters are assigned normally.
194		nnormals := sig.Params().Len() - 1
195		for i, argtype := range argtypes[:nnormals] {
196			f.assign(sig.Params().At(i).Type(), argtype)
197		}
198		// Remaining args are assigned to elements of varargs slice.
199		tElem := sig.Params().At(nnormals).Type().(*types.Slice).Elem()
200		for i := nnormals; i < len(argtypes); i++ {
201			f.assign(tElem, argtypes[i])
202		}
203	}
204}
205
206func (f *Finder) builtin(obj *types.Builtin, sig *types.Signature, args []ast.Expr, T types.Type) types.Type {
207	switch obj.Name() {
208	case "make", "new":
209		// skip the type operand
210		for _, arg := range args[1:] {
211			f.expr(arg)
212		}
213
214	case "append":
215		s := f.expr(args[0])
216		if _, ok := args[len(args)-1].(*ast.Ellipsis); ok && len(args) == 2 {
217			// append(x, y...)   including append([]byte, "foo"...)
218			f.expr(args[1])
219		} else {
220			// append(x, y, z)
221			tElem := s.Underlying().(*types.Slice).Elem()
222			for _, arg := range args[1:] {
223				f.assign(tElem, f.expr(arg))
224			}
225		}
226
227	case "delete":
228		m := f.expr(args[0])
229		k := f.expr(args[1])
230		f.assign(m.Underlying().(*types.Map).Key(), k)
231
232	default:
233		// ordinary call
234		f.call(sig, args)
235	}
236
237	return T
238}
239
240func (f *Finder) extract(tuple types.Type, i int) types.Type {
241	if tuple, ok := tuple.(*types.Tuple); ok && i < tuple.Len() {
242		return tuple.At(i).Type()
243	}
244	return tInvalid
245}
246
247func (f *Finder) valueSpec(spec *ast.ValueSpec) {
248	var T types.Type
249	if spec.Type != nil {
250		T = f.info.Types[spec.Type].Type
251	}
252	switch len(spec.Values) {
253	case len(spec.Names): // e.g. var x, y = f(), g()
254		for _, value := range spec.Values {
255			v := f.expr(value)
256			if T != nil {
257				f.assign(T, v)
258			}
259		}
260
261	case 1: // e.g. var x, y = f()
262		tuple := f.exprN(spec.Values[0])
263		for i := range spec.Names {
264			if T != nil {
265				f.assign(T, f.extract(tuple, i))
266			}
267		}
268	}
269}
270
271// assign records pairs of distinct types that are related by
272// assignability, where the left-hand side is an interface and both
273// sides have methods.
274//
275// It should be called for all assignability checks, type assertions,
276// explicit conversions and comparisons between two types, unless the
277// types are uninteresting (e.g. lhs is a concrete type, or the empty
278// interface; rhs has no methods).
279//
280func (f *Finder) assign(lhs, rhs types.Type) {
281	if types.Identical(lhs, rhs) {
282		return
283	}
284	if !isInterface(lhs) {
285		return
286	}
287
288	if f.msetcache.MethodSet(lhs).Len() == 0 {
289		return
290	}
291	if f.msetcache.MethodSet(rhs).Len() == 0 {
292		return
293	}
294	// record the pair
295	f.Result[Constraint{lhs, rhs}] = true
296}
297
298// typeAssert must be called for each type assertion x.(T) where x has
299// interface type I.
300func (f *Finder) typeAssert(I, T types.Type) {
301	// Type assertions are slightly subtle, because they are allowed
302	// to be "impossible", e.g.
303	//
304	// 	var x interface{f()}
305	//	_ = x.(interface{f()int}) // legal
306	//
307	// (In hindsight, the language spec should probably not have
308	// allowed this, but it's too late to fix now.)
309	//
310	// This means that a type assert from I to T isn't exactly a
311	// constraint that T is assignable to I, but for a refactoring
312	// tool it is a conditional constraint that, if T is assignable
313	// to I before a refactoring, it should remain so after.
314
315	if types.AssignableTo(T, I) {
316		f.assign(I, T)
317	}
318}
319
320// compare must be called for each comparison x==y.
321func (f *Finder) compare(x, y types.Type) {
322	if types.AssignableTo(x, y) {
323		f.assign(y, x)
324	} else if types.AssignableTo(y, x) {
325		f.assign(x, y)
326	}
327}
328
329// expr visits a true expression (not a type or defining ident)
330// and returns its type.
331func (f *Finder) expr(e ast.Expr) types.Type {
332	tv := f.info.Types[e]
333	if tv.Value != nil {
334		return tv.Type // prune the descent for constants
335	}
336
337	// tv.Type may be nil for an ast.Ident.
338
339	switch e := e.(type) {
340	case *ast.BadExpr, *ast.BasicLit:
341		// no-op
342
343	case *ast.Ident:
344		// (referring idents only)
345		if obj, ok := f.info.Uses[e]; ok {
346			return obj.Type()
347		}
348		if e.Name == "_" { // e.g. "for _ = range x"
349			return tInvalid
350		}
351		panic("undefined ident: " + e.Name)
352
353	case *ast.Ellipsis:
354		if e.Elt != nil {
355			f.expr(e.Elt)
356		}
357
358	case *ast.FuncLit:
359		saved := f.sig
360		f.sig = tv.Type.(*types.Signature)
361		f.stmt(e.Body)
362		f.sig = saved
363
364	case *ast.CompositeLit:
365		switch T := deref(tv.Type).Underlying().(type) {
366		case *types.Struct:
367			for i, elem := range e.Elts {
368				if kv, ok := elem.(*ast.KeyValueExpr); ok {
369					f.assign(f.info.Uses[kv.Key.(*ast.Ident)].Type(), f.expr(kv.Value))
370				} else {
371					f.assign(T.Field(i).Type(), f.expr(elem))
372				}
373			}
374
375		case *types.Map:
376			for _, elem := range e.Elts {
377				elem := elem.(*ast.KeyValueExpr)
378				f.assign(T.Key(), f.expr(elem.Key))
379				f.assign(T.Elem(), f.expr(elem.Value))
380			}
381
382		case *types.Array, *types.Slice:
383			tElem := T.(interface {
384				Elem() types.Type
385			}).Elem()
386			for _, elem := range e.Elts {
387				if kv, ok := elem.(*ast.KeyValueExpr); ok {
388					// ignore the key
389					f.assign(tElem, f.expr(kv.Value))
390				} else {
391					f.assign(tElem, f.expr(elem))
392				}
393			}
394
395		default:
396			panic("unexpected composite literal type: " + tv.Type.String())
397		}
398
399	case *ast.ParenExpr:
400		f.expr(e.X)
401
402	case *ast.SelectorExpr:
403		if _, ok := f.info.Selections[e]; ok {
404			f.expr(e.X) // selection
405		} else {
406			return f.info.Uses[e.Sel].Type() // qualified identifier
407		}
408
409	case *ast.IndexExpr:
410		x := f.expr(e.X)
411		i := f.expr(e.Index)
412		if ux, ok := x.Underlying().(*types.Map); ok {
413			f.assign(ux.Key(), i)
414		}
415
416	case *ast.SliceExpr:
417		f.expr(e.X)
418		if e.Low != nil {
419			f.expr(e.Low)
420		}
421		if e.High != nil {
422			f.expr(e.High)
423		}
424		if e.Max != nil {
425			f.expr(e.Max)
426		}
427
428	case *ast.TypeAssertExpr:
429		x := f.expr(e.X)
430		f.typeAssert(x, f.info.Types[e.Type].Type)
431
432	case *ast.CallExpr:
433		if tvFun := f.info.Types[e.Fun]; tvFun.IsType() {
434			// conversion
435			arg0 := f.expr(e.Args[0])
436			f.assign(tvFun.Type, arg0)
437		} else {
438			// function call
439			if id, ok := unparen(e.Fun).(*ast.Ident); ok {
440				if obj, ok := f.info.Uses[id].(*types.Builtin); ok {
441					sig := f.info.Types[id].Type.(*types.Signature)
442					return f.builtin(obj, sig, e.Args, tv.Type)
443				}
444			}
445			// ordinary call
446			f.call(f.expr(e.Fun).Underlying().(*types.Signature), e.Args)
447		}
448
449	case *ast.StarExpr:
450		f.expr(e.X)
451
452	case *ast.UnaryExpr:
453		f.expr(e.X)
454
455	case *ast.BinaryExpr:
456		x := f.expr(e.X)
457		y := f.expr(e.Y)
458		if e.Op == token.EQL || e.Op == token.NEQ {
459			f.compare(x, y)
460		}
461
462	case *ast.KeyValueExpr:
463		f.expr(e.Key)
464		f.expr(e.Value)
465
466	case *ast.ArrayType,
467		*ast.StructType,
468		*ast.FuncType,
469		*ast.InterfaceType,
470		*ast.MapType,
471		*ast.ChanType:
472		panic(e)
473	}
474
475	if tv.Type == nil {
476		panic(fmt.Sprintf("no type for %T", e))
477	}
478
479	return tv.Type
480}
481
482func (f *Finder) stmt(s ast.Stmt) {
483	switch s := s.(type) {
484	case *ast.BadStmt,
485		*ast.EmptyStmt,
486		*ast.BranchStmt:
487		// no-op
488
489	case *ast.DeclStmt:
490		d := s.Decl.(*ast.GenDecl)
491		if d.Tok == token.VAR { // ignore consts
492			for _, spec := range d.Specs {
493				f.valueSpec(spec.(*ast.ValueSpec))
494			}
495		}
496
497	case *ast.LabeledStmt:
498		f.stmt(s.Stmt)
499
500	case *ast.ExprStmt:
501		f.expr(s.X)
502
503	case *ast.SendStmt:
504		ch := f.expr(s.Chan)
505		val := f.expr(s.Value)
506		f.assign(ch.Underlying().(*types.Chan).Elem(), val)
507
508	case *ast.IncDecStmt:
509		f.expr(s.X)
510
511	case *ast.AssignStmt:
512		switch s.Tok {
513		case token.ASSIGN, token.DEFINE:
514			// y := x   or   y = x
515			var rhsTuple types.Type
516			if len(s.Lhs) != len(s.Rhs) {
517				rhsTuple = f.exprN(s.Rhs[0])
518			}
519			for i := range s.Lhs {
520				var lhs, rhs types.Type
521				if rhsTuple == nil {
522					rhs = f.expr(s.Rhs[i]) // 1:1 assignment
523				} else {
524					rhs = f.extract(rhsTuple, i) // n:1 assignment
525				}
526
527				if id, ok := s.Lhs[i].(*ast.Ident); ok {
528					if id.Name != "_" {
529						if obj, ok := f.info.Defs[id]; ok {
530							lhs = obj.Type() // definition
531						}
532					}
533				}
534				if lhs == nil {
535					lhs = f.expr(s.Lhs[i]) // assignment
536				}
537				f.assign(lhs, rhs)
538			}
539
540		default:
541			// y op= x
542			f.expr(s.Lhs[0])
543			f.expr(s.Rhs[0])
544		}
545
546	case *ast.GoStmt:
547		f.expr(s.Call)
548
549	case *ast.DeferStmt:
550		f.expr(s.Call)
551
552	case *ast.ReturnStmt:
553		formals := f.sig.Results()
554		switch len(s.Results) {
555		case formals.Len(): // 1:1
556			for i, result := range s.Results {
557				f.assign(formals.At(i).Type(), f.expr(result))
558			}
559
560		case 1: // n:1
561			tuple := f.exprN(s.Results[0])
562			for i := 0; i < formals.Len(); i++ {
563				f.assign(formals.At(i).Type(), f.extract(tuple, i))
564			}
565		}
566
567	case *ast.SelectStmt:
568		f.stmt(s.Body)
569
570	case *ast.BlockStmt:
571		for _, s := range s.List {
572			f.stmt(s)
573		}
574
575	case *ast.IfStmt:
576		if s.Init != nil {
577			f.stmt(s.Init)
578		}
579		f.expr(s.Cond)
580		f.stmt(s.Body)
581		if s.Else != nil {
582			f.stmt(s.Else)
583		}
584
585	case *ast.SwitchStmt:
586		if s.Init != nil {
587			f.stmt(s.Init)
588		}
589		var tag types.Type = tUntypedBool
590		if s.Tag != nil {
591			tag = f.expr(s.Tag)
592		}
593		for _, cc := range s.Body.List {
594			cc := cc.(*ast.CaseClause)
595			for _, cond := range cc.List {
596				f.compare(tag, f.info.Types[cond].Type)
597			}
598			for _, s := range cc.Body {
599				f.stmt(s)
600			}
601		}
602
603	case *ast.TypeSwitchStmt:
604		if s.Init != nil {
605			f.stmt(s.Init)
606		}
607		var I types.Type
608		switch ass := s.Assign.(type) {
609		case *ast.ExprStmt: // x.(type)
610			I = f.expr(unparen(ass.X).(*ast.TypeAssertExpr).X)
611		case *ast.AssignStmt: // y := x.(type)
612			I = f.expr(unparen(ass.Rhs[0]).(*ast.TypeAssertExpr).X)
613		}
614		for _, cc := range s.Body.List {
615			cc := cc.(*ast.CaseClause)
616			for _, cond := range cc.List {
617				tCase := f.info.Types[cond].Type
618				if tCase != tUntypedNil {
619					f.typeAssert(I, tCase)
620				}
621			}
622			for _, s := range cc.Body {
623				f.stmt(s)
624			}
625		}
626
627	case *ast.CommClause:
628		if s.Comm != nil {
629			f.stmt(s.Comm)
630		}
631		for _, s := range s.Body {
632			f.stmt(s)
633		}
634
635	case *ast.ForStmt:
636		if s.Init != nil {
637			f.stmt(s.Init)
638		}
639		if s.Cond != nil {
640			f.expr(s.Cond)
641		}
642		if s.Post != nil {
643			f.stmt(s.Post)
644		}
645		f.stmt(s.Body)
646
647	case *ast.RangeStmt:
648		x := f.expr(s.X)
649		// No conversions are involved when Tok==DEFINE.
650		if s.Tok == token.ASSIGN {
651			if s.Key != nil {
652				k := f.expr(s.Key)
653				var xelem types.Type
654				// keys of array, *array, slice, string aren't interesting
655				switch ux := x.Underlying().(type) {
656				case *types.Chan:
657					xelem = ux.Elem()
658				case *types.Map:
659					xelem = ux.Key()
660				}
661				if xelem != nil {
662					f.assign(xelem, k)
663				}
664			}
665			if s.Value != nil {
666				val := f.expr(s.Value)
667				var xelem types.Type
668				// values of strings aren't interesting
669				switch ux := x.Underlying().(type) {
670				case *types.Array:
671					xelem = ux.Elem()
672				case *types.Chan:
673					xelem = ux.Elem()
674				case *types.Map:
675					xelem = ux.Elem()
676				case *types.Pointer: // *array
677					xelem = deref(ux).(*types.Array).Elem()
678				case *types.Slice:
679					xelem = ux.Elem()
680				}
681				if xelem != nil {
682					f.assign(xelem, val)
683				}
684			}
685		}
686		f.stmt(s.Body)
687
688	default:
689		panic(s)
690	}
691}
692
693// -- Plundered from golang.org/x/tools/go/ssa -----------------
694
695// deref returns a pointer's element type; otherwise it returns typ.
696func deref(typ types.Type) types.Type {
697	if p, ok := typ.Underlying().(*types.Pointer); ok {
698		return p.Elem()
699	}
700	return typ
701}
702
703func unparen(e ast.Expr) ast.Expr { return astutil.Unparen(e) }
704
705func isInterface(T types.Type) bool { return types.IsInterface(T) }
706