1// Copyright 2019 The go-github AUTHORS. All rights reserved.
2//
3// Use of this source code is governed by a BSD-style
4// license that can be found in the LICENSE file.
5
6// +build ignore
7
8// gen-stringify-test generates test methods to test the String methods.
9//
10// These tests eliminate most of the code coverage problems so that real
11// code coverage issues can be more readily identified.
12//
13// It is meant to be used by go-github contributors in conjunction with the
14// go generate tool before sending a PR to GitHub.
15// Please see the CONTRIBUTING.md file for more information.
16package main
17
18import (
19	"bytes"
20	"flag"
21	"fmt"
22	"go/ast"
23	"go/format"
24	"go/parser"
25	"go/token"
26	"io/ioutil"
27	"log"
28	"os"
29	"strings"
30	"text/template"
31)
32
33const (
34	ignoreFilePrefix1 = "gen-"
35	ignoreFilePrefix2 = "github-"
36	outputFileSuffix  = "-stringify_test.go"
37)
38
39var (
40	verbose = flag.Bool("v", false, "Print verbose log messages")
41
42	// blacklistStructMethod lists "struct.method" combos to skip.
43	blacklistStructMethod = map[string]bool{}
44	// blacklistStruct lists structs to skip.
45	blacklistStruct = map[string]bool{
46		"RateLimits": true,
47	}
48
49	funcMap = template.FuncMap{
50		"isNotLast": func(index int, slice []*structField) string {
51			if index+1 < len(slice) {
52				return ", "
53			}
54			return ""
55		},
56		"processZeroValue": func(v string) string {
57			switch v {
58			case "Bool(false)":
59				return "false"
60			case "Float64(0.0)":
61				return "0"
62			case "0", "Int(0)", "Int64(0)":
63				return "0"
64			case `""`, `String("")`:
65				return `""`
66			case "Timestamp{}", "&Timestamp{}":
67				return "github.Timestamp{0001-01-01 00:00:00 +0000 UTC}"
68			case "nil":
69				return "map[]"
70			}
71			log.Fatalf("Unhandled zero value: %q", v)
72			return ""
73		},
74	}
75
76	sourceTmpl = template.Must(template.New("source").Funcs(funcMap).Parse(source))
77)
78
79func main() {
80	flag.Parse()
81	fset := token.NewFileSet()
82
83	pkgs, err := parser.ParseDir(fset, ".", sourceFilter, 0)
84	if err != nil {
85		log.Fatal(err)
86		return
87	}
88
89	for pkgName, pkg := range pkgs {
90		t := &templateData{
91			filename:     pkgName + outputFileSuffix,
92			Year:         2019, // No need to change this once set (even in following years).
93			Package:      pkgName,
94			Imports:      map[string]string{"testing": "testing"},
95			StringFuncs:  map[string]bool{},
96			StructFields: map[string][]*structField{},
97		}
98		for filename, f := range pkg.Files {
99			logf("Processing %v...", filename)
100			if err := t.processAST(f); err != nil {
101				log.Fatal(err)
102			}
103		}
104		if err := t.dump(); err != nil {
105			log.Fatal(err)
106		}
107	}
108	logf("Done.")
109}
110
111func sourceFilter(fi os.FileInfo) bool {
112	return !strings.HasSuffix(fi.Name(), "_test.go") &&
113		!strings.HasPrefix(fi.Name(), ignoreFilePrefix1) &&
114		!strings.HasPrefix(fi.Name(), ignoreFilePrefix2)
115}
116
117type templateData struct {
118	filename     string
119	Year         int
120	Package      string
121	Imports      map[string]string
122	StringFuncs  map[string]bool
123	StructFields map[string][]*structField
124}
125
126type structField struct {
127	sortVal      string // Lower-case version of "ReceiverType.FieldName".
128	ReceiverVar  string // The one-letter variable name to match the ReceiverType.
129	ReceiverType string
130	FieldName    string
131	FieldType    string
132	ZeroValue    string
133	NamedStruct  bool // Getter for named struct.
134}
135
136func (t *templateData) processAST(f *ast.File) error {
137	for _, decl := range f.Decls {
138		fn, ok := decl.(*ast.FuncDecl)
139		if ok {
140			if fn.Recv != nil && len(fn.Recv.List) > 0 {
141				id, ok := fn.Recv.List[0].Type.(*ast.Ident)
142				if ok && fn.Name.Name == "String" {
143					logf("Got FuncDecl: Name=%q, id.Name=%#v", fn.Name.Name, id.Name)
144					t.StringFuncs[id.Name] = true
145				} else {
146					logf("Ignoring FuncDecl: Name=%q, Type=%T", fn.Name.Name, fn.Recv.List[0].Type)
147				}
148			} else {
149				logf("Ignoring FuncDecl: Name=%q, fn=%#v", fn.Name.Name, fn)
150			}
151			continue
152		}
153
154		gd, ok := decl.(*ast.GenDecl)
155		if !ok {
156			logf("Ignoring AST decl type %T", decl)
157			continue
158		}
159		for _, spec := range gd.Specs {
160			ts, ok := spec.(*ast.TypeSpec)
161			if !ok {
162				continue
163			}
164			// Skip unexported identifiers.
165			if !ts.Name.IsExported() {
166				logf("Struct %v is unexported; skipping.", ts.Name)
167				continue
168			}
169			// Check if the struct is blacklisted.
170			if blacklistStruct[ts.Name.Name] {
171				logf("Struct %v is blacklisted; skipping.", ts.Name)
172				continue
173			}
174			st, ok := ts.Type.(*ast.StructType)
175			if !ok {
176				logf("Ignoring AST type %T, Name=%q", ts.Type, ts.Name.String())
177				continue
178			}
179			for _, field := range st.Fields.List {
180				if len(field.Names) == 0 {
181					continue
182				}
183
184				fieldName := field.Names[0]
185				if id, ok := field.Type.(*ast.Ident); ok {
186					t.addIdent(id, ts.Name.String(), fieldName.String())
187					continue
188				}
189
190				if _, ok := field.Type.(*ast.MapType); ok {
191					t.addMapType(ts.Name.String(), fieldName.String())
192					continue
193				}
194
195				se, ok := field.Type.(*ast.StarExpr)
196				if !ok {
197					logf("Ignoring type %T for Name=%q, FieldName=%q", field.Type, ts.Name.String(), fieldName.String())
198					continue
199				}
200
201				// Skip unexported identifiers.
202				if !fieldName.IsExported() {
203					logf("Field %v is unexported; skipping.", fieldName)
204					continue
205				}
206				// Check if "struct.method" is blacklisted.
207				if key := fmt.Sprintf("%v.Get%v", ts.Name, fieldName); blacklistStructMethod[key] {
208					logf("Method %v is blacklisted; skipping.", key)
209					continue
210				}
211
212				switch x := se.X.(type) {
213				case *ast.ArrayType:
214				case *ast.Ident:
215					t.addIdentPtr(x, ts.Name.String(), fieldName.String())
216				case *ast.MapType:
217				case *ast.SelectorExpr:
218				default:
219					logf("processAST: type %q, field %q, unknown %T: %+v", ts.Name, fieldName, x, x)
220				}
221			}
222		}
223	}
224	return nil
225}
226
227func (t *templateData) addMapType(receiverType, fieldName string) {
228	t.StructFields[receiverType] = append(t.StructFields[receiverType], newStructField(receiverType, fieldName, "map[]", "nil", false))
229}
230
231func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) {
232	var zeroValue string
233	var namedStruct = false
234	switch x.String() {
235	case "int":
236		zeroValue = "0"
237	case "int64":
238		zeroValue = "0"
239	case "float64":
240		zeroValue = "0.0"
241	case "string":
242		zeroValue = `""`
243	case "bool":
244		zeroValue = "false"
245	case "Timestamp":
246		zeroValue = "Timestamp{}"
247	default:
248		zeroValue = "nil"
249		namedStruct = true
250	}
251
252	t.StructFields[receiverType] = append(t.StructFields[receiverType], newStructField(receiverType, fieldName, x.String(), zeroValue, namedStruct))
253}
254
255func (t *templateData) addIdentPtr(x *ast.Ident, receiverType, fieldName string) {
256	var zeroValue string
257	var namedStruct = false
258	switch x.String() {
259	case "int":
260		zeroValue = "Int(0)"
261	case "int64":
262		zeroValue = "Int64(0)"
263	case "float64":
264		zeroValue = "Float64(0.0)"
265	case "string":
266		zeroValue = `String("")`
267	case "bool":
268		zeroValue = "Bool(false)"
269	case "Timestamp":
270		zeroValue = "&Timestamp{}"
271	default:
272		zeroValue = "nil"
273		namedStruct = true
274	}
275
276	t.StructFields[receiverType] = append(t.StructFields[receiverType], newStructField(receiverType, fieldName, x.String(), zeroValue, namedStruct))
277}
278
279func (t *templateData) dump() error {
280	if len(t.StructFields) == 0 {
281		logf("No StructFields for %v; skipping.", t.filename)
282		return nil
283	}
284
285	// Remove unused structs.
286	var toDelete []string
287	for k := range t.StructFields {
288		if !t.StringFuncs[k] {
289			toDelete = append(toDelete, k)
290			continue
291		}
292	}
293	for _, k := range toDelete {
294		delete(t.StructFields, k)
295	}
296
297	var buf bytes.Buffer
298	if err := sourceTmpl.Execute(&buf, t); err != nil {
299		return err
300	}
301	clean, err := format.Source(buf.Bytes())
302	if err != nil {
303		log.Printf("failed-to-format source:\n%v", buf.String())
304		return err
305	}
306
307	logf("Writing %v...", t.filename)
308	return ioutil.WriteFile(t.filename, clean, 0644)
309}
310
311func newStructField(receiverType, fieldName, fieldType, zeroValue string, namedStruct bool) *structField {
312	return &structField{
313		sortVal:      strings.ToLower(receiverType) + "." + strings.ToLower(fieldName),
314		ReceiverVar:  strings.ToLower(receiverType[:1]),
315		ReceiverType: receiverType,
316		FieldName:    fieldName,
317		FieldType:    fieldType,
318		ZeroValue:    zeroValue,
319		NamedStruct:  namedStruct,
320	}
321}
322
323func logf(fmt string, args ...interface{}) {
324	if *verbose {
325		log.Printf(fmt, args...)
326	}
327}
328
329const source = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved.
330//
331// Use of this source code is governed by a BSD-style
332// license that can be found in the LICENSE file.
333
334// Code generated by gen-stringify-tests; DO NOT EDIT.
335
336package {{ $package := .Package}}{{$package}}
337{{with .Imports}}
338import (
339  {{- range . -}}
340  "{{.}}"
341  {{end -}}
342)
343{{end}}
344func Float64(v float64) *float64 { return &v }
345{{range $key, $value := .StructFields}}
346func Test{{ $key }}_String(t *testing.T) {
347  v := {{ $key }}{ {{range .}}{{if .NamedStruct}}
348    {{ .FieldName }}: &{{ .FieldType }}{},{{else}}
349    {{ .FieldName }}: {{.ZeroValue}},{{end}}{{end}}
350  }
351 	want := ` + "`" + `{{ $package }}.{{ $key }}{{ $slice := . }}{
352{{- range $ind, $val := .}}{{if .NamedStruct}}{{ .FieldName }}:{{ $package }}.{{ .FieldType }}{}{{else}}{{ .FieldName }}:{{ processZeroValue .ZeroValue }}{{end}}{{ isNotLast $ind $slice }}{{end}}}` + "`" + `
353	if got := v.String(); got != want {
354		t.Errorf("{{ $key }}.String = %v, want %v", got, want)
355	}
356}
357{{end}}
358`
359