1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8	"bytes"
9	"compress/gzip"
10	"fmt"
11	"io/ioutil"
12	"reflect"
13	"strings"
14	"sync"
15
16	"google.golang.org/protobuf/reflect/protodesc"
17	"google.golang.org/protobuf/reflect/protoreflect"
18	"google.golang.org/protobuf/reflect/protoregistry"
19	"google.golang.org/protobuf/runtime/protoimpl"
20)
21
22// filePath is the path to the proto source file.
23type filePath = string // e.g., "google/protobuf/descriptor.proto"
24
25// fileDescGZIP is the compressed contents of the encoded FileDescriptorProto.
26type fileDescGZIP = []byte
27
28var fileCache sync.Map // map[filePath]fileDescGZIP
29
30// RegisterFile is called from generated code to register the compressed
31// FileDescriptorProto with the file path for a proto source file.
32//
33// Deprecated: Use protoregistry.GlobalFiles.RegisterFile instead.
34func RegisterFile(s filePath, d fileDescGZIP) {
35	// Decompress the descriptor.
36	zr, err := gzip.NewReader(bytes.NewReader(d))
37	if err != nil {
38		panic(fmt.Sprintf("proto: invalid compressed file descriptor: %v", err))
39	}
40	b, err := ioutil.ReadAll(zr)
41	if err != nil {
42		panic(fmt.Sprintf("proto: invalid compressed file descriptor: %v", err))
43	}
44
45	// Construct a protoreflect.FileDescriptor from the raw descriptor.
46	// Note that DescBuilder.Build automatically registers the constructed
47	// file descriptor with the v2 registry.
48	protoimpl.DescBuilder{RawDescriptor: b}.Build()
49
50	// Locally cache the raw descriptor form for the file.
51	fileCache.Store(s, d)
52}
53
54// FileDescriptor returns the compressed FileDescriptorProto given the file path
55// for a proto source file. It returns nil if not found.
56//
57// Deprecated: Use protoregistry.GlobalFiles.FindFileByPath instead.
58func FileDescriptor(s filePath) fileDescGZIP {
59	if v, ok := fileCache.Load(s); ok {
60		return v.(fileDescGZIP)
61	}
62
63	// Find the descriptor in the v2 registry.
64	var b []byte
65	if fd, _ := protoregistry.GlobalFiles.FindFileByPath(s); fd != nil {
66		b, _ = Marshal(protodesc.ToFileDescriptorProto(fd))
67	}
68
69	// Locally cache the raw descriptor form for the file.
70	if len(b) > 0 {
71		v, _ := fileCache.LoadOrStore(s, protoimpl.X.CompressGZIP(b))
72		return v.(fileDescGZIP)
73	}
74	return nil
75}
76
77// enumName is the name of an enum. For historical reasons, the enum name is
78// neither the full Go name nor the full protobuf name of the enum.
79// The name is the dot-separated combination of just the proto package that the
80// enum is declared within followed by the Go type name of the generated enum.
81type enumName = string // e.g., "my.proto.package.GoMessage_GoEnum"
82
83// enumsByName maps enum values by name to their numeric counterpart.
84type enumsByName = map[string]int32
85
86// enumsByNumber maps enum values by number to their name counterpart.
87type enumsByNumber = map[int32]string
88
89var enumCache sync.Map     // map[enumName]enumsByName
90var numFilesCache sync.Map // map[protoreflect.FullName]int
91
92// RegisterEnum is called from the generated code to register the mapping of
93// enum value names to enum numbers for the enum identified by s.
94//
95// Deprecated: Use protoregistry.GlobalTypes.RegisterEnum instead.
96func RegisterEnum(s enumName, _ enumsByNumber, m enumsByName) {
97	if _, ok := enumCache.Load(s); ok {
98		panic("proto: duplicate enum registered: " + s)
99	}
100	enumCache.Store(s, m)
101
102	// This does not forward registration to the v2 registry since this API
103	// lacks sufficient information to construct a complete v2 enum descriptor.
104}
105
106// EnumValueMap returns the mapping from enum value names to enum numbers for
107// the enum of the given name. It returns nil if not found.
108//
109// Deprecated: Use protoregistry.GlobalTypes.FindEnumByName instead.
110func EnumValueMap(s enumName) enumsByName {
111	if v, ok := enumCache.Load(s); ok {
112		return v.(enumsByName)
113	}
114
115	// Check whether the cache is stale. If the number of files in the current
116	// package differs, then it means that some enums may have been recently
117	// registered upstream that we do not know about.
118	var protoPkg protoreflect.FullName
119	if i := strings.LastIndexByte(s, '.'); i >= 0 {
120		protoPkg = protoreflect.FullName(s[:i])
121	}
122	v, _ := numFilesCache.Load(protoPkg)
123	numFiles, _ := v.(int)
124	if protoregistry.GlobalFiles.NumFilesByPackage(protoPkg) == numFiles {
125		return nil // cache is up-to-date; was not found earlier
126	}
127
128	// Update the enum cache for all enums declared in the given proto package.
129	numFiles = 0
130	protoregistry.GlobalFiles.RangeFilesByPackage(protoPkg, func(fd protoreflect.FileDescriptor) bool {
131		walkEnums(fd, func(ed protoreflect.EnumDescriptor) {
132			name := protoimpl.X.LegacyEnumName(ed)
133			if _, ok := enumCache.Load(name); !ok {
134				m := make(enumsByName)
135				evs := ed.Values()
136				for i := evs.Len() - 1; i >= 0; i-- {
137					ev := evs.Get(i)
138					m[string(ev.Name())] = int32(ev.Number())
139				}
140				enumCache.LoadOrStore(name, m)
141			}
142		})
143		numFiles++
144		return true
145	})
146	numFilesCache.Store(protoPkg, numFiles)
147
148	// Check cache again for enum map.
149	if v, ok := enumCache.Load(s); ok {
150		return v.(enumsByName)
151	}
152	return nil
153}
154
155// walkEnums recursively walks all enums declared in d.
156func walkEnums(d interface {
157	Enums() protoreflect.EnumDescriptors
158	Messages() protoreflect.MessageDescriptors
159}, f func(protoreflect.EnumDescriptor)) {
160	eds := d.Enums()
161	for i := eds.Len() - 1; i >= 0; i-- {
162		f(eds.Get(i))
163	}
164	mds := d.Messages()
165	for i := mds.Len() - 1; i >= 0; i-- {
166		walkEnums(mds.Get(i), f)
167	}
168}
169
170// messageName is the full name of protobuf message.
171type messageName = string
172
173var messageTypeCache sync.Map // map[messageName]reflect.Type
174
175// RegisterType is called from generated code to register the message Go type
176// for a message of the given name.
177//
178// Deprecated: Use protoregistry.GlobalTypes.RegisterMessage instead.
179func RegisterType(m Message, s messageName) {
180	mt := protoimpl.X.LegacyMessageTypeOf(m, protoreflect.FullName(s))
181	if err := protoregistry.GlobalTypes.RegisterMessage(mt); err != nil {
182		panic(err)
183	}
184	messageTypeCache.Store(s, reflect.TypeOf(m))
185}
186
187// RegisterMapType is called from generated code to register the Go map type
188// for a protobuf message representing a map entry.
189//
190// Deprecated: Do not use.
191func RegisterMapType(m interface{}, s messageName) {
192	t := reflect.TypeOf(m)
193	if t.Kind() != reflect.Map {
194		panic(fmt.Sprintf("invalid map kind: %v", t))
195	}
196	if _, ok := messageTypeCache.Load(s); ok {
197		panic(fmt.Errorf("proto: duplicate proto message registered: %s", s))
198	}
199	messageTypeCache.Store(s, t)
200}
201
202// MessageType returns the message type for a named message.
203// It returns nil if not found.
204//
205// Deprecated: Use protoregistry.GlobalTypes.FindMessageByName instead.
206func MessageType(s messageName) reflect.Type {
207	if v, ok := messageTypeCache.Load(s); ok {
208		return v.(reflect.Type)
209	}
210
211	// Derive the message type from the v2 registry.
212	var t reflect.Type
213	if mt, _ := protoregistry.GlobalTypes.FindMessageByName(protoreflect.FullName(s)); mt != nil {
214		t = messageGoType(mt)
215	}
216
217	// If we could not get a concrete type, it is possible that it is a
218	// pseudo-message for a map entry.
219	if t == nil {
220		d, _ := protoregistry.GlobalFiles.FindDescriptorByName(protoreflect.FullName(s))
221		if md, _ := d.(protoreflect.MessageDescriptor); md != nil && md.IsMapEntry() {
222			kt := goTypeForField(md.Fields().ByNumber(1))
223			vt := goTypeForField(md.Fields().ByNumber(2))
224			t = reflect.MapOf(kt, vt)
225		}
226	}
227
228	// Locally cache the message type for the given name.
229	if t != nil {
230		v, _ := messageTypeCache.LoadOrStore(s, t)
231		return v.(reflect.Type)
232	}
233	return nil
234}
235
236func goTypeForField(fd protoreflect.FieldDescriptor) reflect.Type {
237	switch k := fd.Kind(); k {
238	case protoreflect.EnumKind:
239		if et, _ := protoregistry.GlobalTypes.FindEnumByName(fd.Enum().FullName()); et != nil {
240			return enumGoType(et)
241		}
242		return reflect.TypeOf(protoreflect.EnumNumber(0))
243	case protoreflect.MessageKind, protoreflect.GroupKind:
244		if mt, _ := protoregistry.GlobalTypes.FindMessageByName(fd.Message().FullName()); mt != nil {
245			return messageGoType(mt)
246		}
247		return reflect.TypeOf((*protoreflect.Message)(nil)).Elem()
248	default:
249		return reflect.TypeOf(fd.Default().Interface())
250	}
251}
252
253func enumGoType(et protoreflect.EnumType) reflect.Type {
254	return reflect.TypeOf(et.New(0))
255}
256
257func messageGoType(mt protoreflect.MessageType) reflect.Type {
258	return reflect.TypeOf(MessageV1(mt.Zero().Interface()))
259}
260
261// MessageName returns the full protobuf name for the given message type.
262//
263// Deprecated: Use protoreflect.MessageDescriptor.FullName instead.
264func MessageName(m Message) messageName {
265	if m == nil {
266		return ""
267	}
268	if m, ok := m.(interface{ XXX_MessageName() messageName }); ok {
269		return m.XXX_MessageName()
270	}
271	return messageName(protoimpl.X.MessageDescriptorOf(m).FullName())
272}
273
274// RegisterExtension is called from the generated code to register
275// the extension descriptor.
276//
277// Deprecated: Use protoregistry.GlobalTypes.RegisterExtension instead.
278func RegisterExtension(d *ExtensionDesc) {
279	if err := protoregistry.GlobalTypes.RegisterExtension(d); err != nil {
280		panic(err)
281	}
282}
283
284type extensionsByNumber = map[int32]*ExtensionDesc
285
286var extensionCache sync.Map // map[messageName]extensionsByNumber
287
288// RegisteredExtensions returns a map of the registered extensions for the
289// provided protobuf message, indexed by the extension field number.
290//
291// Deprecated: Use protoregistry.GlobalTypes.RangeExtensionsByMessage instead.
292func RegisteredExtensions(m Message) extensionsByNumber {
293	// Check whether the cache is stale. If the number of extensions for
294	// the given message differs, then it means that some extensions were
295	// recently registered upstream that we do not know about.
296	s := MessageName(m)
297	v, _ := extensionCache.Load(s)
298	xs, _ := v.(extensionsByNumber)
299	if protoregistry.GlobalTypes.NumExtensionsByMessage(protoreflect.FullName(s)) == len(xs) {
300		return xs // cache is up-to-date
301	}
302
303	// Cache is stale, re-compute the extensions map.
304	xs = make(extensionsByNumber)
305	protoregistry.GlobalTypes.RangeExtensionsByMessage(protoreflect.FullName(s), func(xt protoreflect.ExtensionType) bool {
306		if xd, ok := xt.(*ExtensionDesc); ok {
307			xs[int32(xt.TypeDescriptor().Number())] = xd
308		} else {
309			// TODO: This implies that the protoreflect.ExtensionType is a
310			// custom type not generated by protoc-gen-go. We could try and
311			// convert the type to an ExtensionDesc.
312		}
313		return true
314	})
315	extensionCache.Store(s, xs)
316	return xs
317}
318