1package main
2
3import (
4	"go/ast"
5	"go/parser"
6	"go/token"
7	"io/ioutil"
8	"os"
9	"path/filepath"
10	"strings"
11	"testing"
12)
13
14func TestFileParser_ParseFile(t *testing.T) {
15	fs := token.NewFileSet()
16	file, err := parser.ParseFile(fs, "internal/tests/custom_package_name/greeter/greeter.go", nil, 0)
17	if err != nil {
18		t.Fatalf("Unexpected error: %v", err)
19	}
20
21	p := fileParser{
22		fileSet:            fs,
23		imports:            make(map[string]importedPackage),
24		importedInterfaces: make(map[string]map[string]*ast.InterfaceType),
25	}
26
27	pkg, err := p.parseFile("", file)
28	if err != nil {
29		t.Fatalf("Unexpected error: %v", err)
30	}
31
32	checkGreeterImports(t, p.imports)
33
34	expectedName := "greeter"
35	if pkg.Name != expectedName {
36		t.Fatalf("Expected name to be %v but got %v", expectedName, pkg.Name)
37	}
38
39	expectedInterfaceName := "InputMaker"
40	if pkg.Interfaces[0].Name != expectedInterfaceName {
41		t.Fatalf("Expected interface name to be %v but got %v", expectedInterfaceName, pkg.Interfaces[0].Name)
42	}
43}
44
45func TestFileParser_ParsePackage(t *testing.T) {
46	fs := token.NewFileSet()
47	_, err := parser.ParseFile(fs, "internal/tests/custom_package_name/greeter/greeter.go", nil, 0)
48	if err != nil {
49		t.Fatalf("Unexpected error: %v", err)
50	}
51
52	p := fileParser{
53		fileSet:            fs,
54		imports:            make(map[string]importedPackage),
55		importedInterfaces: make(map[string]map[string]*ast.InterfaceType),
56	}
57
58	newP, err := p.parsePackage("github.com/golang/mock/mockgen/internal/tests/custom_package_name/greeter")
59	if err != nil {
60		t.Fatalf("Unexpected error: %v", err)
61	}
62
63	checkGreeterImports(t, newP.imports)
64}
65
66func TestImportsOfFile(t *testing.T) {
67	fs := token.NewFileSet()
68	file, err := parser.ParseFile(fs, "internal/tests/custom_package_name/greeter/greeter.go", nil, 0)
69	if err != nil {
70		t.Fatalf("Unexpected error: %v", err)
71	}
72
73	imports, _ := importsOfFile(file)
74	checkGreeterImports(t, imports)
75}
76
77func checkGreeterImports(t *testing.T, imports map[string]importedPackage) {
78	// check that imports have stdlib package "fmt"
79	if fmtPackage, ok := imports["fmt"]; !ok {
80		t.Errorf("Expected imports to have key \"fmt\"")
81	} else {
82		expectedFmtPackage := "fmt"
83		if fmtPackage.Path() != expectedFmtPackage {
84			t.Errorf("Expected fmt key to have value %s but got %s", expectedFmtPackage, fmtPackage.Path())
85		}
86	}
87
88	// check that imports have package named "validator"
89	if validatorPackage, ok := imports["validator"]; !ok {
90		t.Errorf("Expected imports to have key \"fmt\"")
91	} else {
92		expectedValidatorPackage := "github.com/golang/mock/mockgen/internal/tests/custom_package_name/validator"
93		if validatorPackage.Path() != expectedValidatorPackage {
94			t.Errorf("Expected validator key to have value %s but got %s", expectedValidatorPackage, validatorPackage.Path())
95		}
96	}
97
98	// check that imports have package named "client"
99	if clientPackage, ok := imports["client"]; !ok {
100		t.Errorf("Expected imports to have key \"client\"")
101	} else {
102		expectedClientPackage := "github.com/golang/mock/mockgen/internal/tests/custom_package_name/client/v1"
103		if clientPackage.Path() != expectedClientPackage {
104			t.Errorf("Expected client key to have value %s but got %s", expectedClientPackage, clientPackage.Path())
105		}
106	}
107
108	// check that imports don't have package named "v1"
109	if _, ok := imports["v1"]; ok {
110		t.Errorf("Expected import not to have key \"v1\"")
111	}
112}
113
114func Benchmark_parseFile(b *testing.B) {
115	source := "internal/tests/performance/big_interface/big_interface.go"
116	for n := 0; n < b.N; n++ {
117		sourceMode(source)
118	}
119}
120
121func TestParsePackageImport(t *testing.T) {
122	testRoot, err := ioutil.TempDir("", "test_root")
123	if err != nil {
124		t.Fatal("cannot create tempdir")
125	}
126	defer func() {
127		if err = os.RemoveAll(testRoot); err != nil {
128			t.Errorf("cannot clean up tempdir at %s: %v", testRoot, err)
129		}
130	}()
131	barDir := filepath.Join(testRoot, "gomod/bar")
132	if err = os.MkdirAll(barDir, 0755); err != nil {
133		t.Fatalf("error creating %s: %v", barDir, err)
134	}
135	if err = ioutil.WriteFile(filepath.Join(barDir, "bar.go"), []byte("package bar"), 0644); err != nil {
136		t.Fatalf("error creating gomod/bar/bar.go: %v", err)
137	}
138	if err = ioutil.WriteFile(filepath.Join(testRoot, "gomod/go.mod"), []byte("module github.com/golang/foo"), 0644); err != nil {
139		t.Fatalf("error creating gomod/go.mod: %v", err)
140	}
141	goPath := filepath.Join(testRoot, "gopath")
142	for _, testCase := range []struct {
143		name    string
144		envs    map[string]string
145		dir     string
146		pkgPath string
147		err     error
148	}{
149		{
150			name:    "go mod default",
151			envs:    map[string]string{"GO111MODULE": ""},
152			dir:     barDir,
153			pkgPath: "github.com/golang/foo/bar",
154		},
155		{
156			name:    "go mod off",
157			envs:    map[string]string{"GO111MODULE": "off", "GOPATH": goPath},
158			dir:     filepath.Join(testRoot, "gopath/src/example.com/foo"),
159			pkgPath: "example.com/foo",
160		},
161		{
162			name: "outside GOPATH",
163			envs: map[string]string{"GO111MODULE": "off", "GOPATH": goPath},
164			dir:  "testdata",
165			err:  errOutsideGoPath,
166		},
167	} {
168		t.Run(testCase.name, func(t *testing.T) {
169			for key, value := range testCase.envs {
170				if err := os.Setenv(key, value); err != nil {
171					t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err)
172				}
173			}
174			pkgPath, err := parsePackageImport(filepath.Clean(testCase.dir))
175			if err != testCase.err {
176				t.Errorf("expect %v, got %v", testCase.err, err)
177			}
178			if pkgPath != testCase.pkgPath {
179				t.Errorf("expect %s, got %s", testCase.pkgPath, pkgPath)
180			}
181		})
182	}
183}
184
185func TestParsePackageImport_FallbackGoPath(t *testing.T) {
186	goPath, err := ioutil.TempDir("", "gopath")
187	if err != nil {
188		t.Error(err)
189	}
190	defer func() {
191		if err = os.RemoveAll(goPath); err != nil {
192			t.Error(err)
193		}
194	}()
195	srcDir := filepath.Join(goPath, "src/example.com/foo")
196	err = os.MkdirAll(srcDir, 0755)
197	if err != nil {
198		t.Error(err)
199	}
200	key := "GOPATH"
201	value := goPath
202	if err := os.Setenv(key, value); err != nil {
203		t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err)
204	}
205	key = "GO111MODULE"
206	value = "on"
207	if err := os.Setenv(key, value); err != nil {
208		t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err)
209	}
210	pkgPath, err := parsePackageImport(srcDir)
211	expected := "example.com/foo"
212	if pkgPath != expected {
213		t.Errorf("expect %s, got %s", expected, pkgPath)
214	}
215}
216
217func TestParsePackageImport_FallbackMultiGoPath(t *testing.T) {
218	var goPathList []string
219
220	// first gopath
221	goPath, err := ioutil.TempDir("", "gopath1")
222	if err != nil {
223		t.Error(err)
224	}
225	goPathList = append(goPathList, goPath)
226	defer func() {
227		if err = os.RemoveAll(goPath); err != nil {
228			t.Error(err)
229		}
230	}()
231	srcDir := filepath.Join(goPath, "src/example.com/foo")
232	err = os.MkdirAll(srcDir, 0755)
233	if err != nil {
234		t.Error(err)
235	}
236
237	// second gopath
238	goPath, err = ioutil.TempDir("", "gopath2")
239	if err != nil {
240		t.Error(err)
241	}
242	goPathList = append(goPathList, goPath)
243	defer func() {
244		if err = os.RemoveAll(goPath); err != nil {
245			t.Error(err)
246		}
247	}()
248
249	goPaths := strings.Join(goPathList, string(os.PathListSeparator))
250	key := "GOPATH"
251	value := goPaths
252	if err := os.Setenv(key, value); err != nil {
253		t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err)
254	}
255	key = "GO111MODULE"
256	value = "on"
257	if err := os.Setenv(key, value); err != nil {
258		t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err)
259	}
260	pkgPath, err := parsePackageImport(srcDir)
261	expected := "example.com/foo"
262	if pkgPath != expected {
263		t.Errorf("expect %s, got %s", expected, pkgPath)
264	}
265}
266