1package main
2
3import (
4	"fmt"
5	"reflect"
6	"regexp"
7	"strings"
8	"testing"
9
10	"github.com/golang/mock/mockgen/model"
11)
12
13func TestMakeArgString(t *testing.T) {
14	testCases := []struct {
15		argNames  []string
16		argTypes  []string
17		argString string
18	}{
19		{
20			argNames:  nil,
21			argTypes:  nil,
22			argString: "",
23		},
24		{
25			argNames:  []string{"arg0"},
26			argTypes:  []string{"int"},
27			argString: "arg0 int",
28		},
29		{
30			argNames:  []string{"arg0", "arg1"},
31			argTypes:  []string{"int", "bool"},
32			argString: "arg0 int, arg1 bool",
33		},
34		{
35			argNames:  []string{"arg0", "arg1"},
36			argTypes:  []string{"int", "int"},
37			argString: "arg0, arg1 int",
38		},
39		{
40			argNames:  []string{"arg0", "arg1", "arg2"},
41			argTypes:  []string{"bool", "int", "int"},
42			argString: "arg0 bool, arg1, arg2 int",
43		},
44		{
45			argNames:  []string{"arg0", "arg1", "arg2"},
46			argTypes:  []string{"int", "bool", "int"},
47			argString: "arg0 int, arg1 bool, arg2 int",
48		},
49		{
50			argNames:  []string{"arg0", "arg1", "arg2"},
51			argTypes:  []string{"int", "int", "bool"},
52			argString: "arg0, arg1 int, arg2 bool",
53		},
54		{
55			argNames:  []string{"arg0", "arg1", "arg2"},
56			argTypes:  []string{"int", "int", "int"},
57			argString: "arg0, arg1, arg2 int",
58		},
59		{
60			argNames:  []string{"arg0", "arg1", "arg2", "arg3"},
61			argTypes:  []string{"bool", "int", "int", "int"},
62			argString: "arg0 bool, arg1, arg2, arg3 int",
63		},
64		{
65			argNames:  []string{"arg0", "arg1", "arg2", "arg3"},
66			argTypes:  []string{"int", "bool", "int", "int"},
67			argString: "arg0 int, arg1 bool, arg2, arg3 int",
68		},
69		{
70			argNames:  []string{"arg0", "arg1", "arg2", "arg3"},
71			argTypes:  []string{"int", "int", "bool", "int"},
72			argString: "arg0, arg1 int, arg2 bool, arg3 int",
73		},
74		{
75			argNames:  []string{"arg0", "arg1", "arg2", "arg3"},
76			argTypes:  []string{"int", "int", "int", "bool"},
77			argString: "arg0, arg1, arg2 int, arg3 bool",
78		},
79		{
80			argNames:  []string{"arg0", "arg1", "arg2", "arg3", "arg4"},
81			argTypes:  []string{"bool", "int", "int", "int", "bool"},
82			argString: "arg0 bool, arg1, arg2, arg3 int, arg4 bool",
83		},
84		{
85			argNames:  []string{"arg0", "arg1", "arg2", "arg3", "arg4"},
86			argTypes:  []string{"int", "bool", "int", "int", "bool"},
87			argString: "arg0 int, arg1 bool, arg2, arg3 int, arg4 bool",
88		},
89		{
90			argNames:  []string{"arg0", "arg1", "arg2", "arg3", "arg4"},
91			argTypes:  []string{"int", "int", "bool", "int", "bool"},
92			argString: "arg0, arg1 int, arg2 bool, arg3 int, arg4 bool",
93		},
94		{
95			argNames:  []string{"arg0", "arg1", "arg2", "arg3", "arg4"},
96			argTypes:  []string{"int", "int", "int", "bool", "bool"},
97			argString: "arg0, arg1, arg2 int, arg3, arg4 bool",
98		},
99		{
100			argNames:  []string{"arg0", "arg1", "arg2", "arg3", "arg4"},
101			argTypes:  []string{"int", "int", "bool", "bool", "int"},
102			argString: "arg0, arg1 int, arg2, arg3 bool, arg4 int",
103		},
104	}
105
106	for i, tc := range testCases {
107		t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
108			s := makeArgString(tc.argNames, tc.argTypes)
109			if s != tc.argString {
110				t.Errorf("result == %q, want %q", s, tc.argString)
111			}
112		})
113	}
114}
115
116func TestNewIdentifierAllocator(t *testing.T) {
117	a := newIdentifierAllocator([]string{"taken1", "taken2"})
118	if len(a) != 2 {
119		t.Fatalf("expected 2 items, got %v", len(a))
120	}
121
122	_, ok := a["taken1"]
123	if !ok {
124		t.Errorf("allocator doesn't contain 'taken1': %#v", a)
125	}
126
127	_, ok = a["taken2"]
128	if !ok {
129		t.Errorf("allocator doesn't contain 'taken2': %#v", a)
130	}
131}
132
133func allocatorContainsIdentifiers(a identifierAllocator, ids []string) bool {
134	if len(a) != len(ids) {
135		return false
136	}
137
138	for _, id := range ids {
139		_, ok := a[id]
140		if !ok {
141			return false
142		}
143	}
144
145	return true
146}
147
148func TestIdentifierAllocator_allocateIdentifier(t *testing.T) {
149	a := newIdentifierAllocator([]string{"taken"})
150
151	t2 := a.allocateIdentifier("taken_2")
152	if t2 != "taken_2" {
153		t.Fatalf("expected 'taken_2', got %q", t2)
154	}
155	expected := []string{"taken", "taken_2"}
156	if !allocatorContainsIdentifiers(a, expected) {
157		t.Fatalf("allocator doesn't contain the expected items - allocator: %#v, expected items: %#v", a, expected)
158	}
159
160	t3 := a.allocateIdentifier("taken")
161	if t3 != "taken_3" {
162		t.Fatalf("expected 'taken_3', got %q", t3)
163	}
164	expected = []string{"taken", "taken_2", "taken_3"}
165	if !allocatorContainsIdentifiers(a, expected) {
166		t.Fatalf("allocator doesn't contain the expected items - allocator: %#v, expected items: %#v", a, expected)
167	}
168
169	t4 := a.allocateIdentifier("taken")
170	if t4 != "taken_4" {
171		t.Fatalf("expected 'taken_4', got %q", t4)
172	}
173	expected = []string{"taken", "taken_2", "taken_3", "taken_4"}
174	if !allocatorContainsIdentifiers(a, expected) {
175		t.Fatalf("allocator doesn't contain the expected items - allocator: %#v, expected items: %#v", a, expected)
176	}
177
178	id := a.allocateIdentifier("id")
179	if id != "id" {
180		t.Fatalf("expected 'id', got %q", id)
181	}
182	expected = []string{"taken", "taken_2", "taken_3", "taken_4", "id"}
183	if !allocatorContainsIdentifiers(a, expected) {
184		t.Fatalf("allocator doesn't contain the expected items - allocator: %#v, expected items: %#v", a, expected)
185	}
186}
187
188func TestGenerateMockInterface_Helper(t *testing.T) {
189	for _, test := range []struct {
190		Name       string
191		Identifier string
192		HelperLine string
193		Methods    []*model.Method
194	}{
195		{Name: "mock", Identifier: "MockSomename", HelperLine: "m.ctrl.T.Helper()"},
196		{Name: "recorder", Identifier: "MockSomenameMockRecorder", HelperLine: "mr.mock.ctrl.T.Helper()"},
197		{
198			Name:       "mock identifier conflict",
199			Identifier: "MockSomename",
200			HelperLine: "m_2.ctrl.T.Helper()",
201			Methods: []*model.Method{
202				{
203					Name: "MethodA",
204					In: []*model.Parameter{
205						{
206							Name: "m",
207							Type: &model.NamedType{Type: "int"},
208						},
209					},
210				},
211			},
212		},
213		{
214			Name:       "recorder identifier conflict",
215			Identifier: "MockSomenameMockRecorder",
216			HelperLine: "mr_2.mock.ctrl.T.Helper()",
217			Methods: []*model.Method{
218				{
219					Name: "MethodA",
220					In: []*model.Parameter{
221						{
222							Name: "mr",
223							Type: &model.NamedType{Type: "int"},
224						},
225					},
226				},
227			},
228		},
229	} {
230		t.Run(test.Name, func(t *testing.T) {
231			g := generator{}
232
233			if len(test.Methods) == 0 {
234				test.Methods = []*model.Method{
235					{Name: "MethodA"},
236					{Name: "MethodB"},
237				}
238			}
239
240			intf := &model.Interface{Name: "Somename"}
241			for _, m := range test.Methods {
242				intf.AddMethod(m)
243			}
244
245			if err := g.GenerateMockInterface(intf, "somepackage"); err != nil {
246				t.Fatal(err)
247			}
248
249			lines := strings.Split(g.buf.String(), "\n")
250
251			// T.Helper() should be the first line
252			for _, method := range test.Methods {
253				if strings.TrimSpace(lines[findMethod(t, test.Identifier, method.Name, lines)+1]) != test.HelperLine {
254					t.Fatalf("method %s.%s did not declare itself a Helper method", test.Identifier, method.Name)
255				}
256			}
257		})
258	}
259}
260
261func findMethod(t *testing.T, identifier, methodName string, lines []string) int {
262	t.Helper()
263	r := regexp.MustCompile(fmt.Sprintf(`func\s+\(.+%s\)\s*%s`, identifier, methodName))
264	for i, line := range lines {
265		if r.MatchString(line) {
266			return i
267		}
268	}
269
270	t.Fatalf("unable to find 'func (m %s) %s'", identifier, methodName)
271	panic("unreachable")
272}
273
274func TestGetArgNames(t *testing.T) {
275	for _, testCase := range []struct {
276		name     string
277		method   *model.Method
278		expected []string
279	}{
280		{
281			name: "NamedArg",
282			method: &model.Method{
283				In: []*model.Parameter{
284					{
285						Name: "firstArg",
286						Type: &model.NamedType{Type: "int"},
287					},
288					{
289						Name: "secondArg",
290						Type: &model.NamedType{Type: "string"},
291					},
292				},
293			},
294			expected: []string{"firstArg", "secondArg"},
295		},
296		{
297			name: "NotNamedArg",
298			method: &model.Method{
299				In: []*model.Parameter{
300					{
301						Name: "",
302						Type: &model.NamedType{Type: "int"},
303					},
304					{
305						Name: "",
306						Type: &model.NamedType{Type: "string"},
307					},
308				},
309			},
310			expected: []string{"arg0", "arg1"},
311		},
312		{
313			name: "MixedNameArg",
314			method: &model.Method{
315				In: []*model.Parameter{
316					{
317						Name: "firstArg",
318						Type: &model.NamedType{Type: "int"},
319					},
320					{
321						Name: "_",
322						Type: &model.NamedType{Type: "string"},
323					},
324				},
325			},
326			expected: []string{"firstArg", "arg1"},
327		},
328	} {
329		t.Run(testCase.name, func(t *testing.T) {
330			g := generator{}
331
332			result := g.getArgNames(testCase.method)
333			if !reflect.DeepEqual(result, testCase.expected) {
334				t.Fatalf("expected %s, got %s", result, testCase.expected)
335			}
336		})
337	}
338}
339
340func Test_createPackageMap(t *testing.T) {
341	tests := []struct {
342		name            string
343		importPath      string
344		wantPackageName string
345		wantOK          bool
346	}{
347		{"golang package", "context", "context", true},
348		{"third party", "golang.org/x/tools/present", "present", true},
349	}
350	var importPaths []string
351	for _, t := range tests {
352		importPaths = append(importPaths, t.importPath)
353	}
354	packages := createPackageMap(importPaths)
355	for _, tt := range tests {
356		t.Run(tt.name, func(t *testing.T) {
357			gotPackageName, gotOk := packages[tt.importPath]
358			if gotPackageName != tt.wantPackageName {
359				t.Errorf("createPackageMap() gotPackageName = %v, wantPackageName = %v", gotPackageName, tt.wantPackageName)
360			}
361			if gotOk != tt.wantOK {
362				t.Errorf("createPackageMap() gotOk = %v, wantOK = %v", gotOk, tt.wantOK)
363			}
364		})
365	}
366}
367