1// Copyright 2018 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5// Package protoregistry provides data structures to register and lookup 6// protobuf descriptor types. 7// 8// The Files registry contains file descriptors and provides the ability 9// to iterate over the files or lookup a specific descriptor within the files. 10// Files only contains protobuf descriptors and has no understanding of Go 11// type information that may be associated with each descriptor. 12// 13// The Types registry contains descriptor types for which there is a known 14// Go type associated with that descriptor. It provides the ability to iterate 15// over the registered types or lookup a type by name. 16package protoregistry 17 18import ( 19 "fmt" 20 "log" 21 "strings" 22 "sync" 23 24 "google.golang.org/protobuf/internal/errors" 25 "google.golang.org/protobuf/reflect/protoreflect" 26) 27 28// ignoreConflict reports whether to ignore a registration conflict 29// given the descriptor being registered and the error. 30// It is a variable so that the behavior is easily overridden in another file. 31var ignoreConflict = func(d protoreflect.Descriptor, err error) bool { 32 log.Printf(""+ 33 "WARNING: %v\n"+ 34 "A future release will panic on registration conflicts. See:\n"+ 35 "https://developers.google.com/protocol-buffers/docs/reference/go/faq#namespace-conflict\n"+ 36 "\n", err) 37 return true 38} 39 40var globalMutex sync.RWMutex 41 42// GlobalFiles is a global registry of file descriptors. 43var GlobalFiles *Files = new(Files) 44 45// GlobalTypes is the registry used by default for type lookups 46// unless a local registry is provided by the user. 47var GlobalTypes *Types = new(Types) 48 49// NotFound is a sentinel error value to indicate that the type was not found. 50// 51// Since registry lookup can happen in the critical performance path, resolvers 52// must return this exact error value, not an error wrapping it. 53var NotFound = errors.New("not found") 54 55// Files is a registry for looking up or iterating over files and the 56// descriptors contained within them. 57// The Find and Range methods are safe for concurrent use. 58type Files struct { 59 // The map of descsByName contains: 60 // EnumDescriptor 61 // EnumValueDescriptor 62 // MessageDescriptor 63 // ExtensionDescriptor 64 // ServiceDescriptor 65 // *packageDescriptor 66 // 67 // Note that files are stored as a slice, since a package may contain 68 // multiple files. Only top-level declarations are registered. 69 // Note that enum values are in the top-level since that are in the same 70 // scope as the parent enum. 71 descsByName map[protoreflect.FullName]interface{} 72 filesByPath map[string]protoreflect.FileDescriptor 73} 74 75type packageDescriptor struct { 76 files []protoreflect.FileDescriptor 77} 78 79// RegisterFile registers the provided file descriptor. 80// 81// If any descriptor within the file conflicts with the descriptor of any 82// previously registered file (e.g., two enums with the same full name), 83// then the file is not registered and an error is returned. 84// 85// It is permitted for multiple files to have the same file path. 86func (r *Files) RegisterFile(file protoreflect.FileDescriptor) error { 87 if r == GlobalFiles { 88 globalMutex.Lock() 89 defer globalMutex.Unlock() 90 } 91 if r.descsByName == nil { 92 r.descsByName = map[protoreflect.FullName]interface{}{ 93 "": &packageDescriptor{}, 94 } 95 r.filesByPath = make(map[string]protoreflect.FileDescriptor) 96 } 97 path := file.Path() 98 if prev := r.filesByPath[path]; prev != nil { 99 err := errors.New("file %q is already registered", file.Path()) 100 err = amendErrorWithCaller(err, prev, file) 101 if r == GlobalFiles && ignoreConflict(file, err) { 102 err = nil 103 } 104 return err 105 } 106 107 for name := file.Package(); name != ""; name = name.Parent() { 108 switch prev := r.descsByName[name]; prev.(type) { 109 case nil, *packageDescriptor: 110 default: 111 err := errors.New("file %q has a package name conflict over %v", file.Path(), name) 112 err = amendErrorWithCaller(err, prev, file) 113 if r == GlobalFiles && ignoreConflict(file, err) { 114 err = nil 115 } 116 return err 117 } 118 } 119 var err error 120 var hasConflict bool 121 rangeTopLevelDescriptors(file, func(d protoreflect.Descriptor) { 122 if prev := r.descsByName[d.FullName()]; prev != nil { 123 hasConflict = true 124 err = errors.New("file %q has a name conflict over %v", file.Path(), d.FullName()) 125 err = amendErrorWithCaller(err, prev, file) 126 if r == GlobalFiles && ignoreConflict(d, err) { 127 err = nil 128 } 129 } 130 }) 131 if hasConflict { 132 return err 133 } 134 135 for name := file.Package(); name != ""; name = name.Parent() { 136 if r.descsByName[name] == nil { 137 r.descsByName[name] = &packageDescriptor{} 138 } 139 } 140 p := r.descsByName[file.Package()].(*packageDescriptor) 141 p.files = append(p.files, file) 142 rangeTopLevelDescriptors(file, func(d protoreflect.Descriptor) { 143 r.descsByName[d.FullName()] = d 144 }) 145 r.filesByPath[path] = file 146 return nil 147} 148 149// FindDescriptorByName looks up a descriptor by the full name. 150// 151// This returns (nil, NotFound) if not found. 152func (r *Files) FindDescriptorByName(name protoreflect.FullName) (protoreflect.Descriptor, error) { 153 if r == nil { 154 return nil, NotFound 155 } 156 if r == GlobalFiles { 157 globalMutex.RLock() 158 defer globalMutex.RUnlock() 159 } 160 prefix := name 161 suffix := nameSuffix("") 162 for prefix != "" { 163 if d, ok := r.descsByName[prefix]; ok { 164 switch d := d.(type) { 165 case protoreflect.EnumDescriptor: 166 if d.FullName() == name { 167 return d, nil 168 } 169 case protoreflect.EnumValueDescriptor: 170 if d.FullName() == name { 171 return d, nil 172 } 173 case protoreflect.MessageDescriptor: 174 if d.FullName() == name { 175 return d, nil 176 } 177 if d := findDescriptorInMessage(d, suffix); d != nil && d.FullName() == name { 178 return d, nil 179 } 180 case protoreflect.ExtensionDescriptor: 181 if d.FullName() == name { 182 return d, nil 183 } 184 case protoreflect.ServiceDescriptor: 185 if d.FullName() == name { 186 return d, nil 187 } 188 if d := d.Methods().ByName(suffix.Pop()); d != nil && d.FullName() == name { 189 return d, nil 190 } 191 } 192 return nil, NotFound 193 } 194 prefix = prefix.Parent() 195 suffix = nameSuffix(name[len(prefix)+len("."):]) 196 } 197 return nil, NotFound 198} 199 200func findDescriptorInMessage(md protoreflect.MessageDescriptor, suffix nameSuffix) protoreflect.Descriptor { 201 name := suffix.Pop() 202 if suffix == "" { 203 if ed := md.Enums().ByName(name); ed != nil { 204 return ed 205 } 206 for i := md.Enums().Len() - 1; i >= 0; i-- { 207 if vd := md.Enums().Get(i).Values().ByName(name); vd != nil { 208 return vd 209 } 210 } 211 if xd := md.Extensions().ByName(name); xd != nil { 212 return xd 213 } 214 if fd := md.Fields().ByName(name); fd != nil { 215 return fd 216 } 217 if od := md.Oneofs().ByName(name); od != nil { 218 return od 219 } 220 } 221 if md := md.Messages().ByName(name); md != nil { 222 if suffix == "" { 223 return md 224 } 225 return findDescriptorInMessage(md, suffix) 226 } 227 return nil 228} 229 230type nameSuffix string 231 232func (s *nameSuffix) Pop() (name protoreflect.Name) { 233 if i := strings.IndexByte(string(*s), '.'); i >= 0 { 234 name, *s = protoreflect.Name((*s)[:i]), (*s)[i+1:] 235 } else { 236 name, *s = protoreflect.Name((*s)), "" 237 } 238 return name 239} 240 241// FindFileByPath looks up a file by the path. 242// 243// This returns (nil, NotFound) if not found. 244func (r *Files) FindFileByPath(path string) (protoreflect.FileDescriptor, error) { 245 if r == nil { 246 return nil, NotFound 247 } 248 if r == GlobalFiles { 249 globalMutex.RLock() 250 defer globalMutex.RUnlock() 251 } 252 if fd, ok := r.filesByPath[path]; ok { 253 return fd, nil 254 } 255 return nil, NotFound 256} 257 258// NumFiles reports the number of registered files. 259func (r *Files) NumFiles() int { 260 if r == nil { 261 return 0 262 } 263 if r == GlobalFiles { 264 globalMutex.RLock() 265 defer globalMutex.RUnlock() 266 } 267 return len(r.filesByPath) 268} 269 270// RangeFiles iterates over all registered files while f returns true. 271// The iteration order is undefined. 272func (r *Files) RangeFiles(f func(protoreflect.FileDescriptor) bool) { 273 if r == nil { 274 return 275 } 276 if r == GlobalFiles { 277 globalMutex.RLock() 278 defer globalMutex.RUnlock() 279 } 280 for _, file := range r.filesByPath { 281 if !f(file) { 282 return 283 } 284 } 285} 286 287// NumFilesByPackage reports the number of registered files in a proto package. 288func (r *Files) NumFilesByPackage(name protoreflect.FullName) int { 289 if r == nil { 290 return 0 291 } 292 if r == GlobalFiles { 293 globalMutex.RLock() 294 defer globalMutex.RUnlock() 295 } 296 p, ok := r.descsByName[name].(*packageDescriptor) 297 if !ok { 298 return 0 299 } 300 return len(p.files) 301} 302 303// RangeFilesByPackage iterates over all registered files in a given proto package 304// while f returns true. The iteration order is undefined. 305func (r *Files) RangeFilesByPackage(name protoreflect.FullName, f func(protoreflect.FileDescriptor) bool) { 306 if r == nil { 307 return 308 } 309 if r == GlobalFiles { 310 globalMutex.RLock() 311 defer globalMutex.RUnlock() 312 } 313 p, ok := r.descsByName[name].(*packageDescriptor) 314 if !ok { 315 return 316 } 317 for _, file := range p.files { 318 if !f(file) { 319 return 320 } 321 } 322} 323 324// rangeTopLevelDescriptors iterates over all top-level descriptors in a file 325// which will be directly entered into the registry. 326func rangeTopLevelDescriptors(fd protoreflect.FileDescriptor, f func(protoreflect.Descriptor)) { 327 eds := fd.Enums() 328 for i := eds.Len() - 1; i >= 0; i-- { 329 f(eds.Get(i)) 330 vds := eds.Get(i).Values() 331 for i := vds.Len() - 1; i >= 0; i-- { 332 f(vds.Get(i)) 333 } 334 } 335 mds := fd.Messages() 336 for i := mds.Len() - 1; i >= 0; i-- { 337 f(mds.Get(i)) 338 } 339 xds := fd.Extensions() 340 for i := xds.Len() - 1; i >= 0; i-- { 341 f(xds.Get(i)) 342 } 343 sds := fd.Services() 344 for i := sds.Len() - 1; i >= 0; i-- { 345 f(sds.Get(i)) 346 } 347} 348 349// MessageTypeResolver is an interface for looking up messages. 350// 351// A compliant implementation must deterministically return the same type 352// if no error is encountered. 353// 354// The Types type implements this interface. 355type MessageTypeResolver interface { 356 // FindMessageByName looks up a message by its full name. 357 // E.g., "google.protobuf.Any" 358 // 359 // This return (nil, NotFound) if not found. 360 FindMessageByName(message protoreflect.FullName) (protoreflect.MessageType, error) 361 362 // FindMessageByURL looks up a message by a URL identifier. 363 // See documentation on google.protobuf.Any.type_url for the URL format. 364 // 365 // This returns (nil, NotFound) if not found. 366 FindMessageByURL(url string) (protoreflect.MessageType, error) 367} 368 369// ExtensionTypeResolver is an interface for looking up extensions. 370// 371// A compliant implementation must deterministically return the same type 372// if no error is encountered. 373// 374// The Types type implements this interface. 375type ExtensionTypeResolver interface { 376 // FindExtensionByName looks up a extension field by the field's full name. 377 // Note that this is the full name of the field as determined by 378 // where the extension is declared and is unrelated to the full name of the 379 // message being extended. 380 // 381 // This returns (nil, NotFound) if not found. 382 FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) 383 384 // FindExtensionByNumber looks up a extension field by the field number 385 // within some parent message, identified by full name. 386 // 387 // This returns (nil, NotFound) if not found. 388 FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) 389} 390 391var ( 392 _ MessageTypeResolver = (*Types)(nil) 393 _ ExtensionTypeResolver = (*Types)(nil) 394) 395 396// Types is a registry for looking up or iterating over descriptor types. 397// The Find and Range methods are safe for concurrent use. 398type Types struct { 399 typesByName typesByName 400 extensionsByMessage extensionsByMessage 401 402 numEnums int 403 numMessages int 404 numExtensions int 405} 406 407type ( 408 typesByName map[protoreflect.FullName]interface{} 409 extensionsByMessage map[protoreflect.FullName]extensionsByNumber 410 extensionsByNumber map[protoreflect.FieldNumber]protoreflect.ExtensionType 411) 412 413// RegisterMessage registers the provided message type. 414// 415// If a naming conflict occurs, the type is not registered and an error is returned. 416func (r *Types) RegisterMessage(mt protoreflect.MessageType) error { 417 // Under rare circumstances getting the descriptor might recursively 418 // examine the registry, so fetch it before locking. 419 md := mt.Descriptor() 420 421 if r == GlobalTypes { 422 globalMutex.Lock() 423 defer globalMutex.Unlock() 424 } 425 426 if err := r.register("message", md, mt); err != nil { 427 return err 428 } 429 r.numMessages++ 430 return nil 431} 432 433// RegisterEnum registers the provided enum type. 434// 435// If a naming conflict occurs, the type is not registered and an error is returned. 436func (r *Types) RegisterEnum(et protoreflect.EnumType) error { 437 // Under rare circumstances getting the descriptor might recursively 438 // examine the registry, so fetch it before locking. 439 ed := et.Descriptor() 440 441 if r == GlobalTypes { 442 globalMutex.Lock() 443 defer globalMutex.Unlock() 444 } 445 446 if err := r.register("enum", ed, et); err != nil { 447 return err 448 } 449 r.numEnums++ 450 return nil 451} 452 453// RegisterExtension registers the provided extension type. 454// 455// If a naming conflict occurs, the type is not registered and an error is returned. 456func (r *Types) RegisterExtension(xt protoreflect.ExtensionType) error { 457 // Under rare circumstances getting the descriptor might recursively 458 // examine the registry, so fetch it before locking. 459 // 460 // A known case where this can happen: Fetching the TypeDescriptor for a 461 // legacy ExtensionDesc can consult the global registry. 462 xd := xt.TypeDescriptor() 463 464 if r == GlobalTypes { 465 globalMutex.Lock() 466 defer globalMutex.Unlock() 467 } 468 469 field := xd.Number() 470 message := xd.ContainingMessage().FullName() 471 if prev := r.extensionsByMessage[message][field]; prev != nil { 472 err := errors.New("extension number %d is already registered on message %v", field, message) 473 err = amendErrorWithCaller(err, prev, xt) 474 if !(r == GlobalTypes && ignoreConflict(xd, err)) { 475 return err 476 } 477 } 478 479 if err := r.register("extension", xd, xt); err != nil { 480 return err 481 } 482 if r.extensionsByMessage == nil { 483 r.extensionsByMessage = make(extensionsByMessage) 484 } 485 if r.extensionsByMessage[message] == nil { 486 r.extensionsByMessage[message] = make(extensionsByNumber) 487 } 488 r.extensionsByMessage[message][field] = xt 489 r.numExtensions++ 490 return nil 491} 492 493func (r *Types) register(kind string, desc protoreflect.Descriptor, typ interface{}) error { 494 name := desc.FullName() 495 prev := r.typesByName[name] 496 if prev != nil { 497 err := errors.New("%v %v is already registered", kind, name) 498 err = amendErrorWithCaller(err, prev, typ) 499 if !(r == GlobalTypes && ignoreConflict(desc, err)) { 500 return err 501 } 502 } 503 if r.typesByName == nil { 504 r.typesByName = make(typesByName) 505 } 506 r.typesByName[name] = typ 507 return nil 508} 509 510// FindEnumByName looks up an enum by its full name. 511// E.g., "google.protobuf.Field.Kind". 512// 513// This returns (nil, NotFound) if not found. 514func (r *Types) FindEnumByName(enum protoreflect.FullName) (protoreflect.EnumType, error) { 515 if r == nil { 516 return nil, NotFound 517 } 518 if r == GlobalTypes { 519 globalMutex.RLock() 520 defer globalMutex.RUnlock() 521 } 522 if v := r.typesByName[enum]; v != nil { 523 if et, _ := v.(protoreflect.EnumType); et != nil { 524 return et, nil 525 } 526 return nil, errors.New("found wrong type: got %v, want enum", typeName(v)) 527 } 528 return nil, NotFound 529} 530 531// FindMessageByName looks up a message by its full name. 532// E.g., "google.protobuf.Any" 533// 534// This return (nil, NotFound) if not found. 535func (r *Types) FindMessageByName(message protoreflect.FullName) (protoreflect.MessageType, error) { 536 // The full name by itself is a valid URL. 537 return r.FindMessageByURL(string(message)) 538} 539 540// FindMessageByURL looks up a message by a URL identifier. 541// See documentation on google.protobuf.Any.type_url for the URL format. 542// 543// This returns (nil, NotFound) if not found. 544func (r *Types) FindMessageByURL(url string) (protoreflect.MessageType, error) { 545 if r == nil { 546 return nil, NotFound 547 } 548 if r == GlobalTypes { 549 globalMutex.RLock() 550 defer globalMutex.RUnlock() 551 } 552 message := protoreflect.FullName(url) 553 if i := strings.LastIndexByte(url, '/'); i >= 0 { 554 message = message[i+len("/"):] 555 } 556 557 if v := r.typesByName[message]; v != nil { 558 if mt, _ := v.(protoreflect.MessageType); mt != nil { 559 return mt, nil 560 } 561 return nil, errors.New("found wrong type: got %v, want message", typeName(v)) 562 } 563 return nil, NotFound 564} 565 566// FindExtensionByName looks up a extension field by the field's full name. 567// Note that this is the full name of the field as determined by 568// where the extension is declared and is unrelated to the full name of the 569// message being extended. 570// 571// This returns (nil, NotFound) if not found. 572func (r *Types) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) { 573 if r == nil { 574 return nil, NotFound 575 } 576 if r == GlobalTypes { 577 globalMutex.RLock() 578 defer globalMutex.RUnlock() 579 } 580 if v := r.typesByName[field]; v != nil { 581 if xt, _ := v.(protoreflect.ExtensionType); xt != nil { 582 return xt, nil 583 } 584 return nil, errors.New("found wrong type: got %v, want extension", typeName(v)) 585 } 586 return nil, NotFound 587} 588 589// FindExtensionByNumber looks up a extension field by the field number 590// within some parent message, identified by full name. 591// 592// This returns (nil, NotFound) if not found. 593func (r *Types) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) { 594 if r == nil { 595 return nil, NotFound 596 } 597 if r == GlobalTypes { 598 globalMutex.RLock() 599 defer globalMutex.RUnlock() 600 } 601 if xt, ok := r.extensionsByMessage[message][field]; ok { 602 return xt, nil 603 } 604 return nil, NotFound 605} 606 607// NumEnums reports the number of registered enums. 608func (r *Types) NumEnums() int { 609 if r == nil { 610 return 0 611 } 612 if r == GlobalTypes { 613 globalMutex.RLock() 614 defer globalMutex.RUnlock() 615 } 616 return r.numEnums 617} 618 619// RangeEnums iterates over all registered enums while f returns true. 620// Iteration order is undefined. 621func (r *Types) RangeEnums(f func(protoreflect.EnumType) bool) { 622 if r == nil { 623 return 624 } 625 if r == GlobalTypes { 626 globalMutex.RLock() 627 defer globalMutex.RUnlock() 628 } 629 for _, typ := range r.typesByName { 630 if et, ok := typ.(protoreflect.EnumType); ok { 631 if !f(et) { 632 return 633 } 634 } 635 } 636} 637 638// NumMessages reports the number of registered messages. 639func (r *Types) NumMessages() int { 640 if r == nil { 641 return 0 642 } 643 if r == GlobalTypes { 644 globalMutex.RLock() 645 defer globalMutex.RUnlock() 646 } 647 return r.numMessages 648} 649 650// RangeMessages iterates over all registered messages while f returns true. 651// Iteration order is undefined. 652func (r *Types) RangeMessages(f func(protoreflect.MessageType) bool) { 653 if r == nil { 654 return 655 } 656 if r == GlobalTypes { 657 globalMutex.RLock() 658 defer globalMutex.RUnlock() 659 } 660 for _, typ := range r.typesByName { 661 if mt, ok := typ.(protoreflect.MessageType); ok { 662 if !f(mt) { 663 return 664 } 665 } 666 } 667} 668 669// NumExtensions reports the number of registered extensions. 670func (r *Types) NumExtensions() int { 671 if r == nil { 672 return 0 673 } 674 if r == GlobalTypes { 675 globalMutex.RLock() 676 defer globalMutex.RUnlock() 677 } 678 return r.numExtensions 679} 680 681// RangeExtensions iterates over all registered extensions while f returns true. 682// Iteration order is undefined. 683func (r *Types) RangeExtensions(f func(protoreflect.ExtensionType) bool) { 684 if r == nil { 685 return 686 } 687 if r == GlobalTypes { 688 globalMutex.RLock() 689 defer globalMutex.RUnlock() 690 } 691 for _, typ := range r.typesByName { 692 if xt, ok := typ.(protoreflect.ExtensionType); ok { 693 if !f(xt) { 694 return 695 } 696 } 697 } 698} 699 700// NumExtensionsByMessage reports the number of registered extensions for 701// a given message type. 702func (r *Types) NumExtensionsByMessage(message protoreflect.FullName) int { 703 if r == nil { 704 return 0 705 } 706 if r == GlobalTypes { 707 globalMutex.RLock() 708 defer globalMutex.RUnlock() 709 } 710 return len(r.extensionsByMessage[message]) 711} 712 713// RangeExtensionsByMessage iterates over all registered extensions filtered 714// by a given message type while f returns true. Iteration order is undefined. 715func (r *Types) RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool) { 716 if r == nil { 717 return 718 } 719 if r == GlobalTypes { 720 globalMutex.RLock() 721 defer globalMutex.RUnlock() 722 } 723 for _, xt := range r.extensionsByMessage[message] { 724 if !f(xt) { 725 return 726 } 727 } 728} 729 730func typeName(t interface{}) string { 731 switch t.(type) { 732 case protoreflect.EnumType: 733 return "enum" 734 case protoreflect.MessageType: 735 return "message" 736 case protoreflect.ExtensionType: 737 return "extension" 738 default: 739 return fmt.Sprintf("%T", t) 740 } 741} 742 743func amendErrorWithCaller(err error, prev, curr interface{}) error { 744 prevPkg := goPackage(prev) 745 currPkg := goPackage(curr) 746 if prevPkg == "" || currPkg == "" || prevPkg == currPkg { 747 return err 748 } 749 return errors.New("%s\n\tpreviously from: %q\n\tcurrently from: %q", err, prevPkg, currPkg) 750} 751 752func goPackage(v interface{}) string { 753 switch d := v.(type) { 754 case protoreflect.EnumType: 755 v = d.Descriptor() 756 case protoreflect.MessageType: 757 v = d.Descriptor() 758 case protoreflect.ExtensionType: 759 v = d.TypeDescriptor() 760 } 761 if d, ok := v.(protoreflect.Descriptor); ok { 762 v = d.ParentFile() 763 } 764 if d, ok := v.(interface{ GoPackagePath() string }); ok { 765 return d.GoPackagePath() 766 } 767 return "" 768} 769