1// Copyright 2017 The go-github AUTHORS. All rights reserved. 2// 3// Use of this source code is governed by a BSD-style 4// license that can be found in the LICENSE file. 5 6// +build ignore 7 8// gen-accessors generates accessor methods for structs with pointer fields. 9// 10// It is meant to be used by go-github contributors in conjunction with the 11// go generate tool before sending a PR to GitHub. 12// Please see the CONTRIBUTING.md file for more information. 13package main 14 15import ( 16 "bytes" 17 "flag" 18 "fmt" 19 "go/ast" 20 "go/format" 21 "go/parser" 22 "go/token" 23 "io/ioutil" 24 "log" 25 "os" 26 "sort" 27 "strings" 28 "text/template" 29) 30 31const ( 32 fileSuffix = "-accessors.go" 33) 34 35var ( 36 verbose = flag.Bool("v", false, "Print verbose log messages") 37 38 sourceTmpl = template.Must(template.New("source").Parse(source)) 39 40 // blacklistStructMethod lists "struct.method" combos to skip. 41 blacklistStructMethod = map[string]bool{ 42 "RepositoryContent.GetContent": true, 43 "Client.GetBaseURL": true, 44 "Client.GetUploadURL": true, 45 "ErrorResponse.GetResponse": true, 46 "RateLimitError.GetResponse": true, 47 "AbuseRateLimitError.GetResponse": true, 48 } 49 // blacklistStruct lists structs to skip. 50 blacklistStruct = map[string]bool{ 51 "Client": true, 52 } 53) 54 55func logf(fmt string, args ...interface{}) { 56 if *verbose { 57 log.Printf(fmt, args...) 58 } 59} 60 61func main() { 62 flag.Parse() 63 fset := token.NewFileSet() 64 65 pkgs, err := parser.ParseDir(fset, ".", sourceFilter, 0) 66 if err != nil { 67 log.Fatal(err) 68 return 69 } 70 71 for pkgName, pkg := range pkgs { 72 t := &templateData{ 73 filename: pkgName + fileSuffix, 74 Year: 2017, 75 Package: pkgName, 76 Imports: map[string]string{}, 77 } 78 for filename, f := range pkg.Files { 79 logf("Processing %v...", filename) 80 if err := t.processAST(f); err != nil { 81 log.Fatal(err) 82 } 83 } 84 if err := t.dump(); err != nil { 85 log.Fatal(err) 86 } 87 } 88 logf("Done.") 89} 90 91func (t *templateData) processAST(f *ast.File) error { 92 for _, decl := range f.Decls { 93 gd, ok := decl.(*ast.GenDecl) 94 if !ok { 95 continue 96 } 97 for _, spec := range gd.Specs { 98 ts, ok := spec.(*ast.TypeSpec) 99 if !ok { 100 continue 101 } 102 // Skip unexported identifiers. 103 if !ts.Name.IsExported() { 104 logf("Struct %v is unexported; skipping.", ts.Name) 105 continue 106 } 107 // Check if the struct is blacklisted. 108 if blacklistStruct[ts.Name.Name] { 109 logf("Struct %v is blacklisted; skipping.", ts.Name) 110 continue 111 } 112 st, ok := ts.Type.(*ast.StructType) 113 if !ok { 114 continue 115 } 116 for _, field := range st.Fields.List { 117 se, ok := field.Type.(*ast.StarExpr) 118 if len(field.Names) == 0 || !ok { 119 continue 120 } 121 122 fieldName := field.Names[0] 123 // Skip unexported identifiers. 124 if !fieldName.IsExported() { 125 logf("Field %v is unexported; skipping.", fieldName) 126 continue 127 } 128 // Check if "struct.method" is blacklisted. 129 if key := fmt.Sprintf("%v.Get%v", ts.Name, fieldName); blacklistStructMethod[key] { 130 logf("Method %v is blacklisted; skipping.", key) 131 continue 132 } 133 134 switch x := se.X.(type) { 135 case *ast.ArrayType: 136 t.addArrayType(x, ts.Name.String(), fieldName.String()) 137 case *ast.Ident: 138 t.addIdent(x, ts.Name.String(), fieldName.String()) 139 case *ast.MapType: 140 t.addMapType(x, ts.Name.String(), fieldName.String()) 141 case *ast.SelectorExpr: 142 t.addSelectorExpr(x, ts.Name.String(), fieldName.String()) 143 default: 144 logf("processAST: type %q, field %q, unknown %T: %+v", ts.Name, fieldName, x, x) 145 } 146 } 147 } 148 } 149 return nil 150} 151 152func sourceFilter(fi os.FileInfo) bool { 153 return !strings.HasSuffix(fi.Name(), "_test.go") && !strings.HasSuffix(fi.Name(), fileSuffix) 154} 155 156func (t *templateData) dump() error { 157 if len(t.Getters) == 0 { 158 logf("No getters for %v; skipping.", t.filename) 159 return nil 160 } 161 162 // Sort getters by ReceiverType.FieldName. 163 sort.Sort(byName(t.Getters)) 164 165 var buf bytes.Buffer 166 if err := sourceTmpl.Execute(&buf, t); err != nil { 167 return err 168 } 169 clean, err := format.Source(buf.Bytes()) 170 if err != nil { 171 return err 172 } 173 174 logf("Writing %v...", t.filename) 175 return ioutil.WriteFile(t.filename, clean, 0644) 176} 177 178func newGetter(receiverType, fieldName, fieldType, zeroValue string, namedStruct bool) *getter { 179 return &getter{ 180 sortVal: strings.ToLower(receiverType) + "." + strings.ToLower(fieldName), 181 ReceiverVar: strings.ToLower(receiverType[:1]), 182 ReceiverType: receiverType, 183 FieldName: fieldName, 184 FieldType: fieldType, 185 ZeroValue: zeroValue, 186 NamedStruct: namedStruct, 187 } 188} 189 190func (t *templateData) addArrayType(x *ast.ArrayType, receiverType, fieldName string) { 191 var eltType string 192 switch elt := x.Elt.(type) { 193 case *ast.Ident: 194 eltType = elt.String() 195 default: 196 logf("addArrayType: type %q, field %q: unknown elt type: %T %+v; skipping.", receiverType, fieldName, elt, elt) 197 return 198 } 199 200 t.Getters = append(t.Getters, newGetter(receiverType, fieldName, "[]"+eltType, "nil", false)) 201} 202 203func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) { 204 var zeroValue string 205 var namedStruct = false 206 switch x.String() { 207 case "int", "int64": 208 zeroValue = "0" 209 case "string": 210 zeroValue = `""` 211 case "bool": 212 zeroValue = "false" 213 case "Timestamp": 214 zeroValue = "Timestamp{}" 215 default: 216 zeroValue = "nil" 217 namedStruct = true 218 } 219 220 t.Getters = append(t.Getters, newGetter(receiverType, fieldName, x.String(), zeroValue, namedStruct)) 221} 222 223func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string) { 224 var keyType string 225 switch key := x.Key.(type) { 226 case *ast.Ident: 227 keyType = key.String() 228 default: 229 logf("addMapType: type %q, field %q: unknown key type: %T %+v; skipping.", receiverType, fieldName, key, key) 230 return 231 } 232 233 var valueType string 234 switch value := x.Value.(type) { 235 case *ast.Ident: 236 valueType = value.String() 237 default: 238 logf("addMapType: type %q, field %q: unknown value type: %T %+v; skipping.", receiverType, fieldName, value, value) 239 return 240 } 241 242 fieldType := fmt.Sprintf("map[%v]%v", keyType, valueType) 243 zeroValue := fmt.Sprintf("map[%v]%v{}", keyType, valueType) 244 t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue, false)) 245} 246 247func (t *templateData) addSelectorExpr(x *ast.SelectorExpr, receiverType, fieldName string) { 248 if strings.ToLower(fieldName[:1]) == fieldName[:1] { // Non-exported field. 249 return 250 } 251 252 var xX string 253 if xx, ok := x.X.(*ast.Ident); ok { 254 xX = xx.String() 255 } 256 257 switch xX { 258 case "time", "json": 259 if xX == "json" { 260 t.Imports["encoding/json"] = "encoding/json" 261 } else { 262 t.Imports[xX] = xX 263 } 264 fieldType := fmt.Sprintf("%v.%v", xX, x.Sel.Name) 265 zeroValue := fmt.Sprintf("%v.%v{}", xX, x.Sel.Name) 266 if xX == "time" && x.Sel.Name == "Duration" { 267 zeroValue = "0" 268 } 269 t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue, false)) 270 default: 271 logf("addSelectorExpr: xX %q, type %q, field %q: unknown x=%+v; skipping.", xX, receiverType, fieldName, x) 272 } 273} 274 275type templateData struct { 276 filename string 277 Year int 278 Package string 279 Imports map[string]string 280 Getters []*getter 281} 282 283type getter struct { 284 sortVal string // Lower-case version of "ReceiverType.FieldName". 285 ReceiverVar string // The one-letter variable name to match the ReceiverType. 286 ReceiverType string 287 FieldName string 288 FieldType string 289 ZeroValue string 290 NamedStruct bool // Getter for named struct. 291} 292 293type byName []*getter 294 295func (b byName) Len() int { return len(b) } 296func (b byName) Less(i, j int) bool { return b[i].sortVal < b[j].sortVal } 297func (b byName) Swap(i, j int) { b[i], b[j] = b[j], b[i] } 298 299const source = `// Copyright {{.Year}} The go-github AUTHORS. All rights reserved. 300// 301// Use of this source code is governed by a BSD-style 302// license that can be found in the LICENSE file. 303 304// Code generated by gen-accessors; DO NOT EDIT. 305 306package {{.Package}} 307{{with .Imports}} 308import ( 309 {{- range . -}} 310 "{{.}}" 311 {{end -}} 312) 313{{end}} 314{{range .Getters}} 315{{if .NamedStruct}} 316// Get{{.FieldName}} returns the {{.FieldName}} field. 317func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() *{{.FieldType}} { 318 if {{.ReceiverVar}} == nil { 319 return {{.ZeroValue}} 320 } 321 return {{.ReceiverVar}}.{{.FieldName}} 322} 323{{else}} 324// Get{{.FieldName}} returns the {{.FieldName}} field if it's non-nil, zero value otherwise. 325func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} { 326 if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil { 327 return {{.ZeroValue}} 328 } 329 return *{{.ReceiverVar}}.{{.FieldName}} 330} 331{{end}} 332{{end}} 333` 334