1// Package grpcurl provides the core functionality exposed by the grpcurl command, for 2// dynamically connecting to a server, using the reflection service to inspect the server, 3// and invoking RPCs. The grpcurl command-line tool constructs a DescriptorSource, based 4// on the command-line parameters, and supplies an InvocationEventHandler to supply request 5// data (which can come from command-line args or the process's stdin) and to log the 6// events (to the process's stdout). 7package grpcurl 8 9import ( 10 "bytes" 11 "context" 12 "crypto/tls" 13 "crypto/x509" 14 "encoding/base64" 15 "errors" 16 "fmt" 17 "io/ioutil" 18 "net" 19 "os" 20 "regexp" 21 "sort" 22 "strings" 23 24 "github.com/golang/protobuf/proto" //lint:ignore SA1019 we have to import this because it appears in exported API 25 "github.com/jhump/protoreflect/desc" 26 "github.com/jhump/protoreflect/desc/protoprint" 27 "github.com/jhump/protoreflect/dynamic" 28 "google.golang.org/grpc" 29 "google.golang.org/grpc/credentials" 30 "google.golang.org/grpc/metadata" 31 protov2 "google.golang.org/protobuf/proto" 32 "google.golang.org/protobuf/types/descriptorpb" 33 "google.golang.org/protobuf/types/known/anypb" 34 "google.golang.org/protobuf/types/known/emptypb" 35 "google.golang.org/protobuf/types/known/structpb" 36) 37 38// ListServices uses the given descriptor source to return a sorted list of fully-qualified 39// service names. 40func ListServices(source DescriptorSource) ([]string, error) { 41 svcs, err := source.ListServices() 42 if err != nil { 43 return nil, err 44 } 45 sort.Strings(svcs) 46 return svcs, nil 47} 48 49type sourceWithFiles interface { 50 GetAllFiles() ([]*desc.FileDescriptor, error) 51} 52 53var _ sourceWithFiles = (*fileSource)(nil) 54 55// GetAllFiles uses the given descriptor source to return a list of file descriptors. 56func GetAllFiles(source DescriptorSource) ([]*desc.FileDescriptor, error) { 57 var files []*desc.FileDescriptor 58 srcFiles, ok := source.(sourceWithFiles) 59 60 // If an error occurs, we still try to load as many files as we can, so that 61 // caller can decide whether to ignore error or not. 62 var firstError error 63 if ok { 64 files, firstError = srcFiles.GetAllFiles() 65 } else { 66 // Source does not implement GetAllFiles method, so use ListServices 67 // and grab files from there. 68 svcNames, err := source.ListServices() 69 if err != nil { 70 firstError = err 71 } else { 72 allFiles := map[string]*desc.FileDescriptor{} 73 for _, name := range svcNames { 74 d, err := source.FindSymbol(name) 75 if err != nil { 76 if firstError == nil { 77 firstError = err 78 } 79 } else { 80 addAllFilesToSet(d.GetFile(), allFiles) 81 } 82 } 83 files = make([]*desc.FileDescriptor, len(allFiles)) 84 i := 0 85 for _, fd := range allFiles { 86 files[i] = fd 87 i++ 88 } 89 } 90 } 91 92 sort.Sort(filesByName(files)) 93 return files, firstError 94} 95 96type filesByName []*desc.FileDescriptor 97 98func (f filesByName) Len() int { 99 return len(f) 100} 101 102func (f filesByName) Less(i, j int) bool { 103 return f[i].GetName() < f[j].GetName() 104} 105 106func (f filesByName) Swap(i, j int) { 107 f[i], f[j] = f[j], f[i] 108} 109 110func addAllFilesToSet(fd *desc.FileDescriptor, all map[string]*desc.FileDescriptor) { 111 if _, ok := all[fd.GetName()]; ok { 112 // already added 113 return 114 } 115 all[fd.GetName()] = fd 116 for _, dep := range fd.GetDependencies() { 117 addAllFilesToSet(dep, all) 118 } 119} 120 121// ListMethods uses the given descriptor source to return a sorted list of method names 122// for the specified fully-qualified service name. 123func ListMethods(source DescriptorSource, serviceName string) ([]string, error) { 124 dsc, err := source.FindSymbol(serviceName) 125 if err != nil { 126 return nil, err 127 } 128 if sd, ok := dsc.(*desc.ServiceDescriptor); !ok { 129 return nil, notFound("Service", serviceName) 130 } else { 131 methods := make([]string, 0, len(sd.GetMethods())) 132 for _, method := range sd.GetMethods() { 133 methods = append(methods, method.GetFullyQualifiedName()) 134 } 135 sort.Strings(methods) 136 return methods, nil 137 } 138} 139 140// MetadataFromHeaders converts a list of header strings (each string in 141// "Header-Name: Header-Value" form) into metadata. If a string has a header 142// name without a value (e.g. does not contain a colon), the value is assumed 143// to be blank. Binary headers (those whose names end in "-bin") should be 144// base64-encoded. But if they cannot be base64-decoded, they will be assumed to 145// be in raw form and used as is. 146func MetadataFromHeaders(headers []string) metadata.MD { 147 md := make(metadata.MD) 148 for _, part := range headers { 149 if part != "" { 150 pieces := strings.SplitN(part, ":", 2) 151 if len(pieces) == 1 { 152 pieces = append(pieces, "") // if no value was specified, just make it "" (maybe the header value doesn't matter) 153 } 154 headerName := strings.ToLower(strings.TrimSpace(pieces[0])) 155 val := strings.TrimSpace(pieces[1]) 156 if strings.HasSuffix(headerName, "-bin") { 157 if v, err := decode(val); err == nil { 158 val = v 159 } 160 } 161 md[headerName] = append(md[headerName], val) 162 } 163 } 164 return md 165} 166 167var envVarRegex = regexp.MustCompile(`\${\w+}`) 168 169// ExpandHeaders expands environment variables contained in the header string. 170// If no corresponding environment variable is found an error is returned. 171// TODO: Add escaping for `${` 172func ExpandHeaders(headers []string) ([]string, error) { 173 expandedHeaders := make([]string, len(headers)) 174 for idx, header := range headers { 175 if header == "" { 176 continue 177 } 178 results := envVarRegex.FindAllString(header, -1) 179 if len(results) == 0 { 180 expandedHeaders[idx] = headers[idx] 181 continue 182 } 183 expandedHeader := header 184 for _, result := range results { 185 envVarName := result[2 : len(result)-1] // strip leading `${` and trailing `}` 186 envVarValue, ok := os.LookupEnv(envVarName) 187 if !ok { 188 return nil, fmt.Errorf("header %q refers to missing environment variable %q", header, envVarName) 189 } 190 expandedHeader = strings.Replace(expandedHeader, result, envVarValue, -1) 191 } 192 expandedHeaders[idx] = expandedHeader 193 } 194 return expandedHeaders, nil 195} 196 197var base64Codecs = []*base64.Encoding{base64.StdEncoding, base64.URLEncoding, base64.RawStdEncoding, base64.RawURLEncoding} 198 199func decode(val string) (string, error) { 200 var firstErr error 201 var b []byte 202 // we are lenient and can accept any of the flavors of base64 encoding 203 for _, d := range base64Codecs { 204 var err error 205 b, err = d.DecodeString(val) 206 if err != nil { 207 if firstErr == nil { 208 firstErr = err 209 } 210 continue 211 } 212 return string(b), nil 213 } 214 return "", firstErr 215} 216 217// MetadataToString returns a string representation of the given metadata, for 218// displaying to users. 219func MetadataToString(md metadata.MD) string { 220 if len(md) == 0 { 221 return "(empty)" 222 } 223 224 keys := make([]string, 0, len(md)) 225 for k := range md { 226 keys = append(keys, k) 227 } 228 sort.Strings(keys) 229 230 var b bytes.Buffer 231 first := true 232 for _, k := range keys { 233 vs := md[k] 234 for _, v := range vs { 235 if first { 236 first = false 237 } else { 238 b.WriteString("\n") 239 } 240 b.WriteString(k) 241 b.WriteString(": ") 242 if strings.HasSuffix(k, "-bin") { 243 v = base64.StdEncoding.EncodeToString([]byte(v)) 244 } 245 b.WriteString(v) 246 } 247 } 248 return b.String() 249} 250 251var printer = &protoprint.Printer{ 252 Compact: true, 253 OmitComments: protoprint.CommentsNonDoc, 254 SortElements: true, 255 ForceFullyQualifiedNames: true, 256} 257 258// GetDescriptorText returns a string representation of the given descriptor. 259// This returns a snippet of proto source that describes the given element. 260func GetDescriptorText(dsc desc.Descriptor, _ DescriptorSource) (string, error) { 261 // Note: DescriptorSource is not used, but remains an argument for backwards 262 // compatibility with previous implementation. 263 txt, err := printer.PrintProtoToString(dsc) 264 if err != nil { 265 return "", err 266 } 267 // callers don't expect trailing newlines 268 if txt[len(txt)-1] == '\n' { 269 txt = txt[:len(txt)-1] 270 } 271 return txt, nil 272} 273 274// EnsureExtensions uses the given descriptor source to download extensions for 275// the given message. It returns a copy of the given message, but as a dynamic 276// message that knows about all extensions known to the given descriptor source. 277func EnsureExtensions(source DescriptorSource, msg proto.Message) proto.Message { 278 // load any server extensions so we can properly describe custom options 279 dsc, err := desc.LoadMessageDescriptorForMessage(msg) 280 if err != nil { 281 return msg 282 } 283 284 var ext dynamic.ExtensionRegistry 285 if err = fetchAllExtensions(source, &ext, dsc, map[string]bool{}); err != nil { 286 return msg 287 } 288 289 // convert message into dynamic message that knows about applicable extensions 290 // (that way we can show meaningful info for custom options instead of printing as unknown) 291 msgFactory := dynamic.NewMessageFactoryWithExtensionRegistry(&ext) 292 dm, err := fullyConvertToDynamic(msgFactory, msg) 293 if err != nil { 294 return msg 295 } 296 return dm 297} 298 299// fetchAllExtensions recursively fetches from the server extensions for the given message type as well as 300// for all message types of nested fields. The extensions are added to the given dynamic registry of extensions 301// so that all server-known extensions can be correctly parsed by grpcurl. 302func fetchAllExtensions(source DescriptorSource, ext *dynamic.ExtensionRegistry, md *desc.MessageDescriptor, alreadyFetched map[string]bool) error { 303 msgTypeName := md.GetFullyQualifiedName() 304 if alreadyFetched[msgTypeName] { 305 return nil 306 } 307 alreadyFetched[msgTypeName] = true 308 if len(md.GetExtensionRanges()) > 0 { 309 fds, err := source.AllExtensionsForType(msgTypeName) 310 if err != nil { 311 return fmt.Errorf("failed to query for extensions of type %s: %v", msgTypeName, err) 312 } 313 for _, fd := range fds { 314 if err := ext.AddExtension(fd); err != nil { 315 return fmt.Errorf("could not register extension %s of type %s: %v", fd.GetFullyQualifiedName(), msgTypeName, err) 316 } 317 } 318 } 319 // recursively fetch extensions for the types of any message fields 320 for _, fd := range md.GetFields() { 321 if fd.GetMessageType() != nil { 322 err := fetchAllExtensions(source, ext, fd.GetMessageType(), alreadyFetched) 323 if err != nil { 324 return err 325 } 326 } 327 } 328 return nil 329} 330 331// fullConvertToDynamic attempts to convert the given message to a dynamic message as well 332// as any nested messages it may contain as field values. If the given message factory has 333// extensions registered that were not known when the given message was parsed, this effectively 334// allows re-parsing to identify those extensions. 335func fullyConvertToDynamic(msgFact *dynamic.MessageFactory, msg proto.Message) (proto.Message, error) { 336 if _, ok := msg.(*dynamic.Message); ok { 337 return msg, nil // already a dynamic message 338 } 339 md, err := desc.LoadMessageDescriptorForMessage(msg) 340 if err != nil { 341 return nil, err 342 } 343 newMsg := msgFact.NewMessage(md) 344 dm, ok := newMsg.(*dynamic.Message) 345 if !ok { 346 // if message factory didn't produce a dynamic message, then we should leave msg as is 347 return msg, nil 348 } 349 350 if err := dm.ConvertFrom(msg); err != nil { 351 return nil, err 352 } 353 354 // recursively convert all field values, too 355 for _, fd := range md.GetFields() { 356 if fd.IsMap() { 357 if fd.GetMapValueType().GetMessageType() != nil { 358 m := dm.GetField(fd).(map[interface{}]interface{}) 359 for k, v := range m { 360 // keys can't be nested messages; so we only need to recurse through map values, not keys 361 newVal, err := fullyConvertToDynamic(msgFact, v.(proto.Message)) 362 if err != nil { 363 return nil, err 364 } 365 dm.PutMapField(fd, k, newVal) 366 } 367 } 368 } else if fd.IsRepeated() { 369 if fd.GetMessageType() != nil { 370 s := dm.GetField(fd).([]interface{}) 371 for i, e := range s { 372 newVal, err := fullyConvertToDynamic(msgFact, e.(proto.Message)) 373 if err != nil { 374 return nil, err 375 } 376 dm.SetRepeatedField(fd, i, newVal) 377 } 378 } 379 } else { 380 if fd.GetMessageType() != nil { 381 v := dm.GetField(fd) 382 newVal, err := fullyConvertToDynamic(msgFact, v.(proto.Message)) 383 if err != nil { 384 return nil, err 385 } 386 dm.SetField(fd, newVal) 387 } 388 } 389 } 390 return dm, nil 391} 392 393// MakeTemplate returns a message instance for the given descriptor that is a 394// suitable template for creating an instance of that message in JSON. In 395// particular, it ensures that any repeated fields (which include map fields) 396// are not empty, so they will render with a single element (to show the types 397// and optionally nested fields). It also ensures that nested messages are not 398// nil by setting them to a message that is also fleshed out as a template 399// message. 400func MakeTemplate(md *desc.MessageDescriptor) proto.Message { 401 return makeTemplate(md, nil) 402} 403 404func makeTemplate(md *desc.MessageDescriptor, path []*desc.MessageDescriptor) proto.Message { 405 switch md.GetFullyQualifiedName() { 406 case "google.protobuf.Any": 407 // empty type URL is not allowed by JSON representation 408 // so we must give it a dummy type 409 var any anypb.Any 410 _ = anypb.MarshalFrom(&any, &emptypb.Empty{}, protov2.MarshalOptions{}) 411 return &any 412 case "google.protobuf.Value": 413 // unset kind is not allowed by JSON representation 414 // so we must give it something 415 return &structpb.Value{ 416 Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{ 417 Fields: map[string]*structpb.Value{ 418 "google.protobuf.Value": {Kind: &structpb.Value_StringValue{ 419 StringValue: "supports arbitrary JSON", 420 }}, 421 }, 422 }}, 423 } 424 case "google.protobuf.ListValue": 425 return &structpb.ListValue{ 426 Values: []*structpb.Value{ 427 { 428 Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{ 429 Fields: map[string]*structpb.Value{ 430 "google.protobuf.ListValue": {Kind: &structpb.Value_StringValue{ 431 StringValue: "is an array of arbitrary JSON values", 432 }}, 433 }, 434 }}, 435 }, 436 }, 437 } 438 case "google.protobuf.Struct": 439 return &structpb.Struct{ 440 Fields: map[string]*structpb.Value{ 441 "google.protobuf.Struct": {Kind: &structpb.Value_StringValue{ 442 StringValue: "supports arbitrary JSON objects", 443 }}, 444 }, 445 } 446 } 447 448 dm := dynamic.NewMessage(md) 449 450 // if the message is a recursive structure, we don't want to blow the stack 451 for _, seen := range path { 452 if seen == md { 453 // already visited this type; avoid infinite recursion 454 return dm 455 } 456 } 457 path = append(path, dm.GetMessageDescriptor()) 458 459 // for repeated fields, add a single element with default value 460 // and for message fields, add a message with all default fields 461 // that also has non-nil message and non-empty repeated fields 462 463 for _, fd := range dm.GetMessageDescriptor().GetFields() { 464 if fd.IsRepeated() { 465 switch fd.GetType() { 466 case descriptorpb.FieldDescriptorProto_TYPE_FIXED32, 467 descriptorpb.FieldDescriptorProto_TYPE_UINT32: 468 dm.AddRepeatedField(fd, uint32(0)) 469 470 case descriptorpb.FieldDescriptorProto_TYPE_SFIXED32, 471 descriptorpb.FieldDescriptorProto_TYPE_SINT32, 472 descriptorpb.FieldDescriptorProto_TYPE_INT32, 473 descriptorpb.FieldDescriptorProto_TYPE_ENUM: 474 dm.AddRepeatedField(fd, int32(0)) 475 476 case descriptorpb.FieldDescriptorProto_TYPE_FIXED64, 477 descriptorpb.FieldDescriptorProto_TYPE_UINT64: 478 dm.AddRepeatedField(fd, uint64(0)) 479 480 case descriptorpb.FieldDescriptorProto_TYPE_SFIXED64, 481 descriptorpb.FieldDescriptorProto_TYPE_SINT64, 482 descriptorpb.FieldDescriptorProto_TYPE_INT64: 483 dm.AddRepeatedField(fd, int64(0)) 484 485 case descriptorpb.FieldDescriptorProto_TYPE_STRING: 486 dm.AddRepeatedField(fd, "") 487 488 case descriptorpb.FieldDescriptorProto_TYPE_BYTES: 489 dm.AddRepeatedField(fd, []byte{}) 490 491 case descriptorpb.FieldDescriptorProto_TYPE_BOOL: 492 dm.AddRepeatedField(fd, false) 493 494 case descriptorpb.FieldDescriptorProto_TYPE_FLOAT: 495 dm.AddRepeatedField(fd, float32(0)) 496 497 case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE: 498 dm.AddRepeatedField(fd, float64(0)) 499 500 case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE, 501 descriptorpb.FieldDescriptorProto_TYPE_GROUP: 502 dm.AddRepeatedField(fd, makeTemplate(fd.GetMessageType(), path)) 503 } 504 } else if fd.GetMessageType() != nil { 505 dm.SetField(fd, makeTemplate(fd.GetMessageType(), path)) 506 } 507 } 508 return dm 509} 510 511// ClientTransportCredentials is a helper function that constructs a TLS config with 512// the given properties (see ClientTLSConfig) and then constructs and returns gRPC 513// transport credentials using that config. 514// 515// Deprecated: Use grpcurl.ClientTLSConfig and credentials.NewTLS instead. 516func ClientTransportCredentials(insecureSkipVerify bool, cacertFile, clientCertFile, clientKeyFile string) (credentials.TransportCredentials, error) { 517 tlsConf, err := ClientTLSConfig(insecureSkipVerify, cacertFile, clientCertFile, clientKeyFile) 518 if err != nil { 519 return nil, err 520 } 521 522 return credentials.NewTLS(tlsConf), nil 523} 524 525// ClientTLSConfig builds transport-layer config for a gRPC client using the 526// given properties. If cacertFile is blank, only standard trusted certs are used to 527// verify the server certs. If clientCertFile is blank, the client will not use a client 528// certificate. If clientCertFile is not blank then clientKeyFile must not be blank. 529func ClientTLSConfig(insecureSkipVerify bool, cacertFile, clientCertFile, clientKeyFile string) (*tls.Config, error) { 530 var tlsConf tls.Config 531 532 if clientCertFile != "" { 533 // Load the client certificates from disk 534 certificate, err := tls.LoadX509KeyPair(clientCertFile, clientKeyFile) 535 if err != nil { 536 return nil, fmt.Errorf("could not load client key pair: %v", err) 537 } 538 tlsConf.Certificates = []tls.Certificate{certificate} 539 } 540 541 if insecureSkipVerify { 542 tlsConf.InsecureSkipVerify = true 543 } else if cacertFile != "" { 544 // Create a certificate pool from the certificate authority 545 certPool := x509.NewCertPool() 546 ca, err := ioutil.ReadFile(cacertFile) 547 if err != nil { 548 return nil, fmt.Errorf("could not read ca certificate: %v", err) 549 } 550 551 // Append the certificates from the CA 552 if ok := certPool.AppendCertsFromPEM(ca); !ok { 553 return nil, errors.New("failed to append ca certs") 554 } 555 556 tlsConf.RootCAs = certPool 557 } 558 559 return &tlsConf, nil 560} 561 562// ServerTransportCredentials builds transport credentials for a gRPC server using the 563// given properties. If cacertFile is blank, the server will not request client certs 564// unless requireClientCerts is true. When requireClientCerts is false and cacertFile is 565// not blank, the server will verify client certs when presented, but will not require 566// client certs. The serverCertFile and serverKeyFile must both not be blank. 567func ServerTransportCredentials(cacertFile, serverCertFile, serverKeyFile string, requireClientCerts bool) (credentials.TransportCredentials, error) { 568 var tlsConf tls.Config 569 // TODO(jh): Remove this line once https://github.com/golang/go/issues/28779 is fixed 570 // in Go tip. Until then, the recently merged TLS 1.3 support breaks the TLS tests. 571 tlsConf.MaxVersion = tls.VersionTLS12 572 573 // Load the server certificates from disk 574 certificate, err := tls.LoadX509KeyPair(serverCertFile, serverKeyFile) 575 if err != nil { 576 return nil, fmt.Errorf("could not load key pair: %v", err) 577 } 578 tlsConf.Certificates = []tls.Certificate{certificate} 579 580 if cacertFile != "" { 581 // Create a certificate pool from the certificate authority 582 certPool := x509.NewCertPool() 583 ca, err := ioutil.ReadFile(cacertFile) 584 if err != nil { 585 return nil, fmt.Errorf("could not read ca certificate: %v", err) 586 } 587 588 // Append the certificates from the CA 589 if ok := certPool.AppendCertsFromPEM(ca); !ok { 590 return nil, errors.New("failed to append ca certs") 591 } 592 593 tlsConf.ClientCAs = certPool 594 } 595 596 if requireClientCerts { 597 tlsConf.ClientAuth = tls.RequireAndVerifyClientCert 598 } else if cacertFile != "" { 599 tlsConf.ClientAuth = tls.VerifyClientCertIfGiven 600 } else { 601 tlsConf.ClientAuth = tls.NoClientCert 602 } 603 604 return credentials.NewTLS(&tlsConf), nil 605} 606 607// BlockingDial is a helper method to dial the given address, using optional TLS credentials, 608// and blocking until the returned connection is ready. If the given credentials are nil, the 609// connection will be insecure (plain-text). 610func BlockingDial(ctx context.Context, network, address string, creds credentials.TransportCredentials, opts ...grpc.DialOption) (*grpc.ClientConn, error) { 611 // grpc.Dial doesn't provide any information on permanent connection errors (like 612 // TLS handshake failures). So in order to provide good error messages, we need a 613 // custom dialer that can provide that info. That means we manage the TLS handshake. 614 result := make(chan interface{}, 1) 615 616 writeResult := func(res interface{}) { 617 // non-blocking write: we only need the first result 618 select { 619 case result <- res: 620 default: 621 } 622 } 623 624 // custom credentials and dialer will notify on error via the 625 // writeResult function 626 if creds != nil { 627 creds = &errSignalingCreds{ 628 TransportCredentials: creds, 629 writeResult: writeResult, 630 } 631 } 632 dialer := func(ctx context.Context, address string) (net.Conn, error) { 633 // NB: We *could* handle the TLS handshake ourselves, in the custom 634 // dialer (instead of customizing both the dialer and the credentials). 635 // But that requires using WithInsecure dial option (so that the gRPC 636 // library doesn't *also* try to do a handshake). And that would mean 637 // that the library would send the wrong ":scheme" metaheader to 638 // servers: it would send "http" instead of "https" because it is 639 // unaware that TLS is actually in use. 640 conn, err := (&net.Dialer{}).DialContext(ctx, network, address) 641 if err != nil { 642 writeResult(err) 643 } 644 return conn, err 645 } 646 647 // Even with grpc.FailOnNonTempDialError, this call will usually timeout in 648 // the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to 649 // know when we're done. So we run it in a goroutine and then use result 650 // channel to either get the connection or fail-fast. 651 go func() { 652 // We put grpc.FailOnNonTempDialError *before* the explicitly provided 653 // options so that it could be overridden. 654 opts = append([]grpc.DialOption{grpc.FailOnNonTempDialError(true)}, opts...) 655 // But we don't want caller to be able to override these two, so we put 656 // them *after* the explicitly provided options. 657 opts = append(opts, grpc.WithBlock(), grpc.WithContextDialer(dialer)) 658 659 if creds == nil { 660 opts = append(opts, grpc.WithInsecure()) 661 } else { 662 opts = append(opts, grpc.WithTransportCredentials(creds)) 663 } 664 conn, err := grpc.DialContext(ctx, address, opts...) 665 var res interface{} 666 if err != nil { 667 res = err 668 } else { 669 res = conn 670 } 671 writeResult(res) 672 }() 673 674 select { 675 case res := <-result: 676 if conn, ok := res.(*grpc.ClientConn); ok { 677 return conn, nil 678 } 679 return nil, res.(error) 680 case <-ctx.Done(): 681 return nil, ctx.Err() 682 } 683} 684 685// errSignalingCreds is a wrapper around a TransportCredentials value, but 686// it will use the writeResult function to notify on error. 687type errSignalingCreds struct { 688 credentials.TransportCredentials 689 writeResult func(res interface{}) 690} 691 692func (c *errSignalingCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 693 conn, auth, err := c.TransportCredentials.ClientHandshake(ctx, addr, rawConn) 694 if err != nil { 695 c.writeResult(err) 696 } 697 return conn, auth, err 698} 699