1// Copyright 2018 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5// Package protogen provides support for writing protoc plugins. 6// 7// Plugins for protoc, the Protocol Buffer compiler, 8// are programs which read a CodeGeneratorRequest message from standard input 9// and write a CodeGeneratorResponse message to standard output. 10// This package provides support for writing plugins which generate Go code. 11package protogen 12 13import ( 14 "bufio" 15 "bytes" 16 "fmt" 17 "go/ast" 18 "go/parser" 19 "go/printer" 20 "go/token" 21 "go/types" 22 "io/ioutil" 23 "os" 24 "path" 25 "path/filepath" 26 "sort" 27 "strconv" 28 "strings" 29 30 "google.golang.org/protobuf/encoding/prototext" 31 "google.golang.org/protobuf/internal/genid" 32 "google.golang.org/protobuf/internal/strs" 33 "google.golang.org/protobuf/proto" 34 "google.golang.org/protobuf/reflect/protodesc" 35 "google.golang.org/protobuf/reflect/protoreflect" 36 "google.golang.org/protobuf/reflect/protoregistry" 37 38 "google.golang.org/protobuf/types/descriptorpb" 39 "google.golang.org/protobuf/types/pluginpb" 40) 41 42const goPackageDocURL = "https://developers.google.com/protocol-buffers/docs/reference/go-generated#package" 43 44// Run executes a function as a protoc plugin. 45// 46// It reads a CodeGeneratorRequest message from os.Stdin, invokes the plugin 47// function, and writes a CodeGeneratorResponse message to os.Stdout. 48// 49// If a failure occurs while reading or writing, Run prints an error to 50// os.Stderr and calls os.Exit(1). 51func (opts Options) Run(f func(*Plugin) error) { 52 if err := run(opts, f); err != nil { 53 fmt.Fprintf(os.Stderr, "%s: %v\n", filepath.Base(os.Args[0]), err) 54 os.Exit(1) 55 } 56} 57 58func run(opts Options, f func(*Plugin) error) error { 59 if len(os.Args) > 1 { 60 return fmt.Errorf("unknown argument %q (this program should be run by protoc, not directly)", os.Args[1]) 61 } 62 in, err := ioutil.ReadAll(os.Stdin) 63 if err != nil { 64 return err 65 } 66 req := &pluginpb.CodeGeneratorRequest{} 67 if err := proto.Unmarshal(in, req); err != nil { 68 return err 69 } 70 gen, err := opts.New(req) 71 if err != nil { 72 return err 73 } 74 if err := f(gen); err != nil { 75 // Errors from the plugin function are reported by setting the 76 // error field in the CodeGeneratorResponse. 77 // 78 // In contrast, errors that indicate a problem in protoc 79 // itself (unparsable input, I/O errors, etc.) are reported 80 // to stderr. 81 gen.Error(err) 82 } 83 resp := gen.Response() 84 out, err := proto.Marshal(resp) 85 if err != nil { 86 return err 87 } 88 if _, err := os.Stdout.Write(out); err != nil { 89 return err 90 } 91 return nil 92} 93 94// A Plugin is a protoc plugin invocation. 95type Plugin struct { 96 // Request is the CodeGeneratorRequest provided by protoc. 97 Request *pluginpb.CodeGeneratorRequest 98 99 // Files is the set of files to generate and everything they import. 100 // Files appear in topological order, so each file appears before any 101 // file that imports it. 102 Files []*File 103 FilesByPath map[string]*File 104 105 // SupportedFeatures is the set of protobuf language features supported by 106 // this generator plugin. See the documentation for 107 // google.protobuf.CodeGeneratorResponse.supported_features for details. 108 SupportedFeatures uint64 109 110 fileReg *protoregistry.Files 111 enumsByName map[protoreflect.FullName]*Enum 112 messagesByName map[protoreflect.FullName]*Message 113 annotateCode bool 114 pathType pathType 115 module string 116 genFiles []*GeneratedFile 117 opts Options 118 err error 119} 120 121type Options struct { 122 // If ParamFunc is non-nil, it will be called with each unknown 123 // generator parameter. 124 // 125 // Plugins for protoc can accept parameters from the command line, 126 // passed in the --<lang>_out protoc, separated from the output 127 // directory with a colon; e.g., 128 // 129 // --go_out=<param1>=<value1>,<param2>=<value2>:<output_directory> 130 // 131 // Parameters passed in this fashion as a comma-separated list of 132 // key=value pairs will be passed to the ParamFunc. 133 // 134 // The (flag.FlagSet).Set method matches this function signature, 135 // so parameters can be converted into flags as in the following: 136 // 137 // var flags flag.FlagSet 138 // value := flags.Bool("param", false, "") 139 // opts := &protogen.Options{ 140 // ParamFunc: flags.Set, 141 // } 142 // protogen.Run(opts, func(p *protogen.Plugin) error { 143 // if *value { ... } 144 // }) 145 ParamFunc func(name, value string) error 146 147 // ImportRewriteFunc is called with the import path of each package 148 // imported by a generated file. It returns the import path to use 149 // for this package. 150 ImportRewriteFunc func(GoImportPath) GoImportPath 151} 152 153// New returns a new Plugin. 154func (opts Options) New(req *pluginpb.CodeGeneratorRequest) (*Plugin, error) { 155 gen := &Plugin{ 156 Request: req, 157 FilesByPath: make(map[string]*File), 158 fileReg: new(protoregistry.Files), 159 enumsByName: make(map[protoreflect.FullName]*Enum), 160 messagesByName: make(map[protoreflect.FullName]*Message), 161 opts: opts, 162 } 163 164 packageNames := make(map[string]GoPackageName) // filename -> package name 165 importPaths := make(map[string]GoImportPath) // filename -> import path 166 for _, param := range strings.Split(req.GetParameter(), ",") { 167 var value string 168 if i := strings.Index(param, "="); i >= 0 { 169 value = param[i+1:] 170 param = param[0:i] 171 } 172 switch param { 173 case "": 174 // Ignore. 175 case "module": 176 gen.module = value 177 case "paths": 178 switch value { 179 case "import": 180 gen.pathType = pathTypeImport 181 case "source_relative": 182 gen.pathType = pathTypeSourceRelative 183 default: 184 return nil, fmt.Errorf(`unknown path type %q: want "import" or "source_relative"`, value) 185 } 186 case "annotate_code": 187 switch value { 188 case "true", "": 189 gen.annotateCode = true 190 case "false": 191 default: 192 return nil, fmt.Errorf(`bad value for parameter %q: want "true" or "false"`, param) 193 } 194 default: 195 if param[0] == 'M' { 196 impPath, pkgName := splitImportPathAndPackageName(value) 197 if pkgName != "" { 198 packageNames[param[1:]] = pkgName 199 } 200 if impPath != "" { 201 importPaths[param[1:]] = impPath 202 } 203 continue 204 } 205 if opts.ParamFunc != nil { 206 if err := opts.ParamFunc(param, value); err != nil { 207 return nil, err 208 } 209 } 210 } 211 } 212 // When the module= option is provided, we strip the module name 213 // prefix from generated files. This only makes sense if generated 214 // filenames are based on the import path. 215 if gen.module != "" && gen.pathType == pathTypeSourceRelative { 216 return nil, fmt.Errorf("cannot use module= with paths=source_relative") 217 } 218 219 // Figure out the import path and package name for each file. 220 // 221 // The rules here are complicated and have grown organically over time. 222 // Interactions between different ways of specifying package information 223 // may be surprising. 224 // 225 // The recommended approach is to include a go_package option in every 226 // .proto source file specifying the full import path of the Go package 227 // associated with this file. 228 // 229 // option go_package = "google.golang.org/protobuf/types/known/anypb"; 230 // 231 // Alternatively, build systems which want to exert full control over 232 // import paths may specify M<filename>=<import_path> flags. 233 for _, fdesc := range gen.Request.ProtoFile { 234 // The "M" command-line flags take precedence over 235 // the "go_package" option in the .proto source file. 236 filename := fdesc.GetName() 237 impPath, pkgName := splitImportPathAndPackageName(fdesc.GetOptions().GetGoPackage()) 238 if importPaths[filename] == "" && impPath != "" { 239 importPaths[filename] = impPath 240 } 241 if packageNames[filename] == "" && pkgName != "" { 242 packageNames[filename] = pkgName 243 } 244 switch { 245 case importPaths[filename] == "": 246 // The import path must be specified one way or another. 247 return nil, fmt.Errorf( 248 "unable to determine Go import path for %q\n\n"+ 249 "Please specify either:\n"+ 250 "\t• a \"go_package\" option in the .proto source file, or\n"+ 251 "\t• a \"M\" argument on the command line.\n\n"+ 252 "See %v for more information.\n", 253 fdesc.GetName(), goPackageDocURL) 254 case !strings.Contains(string(importPaths[filename]), "/"): 255 // Check that import paths contain at least one slash to avoid a 256 // common mistake where import path is confused with package name. 257 return nil, fmt.Errorf( 258 "invalid Go import path %q for %q\n\n"+ 259 "The import path must contain at least one forward slash ('/') character.\n\n"+ 260 "See %v for more information.\n", 261 string(importPaths[filename]), fdesc.GetName(), goPackageDocURL) 262 case packageNames[filename] == "": 263 // If the package name is not explicitly specified, 264 // then derive a reasonable package name from the import path. 265 // 266 // NOTE: The package name is derived first from the import path in 267 // the "go_package" option (if present) before trying the "M" flag. 268 // The inverted order for this is because the primary use of the "M" 269 // flag is by build systems that have full control over the 270 // import paths all packages, where it is generally expected that 271 // the Go package name still be identical for the Go toolchain and 272 // for custom build systems like Bazel. 273 if impPath == "" { 274 impPath = importPaths[filename] 275 } 276 packageNames[filename] = cleanPackageName(path.Base(string(impPath))) 277 } 278 } 279 280 // Consistency check: Every file with the same Go import path should have 281 // the same Go package name. 282 packageFiles := make(map[GoImportPath][]string) 283 for filename, importPath := range importPaths { 284 if _, ok := packageNames[filename]; !ok { 285 // Skip files mentioned in a M<file>=<import_path> parameter 286 // but which do not appear in the CodeGeneratorRequest. 287 continue 288 } 289 packageFiles[importPath] = append(packageFiles[importPath], filename) 290 } 291 for importPath, filenames := range packageFiles { 292 for i := 1; i < len(filenames); i++ { 293 if a, b := packageNames[filenames[0]], packageNames[filenames[i]]; a != b { 294 return nil, fmt.Errorf("Go package %v has inconsistent names %v (%v) and %v (%v)", 295 importPath, a, filenames[0], b, filenames[i]) 296 } 297 } 298 } 299 300 for _, fdesc := range gen.Request.ProtoFile { 301 filename := fdesc.GetName() 302 if gen.FilesByPath[filename] != nil { 303 return nil, fmt.Errorf("duplicate file name: %q", filename) 304 } 305 f, err := newFile(gen, fdesc, packageNames[filename], importPaths[filename]) 306 if err != nil { 307 return nil, err 308 } 309 gen.Files = append(gen.Files, f) 310 gen.FilesByPath[filename] = f 311 } 312 for _, filename := range gen.Request.FileToGenerate { 313 f, ok := gen.FilesByPath[filename] 314 if !ok { 315 return nil, fmt.Errorf("no descriptor for generated file: %v", filename) 316 } 317 f.Generate = true 318 } 319 return gen, nil 320} 321 322// Error records an error in code generation. The generator will report the 323// error back to protoc and will not produce output. 324func (gen *Plugin) Error(err error) { 325 if gen.err == nil { 326 gen.err = err 327 } 328} 329 330// Response returns the generator output. 331func (gen *Plugin) Response() *pluginpb.CodeGeneratorResponse { 332 resp := &pluginpb.CodeGeneratorResponse{} 333 if gen.err != nil { 334 resp.Error = proto.String(gen.err.Error()) 335 return resp 336 } 337 for _, g := range gen.genFiles { 338 if g.skip { 339 continue 340 } 341 content, err := g.Content() 342 if err != nil { 343 return &pluginpb.CodeGeneratorResponse{ 344 Error: proto.String(err.Error()), 345 } 346 } 347 filename := g.filename 348 if gen.module != "" { 349 trim := gen.module + "/" 350 if !strings.HasPrefix(filename, trim) { 351 return &pluginpb.CodeGeneratorResponse{ 352 Error: proto.String(fmt.Sprintf("%v: generated file does not match prefix %q", filename, gen.module)), 353 } 354 } 355 filename = strings.TrimPrefix(filename, trim) 356 } 357 resp.File = append(resp.File, &pluginpb.CodeGeneratorResponse_File{ 358 Name: proto.String(filename), 359 Content: proto.String(string(content)), 360 }) 361 if gen.annotateCode && strings.HasSuffix(g.filename, ".go") { 362 meta, err := g.metaFile(content) 363 if err != nil { 364 return &pluginpb.CodeGeneratorResponse{ 365 Error: proto.String(err.Error()), 366 } 367 } 368 resp.File = append(resp.File, &pluginpb.CodeGeneratorResponse_File{ 369 Name: proto.String(filename + ".meta"), 370 Content: proto.String(meta), 371 }) 372 } 373 } 374 if gen.SupportedFeatures > 0 { 375 resp.SupportedFeatures = proto.Uint64(gen.SupportedFeatures) 376 } 377 return resp 378} 379 380// A File describes a .proto source file. 381type File struct { 382 Desc protoreflect.FileDescriptor 383 Proto *descriptorpb.FileDescriptorProto 384 385 GoDescriptorIdent GoIdent // name of Go variable for the file descriptor 386 GoPackageName GoPackageName // name of this file's Go package 387 GoImportPath GoImportPath // import path of this file's Go package 388 389 Enums []*Enum // top-level enum declarations 390 Messages []*Message // top-level message declarations 391 Extensions []*Extension // top-level extension declarations 392 Services []*Service // top-level service declarations 393 394 Generate bool // true if we should generate code for this file 395 396 // GeneratedFilenamePrefix is used to construct filenames for generated 397 // files associated with this source file. 398 // 399 // For example, the source file "dir/foo.proto" might have a filename prefix 400 // of "dir/foo". Appending ".pb.go" produces an output file of "dir/foo.pb.go". 401 GeneratedFilenamePrefix string 402 403 location Location 404} 405 406func newFile(gen *Plugin, p *descriptorpb.FileDescriptorProto, packageName GoPackageName, importPath GoImportPath) (*File, error) { 407 desc, err := protodesc.NewFile(p, gen.fileReg) 408 if err != nil { 409 return nil, fmt.Errorf("invalid FileDescriptorProto %q: %v", p.GetName(), err) 410 } 411 if err := gen.fileReg.RegisterFile(desc); err != nil { 412 return nil, fmt.Errorf("cannot register descriptor %q: %v", p.GetName(), err) 413 } 414 f := &File{ 415 Desc: desc, 416 Proto: p, 417 GoPackageName: packageName, 418 GoImportPath: importPath, 419 location: Location{SourceFile: desc.Path()}, 420 } 421 422 // Determine the prefix for generated Go files. 423 prefix := p.GetName() 424 if ext := path.Ext(prefix); ext == ".proto" || ext == ".protodevel" { 425 prefix = prefix[:len(prefix)-len(ext)] 426 } 427 switch gen.pathType { 428 case pathTypeImport: 429 // If paths=import, the output filename is derived from the Go import path. 430 prefix = path.Join(string(f.GoImportPath), path.Base(prefix)) 431 case pathTypeSourceRelative: 432 // If paths=source_relative, the output filename is derived from 433 // the input filename. 434 } 435 f.GoDescriptorIdent = GoIdent{ 436 GoName: "File_" + strs.GoSanitized(p.GetName()), 437 GoImportPath: f.GoImportPath, 438 } 439 f.GeneratedFilenamePrefix = prefix 440 441 for i, eds := 0, desc.Enums(); i < eds.Len(); i++ { 442 f.Enums = append(f.Enums, newEnum(gen, f, nil, eds.Get(i))) 443 } 444 for i, mds := 0, desc.Messages(); i < mds.Len(); i++ { 445 f.Messages = append(f.Messages, newMessage(gen, f, nil, mds.Get(i))) 446 } 447 for i, xds := 0, desc.Extensions(); i < xds.Len(); i++ { 448 f.Extensions = append(f.Extensions, newField(gen, f, nil, xds.Get(i))) 449 } 450 for i, sds := 0, desc.Services(); i < sds.Len(); i++ { 451 f.Services = append(f.Services, newService(gen, f, sds.Get(i))) 452 } 453 for _, message := range f.Messages { 454 if err := message.resolveDependencies(gen); err != nil { 455 return nil, err 456 } 457 } 458 for _, extension := range f.Extensions { 459 if err := extension.resolveDependencies(gen); err != nil { 460 return nil, err 461 } 462 } 463 for _, service := range f.Services { 464 for _, method := range service.Methods { 465 if err := method.resolveDependencies(gen); err != nil { 466 return nil, err 467 } 468 } 469 } 470 return f, nil 471} 472 473// splitImportPathAndPackageName splits off the optional Go package name 474// from the Go import path when seperated by a ';' delimiter. 475func splitImportPathAndPackageName(s string) (GoImportPath, GoPackageName) { 476 if i := strings.Index(s, ";"); i >= 0 { 477 return GoImportPath(s[:i]), GoPackageName(s[i+1:]) 478 } 479 return GoImportPath(s), "" 480} 481 482// An Enum describes an enum. 483type Enum struct { 484 Desc protoreflect.EnumDescriptor 485 486 GoIdent GoIdent // name of the generated Go type 487 488 Values []*EnumValue // enum value declarations 489 490 Location Location // location of this enum 491 Comments CommentSet // comments associated with this enum 492} 493 494func newEnum(gen *Plugin, f *File, parent *Message, desc protoreflect.EnumDescriptor) *Enum { 495 var loc Location 496 if parent != nil { 497 loc = parent.Location.appendPath(genid.DescriptorProto_EnumType_field_number, desc.Index()) 498 } else { 499 loc = f.location.appendPath(genid.FileDescriptorProto_EnumType_field_number, desc.Index()) 500 } 501 enum := &Enum{ 502 Desc: desc, 503 GoIdent: newGoIdent(f, desc), 504 Location: loc, 505 Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)), 506 } 507 gen.enumsByName[desc.FullName()] = enum 508 for i, vds := 0, enum.Desc.Values(); i < vds.Len(); i++ { 509 enum.Values = append(enum.Values, newEnumValue(gen, f, parent, enum, vds.Get(i))) 510 } 511 return enum 512} 513 514// An EnumValue describes an enum value. 515type EnumValue struct { 516 Desc protoreflect.EnumValueDescriptor 517 518 GoIdent GoIdent // name of the generated Go declaration 519 520 Parent *Enum // enum in which this value is declared 521 522 Location Location // location of this enum value 523 Comments CommentSet // comments associated with this enum value 524} 525 526func newEnumValue(gen *Plugin, f *File, message *Message, enum *Enum, desc protoreflect.EnumValueDescriptor) *EnumValue { 527 // A top-level enum value's name is: EnumName_ValueName 528 // An enum value contained in a message is: MessageName_ValueName 529 // 530 // For historical reasons, enum value names are not camel-cased. 531 parentIdent := enum.GoIdent 532 if message != nil { 533 parentIdent = message.GoIdent 534 } 535 name := parentIdent.GoName + "_" + string(desc.Name()) 536 loc := enum.Location.appendPath(genid.EnumDescriptorProto_Value_field_number, desc.Index()) 537 return &EnumValue{ 538 Desc: desc, 539 GoIdent: f.GoImportPath.Ident(name), 540 Parent: enum, 541 Location: loc, 542 Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)), 543 } 544} 545 546// A Message describes a message. 547type Message struct { 548 Desc protoreflect.MessageDescriptor 549 550 GoIdent GoIdent // name of the generated Go type 551 552 Fields []*Field // message field declarations 553 Oneofs []*Oneof // message oneof declarations 554 555 Enums []*Enum // nested enum declarations 556 Messages []*Message // nested message declarations 557 Extensions []*Extension // nested extension declarations 558 559 Location Location // location of this message 560 Comments CommentSet // comments associated with this message 561} 562 563func newMessage(gen *Plugin, f *File, parent *Message, desc protoreflect.MessageDescriptor) *Message { 564 var loc Location 565 if parent != nil { 566 loc = parent.Location.appendPath(genid.DescriptorProto_NestedType_field_number, desc.Index()) 567 } else { 568 loc = f.location.appendPath(genid.FileDescriptorProto_MessageType_field_number, desc.Index()) 569 } 570 message := &Message{ 571 Desc: desc, 572 GoIdent: newGoIdent(f, desc), 573 Location: loc, 574 Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)), 575 } 576 gen.messagesByName[desc.FullName()] = message 577 for i, eds := 0, desc.Enums(); i < eds.Len(); i++ { 578 message.Enums = append(message.Enums, newEnum(gen, f, message, eds.Get(i))) 579 } 580 for i, mds := 0, desc.Messages(); i < mds.Len(); i++ { 581 message.Messages = append(message.Messages, newMessage(gen, f, message, mds.Get(i))) 582 } 583 for i, fds := 0, desc.Fields(); i < fds.Len(); i++ { 584 message.Fields = append(message.Fields, newField(gen, f, message, fds.Get(i))) 585 } 586 for i, ods := 0, desc.Oneofs(); i < ods.Len(); i++ { 587 message.Oneofs = append(message.Oneofs, newOneof(gen, f, message, ods.Get(i))) 588 } 589 for i, xds := 0, desc.Extensions(); i < xds.Len(); i++ { 590 message.Extensions = append(message.Extensions, newField(gen, f, message, xds.Get(i))) 591 } 592 593 // Resolve local references between fields and oneofs. 594 for _, field := range message.Fields { 595 if od := field.Desc.ContainingOneof(); od != nil { 596 oneof := message.Oneofs[od.Index()] 597 field.Oneof = oneof 598 oneof.Fields = append(oneof.Fields, field) 599 } 600 } 601 602 // Field name conflict resolution. 603 // 604 // We assume well-known method names that may be attached to a generated 605 // message type, as well as a 'Get*' method for each field. For each 606 // field in turn, we add _s to its name until there are no conflicts. 607 // 608 // Any change to the following set of method names is a potential 609 // incompatible API change because it may change generated field names. 610 // 611 // TODO: If we ever support a 'go_name' option to set the Go name of a 612 // field, we should consider dropping this entirely. The conflict 613 // resolution algorithm is subtle and surprising (changing the order 614 // in which fields appear in the .proto source file can change the 615 // names of fields in generated code), and does not adapt well to 616 // adding new per-field methods such as setters. 617 usedNames := map[string]bool{ 618 "Reset": true, 619 "String": true, 620 "ProtoMessage": true, 621 "Marshal": true, 622 "Unmarshal": true, 623 "ExtensionRangeArray": true, 624 "ExtensionMap": true, 625 "Descriptor": true, 626 } 627 makeNameUnique := func(name string, hasGetter bool) string { 628 for usedNames[name] || (hasGetter && usedNames["Get"+name]) { 629 name += "_" 630 } 631 usedNames[name] = true 632 usedNames["Get"+name] = hasGetter 633 return name 634 } 635 for _, field := range message.Fields { 636 field.GoName = makeNameUnique(field.GoName, true) 637 field.GoIdent.GoName = message.GoIdent.GoName + "_" + field.GoName 638 if field.Oneof != nil && field.Oneof.Fields[0] == field { 639 // Make the name for a oneof unique as well. For historical reasons, 640 // this assumes that a getter method is not generated for oneofs. 641 // This is incorrect, but fixing it breaks existing code. 642 field.Oneof.GoName = makeNameUnique(field.Oneof.GoName, false) 643 field.Oneof.GoIdent.GoName = message.GoIdent.GoName + "_" + field.Oneof.GoName 644 } 645 } 646 647 // Oneof field name conflict resolution. 648 // 649 // This conflict resolution is incomplete as it does not consider collisions 650 // with other oneof field types, but fixing it breaks existing code. 651 for _, field := range message.Fields { 652 if field.Oneof != nil { 653 Loop: 654 for { 655 for _, nestedMessage := range message.Messages { 656 if nestedMessage.GoIdent == field.GoIdent { 657 field.GoIdent.GoName += "_" 658 continue Loop 659 } 660 } 661 for _, nestedEnum := range message.Enums { 662 if nestedEnum.GoIdent == field.GoIdent { 663 field.GoIdent.GoName += "_" 664 continue Loop 665 } 666 } 667 break Loop 668 } 669 } 670 } 671 672 return message 673} 674 675func (message *Message) resolveDependencies(gen *Plugin) error { 676 for _, field := range message.Fields { 677 if err := field.resolveDependencies(gen); err != nil { 678 return err 679 } 680 } 681 for _, message := range message.Messages { 682 if err := message.resolveDependencies(gen); err != nil { 683 return err 684 } 685 } 686 for _, extension := range message.Extensions { 687 if err := extension.resolveDependencies(gen); err != nil { 688 return err 689 } 690 } 691 return nil 692} 693 694// A Field describes a message field. 695type Field struct { 696 Desc protoreflect.FieldDescriptor 697 698 // GoName is the base name of this field's Go field and methods. 699 // For code generated by protoc-gen-go, this means a field named 700 // '{{GoName}}' and a getter method named 'Get{{GoName}}'. 701 GoName string // e.g., "FieldName" 702 703 // GoIdent is the base name of a top-level declaration for this field. 704 // For code generated by protoc-gen-go, this means a wrapper type named 705 // '{{GoIdent}}' for members fields of a oneof, and a variable named 706 // 'E_{{GoIdent}}' for extension fields. 707 GoIdent GoIdent // e.g., "MessageName_FieldName" 708 709 Parent *Message // message in which this field is declared; nil if top-level extension 710 Oneof *Oneof // containing oneof; nil if not part of a oneof 711 Extendee *Message // extended message for extension fields; nil otherwise 712 713 Enum *Enum // type for enum fields; nil otherwise 714 Message *Message // type for message or group fields; nil otherwise 715 716 Location Location // location of this field 717 Comments CommentSet // comments associated with this field 718} 719 720func newField(gen *Plugin, f *File, message *Message, desc protoreflect.FieldDescriptor) *Field { 721 var loc Location 722 switch { 723 case desc.IsExtension() && message == nil: 724 loc = f.location.appendPath(genid.FileDescriptorProto_Extension_field_number, desc.Index()) 725 case desc.IsExtension() && message != nil: 726 loc = message.Location.appendPath(genid.DescriptorProto_Extension_field_number, desc.Index()) 727 default: 728 loc = message.Location.appendPath(genid.DescriptorProto_Field_field_number, desc.Index()) 729 } 730 camelCased := strs.GoCamelCase(string(desc.Name())) 731 var parentPrefix string 732 if message != nil { 733 parentPrefix = message.GoIdent.GoName + "_" 734 } 735 field := &Field{ 736 Desc: desc, 737 GoName: camelCased, 738 GoIdent: GoIdent{ 739 GoImportPath: f.GoImportPath, 740 GoName: parentPrefix + camelCased, 741 }, 742 Parent: message, 743 Location: loc, 744 Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)), 745 } 746 return field 747} 748 749func (field *Field) resolveDependencies(gen *Plugin) error { 750 desc := field.Desc 751 switch desc.Kind() { 752 case protoreflect.EnumKind: 753 name := field.Desc.Enum().FullName() 754 enum, ok := gen.enumsByName[name] 755 if !ok { 756 return fmt.Errorf("field %v: no descriptor for enum %v", desc.FullName(), name) 757 } 758 field.Enum = enum 759 case protoreflect.MessageKind, protoreflect.GroupKind: 760 name := desc.Message().FullName() 761 message, ok := gen.messagesByName[name] 762 if !ok { 763 return fmt.Errorf("field %v: no descriptor for type %v", desc.FullName(), name) 764 } 765 field.Message = message 766 } 767 if desc.IsExtension() { 768 name := desc.ContainingMessage().FullName() 769 message, ok := gen.messagesByName[name] 770 if !ok { 771 return fmt.Errorf("field %v: no descriptor for type %v", desc.FullName(), name) 772 } 773 field.Extendee = message 774 } 775 return nil 776} 777 778// A Oneof describes a message oneof. 779type Oneof struct { 780 Desc protoreflect.OneofDescriptor 781 782 // GoName is the base name of this oneof's Go field and methods. 783 // For code generated by protoc-gen-go, this means a field named 784 // '{{GoName}}' and a getter method named 'Get{{GoName}}'. 785 GoName string // e.g., "OneofName" 786 787 // GoIdent is the base name of a top-level declaration for this oneof. 788 GoIdent GoIdent // e.g., "MessageName_OneofName" 789 790 Parent *Message // message in which this oneof is declared 791 792 Fields []*Field // fields that are part of this oneof 793 794 Location Location // location of this oneof 795 Comments CommentSet // comments associated with this oneof 796} 797 798func newOneof(gen *Plugin, f *File, message *Message, desc protoreflect.OneofDescriptor) *Oneof { 799 loc := message.Location.appendPath(genid.DescriptorProto_OneofDecl_field_number, desc.Index()) 800 camelCased := strs.GoCamelCase(string(desc.Name())) 801 parentPrefix := message.GoIdent.GoName + "_" 802 return &Oneof{ 803 Desc: desc, 804 Parent: message, 805 GoName: camelCased, 806 GoIdent: GoIdent{ 807 GoImportPath: f.GoImportPath, 808 GoName: parentPrefix + camelCased, 809 }, 810 Location: loc, 811 Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)), 812 } 813} 814 815// Extension is an alias of Field for documentation. 816type Extension = Field 817 818// A Service describes a service. 819type Service struct { 820 Desc protoreflect.ServiceDescriptor 821 822 GoName string 823 824 Methods []*Method // service method declarations 825 826 Location Location // location of this service 827 Comments CommentSet // comments associated with this service 828} 829 830func newService(gen *Plugin, f *File, desc protoreflect.ServiceDescriptor) *Service { 831 loc := f.location.appendPath(genid.FileDescriptorProto_Service_field_number, desc.Index()) 832 service := &Service{ 833 Desc: desc, 834 GoName: strs.GoCamelCase(string(desc.Name())), 835 Location: loc, 836 Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)), 837 } 838 for i, mds := 0, desc.Methods(); i < mds.Len(); i++ { 839 service.Methods = append(service.Methods, newMethod(gen, f, service, mds.Get(i))) 840 } 841 return service 842} 843 844// A Method describes a method in a service. 845type Method struct { 846 Desc protoreflect.MethodDescriptor 847 848 GoName string 849 850 Parent *Service // service in which this method is declared 851 852 Input *Message 853 Output *Message 854 855 Location Location // location of this method 856 Comments CommentSet // comments associated with this method 857} 858 859func newMethod(gen *Plugin, f *File, service *Service, desc protoreflect.MethodDescriptor) *Method { 860 loc := service.Location.appendPath(genid.ServiceDescriptorProto_Method_field_number, desc.Index()) 861 method := &Method{ 862 Desc: desc, 863 GoName: strs.GoCamelCase(string(desc.Name())), 864 Parent: service, 865 Location: loc, 866 Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)), 867 } 868 return method 869} 870 871func (method *Method) resolveDependencies(gen *Plugin) error { 872 desc := method.Desc 873 874 inName := desc.Input().FullName() 875 in, ok := gen.messagesByName[inName] 876 if !ok { 877 return fmt.Errorf("method %v: no descriptor for type %v", desc.FullName(), inName) 878 } 879 method.Input = in 880 881 outName := desc.Output().FullName() 882 out, ok := gen.messagesByName[outName] 883 if !ok { 884 return fmt.Errorf("method %v: no descriptor for type %v", desc.FullName(), outName) 885 } 886 method.Output = out 887 888 return nil 889} 890 891// A GeneratedFile is a generated file. 892type GeneratedFile struct { 893 gen *Plugin 894 skip bool 895 filename string 896 goImportPath GoImportPath 897 buf bytes.Buffer 898 packageNames map[GoImportPath]GoPackageName 899 usedPackageNames map[GoPackageName]bool 900 manualImports map[GoImportPath]bool 901 annotations map[string][]Location 902} 903 904// NewGeneratedFile creates a new generated file with the given filename 905// and import path. 906func (gen *Plugin) NewGeneratedFile(filename string, goImportPath GoImportPath) *GeneratedFile { 907 g := &GeneratedFile{ 908 gen: gen, 909 filename: filename, 910 goImportPath: goImportPath, 911 packageNames: make(map[GoImportPath]GoPackageName), 912 usedPackageNames: make(map[GoPackageName]bool), 913 manualImports: make(map[GoImportPath]bool), 914 annotations: make(map[string][]Location), 915 } 916 917 // All predeclared identifiers in Go are already used. 918 for _, s := range types.Universe.Names() { 919 g.usedPackageNames[GoPackageName(s)] = true 920 } 921 922 gen.genFiles = append(gen.genFiles, g) 923 return g 924} 925 926// P prints a line to the generated output. It converts each parameter to a 927// string following the same rules as fmt.Print. It never inserts spaces 928// between parameters. 929func (g *GeneratedFile) P(v ...interface{}) { 930 for _, x := range v { 931 switch x := x.(type) { 932 case GoIdent: 933 fmt.Fprint(&g.buf, g.QualifiedGoIdent(x)) 934 default: 935 fmt.Fprint(&g.buf, x) 936 } 937 } 938 fmt.Fprintln(&g.buf) 939} 940 941// QualifiedGoIdent returns the string to use for a Go identifier. 942// 943// If the identifier is from a different Go package than the generated file, 944// the returned name will be qualified (package.name) and an import statement 945// for the identifier's package will be included in the file. 946func (g *GeneratedFile) QualifiedGoIdent(ident GoIdent) string { 947 if ident.GoImportPath == g.goImportPath { 948 return ident.GoName 949 } 950 if packageName, ok := g.packageNames[ident.GoImportPath]; ok { 951 return string(packageName) + "." + ident.GoName 952 } 953 packageName := cleanPackageName(path.Base(string(ident.GoImportPath))) 954 for i, orig := 1, packageName; g.usedPackageNames[packageName]; i++ { 955 packageName = orig + GoPackageName(strconv.Itoa(i)) 956 } 957 g.packageNames[ident.GoImportPath] = packageName 958 g.usedPackageNames[packageName] = true 959 return string(packageName) + "." + ident.GoName 960} 961 962// Import ensures a package is imported by the generated file. 963// 964// Packages referenced by QualifiedGoIdent are automatically imported. 965// Explicitly importing a package with Import is generally only necessary 966// when the import will be blank (import _ "package"). 967func (g *GeneratedFile) Import(importPath GoImportPath) { 968 g.manualImports[importPath] = true 969} 970 971// Write implements io.Writer. 972func (g *GeneratedFile) Write(p []byte) (n int, err error) { 973 return g.buf.Write(p) 974} 975 976// Skip removes the generated file from the plugin output. 977func (g *GeneratedFile) Skip() { 978 g.skip = true 979} 980 981// Unskip reverts a previous call to Skip, re-including the generated file in 982// the plugin output. 983func (g *GeneratedFile) Unskip() { 984 g.skip = false 985} 986 987// Annotate associates a symbol in a generated Go file with a location in a 988// source .proto file. 989// 990// The symbol may refer to a type, constant, variable, function, method, or 991// struct field. The "T.sel" syntax is used to identify the method or field 992// 'sel' on type 'T'. 993func (g *GeneratedFile) Annotate(symbol string, loc Location) { 994 g.annotations[symbol] = append(g.annotations[symbol], loc) 995} 996 997// Content returns the contents of the generated file. 998func (g *GeneratedFile) Content() ([]byte, error) { 999 if !strings.HasSuffix(g.filename, ".go") { 1000 return g.buf.Bytes(), nil 1001 } 1002 1003 // Reformat generated code. 1004 original := g.buf.Bytes() 1005 fset := token.NewFileSet() 1006 file, err := parser.ParseFile(fset, "", original, parser.ParseComments) 1007 if err != nil { 1008 // Print out the bad code with line numbers. 1009 // This should never happen in practice, but it can while changing generated code 1010 // so consider this a debugging aid. 1011 var src bytes.Buffer 1012 s := bufio.NewScanner(bytes.NewReader(original)) 1013 for line := 1; s.Scan(); line++ { 1014 fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes()) 1015 } 1016 return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.filename, err, src.String()) 1017 } 1018 1019 // Collect a sorted list of all imports. 1020 var importPaths [][2]string 1021 rewriteImport := func(importPath string) string { 1022 if f := g.gen.opts.ImportRewriteFunc; f != nil { 1023 return string(f(GoImportPath(importPath))) 1024 } 1025 return importPath 1026 } 1027 for importPath := range g.packageNames { 1028 pkgName := string(g.packageNames[GoImportPath(importPath)]) 1029 pkgPath := rewriteImport(string(importPath)) 1030 importPaths = append(importPaths, [2]string{pkgName, pkgPath}) 1031 } 1032 for importPath := range g.manualImports { 1033 if _, ok := g.packageNames[importPath]; !ok { 1034 pkgPath := rewriteImport(string(importPath)) 1035 importPaths = append(importPaths, [2]string{"_", pkgPath}) 1036 } 1037 } 1038 sort.Slice(importPaths, func(i, j int) bool { 1039 return importPaths[i][1] < importPaths[j][1] 1040 }) 1041 1042 // Modify the AST to include a new import block. 1043 if len(importPaths) > 0 { 1044 // Insert block after package statement or 1045 // possible comment attached to the end of the package statement. 1046 pos := file.Package 1047 tokFile := fset.File(file.Package) 1048 pkgLine := tokFile.Line(file.Package) 1049 for _, c := range file.Comments { 1050 if tokFile.Line(c.Pos()) > pkgLine { 1051 break 1052 } 1053 pos = c.End() 1054 } 1055 1056 // Construct the import block. 1057 impDecl := &ast.GenDecl{ 1058 Tok: token.IMPORT, 1059 TokPos: pos, 1060 Lparen: pos, 1061 Rparen: pos, 1062 } 1063 for _, importPath := range importPaths { 1064 impDecl.Specs = append(impDecl.Specs, &ast.ImportSpec{ 1065 Name: &ast.Ident{ 1066 Name: importPath[0], 1067 NamePos: pos, 1068 }, 1069 Path: &ast.BasicLit{ 1070 Kind: token.STRING, 1071 Value: strconv.Quote(importPath[1]), 1072 ValuePos: pos, 1073 }, 1074 EndPos: pos, 1075 }) 1076 } 1077 file.Decls = append([]ast.Decl{impDecl}, file.Decls...) 1078 } 1079 1080 var out bytes.Buffer 1081 if err = (&printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}).Fprint(&out, fset, file); err != nil { 1082 return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.filename, err) 1083 } 1084 return out.Bytes(), nil 1085} 1086 1087// metaFile returns the contents of the file's metadata file, which is a 1088// text formatted string of the google.protobuf.GeneratedCodeInfo. 1089func (g *GeneratedFile) metaFile(content []byte) (string, error) { 1090 fset := token.NewFileSet() 1091 astFile, err := parser.ParseFile(fset, "", content, 0) 1092 if err != nil { 1093 return "", err 1094 } 1095 info := &descriptorpb.GeneratedCodeInfo{} 1096 1097 seenAnnotations := make(map[string]bool) 1098 annotate := func(s string, ident *ast.Ident) { 1099 seenAnnotations[s] = true 1100 for _, loc := range g.annotations[s] { 1101 info.Annotation = append(info.Annotation, &descriptorpb.GeneratedCodeInfo_Annotation{ 1102 SourceFile: proto.String(loc.SourceFile), 1103 Path: loc.Path, 1104 Begin: proto.Int32(int32(fset.Position(ident.Pos()).Offset)), 1105 End: proto.Int32(int32(fset.Position(ident.End()).Offset)), 1106 }) 1107 } 1108 } 1109 for _, decl := range astFile.Decls { 1110 switch decl := decl.(type) { 1111 case *ast.GenDecl: 1112 for _, spec := range decl.Specs { 1113 switch spec := spec.(type) { 1114 case *ast.TypeSpec: 1115 annotate(spec.Name.Name, spec.Name) 1116 switch st := spec.Type.(type) { 1117 case *ast.StructType: 1118 for _, field := range st.Fields.List { 1119 for _, name := range field.Names { 1120 annotate(spec.Name.Name+"."+name.Name, name) 1121 } 1122 } 1123 case *ast.InterfaceType: 1124 for _, field := range st.Methods.List { 1125 for _, name := range field.Names { 1126 annotate(spec.Name.Name+"."+name.Name, name) 1127 } 1128 } 1129 } 1130 case *ast.ValueSpec: 1131 for _, name := range spec.Names { 1132 annotate(name.Name, name) 1133 } 1134 } 1135 } 1136 case *ast.FuncDecl: 1137 if decl.Recv == nil { 1138 annotate(decl.Name.Name, decl.Name) 1139 } else { 1140 recv := decl.Recv.List[0].Type 1141 if s, ok := recv.(*ast.StarExpr); ok { 1142 recv = s.X 1143 } 1144 if id, ok := recv.(*ast.Ident); ok { 1145 annotate(id.Name+"."+decl.Name.Name, decl.Name) 1146 } 1147 } 1148 } 1149 } 1150 for a := range g.annotations { 1151 if !seenAnnotations[a] { 1152 return "", fmt.Errorf("%v: no symbol matching annotation %q", g.filename, a) 1153 } 1154 } 1155 1156 b, err := prototext.Marshal(info) 1157 if err != nil { 1158 return "", err 1159 } 1160 return string(b), nil 1161} 1162 1163// A GoIdent is a Go identifier, consisting of a name and import path. 1164// The name is a single identifier and may not be a dot-qualified selector. 1165type GoIdent struct { 1166 GoName string 1167 GoImportPath GoImportPath 1168} 1169 1170func (id GoIdent) String() string { return fmt.Sprintf("%q.%v", id.GoImportPath, id.GoName) } 1171 1172// newGoIdent returns the Go identifier for a descriptor. 1173func newGoIdent(f *File, d protoreflect.Descriptor) GoIdent { 1174 name := strings.TrimPrefix(string(d.FullName()), string(f.Desc.Package())+".") 1175 return GoIdent{ 1176 GoName: strs.GoCamelCase(name), 1177 GoImportPath: f.GoImportPath, 1178 } 1179} 1180 1181// A GoImportPath is the import path of a Go package. 1182// For example: "google.golang.org/protobuf/compiler/protogen" 1183type GoImportPath string 1184 1185func (p GoImportPath) String() string { return strconv.Quote(string(p)) } 1186 1187// Ident returns a GoIdent with s as the GoName and p as the GoImportPath. 1188func (p GoImportPath) Ident(s string) GoIdent { 1189 return GoIdent{GoName: s, GoImportPath: p} 1190} 1191 1192// A GoPackageName is the name of a Go package. e.g., "protobuf". 1193type GoPackageName string 1194 1195// cleanPackageName converts a string to a valid Go package name. 1196func cleanPackageName(name string) GoPackageName { 1197 return GoPackageName(strs.GoSanitized(name)) 1198} 1199 1200type pathType int 1201 1202const ( 1203 pathTypeImport pathType = iota 1204 pathTypeSourceRelative 1205) 1206 1207// A Location is a location in a .proto source file. 1208// 1209// See the google.protobuf.SourceCodeInfo documentation in descriptor.proto 1210// for details. 1211type Location struct { 1212 SourceFile string 1213 Path protoreflect.SourcePath 1214} 1215 1216// appendPath add elements to a Location's path, returning a new Location. 1217func (loc Location) appendPath(num protoreflect.FieldNumber, idx int) Location { 1218 loc.Path = append(protoreflect.SourcePath(nil), loc.Path...) // make copy 1219 loc.Path = append(loc.Path, int32(num), int32(idx)) 1220 return loc 1221} 1222 1223// CommentSet is a set of leading and trailing comments associated 1224// with a .proto descriptor declaration. 1225type CommentSet struct { 1226 LeadingDetached []Comments 1227 Leading Comments 1228 Trailing Comments 1229} 1230 1231func makeCommentSet(loc protoreflect.SourceLocation) CommentSet { 1232 var leadingDetached []Comments 1233 for _, s := range loc.LeadingDetachedComments { 1234 leadingDetached = append(leadingDetached, Comments(s)) 1235 } 1236 return CommentSet{ 1237 LeadingDetached: leadingDetached, 1238 Leading: Comments(loc.LeadingComments), 1239 Trailing: Comments(loc.TrailingComments), 1240 } 1241} 1242 1243// Comments is a comments string as provided by protoc. 1244type Comments string 1245 1246// String formats the comments by inserting // to the start of each line, 1247// ensuring that there is a trailing newline. 1248// An empty comment is formatted as an empty string. 1249func (c Comments) String() string { 1250 if c == "" { 1251 return "" 1252 } 1253 var b []byte 1254 for _, line := range strings.Split(strings.TrimSuffix(string(c), "\n"), "\n") { 1255 b = append(b, "//"...) 1256 b = append(b, line...) 1257 b = append(b, "\n"...) 1258 } 1259 return string(b) 1260} 1261