1// Protocol Buffers for Go with Gadgets 2// 3// Copyright (c) 2013, The GoGo Authors. All rights reserved. 4// http://github.com/gogo/protobuf 5// 6// Redistribution and use in source and binary forms, with or without 7// modification, are permitted provided that the following conditions are 8// met: 9// 10// * Redistributions of source code must retain the above copyright 11// notice, this list of conditions and the following disclaimer. 12// * Redistributions in binary form must reproduce the above 13// copyright notice, this list of conditions and the following disclaimer 14// in the documentation and/or other materials provided with the 15// distribution. 16// 17// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 18// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 19// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 20// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 21// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 22// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 23// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 24// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 25// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 29package generator 30 31import ( 32 "bytes" 33 "go/parser" 34 "go/printer" 35 "go/token" 36 "path" 37 "strings" 38 39 "github.com/gogo/protobuf/gogoproto" 40 "github.com/gogo/protobuf/proto" 41 descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" 42 plugin "github.com/gogo/protobuf/protoc-gen-gogo/plugin" 43) 44 45func (d *FileDescriptor) Messages() []*Descriptor { 46 return d.desc 47} 48 49func (d *FileDescriptor) Enums() []*EnumDescriptor { 50 return d.enum 51} 52 53func (d *Descriptor) IsGroup() bool { 54 return d.group 55} 56 57func (g *Generator) IsGroup(field *descriptor.FieldDescriptorProto) bool { 58 if d, ok := g.typeNameToObject[field.GetTypeName()].(*Descriptor); ok { 59 return d.IsGroup() 60 } 61 return false 62} 63 64func (g *Generator) TypeNameByObject(typeName string) Object { 65 o, ok := g.typeNameToObject[typeName] 66 if !ok { 67 g.Fail("can't find object with type", typeName) 68 } 69 return o 70} 71 72func (g *Generator) OneOfTypeName(message *Descriptor, field *descriptor.FieldDescriptorProto) string { 73 typeName := message.TypeName() 74 ccTypeName := CamelCaseSlice(typeName) 75 fieldName := g.GetOneOfFieldName(message, field) 76 tname := ccTypeName + "_" + fieldName 77 // It is possible for this to collide with a message or enum 78 // nested in this message. Check for collisions. 79 ok := true 80 for _, desc := range message.nested { 81 if strings.Join(desc.TypeName(), "_") == tname { 82 ok = false 83 break 84 } 85 } 86 for _, enum := range message.enums { 87 if strings.Join(enum.TypeName(), "_") == tname { 88 ok = false 89 break 90 } 91 } 92 if !ok { 93 tname += "_" 94 } 95 return tname 96} 97 98type PluginImports interface { 99 NewImport(pkg string) Single 100 GenerateImports(file *FileDescriptor) 101} 102 103type pluginImports struct { 104 generator *Generator 105 singles []Single 106} 107 108func NewPluginImports(generator *Generator) *pluginImports { 109 return &pluginImports{generator, make([]Single, 0)} 110} 111 112func (this *pluginImports) NewImport(pkg string) Single { 113 imp := newImportedPackage(this.generator.ImportPrefix, pkg) 114 this.singles = append(this.singles, imp) 115 return imp 116} 117 118func (this *pluginImports) GenerateImports(file *FileDescriptor) { 119 for _, s := range this.singles { 120 if s.IsUsed() { 121 this.generator.PrintImport(GoPackageName(s.Name()), GoImportPath(s.Location())) 122 } 123 } 124} 125 126type Single interface { 127 Use() string 128 IsUsed() bool 129 Name() string 130 Location() string 131} 132 133type importedPackage struct { 134 used bool 135 pkg string 136 name string 137 importPrefix string 138} 139 140func newImportedPackage(importPrefix string, pkg string) *importedPackage { 141 return &importedPackage{ 142 pkg: pkg, 143 importPrefix: importPrefix, 144 } 145} 146 147func (this *importedPackage) Use() string { 148 if !this.used { 149 this.name = string(cleanPackageName(this.pkg)) 150 this.used = true 151 } 152 return this.name 153} 154 155func (this *importedPackage) IsUsed() bool { 156 return this.used 157} 158 159func (this *importedPackage) Name() string { 160 return this.name 161} 162 163func (this *importedPackage) Location() string { 164 return this.importPrefix + this.pkg 165} 166 167func (g *Generator) GetFieldName(message *Descriptor, field *descriptor.FieldDescriptorProto) string { 168 goTyp, _ := g.GoType(message, field) 169 fieldname := CamelCase(*field.Name) 170 if gogoproto.IsCustomName(field) { 171 fieldname = gogoproto.GetCustomName(field) 172 } 173 if gogoproto.IsEmbed(field) { 174 fieldname = EmbedFieldName(goTyp) 175 } 176 if field.OneofIndex != nil { 177 fieldname = message.OneofDecl[int(*field.OneofIndex)].GetName() 178 fieldname = CamelCase(fieldname) 179 } 180 for _, f := range methodNames { 181 if f == fieldname { 182 return fieldname + "_" 183 } 184 } 185 if !gogoproto.IsProtoSizer(message.file.FileDescriptorProto, message.DescriptorProto) { 186 if fieldname == "Size" { 187 return fieldname + "_" 188 } 189 } 190 return fieldname 191} 192 193func (g *Generator) GetOneOfFieldName(message *Descriptor, field *descriptor.FieldDescriptorProto) string { 194 goTyp, _ := g.GoType(message, field) 195 fieldname := CamelCase(*field.Name) 196 if gogoproto.IsCustomName(field) { 197 fieldname = gogoproto.GetCustomName(field) 198 } 199 if gogoproto.IsEmbed(field) { 200 fieldname = EmbedFieldName(goTyp) 201 } 202 for _, f := range methodNames { 203 if f == fieldname { 204 return fieldname + "_" 205 } 206 } 207 if !gogoproto.IsProtoSizer(message.file.FileDescriptorProto, message.DescriptorProto) { 208 if fieldname == "Size" { 209 return fieldname + "_" 210 } 211 } 212 return fieldname 213} 214 215func (g *Generator) IsMap(field *descriptor.FieldDescriptorProto) bool { 216 if !field.IsMessage() { 217 return false 218 } 219 byName := g.ObjectNamed(field.GetTypeName()) 220 desc, ok := byName.(*Descriptor) 221 if byName == nil || !ok || !desc.GetOptions().GetMapEntry() { 222 return false 223 } 224 return true 225} 226 227func (g *Generator) GetMapKeyField(field, keyField *descriptor.FieldDescriptorProto) *descriptor.FieldDescriptorProto { 228 if !gogoproto.IsCastKey(field) { 229 return keyField 230 } 231 keyField = proto.Clone(keyField).(*descriptor.FieldDescriptorProto) 232 if keyField.Options == nil { 233 keyField.Options = &descriptor.FieldOptions{} 234 } 235 keyType := gogoproto.GetCastKey(field) 236 if err := proto.SetExtension(keyField.Options, gogoproto.E_Casttype, &keyType); err != nil { 237 g.Fail(err.Error()) 238 } 239 return keyField 240} 241 242func (g *Generator) GetMapValueField(field, valField *descriptor.FieldDescriptorProto) *descriptor.FieldDescriptorProto { 243 if gogoproto.IsCustomType(field) && gogoproto.IsCastValue(field) { 244 g.Fail("cannot have a customtype and casttype: ", field.String()) 245 } 246 valField = proto.Clone(valField).(*descriptor.FieldDescriptorProto) 247 if valField.Options == nil { 248 valField.Options = &descriptor.FieldOptions{} 249 } 250 251 stdtime := gogoproto.IsStdTime(field) 252 if stdtime { 253 if err := proto.SetExtension(valField.Options, gogoproto.E_Stdtime, &stdtime); err != nil { 254 g.Fail(err.Error()) 255 } 256 } 257 258 stddur := gogoproto.IsStdDuration(field) 259 if stddur { 260 if err := proto.SetExtension(valField.Options, gogoproto.E_Stdduration, &stddur); err != nil { 261 g.Fail(err.Error()) 262 } 263 } 264 265 wktptr := gogoproto.IsWktPtr(field) 266 if wktptr { 267 if err := proto.SetExtension(valField.Options, gogoproto.E_Wktpointer, &wktptr); err != nil { 268 g.Fail(err.Error()) 269 } 270 } 271 272 if valType := gogoproto.GetCastValue(field); len(valType) > 0 { 273 if err := proto.SetExtension(valField.Options, gogoproto.E_Casttype, &valType); err != nil { 274 g.Fail(err.Error()) 275 } 276 } 277 if valType := gogoproto.GetCustomType(field); len(valType) > 0 { 278 if err := proto.SetExtension(valField.Options, gogoproto.E_Customtype, &valType); err != nil { 279 g.Fail(err.Error()) 280 } 281 } 282 283 nullable := gogoproto.IsNullable(field) 284 if err := proto.SetExtension(valField.Options, gogoproto.E_Nullable, &nullable); err != nil { 285 g.Fail(err.Error()) 286 } 287 return valField 288} 289 290// GoMapValueTypes returns the map value Go type and the alias map value Go type (for casting), taking into 291// account whether the map is nullable or the value is a message. 292func GoMapValueTypes(mapField, valueField *descriptor.FieldDescriptorProto, goValueType, goValueAliasType string) (nullable bool, outGoType string, outGoAliasType string) { 293 nullable = gogoproto.IsNullable(mapField) && (valueField.IsMessage() || gogoproto.IsCustomType(mapField)) 294 if nullable { 295 // ensure the non-aliased Go value type is a pointer for consistency 296 if strings.HasPrefix(goValueType, "*") { 297 outGoType = goValueType 298 } else { 299 outGoType = "*" + goValueType 300 } 301 outGoAliasType = goValueAliasType 302 } else { 303 outGoType = strings.Replace(goValueType, "*", "", 1) 304 outGoAliasType = strings.Replace(goValueAliasType, "*", "", 1) 305 } 306 return 307} 308 309func GoTypeToName(goTyp string) string { 310 return strings.Replace(strings.Replace(goTyp, "*", "", -1), "[]", "", -1) 311} 312 313func EmbedFieldName(goTyp string) string { 314 goTyp = GoTypeToName(goTyp) 315 goTyps := strings.Split(goTyp, ".") 316 if len(goTyps) == 1 { 317 return goTyp 318 } 319 if len(goTyps) == 2 { 320 return goTyps[1] 321 } 322 panic("unreachable") 323} 324 325func (g *Generator) GeneratePlugin(p Plugin) { 326 plugins = []Plugin{p} 327 p.Init(g) 328 // Generate the output. The generator runs for every file, even the files 329 // that we don't generate output for, so that we can collate the full list 330 // of exported symbols to support public imports. 331 genFileMap := make(map[*FileDescriptor]bool, len(g.genFiles)) 332 for _, file := range g.genFiles { 333 genFileMap[file] = true 334 } 335 for _, file := range g.allFiles { 336 g.Reset() 337 g.writeOutput = genFileMap[file] 338 g.generatePlugin(file, p) 339 if !g.writeOutput { 340 continue 341 } 342 g.Response.File = append(g.Response.File, &plugin.CodeGeneratorResponse_File{ 343 Name: proto.String(file.goFileName(g.pathType)), 344 Content: proto.String(g.String()), 345 }) 346 } 347} 348 349func (g *Generator) SetFile(filename string) { 350 g.file = g.fileByName(filename) 351} 352 353func (g *Generator) generatePlugin(file *FileDescriptor, p Plugin) { 354 g.writtenImports = make(map[string]bool) 355 g.usedPackages = make(map[GoImportPath]bool) 356 g.packageNames = make(map[GoImportPath]GoPackageName) 357 g.usedPackageNames = make(map[GoPackageName]bool) 358 g.addedImports = make(map[GoImportPath]bool) 359 g.file = file 360 361 // Run the plugins before the imports so we know which imports are necessary. 362 p.Generate(file) 363 364 // Generate header and imports last, though they appear first in the output. 365 rem := g.Buffer 366 g.Buffer = new(bytes.Buffer) 367 g.generateHeader() 368 // p.GenerateImports(g.file) 369 g.generateImports() 370 if !g.writeOutput { 371 return 372 } 373 g.Write(rem.Bytes()) 374 375 // Reformat generated code. 376 contents := string(g.Buffer.Bytes()) 377 fset := token.NewFileSet() 378 ast, err := parser.ParseFile(fset, "", g, parser.ParseComments) 379 if err != nil { 380 g.Fail("bad Go source code was generated:", contents, err.Error()) 381 return 382 } 383 g.Reset() 384 err = (&printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}).Fprint(g, fset, ast) 385 if err != nil { 386 g.Fail("generated Go source code could not be reformatted:", err.Error()) 387 } 388} 389 390func GetCustomType(field *descriptor.FieldDescriptorProto) (packageName string, typ string, err error) { 391 return getCustomType(field) 392} 393 394func getCustomType(field *descriptor.FieldDescriptorProto) (packageName string, typ string, err error) { 395 if field.Options != nil { 396 var v interface{} 397 v, err = proto.GetExtension(field.Options, gogoproto.E_Customtype) 398 if err == nil && v.(*string) != nil { 399 ctype := *(v.(*string)) 400 packageName, typ = splitCPackageType(ctype) 401 return packageName, typ, nil 402 } 403 } 404 return "", "", err 405} 406 407func splitCPackageType(ctype string) (packageName string, typ string) { 408 ss := strings.Split(ctype, ".") 409 if len(ss) == 1 { 410 return "", ctype 411 } 412 packageName = strings.Join(ss[0:len(ss)-1], ".") 413 typeName := ss[len(ss)-1] 414 importStr := strings.Map(badToUnderscore, packageName) 415 typ = importStr + "." + typeName 416 return packageName, typ 417} 418 419func getCastType(field *descriptor.FieldDescriptorProto) (packageName string, typ string, err error) { 420 if field.Options != nil { 421 var v interface{} 422 v, err = proto.GetExtension(field.Options, gogoproto.E_Casttype) 423 if err == nil && v.(*string) != nil { 424 ctype := *(v.(*string)) 425 packageName, typ = splitCPackageType(ctype) 426 return packageName, typ, nil 427 } 428 } 429 return "", "", err 430} 431 432func FileName(file *FileDescriptor) string { 433 fname := path.Base(file.FileDescriptorProto.GetName()) 434 fname = strings.Replace(fname, ".proto", "", -1) 435 fname = strings.Replace(fname, "-", "_", -1) 436 fname = strings.Replace(fname, ".", "_", -1) 437 return CamelCase(fname) 438} 439 440func (g *Generator) AllFiles() *descriptor.FileDescriptorSet { 441 set := &descriptor.FileDescriptorSet{} 442 set.File = make([]*descriptor.FileDescriptorProto, len(g.allFiles)) 443 for i := range g.allFiles { 444 set.File[i] = g.allFiles[i].FileDescriptorProto 445 } 446 return set 447} 448 449func (d *Descriptor) Path() string { 450 return d.path 451} 452 453func (g *Generator) useTypes() string { 454 pkg := strings.Map(badToUnderscore, "github.com/gogo/protobuf/types") 455 g.customImports = append(g.customImports, "github.com/gogo/protobuf/types") 456 return pkg 457} 458 459func (d *FileDescriptor) GoPackageName() string { 460 return string(d.packageName) 461} 462