1/*
2 *
3 * Copyright 2016 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19//go:generate protoc --go_out=plugins=grpc:. grpc_reflection_v1alpha/reflection.proto
20
21/*
22Package reflection implements server reflection service.
23
24The service implemented is defined in:
25https://github.com/grpc/grpc/blob/master/src/proto/grpc/reflection/v1alpha/reflection.proto.
26
27To register server reflection on a gRPC server:
28	import "google.golang.org/grpc/reflection"
29
30	s := grpc.NewServer()
31	pb.RegisterYourOwnServer(s, &server{})
32
33	// Register reflection service on gRPC server.
34	reflection.Register(s)
35
36	s.Serve(lis)
37
38*/
39package reflection // import "google.golang.org/grpc/reflection"
40
41import (
42	"bytes"
43	"compress/gzip"
44	"fmt"
45	"io"
46	"io/ioutil"
47	"reflect"
48	"strings"
49
50	"github.com/golang/protobuf/proto"
51	dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
52	"google.golang.org/grpc"
53	"google.golang.org/grpc/codes"
54	rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
55)
56
57type serverReflectionServer struct {
58	s *grpc.Server
59	// TODO add more cache if necessary
60	serviceInfo map[string]grpc.ServiceInfo // cache for s.GetServiceInfo()
61}
62
63// Register registers the server reflection service on the given gRPC server.
64func Register(s *grpc.Server) {
65	rpb.RegisterServerReflectionServer(s, &serverReflectionServer{
66		s: s,
67	})
68}
69
70// protoMessage is used for type assertion on proto messages.
71// Generated proto message implements function Descriptor(), but Descriptor()
72// is not part of interface proto.Message. This interface is needed to
73// call Descriptor().
74type protoMessage interface {
75	Descriptor() ([]byte, []int)
76}
77
78// fileDescForType gets the file descriptor for the given type.
79// The given type should be a proto message.
80func (s *serverReflectionServer) fileDescForType(st reflect.Type) (*dpb.FileDescriptorProto, error) {
81	m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(protoMessage)
82	if !ok {
83		return nil, fmt.Errorf("failed to create message from type: %v", st)
84	}
85	enc, _ := m.Descriptor()
86
87	return s.decodeFileDesc(enc)
88}
89
90// decodeFileDesc does decompression and unmarshalling on the given
91// file descriptor byte slice.
92func (s *serverReflectionServer) decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) {
93	raw, err := decompress(enc)
94	if err != nil {
95		return nil, fmt.Errorf("failed to decompress enc: %v", err)
96	}
97
98	fd := new(dpb.FileDescriptorProto)
99	if err := proto.Unmarshal(raw, fd); err != nil {
100		return nil, fmt.Errorf("bad descriptor: %v", err)
101	}
102	return fd, nil
103}
104
105// decompress does gzip decompression.
106func decompress(b []byte) ([]byte, error) {
107	r, err := gzip.NewReader(bytes.NewReader(b))
108	if err != nil {
109		return nil, fmt.Errorf("bad gzipped descriptor: %v", err)
110	}
111	out, err := ioutil.ReadAll(r)
112	if err != nil {
113		return nil, fmt.Errorf("bad gzipped descriptor: %v", err)
114	}
115	return out, nil
116}
117
118func (s *serverReflectionServer) typeForName(name string) (reflect.Type, error) {
119	pt := proto.MessageType(name)
120	if pt == nil {
121		return nil, fmt.Errorf("unknown type: %q", name)
122	}
123	st := pt.Elem()
124
125	return st, nil
126}
127
128func (s *serverReflectionServer) fileDescContainingExtension(st reflect.Type, ext int32) (*dpb.FileDescriptorProto, error) {
129	m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message)
130	if !ok {
131		return nil, fmt.Errorf("failed to create message from type: %v", st)
132	}
133
134	var extDesc *proto.ExtensionDesc
135	for id, desc := range proto.RegisteredExtensions(m) {
136		if id == ext {
137			extDesc = desc
138			break
139		}
140	}
141
142	if extDesc == nil {
143		return nil, fmt.Errorf("failed to find registered extension for extension number %v", ext)
144	}
145
146	return s.decodeFileDesc(proto.FileDescriptor(extDesc.Filename))
147}
148
149func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]int32, error) {
150	m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message)
151	if !ok {
152		return nil, fmt.Errorf("failed to create message from type: %v", st)
153	}
154
155	exts := proto.RegisteredExtensions(m)
156	out := make([]int32, 0, len(exts))
157	for id := range exts {
158		out = append(out, id)
159	}
160	return out, nil
161}
162
163// fileDescEncodingByFilename finds the file descriptor for given filename,
164// does marshalling on it and returns the marshalled result.
165func (s *serverReflectionServer) fileDescEncodingByFilename(name string) ([]byte, error) {
166	enc := proto.FileDescriptor(name)
167	if enc == nil {
168		return nil, fmt.Errorf("unknown file: %v", name)
169	}
170	fd, err := s.decodeFileDesc(enc)
171	if err != nil {
172		return nil, err
173	}
174	return proto.Marshal(fd)
175}
176
177// serviceMetadataForSymbol finds the metadata for name in s.serviceInfo.
178// name should be a service name or a method name.
179func (s *serverReflectionServer) serviceMetadataForSymbol(name string) (interface{}, error) {
180	if s.serviceInfo == nil {
181		s.serviceInfo = s.s.GetServiceInfo()
182	}
183
184	// Check if it's a service name.
185	if info, ok := s.serviceInfo[name]; ok {
186		return info.Metadata, nil
187	}
188
189	// Check if it's a method name.
190	pos := strings.LastIndex(name, ".")
191	// Not a valid method name.
192	if pos == -1 {
193		return nil, fmt.Errorf("unknown symbol: %v", name)
194	}
195
196	info, ok := s.serviceInfo[name[:pos]]
197	// Substring before last "." is not a service name.
198	if !ok {
199		return nil, fmt.Errorf("unknown symbol: %v", name)
200	}
201
202	// Search the method name in info.Methods.
203	var found bool
204	for _, m := range info.Methods {
205		if m.Name == name[pos+1:] {
206			found = true
207			break
208		}
209	}
210	if found {
211		return info.Metadata, nil
212	}
213
214	return nil, fmt.Errorf("unknown symbol: %v", name)
215}
216
217// parseMetadata finds the file descriptor bytes specified meta.
218// For SupportPackageIsVersion4, m is the name of the proto file, we
219// call proto.FileDescriptor to get the byte slice.
220// For SupportPackageIsVersion3, m is a byte slice itself.
221func parseMetadata(meta interface{}) ([]byte, bool) {
222	// Check if meta is the file name.
223	if fileNameForMeta, ok := meta.(string); ok {
224		return proto.FileDescriptor(fileNameForMeta), true
225	}
226
227	// Check if meta is the byte slice.
228	if enc, ok := meta.([]byte); ok {
229		return enc, true
230	}
231
232	return nil, false
233}
234
235// fileDescEncodingContainingSymbol finds the file descriptor containing the given symbol,
236// does marshalling on it and returns the marshalled result.
237// The given symbol can be a type, a service or a method.
238func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) ([]byte, error) {
239	var (
240		fd *dpb.FileDescriptorProto
241	)
242	// Check if it's a type name.
243	if st, err := s.typeForName(name); err == nil {
244		fd, err = s.fileDescForType(st)
245		if err != nil {
246			return nil, err
247		}
248	} else { // Check if it's a service name or a method name.
249		meta, err := s.serviceMetadataForSymbol(name)
250
251		// Metadata not found.
252		if err != nil {
253			return nil, err
254		}
255
256		// Metadata not valid.
257		enc, ok := parseMetadata(meta)
258		if !ok {
259			return nil, fmt.Errorf("invalid file descriptor for symbol: %v", name)
260		}
261
262		fd, err = s.decodeFileDesc(enc)
263		if err != nil {
264			return nil, err
265		}
266	}
267
268	return proto.Marshal(fd)
269}
270
271// fileDescEncodingContainingExtension finds the file descriptor containing given extension,
272// does marshalling on it and returns the marshalled result.
273func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32) ([]byte, error) {
274	st, err := s.typeForName(typeName)
275	if err != nil {
276		return nil, err
277	}
278	fd, err := s.fileDescContainingExtension(st, extNum)
279	if err != nil {
280		return nil, err
281	}
282	return proto.Marshal(fd)
283}
284
285// allExtensionNumbersForTypeName returns all extension numbers for the given type.
286func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]int32, error) {
287	st, err := s.typeForName(name)
288	if err != nil {
289		return nil, err
290	}
291	extNums, err := s.allExtensionNumbersForType(st)
292	if err != nil {
293		return nil, err
294	}
295	return extNums, nil
296}
297
298// ServerReflectionInfo is the reflection service handler.
299func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error {
300	for {
301		in, err := stream.Recv()
302		if err == io.EOF {
303			return nil
304		}
305		if err != nil {
306			return err
307		}
308
309		out := &rpb.ServerReflectionResponse{
310			ValidHost:       in.Host,
311			OriginalRequest: in,
312		}
313		switch req := in.MessageRequest.(type) {
314		case *rpb.ServerReflectionRequest_FileByFilename:
315			b, err := s.fileDescEncodingByFilename(req.FileByFilename)
316			if err != nil {
317				out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
318					ErrorResponse: &rpb.ErrorResponse{
319						ErrorCode:    int32(codes.NotFound),
320						ErrorMessage: err.Error(),
321					},
322				}
323			} else {
324				out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
325					FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}},
326				}
327			}
328		case *rpb.ServerReflectionRequest_FileContainingSymbol:
329			b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol)
330			if err != nil {
331				out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
332					ErrorResponse: &rpb.ErrorResponse{
333						ErrorCode:    int32(codes.NotFound),
334						ErrorMessage: err.Error(),
335					},
336				}
337			} else {
338				out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
339					FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}},
340				}
341			}
342		case *rpb.ServerReflectionRequest_FileContainingExtension:
343			typeName := req.FileContainingExtension.ContainingType
344			extNum := req.FileContainingExtension.ExtensionNumber
345			b, err := s.fileDescEncodingContainingExtension(typeName, extNum)
346			if err != nil {
347				out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
348					ErrorResponse: &rpb.ErrorResponse{
349						ErrorCode:    int32(codes.NotFound),
350						ErrorMessage: err.Error(),
351					},
352				}
353			} else {
354				out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
355					FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}},
356				}
357			}
358		case *rpb.ServerReflectionRequest_AllExtensionNumbersOfType:
359			extNums, err := s.allExtensionNumbersForTypeName(req.AllExtensionNumbersOfType)
360			if err != nil {
361				out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
362					ErrorResponse: &rpb.ErrorResponse{
363						ErrorCode:    int32(codes.NotFound),
364						ErrorMessage: err.Error(),
365					},
366				}
367			} else {
368				out.MessageResponse = &rpb.ServerReflectionResponse_AllExtensionNumbersResponse{
369					AllExtensionNumbersResponse: &rpb.ExtensionNumberResponse{
370						BaseTypeName:    req.AllExtensionNumbersOfType,
371						ExtensionNumber: extNums,
372					},
373				}
374			}
375		case *rpb.ServerReflectionRequest_ListServices:
376			if s.serviceInfo == nil {
377				s.serviceInfo = s.s.GetServiceInfo()
378			}
379			serviceResponses := make([]*rpb.ServiceResponse, 0, len(s.serviceInfo))
380			for n := range s.serviceInfo {
381				serviceResponses = append(serviceResponses, &rpb.ServiceResponse{
382					Name: n,
383				})
384			}
385			out.MessageResponse = &rpb.ServerReflectionResponse_ListServicesResponse{
386				ListServicesResponse: &rpb.ListServiceResponse{
387					Service: serviceResponses,
388				},
389			}
390		default:
391			return grpc.Errorf(codes.InvalidArgument, "invalid MessageRequest: %v", in.MessageRequest)
392		}
393
394		if err := stream.Send(out); err != nil {
395			return err
396		}
397	}
398}
399