1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package protogen provides support for writing protoc plugins.
6//
7// Plugins for protoc, the Protocol Buffer compiler,
8// are programs which read a CodeGeneratorRequest message from standard input
9// and write a CodeGeneratorResponse message to standard output.
10// This package provides support for writing plugins which generate Go code.
11package protogen
12
13import (
14	"bufio"
15	"bytes"
16	"fmt"
17	"go/ast"
18	"go/parser"
19	"go/printer"
20	"go/token"
21	"go/types"
22	"io/ioutil"
23	"os"
24	"path"
25	"path/filepath"
26	"sort"
27	"strconv"
28	"strings"
29
30	"google.golang.org/protobuf/encoding/prototext"
31	"google.golang.org/protobuf/internal/genid"
32	"google.golang.org/protobuf/internal/strs"
33	"google.golang.org/protobuf/proto"
34	"google.golang.org/protobuf/reflect/protodesc"
35	"google.golang.org/protobuf/reflect/protoreflect"
36	"google.golang.org/protobuf/reflect/protoregistry"
37
38	"google.golang.org/protobuf/types/descriptorpb"
39	"google.golang.org/protobuf/types/pluginpb"
40)
41
42const goPackageDocURL = "https://developers.google.com/protocol-buffers/docs/reference/go-generated#package"
43
44// Run executes a function as a protoc plugin.
45//
46// It reads a CodeGeneratorRequest message from os.Stdin, invokes the plugin
47// function, and writes a CodeGeneratorResponse message to os.Stdout.
48//
49// If a failure occurs while reading or writing, Run prints an error to
50// os.Stderr and calls os.Exit(1).
51func (opts Options) Run(f func(*Plugin) error) {
52	if err := run(opts, f); err != nil {
53		fmt.Fprintf(os.Stderr, "%s: %v\n", filepath.Base(os.Args[0]), err)
54		os.Exit(1)
55	}
56}
57
58func run(opts Options, f func(*Plugin) error) error {
59	if len(os.Args) > 1 {
60		return fmt.Errorf("unknown argument %q (this program should be run by protoc, not directly)", os.Args[1])
61	}
62	in, err := ioutil.ReadAll(os.Stdin)
63	if err != nil {
64		return err
65	}
66	req := &pluginpb.CodeGeneratorRequest{}
67	if err := proto.Unmarshal(in, req); err != nil {
68		return err
69	}
70	gen, err := opts.New(req)
71	if err != nil {
72		return err
73	}
74	if err := f(gen); err != nil {
75		// Errors from the plugin function are reported by setting the
76		// error field in the CodeGeneratorResponse.
77		//
78		// In contrast, errors that indicate a problem in protoc
79		// itself (unparsable input, I/O errors, etc.) are reported
80		// to stderr.
81		gen.Error(err)
82	}
83	resp := gen.Response()
84	out, err := proto.Marshal(resp)
85	if err != nil {
86		return err
87	}
88	if _, err := os.Stdout.Write(out); err != nil {
89		return err
90	}
91	return nil
92}
93
94// A Plugin is a protoc plugin invocation.
95type Plugin struct {
96	// Request is the CodeGeneratorRequest provided by protoc.
97	Request *pluginpb.CodeGeneratorRequest
98
99	// Files is the set of files to generate and everything they import.
100	// Files appear in topological order, so each file appears before any
101	// file that imports it.
102	Files       []*File
103	FilesByPath map[string]*File
104
105	// SupportedFeatures is the set of protobuf language features supported by
106	// this generator plugin. See the documentation for
107	// google.protobuf.CodeGeneratorResponse.supported_features for details.
108	SupportedFeatures uint64
109
110	fileReg        *protoregistry.Files
111	enumsByName    map[protoreflect.FullName]*Enum
112	messagesByName map[protoreflect.FullName]*Message
113	annotateCode   bool
114	pathType       pathType
115	module         string
116	genFiles       []*GeneratedFile
117	opts           Options
118	err            error
119}
120
121type Options struct {
122	// If ParamFunc is non-nil, it will be called with each unknown
123	// generator parameter.
124	//
125	// Plugins for protoc can accept parameters from the command line,
126	// passed in the --<lang>_out protoc, separated from the output
127	// directory with a colon; e.g.,
128	//
129	//   --go_out=<param1>=<value1>,<param2>=<value2>:<output_directory>
130	//
131	// Parameters passed in this fashion as a comma-separated list of
132	// key=value pairs will be passed to the ParamFunc.
133	//
134	// The (flag.FlagSet).Set method matches this function signature,
135	// so parameters can be converted into flags as in the following:
136	//
137	//   var flags flag.FlagSet
138	//   value := flags.Bool("param", false, "")
139	//   opts := &protogen.Options{
140	//     ParamFunc: flags.Set,
141	//   }
142	//   protogen.Run(opts, func(p *protogen.Plugin) error {
143	//     if *value { ... }
144	//   })
145	ParamFunc func(name, value string) error
146
147	// ImportRewriteFunc is called with the import path of each package
148	// imported by a generated file. It returns the import path to use
149	// for this package.
150	ImportRewriteFunc func(GoImportPath) GoImportPath
151}
152
153// New returns a new Plugin.
154func (opts Options) New(req *pluginpb.CodeGeneratorRequest) (*Plugin, error) {
155	gen := &Plugin{
156		Request:        req,
157		FilesByPath:    make(map[string]*File),
158		fileReg:        new(protoregistry.Files),
159		enumsByName:    make(map[protoreflect.FullName]*Enum),
160		messagesByName: make(map[protoreflect.FullName]*Message),
161		opts:           opts,
162	}
163
164	packageNames := make(map[string]GoPackageName) // filename -> package name
165	importPaths := make(map[string]GoImportPath)   // filename -> import path
166	for _, param := range strings.Split(req.GetParameter(), ",") {
167		var value string
168		if i := strings.Index(param, "="); i >= 0 {
169			value = param[i+1:]
170			param = param[0:i]
171		}
172		switch param {
173		case "":
174			// Ignore.
175		case "module":
176			gen.module = value
177		case "paths":
178			switch value {
179			case "import":
180				gen.pathType = pathTypeImport
181			case "source_relative":
182				gen.pathType = pathTypeSourceRelative
183			default:
184				return nil, fmt.Errorf(`unknown path type %q: want "import" or "source_relative"`, value)
185			}
186		case "annotate_code":
187			switch value {
188			case "true", "":
189				gen.annotateCode = true
190			case "false":
191			default:
192				return nil, fmt.Errorf(`bad value for parameter %q: want "true" or "false"`, param)
193			}
194		default:
195			if param[0] == 'M' {
196				impPath, pkgName := splitImportPathAndPackageName(value)
197				if pkgName != "" {
198					packageNames[param[1:]] = pkgName
199				}
200				if impPath != "" {
201					importPaths[param[1:]] = impPath
202				}
203				continue
204			}
205			if opts.ParamFunc != nil {
206				if err := opts.ParamFunc(param, value); err != nil {
207					return nil, err
208				}
209			}
210		}
211	}
212	// When the module= option is provided, we strip the module name
213	// prefix from generated files. This only makes sense if generated
214	// filenames are based on the import path.
215	if gen.module != "" && gen.pathType == pathTypeSourceRelative {
216		return nil, fmt.Errorf("cannot use module= with paths=source_relative")
217	}
218
219	// Figure out the import path and package name for each file.
220	//
221	// The rules here are complicated and have grown organically over time.
222	// Interactions between different ways of specifying package information
223	// may be surprising.
224	//
225	// The recommended approach is to include a go_package option in every
226	// .proto source file specifying the full import path of the Go package
227	// associated with this file.
228	//
229	//     option go_package = "google.golang.org/protobuf/types/known/anypb";
230	//
231	// Alternatively, build systems which want to exert full control over
232	// import paths may specify M<filename>=<import_path> flags.
233	for _, fdesc := range gen.Request.ProtoFile {
234		// The "M" command-line flags take precedence over
235		// the "go_package" option in the .proto source file.
236		filename := fdesc.GetName()
237		impPath, pkgName := splitImportPathAndPackageName(fdesc.GetOptions().GetGoPackage())
238		if importPaths[filename] == "" && impPath != "" {
239			importPaths[filename] = impPath
240		}
241		if packageNames[filename] == "" && pkgName != "" {
242			packageNames[filename] = pkgName
243		}
244		switch {
245		case importPaths[filename] == "":
246			// The import path must be specified one way or another.
247			return nil, fmt.Errorf(
248				"unable to determine Go import path for %q\n\n"+
249					"Please specify either:\n"+
250					"\t• a \"go_package\" option in the .proto source file, or\n"+
251					"\t• a \"M\" argument on the command line.\n\n"+
252					"See %v for more information.\n",
253				fdesc.GetName(), goPackageDocURL)
254		case !strings.Contains(string(importPaths[filename]), ".") &&
255			!strings.Contains(string(importPaths[filename]), "/"):
256			// Check that import paths contain at least a dot or slash to avoid
257			// a common mistake where import path is confused with package name.
258			return nil, fmt.Errorf(
259				"invalid Go import path %q for %q\n\n"+
260					"The import path must contain at least one period ('.') or forward slash ('/') character.\n\n"+
261					"See %v for more information.\n",
262				string(importPaths[filename]), fdesc.GetName(), goPackageDocURL)
263		case packageNames[filename] == "":
264			// If the package name is not explicitly specified,
265			// then derive a reasonable package name from the import path.
266			//
267			// NOTE: The package name is derived first from the import path in
268			// the "go_package" option (if present) before trying the "M" flag.
269			// The inverted order for this is because the primary use of the "M"
270			// flag is by build systems that have full control over the
271			// import paths all packages, where it is generally expected that
272			// the Go package name still be identical for the Go toolchain and
273			// for custom build systems like Bazel.
274			if impPath == "" {
275				impPath = importPaths[filename]
276			}
277			packageNames[filename] = cleanPackageName(path.Base(string(impPath)))
278		}
279	}
280
281	// Consistency check: Every file with the same Go import path should have
282	// the same Go package name.
283	packageFiles := make(map[GoImportPath][]string)
284	for filename, importPath := range importPaths {
285		if _, ok := packageNames[filename]; !ok {
286			// Skip files mentioned in a M<file>=<import_path> parameter
287			// but which do not appear in the CodeGeneratorRequest.
288			continue
289		}
290		packageFiles[importPath] = append(packageFiles[importPath], filename)
291	}
292	for importPath, filenames := range packageFiles {
293		for i := 1; i < len(filenames); i++ {
294			if a, b := packageNames[filenames[0]], packageNames[filenames[i]]; a != b {
295				return nil, fmt.Errorf("Go package %v has inconsistent names %v (%v) and %v (%v)",
296					importPath, a, filenames[0], b, filenames[i])
297			}
298		}
299	}
300
301	for _, fdesc := range gen.Request.ProtoFile {
302		filename := fdesc.GetName()
303		if gen.FilesByPath[filename] != nil {
304			return nil, fmt.Errorf("duplicate file name: %q", filename)
305		}
306		f, err := newFile(gen, fdesc, packageNames[filename], importPaths[filename])
307		if err != nil {
308			return nil, err
309		}
310		gen.Files = append(gen.Files, f)
311		gen.FilesByPath[filename] = f
312	}
313	for _, filename := range gen.Request.FileToGenerate {
314		f, ok := gen.FilesByPath[filename]
315		if !ok {
316			return nil, fmt.Errorf("no descriptor for generated file: %v", filename)
317		}
318		f.Generate = true
319	}
320	return gen, nil
321}
322
323// Error records an error in code generation. The generator will report the
324// error back to protoc and will not produce output.
325func (gen *Plugin) Error(err error) {
326	if gen.err == nil {
327		gen.err = err
328	}
329}
330
331// Response returns the generator output.
332func (gen *Plugin) Response() *pluginpb.CodeGeneratorResponse {
333	resp := &pluginpb.CodeGeneratorResponse{}
334	if gen.err != nil {
335		resp.Error = proto.String(gen.err.Error())
336		return resp
337	}
338	for _, g := range gen.genFiles {
339		if g.skip {
340			continue
341		}
342		content, err := g.Content()
343		if err != nil {
344			return &pluginpb.CodeGeneratorResponse{
345				Error: proto.String(err.Error()),
346			}
347		}
348		filename := g.filename
349		if gen.module != "" {
350			trim := gen.module + "/"
351			if !strings.HasPrefix(filename, trim) {
352				return &pluginpb.CodeGeneratorResponse{
353					Error: proto.String(fmt.Sprintf("%v: generated file does not match prefix %q", filename, gen.module)),
354				}
355			}
356			filename = strings.TrimPrefix(filename, trim)
357		}
358		resp.File = append(resp.File, &pluginpb.CodeGeneratorResponse_File{
359			Name:    proto.String(filename),
360			Content: proto.String(string(content)),
361		})
362		if gen.annotateCode && strings.HasSuffix(g.filename, ".go") {
363			meta, err := g.metaFile(content)
364			if err != nil {
365				return &pluginpb.CodeGeneratorResponse{
366					Error: proto.String(err.Error()),
367				}
368			}
369			resp.File = append(resp.File, &pluginpb.CodeGeneratorResponse_File{
370				Name:    proto.String(filename + ".meta"),
371				Content: proto.String(meta),
372			})
373		}
374	}
375	if gen.SupportedFeatures > 0 {
376		resp.SupportedFeatures = proto.Uint64(gen.SupportedFeatures)
377	}
378	return resp
379}
380
381// A File describes a .proto source file.
382type File struct {
383	Desc  protoreflect.FileDescriptor
384	Proto *descriptorpb.FileDescriptorProto
385
386	GoDescriptorIdent GoIdent       // name of Go variable for the file descriptor
387	GoPackageName     GoPackageName // name of this file's Go package
388	GoImportPath      GoImportPath  // import path of this file's Go package
389
390	Enums      []*Enum      // top-level enum declarations
391	Messages   []*Message   // top-level message declarations
392	Extensions []*Extension // top-level extension declarations
393	Services   []*Service   // top-level service declarations
394
395	Generate bool // true if we should generate code for this file
396
397	// GeneratedFilenamePrefix is used to construct filenames for generated
398	// files associated with this source file.
399	//
400	// For example, the source file "dir/foo.proto" might have a filename prefix
401	// of "dir/foo". Appending ".pb.go" produces an output file of "dir/foo.pb.go".
402	GeneratedFilenamePrefix string
403
404	location Location
405}
406
407func newFile(gen *Plugin, p *descriptorpb.FileDescriptorProto, packageName GoPackageName, importPath GoImportPath) (*File, error) {
408	desc, err := protodesc.NewFile(p, gen.fileReg)
409	if err != nil {
410		return nil, fmt.Errorf("invalid FileDescriptorProto %q: %v", p.GetName(), err)
411	}
412	if err := gen.fileReg.RegisterFile(desc); err != nil {
413		return nil, fmt.Errorf("cannot register descriptor %q: %v", p.GetName(), err)
414	}
415	f := &File{
416		Desc:          desc,
417		Proto:         p,
418		GoPackageName: packageName,
419		GoImportPath:  importPath,
420		location:      Location{SourceFile: desc.Path()},
421	}
422
423	// Determine the prefix for generated Go files.
424	prefix := p.GetName()
425	if ext := path.Ext(prefix); ext == ".proto" || ext == ".protodevel" {
426		prefix = prefix[:len(prefix)-len(ext)]
427	}
428	switch gen.pathType {
429	case pathTypeImport:
430		// If paths=import, the output filename is derived from the Go import path.
431		prefix = path.Join(string(f.GoImportPath), path.Base(prefix))
432	case pathTypeSourceRelative:
433		// If paths=source_relative, the output filename is derived from
434		// the input filename.
435	}
436	f.GoDescriptorIdent = GoIdent{
437		GoName:       "File_" + strs.GoSanitized(p.GetName()),
438		GoImportPath: f.GoImportPath,
439	}
440	f.GeneratedFilenamePrefix = prefix
441
442	for i, eds := 0, desc.Enums(); i < eds.Len(); i++ {
443		f.Enums = append(f.Enums, newEnum(gen, f, nil, eds.Get(i)))
444	}
445	for i, mds := 0, desc.Messages(); i < mds.Len(); i++ {
446		f.Messages = append(f.Messages, newMessage(gen, f, nil, mds.Get(i)))
447	}
448	for i, xds := 0, desc.Extensions(); i < xds.Len(); i++ {
449		f.Extensions = append(f.Extensions, newField(gen, f, nil, xds.Get(i)))
450	}
451	for i, sds := 0, desc.Services(); i < sds.Len(); i++ {
452		f.Services = append(f.Services, newService(gen, f, sds.Get(i)))
453	}
454	for _, message := range f.Messages {
455		if err := message.resolveDependencies(gen); err != nil {
456			return nil, err
457		}
458	}
459	for _, extension := range f.Extensions {
460		if err := extension.resolveDependencies(gen); err != nil {
461			return nil, err
462		}
463	}
464	for _, service := range f.Services {
465		for _, method := range service.Methods {
466			if err := method.resolveDependencies(gen); err != nil {
467				return nil, err
468			}
469		}
470	}
471	return f, nil
472}
473
474// splitImportPathAndPackageName splits off the optional Go package name
475// from the Go import path when seperated by a ';' delimiter.
476func splitImportPathAndPackageName(s string) (GoImportPath, GoPackageName) {
477	if i := strings.Index(s, ";"); i >= 0 {
478		return GoImportPath(s[:i]), GoPackageName(s[i+1:])
479	}
480	return GoImportPath(s), ""
481}
482
483// An Enum describes an enum.
484type Enum struct {
485	Desc protoreflect.EnumDescriptor
486
487	GoIdent GoIdent // name of the generated Go type
488
489	Values []*EnumValue // enum value declarations
490
491	Location Location   // location of this enum
492	Comments CommentSet // comments associated with this enum
493}
494
495func newEnum(gen *Plugin, f *File, parent *Message, desc protoreflect.EnumDescriptor) *Enum {
496	var loc Location
497	if parent != nil {
498		loc = parent.Location.appendPath(genid.DescriptorProto_EnumType_field_number, desc.Index())
499	} else {
500		loc = f.location.appendPath(genid.FileDescriptorProto_EnumType_field_number, desc.Index())
501	}
502	enum := &Enum{
503		Desc:     desc,
504		GoIdent:  newGoIdent(f, desc),
505		Location: loc,
506		Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)),
507	}
508	gen.enumsByName[desc.FullName()] = enum
509	for i, vds := 0, enum.Desc.Values(); i < vds.Len(); i++ {
510		enum.Values = append(enum.Values, newEnumValue(gen, f, parent, enum, vds.Get(i)))
511	}
512	return enum
513}
514
515// An EnumValue describes an enum value.
516type EnumValue struct {
517	Desc protoreflect.EnumValueDescriptor
518
519	GoIdent GoIdent // name of the generated Go declaration
520
521	Parent *Enum // enum in which this value is declared
522
523	Location Location   // location of this enum value
524	Comments CommentSet // comments associated with this enum value
525}
526
527func newEnumValue(gen *Plugin, f *File, message *Message, enum *Enum, desc protoreflect.EnumValueDescriptor) *EnumValue {
528	// A top-level enum value's name is: EnumName_ValueName
529	// An enum value contained in a message is: MessageName_ValueName
530	//
531	// For historical reasons, enum value names are not camel-cased.
532	parentIdent := enum.GoIdent
533	if message != nil {
534		parentIdent = message.GoIdent
535	}
536	name := parentIdent.GoName + "_" + string(desc.Name())
537	loc := enum.Location.appendPath(genid.EnumDescriptorProto_Value_field_number, desc.Index())
538	return &EnumValue{
539		Desc:     desc,
540		GoIdent:  f.GoImportPath.Ident(name),
541		Parent:   enum,
542		Location: loc,
543		Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)),
544	}
545}
546
547// A Message describes a message.
548type Message struct {
549	Desc protoreflect.MessageDescriptor
550
551	GoIdent GoIdent // name of the generated Go type
552
553	Fields []*Field // message field declarations
554	Oneofs []*Oneof // message oneof declarations
555
556	Enums      []*Enum      // nested enum declarations
557	Messages   []*Message   // nested message declarations
558	Extensions []*Extension // nested extension declarations
559
560	Location Location   // location of this message
561	Comments CommentSet // comments associated with this message
562}
563
564func newMessage(gen *Plugin, f *File, parent *Message, desc protoreflect.MessageDescriptor) *Message {
565	var loc Location
566	if parent != nil {
567		loc = parent.Location.appendPath(genid.DescriptorProto_NestedType_field_number, desc.Index())
568	} else {
569		loc = f.location.appendPath(genid.FileDescriptorProto_MessageType_field_number, desc.Index())
570	}
571	message := &Message{
572		Desc:     desc,
573		GoIdent:  newGoIdent(f, desc),
574		Location: loc,
575		Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)),
576	}
577	gen.messagesByName[desc.FullName()] = message
578	for i, eds := 0, desc.Enums(); i < eds.Len(); i++ {
579		message.Enums = append(message.Enums, newEnum(gen, f, message, eds.Get(i)))
580	}
581	for i, mds := 0, desc.Messages(); i < mds.Len(); i++ {
582		message.Messages = append(message.Messages, newMessage(gen, f, message, mds.Get(i)))
583	}
584	for i, fds := 0, desc.Fields(); i < fds.Len(); i++ {
585		message.Fields = append(message.Fields, newField(gen, f, message, fds.Get(i)))
586	}
587	for i, ods := 0, desc.Oneofs(); i < ods.Len(); i++ {
588		message.Oneofs = append(message.Oneofs, newOneof(gen, f, message, ods.Get(i)))
589	}
590	for i, xds := 0, desc.Extensions(); i < xds.Len(); i++ {
591		message.Extensions = append(message.Extensions, newField(gen, f, message, xds.Get(i)))
592	}
593
594	// Resolve local references between fields and oneofs.
595	for _, field := range message.Fields {
596		if od := field.Desc.ContainingOneof(); od != nil {
597			oneof := message.Oneofs[od.Index()]
598			field.Oneof = oneof
599			oneof.Fields = append(oneof.Fields, field)
600		}
601	}
602
603	// Field name conflict resolution.
604	//
605	// We assume well-known method names that may be attached to a generated
606	// message type, as well as a 'Get*' method for each field. For each
607	// field in turn, we add _s to its name until there are no conflicts.
608	//
609	// Any change to the following set of method names is a potential
610	// incompatible API change because it may change generated field names.
611	//
612	// TODO: If we ever support a 'go_name' option to set the Go name of a
613	// field, we should consider dropping this entirely. The conflict
614	// resolution algorithm is subtle and surprising (changing the order
615	// in which fields appear in the .proto source file can change the
616	// names of fields in generated code), and does not adapt well to
617	// adding new per-field methods such as setters.
618	usedNames := map[string]bool{
619		"Reset":               true,
620		"String":              true,
621		"ProtoMessage":        true,
622		"Marshal":             true,
623		"Unmarshal":           true,
624		"ExtensionRangeArray": true,
625		"ExtensionMap":        true,
626		"Descriptor":          true,
627	}
628	makeNameUnique := func(name string, hasGetter bool) string {
629		for usedNames[name] || (hasGetter && usedNames["Get"+name]) {
630			name += "_"
631		}
632		usedNames[name] = true
633		usedNames["Get"+name] = hasGetter
634		return name
635	}
636	for _, field := range message.Fields {
637		field.GoName = makeNameUnique(field.GoName, true)
638		field.GoIdent.GoName = message.GoIdent.GoName + "_" + field.GoName
639		if field.Oneof != nil && field.Oneof.Fields[0] == field {
640			// Make the name for a oneof unique as well. For historical reasons,
641			// this assumes that a getter method is not generated for oneofs.
642			// This is incorrect, but fixing it breaks existing code.
643			field.Oneof.GoName = makeNameUnique(field.Oneof.GoName, false)
644			field.Oneof.GoIdent.GoName = message.GoIdent.GoName + "_" + field.Oneof.GoName
645		}
646	}
647
648	// Oneof field name conflict resolution.
649	//
650	// This conflict resolution is incomplete as it does not consider collisions
651	// with other oneof field types, but fixing it breaks existing code.
652	for _, field := range message.Fields {
653		if field.Oneof != nil {
654		Loop:
655			for {
656				for _, nestedMessage := range message.Messages {
657					if nestedMessage.GoIdent == field.GoIdent {
658						field.GoIdent.GoName += "_"
659						continue Loop
660					}
661				}
662				for _, nestedEnum := range message.Enums {
663					if nestedEnum.GoIdent == field.GoIdent {
664						field.GoIdent.GoName += "_"
665						continue Loop
666					}
667				}
668				break Loop
669			}
670		}
671	}
672
673	return message
674}
675
676func (message *Message) resolveDependencies(gen *Plugin) error {
677	for _, field := range message.Fields {
678		if err := field.resolveDependencies(gen); err != nil {
679			return err
680		}
681	}
682	for _, message := range message.Messages {
683		if err := message.resolveDependencies(gen); err != nil {
684			return err
685		}
686	}
687	for _, extension := range message.Extensions {
688		if err := extension.resolveDependencies(gen); err != nil {
689			return err
690		}
691	}
692	return nil
693}
694
695// A Field describes a message field.
696type Field struct {
697	Desc protoreflect.FieldDescriptor
698
699	// GoName is the base name of this field's Go field and methods.
700	// For code generated by protoc-gen-go, this means a field named
701	// '{{GoName}}' and a getter method named 'Get{{GoName}}'.
702	GoName string // e.g., "FieldName"
703
704	// GoIdent is the base name of a top-level declaration for this field.
705	// For code generated by protoc-gen-go, this means a wrapper type named
706	// '{{GoIdent}}' for members fields of a oneof, and a variable named
707	// 'E_{{GoIdent}}' for extension fields.
708	GoIdent GoIdent // e.g., "MessageName_FieldName"
709
710	Parent   *Message // message in which this field is declared; nil if top-level extension
711	Oneof    *Oneof   // containing oneof; nil if not part of a oneof
712	Extendee *Message // extended message for extension fields; nil otherwise
713
714	Enum    *Enum    // type for enum fields; nil otherwise
715	Message *Message // type for message or group fields; nil otherwise
716
717	Location Location   // location of this field
718	Comments CommentSet // comments associated with this field
719}
720
721func newField(gen *Plugin, f *File, message *Message, desc protoreflect.FieldDescriptor) *Field {
722	var loc Location
723	switch {
724	case desc.IsExtension() && message == nil:
725		loc = f.location.appendPath(genid.FileDescriptorProto_Extension_field_number, desc.Index())
726	case desc.IsExtension() && message != nil:
727		loc = message.Location.appendPath(genid.DescriptorProto_Extension_field_number, desc.Index())
728	default:
729		loc = message.Location.appendPath(genid.DescriptorProto_Field_field_number, desc.Index())
730	}
731	camelCased := strs.GoCamelCase(string(desc.Name()))
732	var parentPrefix string
733	if message != nil {
734		parentPrefix = message.GoIdent.GoName + "_"
735	}
736	field := &Field{
737		Desc:   desc,
738		GoName: camelCased,
739		GoIdent: GoIdent{
740			GoImportPath: f.GoImportPath,
741			GoName:       parentPrefix + camelCased,
742		},
743		Parent:   message,
744		Location: loc,
745		Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)),
746	}
747	return field
748}
749
750func (field *Field) resolveDependencies(gen *Plugin) error {
751	desc := field.Desc
752	switch desc.Kind() {
753	case protoreflect.EnumKind:
754		name := field.Desc.Enum().FullName()
755		enum, ok := gen.enumsByName[name]
756		if !ok {
757			return fmt.Errorf("field %v: no descriptor for enum %v", desc.FullName(), name)
758		}
759		field.Enum = enum
760	case protoreflect.MessageKind, protoreflect.GroupKind:
761		name := desc.Message().FullName()
762		message, ok := gen.messagesByName[name]
763		if !ok {
764			return fmt.Errorf("field %v: no descriptor for type %v", desc.FullName(), name)
765		}
766		field.Message = message
767	}
768	if desc.IsExtension() {
769		name := desc.ContainingMessage().FullName()
770		message, ok := gen.messagesByName[name]
771		if !ok {
772			return fmt.Errorf("field %v: no descriptor for type %v", desc.FullName(), name)
773		}
774		field.Extendee = message
775	}
776	return nil
777}
778
779// A Oneof describes a message oneof.
780type Oneof struct {
781	Desc protoreflect.OneofDescriptor
782
783	// GoName is the base name of this oneof's Go field and methods.
784	// For code generated by protoc-gen-go, this means a field named
785	// '{{GoName}}' and a getter method named 'Get{{GoName}}'.
786	GoName string // e.g., "OneofName"
787
788	// GoIdent is the base name of a top-level declaration for this oneof.
789	GoIdent GoIdent // e.g., "MessageName_OneofName"
790
791	Parent *Message // message in which this oneof is declared
792
793	Fields []*Field // fields that are part of this oneof
794
795	Location Location   // location of this oneof
796	Comments CommentSet // comments associated with this oneof
797}
798
799func newOneof(gen *Plugin, f *File, message *Message, desc protoreflect.OneofDescriptor) *Oneof {
800	loc := message.Location.appendPath(genid.DescriptorProto_OneofDecl_field_number, desc.Index())
801	camelCased := strs.GoCamelCase(string(desc.Name()))
802	parentPrefix := message.GoIdent.GoName + "_"
803	return &Oneof{
804		Desc:   desc,
805		Parent: message,
806		GoName: camelCased,
807		GoIdent: GoIdent{
808			GoImportPath: f.GoImportPath,
809			GoName:       parentPrefix + camelCased,
810		},
811		Location: loc,
812		Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)),
813	}
814}
815
816// Extension is an alias of Field for documentation.
817type Extension = Field
818
819// A Service describes a service.
820type Service struct {
821	Desc protoreflect.ServiceDescriptor
822
823	GoName string
824
825	Methods []*Method // service method declarations
826
827	Location Location   // location of this service
828	Comments CommentSet // comments associated with this service
829}
830
831func newService(gen *Plugin, f *File, desc protoreflect.ServiceDescriptor) *Service {
832	loc := f.location.appendPath(genid.FileDescriptorProto_Service_field_number, desc.Index())
833	service := &Service{
834		Desc:     desc,
835		GoName:   strs.GoCamelCase(string(desc.Name())),
836		Location: loc,
837		Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)),
838	}
839	for i, mds := 0, desc.Methods(); i < mds.Len(); i++ {
840		service.Methods = append(service.Methods, newMethod(gen, f, service, mds.Get(i)))
841	}
842	return service
843}
844
845// A Method describes a method in a service.
846type Method struct {
847	Desc protoreflect.MethodDescriptor
848
849	GoName string
850
851	Parent *Service // service in which this method is declared
852
853	Input  *Message
854	Output *Message
855
856	Location Location   // location of this method
857	Comments CommentSet // comments associated with this method
858}
859
860func newMethod(gen *Plugin, f *File, service *Service, desc protoreflect.MethodDescriptor) *Method {
861	loc := service.Location.appendPath(genid.ServiceDescriptorProto_Method_field_number, desc.Index())
862	method := &Method{
863		Desc:     desc,
864		GoName:   strs.GoCamelCase(string(desc.Name())),
865		Parent:   service,
866		Location: loc,
867		Comments: makeCommentSet(f.Desc.SourceLocations().ByDescriptor(desc)),
868	}
869	return method
870}
871
872func (method *Method) resolveDependencies(gen *Plugin) error {
873	desc := method.Desc
874
875	inName := desc.Input().FullName()
876	in, ok := gen.messagesByName[inName]
877	if !ok {
878		return fmt.Errorf("method %v: no descriptor for type %v", desc.FullName(), inName)
879	}
880	method.Input = in
881
882	outName := desc.Output().FullName()
883	out, ok := gen.messagesByName[outName]
884	if !ok {
885		return fmt.Errorf("method %v: no descriptor for type %v", desc.FullName(), outName)
886	}
887	method.Output = out
888
889	return nil
890}
891
892// A GeneratedFile is a generated file.
893type GeneratedFile struct {
894	gen              *Plugin
895	skip             bool
896	filename         string
897	goImportPath     GoImportPath
898	buf              bytes.Buffer
899	packageNames     map[GoImportPath]GoPackageName
900	usedPackageNames map[GoPackageName]bool
901	manualImports    map[GoImportPath]bool
902	annotations      map[string][]Location
903}
904
905// NewGeneratedFile creates a new generated file with the given filename
906// and import path.
907func (gen *Plugin) NewGeneratedFile(filename string, goImportPath GoImportPath) *GeneratedFile {
908	g := &GeneratedFile{
909		gen:              gen,
910		filename:         filename,
911		goImportPath:     goImportPath,
912		packageNames:     make(map[GoImportPath]GoPackageName),
913		usedPackageNames: make(map[GoPackageName]bool),
914		manualImports:    make(map[GoImportPath]bool),
915		annotations:      make(map[string][]Location),
916	}
917
918	// All predeclared identifiers in Go are already used.
919	for _, s := range types.Universe.Names() {
920		g.usedPackageNames[GoPackageName(s)] = true
921	}
922
923	gen.genFiles = append(gen.genFiles, g)
924	return g
925}
926
927// P prints a line to the generated output. It converts each parameter to a
928// string following the same rules as fmt.Print. It never inserts spaces
929// between parameters.
930func (g *GeneratedFile) P(v ...interface{}) {
931	for _, x := range v {
932		switch x := x.(type) {
933		case GoIdent:
934			fmt.Fprint(&g.buf, g.QualifiedGoIdent(x))
935		default:
936			fmt.Fprint(&g.buf, x)
937		}
938	}
939	fmt.Fprintln(&g.buf)
940}
941
942// QualifiedGoIdent returns the string to use for a Go identifier.
943//
944// If the identifier is from a different Go package than the generated file,
945// the returned name will be qualified (package.name) and an import statement
946// for the identifier's package will be included in the file.
947func (g *GeneratedFile) QualifiedGoIdent(ident GoIdent) string {
948	if ident.GoImportPath == g.goImportPath {
949		return ident.GoName
950	}
951	if packageName, ok := g.packageNames[ident.GoImportPath]; ok {
952		return string(packageName) + "." + ident.GoName
953	}
954	packageName := cleanPackageName(path.Base(string(ident.GoImportPath)))
955	for i, orig := 1, packageName; g.usedPackageNames[packageName]; i++ {
956		packageName = orig + GoPackageName(strconv.Itoa(i))
957	}
958	g.packageNames[ident.GoImportPath] = packageName
959	g.usedPackageNames[packageName] = true
960	return string(packageName) + "." + ident.GoName
961}
962
963// Import ensures a package is imported by the generated file.
964//
965// Packages referenced by QualifiedGoIdent are automatically imported.
966// Explicitly importing a package with Import is generally only necessary
967// when the import will be blank (import _ "package").
968func (g *GeneratedFile) Import(importPath GoImportPath) {
969	g.manualImports[importPath] = true
970}
971
972// Write implements io.Writer.
973func (g *GeneratedFile) Write(p []byte) (n int, err error) {
974	return g.buf.Write(p)
975}
976
977// Skip removes the generated file from the plugin output.
978func (g *GeneratedFile) Skip() {
979	g.skip = true
980}
981
982// Unskip reverts a previous call to Skip, re-including the generated file in
983// the plugin output.
984func (g *GeneratedFile) Unskip() {
985	g.skip = false
986}
987
988// Annotate associates a symbol in a generated Go file with a location in a
989// source .proto file.
990//
991// The symbol may refer to a type, constant, variable, function, method, or
992// struct field.  The "T.sel" syntax is used to identify the method or field
993// 'sel' on type 'T'.
994func (g *GeneratedFile) Annotate(symbol string, loc Location) {
995	g.annotations[symbol] = append(g.annotations[symbol], loc)
996}
997
998// Content returns the contents of the generated file.
999func (g *GeneratedFile) Content() ([]byte, error) {
1000	if !strings.HasSuffix(g.filename, ".go") {
1001		return g.buf.Bytes(), nil
1002	}
1003
1004	// Reformat generated code.
1005	original := g.buf.Bytes()
1006	fset := token.NewFileSet()
1007	file, err := parser.ParseFile(fset, "", original, parser.ParseComments)
1008	if err != nil {
1009		// Print out the bad code with line numbers.
1010		// This should never happen in practice, but it can while changing generated code
1011		// so consider this a debugging aid.
1012		var src bytes.Buffer
1013		s := bufio.NewScanner(bytes.NewReader(original))
1014		for line := 1; s.Scan(); line++ {
1015			fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes())
1016		}
1017		return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.filename, err, src.String())
1018	}
1019
1020	// Collect a sorted list of all imports.
1021	var importPaths [][2]string
1022	rewriteImport := func(importPath string) string {
1023		if f := g.gen.opts.ImportRewriteFunc; f != nil {
1024			return string(f(GoImportPath(importPath)))
1025		}
1026		return importPath
1027	}
1028	for importPath := range g.packageNames {
1029		pkgName := string(g.packageNames[GoImportPath(importPath)])
1030		pkgPath := rewriteImport(string(importPath))
1031		importPaths = append(importPaths, [2]string{pkgName, pkgPath})
1032	}
1033	for importPath := range g.manualImports {
1034		if _, ok := g.packageNames[importPath]; !ok {
1035			pkgPath := rewriteImport(string(importPath))
1036			importPaths = append(importPaths, [2]string{"_", pkgPath})
1037		}
1038	}
1039	sort.Slice(importPaths, func(i, j int) bool {
1040		return importPaths[i][1] < importPaths[j][1]
1041	})
1042
1043	// Modify the AST to include a new import block.
1044	if len(importPaths) > 0 {
1045		// Insert block after package statement or
1046		// possible comment attached to the end of the package statement.
1047		pos := file.Package
1048		tokFile := fset.File(file.Package)
1049		pkgLine := tokFile.Line(file.Package)
1050		for _, c := range file.Comments {
1051			if tokFile.Line(c.Pos()) > pkgLine {
1052				break
1053			}
1054			pos = c.End()
1055		}
1056
1057		// Construct the import block.
1058		impDecl := &ast.GenDecl{
1059			Tok:    token.IMPORT,
1060			TokPos: pos,
1061			Lparen: pos,
1062			Rparen: pos,
1063		}
1064		for _, importPath := range importPaths {
1065			impDecl.Specs = append(impDecl.Specs, &ast.ImportSpec{
1066				Name: &ast.Ident{
1067					Name:    importPath[0],
1068					NamePos: pos,
1069				},
1070				Path: &ast.BasicLit{
1071					Kind:     token.STRING,
1072					Value:    strconv.Quote(importPath[1]),
1073					ValuePos: pos,
1074				},
1075				EndPos: pos,
1076			})
1077		}
1078		file.Decls = append([]ast.Decl{impDecl}, file.Decls...)
1079	}
1080
1081	var out bytes.Buffer
1082	if err = (&printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}).Fprint(&out, fset, file); err != nil {
1083		return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.filename, err)
1084	}
1085	return out.Bytes(), nil
1086}
1087
1088// metaFile returns the contents of the file's metadata file, which is a
1089// text formatted string of the google.protobuf.GeneratedCodeInfo.
1090func (g *GeneratedFile) metaFile(content []byte) (string, error) {
1091	fset := token.NewFileSet()
1092	astFile, err := parser.ParseFile(fset, "", content, 0)
1093	if err != nil {
1094		return "", err
1095	}
1096	info := &descriptorpb.GeneratedCodeInfo{}
1097
1098	seenAnnotations := make(map[string]bool)
1099	annotate := func(s string, ident *ast.Ident) {
1100		seenAnnotations[s] = true
1101		for _, loc := range g.annotations[s] {
1102			info.Annotation = append(info.Annotation, &descriptorpb.GeneratedCodeInfo_Annotation{
1103				SourceFile: proto.String(loc.SourceFile),
1104				Path:       loc.Path,
1105				Begin:      proto.Int32(int32(fset.Position(ident.Pos()).Offset)),
1106				End:        proto.Int32(int32(fset.Position(ident.End()).Offset)),
1107			})
1108		}
1109	}
1110	for _, decl := range astFile.Decls {
1111		switch decl := decl.(type) {
1112		case *ast.GenDecl:
1113			for _, spec := range decl.Specs {
1114				switch spec := spec.(type) {
1115				case *ast.TypeSpec:
1116					annotate(spec.Name.Name, spec.Name)
1117					switch st := spec.Type.(type) {
1118					case *ast.StructType:
1119						for _, field := range st.Fields.List {
1120							for _, name := range field.Names {
1121								annotate(spec.Name.Name+"."+name.Name, name)
1122							}
1123						}
1124					case *ast.InterfaceType:
1125						for _, field := range st.Methods.List {
1126							for _, name := range field.Names {
1127								annotate(spec.Name.Name+"."+name.Name, name)
1128							}
1129						}
1130					}
1131				case *ast.ValueSpec:
1132					for _, name := range spec.Names {
1133						annotate(name.Name, name)
1134					}
1135				}
1136			}
1137		case *ast.FuncDecl:
1138			if decl.Recv == nil {
1139				annotate(decl.Name.Name, decl.Name)
1140			} else {
1141				recv := decl.Recv.List[0].Type
1142				if s, ok := recv.(*ast.StarExpr); ok {
1143					recv = s.X
1144				}
1145				if id, ok := recv.(*ast.Ident); ok {
1146					annotate(id.Name+"."+decl.Name.Name, decl.Name)
1147				}
1148			}
1149		}
1150	}
1151	for a := range g.annotations {
1152		if !seenAnnotations[a] {
1153			return "", fmt.Errorf("%v: no symbol matching annotation %q", g.filename, a)
1154		}
1155	}
1156
1157	b, err := prototext.Marshal(info)
1158	if err != nil {
1159		return "", err
1160	}
1161	return string(b), nil
1162}
1163
1164// A GoIdent is a Go identifier, consisting of a name and import path.
1165// The name is a single identifier and may not be a dot-qualified selector.
1166type GoIdent struct {
1167	GoName       string
1168	GoImportPath GoImportPath
1169}
1170
1171func (id GoIdent) String() string { return fmt.Sprintf("%q.%v", id.GoImportPath, id.GoName) }
1172
1173// newGoIdent returns the Go identifier for a descriptor.
1174func newGoIdent(f *File, d protoreflect.Descriptor) GoIdent {
1175	name := strings.TrimPrefix(string(d.FullName()), string(f.Desc.Package())+".")
1176	return GoIdent{
1177		GoName:       strs.GoCamelCase(name),
1178		GoImportPath: f.GoImportPath,
1179	}
1180}
1181
1182// A GoImportPath is the import path of a Go package.
1183// For example: "google.golang.org/protobuf/compiler/protogen"
1184type GoImportPath string
1185
1186func (p GoImportPath) String() string { return strconv.Quote(string(p)) }
1187
1188// Ident returns a GoIdent with s as the GoName and p as the GoImportPath.
1189func (p GoImportPath) Ident(s string) GoIdent {
1190	return GoIdent{GoName: s, GoImportPath: p}
1191}
1192
1193// A GoPackageName is the name of a Go package. e.g., "protobuf".
1194type GoPackageName string
1195
1196// cleanPackageName converts a string to a valid Go package name.
1197func cleanPackageName(name string) GoPackageName {
1198	return GoPackageName(strs.GoSanitized(name))
1199}
1200
1201type pathType int
1202
1203const (
1204	pathTypeImport pathType = iota
1205	pathTypeSourceRelative
1206)
1207
1208// A Location is a location in a .proto source file.
1209//
1210// See the google.protobuf.SourceCodeInfo documentation in descriptor.proto
1211// for details.
1212type Location struct {
1213	SourceFile string
1214	Path       protoreflect.SourcePath
1215}
1216
1217// appendPath add elements to a Location's path, returning a new Location.
1218func (loc Location) appendPath(num protoreflect.FieldNumber, idx int) Location {
1219	loc.Path = append(protoreflect.SourcePath(nil), loc.Path...) // make copy
1220	loc.Path = append(loc.Path, int32(num), int32(idx))
1221	return loc
1222}
1223
1224// CommentSet is a set of leading and trailing comments associated
1225// with a .proto descriptor declaration.
1226type CommentSet struct {
1227	LeadingDetached []Comments
1228	Leading         Comments
1229	Trailing        Comments
1230}
1231
1232func makeCommentSet(loc protoreflect.SourceLocation) CommentSet {
1233	var leadingDetached []Comments
1234	for _, s := range loc.LeadingDetachedComments {
1235		leadingDetached = append(leadingDetached, Comments(s))
1236	}
1237	return CommentSet{
1238		LeadingDetached: leadingDetached,
1239		Leading:         Comments(loc.LeadingComments),
1240		Trailing:        Comments(loc.TrailingComments),
1241	}
1242}
1243
1244// Comments is a comments string as provided by protoc.
1245type Comments string
1246
1247// String formats the comments by inserting // to the start of each line,
1248// ensuring that there is a trailing newline.
1249// An empty comment is formatted as an empty string.
1250func (c Comments) String() string {
1251	if c == "" {
1252		return ""
1253	}
1254	var b []byte
1255	for _, line := range strings.Split(strings.TrimSuffix(string(c), "\n"), "\n") {
1256		b = append(b, "//"...)
1257		b = append(b, line...)
1258		b = append(b, "\n"...)
1259	}
1260	return string(b)
1261}
1262