1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package internal_gengo
6
7import (
8	"fmt"
9	"math"
10	"strings"
11	"unicode/utf8"
12
13	"google.golang.org/protobuf/compiler/protogen"
14	"google.golang.org/protobuf/proto"
15	"google.golang.org/protobuf/reflect/protoreflect"
16
17	"google.golang.org/protobuf/types/descriptorpb"
18)
19
20func genReflectFileDescriptor(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo) {
21	g.P("var ", f.GoDescriptorIdent, " ", protoreflectPackage.Ident("FileDescriptor"))
22	g.P()
23
24	genFileDescriptor(gen, g, f)
25	if len(f.allEnums) > 0 {
26		g.P("var ", enumTypesVarName(f), " = make([]", protoimplPackage.Ident("EnumInfo"), ",", len(f.allEnums), ")")
27	}
28	if len(f.allMessages) > 0 {
29		g.P("var ", messageTypesVarName(f), " = make([]", protoimplPackage.Ident("MessageInfo"), ",", len(f.allMessages), ")")
30	}
31
32	// Generate a unique list of Go types for all declarations and dependencies,
33	// and the associated index into the type list for all dependencies.
34	var goTypes []string
35	var depIdxs []string
36	seen := map[protoreflect.FullName]int{}
37	genDep := func(name protoreflect.FullName, depSource string) {
38		if depSource != "" {
39			line := fmt.Sprintf("%d, // %d: %s -> %s", seen[name], len(depIdxs), depSource, name)
40			depIdxs = append(depIdxs, line)
41		}
42	}
43	genEnum := func(e *protogen.Enum, depSource string) {
44		if e != nil {
45			name := e.Desc.FullName()
46			if _, ok := seen[name]; !ok {
47				line := fmt.Sprintf("(%s)(0), // %d: %s", g.QualifiedGoIdent(e.GoIdent), len(goTypes), name)
48				goTypes = append(goTypes, line)
49				seen[name] = len(seen)
50			}
51			if depSource != "" {
52				genDep(name, depSource)
53			}
54		}
55	}
56	genMessage := func(m *protogen.Message, depSource string) {
57		if m != nil {
58			name := m.Desc.FullName()
59			if _, ok := seen[name]; !ok {
60				line := fmt.Sprintf("(*%s)(nil), // %d: %s", g.QualifiedGoIdent(m.GoIdent), len(goTypes), name)
61				if m.Desc.IsMapEntry() {
62					// Map entry messages have no associated Go type.
63					line = fmt.Sprintf("nil, // %d: %s", len(goTypes), name)
64				}
65				goTypes = append(goTypes, line)
66				seen[name] = len(seen)
67			}
68			if depSource != "" {
69				genDep(name, depSource)
70			}
71		}
72	}
73
74	// This ordering is significant.
75	// See filetype.TypeBuilder.DependencyIndexes.
76	type offsetEntry struct {
77		start int
78		name  string
79	}
80	var depOffsets []offsetEntry
81	for _, enum := range f.allEnums {
82		genEnum(enum.Enum, "")
83	}
84	for _, message := range f.allMessages {
85		genMessage(message.Message, "")
86	}
87	depOffsets = append(depOffsets, offsetEntry{len(depIdxs), "field type_name"})
88	for _, message := range f.allMessages {
89		for _, field := range message.Fields {
90			if field.Desc.IsWeak() {
91				continue
92			}
93			source := string(field.Desc.FullName())
94			genEnum(field.Enum, source+":type_name")
95			genMessage(field.Message, source+":type_name")
96		}
97	}
98	depOffsets = append(depOffsets, offsetEntry{len(depIdxs), "extension extendee"})
99	for _, extension := range f.allExtensions {
100		source := string(extension.Desc.FullName())
101		genMessage(extension.Extendee, source+":extendee")
102	}
103	depOffsets = append(depOffsets, offsetEntry{len(depIdxs), "extension type_name"})
104	for _, extension := range f.allExtensions {
105		source := string(extension.Desc.FullName())
106		genEnum(extension.Enum, source+":type_name")
107		genMessage(extension.Message, source+":type_name")
108	}
109	depOffsets = append(depOffsets, offsetEntry{len(depIdxs), "method input_type"})
110	for _, service := range f.Services {
111		for _, method := range service.Methods {
112			source := string(method.Desc.FullName())
113			genMessage(method.Input, source+":input_type")
114		}
115	}
116	depOffsets = append(depOffsets, offsetEntry{len(depIdxs), "method output_type"})
117	for _, service := range f.Services {
118		for _, method := range service.Methods {
119			source := string(method.Desc.FullName())
120			genMessage(method.Output, source+":output_type")
121		}
122	}
123	depOffsets = append(depOffsets, offsetEntry{len(depIdxs), ""})
124	for i := len(depOffsets) - 2; i >= 0; i-- {
125		curr, next := depOffsets[i], depOffsets[i+1]
126		depIdxs = append(depIdxs, fmt.Sprintf("%d, // [%d:%d] is the sub-list for %s",
127			curr.start, curr.start, next.start, curr.name))
128	}
129	if len(depIdxs) > math.MaxInt32 {
130		panic("too many dependencies") // sanity check
131	}
132
133	g.P("var ", goTypesVarName(f), " = []interface{}{")
134	for _, s := range goTypes {
135		g.P(s)
136	}
137	g.P("}")
138
139	g.P("var ", depIdxsVarName(f), " = []int32{")
140	for _, s := range depIdxs {
141		g.P(s)
142	}
143	g.P("}")
144
145	g.P("func init() { ", initFuncName(f.File), "() }")
146
147	g.P("func ", initFuncName(f.File), "() {")
148	g.P("if ", f.GoDescriptorIdent, " != nil {")
149	g.P("return")
150	g.P("}")
151
152	// Ensure that initialization functions for different files in the same Go
153	// package run in the correct order: Call the init funcs for every .proto file
154	// imported by this one that is in the same Go package.
155	for i, imps := 0, f.Desc.Imports(); i < imps.Len(); i++ {
156		impFile := gen.FilesByPath[imps.Get(i).Path()]
157		if impFile.GoImportPath != f.GoImportPath {
158			continue
159		}
160		g.P(initFuncName(impFile), "()")
161	}
162
163	if len(f.allMessages) > 0 {
164		// Populate MessageInfo.Exporters.
165		g.P("if !", protoimplPackage.Ident("UnsafeEnabled"), " {")
166		for _, message := range f.allMessages {
167			if sf := f.allMessageFieldsByPtr[message]; len(sf.unexported) > 0 {
168				idx := f.allMessagesByPtr[message]
169				typesVar := messageTypesVarName(f)
170
171				g.P(typesVar, "[", idx, "].Exporter = func(v interface{}, i int) interface{} {")
172				g.P("switch v := v.(*", message.GoIdent, "); i {")
173				for i := 0; i < sf.count; i++ {
174					if name := sf.unexported[i]; name != "" {
175						g.P("case ", i, ": return &v.", name)
176					}
177				}
178				g.P("default: return nil")
179				g.P("}")
180				g.P("}")
181			}
182		}
183		g.P("}")
184
185		// Populate MessageInfo.OneofWrappers.
186		for _, message := range f.allMessages {
187			if len(message.Oneofs) > 0 {
188				idx := f.allMessagesByPtr[message]
189				typesVar := messageTypesVarName(f)
190
191				// Associate the wrapper types by directly passing them to the MessageInfo.
192				g.P(typesVar, "[", idx, "].OneofWrappers = []interface{} {")
193				for _, oneof := range message.Oneofs {
194					if !oneof.Desc.IsSynthetic() {
195						for _, field := range oneof.Fields {
196							g.P("(*", field.GoIdent, ")(nil),")
197						}
198					}
199				}
200				g.P("}")
201			}
202		}
203	}
204
205	g.P("type x struct{}")
206	g.P("out := ", protoimplPackage.Ident("TypeBuilder"), "{")
207	g.P("File: ", protoimplPackage.Ident("DescBuilder"), "{")
208	g.P("GoPackagePath: ", reflectPackage.Ident("TypeOf"), "(x{}).PkgPath(),")
209	g.P("RawDescriptor: ", rawDescVarName(f), ",")
210	g.P("NumEnums: ", len(f.allEnums), ",")
211	g.P("NumMessages: ", len(f.allMessages), ",")
212	g.P("NumExtensions: ", len(f.allExtensions), ",")
213	g.P("NumServices: ", len(f.Services), ",")
214	g.P("},")
215	g.P("GoTypes: ", goTypesVarName(f), ",")
216	g.P("DependencyIndexes: ", depIdxsVarName(f), ",")
217	if len(f.allEnums) > 0 {
218		g.P("EnumInfos: ", enumTypesVarName(f), ",")
219	}
220	if len(f.allMessages) > 0 {
221		g.P("MessageInfos: ", messageTypesVarName(f), ",")
222	}
223	if len(f.allExtensions) > 0 {
224		g.P("ExtensionInfos: ", extensionTypesVarName(f), ",")
225	}
226	g.P("}.Build()")
227	g.P(f.GoDescriptorIdent, " = out.File")
228
229	// Set inputs to nil to allow GC to reclaim resources.
230	g.P(rawDescVarName(f), " = nil")
231	g.P(goTypesVarName(f), " = nil")
232	g.P(depIdxsVarName(f), " = nil")
233	g.P("}")
234}
235
236func genFileDescriptor(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo) {
237	descProto := proto.Clone(f.Proto).(*descriptorpb.FileDescriptorProto)
238	descProto.SourceCodeInfo = nil // drop source code information
239
240	b, err := proto.MarshalOptions{AllowPartial: true, Deterministic: true}.Marshal(descProto)
241	if err != nil {
242		gen.Error(err)
243		return
244	}
245
246	g.P("var ", rawDescVarName(f), " = []byte{")
247	for len(b) > 0 {
248		n := 16
249		if n > len(b) {
250			n = len(b)
251		}
252
253		s := ""
254		for _, c := range b[:n] {
255			s += fmt.Sprintf("0x%02x,", c)
256		}
257		g.P(s)
258
259		b = b[n:]
260	}
261	g.P("}")
262	g.P()
263
264	if f.needRawDesc {
265		onceVar := rawDescVarName(f) + "Once"
266		dataVar := rawDescVarName(f) + "Data"
267		g.P("var (")
268		g.P(onceVar, " ", syncPackage.Ident("Once"))
269		g.P(dataVar, " = ", rawDescVarName(f))
270		g.P(")")
271		g.P()
272
273		g.P("func ", rawDescVarName(f), "GZIP() []byte {")
274		g.P(onceVar, ".Do(func() {")
275		g.P(dataVar, " = ", protoimplPackage.Ident("X"), ".CompressGZIP(", dataVar, ")")
276		g.P("})")
277		g.P("return ", dataVar)
278		g.P("}")
279		g.P()
280	}
281}
282
283func genEnumReflectMethods(g *protogen.GeneratedFile, f *fileInfo, e *enumInfo) {
284	idx := f.allEnumsByPtr[e]
285	typesVar := enumTypesVarName(f)
286
287	// Descriptor method.
288	g.P("func (", e.GoIdent, ") Descriptor() ", protoreflectPackage.Ident("EnumDescriptor"), " {")
289	g.P("return ", typesVar, "[", idx, "].Descriptor()")
290	g.P("}")
291	g.P()
292
293	// Type method.
294	g.P("func (", e.GoIdent, ") Type() ", protoreflectPackage.Ident("EnumType"), " {")
295	g.P("return &", typesVar, "[", idx, "]")
296	g.P("}")
297	g.P()
298
299	// Number method.
300	g.P("func (x ", e.GoIdent, ") Number() ", protoreflectPackage.Ident("EnumNumber"), " {")
301	g.P("return ", protoreflectPackage.Ident("EnumNumber"), "(x)")
302	g.P("}")
303	g.P()
304}
305
306func genMessageReflectMethods(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) {
307	idx := f.allMessagesByPtr[m]
308	typesVar := messageTypesVarName(f)
309
310	// ProtoReflect method.
311	g.P("func (x *", m.GoIdent, ") ProtoReflect() ", protoreflectPackage.Ident("Message"), " {")
312	g.P("mi := &", typesVar, "[", idx, "]")
313	g.P("if ", protoimplPackage.Ident("UnsafeEnabled"), " && x != nil {")
314	g.P("ms := ", protoimplPackage.Ident("X"), ".MessageStateOf(", protoimplPackage.Ident("Pointer"), "(x))")
315	g.P("if ms.LoadMessageInfo() == nil {")
316	g.P("ms.StoreMessageInfo(mi)")
317	g.P("}")
318	g.P("return ms")
319	g.P("}")
320	g.P("return mi.MessageOf(x)")
321	g.P("}")
322	g.P()
323}
324
325func fileVarName(f *protogen.File, suffix string) string {
326	prefix := f.GoDescriptorIdent.GoName
327	_, n := utf8.DecodeRuneInString(prefix)
328	prefix = strings.ToLower(prefix[:n]) + prefix[n:]
329	return prefix + "_" + suffix
330}
331func rawDescVarName(f *fileInfo) string {
332	return fileVarName(f.File, "rawDesc")
333}
334func goTypesVarName(f *fileInfo) string {
335	return fileVarName(f.File, "goTypes")
336}
337func depIdxsVarName(f *fileInfo) string {
338	return fileVarName(f.File, "depIdxs")
339}
340func enumTypesVarName(f *fileInfo) string {
341	return fileVarName(f.File, "enumTypes")
342}
343func messageTypesVarName(f *fileInfo) string {
344	return fileVarName(f.File, "msgTypes")
345}
346func extensionTypesVarName(f *fileInfo) string {
347	return fileVarName(f.File, "extTypes")
348}
349func initFuncName(f *protogen.File) string {
350	return fileVarName(f, "init")
351}
352