1package msgregistry
2
3import (
4	"bytes"
5	"fmt"
6	"reflect"
7	"sort"
8	"strings"
9	"sync"
10	"sync/atomic"
11
12	"github.com/golang/protobuf/proto"
13	"github.com/golang/protobuf/protoc-gen-go/descriptor"
14	"github.com/golang/protobuf/ptypes/wrappers"
15	"golang.org/x/net/context"
16	"google.golang.org/genproto/protobuf/api"
17	"google.golang.org/genproto/protobuf/ptype"
18
19	"github.com/jhump/protoreflect/desc"
20	"github.com/jhump/protoreflect/dynamic"
21)
22
23var (
24	enumOptionsDesc, enumValueOptionsDesc *desc.MessageDescriptor
25	msgOptionsDesc, fieldOptionsDesc      *desc.MessageDescriptor
26	svcOptionsDesc, methodOptionsDesc     *desc.MessageDescriptor
27)
28
29func init() {
30	var err error
31	enumOptionsDesc, err = desc.LoadMessageDescriptorForMessage((*descriptor.EnumOptions)(nil))
32	if err != nil {
33		panic("Failed to load descriptor for EnumOptions")
34	}
35	enumValueOptionsDesc, err = desc.LoadMessageDescriptorForMessage((*descriptor.EnumValueOptions)(nil))
36	if err != nil {
37		panic("Failed to load descriptor for EnumValueOptions")
38	}
39	msgOptionsDesc, err = desc.LoadMessageDescriptorForMessage((*descriptor.MessageOptions)(nil))
40	if err != nil {
41		panic("Failed to load descriptor for MessageOptions")
42	}
43	fieldOptionsDesc, err = desc.LoadMessageDescriptorForMessage((*descriptor.FieldOptions)(nil))
44	if err != nil {
45		panic("Failed to load descriptor for FieldOptions")
46	}
47	svcOptionsDesc, err = desc.LoadMessageDescriptorForMessage((*descriptor.ServiceOptions)(nil))
48	if err != nil {
49		panic("Failed to load descriptor for ServiceOptions")
50	}
51	methodOptionsDesc, err = desc.LoadMessageDescriptorForMessage((*descriptor.MethodOptions)(nil))
52	if err != nil {
53		panic("Failed to load descriptor for MethodOptions")
54	}
55}
56
57func ensureScheme(url string) string {
58	pos := strings.Index(url, "://")
59	if pos < 0 {
60		return "https://" + url
61	}
62	return url
63}
64
65// typeResolver is used by MessageRegistry to resolve message types. It uses a given TypeFetcher
66// to retrieve type definitions and caches resulting descriptor objects.
67type typeResolver struct {
68	fetcher TypeFetcher
69	mr      *MessageRegistry
70	mu      sync.RWMutex
71	cache   map[string]desc.Descriptor
72}
73
74// resolveUrlToMessageDescriptor returns a message descriptor that represents the type at the given URL.
75func (r *typeResolver) resolveUrlToMessageDescriptor(url string) (*desc.MessageDescriptor, error) {
76	url = ensureScheme(url)
77	r.mu.RLock()
78	cached := r.cache[url]
79	r.mu.RUnlock()
80	if cached != nil {
81		if md, ok := cached.(*desc.MessageDescriptor); ok {
82			return md, nil
83		} else {
84			return nil, fmt.Errorf("type for URL %v is the wrong type: wanted message, got enum", url)
85		}
86	}
87
88	rc := newResolutionContext(r)
89	if err := rc.addType(url, false); err != nil {
90		return nil, err
91	}
92
93	var files map[string]*desc.FileDescriptor
94	files, err := rc.toFileDescriptors(r.mr)
95	if err != nil {
96		return nil, err
97	}
98	r.mu.Lock()
99	defer r.mu.Unlock()
100	var md *desc.MessageDescriptor
101	if len(rc.typeLocations) > 0 {
102		if r.cache == nil {
103			r.cache = map[string]desc.Descriptor{}
104		}
105	}
106	for typeUrl, fileName := range rc.typeLocations {
107		fd := files[fileName]
108		sym := fd.FindSymbol(typeName(typeUrl))
109		r.cache[typeUrl] = sym
110		if url == typeUrl {
111			md = sym.(*desc.MessageDescriptor)
112		}
113	}
114	return md, nil
115}
116
117// resolveUrlsToMessageDescriptors returns a map of the given URLs to corresponding
118// message descriptors that represent the types at those URLs.
119func (r *typeResolver) resolveUrlsToMessageDescriptors(urls ...string) (map[string]*desc.MessageDescriptor, error) {
120	ret := map[string]*desc.MessageDescriptor{}
121	var unresolved []string
122	r.mu.RLock()
123	for _, u := range urls {
124		u = ensureScheme(u)
125		cached := r.cache[u]
126		if cached != nil {
127			if md, ok := cached.(*desc.MessageDescriptor); ok {
128				ret[u] = md
129			} else {
130				r.mu.RUnlock()
131				return nil, fmt.Errorf("type for URL %v is the wrong type: wanted message, got enum", u)
132			}
133		} else {
134			ret[u] = nil
135			unresolved = append(unresolved, u)
136		}
137	}
138	r.mu.RUnlock()
139
140	if len(unresolved) == 0 {
141		return ret, nil
142	}
143
144	rc := newResolutionContext(r)
145	for _, u := range unresolved {
146		if err := rc.addType(u, false); err != nil {
147			return nil, err
148		}
149	}
150
151	var files map[string]*desc.FileDescriptor
152	files, err := rc.toFileDescriptors(r.mr)
153	if err != nil {
154		return nil, err
155	}
156	r.mu.Lock()
157	defer r.mu.Unlock()
158	if len(rc.typeLocations) > 0 {
159		if r.cache == nil {
160			r.cache = map[string]desc.Descriptor{}
161		}
162	}
163	for typeUrl, fileName := range rc.typeLocations {
164		fd := files[fileName]
165		sym := fd.FindSymbol(typeName(typeUrl))
166		r.cache[typeUrl] = sym
167		if _, ok := ret[typeUrl]; ok {
168			ret[typeUrl] = sym.(*desc.MessageDescriptor)
169		}
170	}
171	return ret, nil
172}
173
174// resolveUrlToEnumDescriptor returns an enum descriptor that represents the enum type at the given URL.
175func (r *typeResolver) resolveUrlToEnumDescriptor(url string) (*desc.EnumDescriptor, error) {
176	url = ensureScheme(url)
177	r.mu.RLock()
178	cached := r.cache[url]
179	r.mu.RUnlock()
180	if cached != nil {
181		if ed, ok := cached.(*desc.EnumDescriptor); ok {
182			return ed, nil
183		} else {
184			return nil, fmt.Errorf("type for URL %v is the wrong type: wanted enum, got message", url)
185		}
186	}
187
188	rc := newResolutionContext(r)
189	if err := rc.addType(url, true); err != nil {
190		return nil, err
191	}
192
193	var files map[string]*desc.FileDescriptor
194	files, err := rc.toFileDescriptors(r.mr)
195	if err != nil {
196		return nil, err
197	}
198	r.mu.Lock()
199	defer r.mu.Unlock()
200	var ed *desc.EnumDescriptor
201	if len(rc.typeLocations) > 0 {
202		if r.cache == nil {
203			r.cache = map[string]desc.Descriptor{}
204		}
205	}
206	for typeUrl, fileName := range rc.typeLocations {
207		fd := files[fileName]
208		sym := fd.FindSymbol(typeName(typeUrl))
209		r.cache[typeUrl] = sym
210		if url == typeUrl {
211			ed = sym.(*desc.EnumDescriptor)
212		}
213	}
214	return ed, nil
215}
216
217type tracker func(d desc.Descriptor) bool
218
219func newNameTracker() tracker {
220	names := map[string]struct{}{}
221	return func(d desc.Descriptor) bool {
222		name := d.GetFullyQualifiedName()
223		if _, ok := names[name]; ok {
224			return false
225		}
226		names[name] = struct{}{}
227		return true
228	}
229}
230
231func addDescriptors(ref string, files map[string]*fileEntry, d desc.Descriptor, msgs map[string]*desc.MessageDescriptor, onAdd tracker) {
232	name := d.GetFullyQualifiedName()
233
234	fileName := d.GetFile().GetName()
235	if fileName != ref {
236		dependee := files[ref]
237		if dependee.deps == nil {
238			dependee.deps = map[string]struct{}{}
239		}
240		dependee.deps[fileName] = struct{}{}
241	}
242
243	if !onAdd(d) {
244		// already added this one
245		return
246	}
247
248	fe := files[fileName]
249	if fe == nil {
250		fe = &fileEntry{}
251		fe.proto3 = d.GetFile().IsProto3()
252		files[fileName] = fe
253	}
254	fe.types.addType(name, d.AsProto())
255
256	if md, ok := d.(*desc.MessageDescriptor); ok {
257		for _, fld := range md.GetFields() {
258			if fld.GetType() == descriptor.FieldDescriptorProto_TYPE_MESSAGE || fld.GetType() == descriptor.FieldDescriptorProto_TYPE_GROUP {
259				// prefer descriptor in msgs map over what the field descriptor indicates
260				md := msgs[fld.GetMessageType().GetFullyQualifiedName()]
261				if md == nil {
262					md = fld.GetMessageType()
263				}
264				addDescriptors(fileName, files, md, msgs, onAdd)
265			} else if fld.GetType() == descriptor.FieldDescriptorProto_TYPE_ENUM {
266				addDescriptors(fileName, files, fld.GetEnumType(), msgs, onAdd)
267			}
268		}
269	}
270}
271
272// resolutionContext provides the state for a resolution operation, accumulating details about
273// type descriptions and the files that contain them.
274type resolutionContext struct {
275	// The context and cancel function, used to coordinate multiple goroutines when there are multiple
276	// type or enum descriptions to download.
277	ctx    context.Context
278	cancel func()
279	res    *typeResolver
280
281	mu sync.Mutex
282	// map of file names to details regarding the files' contents
283	files map[string]*fileEntry
284	// map of type URLs to the file name that defines them
285	typeLocations map[string]string
286	// count of source contexts that do not indicate a file name (used to generate unique file names
287	// when synthesizing file descriptors)
288	unknownCount int
289}
290
291func newResolutionContext(res *typeResolver) *resolutionContext {
292	ctx, cancel := context.WithCancel(context.Background())
293	return &resolutionContext{
294		ctx:           ctx,
295		cancel:        cancel,
296		res:           res,
297		typeLocations: map[string]string{},
298		files:         map[string]*fileEntry{},
299	}
300}
301
302// addType adds the type at the given URL to the context, using the given fetcher to download the type's
303// description. This function will recursively add dependencies (e.g. types referenced by the given type's
304// fields if it is a message type), fetching their type descriptions concurrently.
305func (rc *resolutionContext) addType(url string, enum bool) error {
306	if err := rc.ctx.Err(); err != nil {
307		return err
308	}
309
310	m, err := rc.res.fetcher(url, enum)
311	if err != nil {
312		return err
313	} else if m == nil {
314		return fmt.Errorf("failed to locate type for %s", url)
315	}
316
317	if enum {
318		rc.recordEnum(url, m.(*ptype.Enum))
319		return nil
320	}
321
322	// for messages, resolve dependencies in parallel
323	t := m.(*ptype.Type)
324	fe, fileName := rc.recordType(url, t)
325	if fe == nil {
326		// already resolved this one
327		return nil
328	}
329
330	var wg sync.WaitGroup
331	var failed int32
332	for _, f := range t.Fields {
333		if f.Kind == ptype.Field_TYPE_GROUP || f.Kind == ptype.Field_TYPE_MESSAGE || f.Kind == ptype.Field_TYPE_ENUM {
334			typeUrl := ensureScheme(f.TypeUrl)
335			kind := f.Kind
336			wg.Add(1)
337			go func() {
338				defer wg.Done()
339				// first check the registry for descriptors
340				var d desc.Descriptor
341				var innerErr error
342				if kind == ptype.Field_TYPE_ENUM {
343					var ed *desc.EnumDescriptor
344					ed, innerErr = rc.res.mr.getRegisteredEnumTypeByUrl(typeUrl)
345					if ed != nil {
346						d = ed
347					}
348				} else {
349					var md *desc.MessageDescriptor
350					md, innerErr = rc.res.mr.getRegisteredMessageTypeByUrl(typeUrl)
351					if md != nil {
352						d = md
353					}
354				}
355
356				if innerErr == nil {
357					if d != nil {
358						// found it!
359						rc.recordDescriptor(typeUrl, fileName, d)
360					} else {
361						// not in registry, so we have to recursively fetch
362						innerErr = rc.addType(typeUrl, kind == ptype.Field_TYPE_ENUM)
363					}
364				}
365
366				// We want the "real" error to ultimately propagate to root, not
367				// one of the resulting cancellations (from any concurrent goroutines
368				// working in the same resolution context).
369				if innerErr != nil && (rc.ctx.Err() == nil || innerErr != context.Canceled) {
370					if atomic.CompareAndSwapInt32(&failed, 0, 1) {
371						err = innerErr
372					}
373					rc.cancel()
374				}
375			}()
376		}
377	}
378	wg.Wait()
379	if err != nil {
380		return err
381	}
382	// double-check if context has been cancelled
383	if err = rc.ctx.Err(); err != nil {
384		return err
385	}
386
387	rc.mu.Lock()
388	defer rc.mu.Unlock()
389
390	for _, f := range t.Fields {
391		if f.Kind == ptype.Field_TYPE_GROUP || f.Kind == ptype.Field_TYPE_MESSAGE || f.Kind == ptype.Field_TYPE_ENUM {
392			typeUrl := ensureScheme(f.TypeUrl)
393			if fe.deps == nil {
394				fe.deps = map[string]struct{}{}
395			}
396			dep := rc.typeLocations[typeUrl]
397			if dep != fileName {
398				fe.deps[dep] = struct{}{}
399			}
400		}
401	}
402	return nil
403}
404
405func (rc *resolutionContext) recordEnum(url string, e *ptype.Enum) {
406	rc.mu.Lock()
407	defer rc.mu.Unlock()
408
409	var fileName string
410	if e.SourceContext != nil && e.SourceContext.FileName != "" {
411		fileName = e.SourceContext.FileName
412	} else {
413		fileName = fmt.Sprintf("--unknown--%d.proto", rc.unknownCount)
414		rc.unknownCount++
415	}
416	rc.typeLocations[url] = fileName
417
418	fe := rc.files[fileName]
419	if fe == nil {
420		fe = &fileEntry{}
421		rc.files[fileName] = fe
422	}
423	fe.types.addType(e.Name, e)
424	if e.Syntax == ptype.Syntax_SYNTAX_PROTO3 {
425		fe.proto3 = true
426	}
427}
428
429func (rc *resolutionContext) recordType(url string, t *ptype.Type) (*fileEntry, string) {
430	rc.mu.Lock()
431	defer rc.mu.Unlock()
432
433	if _, ok := rc.typeLocations[url]; ok {
434		return nil, ""
435	}
436
437	var fileName string
438	if t.SourceContext != nil && t.SourceContext.FileName != "" {
439		fileName = t.SourceContext.FileName
440	} else {
441		fileName = fmt.Sprintf("--unknown--%d.proto", rc.unknownCount)
442		rc.unknownCount++
443	}
444	rc.typeLocations[url] = fileName
445
446	fe := rc.files[fileName]
447	if fe == nil {
448		fe = &fileEntry{}
449		rc.files[fileName] = fe
450	}
451	fe.types.addType(t.Name, t)
452	if t.Syntax == ptype.Syntax_SYNTAX_PROTO3 {
453		fe.proto3 = true
454	}
455
456	return fe, fileName
457}
458
459func (rc *resolutionContext) recordDescriptor(url, ref string, d desc.Descriptor) {
460	rc.mu.Lock()
461	defer rc.mu.Unlock()
462
463	addDescriptors(ref, rc.files, d, nil, func(dsc desc.Descriptor) bool {
464		u := ensureScheme(rc.res.mr.ComputeUrl(dsc))
465		if _, ok := rc.typeLocations[u]; ok {
466			// already seen this one
467			return false
468		}
469		fileName := dsc.GetFile().GetName()
470		rc.typeLocations[u] = fileName
471		if dsc == d {
472			// make sure we're also adding the actual URL reference used
473			rc.typeLocations[url] = fileName
474		}
475		return true
476	})
477}
478
479// toFileDescriptors converts the information in the context into a map of file names to file descriptors.
480func (rc *resolutionContext) toFileDescriptors(mr *MessageRegistry) (map[string]*desc.FileDescriptor, error) {
481	return toFileDescriptors(rc.files, func(tt *typeTrie, name string) (proto.Message, error) {
482		mdp, edp := tt.ptypeToDescriptor(name, mr)
483		if mdp != nil {
484			return mdp, nil
485		} else {
486			return edp, nil
487		}
488	})
489}
490
491// converts a map of file entries into a map of file descriptors using the given function to convert
492// each trie node into a descriptor proto.
493func toFileDescriptors(files map[string]*fileEntry, trieFn func(*typeTrie, string) (proto.Message, error)) (map[string]*desc.FileDescriptor, error) {
494	fdps := map[string]*descriptor.FileDescriptorProto{}
495	for name, file := range files {
496		fdp, err := file.toFileDescriptor(name, trieFn)
497		if err != nil {
498			return nil, err
499		}
500		fdps[name] = fdp
501	}
502	fds := map[string]*desc.FileDescriptor{}
503	for name, fdp := range fdps {
504		if _, ok := fds[name]; ok {
505			continue
506		}
507		var err error
508		if fds[name], err = makeFileDesc(fdp, fds, fdps); err != nil {
509			return nil, err
510		}
511	}
512	return fds, nil
513}
514
515func makeFileDesc(fdp *descriptor.FileDescriptorProto, fds map[string]*desc.FileDescriptor, fdps map[string]*descriptor.FileDescriptorProto) (*desc.FileDescriptor, error) {
516	deps := make([]*desc.FileDescriptor, len(fdp.Dependency))
517	for i, dep := range fdp.Dependency {
518		d := fds[dep]
519		if d == nil {
520			var err error
521			depFd := fdps[dep]
522			if depFd == nil {
523				return nil, fmt.Errorf("missing dependency: %s", dep)
524			}
525			d, err = makeFileDesc(depFd, fds, fdps)
526			if err != nil {
527				return nil, err
528			}
529		}
530		deps[i] = d
531	}
532	if fd, err := desc.CreateFileDescriptor(fdp, deps...); err != nil {
533		return nil, err
534	} else {
535		fds[fdp.GetName()] = fd
536		return fd, nil
537	}
538}
539
540// fileEntry represents the contents of a single file.
541type fileEntry struct {
542	types  typeTrie
543	deps   map[string]struct{}
544	proto3 bool
545}
546
547// toFileDescriptor converts this file entry into a file descriptor proto. The given function
548// is used to transform nodes in a typeTrie into message and/or enum descriptor protos.
549func (fe *fileEntry) toFileDescriptor(name string, trieFn func(*typeTrie, string) (proto.Message, error)) (*descriptor.FileDescriptorProto, error) {
550	var pkg bytes.Buffer
551	tt := &fe.types
552	first := true
553	last := ""
554	for tt.typ == nil {
555		if last != "" {
556			if first {
557				first = false
558			} else {
559				pkg.WriteByte('.')
560			}
561			pkg.WriteString(last)
562		}
563		if len(tt.children) != 1 {
564			break
565		}
566		for last, tt = range tt.children {
567		}
568	}
569	fd := createFileDescriptor(name, pkg.String(), fe.proto3, fe.deps)
570	if tt.typ != nil {
571		pm, err := trieFn(tt, last)
572		if err != nil {
573			return nil, err
574		}
575		if mdp, ok := pm.(*descriptor.DescriptorProto); ok {
576			fd.MessageType = append(fd.MessageType, mdp)
577		} else if edp, ok := pm.(*descriptor.EnumDescriptorProto); ok {
578			fd.EnumType = append(fd.EnumType, edp)
579		} else {
580			sdp := pm.(*descriptor.ServiceDescriptorProto)
581			fd.Service = append(fd.Service, sdp)
582		}
583	} else {
584		for name, nested := range tt.children {
585			pm, err := trieFn(nested, name)
586			if err != nil {
587				return nil, err
588			}
589			if mdp, ok := pm.(*descriptor.DescriptorProto); ok {
590				fd.MessageType = append(fd.MessageType, mdp)
591			} else if edp, ok := pm.(*descriptor.EnumDescriptorProto); ok {
592				fd.EnumType = append(fd.EnumType, edp)
593			} else {
594				sdp := pm.(*descriptor.ServiceDescriptorProto)
595				fd.Service = append(fd.Service, sdp)
596			}
597		}
598	}
599	return fd, nil
600}
601
602// typeTrie is a prefix trie where each key component is part of a fully-qualified type name. So key components
603// will either be package name components or element names.
604type typeTrie struct {
605	// successor key components
606	children map[string]*typeTrie
607	// if non-nil, the element whose fully-qualified name is the path from the trie root to this node
608	typ proto.Message
609}
610
611// addType recursively adds an element to the trie.
612func (t *typeTrie) addType(key string, typ proto.Message) {
613	if key == "" {
614		t.typ = typ
615		return
616	}
617	if t.children == nil {
618		t.children = map[string]*typeTrie{}
619	}
620	curr, rest := split(key)
621	child := t.children[curr]
622	if child == nil {
623		child = &typeTrie{}
624		t.children[curr] = child
625	}
626	child.addType(rest, typ)
627}
628
629// ptypeToDescriptor converts this level of the trie into a message or enum
630// descriptor proto, requiring that the element stored in t.typ is a *ptype.Type
631// or *ptype.Enum. If t.typ is nil, a placeholder message (with no fields) is
632// returned that contains the trie's children as nested message and/or enum
633// types.
634//
635// If the value in t.typ is already a *descriptor.DescriptorProto or a
636// *descriptor.EnumDescriptorProto then it is returned as is. This function
637// should not be used in type tries that may have service descriptors. That will
638// result in a panic.
639func (t *typeTrie) ptypeToDescriptor(name string, mr *MessageRegistry) (*descriptor.DescriptorProto, *descriptor.EnumDescriptorProto) {
640	switch typ := t.typ.(type) {
641	case *descriptor.EnumDescriptorProto:
642		return nil, typ
643	case *ptype.Enum:
644		return nil, createEnumDescriptor(typ, mr)
645	case *descriptor.DescriptorProto:
646		return typ, nil
647	default:
648		var msg *descriptor.DescriptorProto
649		if t.typ == nil {
650			msg = createIntermediateMessageDescriptor(name)
651		} else {
652			msg = createMessageDescriptor(t.typ.(*ptype.Type), mr)
653		}
654		// sort children for deterministic output
655		var keys []string
656		for k := range t.children {
657			keys = append(keys, k)
658		}
659		for _, name := range keys {
660			nested := t.children[name]
661			chMsg, chEnum := nested.ptypeToDescriptor(name, mr)
662			if chMsg != nil {
663				msg.NestedType = append(msg.NestedType, chMsg)
664			}
665			if chEnum != nil {
666				msg.EnumType = append(msg.EnumType, chEnum)
667			}
668		}
669		return msg, nil
670	}
671}
672
673// rewriteDescriptor converts this level of the trie into a new descriptor
674// proto, requiring that the element stored in t.type is already a service,
675// message, or enum descriptor proto. If this trie has children then t.typ must
676// be a message descriptor proto. The returned descriptor proto is the same as
677// .type but with possibly new nested elements to represent this trie node's
678// children.
679func (t *typeTrie) rewriteDescriptor(name string) (proto.Message, error) {
680	if len(t.children) == 0 && t.typ != nil {
681		if mdp, ok := t.typ.(*descriptor.DescriptorProto); ok {
682			if len(mdp.NestedType) == 0 && len(mdp.EnumType) == 0 {
683				return mdp, nil
684			}
685			mdp = proto.Clone(mdp).(*descriptor.DescriptorProto)
686			mdp.NestedType = nil
687			mdp.EnumType = nil
688			return mdp, nil
689		}
690		return t.typ, nil
691	}
692	var mdp *descriptor.DescriptorProto
693	if t.typ == nil {
694		mdp = createIntermediateMessageDescriptor(name)
695	} else {
696		mdp = t.typ.(*descriptor.DescriptorProto)
697		mdp = proto.Clone(mdp).(*descriptor.DescriptorProto)
698		mdp.NestedType = nil
699		mdp.EnumType = nil
700	}
701	// sort children for deterministic output
702	var keys []string
703	for k := range t.children {
704		keys = append(keys, k)
705	}
706	for _, n := range keys {
707		ch := t.children[n]
708		typ, err := ch.rewriteDescriptor(n)
709		if err != nil {
710			return nil, err
711		}
712		switch typ := typ.(type) {
713		case (*descriptor.DescriptorProto):
714			mdp.NestedType = append(mdp.NestedType, typ)
715		case (*descriptor.EnumDescriptorProto):
716			mdp.EnumType = append(mdp.EnumType, typ)
717		default:
718			// TODO: this should probably panic instead
719			return nil, fmt.Errorf("invalid descriptor trie: message cannot have child of type %v", reflect.TypeOf(typ))
720		}
721	}
722	return mdp, nil
723}
724
725func split(s string) (string, string) {
726	pos := strings.Index(s, ".")
727	if pos >= 0 {
728		return s[:pos], s[pos+1:]
729	} else {
730		return s, ""
731	}
732}
733
734func createEnumDescriptor(e *ptype.Enum, mr *MessageRegistry) *descriptor.EnumDescriptorProto {
735	var opts *descriptor.EnumOptions
736	if len(e.Options) > 0 {
737		dopts := createOptions(e.Options, enumOptionsDesc, mr)
738		opts = &descriptor.EnumOptions{}
739		dopts.ConvertTo(opts) // ignore any error
740	}
741
742	var vals []*descriptor.EnumValueDescriptorProto
743	for _, v := range e.Enumvalue {
744		evd := createEnumValueDescriptor(v, mr)
745		vals = append(vals, evd)
746	}
747
748	return &descriptor.EnumDescriptorProto{
749		Name:    proto.String(base(e.Name)),
750		Options: opts,
751		Value:   vals,
752	}
753}
754
755func createEnumValueDescriptor(v *ptype.EnumValue, mr *MessageRegistry) *descriptor.EnumValueDescriptorProto {
756	var opts *descriptor.EnumValueOptions
757	if len(v.Options) > 0 {
758		dopts := createOptions(v.Options, enumValueOptionsDesc, mr)
759		opts = &descriptor.EnumValueOptions{}
760		dopts.ConvertTo(opts) // ignore any error
761	}
762
763	return &descriptor.EnumValueDescriptorProto{
764		Name:    proto.String(v.Name),
765		Number:  proto.Int32(v.Number),
766		Options: opts,
767	}
768}
769
770func createMessageDescriptor(m *ptype.Type, mr *MessageRegistry) *descriptor.DescriptorProto {
771	var opts *descriptor.MessageOptions
772	if len(m.Options) > 0 {
773		dopts := createOptions(m.Options, msgOptionsDesc, mr)
774		opts = &descriptor.MessageOptions{}
775		dopts.ConvertTo(opts) // ignore any error
776	}
777
778	var fields []*descriptor.FieldDescriptorProto
779	for _, f := range m.Fields {
780		fields = append(fields, createFieldDescriptor(f, mr))
781	}
782
783	var oneOfs []*descriptor.OneofDescriptorProto
784	for _, o := range m.Oneofs {
785		oneOfs = append(oneOfs, &descriptor.OneofDescriptorProto{
786			Name: proto.String(o),
787		})
788	}
789
790	return &descriptor.DescriptorProto{
791		Name:      proto.String(base(m.Name)),
792		Options:   opts,
793		Field:     fields,
794		OneofDecl: oneOfs,
795	}
796}
797
798func createFieldDescriptor(f *ptype.Field, mr *MessageRegistry) *descriptor.FieldDescriptorProto {
799	var opts *descriptor.FieldOptions
800	if len(f.Options) > 0 {
801		dopts := createOptions(f.Options, fieldOptionsDesc, mr)
802		opts = &descriptor.FieldOptions{}
803		dopts.ConvertTo(opts) // ignore any error
804	}
805	if f.Packed {
806		if opts == nil {
807			opts = &descriptor.FieldOptions{Packed: proto.Bool(true)}
808		} else {
809			opts.Packed = proto.Bool(true)
810		}
811	}
812
813	var oneOf *int32
814	if f.OneofIndex > 0 {
815		oneOf = proto.Int32(f.OneofIndex - 1)
816	}
817
818	var typeName string
819	if f.Kind == ptype.Field_TYPE_GROUP || f.Kind == ptype.Field_TYPE_MESSAGE || f.Kind == ptype.Field_TYPE_ENUM {
820		pos := strings.LastIndex(f.TypeUrl, "/")
821		typeName = "." + f.TypeUrl[pos+1:]
822	}
823
824	var label descriptor.FieldDescriptorProto_Label
825	switch f.Cardinality {
826	case ptype.Field_CARDINALITY_OPTIONAL:
827		label = descriptor.FieldDescriptorProto_LABEL_OPTIONAL
828	case ptype.Field_CARDINALITY_REPEATED:
829		label = descriptor.FieldDescriptorProto_LABEL_REPEATED
830	case ptype.Field_CARDINALITY_REQUIRED:
831		label = descriptor.FieldDescriptorProto_LABEL_REQUIRED
832	}
833
834	var typ descriptor.FieldDescriptorProto_Type
835	switch f.Kind {
836	case ptype.Field_TYPE_ENUM:
837		typ = descriptor.FieldDescriptorProto_TYPE_ENUM
838	case ptype.Field_TYPE_GROUP:
839		typ = descriptor.FieldDescriptorProto_TYPE_GROUP
840	case ptype.Field_TYPE_MESSAGE:
841		typ = descriptor.FieldDescriptorProto_TYPE_MESSAGE
842	case ptype.Field_TYPE_BYTES:
843		typ = descriptor.FieldDescriptorProto_TYPE_BYTES
844	case ptype.Field_TYPE_STRING:
845		typ = descriptor.FieldDescriptorProto_TYPE_STRING
846	case ptype.Field_TYPE_BOOL:
847		typ = descriptor.FieldDescriptorProto_TYPE_BOOL
848	case ptype.Field_TYPE_DOUBLE:
849		typ = descriptor.FieldDescriptorProto_TYPE_DOUBLE
850	case ptype.Field_TYPE_FLOAT:
851		typ = descriptor.FieldDescriptorProto_TYPE_FLOAT
852	case ptype.Field_TYPE_FIXED32:
853		typ = descriptor.FieldDescriptorProto_TYPE_FIXED32
854	case ptype.Field_TYPE_FIXED64:
855		typ = descriptor.FieldDescriptorProto_TYPE_FIXED64
856	case ptype.Field_TYPE_INT32:
857		typ = descriptor.FieldDescriptorProto_TYPE_INT32
858	case ptype.Field_TYPE_INT64:
859		typ = descriptor.FieldDescriptorProto_TYPE_INT64
860	case ptype.Field_TYPE_SFIXED32:
861		typ = descriptor.FieldDescriptorProto_TYPE_SFIXED32
862	case ptype.Field_TYPE_SFIXED64:
863		typ = descriptor.FieldDescriptorProto_TYPE_SFIXED64
864	case ptype.Field_TYPE_SINT32:
865		typ = descriptor.FieldDescriptorProto_TYPE_SINT32
866	case ptype.Field_TYPE_SINT64:
867		typ = descriptor.FieldDescriptorProto_TYPE_SINT64
868	case ptype.Field_TYPE_UINT32:
869		typ = descriptor.FieldDescriptorProto_TYPE_UINT32
870	case ptype.Field_TYPE_UINT64:
871		typ = descriptor.FieldDescriptorProto_TYPE_UINT64
872	}
873
874	return &descriptor.FieldDescriptorProto{
875		Name:         proto.String(f.Name),
876		Number:       proto.Int32(f.Number),
877		DefaultValue: proto.String(f.DefaultValue),
878		JsonName:     proto.String(f.JsonName),
879		OneofIndex:   oneOf,
880		TypeName:     proto.String(typeName),
881		Label:        label.Enum(),
882		Type:         typ.Enum(),
883		Options:      opts,
884	}
885}
886
887func createServiceDescriptor(a *api.Api, mr *MessageRegistry) *descriptor.ServiceDescriptorProto {
888	var opts *descriptor.ServiceOptions
889	if len(a.Options) > 0 {
890		dopts := createOptions(a.Options, svcOptionsDesc, mr)
891		opts = &descriptor.ServiceOptions{}
892		dopts.ConvertTo(opts) // ignore any error
893	}
894
895	methods := make([]*descriptor.MethodDescriptorProto, len(a.Methods))
896	for i, m := range a.Methods {
897		methods[i] = createMethodDescriptor(m, mr)
898	}
899
900	return &descriptor.ServiceDescriptorProto{
901		Name:    proto.String(base(a.Name)),
902		Method:  methods,
903		Options: opts,
904	}
905}
906
907func createMethodDescriptor(m *api.Method, mr *MessageRegistry) *descriptor.MethodDescriptorProto {
908	var opts *descriptor.MethodOptions
909	if len(m.Options) > 0 {
910		dopts := createOptions(m.Options, methodOptionsDesc, mr)
911		opts = &descriptor.MethodOptions{}
912		dopts.ConvertTo(opts) // ignore any error
913	}
914
915	var reqType, respType string
916	pos := strings.LastIndex(m.RequestTypeUrl, "/")
917	reqType = "." + m.RequestTypeUrl[pos+1:]
918	pos = strings.LastIndex(m.ResponseTypeUrl, "/")
919	respType = "." + m.ResponseTypeUrl[pos+1:]
920
921	return &descriptor.MethodDescriptorProto{
922		Name:            proto.String(m.Name),
923		Options:         opts,
924		ClientStreaming: proto.Bool(m.RequestStreaming),
925		ServerStreaming: proto.Bool(m.ResponseStreaming),
926		InputType:       proto.String(reqType),
927		OutputType:      proto.String(respType),
928	}
929}
930
931func createIntermediateMessageDescriptor(name string) *descriptor.DescriptorProto {
932	return &descriptor.DescriptorProto{
933		Name: proto.String(name),
934	}
935}
936
937func createFileDescriptor(name, pkg string, proto3 bool, deps map[string]struct{}) *descriptor.FileDescriptorProto {
938	imports := make([]string, 0, len(deps))
939	for k := range deps {
940		imports = append(imports, k)
941	}
942	sort.Strings(imports)
943	var syntax string
944	if proto3 {
945		syntax = "proto3"
946	} else {
947		syntax = "proto2"
948	}
949	return &descriptor.FileDescriptorProto{
950		Name:       proto.String(name),
951		Package:    proto.String(pkg),
952		Syntax:     proto.String(syntax),
953		Dependency: imports,
954	}
955}
956
957func createOptions(options []*ptype.Option, optionsDesc *desc.MessageDescriptor, mr *MessageRegistry) *dynamic.Message {
958	// these are created "best effort" so entries which are unresolvable
959	// (or seemingly invalid) are simply ignored...
960	dopts := mr.mf.NewDynamicMessage(optionsDesc)
961	for _, o := range options {
962		field := optionsDesc.FindFieldByName(o.Name)
963		if field == nil {
964			field = mr.er.FindExtensionByName(optionsDesc.GetFullyQualifiedName(), o.Name)
965			if field == nil && o.Name[0] != '[' {
966				field = mr.er.FindExtensionByName(optionsDesc.GetFullyQualifiedName(), fmt.Sprintf("[%s]", o.Name))
967			}
968			if field == nil {
969				// can't resolve option name? skip it
970				continue
971			}
972		}
973		v, err := mr.unmarshalAny(o.Value, func(url string) (*desc.MessageDescriptor, error) {
974			// we don't want to try to recursively fetch this value's type, so if it doesn't
975			// match the type of the extension field, we'll skip it
976			if (field.GetType() == descriptor.FieldDescriptorProto_TYPE_GROUP ||
977				field.GetType() == descriptor.FieldDescriptorProto_TYPE_MESSAGE) &&
978				typeName(url) == field.GetMessageType().GetFullyQualifiedName() {
979
980				return field.GetMessageType(), nil
981			}
982			return nil, nil
983		})
984		if err != nil {
985			// can't interpret value? skip it
986			continue
987		}
988		var fv interface{}
989		if field.GetType() != descriptor.FieldDescriptorProto_TYPE_MESSAGE && field.GetType() != descriptor.FieldDescriptorProto_TYPE_GROUP {
990			fv = unwrap(v)
991			if v == nil {
992				// non-wrapper type for scalar field? skip it
993				continue
994			}
995		} else {
996			fv = v
997		}
998		if field.IsRepeated() {
999			dopts.TryAddRepeatedField(field, fv) // ignore any error
1000		} else {
1001			dopts.TrySetField(field, fv) // ignore any error
1002		}
1003	}
1004	return dopts
1005}
1006
1007func base(name string) string {
1008	pos := strings.LastIndex(name, ".")
1009	if pos >= 0 {
1010		return name[pos+1:]
1011	}
1012	return name
1013}
1014
1015func unwrap(msg proto.Message) interface{} {
1016	switch m := msg.(type) {
1017	case (*wrappers.BoolValue):
1018		return m.Value
1019	case (*wrappers.FloatValue):
1020		return m.Value
1021	case (*wrappers.DoubleValue):
1022		return m.Value
1023	case (*wrappers.Int32Value):
1024		return m.Value
1025	case (*wrappers.Int64Value):
1026		return m.Value
1027	case (*wrappers.UInt32Value):
1028		return m.Value
1029	case (*wrappers.UInt64Value):
1030		return m.Value
1031	case (*wrappers.BytesValue):
1032		return m.Value
1033	case (*wrappers.StringValue):
1034		return m.Value
1035	default:
1036		return nil
1037	}
1038}
1039
1040func typeName(url string) string {
1041	return url[strings.LastIndex(url, "/")+1:]
1042}
1043