1package msgregistry 2 3import ( 4 "bytes" 5 "fmt" 6 "reflect" 7 "sort" 8 "strings" 9 "sync" 10 "sync/atomic" 11 12 "github.com/golang/protobuf/proto" 13 "github.com/golang/protobuf/protoc-gen-go/descriptor" 14 "github.com/golang/protobuf/ptypes/wrappers" 15 "golang.org/x/net/context" 16 "google.golang.org/genproto/protobuf/api" 17 "google.golang.org/genproto/protobuf/ptype" 18 19 "github.com/jhump/protoreflect/desc" 20 "github.com/jhump/protoreflect/dynamic" 21) 22 23var ( 24 enumOptionsDesc, enumValueOptionsDesc *desc.MessageDescriptor 25 msgOptionsDesc, fieldOptionsDesc *desc.MessageDescriptor 26 svcOptionsDesc, methodOptionsDesc *desc.MessageDescriptor 27) 28 29func init() { 30 var err error 31 enumOptionsDesc, err = desc.LoadMessageDescriptorForMessage((*descriptor.EnumOptions)(nil)) 32 if err != nil { 33 panic("Failed to load descriptor for EnumOptions") 34 } 35 enumValueOptionsDesc, err = desc.LoadMessageDescriptorForMessage((*descriptor.EnumValueOptions)(nil)) 36 if err != nil { 37 panic("Failed to load descriptor for EnumValueOptions") 38 } 39 msgOptionsDesc, err = desc.LoadMessageDescriptorForMessage((*descriptor.MessageOptions)(nil)) 40 if err != nil { 41 panic("Failed to load descriptor for MessageOptions") 42 } 43 fieldOptionsDesc, err = desc.LoadMessageDescriptorForMessage((*descriptor.FieldOptions)(nil)) 44 if err != nil { 45 panic("Failed to load descriptor for FieldOptions") 46 } 47 svcOptionsDesc, err = desc.LoadMessageDescriptorForMessage((*descriptor.ServiceOptions)(nil)) 48 if err != nil { 49 panic("Failed to load descriptor for ServiceOptions") 50 } 51 methodOptionsDesc, err = desc.LoadMessageDescriptorForMessage((*descriptor.MethodOptions)(nil)) 52 if err != nil { 53 panic("Failed to load descriptor for MethodOptions") 54 } 55} 56 57func ensureScheme(url string) string { 58 pos := strings.Index(url, "://") 59 if pos < 0 { 60 return "https://" + url 61 } 62 return url 63} 64 65// typeResolver is used by MessageRegistry to resolve message types. It uses a given TypeFetcher 66// to retrieve type definitions and caches resulting descriptor objects. 67type typeResolver struct { 68 fetcher TypeFetcher 69 mr *MessageRegistry 70 mu sync.RWMutex 71 cache map[string]desc.Descriptor 72} 73 74// resolveUrlToMessageDescriptor returns a message descriptor that represents the type at the given URL. 75func (r *typeResolver) resolveUrlToMessageDescriptor(url string) (*desc.MessageDescriptor, error) { 76 url = ensureScheme(url) 77 r.mu.RLock() 78 cached := r.cache[url] 79 r.mu.RUnlock() 80 if cached != nil { 81 if md, ok := cached.(*desc.MessageDescriptor); ok { 82 return md, nil 83 } else { 84 return nil, fmt.Errorf("type for URL %v is the wrong type: wanted message, got enum", url) 85 } 86 } 87 88 rc := newResolutionContext(r) 89 if err := rc.addType(url, false); err != nil { 90 return nil, err 91 } 92 93 var files map[string]*desc.FileDescriptor 94 files, err := rc.toFileDescriptors(r.mr) 95 if err != nil { 96 return nil, err 97 } 98 r.mu.Lock() 99 defer r.mu.Unlock() 100 var md *desc.MessageDescriptor 101 if len(rc.typeLocations) > 0 { 102 if r.cache == nil { 103 r.cache = map[string]desc.Descriptor{} 104 } 105 } 106 for typeUrl, fileName := range rc.typeLocations { 107 fd := files[fileName] 108 sym := fd.FindSymbol(typeName(typeUrl)) 109 r.cache[typeUrl] = sym 110 if url == typeUrl { 111 md = sym.(*desc.MessageDescriptor) 112 } 113 } 114 return md, nil 115} 116 117// resolveUrlsToMessageDescriptors returns a map of the given URLs to corresponding 118// message descriptors that represent the types at those URLs. 119func (r *typeResolver) resolveUrlsToMessageDescriptors(urls ...string) (map[string]*desc.MessageDescriptor, error) { 120 ret := map[string]*desc.MessageDescriptor{} 121 var unresolved []string 122 r.mu.RLock() 123 for _, u := range urls { 124 u = ensureScheme(u) 125 cached := r.cache[u] 126 if cached != nil { 127 if md, ok := cached.(*desc.MessageDescriptor); ok { 128 ret[u] = md 129 } else { 130 r.mu.RUnlock() 131 return nil, fmt.Errorf("type for URL %v is the wrong type: wanted message, got enum", u) 132 } 133 } else { 134 ret[u] = nil 135 unresolved = append(unresolved, u) 136 } 137 } 138 r.mu.RUnlock() 139 140 if len(unresolved) == 0 { 141 return ret, nil 142 } 143 144 rc := newResolutionContext(r) 145 for _, u := range unresolved { 146 if err := rc.addType(u, false); err != nil { 147 return nil, err 148 } 149 } 150 151 var files map[string]*desc.FileDescriptor 152 files, err := rc.toFileDescriptors(r.mr) 153 if err != nil { 154 return nil, err 155 } 156 r.mu.Lock() 157 defer r.mu.Unlock() 158 if len(rc.typeLocations) > 0 { 159 if r.cache == nil { 160 r.cache = map[string]desc.Descriptor{} 161 } 162 } 163 for typeUrl, fileName := range rc.typeLocations { 164 fd := files[fileName] 165 sym := fd.FindSymbol(typeName(typeUrl)) 166 r.cache[typeUrl] = sym 167 if _, ok := ret[typeUrl]; ok { 168 ret[typeUrl] = sym.(*desc.MessageDescriptor) 169 } 170 } 171 return ret, nil 172} 173 174// resolveUrlToEnumDescriptor returns an enum descriptor that represents the enum type at the given URL. 175func (r *typeResolver) resolveUrlToEnumDescriptor(url string) (*desc.EnumDescriptor, error) { 176 url = ensureScheme(url) 177 r.mu.RLock() 178 cached := r.cache[url] 179 r.mu.RUnlock() 180 if cached != nil { 181 if ed, ok := cached.(*desc.EnumDescriptor); ok { 182 return ed, nil 183 } else { 184 return nil, fmt.Errorf("type for URL %v is the wrong type: wanted enum, got message", url) 185 } 186 } 187 188 rc := newResolutionContext(r) 189 if err := rc.addType(url, true); err != nil { 190 return nil, err 191 } 192 193 var files map[string]*desc.FileDescriptor 194 files, err := rc.toFileDescriptors(r.mr) 195 if err != nil { 196 return nil, err 197 } 198 r.mu.Lock() 199 defer r.mu.Unlock() 200 var ed *desc.EnumDescriptor 201 if len(rc.typeLocations) > 0 { 202 if r.cache == nil { 203 r.cache = map[string]desc.Descriptor{} 204 } 205 } 206 for typeUrl, fileName := range rc.typeLocations { 207 fd := files[fileName] 208 sym := fd.FindSymbol(typeName(typeUrl)) 209 r.cache[typeUrl] = sym 210 if url == typeUrl { 211 ed = sym.(*desc.EnumDescriptor) 212 } 213 } 214 return ed, nil 215} 216 217type tracker func(d desc.Descriptor) bool 218 219func newNameTracker() tracker { 220 names := map[string]struct{}{} 221 return func(d desc.Descriptor) bool { 222 name := d.GetFullyQualifiedName() 223 if _, ok := names[name]; ok { 224 return false 225 } 226 names[name] = struct{}{} 227 return true 228 } 229} 230 231func addDescriptors(ref string, files map[string]*fileEntry, d desc.Descriptor, msgs map[string]*desc.MessageDescriptor, onAdd tracker) { 232 name := d.GetFullyQualifiedName() 233 234 fileName := d.GetFile().GetName() 235 if fileName != ref { 236 dependee := files[ref] 237 if dependee.deps == nil { 238 dependee.deps = map[string]struct{}{} 239 } 240 dependee.deps[fileName] = struct{}{} 241 } 242 243 if !onAdd(d) { 244 // already added this one 245 return 246 } 247 248 fe := files[fileName] 249 if fe == nil { 250 fe = &fileEntry{} 251 fe.proto3 = d.GetFile().IsProto3() 252 files[fileName] = fe 253 } 254 fe.types.addType(name, d.AsProto()) 255 256 if md, ok := d.(*desc.MessageDescriptor); ok { 257 for _, fld := range md.GetFields() { 258 if fld.GetType() == descriptor.FieldDescriptorProto_TYPE_MESSAGE || fld.GetType() == descriptor.FieldDescriptorProto_TYPE_GROUP { 259 // prefer descriptor in msgs map over what the field descriptor indicates 260 md := msgs[fld.GetMessageType().GetFullyQualifiedName()] 261 if md == nil { 262 md = fld.GetMessageType() 263 } 264 addDescriptors(fileName, files, md, msgs, onAdd) 265 } else if fld.GetType() == descriptor.FieldDescriptorProto_TYPE_ENUM { 266 addDescriptors(fileName, files, fld.GetEnumType(), msgs, onAdd) 267 } 268 } 269 } 270} 271 272// resolutionContext provides the state for a resolution operation, accumulating details about 273// type descriptions and the files that contain them. 274type resolutionContext struct { 275 // The context and cancel function, used to coordinate multiple goroutines when there are multiple 276 // type or enum descriptions to download. 277 ctx context.Context 278 cancel func() 279 res *typeResolver 280 281 mu sync.Mutex 282 // map of file names to details regarding the files' contents 283 files map[string]*fileEntry 284 // map of type URLs to the file name that defines them 285 typeLocations map[string]string 286 // count of source contexts that do not indicate a file name (used to generate unique file names 287 // when synthesizing file descriptors) 288 unknownCount int 289} 290 291func newResolutionContext(res *typeResolver) *resolutionContext { 292 ctx, cancel := context.WithCancel(context.Background()) 293 return &resolutionContext{ 294 ctx: ctx, 295 cancel: cancel, 296 res: res, 297 typeLocations: map[string]string{}, 298 files: map[string]*fileEntry{}, 299 } 300} 301 302// addType adds the type at the given URL to the context, using the given fetcher to download the type's 303// description. This function will recursively add dependencies (e.g. types referenced by the given type's 304// fields if it is a message type), fetching their type descriptions concurrently. 305func (rc *resolutionContext) addType(url string, enum bool) error { 306 if err := rc.ctx.Err(); err != nil { 307 return err 308 } 309 310 m, err := rc.res.fetcher(url, enum) 311 if err != nil { 312 return err 313 } else if m == nil { 314 return fmt.Errorf("failed to locate type for %s", url) 315 } 316 317 if enum { 318 rc.recordEnum(url, m.(*ptype.Enum)) 319 return nil 320 } 321 322 // for messages, resolve dependencies in parallel 323 t := m.(*ptype.Type) 324 fe, fileName := rc.recordType(url, t) 325 if fe == nil { 326 // already resolved this one 327 return nil 328 } 329 330 var wg sync.WaitGroup 331 var failed int32 332 for _, f := range t.Fields { 333 if f.Kind == ptype.Field_TYPE_GROUP || f.Kind == ptype.Field_TYPE_MESSAGE || f.Kind == ptype.Field_TYPE_ENUM { 334 typeUrl := ensureScheme(f.TypeUrl) 335 kind := f.Kind 336 wg.Add(1) 337 go func() { 338 defer wg.Done() 339 // first check the registry for descriptors 340 var d desc.Descriptor 341 var innerErr error 342 if kind == ptype.Field_TYPE_ENUM { 343 var ed *desc.EnumDescriptor 344 ed, innerErr = rc.res.mr.getRegisteredEnumTypeByUrl(typeUrl) 345 if ed != nil { 346 d = ed 347 } 348 } else { 349 var md *desc.MessageDescriptor 350 md, innerErr = rc.res.mr.getRegisteredMessageTypeByUrl(typeUrl) 351 if md != nil { 352 d = md 353 } 354 } 355 356 if innerErr == nil { 357 if d != nil { 358 // found it! 359 rc.recordDescriptor(typeUrl, fileName, d) 360 } else { 361 // not in registry, so we have to recursively fetch 362 innerErr = rc.addType(typeUrl, kind == ptype.Field_TYPE_ENUM) 363 } 364 } 365 366 // We want the "real" error to ultimately propagate to root, not 367 // one of the resulting cancellations (from any concurrent goroutines 368 // working in the same resolution context). 369 if innerErr != nil && (rc.ctx.Err() == nil || innerErr != context.Canceled) { 370 if atomic.CompareAndSwapInt32(&failed, 0, 1) { 371 err = innerErr 372 } 373 rc.cancel() 374 } 375 }() 376 } 377 } 378 wg.Wait() 379 if err != nil { 380 return err 381 } 382 // double-check if context has been cancelled 383 if err = rc.ctx.Err(); err != nil { 384 return err 385 } 386 387 rc.mu.Lock() 388 defer rc.mu.Unlock() 389 390 for _, f := range t.Fields { 391 if f.Kind == ptype.Field_TYPE_GROUP || f.Kind == ptype.Field_TYPE_MESSAGE || f.Kind == ptype.Field_TYPE_ENUM { 392 typeUrl := ensureScheme(f.TypeUrl) 393 if fe.deps == nil { 394 fe.deps = map[string]struct{}{} 395 } 396 dep := rc.typeLocations[typeUrl] 397 if dep != fileName { 398 fe.deps[dep] = struct{}{} 399 } 400 } 401 } 402 return nil 403} 404 405func (rc *resolutionContext) recordEnum(url string, e *ptype.Enum) { 406 rc.mu.Lock() 407 defer rc.mu.Unlock() 408 409 var fileName string 410 if e.SourceContext != nil && e.SourceContext.FileName != "" { 411 fileName = e.SourceContext.FileName 412 } else { 413 fileName = fmt.Sprintf("--unknown--%d.proto", rc.unknownCount) 414 rc.unknownCount++ 415 } 416 rc.typeLocations[url] = fileName 417 418 fe := rc.files[fileName] 419 if fe == nil { 420 fe = &fileEntry{} 421 rc.files[fileName] = fe 422 } 423 fe.types.addType(e.Name, e) 424 if e.Syntax == ptype.Syntax_SYNTAX_PROTO3 { 425 fe.proto3 = true 426 } 427} 428 429func (rc *resolutionContext) recordType(url string, t *ptype.Type) (*fileEntry, string) { 430 rc.mu.Lock() 431 defer rc.mu.Unlock() 432 433 if _, ok := rc.typeLocations[url]; ok { 434 return nil, "" 435 } 436 437 var fileName string 438 if t.SourceContext != nil && t.SourceContext.FileName != "" { 439 fileName = t.SourceContext.FileName 440 } else { 441 fileName = fmt.Sprintf("--unknown--%d.proto", rc.unknownCount) 442 rc.unknownCount++ 443 } 444 rc.typeLocations[url] = fileName 445 446 fe := rc.files[fileName] 447 if fe == nil { 448 fe = &fileEntry{} 449 rc.files[fileName] = fe 450 } 451 fe.types.addType(t.Name, t) 452 if t.Syntax == ptype.Syntax_SYNTAX_PROTO3 { 453 fe.proto3 = true 454 } 455 456 return fe, fileName 457} 458 459func (rc *resolutionContext) recordDescriptor(url, ref string, d desc.Descriptor) { 460 rc.mu.Lock() 461 defer rc.mu.Unlock() 462 463 addDescriptors(ref, rc.files, d, nil, func(dsc desc.Descriptor) bool { 464 u := ensureScheme(rc.res.mr.ComputeUrl(dsc)) 465 if _, ok := rc.typeLocations[u]; ok { 466 // already seen this one 467 return false 468 } 469 fileName := dsc.GetFile().GetName() 470 rc.typeLocations[u] = fileName 471 if dsc == d { 472 // make sure we're also adding the actual URL reference used 473 rc.typeLocations[url] = fileName 474 } 475 return true 476 }) 477} 478 479// toFileDescriptors converts the information in the context into a map of file names to file descriptors. 480func (rc *resolutionContext) toFileDescriptors(mr *MessageRegistry) (map[string]*desc.FileDescriptor, error) { 481 return toFileDescriptors(rc.files, func(tt *typeTrie, name string) (proto.Message, error) { 482 mdp, edp := tt.ptypeToDescriptor(name, mr) 483 if mdp != nil { 484 return mdp, nil 485 } else { 486 return edp, nil 487 } 488 }) 489} 490 491// converts a map of file entries into a map of file descriptors using the given function to convert 492// each trie node into a descriptor proto. 493func toFileDescriptors(files map[string]*fileEntry, trieFn func(*typeTrie, string) (proto.Message, error)) (map[string]*desc.FileDescriptor, error) { 494 fdps := map[string]*descriptor.FileDescriptorProto{} 495 for name, file := range files { 496 fdp, err := file.toFileDescriptor(name, trieFn) 497 if err != nil { 498 return nil, err 499 } 500 fdps[name] = fdp 501 } 502 fds := map[string]*desc.FileDescriptor{} 503 for name, fdp := range fdps { 504 if _, ok := fds[name]; ok { 505 continue 506 } 507 var err error 508 if fds[name], err = makeFileDesc(fdp, fds, fdps); err != nil { 509 return nil, err 510 } 511 } 512 return fds, nil 513} 514 515func makeFileDesc(fdp *descriptor.FileDescriptorProto, fds map[string]*desc.FileDescriptor, fdps map[string]*descriptor.FileDescriptorProto) (*desc.FileDescriptor, error) { 516 deps := make([]*desc.FileDescriptor, len(fdp.Dependency)) 517 for i, dep := range fdp.Dependency { 518 d := fds[dep] 519 if d == nil { 520 var err error 521 depFd := fdps[dep] 522 if depFd == nil { 523 return nil, fmt.Errorf("missing dependency: %s", dep) 524 } 525 d, err = makeFileDesc(depFd, fds, fdps) 526 if err != nil { 527 return nil, err 528 } 529 } 530 deps[i] = d 531 } 532 if fd, err := desc.CreateFileDescriptor(fdp, deps...); err != nil { 533 return nil, err 534 } else { 535 fds[fdp.GetName()] = fd 536 return fd, nil 537 } 538} 539 540// fileEntry represents the contents of a single file. 541type fileEntry struct { 542 types typeTrie 543 deps map[string]struct{} 544 proto3 bool 545} 546 547// toFileDescriptor converts this file entry into a file descriptor proto. The given function 548// is used to transform nodes in a typeTrie into message and/or enum descriptor protos. 549func (fe *fileEntry) toFileDescriptor(name string, trieFn func(*typeTrie, string) (proto.Message, error)) (*descriptor.FileDescriptorProto, error) { 550 var pkg bytes.Buffer 551 tt := &fe.types 552 first := true 553 last := "" 554 for tt.typ == nil { 555 if last != "" { 556 if first { 557 first = false 558 } else { 559 pkg.WriteByte('.') 560 } 561 pkg.WriteString(last) 562 } 563 if len(tt.children) != 1 { 564 break 565 } 566 for last, tt = range tt.children { 567 } 568 } 569 fd := createFileDescriptor(name, pkg.String(), fe.proto3, fe.deps) 570 if tt.typ != nil { 571 pm, err := trieFn(tt, last) 572 if err != nil { 573 return nil, err 574 } 575 if mdp, ok := pm.(*descriptor.DescriptorProto); ok { 576 fd.MessageType = append(fd.MessageType, mdp) 577 } else if edp, ok := pm.(*descriptor.EnumDescriptorProto); ok { 578 fd.EnumType = append(fd.EnumType, edp) 579 } else { 580 sdp := pm.(*descriptor.ServiceDescriptorProto) 581 fd.Service = append(fd.Service, sdp) 582 } 583 } else { 584 for name, nested := range tt.children { 585 pm, err := trieFn(nested, name) 586 if err != nil { 587 return nil, err 588 } 589 if mdp, ok := pm.(*descriptor.DescriptorProto); ok { 590 fd.MessageType = append(fd.MessageType, mdp) 591 } else if edp, ok := pm.(*descriptor.EnumDescriptorProto); ok { 592 fd.EnumType = append(fd.EnumType, edp) 593 } else { 594 sdp := pm.(*descriptor.ServiceDescriptorProto) 595 fd.Service = append(fd.Service, sdp) 596 } 597 } 598 } 599 return fd, nil 600} 601 602// typeTrie is a prefix trie where each key component is part of a fully-qualified type name. So key components 603// will either be package name components or element names. 604type typeTrie struct { 605 // successor key components 606 children map[string]*typeTrie 607 // if non-nil, the element whose fully-qualified name is the path from the trie root to this node 608 typ proto.Message 609} 610 611// addType recursively adds an element to the trie. 612func (t *typeTrie) addType(key string, typ proto.Message) { 613 if key == "" { 614 t.typ = typ 615 return 616 } 617 if t.children == nil { 618 t.children = map[string]*typeTrie{} 619 } 620 curr, rest := split(key) 621 child := t.children[curr] 622 if child == nil { 623 child = &typeTrie{} 624 t.children[curr] = child 625 } 626 child.addType(rest, typ) 627} 628 629// ptypeToDescriptor converts this level of the trie into a message or enum 630// descriptor proto, requiring that the element stored in t.typ is a *ptype.Type 631// or *ptype.Enum. If t.typ is nil, a placeholder message (with no fields) is 632// returned that contains the trie's children as nested message and/or enum 633// types. 634// 635// If the value in t.typ is already a *descriptor.DescriptorProto or a 636// *descriptor.EnumDescriptorProto then it is returned as is. This function 637// should not be used in type tries that may have service descriptors. That will 638// result in a panic. 639func (t *typeTrie) ptypeToDescriptor(name string, mr *MessageRegistry) (*descriptor.DescriptorProto, *descriptor.EnumDescriptorProto) { 640 switch typ := t.typ.(type) { 641 case *descriptor.EnumDescriptorProto: 642 return nil, typ 643 case *ptype.Enum: 644 return nil, createEnumDescriptor(typ, mr) 645 case *descriptor.DescriptorProto: 646 return typ, nil 647 default: 648 var msg *descriptor.DescriptorProto 649 if t.typ == nil { 650 msg = createIntermediateMessageDescriptor(name) 651 } else { 652 msg = createMessageDescriptor(t.typ.(*ptype.Type), mr) 653 } 654 // sort children for deterministic output 655 var keys []string 656 for k := range t.children { 657 keys = append(keys, k) 658 } 659 for _, name := range keys { 660 nested := t.children[name] 661 chMsg, chEnum := nested.ptypeToDescriptor(name, mr) 662 if chMsg != nil { 663 msg.NestedType = append(msg.NestedType, chMsg) 664 } 665 if chEnum != nil { 666 msg.EnumType = append(msg.EnumType, chEnum) 667 } 668 } 669 return msg, nil 670 } 671} 672 673// rewriteDescriptor converts this level of the trie into a new descriptor 674// proto, requiring that the element stored in t.type is already a service, 675// message, or enum descriptor proto. If this trie has children then t.typ must 676// be a message descriptor proto. The returned descriptor proto is the same as 677// .type but with possibly new nested elements to represent this trie node's 678// children. 679func (t *typeTrie) rewriteDescriptor(name string) (proto.Message, error) { 680 if len(t.children) == 0 && t.typ != nil { 681 if mdp, ok := t.typ.(*descriptor.DescriptorProto); ok { 682 if len(mdp.NestedType) == 0 && len(mdp.EnumType) == 0 { 683 return mdp, nil 684 } 685 mdp = proto.Clone(mdp).(*descriptor.DescriptorProto) 686 mdp.NestedType = nil 687 mdp.EnumType = nil 688 return mdp, nil 689 } 690 return t.typ, nil 691 } 692 var mdp *descriptor.DescriptorProto 693 if t.typ == nil { 694 mdp = createIntermediateMessageDescriptor(name) 695 } else { 696 mdp = t.typ.(*descriptor.DescriptorProto) 697 mdp = proto.Clone(mdp).(*descriptor.DescriptorProto) 698 mdp.NestedType = nil 699 mdp.EnumType = nil 700 } 701 // sort children for deterministic output 702 var keys []string 703 for k := range t.children { 704 keys = append(keys, k) 705 } 706 for _, n := range keys { 707 ch := t.children[n] 708 typ, err := ch.rewriteDescriptor(n) 709 if err != nil { 710 return nil, err 711 } 712 switch typ := typ.(type) { 713 case (*descriptor.DescriptorProto): 714 mdp.NestedType = append(mdp.NestedType, typ) 715 case (*descriptor.EnumDescriptorProto): 716 mdp.EnumType = append(mdp.EnumType, typ) 717 default: 718 // TODO: this should probably panic instead 719 return nil, fmt.Errorf("invalid descriptor trie: message cannot have child of type %v", reflect.TypeOf(typ)) 720 } 721 } 722 return mdp, nil 723} 724 725func split(s string) (string, string) { 726 pos := strings.Index(s, ".") 727 if pos >= 0 { 728 return s[:pos], s[pos+1:] 729 } else { 730 return s, "" 731 } 732} 733 734func createEnumDescriptor(e *ptype.Enum, mr *MessageRegistry) *descriptor.EnumDescriptorProto { 735 var opts *descriptor.EnumOptions 736 if len(e.Options) > 0 { 737 dopts := createOptions(e.Options, enumOptionsDesc, mr) 738 opts = &descriptor.EnumOptions{} 739 dopts.ConvertTo(opts) // ignore any error 740 } 741 742 var vals []*descriptor.EnumValueDescriptorProto 743 for _, v := range e.Enumvalue { 744 evd := createEnumValueDescriptor(v, mr) 745 vals = append(vals, evd) 746 } 747 748 return &descriptor.EnumDescriptorProto{ 749 Name: proto.String(base(e.Name)), 750 Options: opts, 751 Value: vals, 752 } 753} 754 755func createEnumValueDescriptor(v *ptype.EnumValue, mr *MessageRegistry) *descriptor.EnumValueDescriptorProto { 756 var opts *descriptor.EnumValueOptions 757 if len(v.Options) > 0 { 758 dopts := createOptions(v.Options, enumValueOptionsDesc, mr) 759 opts = &descriptor.EnumValueOptions{} 760 dopts.ConvertTo(opts) // ignore any error 761 } 762 763 return &descriptor.EnumValueDescriptorProto{ 764 Name: proto.String(v.Name), 765 Number: proto.Int32(v.Number), 766 Options: opts, 767 } 768} 769 770func createMessageDescriptor(m *ptype.Type, mr *MessageRegistry) *descriptor.DescriptorProto { 771 var opts *descriptor.MessageOptions 772 if len(m.Options) > 0 { 773 dopts := createOptions(m.Options, msgOptionsDesc, mr) 774 opts = &descriptor.MessageOptions{} 775 dopts.ConvertTo(opts) // ignore any error 776 } 777 778 var fields []*descriptor.FieldDescriptorProto 779 for _, f := range m.Fields { 780 fields = append(fields, createFieldDescriptor(f, mr)) 781 } 782 783 var oneOfs []*descriptor.OneofDescriptorProto 784 for _, o := range m.Oneofs { 785 oneOfs = append(oneOfs, &descriptor.OneofDescriptorProto{ 786 Name: proto.String(o), 787 }) 788 } 789 790 return &descriptor.DescriptorProto{ 791 Name: proto.String(base(m.Name)), 792 Options: opts, 793 Field: fields, 794 OneofDecl: oneOfs, 795 } 796} 797 798func createFieldDescriptor(f *ptype.Field, mr *MessageRegistry) *descriptor.FieldDescriptorProto { 799 var opts *descriptor.FieldOptions 800 if len(f.Options) > 0 { 801 dopts := createOptions(f.Options, fieldOptionsDesc, mr) 802 opts = &descriptor.FieldOptions{} 803 dopts.ConvertTo(opts) // ignore any error 804 } 805 if f.Packed { 806 if opts == nil { 807 opts = &descriptor.FieldOptions{Packed: proto.Bool(true)} 808 } else { 809 opts.Packed = proto.Bool(true) 810 } 811 } 812 813 var oneOf *int32 814 if f.OneofIndex > 0 { 815 oneOf = proto.Int32(f.OneofIndex - 1) 816 } 817 818 var typeName string 819 if f.Kind == ptype.Field_TYPE_GROUP || f.Kind == ptype.Field_TYPE_MESSAGE || f.Kind == ptype.Field_TYPE_ENUM { 820 pos := strings.LastIndex(f.TypeUrl, "/") 821 typeName = "." + f.TypeUrl[pos+1:] 822 } 823 824 var label descriptor.FieldDescriptorProto_Label 825 switch f.Cardinality { 826 case ptype.Field_CARDINALITY_OPTIONAL: 827 label = descriptor.FieldDescriptorProto_LABEL_OPTIONAL 828 case ptype.Field_CARDINALITY_REPEATED: 829 label = descriptor.FieldDescriptorProto_LABEL_REPEATED 830 case ptype.Field_CARDINALITY_REQUIRED: 831 label = descriptor.FieldDescriptorProto_LABEL_REQUIRED 832 } 833 834 var typ descriptor.FieldDescriptorProto_Type 835 switch f.Kind { 836 case ptype.Field_TYPE_ENUM: 837 typ = descriptor.FieldDescriptorProto_TYPE_ENUM 838 case ptype.Field_TYPE_GROUP: 839 typ = descriptor.FieldDescriptorProto_TYPE_GROUP 840 case ptype.Field_TYPE_MESSAGE: 841 typ = descriptor.FieldDescriptorProto_TYPE_MESSAGE 842 case ptype.Field_TYPE_BYTES: 843 typ = descriptor.FieldDescriptorProto_TYPE_BYTES 844 case ptype.Field_TYPE_STRING: 845 typ = descriptor.FieldDescriptorProto_TYPE_STRING 846 case ptype.Field_TYPE_BOOL: 847 typ = descriptor.FieldDescriptorProto_TYPE_BOOL 848 case ptype.Field_TYPE_DOUBLE: 849 typ = descriptor.FieldDescriptorProto_TYPE_DOUBLE 850 case ptype.Field_TYPE_FLOAT: 851 typ = descriptor.FieldDescriptorProto_TYPE_FLOAT 852 case ptype.Field_TYPE_FIXED32: 853 typ = descriptor.FieldDescriptorProto_TYPE_FIXED32 854 case ptype.Field_TYPE_FIXED64: 855 typ = descriptor.FieldDescriptorProto_TYPE_FIXED64 856 case ptype.Field_TYPE_INT32: 857 typ = descriptor.FieldDescriptorProto_TYPE_INT32 858 case ptype.Field_TYPE_INT64: 859 typ = descriptor.FieldDescriptorProto_TYPE_INT64 860 case ptype.Field_TYPE_SFIXED32: 861 typ = descriptor.FieldDescriptorProto_TYPE_SFIXED32 862 case ptype.Field_TYPE_SFIXED64: 863 typ = descriptor.FieldDescriptorProto_TYPE_SFIXED64 864 case ptype.Field_TYPE_SINT32: 865 typ = descriptor.FieldDescriptorProto_TYPE_SINT32 866 case ptype.Field_TYPE_SINT64: 867 typ = descriptor.FieldDescriptorProto_TYPE_SINT64 868 case ptype.Field_TYPE_UINT32: 869 typ = descriptor.FieldDescriptorProto_TYPE_UINT32 870 case ptype.Field_TYPE_UINT64: 871 typ = descriptor.FieldDescriptorProto_TYPE_UINT64 872 } 873 874 return &descriptor.FieldDescriptorProto{ 875 Name: proto.String(f.Name), 876 Number: proto.Int32(f.Number), 877 DefaultValue: proto.String(f.DefaultValue), 878 JsonName: proto.String(f.JsonName), 879 OneofIndex: oneOf, 880 TypeName: proto.String(typeName), 881 Label: label.Enum(), 882 Type: typ.Enum(), 883 Options: opts, 884 } 885} 886 887func createServiceDescriptor(a *api.Api, mr *MessageRegistry) *descriptor.ServiceDescriptorProto { 888 var opts *descriptor.ServiceOptions 889 if len(a.Options) > 0 { 890 dopts := createOptions(a.Options, svcOptionsDesc, mr) 891 opts = &descriptor.ServiceOptions{} 892 dopts.ConvertTo(opts) // ignore any error 893 } 894 895 methods := make([]*descriptor.MethodDescriptorProto, len(a.Methods)) 896 for i, m := range a.Methods { 897 methods[i] = createMethodDescriptor(m, mr) 898 } 899 900 return &descriptor.ServiceDescriptorProto{ 901 Name: proto.String(base(a.Name)), 902 Method: methods, 903 Options: opts, 904 } 905} 906 907func createMethodDescriptor(m *api.Method, mr *MessageRegistry) *descriptor.MethodDescriptorProto { 908 var opts *descriptor.MethodOptions 909 if len(m.Options) > 0 { 910 dopts := createOptions(m.Options, methodOptionsDesc, mr) 911 opts = &descriptor.MethodOptions{} 912 dopts.ConvertTo(opts) // ignore any error 913 } 914 915 var reqType, respType string 916 pos := strings.LastIndex(m.RequestTypeUrl, "/") 917 reqType = "." + m.RequestTypeUrl[pos+1:] 918 pos = strings.LastIndex(m.ResponseTypeUrl, "/") 919 respType = "." + m.ResponseTypeUrl[pos+1:] 920 921 return &descriptor.MethodDescriptorProto{ 922 Name: proto.String(m.Name), 923 Options: opts, 924 ClientStreaming: proto.Bool(m.RequestStreaming), 925 ServerStreaming: proto.Bool(m.ResponseStreaming), 926 InputType: proto.String(reqType), 927 OutputType: proto.String(respType), 928 } 929} 930 931func createIntermediateMessageDescriptor(name string) *descriptor.DescriptorProto { 932 return &descriptor.DescriptorProto{ 933 Name: proto.String(name), 934 } 935} 936 937func createFileDescriptor(name, pkg string, proto3 bool, deps map[string]struct{}) *descriptor.FileDescriptorProto { 938 imports := make([]string, 0, len(deps)) 939 for k := range deps { 940 imports = append(imports, k) 941 } 942 sort.Strings(imports) 943 var syntax string 944 if proto3 { 945 syntax = "proto3" 946 } else { 947 syntax = "proto2" 948 } 949 return &descriptor.FileDescriptorProto{ 950 Name: proto.String(name), 951 Package: proto.String(pkg), 952 Syntax: proto.String(syntax), 953 Dependency: imports, 954 } 955} 956 957func createOptions(options []*ptype.Option, optionsDesc *desc.MessageDescriptor, mr *MessageRegistry) *dynamic.Message { 958 // these are created "best effort" so entries which are unresolvable 959 // (or seemingly invalid) are simply ignored... 960 dopts := mr.mf.NewDynamicMessage(optionsDesc) 961 for _, o := range options { 962 field := optionsDesc.FindFieldByName(o.Name) 963 if field == nil { 964 field = mr.er.FindExtensionByName(optionsDesc.GetFullyQualifiedName(), o.Name) 965 if field == nil && o.Name[0] != '[' { 966 field = mr.er.FindExtensionByName(optionsDesc.GetFullyQualifiedName(), fmt.Sprintf("[%s]", o.Name)) 967 } 968 if field == nil { 969 // can't resolve option name? skip it 970 continue 971 } 972 } 973 v, err := mr.unmarshalAny(o.Value, func(url string) (*desc.MessageDescriptor, error) { 974 // we don't want to try to recursively fetch this value's type, so if it doesn't 975 // match the type of the extension field, we'll skip it 976 if (field.GetType() == descriptor.FieldDescriptorProto_TYPE_GROUP || 977 field.GetType() == descriptor.FieldDescriptorProto_TYPE_MESSAGE) && 978 typeName(url) == field.GetMessageType().GetFullyQualifiedName() { 979 980 return field.GetMessageType(), nil 981 } 982 return nil, nil 983 }) 984 if err != nil { 985 // can't interpret value? skip it 986 continue 987 } 988 var fv interface{} 989 if field.GetType() != descriptor.FieldDescriptorProto_TYPE_MESSAGE && field.GetType() != descriptor.FieldDescriptorProto_TYPE_GROUP { 990 fv = unwrap(v) 991 if v == nil { 992 // non-wrapper type for scalar field? skip it 993 continue 994 } 995 } else { 996 fv = v 997 } 998 if field.IsRepeated() { 999 dopts.TryAddRepeatedField(field, fv) // ignore any error 1000 } else { 1001 dopts.TrySetField(field, fv) // ignore any error 1002 } 1003 } 1004 return dopts 1005} 1006 1007func base(name string) string { 1008 pos := strings.LastIndex(name, ".") 1009 if pos >= 0 { 1010 return name[pos+1:] 1011 } 1012 return name 1013} 1014 1015func unwrap(msg proto.Message) interface{} { 1016 switch m := msg.(type) { 1017 case (*wrappers.BoolValue): 1018 return m.Value 1019 case (*wrappers.FloatValue): 1020 return m.Value 1021 case (*wrappers.DoubleValue): 1022 return m.Value 1023 case (*wrappers.Int32Value): 1024 return m.Value 1025 case (*wrappers.Int64Value): 1026 return m.Value 1027 case (*wrappers.UInt32Value): 1028 return m.Value 1029 case (*wrappers.UInt64Value): 1030 return m.Value 1031 case (*wrappers.BytesValue): 1032 return m.Value 1033 case (*wrappers.StringValue): 1034 return m.Value 1035 default: 1036 return nil 1037 } 1038} 1039 1040func typeName(url string) string { 1041 return url[strings.LastIndex(url, "/")+1:] 1042} 1043