1package main
2
3import (
4	"fmt"
5	"go/ast"
6	"go/parser"
7	"go/token"
8	"io/ioutil"
9	"os"
10	"path/filepath"
11	"reflect"
12	"sort"
13	"testing"
14	"text/template"
15
16	"github.com/tinylib/msgp/gen"
17)
18
19// When stuff's going wrong, you'll be glad this is here!
20const debugTemp = false
21
22// Ensure that consistent identifiers are generated on a per-method basis by msgp.
23//
24// Also ensure that no duplicate identifiers appear in a method.
25//
26// structs are currently processed alphabetically by msgp. this test relies on
27// that property.
28//
29func TestIssue185Idents(t *testing.T) {
30	var identCases = []struct {
31		tpl             *template.Template
32		expectedChanged []string
33	}{
34		{tpl: issue185IdentsTpl, expectedChanged: []string{"Test1"}},
35		{tpl: issue185ComplexIdentsTpl, expectedChanged: []string{"Test2"}},
36	}
37
38	methods := []string{"DecodeMsg", "EncodeMsg", "Msgsize", "MarshalMsg", "UnmarshalMsg"}
39
40	for idx, identCase := range identCases {
41		// generate the code, extract the generated variable names, mapped to function name
42		var tplData issue185TplData
43		varsBefore, err := loadVars(identCase.tpl, tplData)
44		if err != nil {
45			t.Fatalf("%d: could not extract before vars: %v", idx, err)
46		}
47
48		// regenerate the code with extra field(s), extract the generated variable
49		// names, mapped to function name
50		tplData.Extra = true
51		varsAfter, err := loadVars(identCase.tpl, tplData)
52		if err != nil {
53			t.Fatalf("%d: could not extract after vars: %v", idx, err)
54		}
55
56		// ensure that all declared variable names inside each of the methods we
57		// expect to change have actually changed
58		for _, stct := range identCase.expectedChanged {
59			for _, method := range methods {
60				fn := fmt.Sprintf("%s.%s", stct, method)
61
62				bv, av := varsBefore.Value(fn), varsAfter.Value(fn)
63				if len(bv) > 0 && len(av) > 0 && reflect.DeepEqual(bv, av) {
64					t.Fatalf("%d vars identical! expected vars to change for %s", idx, fn)
65				}
66				delete(varsBefore, fn)
67				delete(varsAfter, fn)
68			}
69		}
70
71		// all of the remaining keys should not have changed
72		for bmethod, bvars := range varsBefore {
73			avars := varsAfter.Value(bmethod)
74
75			if !reflect.DeepEqual(bvars, avars) {
76				t.Fatalf("%d: vars changed! expected vars identical for %s", idx, bmethod)
77			}
78			delete(varsBefore, bmethod)
79			delete(varsAfter, bmethod)
80		}
81
82		if len(varsBefore) > 0 || len(varsAfter) > 0 {
83			t.Fatalf("%d: unexpected methods remaining", idx)
84		}
85	}
86}
87
88type issue185TplData struct {
89	Extra bool
90}
91
92func TestIssue185Overlap(t *testing.T) {
93	var overlapCases = []struct {
94		tpl  *template.Template
95		data issue185TplData
96	}{
97		{tpl: issue185IdentsTpl, data: issue185TplData{Extra: false}},
98		{tpl: issue185IdentsTpl, data: issue185TplData{Extra: true}},
99		{tpl: issue185ComplexIdentsTpl, data: issue185TplData{Extra: false}},
100		{tpl: issue185ComplexIdentsTpl, data: issue185TplData{Extra: true}},
101	}
102
103	for idx, o := range overlapCases {
104		// regenerate the code with extra field(s), extract the generated variable
105		// names, mapped to function name
106		mvars, err := loadVars(o.tpl, o.data)
107		if err != nil {
108			t.Fatalf("%d: could not extract after vars: %v", idx, err)
109		}
110
111		identCnt := 0
112		for fn, vars := range mvars {
113			sort.Strings(vars)
114
115			// Loose sanity check to make sure the tests expectations aren't broken.
116			// If the prefix ever changes, this needs to change.
117			for _, v := range vars {
118				if v[0] == 'z' {
119					identCnt++
120				}
121			}
122
123			for i := 0; i < len(vars)-1; i++ {
124				if vars[i] == vars[i+1] {
125					t.Fatalf("%d: duplicate var %s in function %s", idx, vars[i], fn)
126				}
127			}
128		}
129
130		// one last sanity check: if there aren't any vars that start with 'z',
131		// this test's expectations are unsatisfiable.
132		if identCnt == 0 {
133			t.Fatalf("%d: no generated identifiers found", idx)
134		}
135	}
136}
137
138func loadVars(tpl *template.Template, tplData interface{}) (vars extractedVars, err error) {
139	tempDir, err := ioutil.TempDir("", "msgp-")
140	if err != nil {
141		err = fmt.Errorf("could not create temp dir: %v", err)
142		return
143	}
144
145	if !debugTemp {
146		defer os.RemoveAll(tempDir)
147	} else {
148		fmt.Println(tempDir)
149	}
150	tfile := filepath.Join(tempDir, "msg.go")
151	genFile := newFilename(tfile, "")
152
153	if err = goGenerateTpl(tempDir, tfile, tpl, tplData); err != nil {
154		err = fmt.Errorf("could not generate code: %v", err)
155		return
156	}
157
158	vars, err = extractVars(genFile)
159	if err != nil {
160		err = fmt.Errorf("could not extract after vars: %v", err)
161		return
162	}
163
164	return
165}
166
167type varVisitor struct {
168	vars []string
169	fset *token.FileSet
170}
171
172func (v *varVisitor) Visit(node ast.Node) (w ast.Visitor) {
173	gen, ok := node.(*ast.GenDecl)
174	if !ok {
175		return v
176	}
177	for _, spec := range gen.Specs {
178		if vspec, ok := spec.(*ast.ValueSpec); ok {
179			for _, n := range vspec.Names {
180				v.vars = append(v.vars, n.Name)
181			}
182		}
183	}
184	return v
185}
186
187type extractedVars map[string][]string
188
189func (e extractedVars) Value(key string) []string {
190	if v, ok := e[key]; ok {
191		return v
192	}
193	panic(fmt.Errorf("unknown key %s", key))
194}
195
196func extractVars(file string) (extractedVars, error) {
197	fset := token.NewFileSet()
198
199	f, err := parser.ParseFile(fset, file, nil, 0)
200	if err != nil {
201		return nil, err
202	}
203
204	vars := make(map[string][]string)
205	for _, d := range f.Decls {
206		switch d := d.(type) {
207		case *ast.FuncDecl:
208			sn := ""
209			switch rt := d.Recv.List[0].Type.(type) {
210			case *ast.Ident:
211				sn = rt.Name
212			case *ast.StarExpr:
213				sn = rt.X.(*ast.Ident).Name
214			default:
215				panic("unknown receiver type")
216			}
217
218			key := fmt.Sprintf("%s.%s", sn, d.Name.Name)
219			vis := &varVisitor{fset: fset}
220			ast.Walk(vis, d.Body)
221			vars[key] = vis.vars
222		}
223	}
224	return vars, nil
225}
226
227func goGenerateTpl(cwd, tfile string, tpl *template.Template, tplData interface{}) error {
228	outf, err := os.OpenFile(tfile, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0600)
229	if err != nil {
230		return err
231	}
232	defer outf.Close()
233
234	if err := tpl.Execute(outf, tplData); err != nil {
235		return err
236	}
237
238	mode := gen.Encode | gen.Decode | gen.Size | gen.Marshal | gen.Unmarshal
239
240	return Run(tfile, mode, false)
241}
242
243var issue185IdentsTpl = template.Must(template.New("").Parse(`
244package issue185
245
246//go:generate msgp
247
248type Test1 struct {
249	Foo string
250	Bar string
251	{{ if .Extra }}Baz []string{{ end }}
252	Qux string
253}
254
255type Test2 struct {
256	Foo string
257	Bar string
258	Baz string
259}
260`))
261
262var issue185ComplexIdentsTpl = template.Must(template.New("").Parse(`
263package issue185
264
265//go:generate msgp
266
267type Test1 struct {
268	Foo string
269	Bar string
270	Baz string
271}
272
273type Test2 struct {
274	Foo string
275	Bar string
276	Baz []string
277	Qux map[string]string
278	Yep map[string]map[string]string
279	Quack struct {
280		Quack struct {
281			Quack struct {
282				{{ if .Extra }}Extra []string{{ end }}
283				Quack string
284			}
285		}
286	}
287	Nup struct {
288		Foo string
289		Bar string
290		Baz []string
291		Qux map[string]string
292		Yep map[string]map[string]string
293	}
294	Ding struct {
295		Dong struct {
296			Dung struct {
297				Thing string
298			}
299		}
300	}
301}
302
303type Test3 struct {
304	Foo string
305	Bar string
306	Baz string
307}
308`))
309