1package grpcreflect
2
3import (
4	"bytes"
5	"fmt"
6	"io"
7	"reflect"
8	"runtime"
9	"sync"
10
11	"github.com/golang/protobuf/proto"
12	dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
13	"golang.org/x/net/context"
14	"google.golang.org/grpc/codes"
15	rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
16	"google.golang.org/grpc/status"
17
18	"github.com/jhump/protoreflect/desc"
19	"github.com/jhump/protoreflect/internal"
20)
21
22// elementNotFoundError is the error returned by reflective operations where the
23// server does not recognize a given file name, symbol name, or extension.
24type elementNotFoundError struct {
25	name    string
26	kind    elementKind
27	symType symbolType // only used when kind == elementKindSymbol
28	tag     int32      // only used when kind == elementKindExtension
29
30	// only errors with a kind of elementKindFile will have a cause, which means
31	// the named file count not be resolved because of a dependency that could
32	// not be found where cause describes the missing dependency
33	cause *elementNotFoundError
34}
35
36type elementKind int
37
38const (
39	elementKindSymbol elementKind = iota
40	elementKindFile
41	elementKindExtension
42)
43
44type symbolType string
45
46const (
47	symbolTypeService = "Service"
48	symbolTypeMessage = "Message"
49	symbolTypeEnum    = "Enum"
50	symbolTypeUnknown = "Symbol"
51)
52
53func symbolNotFound(symbol string, symType symbolType, cause *elementNotFoundError) error {
54	return &elementNotFoundError{name: symbol, symType: symType, kind: elementKindSymbol, cause: cause}
55}
56
57func extensionNotFound(extendee string, tag int32, cause *elementNotFoundError) error {
58	return &elementNotFoundError{name: extendee, tag: tag, kind: elementKindExtension, cause: cause}
59}
60
61func fileNotFound(file string, cause *elementNotFoundError) error {
62	return &elementNotFoundError{name: file, kind: elementKindFile, cause: cause}
63}
64
65func (e *elementNotFoundError) Error() string {
66	first := true
67	var b bytes.Buffer
68	for ; e != nil; e = e.cause {
69		if first {
70			first = false
71		} else {
72			fmt.Fprint(&b, "\ncaused by: ")
73		}
74		switch e.kind {
75		case elementKindSymbol:
76			fmt.Fprintf(&b, "%s not found: %s", e.symType, e.name)
77		case elementKindExtension:
78			fmt.Fprintf(&b, "Extension not found: tag %d for %s", e.tag, e.name)
79		default:
80			fmt.Fprintf(&b, "File not found: %s", e.name)
81		}
82	}
83	return b.String()
84}
85
86// IsElementNotFoundError determines if the given error indicates that a file
87// name, symbol name, or extension field was could not be found by the server.
88func IsElementNotFoundError(err error) bool {
89	_, ok := err.(*elementNotFoundError)
90	return ok
91}
92
93// ProtocolError is an error returned when the server sends a response of the
94// wrong type.
95type ProtocolError struct {
96	missingType reflect.Type
97}
98
99func (p ProtocolError) Error() string {
100	return fmt.Sprintf("Protocol error: response was missing %v", p.missingType)
101}
102
103type extDesc struct {
104	extendedMessageName string
105	extensionNumber     int32
106}
107
108// Client is a client connection to a server for performing reflection calls
109// and resolving remote symbols.
110type Client struct {
111	ctx  context.Context
112	stub rpb.ServerReflectionClient
113
114	connMu sync.Mutex
115	cancel context.CancelFunc
116	stream rpb.ServerReflection_ServerReflectionInfoClient
117
118	cacheMu          sync.RWMutex
119	protosByName     map[string]*dpb.FileDescriptorProto
120	filesByName      map[string]*desc.FileDescriptor
121	filesBySymbol    map[string]*desc.FileDescriptor
122	filesByExtension map[extDesc]*desc.FileDescriptor
123}
124
125// NewClient creates a new Client with the given root context and using the
126// given RPC stub for talking to the server.
127func NewClient(ctx context.Context, stub rpb.ServerReflectionClient) *Client {
128	cr := &Client{
129		ctx:              ctx,
130		stub:             stub,
131		protosByName:     map[string]*dpb.FileDescriptorProto{},
132		filesByName:      map[string]*desc.FileDescriptor{},
133		filesBySymbol:    map[string]*desc.FileDescriptor{},
134		filesByExtension: map[extDesc]*desc.FileDescriptor{},
135	}
136	// don't leak a grpc stream
137	runtime.SetFinalizer(cr, (*Client).Reset)
138	return cr
139}
140
141// FileByFilename asks the server for a file descriptor for the proto file with
142// the given name.
143func (cr *Client) FileByFilename(filename string) (*desc.FileDescriptor, error) {
144	// hit the cache first
145	cr.cacheMu.RLock()
146	if fd, ok := cr.filesByName[filename]; ok {
147		cr.cacheMu.RUnlock()
148		return fd, nil
149	}
150	fdp, ok := cr.protosByName[filename]
151	cr.cacheMu.RUnlock()
152	// not there? see if we've downloaded the proto
153	if ok {
154		return cr.descriptorFromProto(fdp)
155	}
156
157	req := &rpb.ServerReflectionRequest{
158		MessageRequest: &rpb.ServerReflectionRequest_FileByFilename{
159			FileByFilename: filename,
160		},
161	}
162	fd, err := cr.getAndCacheFileDescriptors(req, filename, "")
163	if isNotFound(err) {
164		// file not found? see if we can look up via alternate name
165		if alternate, ok := internal.StdFileAliases[filename]; ok {
166			req := &rpb.ServerReflectionRequest{
167				MessageRequest: &rpb.ServerReflectionRequest_FileByFilename{
168					FileByFilename: alternate,
169				},
170			}
171			fd, err = cr.getAndCacheFileDescriptors(req, alternate, filename)
172			if isNotFound(err) {
173				err = fileNotFound(filename, nil)
174			}
175		} else {
176			err = fileNotFound(filename, nil)
177		}
178	} else if e, ok := err.(*elementNotFoundError); ok {
179		err = fileNotFound(filename, e)
180	}
181	return fd, err
182}
183
184// FileContainingSymbol asks the server for a file descriptor for the proto file
185// that declares the given fully-qualified symbol.
186func (cr *Client) FileContainingSymbol(symbol string) (*desc.FileDescriptor, error) {
187	// hit the cache first
188	cr.cacheMu.RLock()
189	fd, ok := cr.filesBySymbol[symbol]
190	cr.cacheMu.RUnlock()
191	if ok {
192		return fd, nil
193	}
194
195	req := &rpb.ServerReflectionRequest{
196		MessageRequest: &rpb.ServerReflectionRequest_FileContainingSymbol{
197			FileContainingSymbol: symbol,
198		},
199	}
200	fd, err := cr.getAndCacheFileDescriptors(req, "", "")
201	if isNotFound(err) {
202		err = symbolNotFound(symbol, symbolTypeUnknown, nil)
203	} else if e, ok := err.(*elementNotFoundError); ok {
204		err = symbolNotFound(symbol, symbolTypeUnknown, e)
205	}
206	return fd, err
207}
208
209// FileContainingExtension asks the server for a file descriptor for the proto
210// file that declares an extension with the given number for the given
211// fully-qualified message name.
212func (cr *Client) FileContainingExtension(extendedMessageName string, extensionNumber int32) (*desc.FileDescriptor, error) {
213	// hit the cache first
214	cr.cacheMu.RLock()
215	fd, ok := cr.filesByExtension[extDesc{extendedMessageName, extensionNumber}]
216	cr.cacheMu.RUnlock()
217	if ok {
218		return fd, nil
219	}
220
221	req := &rpb.ServerReflectionRequest{
222		MessageRequest: &rpb.ServerReflectionRequest_FileContainingExtension{
223			FileContainingExtension: &rpb.ExtensionRequest{
224				ContainingType:  extendedMessageName,
225				ExtensionNumber: extensionNumber,
226			},
227		},
228	}
229	fd, err := cr.getAndCacheFileDescriptors(req, "", "")
230	if isNotFound(err) {
231		err = extensionNotFound(extendedMessageName, extensionNumber, nil)
232	} else if e, ok := err.(*elementNotFoundError); ok {
233		err = extensionNotFound(extendedMessageName, extensionNumber, e)
234	}
235	return fd, err
236}
237
238func (cr *Client) getAndCacheFileDescriptors(req *rpb.ServerReflectionRequest, expectedName, alias string) (*desc.FileDescriptor, error) {
239	resp, err := cr.send(req)
240	if err != nil {
241		return nil, err
242	}
243
244	fdResp := resp.GetFileDescriptorResponse()
245	if fdResp == nil {
246		return nil, &ProtocolError{reflect.TypeOf(fdResp).Elem()}
247	}
248
249	// Response can contain the result file descriptor, but also its transitive
250	// deps. Furthermore, protocol states that subsequent requests do not need
251	// to send transitive deps that have been sent in prior responses. So we
252	// need to cache all file descriptors and then return the first one (which
253	// should be the answer). If we're looking for a file by name, we can be
254	// smarter and make sure to grab one by name instead of just grabbing the
255	// first one.
256	var firstFd *dpb.FileDescriptorProto
257	for _, fdBytes := range fdResp.FileDescriptorProto {
258		fd := &dpb.FileDescriptorProto{}
259		if err = proto.Unmarshal(fdBytes, fd); err != nil {
260			return nil, err
261		}
262
263		if expectedName != "" && alias != "" && expectedName != alias && fd.GetName() == expectedName {
264			// we found a file was aliased, so we need to update the proto to reflect that
265			fd.Name = proto.String(alias)
266		}
267
268		cr.cacheMu.Lock()
269		// see if this file was created and cached concurrently
270		if firstFd == nil {
271			if d, ok := cr.filesByName[fd.GetName()]; ok {
272				cr.cacheMu.Unlock()
273				return d, nil
274			}
275		}
276		// store in cache of raw descriptor protos, but don't overwrite existing protos
277		if existingFd, ok := cr.protosByName[fd.GetName()]; ok {
278			fd = existingFd
279		} else {
280			cr.protosByName[fd.GetName()] = fd
281		}
282		cr.cacheMu.Unlock()
283		if firstFd == nil {
284			firstFd = fd
285		}
286	}
287	if firstFd == nil {
288		return nil, &ProtocolError{reflect.TypeOf(firstFd).Elem()}
289	}
290
291	return cr.descriptorFromProto(firstFd)
292}
293
294func (cr *Client) descriptorFromProto(fd *dpb.FileDescriptorProto) (*desc.FileDescriptor, error) {
295	deps := make([]*desc.FileDescriptor, len(fd.GetDependency()))
296	for i, depName := range fd.GetDependency() {
297		if dep, err := cr.FileByFilename(depName); err != nil {
298			return nil, err
299		} else {
300			deps[i] = dep
301		}
302	}
303	d, err := desc.CreateFileDescriptor(fd, deps...)
304	if err != nil {
305		return nil, err
306	}
307	d = cr.cacheFile(d)
308	return d, nil
309}
310
311func (cr *Client) cacheFile(fd *desc.FileDescriptor) *desc.FileDescriptor {
312	cr.cacheMu.Lock()
313	defer cr.cacheMu.Unlock()
314
315	// cache file descriptor by name, but don't overwrite existing entry
316	// (existing entry could come from concurrent caller)
317	if existingFd, ok := cr.filesByName[fd.GetName()]; ok {
318		return existingFd
319	}
320	cr.filesByName[fd.GetName()] = fd
321
322	// also cache by symbols and extensions
323	for _, m := range fd.GetMessageTypes() {
324		cr.cacheMessageLocked(fd, m)
325	}
326	for _, e := range fd.GetEnumTypes() {
327		cr.filesBySymbol[e.GetFullyQualifiedName()] = fd
328		for _, v := range e.GetValues() {
329			cr.filesBySymbol[v.GetFullyQualifiedName()] = fd
330		}
331	}
332	for _, e := range fd.GetExtensions() {
333		cr.filesBySymbol[e.GetFullyQualifiedName()] = fd
334		cr.filesByExtension[extDesc{e.GetOwner().GetFullyQualifiedName(), e.GetNumber()}] = fd
335	}
336	for _, s := range fd.GetServices() {
337		cr.filesBySymbol[s.GetFullyQualifiedName()] = fd
338		for _, m := range s.GetMethods() {
339			cr.filesBySymbol[m.GetFullyQualifiedName()] = fd
340		}
341	}
342
343	return fd
344}
345
346func (cr *Client) cacheMessageLocked(fd *desc.FileDescriptor, md *desc.MessageDescriptor) {
347	cr.filesBySymbol[md.GetFullyQualifiedName()] = fd
348	for _, f := range md.GetFields() {
349		cr.filesBySymbol[f.GetFullyQualifiedName()] = fd
350	}
351	for _, o := range md.GetOneOfs() {
352		cr.filesBySymbol[o.GetFullyQualifiedName()] = fd
353	}
354	for _, e := range md.GetNestedEnumTypes() {
355		cr.filesBySymbol[e.GetFullyQualifiedName()] = fd
356		for _, v := range e.GetValues() {
357			cr.filesBySymbol[v.GetFullyQualifiedName()] = fd
358		}
359	}
360	for _, e := range md.GetNestedExtensions() {
361		cr.filesBySymbol[e.GetFullyQualifiedName()] = fd
362		cr.filesByExtension[extDesc{e.GetOwner().GetFullyQualifiedName(), e.GetNumber()}] = fd
363	}
364	for _, m := range md.GetNestedMessageTypes() {
365		cr.cacheMessageLocked(fd, m) // recurse
366	}
367}
368
369// AllExtensionNumbersForType asks the server for all known extension numbers
370// for the given fully-qualified message name.
371func (cr *Client) AllExtensionNumbersForType(extendedMessageName string) ([]int32, error) {
372	req := &rpb.ServerReflectionRequest{
373		MessageRequest: &rpb.ServerReflectionRequest_AllExtensionNumbersOfType{
374			AllExtensionNumbersOfType: extendedMessageName,
375		},
376	}
377	resp, err := cr.send(req)
378	if err != nil {
379		if isNotFound(err) {
380			return nil, symbolNotFound(extendedMessageName, symbolTypeMessage, nil)
381		}
382		return nil, err
383	}
384
385	extResp := resp.GetAllExtensionNumbersResponse()
386	if extResp == nil {
387		return nil, &ProtocolError{reflect.TypeOf(extResp).Elem()}
388	}
389	return extResp.ExtensionNumber, nil
390}
391
392// ListServices asks the server for the fully-qualified names of all exposed
393// services.
394func (cr *Client) ListServices() ([]string, error) {
395	req := &rpb.ServerReflectionRequest{
396		MessageRequest: &rpb.ServerReflectionRequest_ListServices{
397			// proto doesn't indicate any purpose for this value and server impl
398			// doesn't actually use it...
399			ListServices: "*",
400		},
401	}
402	resp, err := cr.send(req)
403	if err != nil {
404		return nil, err
405	}
406
407	listResp := resp.GetListServicesResponse()
408	if listResp == nil {
409		return nil, &ProtocolError{reflect.TypeOf(listResp).Elem()}
410	}
411	serviceNames := make([]string, len(listResp.Service))
412	for i, s := range listResp.Service {
413		serviceNames[i] = s.Name
414	}
415	return serviceNames, nil
416}
417
418func (cr *Client) send(req *rpb.ServerReflectionRequest) (*rpb.ServerReflectionResponse, error) {
419	// we allow one immediate retry, in case we have a stale stream
420	// (e.g. closed by server)
421	resp, err := cr.doSend(true, req)
422	if err != nil {
423		return nil, err
424	}
425
426	// convert error response messages into errors
427	errResp := resp.GetErrorResponse()
428	if errResp != nil {
429		return nil, status.Errorf(codes.Code(errResp.ErrorCode), "%s", errResp.ErrorMessage)
430	}
431
432	return resp, nil
433}
434
435func isNotFound(err error) bool {
436	if err == nil {
437		return false
438	}
439	s, ok := status.FromError(err)
440	return ok && s.Code() == codes.NotFound
441}
442
443func (cr *Client) doSend(retry bool, req *rpb.ServerReflectionRequest) (*rpb.ServerReflectionResponse, error) {
444	// TODO: Streams are thread-safe, so we shouldn't need to lock. But without locking, we'll need more machinery
445	// (goroutines and channels) to ensure that responses are correctly correlated with their requests and thus
446	// delivered in correct oder.
447	cr.connMu.Lock()
448	defer cr.connMu.Unlock()
449	return cr.doSendLocked(retry, req)
450}
451
452func (cr *Client) doSendLocked(retry bool, req *rpb.ServerReflectionRequest) (*rpb.ServerReflectionResponse, error) {
453	if err := cr.initStreamLocked(); err != nil {
454		return nil, err
455	}
456
457	if err := cr.stream.Send(req); err != nil {
458		if err == io.EOF {
459			// if send returns EOF, must call Recv to get real underlying error
460			_, err = cr.stream.Recv()
461		}
462		cr.resetLocked()
463		if retry {
464			return cr.doSendLocked(false, req)
465		}
466		return nil, err
467	}
468
469	if resp, err := cr.stream.Recv(); err != nil {
470		cr.resetLocked()
471		if retry {
472			return cr.doSendLocked(false, req)
473		}
474		return nil, err
475	} else {
476		return resp, nil
477	}
478}
479
480func (cr *Client) initStreamLocked() error {
481	if cr.stream != nil {
482		return nil
483	}
484	var newCtx context.Context
485	newCtx, cr.cancel = context.WithCancel(cr.ctx)
486	var err error
487	cr.stream, err = cr.stub.ServerReflectionInfo(newCtx)
488	return err
489}
490
491// Reset ensures that any active stream with the server is closed, releasing any
492// resources.
493func (cr *Client) Reset() {
494	cr.connMu.Lock()
495	defer cr.connMu.Unlock()
496	cr.resetLocked()
497}
498
499func (cr *Client) resetLocked() {
500	if cr.stream != nil {
501		cr.stream.CloseSend()
502		for {
503			// drain the stream, this covers io.EOF too
504			if _, err := cr.stream.Recv(); err != nil {
505				break
506			}
507		}
508		cr.stream = nil
509	}
510	if cr.cancel != nil {
511		cr.cancel()
512		cr.cancel = nil
513	}
514}
515
516// ResolveService asks the server to resolve the given fully-qualified service
517// name into a service descriptor.
518func (cr *Client) ResolveService(serviceName string) (*desc.ServiceDescriptor, error) {
519	file, err := cr.FileContainingSymbol(serviceName)
520	if err != nil {
521		return nil, setSymbolType(err, serviceName, symbolTypeService)
522	}
523	d := file.FindSymbol(serviceName)
524	if d == nil {
525		return nil, symbolNotFound(serviceName, symbolTypeService, nil)
526	}
527	if s, ok := d.(*desc.ServiceDescriptor); ok {
528		return s, nil
529	} else {
530		return nil, symbolNotFound(serviceName, symbolTypeService, nil)
531	}
532}
533
534// ResolveMessage asks the server to resolve the given fully-qualified message
535// name into a message descriptor.
536func (cr *Client) ResolveMessage(messageName string) (*desc.MessageDescriptor, error) {
537	file, err := cr.FileContainingSymbol(messageName)
538	if err != nil {
539		return nil, setSymbolType(err, messageName, symbolTypeMessage)
540	}
541	d := file.FindSymbol(messageName)
542	if d == nil {
543		return nil, symbolNotFound(messageName, symbolTypeMessage, nil)
544	}
545	if s, ok := d.(*desc.MessageDescriptor); ok {
546		return s, nil
547	} else {
548		return nil, symbolNotFound(messageName, symbolTypeMessage, nil)
549	}
550}
551
552// ResolveEnum asks the server to resolve the given fully-qualified enum name
553// into an enum descriptor.
554func (cr *Client) ResolveEnum(enumName string) (*desc.EnumDescriptor, error) {
555	file, err := cr.FileContainingSymbol(enumName)
556	if err != nil {
557		return nil, setSymbolType(err, enumName, symbolTypeEnum)
558	}
559	d := file.FindSymbol(enumName)
560	if d == nil {
561		return nil, symbolNotFound(enumName, symbolTypeEnum, nil)
562	}
563	if s, ok := d.(*desc.EnumDescriptor); ok {
564		return s, nil
565	} else {
566		return nil, symbolNotFound(enumName, symbolTypeEnum, nil)
567	}
568}
569
570func setSymbolType(err error, name string, symType symbolType) error {
571	if e, ok := err.(*elementNotFoundError); ok {
572		if e.kind == elementKindSymbol && e.name == name && e.symType == symbolTypeUnknown {
573			e.symType = symType
574		}
575	}
576	return err
577}
578
579// ResolveEnumValues asks the server to resolve the given fully-qualified enum
580// name into a map of names to numbers that represents the enum's values.
581func (cr *Client) ResolveEnumValues(enumName string) (map[string]int32, error) {
582	enumDesc, err := cr.ResolveEnum(enumName)
583	if err != nil {
584		return nil, err
585	}
586	vals := map[string]int32{}
587	for _, valDesc := range enumDesc.GetValues() {
588		vals[valDesc.GetName()] = valDesc.GetNumber()
589	}
590	return vals, nil
591}
592
593// ResolveExtension asks the server to resolve the given extension number and
594// fully-qualified message name into a field descriptor.
595func (cr *Client) ResolveExtension(extendedType string, extensionNumber int32) (*desc.FieldDescriptor, error) {
596	file, err := cr.FileContainingExtension(extendedType, extensionNumber)
597	if err != nil {
598		return nil, err
599	}
600	d := findExtension(extendedType, extensionNumber, fileDescriptorExtensions{file})
601	if d == nil {
602		return nil, extensionNotFound(extendedType, extensionNumber, nil)
603	} else {
604		return d, nil
605	}
606}
607
608func findExtension(extendedType string, extensionNumber int32, scope extensionScope) *desc.FieldDescriptor {
609	// search extensions in this scope
610	for _, ext := range scope.extensions() {
611		if ext.GetNumber() == extensionNumber && ext.GetOwner().GetFullyQualifiedName() == extendedType {
612			return ext
613		}
614	}
615
616	// if not found, search nested scopes
617	for _, nested := range scope.nestedScopes() {
618		ext := findExtension(extendedType, extensionNumber, nested)
619		if ext != nil {
620			return ext
621		}
622	}
623
624	return nil
625}
626
627type extensionScope interface {
628	extensions() []*desc.FieldDescriptor
629	nestedScopes() []extensionScope
630}
631
632// fileDescriptorExtensions implements extensionHolder interface on top of
633// FileDescriptorProto
634type fileDescriptorExtensions struct {
635	proto *desc.FileDescriptor
636}
637
638func (fde fileDescriptorExtensions) extensions() []*desc.FieldDescriptor {
639	return fde.proto.GetExtensions()
640}
641
642func (fde fileDescriptorExtensions) nestedScopes() []extensionScope {
643	scopes := make([]extensionScope, len(fde.proto.GetMessageTypes()))
644	for i, m := range fde.proto.GetMessageTypes() {
645		scopes[i] = msgDescriptorExtensions{m}
646	}
647	return scopes
648}
649
650// msgDescriptorExtensions implements extensionHolder interface on top of
651// DescriptorProto
652type msgDescriptorExtensions struct {
653	proto *desc.MessageDescriptor
654}
655
656func (mde msgDescriptorExtensions) extensions() []*desc.FieldDescriptor {
657	return mde.proto.GetNestedExtensions()
658}
659
660func (mde msgDescriptorExtensions) nestedScopes() []extensionScope {
661	scopes := make([]extensionScope, len(mde.proto.GetNestedMessageTypes()))
662	for i, m := range mde.proto.GetNestedMessageTypes() {
663		scopes[i] = msgDescriptorExtensions{m}
664	}
665	return scopes
666}
667