1// Copyright 2017 The go-github AUTHORS. All rights reserved.
2//
3// Use of this source code is governed by a BSD-style
4// license that can be found in the LICENSE file.
5
6// +build ignore
7
8// gen-accessors generates accessor methods for structs with pointer fields.
9//
10// It is meant to be used by go-github contributors in conjunction with the
11// go generate tool before sending a PR to GitHub.
12// Please see the CONTRIBUTING.md file for more information.
13package main
14
15import (
16	"bytes"
17	"flag"
18	"fmt"
19	"go/ast"
20	"go/format"
21	"go/parser"
22	"go/token"
23	"io/ioutil"
24	"log"
25	"os"
26	"sort"
27	"strings"
28	"text/template"
29)
30
31const (
32	fileSuffix = "-accessors.go"
33)
34
35var (
36	verbose = flag.Bool("v", false, "Print verbose log messages")
37
38	sourceTmpl = template.Must(template.New("source").Parse(source))
39
40	// blacklistStructMethod lists "struct.method" combos to skip.
41	blacklistStructMethod = map[string]bool{
42		"RepositoryContent.GetContent":    true,
43		"Client.GetBaseURL":               true,
44		"Client.GetUploadURL":             true,
45		"ErrorResponse.GetResponse":       true,
46		"RateLimitError.GetResponse":      true,
47		"AbuseRateLimitError.GetResponse": true,
48	}
49	// blacklistStruct lists structs to skip.
50	blacklistStruct = map[string]bool{
51		"Client": true,
52	}
53)
54
55func logf(fmt string, args ...interface{}) {
56	if *verbose {
57		log.Printf(fmt, args...)
58	}
59}
60
61func main() {
62	flag.Parse()
63	fset := token.NewFileSet()
64
65	pkgs, err := parser.ParseDir(fset, ".", sourceFilter, 0)
66	if err != nil {
67		log.Fatal(err)
68		return
69	}
70
71	for pkgName, pkg := range pkgs {
72		t := &templateData{
73			filename: pkgName + fileSuffix,
74			Year:     2017,
75			Package:  pkgName,
76			Imports:  map[string]string{},
77		}
78		for filename, f := range pkg.Files {
79			logf("Processing %v...", filename)
80			if err := t.processAST(f); err != nil {
81				log.Fatal(err)
82			}
83		}
84		if err := t.dump(); err != nil {
85			log.Fatal(err)
86		}
87	}
88	logf("Done.")
89}
90
91func (t *templateData) processAST(f *ast.File) error {
92	for _, decl := range f.Decls {
93		gd, ok := decl.(*ast.GenDecl)
94		if !ok {
95			continue
96		}
97		for _, spec := range gd.Specs {
98			ts, ok := spec.(*ast.TypeSpec)
99			if !ok {
100				continue
101			}
102			// Skip unexported identifiers.
103			if !ts.Name.IsExported() {
104				logf("Struct %v is unexported; skipping.", ts.Name)
105				continue
106			}
107			// Check if the struct is blacklisted.
108			if blacklistStruct[ts.Name.Name] {
109				logf("Struct %v is blacklisted; skipping.", ts.Name)
110				continue
111			}
112			st, ok := ts.Type.(*ast.StructType)
113			if !ok {
114				continue
115			}
116			for _, field := range st.Fields.List {
117				se, ok := field.Type.(*ast.StarExpr)
118				if len(field.Names) == 0 || !ok {
119					continue
120				}
121
122				fieldName := field.Names[0]
123				// Skip unexported identifiers.
124				if !fieldName.IsExported() {
125					logf("Field %v is unexported; skipping.", fieldName)
126					continue
127				}
128				// Check if "struct.method" is blacklisted.
129				if key := fmt.Sprintf("%v.Get%v", ts.Name, fieldName); blacklistStructMethod[key] {
130					logf("Method %v is blacklisted; skipping.", key)
131					continue
132				}
133
134				switch x := se.X.(type) {
135				case *ast.ArrayType:
136					t.addArrayType(x, ts.Name.String(), fieldName.String())
137				case *ast.Ident:
138					t.addIdent(x, ts.Name.String(), fieldName.String())
139				case *ast.MapType:
140					t.addMapType(x, ts.Name.String(), fieldName.String())
141				case *ast.SelectorExpr:
142					t.addSelectorExpr(x, ts.Name.String(), fieldName.String())
143				default:
144					logf("processAST: type %q, field %q, unknown %T: %+v", ts.Name, fieldName, x, x)
145				}
146			}
147		}
148	}
149	return nil
150}
151
152func sourceFilter(fi os.FileInfo) bool {
153	return !strings.HasSuffix(fi.Name(), "_test.go") && !strings.HasSuffix(fi.Name(), fileSuffix)
154}
155
156func (t *templateData) dump() error {
157	if len(t.Getters) == 0 {
158		logf("No getters for %v; skipping.", t.filename)
159		return nil
160	}
161
162	// Sort getters by ReceiverType.FieldName.
163	sort.Sort(byName(t.Getters))
164
165	var buf bytes.Buffer
166	if err := sourceTmpl.Execute(&buf, t); err != nil {
167		return err
168	}
169	clean, err := format.Source(buf.Bytes())
170	if err != nil {
171		return err
172	}
173
174	logf("Writing %v...", t.filename)
175	return ioutil.WriteFile(t.filename, clean, 0644)
176}
177
178func newGetter(receiverType, fieldName, fieldType, zeroValue string, namedStruct bool) *getter {
179	return &getter{
180		sortVal:      strings.ToLower(receiverType) + "." + strings.ToLower(fieldName),
181		ReceiverVar:  strings.ToLower(receiverType[:1]),
182		ReceiverType: receiverType,
183		FieldName:    fieldName,
184		FieldType:    fieldType,
185		ZeroValue:    zeroValue,
186		NamedStruct:  namedStruct,
187	}
188}
189
190func (t *templateData) addArrayType(x *ast.ArrayType, receiverType, fieldName string) {
191	var eltType string
192	switch elt := x.Elt.(type) {
193	case *ast.Ident:
194		eltType = elt.String()
195	default:
196		logf("addArrayType: type %q, field %q: unknown elt type: %T %+v; skipping.", receiverType, fieldName, elt, elt)
197		return
198	}
199
200	t.Getters = append(t.Getters, newGetter(receiverType, fieldName, "[]"+eltType, "nil", false))
201}
202
203func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) {
204	var zeroValue string
205	var namedStruct = false
206	switch x.String() {
207	case "int", "int64":
208		zeroValue = "0"
209	case "string":
210		zeroValue = `""`
211	case "bool":
212		zeroValue = "false"
213	case "Timestamp":
214		zeroValue = "Timestamp{}"
215	default:
216		zeroValue = "nil"
217		namedStruct = true
218	}
219
220	t.Getters = append(t.Getters, newGetter(receiverType, fieldName, x.String(), zeroValue, namedStruct))
221}
222
223func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string) {
224	var keyType string
225	switch key := x.Key.(type) {
226	case *ast.Ident:
227		keyType = key.String()
228	default:
229		logf("addMapType: type %q, field %q: unknown key type: %T %+v; skipping.", receiverType, fieldName, key, key)
230		return
231	}
232
233	var valueType string
234	switch value := x.Value.(type) {
235	case *ast.Ident:
236		valueType = value.String()
237	default:
238		logf("addMapType: type %q, field %q: unknown value type: %T %+v; skipping.", receiverType, fieldName, value, value)
239		return
240	}
241
242	fieldType := fmt.Sprintf("map[%v]%v", keyType, valueType)
243	zeroValue := fmt.Sprintf("map[%v]%v{}", keyType, valueType)
244	t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue, false))
245}
246
247func (t *templateData) addSelectorExpr(x *ast.SelectorExpr, receiverType, fieldName string) {
248	if strings.ToLower(fieldName[:1]) == fieldName[:1] { // Non-exported field.
249		return
250	}
251
252	var xX string
253	if xx, ok := x.X.(*ast.Ident); ok {
254		xX = xx.String()
255	}
256
257	switch xX {
258	case "time", "json":
259		if xX == "json" {
260			t.Imports["encoding/json"] = "encoding/json"
261		} else {
262			t.Imports[xX] = xX
263		}
264		fieldType := fmt.Sprintf("%v.%v", xX, x.Sel.Name)
265		zeroValue := fmt.Sprintf("%v.%v{}", xX, x.Sel.Name)
266		if xX == "time" && x.Sel.Name == "Duration" {
267			zeroValue = "0"
268		}
269		t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue, false))
270	default:
271		logf("addSelectorExpr: xX %q, type %q, field %q: unknown x=%+v; skipping.", xX, receiverType, fieldName, x)
272	}
273}
274
275type templateData struct {
276	filename string
277	Year     int
278	Package  string
279	Imports  map[string]string
280	Getters  []*getter
281}
282
283type getter struct {
284	sortVal      string // Lower-case version of "ReceiverType.FieldName".
285	ReceiverVar  string // The one-letter variable name to match the ReceiverType.
286	ReceiverType string
287	FieldName    string
288	FieldType    string
289	ZeroValue    string
290	NamedStruct  bool // Getter for named struct.
291}
292
293type byName []*getter
294
295func (b byName) Len() int           { return len(b) }
296func (b byName) Less(i, j int) bool { return b[i].sortVal < b[j].sortVal }
297func (b byName) Swap(i, j int)      { b[i], b[j] = b[j], b[i] }
298
299const source = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved.
300//
301// Use of this source code is governed by a BSD-style
302// license that can be found in the LICENSE file.
303
304// Code generated by gen-accessors; DO NOT EDIT.
305
306package {{.Package}}
307{{with .Imports}}
308import (
309  {{- range . -}}
310  "{{.}}"
311  {{end -}}
312)
313{{end}}
314{{range .Getters}}
315{{if .NamedStruct}}
316// Get{{.FieldName}} returns the {{.FieldName}} field.
317func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() *{{.FieldType}} {
318  if {{.ReceiverVar}} == nil {
319    return {{.ZeroValue}}
320  }
321  return {{.ReceiverVar}}.{{.FieldName}}
322}
323{{else}}
324// Get{{.FieldName}} returns the {{.FieldName}} field if it's non-nil, zero value otherwise.
325func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} {
326  if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil {
327    return {{.ZeroValue}}
328  }
329  return *{{.ReceiverVar}}.{{.FieldName}}
330}
331{{end}}
332{{end}}
333`
334