1/* 2 * 3 * Copyright 2016 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19//go:generate protoc --go_out=plugins=grpc:. grpc_reflection_v1alpha/reflection.proto 20 21/* 22Package reflection implements server reflection service. 23 24The service implemented is defined in: 25https://github.com/grpc/grpc/blob/master/src/proto/grpc/reflection/v1alpha/reflection.proto. 26 27To register server reflection on a gRPC server: 28 import "google.golang.org/grpc/reflection" 29 30 s := grpc.NewServer() 31 pb.RegisterYourOwnServer(s, &server{}) 32 33 // Register reflection service on gRPC server. 34 reflection.Register(s) 35 36 s.Serve(lis) 37 38*/ 39package reflection // import "google.golang.org/grpc/reflection" 40 41import ( 42 "bytes" 43 "compress/gzip" 44 "fmt" 45 "io" 46 "io/ioutil" 47 "reflect" 48 "sort" 49 "sync" 50 51 "github.com/golang/protobuf/proto" 52 dpb "github.com/golang/protobuf/protoc-gen-go/descriptor" 53 "google.golang.org/grpc" 54 "google.golang.org/grpc/codes" 55 rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha" 56 "google.golang.org/grpc/status" 57) 58 59type serverReflectionServer struct { 60 s *grpc.Server 61 62 initSymbols sync.Once 63 serviceNames []string 64 symbols map[string]*dpb.FileDescriptorProto // map of fully-qualified names to files 65} 66 67// Register registers the server reflection service on the given gRPC server. 68func Register(s *grpc.Server) { 69 rpb.RegisterServerReflectionServer(s, &serverReflectionServer{ 70 s: s, 71 }) 72} 73 74// protoMessage is used for type assertion on proto messages. 75// Generated proto message implements function Descriptor(), but Descriptor() 76// is not part of interface proto.Message. This interface is needed to 77// call Descriptor(). 78type protoMessage interface { 79 Descriptor() ([]byte, []int) 80} 81 82func (s *serverReflectionServer) getSymbols() (svcNames []string, symbolIndex map[string]*dpb.FileDescriptorProto) { 83 s.initSymbols.Do(func() { 84 serviceInfo := s.s.GetServiceInfo() 85 86 s.symbols = map[string]*dpb.FileDescriptorProto{} 87 s.serviceNames = make([]string, 0, len(serviceInfo)) 88 processed := map[string]struct{}{} 89 for svc, info := range serviceInfo { 90 s.serviceNames = append(s.serviceNames, svc) 91 fdenc, ok := parseMetadata(info.Metadata) 92 if !ok { 93 continue 94 } 95 fd, err := decodeFileDesc(fdenc) 96 if err != nil { 97 continue 98 } 99 s.processFile(fd, processed) 100 } 101 sort.Strings(s.serviceNames) 102 }) 103 104 return s.serviceNames, s.symbols 105} 106 107func (s *serverReflectionServer) processFile(fd *dpb.FileDescriptorProto, processed map[string]struct{}) { 108 filename := fd.GetName() 109 if _, ok := processed[filename]; ok { 110 return 111 } 112 processed[filename] = struct{}{} 113 114 prefix := fd.GetPackage() 115 116 for _, msg := range fd.MessageType { 117 s.processMessage(fd, prefix, msg) 118 } 119 for _, en := range fd.EnumType { 120 s.processEnum(fd, prefix, en) 121 } 122 for _, ext := range fd.Extension { 123 s.processField(fd, prefix, ext) 124 } 125 for _, svc := range fd.Service { 126 svcName := fqn(prefix, svc.GetName()) 127 s.symbols[svcName] = fd 128 for _, meth := range svc.Method { 129 name := fqn(svcName, meth.GetName()) 130 s.symbols[name] = fd 131 } 132 } 133 134 for _, dep := range fd.Dependency { 135 fdenc := proto.FileDescriptor(dep) 136 fdDep, err := decodeFileDesc(fdenc) 137 if err != nil { 138 continue 139 } 140 s.processFile(fdDep, processed) 141 } 142} 143 144func (s *serverReflectionServer) processMessage(fd *dpb.FileDescriptorProto, prefix string, msg *dpb.DescriptorProto) { 145 msgName := fqn(prefix, msg.GetName()) 146 s.symbols[msgName] = fd 147 148 for _, nested := range msg.NestedType { 149 s.processMessage(fd, msgName, nested) 150 } 151 for _, en := range msg.EnumType { 152 s.processEnum(fd, msgName, en) 153 } 154 for _, ext := range msg.Extension { 155 s.processField(fd, msgName, ext) 156 } 157 for _, fld := range msg.Field { 158 s.processField(fd, msgName, fld) 159 } 160 for _, oneof := range msg.OneofDecl { 161 oneofName := fqn(msgName, oneof.GetName()) 162 s.symbols[oneofName] = fd 163 } 164} 165 166func (s *serverReflectionServer) processEnum(fd *dpb.FileDescriptorProto, prefix string, en *dpb.EnumDescriptorProto) { 167 enName := fqn(prefix, en.GetName()) 168 s.symbols[enName] = fd 169 170 for _, val := range en.Value { 171 valName := fqn(enName, val.GetName()) 172 s.symbols[valName] = fd 173 } 174} 175 176func (s *serverReflectionServer) processField(fd *dpb.FileDescriptorProto, prefix string, fld *dpb.FieldDescriptorProto) { 177 fldName := fqn(prefix, fld.GetName()) 178 s.symbols[fldName] = fd 179} 180 181func fqn(prefix, name string) string { 182 if prefix == "" { 183 return name 184 } 185 return prefix + "." + name 186} 187 188// fileDescForType gets the file descriptor for the given type. 189// The given type should be a proto message. 190func (s *serverReflectionServer) fileDescForType(st reflect.Type) (*dpb.FileDescriptorProto, error) { 191 m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(protoMessage) 192 if !ok { 193 return nil, fmt.Errorf("failed to create message from type: %v", st) 194 } 195 enc, _ := m.Descriptor() 196 197 return decodeFileDesc(enc) 198} 199 200// decodeFileDesc does decompression and unmarshalling on the given 201// file descriptor byte slice. 202func decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) { 203 raw, err := decompress(enc) 204 if err != nil { 205 return nil, fmt.Errorf("failed to decompress enc: %v", err) 206 } 207 208 fd := new(dpb.FileDescriptorProto) 209 if err := proto.Unmarshal(raw, fd); err != nil { 210 return nil, fmt.Errorf("bad descriptor: %v", err) 211 } 212 return fd, nil 213} 214 215// decompress does gzip decompression. 216func decompress(b []byte) ([]byte, error) { 217 r, err := gzip.NewReader(bytes.NewReader(b)) 218 if err != nil { 219 return nil, fmt.Errorf("bad gzipped descriptor: %v", err) 220 } 221 out, err := ioutil.ReadAll(r) 222 if err != nil { 223 return nil, fmt.Errorf("bad gzipped descriptor: %v", err) 224 } 225 return out, nil 226} 227 228func typeForName(name string) (reflect.Type, error) { 229 pt := proto.MessageType(name) 230 if pt == nil { 231 return nil, fmt.Errorf("unknown type: %q", name) 232 } 233 st := pt.Elem() 234 235 return st, nil 236} 237 238func fileDescContainingExtension(st reflect.Type, ext int32) (*dpb.FileDescriptorProto, error) { 239 m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message) 240 if !ok { 241 return nil, fmt.Errorf("failed to create message from type: %v", st) 242 } 243 244 var extDesc *proto.ExtensionDesc 245 for id, desc := range proto.RegisteredExtensions(m) { 246 if id == ext { 247 extDesc = desc 248 break 249 } 250 } 251 252 if extDesc == nil { 253 return nil, fmt.Errorf("failed to find registered extension for extension number %v", ext) 254 } 255 256 return decodeFileDesc(proto.FileDescriptor(extDesc.Filename)) 257} 258 259func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]int32, error) { 260 m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message) 261 if !ok { 262 return nil, fmt.Errorf("failed to create message from type: %v", st) 263 } 264 265 exts := proto.RegisteredExtensions(m) 266 out := make([]int32, 0, len(exts)) 267 for id := range exts { 268 out = append(out, id) 269 } 270 return out, nil 271} 272 273// fileDescEncodingByFilename finds the file descriptor for given filename, 274// does marshalling on it and returns the marshalled result. 275func (s *serverReflectionServer) fileDescEncodingByFilename(name string) ([]byte, error) { 276 enc := proto.FileDescriptor(name) 277 if enc == nil { 278 return nil, fmt.Errorf("unknown file: %v", name) 279 } 280 fd, err := decodeFileDesc(enc) 281 if err != nil { 282 return nil, err 283 } 284 return proto.Marshal(fd) 285} 286 287// parseMetadata finds the file descriptor bytes specified meta. 288// For SupportPackageIsVersion4, m is the name of the proto file, we 289// call proto.FileDescriptor to get the byte slice. 290// For SupportPackageIsVersion3, m is a byte slice itself. 291func parseMetadata(meta interface{}) ([]byte, bool) { 292 // Check if meta is the file name. 293 if fileNameForMeta, ok := meta.(string); ok { 294 return proto.FileDescriptor(fileNameForMeta), true 295 } 296 297 // Check if meta is the byte slice. 298 if enc, ok := meta.([]byte); ok { 299 return enc, true 300 } 301 302 return nil, false 303} 304 305// fileDescEncodingContainingSymbol finds the file descriptor containing the given symbol, 306// does marshalling on it and returns the marshalled result. 307// The given symbol can be a type, a service or a method. 308func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) ([]byte, error) { 309 _, symbols := s.getSymbols() 310 fd := symbols[name] 311 if fd == nil { 312 // Check if it's a type name that was not present in the 313 // transitive dependencies of the registered services. 314 if st, err := typeForName(name); err == nil { 315 fd, err = s.fileDescForType(st) 316 if err != nil { 317 return nil, err 318 } 319 } 320 } 321 322 if fd == nil { 323 return nil, fmt.Errorf("unknown symbol: %v", name) 324 } 325 326 return proto.Marshal(fd) 327} 328 329// fileDescEncodingContainingExtension finds the file descriptor containing given extension, 330// does marshalling on it and returns the marshalled result. 331func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32) ([]byte, error) { 332 st, err := typeForName(typeName) 333 if err != nil { 334 return nil, err 335 } 336 fd, err := fileDescContainingExtension(st, extNum) 337 if err != nil { 338 return nil, err 339 } 340 return proto.Marshal(fd) 341} 342 343// allExtensionNumbersForTypeName returns all extension numbers for the given type. 344func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]int32, error) { 345 st, err := typeForName(name) 346 if err != nil { 347 return nil, err 348 } 349 extNums, err := s.allExtensionNumbersForType(st) 350 if err != nil { 351 return nil, err 352 } 353 return extNums, nil 354} 355 356// ServerReflectionInfo is the reflection service handler. 357func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error { 358 for { 359 in, err := stream.Recv() 360 if err == io.EOF { 361 return nil 362 } 363 if err != nil { 364 return err 365 } 366 367 out := &rpb.ServerReflectionResponse{ 368 ValidHost: in.Host, 369 OriginalRequest: in, 370 } 371 switch req := in.MessageRequest.(type) { 372 case *rpb.ServerReflectionRequest_FileByFilename: 373 b, err := s.fileDescEncodingByFilename(req.FileByFilename) 374 if err != nil { 375 out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ 376 ErrorResponse: &rpb.ErrorResponse{ 377 ErrorCode: int32(codes.NotFound), 378 ErrorMessage: err.Error(), 379 }, 380 } 381 } else { 382 out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ 383 FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}}, 384 } 385 } 386 case *rpb.ServerReflectionRequest_FileContainingSymbol: 387 b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol) 388 if err != nil { 389 out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ 390 ErrorResponse: &rpb.ErrorResponse{ 391 ErrorCode: int32(codes.NotFound), 392 ErrorMessage: err.Error(), 393 }, 394 } 395 } else { 396 out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ 397 FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}}, 398 } 399 } 400 case *rpb.ServerReflectionRequest_FileContainingExtension: 401 typeName := req.FileContainingExtension.ContainingType 402 extNum := req.FileContainingExtension.ExtensionNumber 403 b, err := s.fileDescEncodingContainingExtension(typeName, extNum) 404 if err != nil { 405 out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ 406 ErrorResponse: &rpb.ErrorResponse{ 407 ErrorCode: int32(codes.NotFound), 408 ErrorMessage: err.Error(), 409 }, 410 } 411 } else { 412 out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{ 413 FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}}, 414 } 415 } 416 case *rpb.ServerReflectionRequest_AllExtensionNumbersOfType: 417 extNums, err := s.allExtensionNumbersForTypeName(req.AllExtensionNumbersOfType) 418 if err != nil { 419 out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{ 420 ErrorResponse: &rpb.ErrorResponse{ 421 ErrorCode: int32(codes.NotFound), 422 ErrorMessage: err.Error(), 423 }, 424 } 425 } else { 426 out.MessageResponse = &rpb.ServerReflectionResponse_AllExtensionNumbersResponse{ 427 AllExtensionNumbersResponse: &rpb.ExtensionNumberResponse{ 428 BaseTypeName: req.AllExtensionNumbersOfType, 429 ExtensionNumber: extNums, 430 }, 431 } 432 } 433 case *rpb.ServerReflectionRequest_ListServices: 434 svcNames, _ := s.getSymbols() 435 serviceResponses := make([]*rpb.ServiceResponse, len(svcNames)) 436 for i, n := range svcNames { 437 serviceResponses[i] = &rpb.ServiceResponse{ 438 Name: n, 439 } 440 } 441 out.MessageResponse = &rpb.ServerReflectionResponse_ListServicesResponse{ 442 ListServicesResponse: &rpb.ListServiceResponse{ 443 Service: serviceResponses, 444 }, 445 } 446 default: 447 return status.Errorf(codes.InvalidArgument, "invalid MessageRequest: %v", in.MessageRequest) 448 } 449 450 if err := stream.Send(out); err != nil { 451 return err 452 } 453 } 454} 455