1// Copyright 2020 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package main
6
7import (
8	"bytes"
9	"fmt"
10	"go/parser"
11	"go/token"
12	"html/template"
13	"strings"
14	"testing"
15)
16
17func TestImports(t *testing.T) {
18	t.Run("importName", func(t *testing.T) {
19		cases := []struct {
20			src   string
21			ident string
22		}{
23			{`"syscall"`, "syscall"},
24			{`. "foobar"`, "."},
25			{`"go/ast"`, "ast"},
26			{`moo "go/format"`, "moo"},
27			{`. "go/token"`, "."},
28			{`"golang.org/x/sys/unix"`, "unix"},
29			{`nix "golang.org/x/sys/unix"`, "nix"},
30			{`_ "golang.org/x/sys/unix"`, "_"},
31		}
32
33		for _, c := range cases {
34			pkgSrc := fmt.Sprintf("package main\nimport %s", c.src)
35
36			f, err := parser.ParseFile(token.NewFileSet(), "", pkgSrc, parser.ImportsOnly)
37			if err != nil {
38				t.Error(err)
39				continue
40			}
41			if len(f.Imports) != 1 {
42				t.Errorf("Got %d imports, expected 1", len(f.Imports))
43				continue
44			}
45
46			got, err := importName(f.Imports[0])
47			if err != nil {
48				t.Fatal(err)
49			}
50			if got != c.ident {
51				t.Errorf("Got %q, expected %q", got, c.ident)
52			}
53		}
54	})
55
56	t.Run("filterImports", func(t *testing.T) {
57		cases := []struct{ before, after string }{
58			{`package test
59
60			import (
61				"foo"
62				"bar"
63			)`,
64				"package test\n"},
65			{`package test
66
67			import (
68				"foo"
69				"bar"
70			)
71
72			func useFoo() { foo.Usage() }`,
73				`package test
74
75import (
76	"foo"
77)
78
79func useFoo() { foo.Usage() }
80`},
81		}
82		for _, c := range cases {
83			got, err := filterImports([]byte(c.before))
84			if err != nil {
85				t.Error(err)
86			}
87
88			if string(got) != c.after {
89				t.Errorf("Got:\n%s\nExpected:\n%s\n", got, c.after)
90			}
91		}
92	})
93}
94
95func TestMerge(t *testing.T) {
96	// Input architecture files
97	inTmpl := template.Must(template.New("input").Parse(`
98// Package comments
99
100// build directives for arch{{.}}
101
102//go:build goos && arch{{.}}
103// +build goos,arch{{.}}
104
105package main
106
107/*
108#include <stdint.h>
109#include <stddef.h>
110int utimes(uintptr_t, uintptr_t);
111int utimensat(int, uintptr_t, uintptr_t, int);
112*/
113import "C"
114
115// The imports
116import (
117	"commonDep"
118	"uniqueDep{{.}}"
119)
120
121// Vars
122var (
123	commonVar = commonDep.Use("common")
124
125	uniqueVar{{.}} = "unique{{.}}"
126)
127
128// Common free standing comment
129
130// Common comment
131const COMMON_INDEPENDENT = 1234
132const UNIQUE_INDEPENDENT_{{.}} = "UNIQUE_INDEPENDENT_{{.}}"
133
134// Group comment
135const (
136	COMMON_GROUP = "COMMON_GROUP"
137	UNIQUE_GROUP_{{.}} = "UNIQUE_GROUP_{{.}}"
138)
139
140// Group2 comment
141const (
142	UNIQUE_GROUP21_{{.}} = "UNIQUE_GROUP21_{{.}}"
143	UNIQUE_GROUP22_{{.}} = "UNIQUE_GROUP22_{{.}}"
144)
145
146// Group3 comment
147const (
148	sub1Common1 = 11
149	sub1Unique2{{.}} = 12
150	sub1Common3_LONG = 13
151
152	sub2Unique1{{.}} = 21
153	sub2Common2 = 22
154	sub2Common3 = 23
155	sub2Unique4{{.}} = 24
156)
157
158type commonInt int
159
160type uniqueInt{{.}} int
161
162func commonF() string {
163	return commonDep.Use("common")
164	}
165
166func uniqueF() string {
167	C.utimes(0, 0)
168	return uniqueDep{{.}}.Use("{{.}}")
169	}
170
171// Group4 comment
172const (
173	sub3Common1 = 31
174	sub3Unique2{{.}} = 32
175	sub3Unique3{{.}} = 33
176	sub3Common4 = 34
177
178	sub4Common1, sub4Unique2{{.}} = 41, 42
179	sub4Unique3{{.}}, sub4Common4 = 43, 44
180)
181`))
182
183	// Filtered architecture files
184	outTmpl := template.Must(template.New("output").Parse(`// Package comments
185
186// build directives for arch{{.}}
187
188//go:build goos && arch{{.}}
189// +build goos,arch{{.}}
190
191package main
192
193/*
194#include <stdint.h>
195#include <stddef.h>
196int utimes(uintptr_t, uintptr_t);
197int utimensat(int, uintptr_t, uintptr_t, int);
198*/
199import "C"
200
201// The imports
202import (
203	"commonDep"
204	"uniqueDep{{.}}"
205)
206
207// Vars
208var (
209	commonVar = commonDep.Use("common")
210
211	uniqueVar{{.}} = "unique{{.}}"
212)
213
214const UNIQUE_INDEPENDENT_{{.}} = "UNIQUE_INDEPENDENT_{{.}}"
215
216// Group comment
217const (
218	UNIQUE_GROUP_{{.}} = "UNIQUE_GROUP_{{.}}"
219)
220
221// Group2 comment
222const (
223	UNIQUE_GROUP21_{{.}} = "UNIQUE_GROUP21_{{.}}"
224	UNIQUE_GROUP22_{{.}} = "UNIQUE_GROUP22_{{.}}"
225)
226
227// Group3 comment
228const (
229	sub1Unique2{{.}} = 12
230
231	sub2Unique1{{.}} = 21
232	sub2Unique4{{.}} = 24
233)
234
235type uniqueInt{{.}} int
236
237func uniqueF() string {
238	C.utimes(0, 0)
239	return uniqueDep{{.}}.Use("{{.}}")
240}
241
242// Group4 comment
243const (
244	sub3Unique2{{.}} = 32
245	sub3Unique3{{.}} = 33
246
247	sub4Common1, sub4Unique2{{.}} = 41, 42
248	sub4Unique3{{.}}, sub4Common4 = 43, 44
249)
250`))
251
252	const mergedFile = `// Package comments
253
254package main
255
256// The imports
257import (
258	"commonDep"
259)
260
261// Common free standing comment
262
263// Common comment
264const COMMON_INDEPENDENT = 1234
265
266// Group comment
267const (
268	COMMON_GROUP = "COMMON_GROUP"
269)
270
271// Group3 comment
272const (
273	sub1Common1      = 11
274	sub1Common3_LONG = 13
275
276	sub2Common2 = 22
277	sub2Common3 = 23
278)
279
280type commonInt int
281
282func commonF() string {
283	return commonDep.Use("common")
284}
285
286// Group4 comment
287const (
288	sub3Common1 = 31
289	sub3Common4 = 34
290)
291`
292
293	// Generate source code for different "architectures"
294	var inFiles, outFiles []srcFile
295	for _, arch := range strings.Fields("A B C D") {
296		buf := new(bytes.Buffer)
297		err := inTmpl.Execute(buf, arch)
298		if err != nil {
299			t.Fatal(err)
300		}
301		inFiles = append(inFiles, srcFile{"file" + arch, buf.Bytes()})
302
303		buf = new(bytes.Buffer)
304		err = outTmpl.Execute(buf, arch)
305		if err != nil {
306			t.Fatal(err)
307		}
308		outFiles = append(outFiles, srcFile{"file" + arch, buf.Bytes()})
309	}
310
311	t.Run("getCodeSet", func(t *testing.T) {
312		got, err := getCodeSet(inFiles[0].src)
313		if err != nil {
314			t.Fatal(err)
315		}
316
317		expectedElems := []codeElem{
318			{token.COMMENT, "Package comments\n"},
319			{token.COMMENT, "build directives for archA\n"},
320			{token.COMMENT, "+build goos,archA\n"},
321			{token.CONST, `COMMON_INDEPENDENT = 1234`},
322			{token.CONST, `UNIQUE_INDEPENDENT_A = "UNIQUE_INDEPENDENT_A"`},
323			{token.CONST, `COMMON_GROUP = "COMMON_GROUP"`},
324			{token.CONST, `UNIQUE_GROUP_A = "UNIQUE_GROUP_A"`},
325			{token.CONST, `UNIQUE_GROUP21_A = "UNIQUE_GROUP21_A"`},
326			{token.CONST, `UNIQUE_GROUP22_A = "UNIQUE_GROUP22_A"`},
327			{token.CONST, `sub1Common1 = 11`},
328			{token.CONST, `sub1Unique2A = 12`},
329			{token.CONST, `sub1Common3_LONG = 13`},
330			{token.CONST, `sub2Unique1A = 21`},
331			{token.CONST, `sub2Common2 = 22`},
332			{token.CONST, `sub2Common3 = 23`},
333			{token.CONST, `sub2Unique4A = 24`},
334			{token.CONST, `sub3Common1 = 31`},
335			{token.CONST, `sub3Unique2A = 32`},
336			{token.CONST, `sub3Unique3A = 33`},
337			{token.CONST, `sub3Common4 = 34`},
338			{token.CONST, `sub4Common1, sub4Unique2A = 41, 42`},
339			{token.CONST, `sub4Unique3A, sub4Common4 = 43, 44`},
340			{token.TYPE, `commonInt int`},
341			{token.TYPE, `uniqueIntA int`},
342			{token.FUNC, `func commonF() string {
343	return commonDep.Use("common")
344}`},
345			{token.FUNC, `func uniqueF() string {
346	C.utimes(0, 0)
347	return uniqueDepA.Use("A")
348}`},
349		}
350		expected := newCodeSet()
351		for _, d := range expectedElems {
352			expected.add(d)
353		}
354
355		if len(got.set) != len(expected.set) {
356			t.Errorf("Got %d codeElems, expected %d", len(got.set), len(expected.set))
357		}
358		for expElem := range expected.set {
359			if !got.has(expElem) {
360				t.Errorf("Didn't get expected codeElem %#v", expElem)
361			}
362		}
363		for gotElem := range got.set {
364			if !expected.has(gotElem) {
365				t.Errorf("Got unexpected codeElem %#v", gotElem)
366			}
367		}
368	})
369
370	t.Run("getCommonSet", func(t *testing.T) {
371		got, err := getCommonSet(inFiles)
372		if err != nil {
373			t.Fatal(err)
374		}
375
376		expected := newCodeSet()
377		expected.add(codeElem{token.COMMENT, "Package comments\n"})
378		expected.add(codeElem{token.CONST, `COMMON_INDEPENDENT = 1234`})
379		expected.add(codeElem{token.CONST, `COMMON_GROUP = "COMMON_GROUP"`})
380		expected.add(codeElem{token.CONST, `sub1Common1 = 11`})
381		expected.add(codeElem{token.CONST, `sub1Common3_LONG = 13`})
382		expected.add(codeElem{token.CONST, `sub2Common2 = 22`})
383		expected.add(codeElem{token.CONST, `sub2Common3 = 23`})
384		expected.add(codeElem{token.CONST, `sub3Common1 = 31`})
385		expected.add(codeElem{token.CONST, `sub3Common4 = 34`})
386		expected.add(codeElem{token.TYPE, `commonInt int`})
387		expected.add(codeElem{token.FUNC, `func commonF() string {
388	return commonDep.Use("common")
389}`})
390
391		if len(got.set) != len(expected.set) {
392			t.Errorf("Got %d codeElems, expected %d", len(got.set), len(expected.set))
393		}
394		for expElem := range expected.set {
395			if !got.has(expElem) {
396				t.Errorf("Didn't get expected codeElem %#v", expElem)
397			}
398		}
399		for gotElem := range got.set {
400			if !expected.has(gotElem) {
401				t.Errorf("Got unexpected codeElem %#v", gotElem)
402			}
403		}
404	})
405
406	t.Run("filter(keepCommon)", func(t *testing.T) {
407		commonSet, err := getCommonSet(inFiles)
408		if err != nil {
409			t.Fatal(err)
410		}
411
412		got, err := filter(inFiles[0].src, commonSet.keepCommon)
413		expected := []byte(mergedFile)
414
415		if !bytes.Equal(got, expected) {
416			t.Errorf("Got:\n%s\nExpected:\n%s", addLineNr(got), addLineNr(expected))
417			diffLines(t, got, expected)
418		}
419	})
420
421	t.Run("filter(keepArchSpecific)", func(t *testing.T) {
422		commonSet, err := getCommonSet(inFiles)
423		if err != nil {
424			t.Fatal(err)
425		}
426
427		for i := range inFiles {
428			got, err := filter(inFiles[i].src, commonSet.keepArchSpecific)
429			if err != nil {
430				t.Fatal(err)
431			}
432
433			expected := outFiles[i].src
434
435			if !bytes.Equal(got, expected) {
436				t.Errorf("Got:\n%s\nExpected:\n%s", addLineNr(got), addLineNr(expected))
437				diffLines(t, got, expected)
438			}
439		}
440	})
441}
442
443func TestMergedName(t *testing.T) {
444	t.Run("getValidGOOS", func(t *testing.T) {
445		testcases := []struct {
446			filename, goos string
447			ok             bool
448		}{
449			{"zerrors_aix.go", "aix", true},
450			{"zerrors_darwin.go", "darwin", true},
451			{"zerrors_dragonfly.go", "dragonfly", true},
452			{"zerrors_freebsd.go", "freebsd", true},
453			{"zerrors_linux.go", "linux", true},
454			{"zerrors_netbsd.go", "netbsd", true},
455			{"zerrors_openbsd.go", "openbsd", true},
456			{"zerrors_solaris.go", "solaris", true},
457			{"zerrors_multics.go", "", false},
458		}
459		for _, tc := range testcases {
460			goos, ok := getValidGOOS(tc.filename)
461			if goos != tc.goos {
462				t.Errorf("got GOOS %q, expected %q", goos, tc.goos)
463			}
464			if ok != tc.ok {
465				t.Errorf("got ok %v, expected %v", ok, tc.ok)
466			}
467		}
468	})
469}
470
471// Helper functions to diff test sources
472
473func diffLines(t *testing.T, got, expected []byte) {
474	t.Helper()
475
476	gotLines := bytes.Split(got, []byte{'\n'})
477	expLines := bytes.Split(expected, []byte{'\n'})
478
479	i := 0
480	for i < len(gotLines) && i < len(expLines) {
481		if !bytes.Equal(gotLines[i], expLines[i]) {
482			t.Errorf("Line %d: Got:\n%q\nExpected:\n%q", i+1, gotLines[i], expLines[i])
483			return
484		}
485		i++
486	}
487
488	if i < len(gotLines) && i >= len(expLines) {
489		t.Errorf("Line %d: got %q, expected EOF", i+1, gotLines[i])
490	}
491	if i >= len(gotLines) && i < len(expLines) {
492		t.Errorf("Line %d: got EOF, expected %q", i+1, gotLines[i])
493	}
494}
495
496func addLineNr(src []byte) []byte {
497	lines := bytes.Split(src, []byte("\n"))
498	for i, line := range lines {
499		lines[i] = []byte(fmt.Sprintf("%d: %s", i+1, line))
500	}
501	return bytes.Join(lines, []byte("\n"))
502}
503