1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package internal_gengo is internal to the protobuf module.
6package internal_gengo
7
8import (
9	"fmt"
10	"go/ast"
11	"go/parser"
12	"go/token"
13	"math"
14	"strconv"
15	"strings"
16	"unicode"
17	"unicode/utf8"
18
19	"google.golang.org/protobuf/compiler/protogen"
20	"google.golang.org/protobuf/internal/encoding/tag"
21	"google.golang.org/protobuf/internal/genid"
22	"google.golang.org/protobuf/internal/version"
23	"google.golang.org/protobuf/reflect/protoreflect"
24	"google.golang.org/protobuf/runtime/protoimpl"
25
26	"google.golang.org/protobuf/types/descriptorpb"
27	"google.golang.org/protobuf/types/pluginpb"
28)
29
30// SupportedFeatures reports the set of supported protobuf language features.
31var SupportedFeatures = uint64(pluginpb.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL)
32
33// GenerateVersionMarkers specifies whether to generate version markers.
34var GenerateVersionMarkers = true
35
36// Standard library dependencies.
37const (
38	base64Package  = protogen.GoImportPath("encoding/base64")
39	mathPackage    = protogen.GoImportPath("math")
40	reflectPackage = protogen.GoImportPath("reflect")
41	sortPackage    = protogen.GoImportPath("sort")
42	stringsPackage = protogen.GoImportPath("strings")
43	syncPackage    = protogen.GoImportPath("sync")
44	timePackage    = protogen.GoImportPath("time")
45	utf8Package    = protogen.GoImportPath("unicode/utf8")
46)
47
48// Protobuf library dependencies.
49//
50// These are declared as an interface type so that they can be more easily
51// patched to support unique build environments that impose restrictions
52// on the dependencies of generated source code.
53var (
54	protoPackage         goImportPath = protogen.GoImportPath("google.golang.org/protobuf/proto")
55	protoifacePackage    goImportPath = protogen.GoImportPath("google.golang.org/protobuf/runtime/protoiface")
56	protoimplPackage     goImportPath = protogen.GoImportPath("google.golang.org/protobuf/runtime/protoimpl")
57	protojsonPackage     goImportPath = protogen.GoImportPath("google.golang.org/protobuf/encoding/protojson")
58	protoreflectPackage  goImportPath = protogen.GoImportPath("google.golang.org/protobuf/reflect/protoreflect")
59	protoregistryPackage goImportPath = protogen.GoImportPath("google.golang.org/protobuf/reflect/protoregistry")
60)
61
62type goImportPath interface {
63	String() string
64	Ident(string) protogen.GoIdent
65}
66
67// GenerateFile generates the contents of a .pb.go file.
68func GenerateFile(gen *protogen.Plugin, file *protogen.File) *protogen.GeneratedFile {
69	filename := file.GeneratedFilenamePrefix + ".pb.go"
70	g := gen.NewGeneratedFile(filename, file.GoImportPath)
71	f := newFileInfo(file)
72
73	genStandaloneComments(g, f, int32(genid.FileDescriptorProto_Syntax_field_number))
74	genGeneratedHeader(gen, g, f)
75	genStandaloneComments(g, f, int32(genid.FileDescriptorProto_Package_field_number))
76
77	packageDoc := genPackageKnownComment(f)
78	g.P(packageDoc, "package ", f.GoPackageName)
79	g.P()
80
81	// Emit a static check that enforces a minimum version of the proto package.
82	if GenerateVersionMarkers {
83		g.P("const (")
84		g.P("// Verify that this generated code is sufficiently up-to-date.")
85		g.P("_ = ", protoimplPackage.Ident("EnforceVersion"), "(", protoimpl.GenVersion, " - ", protoimplPackage.Ident("MinVersion"), ")")
86		g.P("// Verify that runtime/protoimpl is sufficiently up-to-date.")
87		g.P("_ = ", protoimplPackage.Ident("EnforceVersion"), "(", protoimplPackage.Ident("MaxVersion"), " - ", protoimpl.GenVersion, ")")
88		g.P(")")
89		g.P()
90	}
91
92	for i, imps := 0, f.Desc.Imports(); i < imps.Len(); i++ {
93		genImport(gen, g, f, imps.Get(i))
94	}
95	for _, enum := range f.allEnums {
96		genEnum(g, f, enum)
97	}
98	for _, message := range f.allMessages {
99		genMessage(g, f, message)
100	}
101	genExtensions(g, f)
102
103	genReflectFileDescriptor(gen, g, f)
104
105	return g
106}
107
108// genStandaloneComments prints all leading comments for a FileDescriptorProto
109// location identified by the field number n.
110func genStandaloneComments(g *protogen.GeneratedFile, f *fileInfo, n int32) {
111	loc := f.Desc.SourceLocations().ByPath(protoreflect.SourcePath{n})
112	for _, s := range loc.LeadingDetachedComments {
113		g.P(protogen.Comments(s))
114		g.P()
115	}
116	if s := loc.LeadingComments; s != "" {
117		g.P(protogen.Comments(s))
118		g.P()
119	}
120}
121
122func genGeneratedHeader(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo) {
123	g.P("// Code generated by protoc-gen-go. DO NOT EDIT.")
124
125	if GenerateVersionMarkers {
126		g.P("// versions:")
127		protocGenGoVersion := version.String()
128		protocVersion := "(unknown)"
129		if v := gen.Request.GetCompilerVersion(); v != nil {
130			protocVersion = fmt.Sprintf("v%v.%v.%v", v.GetMajor(), v.GetMinor(), v.GetPatch())
131			if s := v.GetSuffix(); s != "" {
132				protocVersion += "-" + s
133			}
134		}
135		g.P("// \tprotoc-gen-go ", protocGenGoVersion)
136		g.P("// \tprotoc        ", protocVersion)
137	}
138
139	if f.Proto.GetOptions().GetDeprecated() {
140		g.P("// ", f.Desc.Path(), " is a deprecated file.")
141	} else {
142		g.P("// source: ", f.Desc.Path())
143	}
144	g.P()
145}
146
147func genImport(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo, imp protoreflect.FileImport) {
148	impFile, ok := gen.FilesByPath[imp.Path()]
149	if !ok {
150		return
151	}
152	if impFile.GoImportPath == f.GoImportPath {
153		// Don't generate imports or aliases for types in the same Go package.
154		return
155	}
156	// Generate imports for all non-weak dependencies, even if they are not
157	// referenced, because other code and tools depend on having the
158	// full transitive closure of protocol buffer types in the binary.
159	if !imp.IsWeak {
160		g.Import(impFile.GoImportPath)
161	}
162	if !imp.IsPublic {
163		return
164	}
165
166	// Generate public imports by generating the imported file, parsing it,
167	// and extracting every symbol that should receive a forwarding declaration.
168	impGen := GenerateFile(gen, impFile)
169	impGen.Skip()
170	b, err := impGen.Content()
171	if err != nil {
172		gen.Error(err)
173		return
174	}
175	fset := token.NewFileSet()
176	astFile, err := parser.ParseFile(fset, "", b, parser.ParseComments)
177	if err != nil {
178		gen.Error(err)
179		return
180	}
181	genForward := func(tok token.Token, name string, expr ast.Expr) {
182		// Don't import unexported symbols.
183		r, _ := utf8.DecodeRuneInString(name)
184		if !unicode.IsUpper(r) {
185			return
186		}
187		// Don't import the FileDescriptor.
188		if name == impFile.GoDescriptorIdent.GoName {
189			return
190		}
191		// Don't import decls referencing a symbol defined in another package.
192		// i.e., don't import decls which are themselves public imports:
193		//
194		//	type T = somepackage.T
195		if _, ok := expr.(*ast.SelectorExpr); ok {
196			return
197		}
198		g.P(tok, " ", name, " = ", impFile.GoImportPath.Ident(name))
199	}
200	g.P("// Symbols defined in public import of ", imp.Path(), ".")
201	g.P()
202	for _, decl := range astFile.Decls {
203		switch decl := decl.(type) {
204		case *ast.GenDecl:
205			for _, spec := range decl.Specs {
206				switch spec := spec.(type) {
207				case *ast.TypeSpec:
208					genForward(decl.Tok, spec.Name.Name, spec.Type)
209				case *ast.ValueSpec:
210					for i, name := range spec.Names {
211						var expr ast.Expr
212						if i < len(spec.Values) {
213							expr = spec.Values[i]
214						}
215						genForward(decl.Tok, name.Name, expr)
216					}
217				case *ast.ImportSpec:
218				default:
219					panic(fmt.Sprintf("can't generate forward for spec type %T", spec))
220				}
221			}
222		}
223	}
224	g.P()
225}
226
227func genEnum(g *protogen.GeneratedFile, f *fileInfo, e *enumInfo) {
228	// Enum type declaration.
229	g.Annotate(e.GoIdent.GoName, e.Location)
230	leadingComments := appendDeprecationSuffix(e.Comments.Leading,
231		e.Desc.Options().(*descriptorpb.EnumOptions).GetDeprecated())
232	g.P(leadingComments,
233		"type ", e.GoIdent, " int32")
234
235	// Enum value constants.
236	g.P("const (")
237	for _, value := range e.Values {
238		g.Annotate(value.GoIdent.GoName, value.Location)
239		leadingComments := appendDeprecationSuffix(value.Comments.Leading,
240			value.Desc.Options().(*descriptorpb.EnumValueOptions).GetDeprecated())
241		g.P(leadingComments,
242			value.GoIdent, " ", e.GoIdent, " = ", value.Desc.Number(),
243			trailingComment(value.Comments.Trailing))
244	}
245	g.P(")")
246	g.P()
247
248	// Enum value maps.
249	g.P("// Enum value maps for ", e.GoIdent, ".")
250	g.P("var (")
251	g.P(e.GoIdent.GoName+"_name", " = map[int32]string{")
252	for _, value := range e.Values {
253		duplicate := ""
254		if value.Desc != e.Desc.Values().ByNumber(value.Desc.Number()) {
255			duplicate = "// Duplicate value: "
256		}
257		g.P(duplicate, value.Desc.Number(), ": ", strconv.Quote(string(value.Desc.Name())), ",")
258	}
259	g.P("}")
260	g.P(e.GoIdent.GoName+"_value", " = map[string]int32{")
261	for _, value := range e.Values {
262		g.P(strconv.Quote(string(value.Desc.Name())), ": ", value.Desc.Number(), ",")
263	}
264	g.P("}")
265	g.P(")")
266	g.P()
267
268	// Enum method.
269	//
270	// NOTE: A pointer value is needed to represent presence in proto2.
271	// Since a proto2 message can reference a proto3 enum, it is useful to
272	// always generate this method (even on proto3 enums) to support that case.
273	g.P("func (x ", e.GoIdent, ") Enum() *", e.GoIdent, " {")
274	g.P("p := new(", e.GoIdent, ")")
275	g.P("*p = x")
276	g.P("return p")
277	g.P("}")
278	g.P()
279
280	// String method.
281	g.P("func (x ", e.GoIdent, ") String() string {")
282	g.P("return ", protoimplPackage.Ident("X"), ".EnumStringOf(x.Descriptor(), ", protoreflectPackage.Ident("EnumNumber"), "(x))")
283	g.P("}")
284	g.P()
285
286	genEnumReflectMethods(g, f, e)
287
288	// UnmarshalJSON method.
289	if e.genJSONMethod && e.Desc.Syntax() == protoreflect.Proto2 {
290		g.P("// Deprecated: Do not use.")
291		g.P("func (x *", e.GoIdent, ") UnmarshalJSON(b []byte) error {")
292		g.P("num, err := ", protoimplPackage.Ident("X"), ".UnmarshalJSONEnum(x.Descriptor(), b)")
293		g.P("if err != nil {")
294		g.P("return err")
295		g.P("}")
296		g.P("*x = ", e.GoIdent, "(num)")
297		g.P("return nil")
298		g.P("}")
299		g.P()
300	}
301
302	// EnumDescriptor method.
303	if e.genRawDescMethod {
304		var indexes []string
305		for i := 1; i < len(e.Location.Path); i += 2 {
306			indexes = append(indexes, strconv.Itoa(int(e.Location.Path[i])))
307		}
308		g.P("// Deprecated: Use ", e.GoIdent, ".Descriptor instead.")
309		g.P("func (", e.GoIdent, ") EnumDescriptor() ([]byte, []int) {")
310		g.P("return ", rawDescVarName(f), "GZIP(), []int{", strings.Join(indexes, ","), "}")
311		g.P("}")
312		g.P()
313		f.needRawDesc = true
314	}
315}
316
317func genMessage(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) {
318	if m.Desc.IsMapEntry() {
319		return
320	}
321
322	// Message type declaration.
323	g.Annotate(m.GoIdent.GoName, m.Location)
324	leadingComments := appendDeprecationSuffix(m.Comments.Leading,
325		m.Desc.Options().(*descriptorpb.MessageOptions).GetDeprecated())
326	g.P(leadingComments,
327		"type ", m.GoIdent, " struct {")
328	genMessageFields(g, f, m)
329	g.P("}")
330	g.P()
331
332	genMessageKnownFunctions(g, f, m)
333	genMessageDefaultDecls(g, f, m)
334	genMessageMethods(g, f, m)
335	genMessageOneofWrapperTypes(g, f, m)
336}
337
338func genMessageFields(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) {
339	sf := f.allMessageFieldsByPtr[m]
340	genMessageInternalFields(g, f, m, sf)
341	for _, field := range m.Fields {
342		genMessageField(g, f, m, field, sf)
343	}
344}
345
346func genMessageInternalFields(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, sf *structFields) {
347	g.P(genid.State_goname, " ", protoimplPackage.Ident("MessageState"))
348	sf.append(genid.State_goname)
349	g.P(genid.SizeCache_goname, " ", protoimplPackage.Ident("SizeCache"))
350	sf.append(genid.SizeCache_goname)
351	if m.hasWeak {
352		g.P(genid.WeakFields_goname, " ", protoimplPackage.Ident("WeakFields"))
353		sf.append(genid.WeakFields_goname)
354	}
355	g.P(genid.UnknownFields_goname, " ", protoimplPackage.Ident("UnknownFields"))
356	sf.append(genid.UnknownFields_goname)
357	if m.Desc.ExtensionRanges().Len() > 0 {
358		g.P(genid.ExtensionFields_goname, " ", protoimplPackage.Ident("ExtensionFields"))
359		sf.append(genid.ExtensionFields_goname)
360	}
361	if sf.count > 0 {
362		g.P()
363	}
364}
365
366func genMessageField(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, field *protogen.Field, sf *structFields) {
367	if oneof := field.Oneof; oneof != nil && !oneof.Desc.IsSynthetic() {
368		// It would be a bit simpler to iterate over the oneofs below,
369		// but generating the field here keeps the contents of the Go
370		// struct in the same order as the contents of the source
371		// .proto file.
372		if oneof.Fields[0] != field {
373			return // only generate for first appearance
374		}
375
376		tags := structTags{
377			{"protobuf_oneof", string(oneof.Desc.Name())},
378		}
379		if m.isTracked {
380			tags = append(tags, gotrackTags...)
381		}
382
383		g.Annotate(m.GoIdent.GoName+"."+oneof.GoName, oneof.Location)
384		leadingComments := oneof.Comments.Leading
385		if leadingComments != "" {
386			leadingComments += "\n"
387		}
388		ss := []string{fmt.Sprintf(" Types that are assignable to %s:\n", oneof.GoName)}
389		for _, field := range oneof.Fields {
390			ss = append(ss, "\t*"+field.GoIdent.GoName+"\n")
391		}
392		leadingComments += protogen.Comments(strings.Join(ss, ""))
393		g.P(leadingComments,
394			oneof.GoName, " ", oneofInterfaceName(oneof), tags)
395		sf.append(oneof.GoName)
396		return
397	}
398	goType, pointer := fieldGoType(g, f, field)
399	if pointer {
400		goType = "*" + goType
401	}
402	tags := structTags{
403		{"protobuf", fieldProtobufTagValue(field)},
404		{"json", fieldJSONTagValue(field)},
405	}
406	if field.Desc.IsMap() {
407		key := field.Message.Fields[0]
408		val := field.Message.Fields[1]
409		tags = append(tags, structTags{
410			{"protobuf_key", fieldProtobufTagValue(key)},
411			{"protobuf_val", fieldProtobufTagValue(val)},
412		}...)
413	}
414	if m.isTracked {
415		tags = append(tags, gotrackTags...)
416	}
417
418	name := field.GoName
419	if field.Desc.IsWeak() {
420		name = genid.WeakFieldPrefix_goname + name
421	}
422	g.Annotate(m.GoIdent.GoName+"."+name, field.Location)
423	leadingComments := appendDeprecationSuffix(field.Comments.Leading,
424		field.Desc.Options().(*descriptorpb.FieldOptions).GetDeprecated())
425	g.P(leadingComments,
426		name, " ", goType, tags,
427		trailingComment(field.Comments.Trailing))
428	sf.append(field.GoName)
429}
430
431// genMessageDefaultDecls generates consts and vars holding the default
432// values of fields.
433func genMessageDefaultDecls(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) {
434	var consts, vars []string
435	for _, field := range m.Fields {
436		if !field.Desc.HasDefault() {
437			continue
438		}
439		name := "Default_" + m.GoIdent.GoName + "_" + field.GoName
440		goType, _ := fieldGoType(g, f, field)
441		defVal := field.Desc.Default()
442		switch field.Desc.Kind() {
443		case protoreflect.StringKind:
444			consts = append(consts, fmt.Sprintf("%s = %s(%q)", name, goType, defVal.String()))
445		case protoreflect.BytesKind:
446			vars = append(vars, fmt.Sprintf("%s = %s(%q)", name, goType, defVal.Bytes()))
447		case protoreflect.EnumKind:
448			idx := field.Desc.DefaultEnumValue().Index()
449			val := field.Enum.Values[idx]
450			if val.GoIdent.GoImportPath == f.GoImportPath {
451				consts = append(consts, fmt.Sprintf("%s = %s", name, g.QualifiedGoIdent(val.GoIdent)))
452			} else {
453				// If the enum value is declared in a different Go package,
454				// reference it by number since the name may not be correct.
455				// See https://github.com/golang/protobuf/issues/513.
456				consts = append(consts, fmt.Sprintf("%s = %s(%d) // %s",
457					name, g.QualifiedGoIdent(field.Enum.GoIdent), val.Desc.Number(), g.QualifiedGoIdent(val.GoIdent)))
458			}
459		case protoreflect.FloatKind, protoreflect.DoubleKind:
460			if f := defVal.Float(); math.IsNaN(f) || math.IsInf(f, 0) {
461				var fn, arg string
462				switch f := defVal.Float(); {
463				case math.IsInf(f, -1):
464					fn, arg = g.QualifiedGoIdent(mathPackage.Ident("Inf")), "-1"
465				case math.IsInf(f, +1):
466					fn, arg = g.QualifiedGoIdent(mathPackage.Ident("Inf")), "+1"
467				case math.IsNaN(f):
468					fn, arg = g.QualifiedGoIdent(mathPackage.Ident("NaN")), ""
469				}
470				vars = append(vars, fmt.Sprintf("%s = %s(%s(%s))", name, goType, fn, arg))
471			} else {
472				consts = append(consts, fmt.Sprintf("%s = %s(%v)", name, goType, f))
473			}
474		default:
475			consts = append(consts, fmt.Sprintf("%s = %s(%v)", name, goType, defVal.Interface()))
476		}
477	}
478	if len(consts) > 0 {
479		g.P("// Default values for ", m.GoIdent, " fields.")
480		g.P("const (")
481		for _, s := range consts {
482			g.P(s)
483		}
484		g.P(")")
485	}
486	if len(vars) > 0 {
487		g.P("// Default values for ", m.GoIdent, " fields.")
488		g.P("var (")
489		for _, s := range vars {
490			g.P(s)
491		}
492		g.P(")")
493	}
494	g.P()
495}
496
497func genMessageMethods(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) {
498	genMessageBaseMethods(g, f, m)
499	genMessageGetterMethods(g, f, m)
500	genMessageSetterMethods(g, f, m)
501}
502
503func genMessageBaseMethods(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) {
504	// Reset method.
505	g.P("func (x *", m.GoIdent, ") Reset() {")
506	g.P("*x = ", m.GoIdent, "{}")
507	g.P("if ", protoimplPackage.Ident("UnsafeEnabled"), " {")
508	g.P("mi := &", messageTypesVarName(f), "[", f.allMessagesByPtr[m], "]")
509	g.P("ms := ", protoimplPackage.Ident("X"), ".MessageStateOf(", protoimplPackage.Ident("Pointer"), "(x))")
510	g.P("ms.StoreMessageInfo(mi)")
511	g.P("}")
512	g.P("}")
513	g.P()
514
515	// String method.
516	g.P("func (x *", m.GoIdent, ") String() string {")
517	g.P("return ", protoimplPackage.Ident("X"), ".MessageStringOf(x)")
518	g.P("}")
519	g.P()
520
521	// ProtoMessage method.
522	g.P("func (*", m.GoIdent, ") ProtoMessage() {}")
523	g.P()
524
525	// ProtoReflect method.
526	genMessageReflectMethods(g, f, m)
527
528	// Descriptor method.
529	if m.genRawDescMethod {
530		var indexes []string
531		for i := 1; i < len(m.Location.Path); i += 2 {
532			indexes = append(indexes, strconv.Itoa(int(m.Location.Path[i])))
533		}
534		g.P("// Deprecated: Use ", m.GoIdent, ".ProtoReflect.Descriptor instead.")
535		g.P("func (*", m.GoIdent, ") Descriptor() ([]byte, []int) {")
536		g.P("return ", rawDescVarName(f), "GZIP(), []int{", strings.Join(indexes, ","), "}")
537		g.P("}")
538		g.P()
539		f.needRawDesc = true
540	}
541}
542
543func genMessageGetterMethods(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) {
544	for _, field := range m.Fields {
545		genNoInterfacePragma(g, m.isTracked)
546
547		// Getter for parent oneof.
548		if oneof := field.Oneof; oneof != nil && oneof.Fields[0] == field && !oneof.Desc.IsSynthetic() {
549			g.Annotate(m.GoIdent.GoName+".Get"+oneof.GoName, oneof.Location)
550			g.P("func (m *", m.GoIdent.GoName, ") Get", oneof.GoName, "() ", oneofInterfaceName(oneof), " {")
551			g.P("if m != nil {")
552			g.P("return m.", oneof.GoName)
553			g.P("}")
554			g.P("return nil")
555			g.P("}")
556			g.P()
557		}
558
559		// Getter for message field.
560		goType, pointer := fieldGoType(g, f, field)
561		defaultValue := fieldDefaultValue(g, f, m, field)
562		g.Annotate(m.GoIdent.GoName+".Get"+field.GoName, field.Location)
563		leadingComments := appendDeprecationSuffix("",
564			field.Desc.Options().(*descriptorpb.FieldOptions).GetDeprecated())
565		switch {
566		case field.Desc.IsWeak():
567			g.P(leadingComments, "func (x *", m.GoIdent, ") Get", field.GoName, "() ", protoPackage.Ident("Message"), "{")
568			g.P("var w ", protoimplPackage.Ident("WeakFields"))
569			g.P("if x != nil {")
570			g.P("w = x.", genid.WeakFields_goname)
571			if m.isTracked {
572				g.P("_ = x.", genid.WeakFieldPrefix_goname+field.GoName)
573			}
574			g.P("}")
575			g.P("return ", protoimplPackage.Ident("X"), ".GetWeak(w, ", field.Desc.Number(), ", ", strconv.Quote(string(field.Message.Desc.FullName())), ")")
576			g.P("}")
577		case field.Oneof != nil && !field.Oneof.Desc.IsSynthetic():
578			g.P(leadingComments, "func (x *", m.GoIdent, ") Get", field.GoName, "() ", goType, " {")
579			g.P("if x, ok := x.Get", field.Oneof.GoName, "().(*", field.GoIdent, "); ok {")
580			g.P("return x.", field.GoName)
581			g.P("}")
582			g.P("return ", defaultValue)
583			g.P("}")
584		default:
585			g.P(leadingComments, "func (x *", m.GoIdent, ") Get", field.GoName, "() ", goType, " {")
586			if !field.Desc.HasPresence() || defaultValue == "nil" {
587				g.P("if x != nil {")
588			} else {
589				g.P("if x != nil && x.", field.GoName, " != nil {")
590			}
591			star := ""
592			if pointer {
593				star = "*"
594			}
595			g.P("return ", star, " x.", field.GoName)
596			g.P("}")
597			g.P("return ", defaultValue)
598			g.P("}")
599		}
600		g.P()
601	}
602}
603
604func genMessageSetterMethods(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) {
605	for _, field := range m.Fields {
606		if !field.Desc.IsWeak() {
607			continue
608		}
609
610		genNoInterfacePragma(g, m.isTracked)
611
612		g.Annotate(m.GoIdent.GoName+".Set"+field.GoName, field.Location)
613		leadingComments := appendDeprecationSuffix("",
614			field.Desc.Options().(*descriptorpb.FieldOptions).GetDeprecated())
615		g.P(leadingComments, "func (x *", m.GoIdent, ") Set", field.GoName, "(v ", protoPackage.Ident("Message"), ") {")
616		g.P("var w *", protoimplPackage.Ident("WeakFields"))
617		g.P("if x != nil {")
618		g.P("w = &x.", genid.WeakFields_goname)
619		if m.isTracked {
620			g.P("_ = x.", genid.WeakFieldPrefix_goname+field.GoName)
621		}
622		g.P("}")
623		g.P(protoimplPackage.Ident("X"), ".SetWeak(w, ", field.Desc.Number(), ", ", strconv.Quote(string(field.Message.Desc.FullName())), ", v)")
624		g.P("}")
625		g.P()
626	}
627}
628
629// fieldGoType returns the Go type used for a field.
630//
631// If it returns pointer=true, the struct field is a pointer to the type.
632func fieldGoType(g *protogen.GeneratedFile, f *fileInfo, field *protogen.Field) (goType string, pointer bool) {
633	if field.Desc.IsWeak() {
634		return "struct{}", false
635	}
636
637	pointer = field.Desc.HasPresence()
638	switch field.Desc.Kind() {
639	case protoreflect.BoolKind:
640		goType = "bool"
641	case protoreflect.EnumKind:
642		goType = g.QualifiedGoIdent(field.Enum.GoIdent)
643	case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
644		goType = "int32"
645	case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
646		goType = "uint32"
647	case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
648		goType = "int64"
649	case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
650		goType = "uint64"
651	case protoreflect.FloatKind:
652		goType = "float32"
653	case protoreflect.DoubleKind:
654		goType = "float64"
655	case protoreflect.StringKind:
656		goType = "string"
657	case protoreflect.BytesKind:
658		goType = "[]byte"
659		pointer = false // rely on nullability of slices for presence
660	case protoreflect.MessageKind, protoreflect.GroupKind:
661		goType = "*" + g.QualifiedGoIdent(field.Message.GoIdent)
662		pointer = false // pointer captured as part of the type
663	}
664	switch {
665	case field.Desc.IsList():
666		return "[]" + goType, false
667	case field.Desc.IsMap():
668		keyType, _ := fieldGoType(g, f, field.Message.Fields[0])
669		valType, _ := fieldGoType(g, f, field.Message.Fields[1])
670		return fmt.Sprintf("map[%v]%v", keyType, valType), false
671	}
672	return goType, pointer
673}
674
675func fieldProtobufTagValue(field *protogen.Field) string {
676	var enumName string
677	if field.Desc.Kind() == protoreflect.EnumKind {
678		enumName = protoimpl.X.LegacyEnumName(field.Enum.Desc)
679	}
680	return tag.Marshal(field.Desc, enumName)
681}
682
683func fieldDefaultValue(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, field *protogen.Field) string {
684	if field.Desc.IsList() {
685		return "nil"
686	}
687	if field.Desc.HasDefault() {
688		defVarName := "Default_" + m.GoIdent.GoName + "_" + field.GoName
689		if field.Desc.Kind() == protoreflect.BytesKind {
690			return "append([]byte(nil), " + defVarName + "...)"
691		}
692		return defVarName
693	}
694	switch field.Desc.Kind() {
695	case protoreflect.BoolKind:
696		return "false"
697	case protoreflect.StringKind:
698		return `""`
699	case protoreflect.MessageKind, protoreflect.GroupKind, protoreflect.BytesKind:
700		return "nil"
701	case protoreflect.EnumKind:
702		val := field.Enum.Values[0]
703		if val.GoIdent.GoImportPath == f.GoImportPath {
704			return g.QualifiedGoIdent(val.GoIdent)
705		} else {
706			// If the enum value is declared in a different Go package,
707			// reference it by number since the name may not be correct.
708			// See https://github.com/golang/protobuf/issues/513.
709			return g.QualifiedGoIdent(field.Enum.GoIdent) + "(" + strconv.FormatInt(int64(val.Desc.Number()), 10) + ")"
710		}
711	default:
712		return "0"
713	}
714}
715
716func fieldJSONTagValue(field *protogen.Field) string {
717	return string(field.Desc.Name()) + ",omitempty"
718}
719
720func genExtensions(g *protogen.GeneratedFile, f *fileInfo) {
721	if len(f.allExtensions) == 0 {
722		return
723	}
724
725	g.P("var ", extensionTypesVarName(f), " = []", protoimplPackage.Ident("ExtensionInfo"), "{")
726	for _, x := range f.allExtensions {
727		g.P("{")
728		g.P("ExtendedType: (*", x.Extendee.GoIdent, ")(nil),")
729		goType, pointer := fieldGoType(g, f, x.Extension)
730		if pointer {
731			goType = "*" + goType
732		}
733		g.P("ExtensionType: (", goType, ")(nil),")
734		g.P("Field: ", x.Desc.Number(), ",")
735		g.P("Name: ", strconv.Quote(string(x.Desc.FullName())), ",")
736		g.P("Tag: ", strconv.Quote(fieldProtobufTagValue(x.Extension)), ",")
737		g.P("Filename: ", strconv.Quote(f.Desc.Path()), ",")
738		g.P("},")
739	}
740	g.P("}")
741	g.P()
742
743	// Group extensions by the target message.
744	var orderedTargets []protogen.GoIdent
745	allExtensionsByTarget := make(map[protogen.GoIdent][]*extensionInfo)
746	allExtensionsByPtr := make(map[*extensionInfo]int)
747	for i, x := range f.allExtensions {
748		target := x.Extendee.GoIdent
749		if len(allExtensionsByTarget[target]) == 0 {
750			orderedTargets = append(orderedTargets, target)
751		}
752		allExtensionsByTarget[target] = append(allExtensionsByTarget[target], x)
753		allExtensionsByPtr[x] = i
754	}
755	for _, target := range orderedTargets {
756		g.P("// Extension fields to ", target, ".")
757		g.P("var (")
758		for _, x := range allExtensionsByTarget[target] {
759			xd := x.Desc
760			typeName := xd.Kind().String()
761			switch xd.Kind() {
762			case protoreflect.EnumKind:
763				typeName = string(xd.Enum().FullName())
764			case protoreflect.MessageKind, protoreflect.GroupKind:
765				typeName = string(xd.Message().FullName())
766			}
767			fieldName := string(xd.Name())
768
769			leadingComments := x.Comments.Leading
770			if leadingComments != "" {
771				leadingComments += "\n"
772			}
773			leadingComments += protogen.Comments(fmt.Sprintf(" %v %v %v = %v;\n",
774				xd.Cardinality(), typeName, fieldName, xd.Number()))
775			leadingComments = appendDeprecationSuffix(leadingComments,
776				x.Desc.Options().(*descriptorpb.FieldOptions).GetDeprecated())
777			g.P(leadingComments,
778				"E_", x.GoIdent, " = &", extensionTypesVarName(f), "[", allExtensionsByPtr[x], "]",
779				trailingComment(x.Comments.Trailing))
780		}
781		g.P(")")
782		g.P()
783	}
784}
785
786// genMessageOneofWrapperTypes generates the oneof wrapper types and
787// associates the types with the parent message type.
788func genMessageOneofWrapperTypes(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) {
789	for _, oneof := range m.Oneofs {
790		if oneof.Desc.IsSynthetic() {
791			continue
792		}
793		ifName := oneofInterfaceName(oneof)
794		g.P("type ", ifName, " interface {")
795		g.P(ifName, "()")
796		g.P("}")
797		g.P()
798		for _, field := range oneof.Fields {
799			g.Annotate(field.GoIdent.GoName, field.Location)
800			g.Annotate(field.GoIdent.GoName+"."+field.GoName, field.Location)
801			g.P("type ", field.GoIdent, " struct {")
802			goType, _ := fieldGoType(g, f, field)
803			tags := structTags{
804				{"protobuf", fieldProtobufTagValue(field)},
805			}
806			if m.isTracked {
807				tags = append(tags, gotrackTags...)
808			}
809			leadingComments := appendDeprecationSuffix(field.Comments.Leading,
810				field.Desc.Options().(*descriptorpb.FieldOptions).GetDeprecated())
811			g.P(leadingComments,
812				field.GoName, " ", goType, tags,
813				trailingComment(field.Comments.Trailing))
814			g.P("}")
815			g.P()
816		}
817		for _, field := range oneof.Fields {
818			g.P("func (*", field.GoIdent, ") ", ifName, "() {}")
819			g.P()
820		}
821	}
822}
823
824// oneofInterfaceName returns the name of the interface type implemented by
825// the oneof field value types.
826func oneofInterfaceName(oneof *protogen.Oneof) string {
827	return "is" + oneof.GoIdent.GoName
828}
829
830// genNoInterfacePragma generates a standalone "nointerface" pragma to
831// decorate methods with field-tracking support.
832func genNoInterfacePragma(g *protogen.GeneratedFile, tracked bool) {
833	if tracked {
834		g.P("//go:nointerface")
835		g.P()
836	}
837}
838
839var gotrackTags = structTags{{"go", "track"}}
840
841// structTags is a data structure for build idiomatic Go struct tags.
842// Each [2]string is a key-value pair, where value is the unescaped string.
843//
844// Example: structTags{{"key", "value"}}.String() -> `key:"value"`
845type structTags [][2]string
846
847func (tags structTags) String() string {
848	if len(tags) == 0 {
849		return ""
850	}
851	var ss []string
852	for _, tag := range tags {
853		// NOTE: When quoting the value, we need to make sure the backtick
854		// character does not appear. Convert all cases to the escaped hex form.
855		key := tag[0]
856		val := strings.Replace(strconv.Quote(tag[1]), "`", `\x60`, -1)
857		ss = append(ss, fmt.Sprintf("%s:%s", key, val))
858	}
859	return "`" + strings.Join(ss, " ") + "`"
860}
861
862// appendDeprecationSuffix optionally appends a deprecation notice as a suffix.
863func appendDeprecationSuffix(prefix protogen.Comments, deprecated bool) protogen.Comments {
864	if !deprecated {
865		return prefix
866	}
867	if prefix != "" {
868		prefix += "\n"
869	}
870	return prefix + " Deprecated: Do not use.\n"
871}
872
873// trailingComment is like protogen.Comments, but lacks a trailing newline.
874type trailingComment protogen.Comments
875
876func (c trailingComment) String() string {
877	s := strings.TrimSuffix(protogen.Comments(c).String(), "\n")
878	if strings.Contains(s, "\n") {
879		// We don't support multi-lined trailing comments as it is unclear
880		// how to best render them in the generated code.
881		return ""
882	}
883	return s
884}
885