1// Copyright 2013 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ssa_test
6
7// This file defines tests of source-level debugging utilities.
8
9import (
10	"fmt"
11	"go/ast"
12	"go/constant"
13	"go/parser"
14	"go/token"
15	"go/types"
16	"io/ioutil"
17	"os"
18	"runtime"
19	"strings"
20	"testing"
21
22	"golang.org/x/tools/go/ast/astutil"
23	"golang.org/x/tools/go/expect"
24	"golang.org/x/tools/go/loader"
25	"golang.org/x/tools/go/ssa"
26	"golang.org/x/tools/go/ssa/ssautil"
27)
28
29func TestObjValueLookup(t *testing.T) {
30	if runtime.GOOS == "android" {
31		t.Skipf("no testdata directory on %s", runtime.GOOS)
32	}
33
34	conf := loader.Config{ParserMode: parser.ParseComments}
35	src, err := ioutil.ReadFile("testdata/objlookup.go")
36	if err != nil {
37		t.Fatal(err)
38	}
39	readFile := func(filename string) ([]byte, error) { return src, nil }
40	f, err := conf.ParseFile("testdata/objlookup.go", src)
41	if err != nil {
42		t.Fatal(err)
43	}
44	conf.CreateFromFiles("main", f)
45
46	// Maps each var Ident (represented "name:linenum") to the
47	// kind of ssa.Value we expect (represented "Constant", "&Alloc").
48	expectations := make(map[string]string)
49
50	// Each note of the form @ssa(x, "BinOp") in testdata/objlookup.go
51	// specifies an expectation that an object named x declared on the
52	// same line is associated with an an ssa.Value of type *ssa.BinOp.
53	notes, err := expect.ExtractGo(conf.Fset, f)
54	if err != nil {
55		t.Fatal(err)
56	}
57	for _, n := range notes {
58		if n.Name != "ssa" {
59			t.Errorf("%v: unexpected note type %q, want \"ssa\"", conf.Fset.Position(n.Pos), n.Name)
60			continue
61		}
62		if len(n.Args) != 2 {
63			t.Errorf("%v: ssa has %d args, want 2", conf.Fset.Position(n.Pos), len(n.Args))
64			continue
65		}
66		ident, ok := n.Args[0].(expect.Identifier)
67		if !ok {
68			t.Errorf("%v: got %v for arg 1, want identifier", conf.Fset.Position(n.Pos), n.Args[0])
69			continue
70		}
71		exp, ok := n.Args[1].(string)
72		if !ok {
73			t.Errorf("%v: got %v for arg 2, want string", conf.Fset.Position(n.Pos), n.Args[1])
74			continue
75		}
76		p, _, err := expect.MatchBefore(conf.Fset, readFile, n.Pos, string(ident))
77		if err != nil {
78			t.Error(err)
79			continue
80		}
81		pos := conf.Fset.Position(p)
82		key := fmt.Sprintf("%s:%d", ident, pos.Line)
83		expectations[key] = exp
84	}
85
86	iprog, err := conf.Load()
87	if err != nil {
88		t.Error(err)
89		return
90	}
91
92	prog := ssautil.CreateProgram(iprog, 0 /*|ssa.PrintFunctions*/)
93	mainInfo := iprog.Created[0]
94	mainPkg := prog.Package(mainInfo.Pkg)
95	mainPkg.SetDebugMode(true)
96	mainPkg.Build()
97
98	var varIds []*ast.Ident
99	var varObjs []*types.Var
100	for id, obj := range mainInfo.Defs {
101		// Check invariants for func and const objects.
102		switch obj := obj.(type) {
103		case *types.Func:
104			checkFuncValue(t, prog, obj)
105
106		case *types.Const:
107			checkConstValue(t, prog, obj)
108
109		case *types.Var:
110			if id.Name == "_" {
111				continue
112			}
113			varIds = append(varIds, id)
114			varObjs = append(varObjs, obj)
115		}
116	}
117	for id, obj := range mainInfo.Uses {
118		if obj, ok := obj.(*types.Var); ok {
119			varIds = append(varIds, id)
120			varObjs = append(varObjs, obj)
121		}
122	}
123
124	// Check invariants for var objects.
125	// The result varies based on the specific Ident.
126	for i, id := range varIds {
127		obj := varObjs[i]
128		ref, _ := astutil.PathEnclosingInterval(f, id.Pos(), id.Pos())
129		pos := prog.Fset.Position(id.Pos())
130		exp := expectations[fmt.Sprintf("%s:%d", id.Name, pos.Line)]
131		if exp == "" {
132			t.Errorf("%s: no expectation for var ident %s ", pos, id.Name)
133			continue
134		}
135		wantAddr := false
136		if exp[0] == '&' {
137			wantAddr = true
138			exp = exp[1:]
139		}
140		checkVarValue(t, prog, mainPkg, ref, obj, exp, wantAddr)
141	}
142}
143
144func checkFuncValue(t *testing.T, prog *ssa.Program, obj *types.Func) {
145	fn := prog.FuncValue(obj)
146	// fmt.Printf("FuncValue(%s) = %s\n", obj, fn) // debugging
147	if fn == nil {
148		if obj.Name() != "interfaceMethod" {
149			t.Errorf("FuncValue(%s) == nil", obj)
150		}
151		return
152	}
153	if fnobj := fn.Object(); fnobj != obj {
154		t.Errorf("FuncValue(%s).Object() == %s; value was %s",
155			obj, fnobj, fn.Name())
156		return
157	}
158	if !types.Identical(fn.Type(), obj.Type()) {
159		t.Errorf("FuncValue(%s).Type() == %s", obj, fn.Type())
160		return
161	}
162}
163
164func checkConstValue(t *testing.T, prog *ssa.Program, obj *types.Const) {
165	c := prog.ConstValue(obj)
166	// fmt.Printf("ConstValue(%s) = %s\n", obj, c) // debugging
167	if c == nil {
168		t.Errorf("ConstValue(%s) == nil", obj)
169		return
170	}
171	if !types.Identical(c.Type(), obj.Type()) {
172		t.Errorf("ConstValue(%s).Type() == %s", obj, c.Type())
173		return
174	}
175	if obj.Name() != "nil" {
176		if !constant.Compare(c.Value, token.EQL, obj.Val()) {
177			t.Errorf("ConstValue(%s).Value (%s) != %s",
178				obj, c.Value, obj.Val())
179			return
180		}
181	}
182}
183
184func checkVarValue(t *testing.T, prog *ssa.Program, pkg *ssa.Package, ref []ast.Node, obj *types.Var, expKind string, wantAddr bool) {
185	// The prefix of all assertions messages.
186	prefix := fmt.Sprintf("VarValue(%s @ L%d)",
187		obj, prog.Fset.Position(ref[0].Pos()).Line)
188
189	v, gotAddr := prog.VarValue(obj, pkg, ref)
190
191	// Kind is the concrete type of the ssa Value.
192	gotKind := "nil"
193	if v != nil {
194		gotKind = fmt.Sprintf("%T", v)[len("*ssa."):]
195	}
196
197	// fmt.Printf("%s = %v (kind %q; expect %q) wantAddr=%t gotAddr=%t\n", prefix, v, gotKind, expKind, wantAddr, gotAddr) // debugging
198
199	// Check the kinds match.
200	// "nil" indicates expected failure (e.g. optimized away).
201	if expKind != gotKind {
202		t.Errorf("%s concrete type == %s, want %s", prefix, gotKind, expKind)
203	}
204
205	// Check the types match.
206	// If wantAddr, the expected type is the object's address.
207	if v != nil {
208		expType := obj.Type()
209		if wantAddr {
210			expType = types.NewPointer(expType)
211			if !gotAddr {
212				t.Errorf("%s: got value, want address", prefix)
213			}
214		} else if gotAddr {
215			t.Errorf("%s: got address, want value", prefix)
216		}
217		if !types.Identical(v.Type(), expType) {
218			t.Errorf("%s.Type() == %s, want %s", prefix, v.Type(), expType)
219		}
220	}
221}
222
223// Ensure that, in debug mode, we can determine the ssa.Value
224// corresponding to every ast.Expr.
225func TestValueForExpr(t *testing.T) {
226	testValueForExpr(t, "testdata/valueforexpr.go")
227}
228
229func testValueForExpr(t *testing.T, testfile string) {
230	if runtime.GOOS == "android" {
231		t.Skipf("no testdata dir on %s", runtime.GOOS)
232	}
233
234	conf := loader.Config{ParserMode: parser.ParseComments}
235	f, err := conf.ParseFile(testfile, nil)
236	if err != nil {
237		t.Error(err)
238		return
239	}
240	conf.CreateFromFiles("main", f)
241
242	iprog, err := conf.Load()
243	if err != nil {
244		t.Error(err)
245		return
246	}
247
248	mainInfo := iprog.Created[0]
249
250	prog := ssautil.CreateProgram(iprog, 0)
251	mainPkg := prog.Package(mainInfo.Pkg)
252	mainPkg.SetDebugMode(true)
253	mainPkg.Build()
254
255	if false {
256		// debugging
257		for _, mem := range mainPkg.Members {
258			if fn, ok := mem.(*ssa.Function); ok {
259				fn.WriteTo(os.Stderr)
260			}
261		}
262	}
263
264	var parenExprs []*ast.ParenExpr
265	ast.Inspect(f, func(n ast.Node) bool {
266		if n != nil {
267			if e, ok := n.(*ast.ParenExpr); ok {
268				parenExprs = append(parenExprs, e)
269			}
270		}
271		return true
272	})
273
274	notes, err := expect.ExtractGo(prog.Fset, f)
275	if err != nil {
276		t.Fatal(err)
277	}
278	for _, n := range notes {
279		want := n.Name
280		if want == "nil" {
281			want = "<nil>"
282		}
283		position := prog.Fset.Position(n.Pos)
284		var e ast.Expr
285		for _, paren := range parenExprs {
286			if paren.Pos() > n.Pos {
287				e = paren.X
288				break
289			}
290		}
291		if e == nil {
292			t.Errorf("%s: note doesn't precede ParenExpr: %q", position, want)
293			continue
294		}
295
296		path, _ := astutil.PathEnclosingInterval(f, n.Pos, n.Pos)
297		if path == nil {
298			t.Errorf("%s: can't find AST path from root to comment: %s", position, want)
299			continue
300		}
301
302		fn := ssa.EnclosingFunction(mainPkg, path)
303		if fn == nil {
304			t.Errorf("%s: can't find enclosing function", position)
305			continue
306		}
307
308		v, gotAddr := fn.ValueForExpr(e) // (may be nil)
309		got := strings.TrimPrefix(fmt.Sprintf("%T", v), "*ssa.")
310		if got != want {
311			t.Errorf("%s: got value %q, want %q", position, got, want)
312		}
313		if v != nil {
314			T := v.Type()
315			if gotAddr {
316				T = T.Underlying().(*types.Pointer).Elem() // deref
317			}
318			if !types.Identical(T, mainInfo.TypeOf(e)) {
319				t.Errorf("%s: got type %s, want %s", position, mainInfo.TypeOf(e), T)
320			}
321		}
322	}
323}
324
325// findInterval parses input and returns the [start, end) positions of
326// the first occurrence of substr in input.  f==nil indicates failure;
327// an error has already been reported in that case.
328//
329func findInterval(t *testing.T, fset *token.FileSet, input, substr string) (f *ast.File, start, end token.Pos) {
330	f, err := parser.ParseFile(fset, "<input>", input, 0)
331	if err != nil {
332		t.Errorf("parse error: %s", err)
333		return
334	}
335
336	i := strings.Index(input, substr)
337	if i < 0 {
338		t.Errorf("%q is not a substring of input", substr)
339		f = nil
340		return
341	}
342
343	filePos := fset.File(f.Package)
344	return f, filePos.Pos(i), filePos.Pos(i + len(substr))
345}
346
347func TestEnclosingFunction(t *testing.T) {
348	tests := []struct {
349		input  string // the input file
350		substr string // first occurrence of this string denotes interval
351		fn     string // name of expected containing function
352	}{
353		// We use distinctive numbers as syntactic landmarks.
354
355		// Ordinary function:
356		{`package main
357		  func f() { println(1003) }`,
358			"100", "main.f"},
359		// Methods:
360		{`package main
361                  type T int
362		  func (t T) f() { println(200) }`,
363			"200", "(main.T).f"},
364		// Function literal:
365		{`package main
366		  func f() { println(func() { print(300) }) }`,
367			"300", "main.f$1"},
368		// Doubly nested
369		{`package main
370		  func f() { println(func() { print(func() { print(350) })})}`,
371			"350", "main.f$1$1"},
372		// Implicit init for package-level var initializer.
373		{"package main; var a = 400", "400", "main.init"},
374		// No code for constants:
375		{"package main; const a = 500", "500", "(none)"},
376		// Explicit init()
377		{"package main; func init() { println(600) }", "600", "main.init#1"},
378		// Multiple explicit init functions:
379		{`package main
380		  func init() { println("foo") }
381		  func init() { println(800) }`,
382			"800", "main.init#2"},
383		// init() containing FuncLit.
384		{`package main
385		  func init() { println(func(){print(900)}) }`,
386			"900", "main.init#1$1"},
387	}
388	for _, test := range tests {
389		conf := loader.Config{Fset: token.NewFileSet()}
390		f, start, end := findInterval(t, conf.Fset, test.input, test.substr)
391		if f == nil {
392			continue
393		}
394		path, exact := astutil.PathEnclosingInterval(f, start, end)
395		if !exact {
396			t.Errorf("EnclosingFunction(%q) not exact", test.substr)
397			continue
398		}
399
400		conf.CreateFromFiles("main", f)
401
402		iprog, err := conf.Load()
403		if err != nil {
404			t.Error(err)
405			continue
406		}
407		prog := ssautil.CreateProgram(iprog, 0)
408		pkg := prog.Package(iprog.Created[0].Pkg)
409		pkg.Build()
410
411		name := "(none)"
412		fn := ssa.EnclosingFunction(pkg, path)
413		if fn != nil {
414			name = fn.String()
415		}
416
417		if name != test.fn {
418			t.Errorf("EnclosingFunction(%q in %q) got %s, want %s",
419				test.substr, test.input, name, test.fn)
420			continue
421		}
422
423		// While we're here: test HasEnclosingFunction.
424		if has := ssa.HasEnclosingFunction(pkg, path); has != (fn != nil) {
425			t.Errorf("HasEnclosingFunction(%q in %q) got %v, want %v",
426				test.substr, test.input, has, fn != nil)
427			continue
428		}
429	}
430}
431