1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package load
6
7import (
8	"bytes"
9	"cmd/go/internal/base"
10	"cmd/go/internal/str"
11	"errors"
12	"fmt"
13	"go/ast"
14	"go/build"
15	"go/doc"
16	"go/parser"
17	"go/token"
18	"path/filepath"
19	"sort"
20	"strings"
21	"text/template"
22	"unicode"
23	"unicode/utf8"
24)
25
26var TestMainDeps = []string{
27	// Dependencies for testmain.
28	"os",
29	"testing",
30	"testing/internal/testdeps",
31}
32
33type TestCover struct {
34	Mode     string
35	Local    bool
36	Pkgs     []*Package
37	Paths    []string
38	Vars     []coverInfo
39	DeclVars func(*Package, ...string) map[string]*CoverVar
40}
41
42// TestPackagesFor returns three packages:
43//	- ptest, the package p compiled with added "package p" test files.
44//	- pxtest, the result of compiling any "package p_test" (external) test files.
45//	- pmain, the package main corresponding to the test binary (running tests in ptest and pxtest).
46//
47// If the package has no "package p_test" test files, pxtest will be nil.
48// If the non-test compilation of package p can be reused
49// (for example, if there are no "package p" test files and
50// package p need not be instrumented for coverage or any other reason),
51// then the returned ptest == p.
52//
53// The caller is expected to have checked that len(p.TestGoFiles)+len(p.XTestGoFiles) > 0,
54// or else there's no point in any of this.
55func GetTestPackagesFor(p *Package, cover *TestCover) (pmain, ptest, pxtest *Package, err error) {
56	var imports, ximports []*Package
57	var stk ImportStack
58	stk.Push(p.ImportPath + " (test)")
59	rawTestImports := str.StringList(p.TestImports)
60	for i, path := range p.TestImports {
61		p1 := LoadImport(path, p.Dir, p, &stk, p.Internal.Build.TestImportPos[path], ResolveImport)
62		if p1.Error != nil {
63			return nil, nil, nil, p1.Error
64		}
65		if len(p1.DepsErrors) > 0 {
66			err := p1.DepsErrors[0]
67			err.Pos = "" // show full import stack
68			return nil, nil, nil, err
69		}
70		if str.Contains(p1.Deps, p.ImportPath) || p1.ImportPath == p.ImportPath {
71			// Same error that loadPackage returns (via reusePackage) in pkg.go.
72			// Can't change that code, because that code is only for loading the
73			// non-test copy of a package.
74			err := &PackageError{
75				ImportStack:   testImportStack(stk[0], p1, p.ImportPath),
76				Err:           "import cycle not allowed in test",
77				IsImportCycle: true,
78			}
79			return nil, nil, nil, err
80		}
81		p.TestImports[i] = p1.ImportPath
82		imports = append(imports, p1)
83	}
84	stk.Pop()
85	stk.Push(p.ImportPath + "_test")
86	pxtestNeedsPtest := false
87	rawXTestImports := str.StringList(p.XTestImports)
88	for i, path := range p.XTestImports {
89		p1 := LoadImport(path, p.Dir, p, &stk, p.Internal.Build.XTestImportPos[path], ResolveImport)
90		if p1.Error != nil {
91			return nil, nil, nil, p1.Error
92		}
93		if len(p1.DepsErrors) > 0 {
94			err := p1.DepsErrors[0]
95			err.Pos = "" // show full import stack
96			return nil, nil, nil, err
97		}
98		if p1.ImportPath == p.ImportPath {
99			pxtestNeedsPtest = true
100		} else {
101			ximports = append(ximports, p1)
102		}
103		p.XTestImports[i] = p1.ImportPath
104	}
105	stk.Pop()
106
107	// Test package.
108	if len(p.TestGoFiles) > 0 || p.Name == "main" || cover != nil && cover.Local {
109		ptest = new(Package)
110		*ptest = *p
111		ptest.ForTest = p.ImportPath
112		ptest.GoFiles = nil
113		ptest.GoFiles = append(ptest.GoFiles, p.GoFiles...)
114		ptest.GoFiles = append(ptest.GoFiles, p.TestGoFiles...)
115		ptest.Target = ""
116		// Note: The preparation of the vet config requires that common
117		// indexes in ptest.Imports and ptest.Internal.RawImports
118		// all line up (but RawImports can be shorter than the others).
119		// That is, for 0 ≤ i < len(RawImports),
120		// RawImports[i] is the import string in the program text, and
121		// Imports[i] is the expanded import string (vendoring applied or relative path expanded away).
122		// Any implicitly added imports appear in Imports and Internal.Imports
123		// but not RawImports (because they were not in the source code).
124		// We insert TestImports, imports, and rawTestImports at the start of
125		// these lists to preserve the alignment.
126		// Note that p.Internal.Imports may not be aligned with p.Imports/p.Internal.RawImports,
127		// but we insert at the beginning there too just for consistency.
128		ptest.Imports = str.StringList(p.TestImports, p.Imports)
129		ptest.Internal.Imports = append(imports, p.Internal.Imports...)
130		ptest.Internal.RawImports = str.StringList(rawTestImports, p.Internal.RawImports)
131		ptest.Internal.ForceLibrary = true
132		ptest.Internal.BuildInfo = ""
133		ptest.Internal.Build = new(build.Package)
134		*ptest.Internal.Build = *p.Internal.Build
135		m := map[string][]token.Position{}
136		for k, v := range p.Internal.Build.ImportPos {
137			m[k] = append(m[k], v...)
138		}
139		for k, v := range p.Internal.Build.TestImportPos {
140			m[k] = append(m[k], v...)
141		}
142		ptest.Internal.Build.ImportPos = m
143	} else {
144		ptest = p
145	}
146
147	// External test package.
148	if len(p.XTestGoFiles) > 0 {
149		pxtest = &Package{
150			PackagePublic: PackagePublic{
151				Name:       p.Name + "_test",
152				ImportPath: p.ImportPath + "_test",
153				Root:       p.Root,
154				Dir:        p.Dir,
155				GoFiles:    p.XTestGoFiles,
156				Imports:    p.XTestImports,
157				ForTest:    p.ImportPath,
158			},
159			Internal: PackageInternal{
160				LocalPrefix: p.Internal.LocalPrefix,
161				Build: &build.Package{
162					ImportPos: p.Internal.Build.XTestImportPos,
163				},
164				Imports:    ximports,
165				RawImports: rawXTestImports,
166
167				Asmflags:   p.Internal.Asmflags,
168				Gcflags:    p.Internal.Gcflags,
169				Ldflags:    p.Internal.Ldflags,
170				Gccgoflags: p.Internal.Gccgoflags,
171			},
172		}
173		if pxtestNeedsPtest {
174			pxtest.Internal.Imports = append(pxtest.Internal.Imports, ptest)
175		}
176	}
177
178	// Build main package.
179	pmain = &Package{
180		PackagePublic: PackagePublic{
181			Name:       "main",
182			Dir:        p.Dir,
183			GoFiles:    []string{"_testmain.go"},
184			ImportPath: p.ImportPath + ".test",
185			Root:       p.Root,
186			Imports:    str.StringList(TestMainDeps),
187		},
188		Internal: PackageInternal{
189			Build:      &build.Package{Name: "main"},
190			BuildInfo:  p.Internal.BuildInfo,
191			Asmflags:   p.Internal.Asmflags,
192			Gcflags:    p.Internal.Gcflags,
193			Ldflags:    p.Internal.Ldflags,
194			Gccgoflags: p.Internal.Gccgoflags,
195		},
196	}
197
198	// The generated main also imports testing, regexp, and os.
199	// Also the linker introduces implicit dependencies reported by LinkerDeps.
200	stk.Push("testmain")
201	deps := TestMainDeps // cap==len, so safe for append
202	for _, d := range LinkerDeps(p) {
203		deps = append(deps, d)
204	}
205	for _, dep := range deps {
206		if dep == ptest.ImportPath {
207			pmain.Internal.Imports = append(pmain.Internal.Imports, ptest)
208		} else {
209			p1 := LoadImport(dep, "", nil, &stk, nil, 0)
210			if p1.Error != nil {
211				return nil, nil, nil, p1.Error
212			}
213			pmain.Internal.Imports = append(pmain.Internal.Imports, p1)
214		}
215	}
216	stk.Pop()
217
218	if cover != nil && cover.Pkgs != nil {
219		// Add imports, but avoid duplicates.
220		seen := map[*Package]bool{p: true, ptest: true}
221		for _, p1 := range pmain.Internal.Imports {
222			seen[p1] = true
223		}
224		for _, p1 := range cover.Pkgs {
225			if !seen[p1] {
226				seen[p1] = true
227				pmain.Internal.Imports = append(pmain.Internal.Imports, p1)
228			}
229		}
230	}
231
232	allTestImports := make([]*Package, 0, len(pmain.Internal.Imports)+len(imports)+len(ximports))
233	allTestImports = append(allTestImports, pmain.Internal.Imports...)
234	allTestImports = append(allTestImports, imports...)
235	allTestImports = append(allTestImports, ximports...)
236	setToolFlags(allTestImports...)
237
238	// Do initial scan for metadata needed for writing _testmain.go
239	// Use that metadata to update the list of imports for package main.
240	// The list of imports is used by recompileForTest and by the loop
241	// afterward that gathers t.Cover information.
242	t, err := loadTestFuncs(ptest)
243	if err != nil {
244		return nil, nil, nil, err
245	}
246	t.Cover = cover
247	if len(ptest.GoFiles)+len(ptest.CgoFiles) > 0 {
248		pmain.Internal.Imports = append(pmain.Internal.Imports, ptest)
249		pmain.Imports = append(pmain.Imports, ptest.ImportPath)
250		t.ImportTest = true
251	}
252	if pxtest != nil {
253		pmain.Internal.Imports = append(pmain.Internal.Imports, pxtest)
254		pmain.Imports = append(pmain.Imports, pxtest.ImportPath)
255		t.ImportXtest = true
256	}
257
258	// Sort and dedup pmain.Imports.
259	// Only matters for go list -test output.
260	sort.Strings(pmain.Imports)
261	w := 0
262	for _, path := range pmain.Imports {
263		if w == 0 || path != pmain.Imports[w-1] {
264			pmain.Imports[w] = path
265			w++
266		}
267	}
268	pmain.Imports = pmain.Imports[:w]
269	pmain.Internal.RawImports = str.StringList(pmain.Imports)
270
271	// Replace pmain's transitive dependencies with test copies, as necessary.
272	recompileForTest(pmain, p, ptest, pxtest)
273
274	// Should we apply coverage analysis locally,
275	// only for this package and only for this test?
276	// Yes, if -cover is on but -coverpkg has not specified
277	// a list of packages for global coverage.
278	if cover != nil && cover.Local {
279		ptest.Internal.CoverMode = cover.Mode
280		var coverFiles []string
281		coverFiles = append(coverFiles, ptest.GoFiles...)
282		coverFiles = append(coverFiles, ptest.CgoFiles...)
283		ptest.Internal.CoverVars = cover.DeclVars(ptest, coverFiles...)
284	}
285
286	for _, cp := range pmain.Internal.Imports {
287		if len(cp.Internal.CoverVars) > 0 {
288			t.Cover.Vars = append(t.Cover.Vars, coverInfo{cp, cp.Internal.CoverVars})
289		}
290	}
291
292	data, err := formatTestmain(t)
293	if err != nil {
294		return nil, nil, nil, err
295	}
296	pmain.Internal.TestmainGo = &data
297
298	return pmain, ptest, pxtest, nil
299}
300
301func testImportStack(top string, p *Package, target string) []string {
302	stk := []string{top, p.ImportPath}
303Search:
304	for p.ImportPath != target {
305		for _, p1 := range p.Internal.Imports {
306			if p1.ImportPath == target || str.Contains(p1.Deps, target) {
307				stk = append(stk, p1.ImportPath)
308				p = p1
309				continue Search
310			}
311		}
312		// Can't happen, but in case it does...
313		stk = append(stk, "<lost path to cycle>")
314		break
315	}
316	return stk
317}
318
319// recompileForTest copies and replaces certain packages in pmain's dependency
320// graph. This is necessary for two reasons. First, if ptest is different than
321// preal, packages that import the package under test should get ptest instead
322// of preal. This is particularly important if pxtest depends on functionality
323// exposed in test sources in ptest. Second, if there is a main package
324// (other than pmain) anywhere, we need to clear p.Internal.BuildInfo in
325// the test copy to prevent link conflicts. This may happen if both -coverpkg
326// and the command line patterns include multiple main packages.
327func recompileForTest(pmain, preal, ptest, pxtest *Package) {
328	// The "test copy" of preal is ptest.
329	// For each package that depends on preal, make a "test copy"
330	// that depends on ptest. And so on, up the dependency tree.
331	testCopy := map[*Package]*Package{preal: ptest}
332	for _, p := range PackageList([]*Package{pmain}) {
333		if p == preal {
334			continue
335		}
336		// Copy on write.
337		didSplit := p == pmain || p == pxtest
338		split := func() {
339			if didSplit {
340				return
341			}
342			didSplit = true
343			if testCopy[p] != nil {
344				panic("recompileForTest loop")
345			}
346			p1 := new(Package)
347			testCopy[p] = p1
348			*p1 = *p
349			p1.ForTest = preal.ImportPath
350			p1.Internal.Imports = make([]*Package, len(p.Internal.Imports))
351			copy(p1.Internal.Imports, p.Internal.Imports)
352			p1.Imports = make([]string, len(p.Imports))
353			copy(p1.Imports, p.Imports)
354			p = p1
355			p.Target = ""
356			p.Internal.BuildInfo = ""
357		}
358
359		// Update p.Internal.Imports to use test copies.
360		for i, imp := range p.Internal.Imports {
361			if p1 := testCopy[imp]; p1 != nil && p1 != imp {
362				split()
363				p.Internal.Imports[i] = p1
364			}
365		}
366
367		// Don't compile build info from a main package. This can happen
368		// if -coverpkg patterns include main packages, since those packages
369		// are imported by pmain. See golang.org/issue/30907.
370		if p.Internal.BuildInfo != "" && p != pmain {
371			split()
372		}
373	}
374}
375
376// isTestFunc tells whether fn has the type of a testing function. arg
377// specifies the parameter type we look for: B, M or T.
378func isTestFunc(fn *ast.FuncDecl, arg string) bool {
379	if fn.Type.Results != nil && len(fn.Type.Results.List) > 0 ||
380		fn.Type.Params.List == nil ||
381		len(fn.Type.Params.List) != 1 ||
382		len(fn.Type.Params.List[0].Names) > 1 {
383		return false
384	}
385	ptr, ok := fn.Type.Params.List[0].Type.(*ast.StarExpr)
386	if !ok {
387		return false
388	}
389	// We can't easily check that the type is *testing.M
390	// because we don't know how testing has been imported,
391	// but at least check that it's *M or *something.M.
392	// Same applies for B and T.
393	if name, ok := ptr.X.(*ast.Ident); ok && name.Name == arg {
394		return true
395	}
396	if sel, ok := ptr.X.(*ast.SelectorExpr); ok && sel.Sel.Name == arg {
397		return true
398	}
399	return false
400}
401
402// isTest tells whether name looks like a test (or benchmark, according to prefix).
403// It is a Test (say) if there is a character after Test that is not a lower-case letter.
404// We don't want TesticularCancer.
405func isTest(name, prefix string) bool {
406	if !strings.HasPrefix(name, prefix) {
407		return false
408	}
409	if len(name) == len(prefix) { // "Test" is ok
410		return true
411	}
412	rune, _ := utf8.DecodeRuneInString(name[len(prefix):])
413	return !unicode.IsLower(rune)
414}
415
416type coverInfo struct {
417	Package *Package
418	Vars    map[string]*CoverVar
419}
420
421// loadTestFuncs returns the testFuncs describing the tests that will be run.
422func loadTestFuncs(ptest *Package) (*testFuncs, error) {
423	t := &testFuncs{
424		Package: ptest,
425	}
426	for _, file := range ptest.TestGoFiles {
427		if err := t.load(filepath.Join(ptest.Dir, file), "_test", &t.ImportTest, &t.NeedTest); err != nil {
428			return nil, err
429		}
430	}
431	for _, file := range ptest.XTestGoFiles {
432		if err := t.load(filepath.Join(ptest.Dir, file), "_xtest", &t.ImportXtest, &t.NeedXtest); err != nil {
433			return nil, err
434		}
435	}
436	return t, nil
437}
438
439// formatTestmain returns the content of the _testmain.go file for t.
440func formatTestmain(t *testFuncs) ([]byte, error) {
441	var buf bytes.Buffer
442	if err := testmainTmpl.Execute(&buf, t); err != nil {
443		return nil, err
444	}
445	return buf.Bytes(), nil
446}
447
448type testFuncs struct {
449	Tests       []testFunc
450	Benchmarks  []testFunc
451	Examples    []testFunc
452	TestMain    *testFunc
453	Package     *Package
454	ImportTest  bool
455	NeedTest    bool
456	ImportXtest bool
457	NeedXtest   bool
458	Cover       *TestCover
459}
460
461// ImportPath returns the import path of the package being tested, if it is within GOPATH.
462// This is printed by the testing package when running benchmarks.
463func (t *testFuncs) ImportPath() string {
464	pkg := t.Package.ImportPath
465	if strings.HasPrefix(pkg, "_/") {
466		return ""
467	}
468	if pkg == "command-line-arguments" {
469		return ""
470	}
471	return pkg
472}
473
474// Covered returns a string describing which packages are being tested for coverage.
475// If the covered package is the same as the tested package, it returns the empty string.
476// Otherwise it is a comma-separated human-readable list of packages beginning with
477// " in", ready for use in the coverage message.
478func (t *testFuncs) Covered() string {
479	if t.Cover == nil || t.Cover.Paths == nil {
480		return ""
481	}
482	return " in " + strings.Join(t.Cover.Paths, ", ")
483}
484
485// Tested returns the name of the package being tested.
486func (t *testFuncs) Tested() string {
487	return t.Package.Name
488}
489
490type testFunc struct {
491	Package   string // imported package name (_test or _xtest)
492	Name      string // function name
493	Output    string // output, for examples
494	Unordered bool   // output is allowed to be unordered.
495}
496
497var testFileSet = token.NewFileSet()
498
499func (t *testFuncs) load(filename, pkg string, doImport, seen *bool) error {
500	f, err := parser.ParseFile(testFileSet, filename, nil, parser.ParseComments)
501	if err != nil {
502		return base.ExpandScanner(err)
503	}
504	for _, d := range f.Decls {
505		n, ok := d.(*ast.FuncDecl)
506		if !ok {
507			continue
508		}
509		if n.Recv != nil {
510			continue
511		}
512		name := n.Name.String()
513		switch {
514		case name == "TestMain":
515			if isTestFunc(n, "T") {
516				t.Tests = append(t.Tests, testFunc{pkg, name, "", false})
517				*doImport, *seen = true, true
518				continue
519			}
520			err := checkTestFunc(n, "M")
521			if err != nil {
522				return err
523			}
524			if t.TestMain != nil {
525				return errors.New("multiple definitions of TestMain")
526			}
527			t.TestMain = &testFunc{pkg, name, "", false}
528			*doImport, *seen = true, true
529		case isTest(name, "Test"):
530			err := checkTestFunc(n, "T")
531			if err != nil {
532				return err
533			}
534			t.Tests = append(t.Tests, testFunc{pkg, name, "", false})
535			*doImport, *seen = true, true
536		case isTest(name, "Benchmark"):
537			err := checkTestFunc(n, "B")
538			if err != nil {
539				return err
540			}
541			t.Benchmarks = append(t.Benchmarks, testFunc{pkg, name, "", false})
542			*doImport, *seen = true, true
543		}
544	}
545	ex := doc.Examples(f)
546	sort.Slice(ex, func(i, j int) bool { return ex[i].Order < ex[j].Order })
547	for _, e := range ex {
548		*doImport = true // import test file whether executed or not
549		if e.Output == "" && !e.EmptyOutput {
550			// Don't run examples with no output.
551			continue
552		}
553		t.Examples = append(t.Examples, testFunc{pkg, "Example" + e.Name, e.Output, e.Unordered})
554		*seen = true
555	}
556	return nil
557}
558
559func checkTestFunc(fn *ast.FuncDecl, arg string) error {
560	if !isTestFunc(fn, arg) {
561		name := fn.Name.String()
562		pos := testFileSet.Position(fn.Pos())
563		return fmt.Errorf("%s: wrong signature for %s, must be: func %s(%s *testing.%s)", pos, name, name, strings.ToLower(arg), arg)
564	}
565	return nil
566}
567
568var testmainTmpl = template.Must(template.New("main").Parse(`
569package main
570
571import (
572{{if not .TestMain}}
573	"os"
574{{end}}
575	"testing"
576	"testing/internal/testdeps"
577
578{{if .ImportTest}}
579	{{if .NeedTest}}_test{{else}}_{{end}} {{.Package.ImportPath | printf "%q"}}
580{{end}}
581{{if .ImportXtest}}
582	{{if .NeedXtest}}_xtest{{else}}_{{end}} {{.Package.ImportPath | printf "%s_test" | printf "%q"}}
583{{end}}
584{{if .Cover}}
585{{range $i, $p := .Cover.Vars}}
586	_cover{{$i}} {{$p.Package.ImportPath | printf "%q"}}
587{{end}}
588{{end}}
589)
590
591var tests = []testing.InternalTest{
592{{range .Tests}}
593	{"{{.Name}}", {{.Package}}.{{.Name}}},
594{{end}}
595}
596
597var benchmarks = []testing.InternalBenchmark{
598{{range .Benchmarks}}
599	{"{{.Name}}", {{.Package}}.{{.Name}}},
600{{end}}
601}
602
603var examples = []testing.InternalExample{
604{{range .Examples}}
605	{"{{.Name}}", {{.Package}}.{{.Name}}, {{.Output | printf "%q"}}, {{.Unordered}}},
606{{end}}
607}
608
609func init() {
610	testdeps.ImportPath = {{.ImportPath | printf "%q"}}
611}
612
613{{if .Cover}}
614
615// Only updated by init functions, so no need for atomicity.
616var (
617	coverCounters = make(map[string][]uint32)
618	coverBlocks = make(map[string][]testing.CoverBlock)
619)
620
621func init() {
622	{{range $i, $p := .Cover.Vars}}
623	{{range $file, $cover := $p.Vars}}
624	coverRegisterFile({{printf "%q" $cover.File}}, _cover{{$i}}.{{$cover.Var}}.Count[:], _cover{{$i}}.{{$cover.Var}}.Pos[:], _cover{{$i}}.{{$cover.Var}}.NumStmt[:])
625	{{end}}
626	{{end}}
627}
628
629func coverRegisterFile(fileName string, counter []uint32, pos []uint32, numStmts []uint16) {
630	if 3*len(counter) != len(pos) || len(counter) != len(numStmts) {
631		panic("coverage: mismatched sizes")
632	}
633	if coverCounters[fileName] != nil {
634		// Already registered.
635		return
636	}
637	coverCounters[fileName] = counter
638	block := make([]testing.CoverBlock, len(counter))
639	for i := range counter {
640		block[i] = testing.CoverBlock{
641			Line0: pos[3*i+0],
642			Col0: uint16(pos[3*i+2]),
643			Line1: pos[3*i+1],
644			Col1: uint16(pos[3*i+2]>>16),
645			Stmts: numStmts[i],
646		}
647	}
648	coverBlocks[fileName] = block
649}
650{{end}}
651
652func main() {
653{{if .Cover}}
654	testing.RegisterCover(testing.Cover{
655		Mode: {{printf "%q" .Cover.Mode}},
656		Counters: coverCounters,
657		Blocks: coverBlocks,
658		CoveredPackages: {{printf "%q" .Covered}},
659	})
660{{end}}
661	m := testing.MainStart(testdeps.TestDeps{}, tests, benchmarks, examples)
662{{with .TestMain}}
663	{{.Package}}.{{.Name}}(m)
664{{else}}
665	os.Exit(m.Run())
666{{end}}
667}
668
669`))
670