1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package protoregistry provides data structures to register and lookup
6// protobuf descriptor types.
7//
8// The Files registry contains file descriptors and provides the ability
9// to iterate over the files or lookup a specific descriptor within the files.
10// Files only contains protobuf descriptors and has no understanding of Go
11// type information that may be associated with each descriptor.
12//
13// The Types registry contains descriptor types for which there is a known
14// Go type associated with that descriptor. It provides the ability to iterate
15// over the registered types or lookup a type by name.
16package protoregistry
17
18import (
19	"fmt"
20	"log"
21	"strings"
22	"sync"
23
24	"google.golang.org/protobuf/internal/errors"
25	"google.golang.org/protobuf/reflect/protoreflect"
26)
27
28// ignoreConflict reports whether to ignore a registration conflict
29// given the descriptor being registered and the error.
30// It is a variable so that the behavior is easily overridden in another file.
31var ignoreConflict = func(d protoreflect.Descriptor, err error) bool {
32	log.Printf(""+
33		"WARNING: %v\n"+
34		"A future release will panic on registration conflicts. See:\n"+
35		"https://developers.google.com/protocol-buffers/docs/reference/go/faq#namespace-conflict\n"+
36		"\n", err)
37	return true
38}
39
40var globalMutex sync.RWMutex
41
42// GlobalFiles is a global registry of file descriptors.
43var GlobalFiles *Files = new(Files)
44
45// GlobalTypes is the registry used by default for type lookups
46// unless a local registry is provided by the user.
47var GlobalTypes *Types = new(Types)
48
49// NotFound is a sentinel error value to indicate that the type was not found.
50//
51// Since registry lookup can happen in the critical performance path, resolvers
52// must return this exact error value, not an error wrapping it.
53var NotFound = errors.New("not found")
54
55// Files is a registry for looking up or iterating over files and the
56// descriptors contained within them.
57// The Find and Range methods are safe for concurrent use.
58type Files struct {
59	// The map of descsByName contains:
60	//	EnumDescriptor
61	//	EnumValueDescriptor
62	//	MessageDescriptor
63	//	ExtensionDescriptor
64	//	ServiceDescriptor
65	//	*packageDescriptor
66	//
67	// Note that files are stored as a slice, since a package may contain
68	// multiple files. Only top-level declarations are registered.
69	// Note that enum values are in the top-level since that are in the same
70	// scope as the parent enum.
71	descsByName map[protoreflect.FullName]interface{}
72	filesByPath map[string]protoreflect.FileDescriptor
73}
74
75type packageDescriptor struct {
76	files []protoreflect.FileDescriptor
77}
78
79// RegisterFile registers the provided file descriptor.
80//
81// If any descriptor within the file conflicts with the descriptor of any
82// previously registered file (e.g., two enums with the same full name),
83// then the file is not registered and an error is returned.
84//
85// It is permitted for multiple files to have the same file path.
86func (r *Files) RegisterFile(file protoreflect.FileDescriptor) error {
87	if r == GlobalFiles {
88		globalMutex.Lock()
89		defer globalMutex.Unlock()
90	}
91	if r.descsByName == nil {
92		r.descsByName = map[protoreflect.FullName]interface{}{
93			"": &packageDescriptor{},
94		}
95		r.filesByPath = make(map[string]protoreflect.FileDescriptor)
96	}
97	path := file.Path()
98	if prev := r.filesByPath[path]; prev != nil {
99		err := errors.New("file %q is already registered", file.Path())
100		err = amendErrorWithCaller(err, prev, file)
101		if r == GlobalFiles && ignoreConflict(file, err) {
102			err = nil
103		}
104		return err
105	}
106
107	for name := file.Package(); name != ""; name = name.Parent() {
108		switch prev := r.descsByName[name]; prev.(type) {
109		case nil, *packageDescriptor:
110		default:
111			err := errors.New("file %q has a package name conflict over %v", file.Path(), name)
112			err = amendErrorWithCaller(err, prev, file)
113			if r == GlobalFiles && ignoreConflict(file, err) {
114				err = nil
115			}
116			return err
117		}
118	}
119	var err error
120	var hasConflict bool
121	rangeTopLevelDescriptors(file, func(d protoreflect.Descriptor) {
122		if prev := r.descsByName[d.FullName()]; prev != nil {
123			hasConflict = true
124			err = errors.New("file %q has a name conflict over %v", file.Path(), d.FullName())
125			err = amendErrorWithCaller(err, prev, file)
126			if r == GlobalFiles && ignoreConflict(d, err) {
127				err = nil
128			}
129		}
130	})
131	if hasConflict {
132		return err
133	}
134
135	for name := file.Package(); name != ""; name = name.Parent() {
136		if r.descsByName[name] == nil {
137			r.descsByName[name] = &packageDescriptor{}
138		}
139	}
140	p := r.descsByName[file.Package()].(*packageDescriptor)
141	p.files = append(p.files, file)
142	rangeTopLevelDescriptors(file, func(d protoreflect.Descriptor) {
143		r.descsByName[d.FullName()] = d
144	})
145	r.filesByPath[path] = file
146	return nil
147}
148
149// FindDescriptorByName looks up a descriptor by the full name.
150//
151// This returns (nil, NotFound) if not found.
152func (r *Files) FindDescriptorByName(name protoreflect.FullName) (protoreflect.Descriptor, error) {
153	if r == nil {
154		return nil, NotFound
155	}
156	if r == GlobalFiles {
157		globalMutex.RLock()
158		defer globalMutex.RUnlock()
159	}
160	prefix := name
161	suffix := nameSuffix("")
162	for prefix != "" {
163		if d, ok := r.descsByName[prefix]; ok {
164			switch d := d.(type) {
165			case protoreflect.EnumDescriptor:
166				if d.FullName() == name {
167					return d, nil
168				}
169			case protoreflect.EnumValueDescriptor:
170				if d.FullName() == name {
171					return d, nil
172				}
173			case protoreflect.MessageDescriptor:
174				if d.FullName() == name {
175					return d, nil
176				}
177				if d := findDescriptorInMessage(d, suffix); d != nil && d.FullName() == name {
178					return d, nil
179				}
180			case protoreflect.ExtensionDescriptor:
181				if d.FullName() == name {
182					return d, nil
183				}
184			case protoreflect.ServiceDescriptor:
185				if d.FullName() == name {
186					return d, nil
187				}
188				if d := d.Methods().ByName(suffix.Pop()); d != nil && d.FullName() == name {
189					return d, nil
190				}
191			}
192			return nil, NotFound
193		}
194		prefix = prefix.Parent()
195		suffix = nameSuffix(name[len(prefix)+len("."):])
196	}
197	return nil, NotFound
198}
199
200func findDescriptorInMessage(md protoreflect.MessageDescriptor, suffix nameSuffix) protoreflect.Descriptor {
201	name := suffix.Pop()
202	if suffix == "" {
203		if ed := md.Enums().ByName(name); ed != nil {
204			return ed
205		}
206		for i := md.Enums().Len() - 1; i >= 0; i-- {
207			if vd := md.Enums().Get(i).Values().ByName(name); vd != nil {
208				return vd
209			}
210		}
211		if xd := md.Extensions().ByName(name); xd != nil {
212			return xd
213		}
214		if fd := md.Fields().ByName(name); fd != nil {
215			return fd
216		}
217		if od := md.Oneofs().ByName(name); od != nil {
218			return od
219		}
220	}
221	if md := md.Messages().ByName(name); md != nil {
222		if suffix == "" {
223			return md
224		}
225		return findDescriptorInMessage(md, suffix)
226	}
227	return nil
228}
229
230type nameSuffix string
231
232func (s *nameSuffix) Pop() (name protoreflect.Name) {
233	if i := strings.IndexByte(string(*s), '.'); i >= 0 {
234		name, *s = protoreflect.Name((*s)[:i]), (*s)[i+1:]
235	} else {
236		name, *s = protoreflect.Name((*s)), ""
237	}
238	return name
239}
240
241// FindFileByPath looks up a file by the path.
242//
243// This returns (nil, NotFound) if not found.
244func (r *Files) FindFileByPath(path string) (protoreflect.FileDescriptor, error) {
245	if r == nil {
246		return nil, NotFound
247	}
248	if r == GlobalFiles {
249		globalMutex.RLock()
250		defer globalMutex.RUnlock()
251	}
252	if fd, ok := r.filesByPath[path]; ok {
253		return fd, nil
254	}
255	return nil, NotFound
256}
257
258// NumFiles reports the number of registered files.
259func (r *Files) NumFiles() int {
260	if r == nil {
261		return 0
262	}
263	if r == GlobalFiles {
264		globalMutex.RLock()
265		defer globalMutex.RUnlock()
266	}
267	return len(r.filesByPath)
268}
269
270// RangeFiles iterates over all registered files while f returns true.
271// The iteration order is undefined.
272func (r *Files) RangeFiles(f func(protoreflect.FileDescriptor) bool) {
273	if r == nil {
274		return
275	}
276	if r == GlobalFiles {
277		globalMutex.RLock()
278		defer globalMutex.RUnlock()
279	}
280	for _, file := range r.filesByPath {
281		if !f(file) {
282			return
283		}
284	}
285}
286
287// NumFilesByPackage reports the number of registered files in a proto package.
288func (r *Files) NumFilesByPackage(name protoreflect.FullName) int {
289	if r == nil {
290		return 0
291	}
292	if r == GlobalFiles {
293		globalMutex.RLock()
294		defer globalMutex.RUnlock()
295	}
296	p, ok := r.descsByName[name].(*packageDescriptor)
297	if !ok {
298		return 0
299	}
300	return len(p.files)
301}
302
303// RangeFilesByPackage iterates over all registered files in a given proto package
304// while f returns true. The iteration order is undefined.
305func (r *Files) RangeFilesByPackage(name protoreflect.FullName, f func(protoreflect.FileDescriptor) bool) {
306	if r == nil {
307		return
308	}
309	if r == GlobalFiles {
310		globalMutex.RLock()
311		defer globalMutex.RUnlock()
312	}
313	p, ok := r.descsByName[name].(*packageDescriptor)
314	if !ok {
315		return
316	}
317	for _, file := range p.files {
318		if !f(file) {
319			return
320		}
321	}
322}
323
324// rangeTopLevelDescriptors iterates over all top-level descriptors in a file
325// which will be directly entered into the registry.
326func rangeTopLevelDescriptors(fd protoreflect.FileDescriptor, f func(protoreflect.Descriptor)) {
327	eds := fd.Enums()
328	for i := eds.Len() - 1; i >= 0; i-- {
329		f(eds.Get(i))
330		vds := eds.Get(i).Values()
331		for i := vds.Len() - 1; i >= 0; i-- {
332			f(vds.Get(i))
333		}
334	}
335	mds := fd.Messages()
336	for i := mds.Len() - 1; i >= 0; i-- {
337		f(mds.Get(i))
338	}
339	xds := fd.Extensions()
340	for i := xds.Len() - 1; i >= 0; i-- {
341		f(xds.Get(i))
342	}
343	sds := fd.Services()
344	for i := sds.Len() - 1; i >= 0; i-- {
345		f(sds.Get(i))
346	}
347}
348
349// MessageTypeResolver is an interface for looking up messages.
350//
351// A compliant implementation must deterministically return the same type
352// if no error is encountered.
353//
354// The Types type implements this interface.
355type MessageTypeResolver interface {
356	// FindMessageByName looks up a message by its full name.
357	// E.g., "google.protobuf.Any"
358	//
359	// This return (nil, NotFound) if not found.
360	FindMessageByName(message protoreflect.FullName) (protoreflect.MessageType, error)
361
362	// FindMessageByURL looks up a message by a URL identifier.
363	// See documentation on google.protobuf.Any.type_url for the URL format.
364	//
365	// This returns (nil, NotFound) if not found.
366	FindMessageByURL(url string) (protoreflect.MessageType, error)
367}
368
369// ExtensionTypeResolver is an interface for looking up extensions.
370//
371// A compliant implementation must deterministically return the same type
372// if no error is encountered.
373//
374// The Types type implements this interface.
375type ExtensionTypeResolver interface {
376	// FindExtensionByName looks up a extension field by the field's full name.
377	// Note that this is the full name of the field as determined by
378	// where the extension is declared and is unrelated to the full name of the
379	// message being extended.
380	//
381	// This returns (nil, NotFound) if not found.
382	FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
383
384	// FindExtensionByNumber looks up a extension field by the field number
385	// within some parent message, identified by full name.
386	//
387	// This returns (nil, NotFound) if not found.
388	FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
389}
390
391var (
392	_ MessageTypeResolver   = (*Types)(nil)
393	_ ExtensionTypeResolver = (*Types)(nil)
394)
395
396// Types is a registry for looking up or iterating over descriptor types.
397// The Find and Range methods are safe for concurrent use.
398type Types struct {
399	typesByName         typesByName
400	extensionsByMessage extensionsByMessage
401
402	numEnums      int
403	numMessages   int
404	numExtensions int
405}
406
407type (
408	typesByName         map[protoreflect.FullName]interface{}
409	extensionsByMessage map[protoreflect.FullName]extensionsByNumber
410	extensionsByNumber  map[protoreflect.FieldNumber]protoreflect.ExtensionType
411)
412
413// RegisterMessage registers the provided message type.
414//
415// If a naming conflict occurs, the type is not registered and an error is returned.
416func (r *Types) RegisterMessage(mt protoreflect.MessageType) error {
417	// Under rare circumstances getting the descriptor might recursively
418	// examine the registry, so fetch it before locking.
419	md := mt.Descriptor()
420
421	if r == GlobalTypes {
422		globalMutex.Lock()
423		defer globalMutex.Unlock()
424	}
425
426	if err := r.register("message", md, mt); err != nil {
427		return err
428	}
429	r.numMessages++
430	return nil
431}
432
433// RegisterEnum registers the provided enum type.
434//
435// If a naming conflict occurs, the type is not registered and an error is returned.
436func (r *Types) RegisterEnum(et protoreflect.EnumType) error {
437	// Under rare circumstances getting the descriptor might recursively
438	// examine the registry, so fetch it before locking.
439	ed := et.Descriptor()
440
441	if r == GlobalTypes {
442		globalMutex.Lock()
443		defer globalMutex.Unlock()
444	}
445
446	if err := r.register("enum", ed, et); err != nil {
447		return err
448	}
449	r.numEnums++
450	return nil
451}
452
453// RegisterExtension registers the provided extension type.
454//
455// If a naming conflict occurs, the type is not registered and an error is returned.
456func (r *Types) RegisterExtension(xt protoreflect.ExtensionType) error {
457	// Under rare circumstances getting the descriptor might recursively
458	// examine the registry, so fetch it before locking.
459	//
460	// A known case where this can happen: Fetching the TypeDescriptor for a
461	// legacy ExtensionDesc can consult the global registry.
462	xd := xt.TypeDescriptor()
463
464	if r == GlobalTypes {
465		globalMutex.Lock()
466		defer globalMutex.Unlock()
467	}
468
469	field := xd.Number()
470	message := xd.ContainingMessage().FullName()
471	if prev := r.extensionsByMessage[message][field]; prev != nil {
472		err := errors.New("extension number %d is already registered on message %v", field, message)
473		err = amendErrorWithCaller(err, prev, xt)
474		if !(r == GlobalTypes && ignoreConflict(xd, err)) {
475			return err
476		}
477	}
478
479	if err := r.register("extension", xd, xt); err != nil {
480		return err
481	}
482	if r.extensionsByMessage == nil {
483		r.extensionsByMessage = make(extensionsByMessage)
484	}
485	if r.extensionsByMessage[message] == nil {
486		r.extensionsByMessage[message] = make(extensionsByNumber)
487	}
488	r.extensionsByMessage[message][field] = xt
489	r.numExtensions++
490	return nil
491}
492
493func (r *Types) register(kind string, desc protoreflect.Descriptor, typ interface{}) error {
494	name := desc.FullName()
495	prev := r.typesByName[name]
496	if prev != nil {
497		err := errors.New("%v %v is already registered", kind, name)
498		err = amendErrorWithCaller(err, prev, typ)
499		if !(r == GlobalTypes && ignoreConflict(desc, err)) {
500			return err
501		}
502	}
503	if r.typesByName == nil {
504		r.typesByName = make(typesByName)
505	}
506	r.typesByName[name] = typ
507	return nil
508}
509
510// FindEnumByName looks up an enum by its full name.
511// E.g., "google.protobuf.Field.Kind".
512//
513// This returns (nil, NotFound) if not found.
514func (r *Types) FindEnumByName(enum protoreflect.FullName) (protoreflect.EnumType, error) {
515	if r == nil {
516		return nil, NotFound
517	}
518	if r == GlobalTypes {
519		globalMutex.RLock()
520		defer globalMutex.RUnlock()
521	}
522	if v := r.typesByName[enum]; v != nil {
523		if et, _ := v.(protoreflect.EnumType); et != nil {
524			return et, nil
525		}
526		return nil, errors.New("found wrong type: got %v, want enum", typeName(v))
527	}
528	return nil, NotFound
529}
530
531// FindMessageByName looks up a message by its full name.
532// E.g., "google.protobuf.Any"
533//
534// This return (nil, NotFound) if not found.
535func (r *Types) FindMessageByName(message protoreflect.FullName) (protoreflect.MessageType, error) {
536	// The full name by itself is a valid URL.
537	return r.FindMessageByURL(string(message))
538}
539
540// FindMessageByURL looks up a message by a URL identifier.
541// See documentation on google.protobuf.Any.type_url for the URL format.
542//
543// This returns (nil, NotFound) if not found.
544func (r *Types) FindMessageByURL(url string) (protoreflect.MessageType, error) {
545	if r == nil {
546		return nil, NotFound
547	}
548	if r == GlobalTypes {
549		globalMutex.RLock()
550		defer globalMutex.RUnlock()
551	}
552	message := protoreflect.FullName(url)
553	if i := strings.LastIndexByte(url, '/'); i >= 0 {
554		message = message[i+len("/"):]
555	}
556
557	if v := r.typesByName[message]; v != nil {
558		if mt, _ := v.(protoreflect.MessageType); mt != nil {
559			return mt, nil
560		}
561		return nil, errors.New("found wrong type: got %v, want message", typeName(v))
562	}
563	return nil, NotFound
564}
565
566// FindExtensionByName looks up a extension field by the field's full name.
567// Note that this is the full name of the field as determined by
568// where the extension is declared and is unrelated to the full name of the
569// message being extended.
570//
571// This returns (nil, NotFound) if not found.
572func (r *Types) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
573	if r == nil {
574		return nil, NotFound
575	}
576	if r == GlobalTypes {
577		globalMutex.RLock()
578		defer globalMutex.RUnlock()
579	}
580	if v := r.typesByName[field]; v != nil {
581		if xt, _ := v.(protoreflect.ExtensionType); xt != nil {
582			return xt, nil
583		}
584		return nil, errors.New("found wrong type: got %v, want extension", typeName(v))
585	}
586	return nil, NotFound
587}
588
589// FindExtensionByNumber looks up a extension field by the field number
590// within some parent message, identified by full name.
591//
592// This returns (nil, NotFound) if not found.
593func (r *Types) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
594	if r == nil {
595		return nil, NotFound
596	}
597	if r == GlobalTypes {
598		globalMutex.RLock()
599		defer globalMutex.RUnlock()
600	}
601	if xt, ok := r.extensionsByMessage[message][field]; ok {
602		return xt, nil
603	}
604	return nil, NotFound
605}
606
607// NumEnums reports the number of registered enums.
608func (r *Types) NumEnums() int {
609	if r == nil {
610		return 0
611	}
612	if r == GlobalTypes {
613		globalMutex.RLock()
614		defer globalMutex.RUnlock()
615	}
616	return r.numEnums
617}
618
619// RangeEnums iterates over all registered enums while f returns true.
620// Iteration order is undefined.
621func (r *Types) RangeEnums(f func(protoreflect.EnumType) bool) {
622	if r == nil {
623		return
624	}
625	if r == GlobalTypes {
626		globalMutex.RLock()
627		defer globalMutex.RUnlock()
628	}
629	for _, typ := range r.typesByName {
630		if et, ok := typ.(protoreflect.EnumType); ok {
631			if !f(et) {
632				return
633			}
634		}
635	}
636}
637
638// NumMessages reports the number of registered messages.
639func (r *Types) NumMessages() int {
640	if r == nil {
641		return 0
642	}
643	if r == GlobalTypes {
644		globalMutex.RLock()
645		defer globalMutex.RUnlock()
646	}
647	return r.numMessages
648}
649
650// RangeMessages iterates over all registered messages while f returns true.
651// Iteration order is undefined.
652func (r *Types) RangeMessages(f func(protoreflect.MessageType) bool) {
653	if r == nil {
654		return
655	}
656	if r == GlobalTypes {
657		globalMutex.RLock()
658		defer globalMutex.RUnlock()
659	}
660	for _, typ := range r.typesByName {
661		if mt, ok := typ.(protoreflect.MessageType); ok {
662			if !f(mt) {
663				return
664			}
665		}
666	}
667}
668
669// NumExtensions reports the number of registered extensions.
670func (r *Types) NumExtensions() int {
671	if r == nil {
672		return 0
673	}
674	if r == GlobalTypes {
675		globalMutex.RLock()
676		defer globalMutex.RUnlock()
677	}
678	return r.numExtensions
679}
680
681// RangeExtensions iterates over all registered extensions while f returns true.
682// Iteration order is undefined.
683func (r *Types) RangeExtensions(f func(protoreflect.ExtensionType) bool) {
684	if r == nil {
685		return
686	}
687	if r == GlobalTypes {
688		globalMutex.RLock()
689		defer globalMutex.RUnlock()
690	}
691	for _, typ := range r.typesByName {
692		if xt, ok := typ.(protoreflect.ExtensionType); ok {
693			if !f(xt) {
694				return
695			}
696		}
697	}
698}
699
700// NumExtensionsByMessage reports the number of registered extensions for
701// a given message type.
702func (r *Types) NumExtensionsByMessage(message protoreflect.FullName) int {
703	if r == nil {
704		return 0
705	}
706	if r == GlobalTypes {
707		globalMutex.RLock()
708		defer globalMutex.RUnlock()
709	}
710	return len(r.extensionsByMessage[message])
711}
712
713// RangeExtensionsByMessage iterates over all registered extensions filtered
714// by a given message type while f returns true. Iteration order is undefined.
715func (r *Types) RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool) {
716	if r == nil {
717		return
718	}
719	if r == GlobalTypes {
720		globalMutex.RLock()
721		defer globalMutex.RUnlock()
722	}
723	for _, xt := range r.extensionsByMessage[message] {
724		if !f(xt) {
725			return
726		}
727	}
728}
729
730func typeName(t interface{}) string {
731	switch t.(type) {
732	case protoreflect.EnumType:
733		return "enum"
734	case protoreflect.MessageType:
735		return "message"
736	case protoreflect.ExtensionType:
737		return "extension"
738	default:
739		return fmt.Sprintf("%T", t)
740	}
741}
742
743func amendErrorWithCaller(err error, prev, curr interface{}) error {
744	prevPkg := goPackage(prev)
745	currPkg := goPackage(curr)
746	if prevPkg == "" || currPkg == "" || prevPkg == currPkg {
747		return err
748	}
749	return errors.New("%s\n\tpreviously from: %q\n\tcurrently from:  %q", err, prevPkg, currPkg)
750}
751
752func goPackage(v interface{}) string {
753	switch d := v.(type) {
754	case protoreflect.EnumType:
755		v = d.Descriptor()
756	case protoreflect.MessageType:
757		v = d.Descriptor()
758	case protoreflect.ExtensionType:
759		v = d.TypeDescriptor()
760	}
761	if d, ok := v.(protoreflect.Descriptor); ok {
762		v = d.ParentFile()
763	}
764	if d, ok := v.(interface{ GoPackagePath() string }); ok {
765		return d.GoPackagePath()
766	}
767	return ""
768}
769