1/* 2 * 3 * Copyright 2016 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19//go:generate protoc --go_out=plugins=grpc:. grpc_reflection_v1alpha/reflection.proto 20 21/* 22Package reflection implements server reflection service. 23 24The service implemented is defined in: 25https://github.com/grpc/grpc/blob/master/src/proto/grpc/reflection/v1alpha/reflection.proto. 26 27To register server reflection on a gRPC server: 28 import "google.golang.org/grpc/reflection" 29 30 s := grpc.NewServer() 31 pb.RegisterYourOwnServer(s, &server{}) 32 33 // Register reflection service on gRPC server. 34 reflection.Register(s) 35 36 s.Serve(lis) 37 38*/ 39package reflection // import "google.golang.org/grpc/reflection" 40 41import ( 42 "bytes" 43 "compress/gzip" 44 "fmt" 45 "io" 46 "io/ioutil" 47 "reflect" 48 "strings" 49 50 "github.com/golang/protobuf/proto" 51 dpb "github.com/golang/protobuf/protoc-gen-go/descriptor" 52 "google.golang.org/grpc" 53 "google.golang.org/grpc/codes" 54 rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" 55) 56 57type serverReflectionServer struct { 58 s *grpc.Server 59 // TODO add more cache if necessary 60 serviceInfo map[string]grpc.ServiceInfo // cache for s.GetServiceInfo() 61} 62 63// Register registers the server reflection service on the given gRPC server. 64func Register(s *grpc.Server) { 65 rpb.RegisterServerReflectionServer(s, &serverReflectionServer{ 66 s: s, 67 }) 68} 69 70// protoMessage is used for type assertion on proto messages. 71// Generated proto message implements function Descriptor(), but Descriptor() 72// is not part of interface proto.Message. This interface is needed to 73// call Descriptor(). 74type protoMessage interface { 75 Descriptor() ([]byte, []int) 76} 77 78// fileDescForType gets the file descriptor for the given type. 79// The given type should be a proto message. 80func (s *serverReflectionServer) fileDescForType(st reflect.Type) (*dpb.FileDescriptorProto, error) { 81 m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(protoMessage) 82 if !ok { 83 return nil, fmt.Errorf("failed to create message from type: %v", st) 84 } 85 enc, _ := m.Descriptor() 86 87 return s.decodeFileDesc(enc) 88} 89 90// decodeFileDesc does decompression and unmarshalling on the given 91// file descriptor byte slice. 92func (s *serverReflectionServer) decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) { 93 raw, err := decompress(enc) 94 if err != nil { 95 return nil, fmt.Errorf("failed to decompress enc: %v", err) 96 } 97 98 fd := new(dpb.FileDescriptorProto) 99 if err := proto.Unmarshal(raw, fd); err != nil { 100 return nil, fmt.Errorf("bad descriptor: %v", err) 101 } 102 return fd, nil 103} 104 105// decompress does gzip decompression. 106func decompress(b []byte) ([]byte, error) { 107 r, err := gzip.NewReader(bytes.NewReader(b)) 108 if err != nil { 109 return nil, fmt.Errorf("bad gzipped descriptor: %v", err) 110 } 111 out, err := ioutil.ReadAll(r) 112 if err != nil { 113 return nil, fmt.Errorf("bad gzipped descriptor: %v", err) 114 } 115 return out, nil 116} 117 118func (s *serverReflectionServer) typeForName(name string) (reflect.Type, error) { 119 pt := proto.MessageType(name) 120 if pt == nil { 121 return nil, fmt.Errorf("unknown type: %q", name) 122 } 123 st := pt.Elem() 124 125 return st, nil 126} 127 128func (s *serverReflectionServer) fileDescContainingExtension(st reflect.Type, ext int32) (*dpb.FileDescriptorProto, error) { 129 m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message) 130 if !ok { 131 return nil, fmt.Errorf("failed to create message from type: %v", st) 132 } 133 134 var extDesc *proto.ExtensionDesc 135 for id, desc := range proto.RegisteredExtensions(m) { 136 if id == ext { 137 extDesc = desc 138 break 139 } 140 } 141 142 if extDesc == nil { 143 return nil, fmt.Errorf("failed to find registered extension for extension number %v", ext) 144 } 145 146 return s.decodeFileDesc(proto.FileDescriptor(extDesc.Filename)) 147} 148 149func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]int32, error) { 150 m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message) 151 if !ok { 152 return nil, fmt.Errorf("failed to create message from type: %v", st) 153 } 154 155 exts := proto.RegisteredExtensions(m) 156 out := make([]int32, 0, len(exts)) 157 for id := range exts { 158 out = append(out, id) 159 } 160 return out, nil 161} 162 163// fileDescEncodingByFilename finds the file descriptor for given filename, 164// does marshalling on it and returns the marshalled result. 165func (s *serverReflectionServer) fileDescEncodingByFilename(name string) ([]byte, error) { 166 enc := proto.FileDescriptor(name) 167 if enc == nil { 168 return nil, fmt.Errorf("unknown file: %v", name) 169 } 170 fd, err := s.decodeFileDesc(enc) 171 if err != nil { 172 return nil, err 173 } 174 return proto.Marshal(fd) 175} 176 177// serviceMetadataForSymbol finds the metadata for name in s.serviceInfo. 178// name should be a service name or a method name. 179func (s *serverReflectionServer) serviceMetadataForSymbol(name string) (interface{}, error) { 180 if s.serviceInfo == nil { 181 s.serviceInfo = s.s.GetServiceInfo() 182 } 183 184 // Check if it's a service name. 185 if info, ok := s.serviceInfo[name]; ok { 186 return info.Metadata, nil 187 } 188 189 // Check if it's a method name. 190 pos := strings.LastIndex(name, ".") 191 // Not a valid method name. 192 if pos == -1 { 193 return nil, fmt.Errorf("unknown symbol: %v", name) 194 } 195 196 info, ok := s.serviceInfo[name[:pos]] 197 // Substring before last "." is not a service name. 198 if !ok { 199 return nil, fmt.Errorf("unknown symbol: %v", name) 200 } 201 202 // Search the method name in info.Methods. 203 var found bool 204 for _, m := range info.Methods { 205 if m.Name == name[pos+1:] { 206 found = true 207 break 208 } 209 } 210 if found { 211 return info.Metadata, nil 212 } 213 214 return nil, fmt.Errorf("unknown symbol: %v", name) 215} 216 217// parseMetadata finds the file descriptor bytes specified meta. 218// For SupportPackageIsVersion4, m is the name of the proto file, we 219// call proto.FileDescriptor to get the byte slice. 220// For SupportPackageIsVersion3, m is a byte slice itself. 221func parseMetadata(meta interface{}) ([]byte, bool) { 222 // Check if meta is the file name. 223 if fileNameForMeta, ok := meta.(string); ok { 224 return proto.FileDescriptor(fileNameForMeta), true 225 } 226 227 // Check if meta is the byte slice. 228 if enc, ok := meta.([]byte); ok { 229 return enc, true 230 } 231 232 return nil, false 233} 234 235// fileDescEncodingContainingSymbol finds the file descriptor containing the given symbol, 236// does marshalling on it and returns the marshalled result. 237// The given symbol can be a type, a service or a method. 238func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) ([]byte, error) { 239 var ( 240 fd *dpb.FileDescriptorProto 241 ) 242 // Check if it's a type name. 243 if st, err := s.typeForName(name); err == nil { 244 fd, err = s.fileDescForType(st) 245 if err != nil { 246 return nil, err 247 } 248 } else { // Check if it's a service name or a method name. 249 meta, err := s.serviceMetadataForSymbol(name) 250 251 // Metadata not found. 252 if err != nil { 253 return nil, err 254 } 255 256 // Metadata not valid. 257 enc, ok := parseMetadata(meta) 258 if !ok { 259 return nil, fmt.Errorf("invalid file descriptor for symbol: %v", name) 260 } 261 262 fd, err = s.decodeFileDesc(enc) 263 if err != nil { 264 return nil, err 265 } 266 } 267 268 return proto.Marshal(fd) 269} 270 271// fileDescEncodingContainingExtension finds the file descriptor containing given extension, 272// does marshalling on it and returns the marshalled result. 273func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32) ([]byte, error) { 274 st, err := s.typeForName(typeName) 275 if err != nil { 276 return nil, err 277 } 278 fd, err := s.fileDescContainingExtension(st, extNum) 279 if err != nil { 280 return nil, err 281 } 282 return proto.Marshal(fd) 283} 284 285// allExtensionNumbersForTypeName returns all extension numbers for the given type. 286func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]int32, error) { 287 st, err := s.typeForName(name) 288 if err != nil { 289 return nil, err 290 } 291 extNums, err := s.allExtensionNumbersForType(st) 292 if err != nil { 293 return nil, err 294 } 295 return extNums, nil 296} 297 298// ServerReflectionInfo is the reflection service handler. 299func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error { 300 for { 301 in, err := stream.Recv() 302 if err == io.EOF { 303 return nil 304 } 305 if err != nil { 306 return err 307 } 308 309 out := &rpb.ServerReflectionResponse{ 310 ValidHost: in.Host, 311 OriginalRequest: in, 312 } 313 switch req := in.MessageRequest.(type) { 314 case *rpb.ServerReflectionRequest_FileByFilename: 315 b, err := s.fileDescEncodingByFilename(req.FileByFilename) 316 if err != nil { 317 out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ 318 ErrorResponse: &rpb.ErrorResponse{ 319 ErrorCode: int32(codes.NotFound), 320 ErrorMessage: err.Error(), 321 }, 322 } 323 } else { 324 out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ 325 FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}}, 326 } 327 } 328 case *rpb.ServerReflectionRequest_FileContainingSymbol: 329 b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol) 330 if err != nil { 331 out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ 332 ErrorResponse: &rpb.ErrorResponse{ 333 ErrorCode: int32(codes.NotFound), 334 ErrorMessage: err.Error(), 335 }, 336 } 337 } else { 338 out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ 339 FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}}, 340 } 341 } 342 case *rpb.ServerReflectionRequest_FileContainingExtension: 343 typeName := req.FileContainingExtension.ContainingType 344 extNum := req.FileContainingExtension.ExtensionNumber 345 b, err := s.fileDescEncodingContainingExtension(typeName, extNum) 346 if err != nil { 347 out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ 348 ErrorResponse: &rpb.ErrorResponse{ 349 ErrorCode: int32(codes.NotFound), 350 ErrorMessage: err.Error(), 351 }, 352 } 353 } else { 354 out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ 355 FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}}, 356 } 357 } 358 case *rpb.ServerReflectionRequest_AllExtensionNumbersOfType: 359 extNums, err := s.allExtensionNumbersForTypeName(req.AllExtensionNumbersOfType) 360 if err != nil { 361 out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ 362 ErrorResponse: &rpb.ErrorResponse{ 363 ErrorCode: int32(codes.NotFound), 364 ErrorMessage: err.Error(), 365 }, 366 } 367 } else { 368 out.MessageResponse = &rpb.ServerReflectionResponse_AllExtensionNumbersResponse{ 369 AllExtensionNumbersResponse: &rpb.ExtensionNumberResponse{ 370 BaseTypeName: req.AllExtensionNumbersOfType, 371 ExtensionNumber: extNums, 372 }, 373 } 374 } 375 case *rpb.ServerReflectionRequest_ListServices: 376 if s.serviceInfo == nil { 377 s.serviceInfo = s.s.GetServiceInfo() 378 } 379 serviceResponses := make([]*rpb.ServiceResponse, 0, len(s.serviceInfo)) 380 for n := range s.serviceInfo { 381 serviceResponses = append(serviceResponses, &rpb.ServiceResponse{ 382 Name: n, 383 }) 384 } 385 out.MessageResponse = &rpb.ServerReflectionResponse_ListServicesResponse{ 386 ListServicesResponse: &rpb.ListServiceResponse{ 387 Service: serviceResponses, 388 }, 389 } 390 default: 391 return grpc.Errorf(codes.InvalidArgument, "invalid MessageRequest: %v", in.MessageRequest) 392 } 393 394 if err := stream.Send(out); err != nil { 395 return err 396 } 397 } 398} 399