1// Copyright 2011 The Go Authors.  All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package main
6
7import (
8	"fmt"
9	"go/ast"
10	"go/token"
11	"os"
12	"reflect"
13	"strings"
14)
15
16// Partial type checker.
17//
18// The fact that it is partial is very important: the input is
19// an AST and a description of some type information to
20// assume about one or more packages, but not all the
21// packages that the program imports.  The checker is
22// expected to do as much as it can with what it has been
23// given.  There is not enough information supplied to do
24// a full type check, but the type checker is expected to
25// apply information that can be derived from variable
26// declarations, function and method returns, and type switches
27// as far as it can, so that the caller can still tell the types
28// of expression relevant to a particular fix.
29//
30// TODO(rsc,gri): Replace with go/typechecker.
31// Doing that could be an interesting test case for go/typechecker:
32// the constraints about working with partial information will
33// likely exercise it in interesting ways.  The ideal interface would
34// be to pass typecheck a map from importpath to package API text
35// (Go source code), but for now we use data structures (TypeConfig, Type).
36//
37// The strings mostly use gofmt form.
38//
39// A Field or FieldList has as its type a comma-separated list
40// of the types of the fields.  For example, the field list
41//	x, y, z int
42// has type "int, int, int".
43
44// The prefix "type " is the type of a type.
45// For example, given
46//	var x int
47//	type T int
48// x's type is "int" but T's type is "type int".
49// mkType inserts the "type " prefix.
50// getType removes it.
51// isType tests for it.
52
53func mkType(t string) string {
54	return "type " + t
55}
56
57func getType(t string) string {
58	if !isType(t) {
59		return ""
60	}
61	return t[len("type "):]
62}
63
64func isType(t string) bool {
65	return strings.HasPrefix(t, "type ")
66}
67
68// TypeConfig describes the universe of relevant types.
69// For ease of creation, the types are all referred to by string
70// name (e.g., "reflect.Value").  TypeByName is the only place
71// where the strings are resolved.
72
73type TypeConfig struct {
74	Type map[string]*Type
75	Var  map[string]string
76	Func map[string]string
77}
78
79// typeof returns the type of the given name, which may be of
80// the form "x" or "p.X".
81func (cfg *TypeConfig) typeof(name string) string {
82	if cfg.Var != nil {
83		if t := cfg.Var[name]; t != "" {
84			return t
85		}
86	}
87	if cfg.Func != nil {
88		if t := cfg.Func[name]; t != "" {
89			return "func()" + t
90		}
91	}
92	return ""
93}
94
95// Type describes the Fields and Methods of a type.
96// If the field or method cannot be found there, it is next
97// looked for in the Embed list.
98type Type struct {
99	Field  map[string]string // map field name to type
100	Method map[string]string // map method name to comma-separated return types (should start with "func ")
101	Embed  []string          // list of types this type embeds (for extra methods)
102	Def    string            // definition of named type
103}
104
105// dot returns the type of "typ.name", making its decision
106// using the type information in cfg.
107func (typ *Type) dot(cfg *TypeConfig, name string) string {
108	if typ.Field != nil {
109		if t := typ.Field[name]; t != "" {
110			return t
111		}
112	}
113	if typ.Method != nil {
114		if t := typ.Method[name]; t != "" {
115			return t
116		}
117	}
118
119	for _, e := range typ.Embed {
120		etyp := cfg.Type[e]
121		if etyp != nil {
122			if t := etyp.dot(cfg, name); t != "" {
123				return t
124			}
125		}
126	}
127
128	return ""
129}
130
131// typecheck type checks the AST f assuming the information in cfg.
132// It returns two maps with type information:
133// typeof maps AST nodes to type information in gofmt string form.
134// assign maps type strings to lists of expressions that were assigned
135// to values of another type that were assigned to that type.
136func typecheck(cfg *TypeConfig, f *ast.File) (typeof map[interface{}]string, assign map[string][]interface{}) {
137	typeof = make(map[interface{}]string)
138	assign = make(map[string][]interface{})
139	cfg1 := &TypeConfig{}
140	*cfg1 = *cfg // make copy so we can add locally
141	copied := false
142
143	// gather function declarations
144	for _, decl := range f.Decls {
145		fn, ok := decl.(*ast.FuncDecl)
146		if !ok {
147			continue
148		}
149		typecheck1(cfg, fn.Type, typeof, assign)
150		t := typeof[fn.Type]
151		if fn.Recv != nil {
152			// The receiver must be a type.
153			rcvr := typeof[fn.Recv]
154			if !isType(rcvr) {
155				if len(fn.Recv.List) != 1 {
156					continue
157				}
158				rcvr = mkType(gofmt(fn.Recv.List[0].Type))
159				typeof[fn.Recv.List[0].Type] = rcvr
160			}
161			rcvr = getType(rcvr)
162			if rcvr != "" && rcvr[0] == '*' {
163				rcvr = rcvr[1:]
164			}
165			typeof[rcvr+"."+fn.Name.Name] = t
166		} else {
167			if isType(t) {
168				t = getType(t)
169			} else {
170				t = gofmt(fn.Type)
171			}
172			typeof[fn.Name] = t
173
174			// Record typeof[fn.Name.Obj] for future references to fn.Name.
175			typeof[fn.Name.Obj] = t
176		}
177	}
178
179	// gather struct declarations
180	for _, decl := range f.Decls {
181		d, ok := decl.(*ast.GenDecl)
182		if ok {
183			for _, s := range d.Specs {
184				switch s := s.(type) {
185				case *ast.TypeSpec:
186					if cfg1.Type[s.Name.Name] != nil {
187						break
188					}
189					if !copied {
190						copied = true
191						// Copy map lazily: it's time.
192						cfg1.Type = make(map[string]*Type)
193						for k, v := range cfg.Type {
194							cfg1.Type[k] = v
195						}
196					}
197					t := &Type{Field: map[string]string{}}
198					cfg1.Type[s.Name.Name] = t
199					switch st := s.Type.(type) {
200					case *ast.StructType:
201						for _, f := range st.Fields.List {
202							for _, n := range f.Names {
203								t.Field[n.Name] = gofmt(f.Type)
204							}
205						}
206					case *ast.ArrayType, *ast.StarExpr, *ast.MapType:
207						t.Def = gofmt(st)
208					}
209				}
210			}
211		}
212	}
213
214	typecheck1(cfg1, f, typeof, assign)
215	return typeof, assign
216}
217
218func makeExprList(a []*ast.Ident) []ast.Expr {
219	var b []ast.Expr
220	for _, x := range a {
221		b = append(b, x)
222	}
223	return b
224}
225
226// Typecheck1 is the recursive form of typecheck.
227// It is like typecheck but adds to the information in typeof
228// instead of allocating a new map.
229func typecheck1(cfg *TypeConfig, f interface{}, typeof map[interface{}]string, assign map[string][]interface{}) {
230	// set sets the type of n to typ.
231	// If isDecl is true, n is being declared.
232	set := func(n ast.Expr, typ string, isDecl bool) {
233		if typeof[n] != "" || typ == "" {
234			if typeof[n] != typ {
235				assign[typ] = append(assign[typ], n)
236			}
237			return
238		}
239		typeof[n] = typ
240
241		// If we obtained typ from the declaration of x
242		// propagate the type to all the uses.
243		// The !isDecl case is a cheat here, but it makes
244		// up in some cases for not paying attention to
245		// struct fields.  The real type checker will be
246		// more accurate so we won't need the cheat.
247		if id, ok := n.(*ast.Ident); ok && id.Obj != nil && (isDecl || typeof[id.Obj] == "") {
248			typeof[id.Obj] = typ
249		}
250	}
251
252	// Type-check an assignment lhs = rhs.
253	// If isDecl is true, this is := so we can update
254	// the types of the objects that lhs refers to.
255	typecheckAssign := func(lhs, rhs []ast.Expr, isDecl bool) {
256		if len(lhs) > 1 && len(rhs) == 1 {
257			if _, ok := rhs[0].(*ast.CallExpr); ok {
258				t := split(typeof[rhs[0]])
259				// Lists should have same length but may not; pair what can be paired.
260				for i := 0; i < len(lhs) && i < len(t); i++ {
261					set(lhs[i], t[i], isDecl)
262				}
263				return
264			}
265		}
266		if len(lhs) == 1 && len(rhs) == 2 {
267			// x = y, ok
268			rhs = rhs[:1]
269		} else if len(lhs) == 2 && len(rhs) == 1 {
270			// x, ok = y
271			lhs = lhs[:1]
272		}
273
274		// Match as much as we can.
275		for i := 0; i < len(lhs) && i < len(rhs); i++ {
276			x, y := lhs[i], rhs[i]
277			if typeof[y] != "" {
278				set(x, typeof[y], isDecl)
279			} else {
280				set(y, typeof[x], false)
281			}
282		}
283	}
284
285	expand := func(s string) string {
286		typ := cfg.Type[s]
287		if typ != nil && typ.Def != "" {
288			return typ.Def
289		}
290		return s
291	}
292
293	// The main type check is a recursive algorithm implemented
294	// by walkBeforeAfter(n, before, after).
295	// Most of it is bottom-up, but in a few places we need
296	// to know the type of the function we are checking.
297	// The before function records that information on
298	// the curfn stack.
299	var curfn []*ast.FuncType
300
301	before := func(n interface{}) {
302		// push function type on stack
303		switch n := n.(type) {
304		case *ast.FuncDecl:
305			curfn = append(curfn, n.Type)
306		case *ast.FuncLit:
307			curfn = append(curfn, n.Type)
308		}
309	}
310
311	// After is the real type checker.
312	after := func(n interface{}) {
313		if n == nil {
314			return
315		}
316		if false && reflect.TypeOf(n).Kind() == reflect.Ptr { // debugging trace
317			defer func() {
318				if t := typeof[n]; t != "" {
319					pos := fset.Position(n.(ast.Node).Pos())
320					fmt.Fprintf(os.Stderr, "%s: typeof[%s] = %s\n", pos, gofmt(n), t)
321				}
322			}()
323		}
324
325		switch n := n.(type) {
326		case *ast.FuncDecl, *ast.FuncLit:
327			// pop function type off stack
328			curfn = curfn[:len(curfn)-1]
329
330		case *ast.FuncType:
331			typeof[n] = mkType(joinFunc(split(typeof[n.Params]), split(typeof[n.Results])))
332
333		case *ast.FieldList:
334			// Field list is concatenation of sub-lists.
335			t := ""
336			for _, field := range n.List {
337				if t != "" {
338					t += ", "
339				}
340				t += typeof[field]
341			}
342			typeof[n] = t
343
344		case *ast.Field:
345			// Field is one instance of the type per name.
346			all := ""
347			t := typeof[n.Type]
348			if !isType(t) {
349				// Create a type, because it is typically *T or *p.T
350				// and we might care about that type.
351				t = mkType(gofmt(n.Type))
352				typeof[n.Type] = t
353			}
354			t = getType(t)
355			if len(n.Names) == 0 {
356				all = t
357			} else {
358				for _, id := range n.Names {
359					if all != "" {
360						all += ", "
361					}
362					all += t
363					typeof[id.Obj] = t
364					typeof[id] = t
365				}
366			}
367			typeof[n] = all
368
369		case *ast.ValueSpec:
370			// var declaration.  Use type if present.
371			if n.Type != nil {
372				t := typeof[n.Type]
373				if !isType(t) {
374					t = mkType(gofmt(n.Type))
375					typeof[n.Type] = t
376				}
377				t = getType(t)
378				for _, id := range n.Names {
379					set(id, t, true)
380				}
381			}
382			// Now treat same as assignment.
383			typecheckAssign(makeExprList(n.Names), n.Values, true)
384
385		case *ast.AssignStmt:
386			typecheckAssign(n.Lhs, n.Rhs, n.Tok == token.DEFINE)
387
388		case *ast.Ident:
389			// Identifier can take its type from underlying object.
390			if t := typeof[n.Obj]; t != "" {
391				typeof[n] = t
392			}
393
394		case *ast.SelectorExpr:
395			// Field or method.
396			name := n.Sel.Name
397			if t := typeof[n.X]; t != "" {
398				if strings.HasPrefix(t, "*") {
399					t = t[1:] // implicit *
400				}
401				if typ := cfg.Type[t]; typ != nil {
402					if t := typ.dot(cfg, name); t != "" {
403						typeof[n] = t
404						return
405					}
406				}
407				tt := typeof[t+"."+name]
408				if isType(tt) {
409					typeof[n] = getType(tt)
410					return
411				}
412			}
413			// Package selector.
414			if x, ok := n.X.(*ast.Ident); ok && x.Obj == nil {
415				str := x.Name + "." + name
416				if cfg.Type[str] != nil {
417					typeof[n] = mkType(str)
418					return
419				}
420				if t := cfg.typeof(x.Name + "." + name); t != "" {
421					typeof[n] = t
422					return
423				}
424			}
425
426		case *ast.CallExpr:
427			// make(T) has type T.
428			if isTopName(n.Fun, "make") && len(n.Args) >= 1 {
429				typeof[n] = gofmt(n.Args[0])
430				return
431			}
432			// new(T) has type *T
433			if isTopName(n.Fun, "new") && len(n.Args) == 1 {
434				typeof[n] = "*" + gofmt(n.Args[0])
435				return
436			}
437			// Otherwise, use type of function to determine arguments.
438			t := typeof[n.Fun]
439			in, out := splitFunc(t)
440			if in == nil && out == nil {
441				return
442			}
443			typeof[n] = join(out)
444			for i, arg := range n.Args {
445				if i >= len(in) {
446					break
447				}
448				if typeof[arg] == "" {
449					typeof[arg] = in[i]
450				}
451			}
452
453		case *ast.TypeAssertExpr:
454			// x.(type) has type of x.
455			if n.Type == nil {
456				typeof[n] = typeof[n.X]
457				return
458			}
459			// x.(T) has type T.
460			if t := typeof[n.Type]; isType(t) {
461				typeof[n] = getType(t)
462			} else {
463				typeof[n] = gofmt(n.Type)
464			}
465
466		case *ast.SliceExpr:
467			// x[i:j] has type of x.
468			typeof[n] = typeof[n.X]
469
470		case *ast.IndexExpr:
471			// x[i] has key type of x's type.
472			t := expand(typeof[n.X])
473			if strings.HasPrefix(t, "[") || strings.HasPrefix(t, "map[") {
474				// Lazy: assume there are no nested [] in the array
475				// length or map key type.
476				if i := strings.Index(t, "]"); i >= 0 {
477					typeof[n] = t[i+1:]
478				}
479			}
480
481		case *ast.StarExpr:
482			// *x for x of type *T has type T when x is an expr.
483			// We don't use the result when *x is a type, but
484			// compute it anyway.
485			t := expand(typeof[n.X])
486			if isType(t) {
487				typeof[n] = "type *" + getType(t)
488			} else if strings.HasPrefix(t, "*") {
489				typeof[n] = t[len("*"):]
490			}
491
492		case *ast.UnaryExpr:
493			// &x for x of type T has type *T.
494			t := typeof[n.X]
495			if t != "" && n.Op == token.AND {
496				typeof[n] = "*" + t
497			}
498
499		case *ast.CompositeLit:
500			// T{...} has type T.
501			typeof[n] = gofmt(n.Type)
502
503		case *ast.ParenExpr:
504			// (x) has type of x.
505			typeof[n] = typeof[n.X]
506
507		case *ast.RangeStmt:
508			t := expand(typeof[n.X])
509			if t == "" {
510				return
511			}
512			var key, value string
513			if t == "string" {
514				key, value = "int", "rune"
515			} else if strings.HasPrefix(t, "[") {
516				key = "int"
517				if i := strings.Index(t, "]"); i >= 0 {
518					value = t[i+1:]
519				}
520			} else if strings.HasPrefix(t, "map[") {
521				if i := strings.Index(t, "]"); i >= 0 {
522					key, value = t[4:i], t[i+1:]
523				}
524			}
525			changed := false
526			if n.Key != nil && key != "" {
527				changed = true
528				set(n.Key, key, n.Tok == token.DEFINE)
529			}
530			if n.Value != nil && value != "" {
531				changed = true
532				set(n.Value, value, n.Tok == token.DEFINE)
533			}
534			// Ugly failure of vision: already type-checked body.
535			// Do it again now that we have that type info.
536			if changed {
537				typecheck1(cfg, n.Body, typeof, assign)
538			}
539
540		case *ast.TypeSwitchStmt:
541			// Type of variable changes for each case in type switch,
542			// but go/parser generates just one variable.
543			// Repeat type check for each case with more precise
544			// type information.
545			as, ok := n.Assign.(*ast.AssignStmt)
546			if !ok {
547				return
548			}
549			varx, ok := as.Lhs[0].(*ast.Ident)
550			if !ok {
551				return
552			}
553			t := typeof[varx]
554			for _, cas := range n.Body.List {
555				cas := cas.(*ast.CaseClause)
556				if len(cas.List) == 1 {
557					// Variable has specific type only when there is
558					// exactly one type in the case list.
559					if tt := typeof[cas.List[0]]; isType(tt) {
560						tt = getType(tt)
561						typeof[varx] = tt
562						typeof[varx.Obj] = tt
563						typecheck1(cfg, cas.Body, typeof, assign)
564					}
565				}
566			}
567			// Restore t.
568			typeof[varx] = t
569			typeof[varx.Obj] = t
570
571		case *ast.ReturnStmt:
572			if len(curfn) == 0 {
573				// Probably can't happen.
574				return
575			}
576			f := curfn[len(curfn)-1]
577			res := n.Results
578			if f.Results != nil {
579				t := split(typeof[f.Results])
580				for i := 0; i < len(res) && i < len(t); i++ {
581					set(res[i], t[i], false)
582				}
583			}
584		}
585	}
586	walkBeforeAfter(f, before, after)
587}
588
589// Convert between function type strings and lists of types.
590// Using strings makes this a little harder, but it makes
591// a lot of the rest of the code easier.  This will all go away
592// when we can use go/typechecker directly.
593
594// splitFunc splits "func(x,y,z) (a,b,c)" into ["x", "y", "z"] and ["a", "b", "c"].
595func splitFunc(s string) (in, out []string) {
596	if !strings.HasPrefix(s, "func(") {
597		return nil, nil
598	}
599
600	i := len("func(") // index of beginning of 'in' arguments
601	nparen := 0
602	for j := i; j < len(s); j++ {
603		switch s[j] {
604		case '(':
605			nparen++
606		case ')':
607			nparen--
608			if nparen < 0 {
609				// found end of parameter list
610				out := strings.TrimSpace(s[j+1:])
611				if len(out) >= 2 && out[0] == '(' && out[len(out)-1] == ')' {
612					out = out[1 : len(out)-1]
613				}
614				return split(s[i:j]), split(out)
615			}
616		}
617	}
618	return nil, nil
619}
620
621// joinFunc is the inverse of splitFunc.
622func joinFunc(in, out []string) string {
623	outs := ""
624	if len(out) == 1 {
625		outs = " " + out[0]
626	} else if len(out) > 1 {
627		outs = " (" + join(out) + ")"
628	}
629	return "func(" + join(in) + ")" + outs
630}
631
632// split splits "int, float" into ["int", "float"] and splits "" into [].
633func split(s string) []string {
634	out := []string{}
635	i := 0 // current type being scanned is s[i:j].
636	nparen := 0
637	for j := 0; j < len(s); j++ {
638		switch s[j] {
639		case ' ':
640			if i == j {
641				i++
642			}
643		case '(':
644			nparen++
645		case ')':
646			nparen--
647			if nparen < 0 {
648				// probably can't happen
649				return nil
650			}
651		case ',':
652			if nparen == 0 {
653				if i < j {
654					out = append(out, s[i:j])
655				}
656				i = j + 1
657			}
658		}
659	}
660	if nparen != 0 {
661		// probably can't happen
662		return nil
663	}
664	if i < len(s) {
665		out = append(out, s[i:])
666	}
667	return out
668}
669
670// join is the inverse of split.
671func join(x []string) string {
672	return strings.Join(x, ", ")
673}
674