1// Package grpcurl provides the core functionality exposed by the grpcurl command, for
2// dynamically connecting to a server, using the reflection service to inspect the server,
3// and invoking RPCs. The grpcurl command-line tool constructs a DescriptorSource, based
4// on the command-line parameters, and supplies an InvocationEventHandler to supply request
5// data (which can come from command-line args or the process's stdin) and to log the
6// events (to the process's stdout).
7package grpcurl
8
9import (
10	"bytes"
11	"context"
12	"crypto/tls"
13	"crypto/x509"
14	"encoding/base64"
15	"errors"
16	"fmt"
17	"io/ioutil"
18	"net"
19	"os"
20	"regexp"
21	"sort"
22	"strings"
23
24	"github.com/golang/protobuf/proto" //lint:ignore SA1019 we have to import this because it appears in exported API
25	"github.com/jhump/protoreflect/desc"
26	"github.com/jhump/protoreflect/desc/protoprint"
27	"github.com/jhump/protoreflect/dynamic"
28	"google.golang.org/grpc"
29	"google.golang.org/grpc/credentials"
30	"google.golang.org/grpc/metadata"
31	protov2 "google.golang.org/protobuf/proto"
32	"google.golang.org/protobuf/types/descriptorpb"
33	"google.golang.org/protobuf/types/known/anypb"
34	"google.golang.org/protobuf/types/known/emptypb"
35	"google.golang.org/protobuf/types/known/structpb"
36)
37
38// ListServices uses the given descriptor source to return a sorted list of fully-qualified
39// service names.
40func ListServices(source DescriptorSource) ([]string, error) {
41	svcs, err := source.ListServices()
42	if err != nil {
43		return nil, err
44	}
45	sort.Strings(svcs)
46	return svcs, nil
47}
48
49type sourceWithFiles interface {
50	GetAllFiles() ([]*desc.FileDescriptor, error)
51}
52
53var _ sourceWithFiles = (*fileSource)(nil)
54
55// GetAllFiles uses the given descriptor source to return a list of file descriptors.
56func GetAllFiles(source DescriptorSource) ([]*desc.FileDescriptor, error) {
57	var files []*desc.FileDescriptor
58	srcFiles, ok := source.(sourceWithFiles)
59
60	// If an error occurs, we still try to load as many files as we can, so that
61	// caller can decide whether to ignore error or not.
62	var firstError error
63	if ok {
64		files, firstError = srcFiles.GetAllFiles()
65	} else {
66		// Source does not implement GetAllFiles method, so use ListServices
67		// and grab files from there.
68		svcNames, err := source.ListServices()
69		if err != nil {
70			firstError = err
71		} else {
72			allFiles := map[string]*desc.FileDescriptor{}
73			for _, name := range svcNames {
74				d, err := source.FindSymbol(name)
75				if err != nil {
76					if firstError == nil {
77						firstError = err
78					}
79				} else {
80					addAllFilesToSet(d.GetFile(), allFiles)
81				}
82			}
83			files = make([]*desc.FileDescriptor, len(allFiles))
84			i := 0
85			for _, fd := range allFiles {
86				files[i] = fd
87				i++
88			}
89		}
90	}
91
92	sort.Sort(filesByName(files))
93	return files, firstError
94}
95
96type filesByName []*desc.FileDescriptor
97
98func (f filesByName) Len() int {
99	return len(f)
100}
101
102func (f filesByName) Less(i, j int) bool {
103	return f[i].GetName() < f[j].GetName()
104}
105
106func (f filesByName) Swap(i, j int) {
107	f[i], f[j] = f[j], f[i]
108}
109
110func addAllFilesToSet(fd *desc.FileDescriptor, all map[string]*desc.FileDescriptor) {
111	if _, ok := all[fd.GetName()]; ok {
112		// already added
113		return
114	}
115	all[fd.GetName()] = fd
116	for _, dep := range fd.GetDependencies() {
117		addAllFilesToSet(dep, all)
118	}
119}
120
121// ListMethods uses the given descriptor source to return a sorted list of method names
122// for the specified fully-qualified service name.
123func ListMethods(source DescriptorSource, serviceName string) ([]string, error) {
124	dsc, err := source.FindSymbol(serviceName)
125	if err != nil {
126		return nil, err
127	}
128	if sd, ok := dsc.(*desc.ServiceDescriptor); !ok {
129		return nil, notFound("Service", serviceName)
130	} else {
131		methods := make([]string, 0, len(sd.GetMethods()))
132		for _, method := range sd.GetMethods() {
133			methods = append(methods, method.GetFullyQualifiedName())
134		}
135		sort.Strings(methods)
136		return methods, nil
137	}
138}
139
140// MetadataFromHeaders converts a list of header strings (each string in
141// "Header-Name: Header-Value" form) into metadata. If a string has a header
142// name without a value (e.g. does not contain a colon), the value is assumed
143// to be blank. Binary headers (those whose names end in "-bin") should be
144// base64-encoded. But if they cannot be base64-decoded, they will be assumed to
145// be in raw form and used as is.
146func MetadataFromHeaders(headers []string) metadata.MD {
147	md := make(metadata.MD)
148	for _, part := range headers {
149		if part != "" {
150			pieces := strings.SplitN(part, ":", 2)
151			if len(pieces) == 1 {
152				pieces = append(pieces, "") // if no value was specified, just make it "" (maybe the header value doesn't matter)
153			}
154			headerName := strings.ToLower(strings.TrimSpace(pieces[0]))
155			val := strings.TrimSpace(pieces[1])
156			if strings.HasSuffix(headerName, "-bin") {
157				if v, err := decode(val); err == nil {
158					val = v
159				}
160			}
161			md[headerName] = append(md[headerName], val)
162		}
163	}
164	return md
165}
166
167var envVarRegex = regexp.MustCompile(`\${\w+}`)
168
169// ExpandHeaders expands environment variables contained in the header string.
170// If no corresponding environment variable is found an error is returned.
171// TODO: Add escaping for `${`
172func ExpandHeaders(headers []string) ([]string, error) {
173	expandedHeaders := make([]string, len(headers))
174	for idx, header := range headers {
175		if header == "" {
176			continue
177		}
178		results := envVarRegex.FindAllString(header, -1)
179		if len(results) == 0 {
180			expandedHeaders[idx] = headers[idx]
181			continue
182		}
183		expandedHeader := header
184		for _, result := range results {
185			envVarName := result[2 : len(result)-1] // strip leading `${` and trailing `}`
186			envVarValue, ok := os.LookupEnv(envVarName)
187			if !ok {
188				return nil, fmt.Errorf("header %q refers to missing environment variable %q", header, envVarName)
189			}
190			expandedHeader = strings.Replace(expandedHeader, result, envVarValue, -1)
191		}
192		expandedHeaders[idx] = expandedHeader
193	}
194	return expandedHeaders, nil
195}
196
197var base64Codecs = []*base64.Encoding{base64.StdEncoding, base64.URLEncoding, base64.RawStdEncoding, base64.RawURLEncoding}
198
199func decode(val string) (string, error) {
200	var firstErr error
201	var b []byte
202	// we are lenient and can accept any of the flavors of base64 encoding
203	for _, d := range base64Codecs {
204		var err error
205		b, err = d.DecodeString(val)
206		if err != nil {
207			if firstErr == nil {
208				firstErr = err
209			}
210			continue
211		}
212		return string(b), nil
213	}
214	return "", firstErr
215}
216
217// MetadataToString returns a string representation of the given metadata, for
218// displaying to users.
219func MetadataToString(md metadata.MD) string {
220	if len(md) == 0 {
221		return "(empty)"
222	}
223
224	keys := make([]string, 0, len(md))
225	for k := range md {
226		keys = append(keys, k)
227	}
228	sort.Strings(keys)
229
230	var b bytes.Buffer
231	first := true
232	for _, k := range keys {
233		vs := md[k]
234		for _, v := range vs {
235			if first {
236				first = false
237			} else {
238				b.WriteString("\n")
239			}
240			b.WriteString(k)
241			b.WriteString(": ")
242			if strings.HasSuffix(k, "-bin") {
243				v = base64.StdEncoding.EncodeToString([]byte(v))
244			}
245			b.WriteString(v)
246		}
247	}
248	return b.String()
249}
250
251var printer = &protoprint.Printer{
252	Compact:                  true,
253	OmitComments:             protoprint.CommentsNonDoc,
254	SortElements:             true,
255	ForceFullyQualifiedNames: true,
256}
257
258// GetDescriptorText returns a string representation of the given descriptor.
259// This returns a snippet of proto source that describes the given element.
260func GetDescriptorText(dsc desc.Descriptor, _ DescriptorSource) (string, error) {
261	// Note: DescriptorSource is not used, but remains an argument for backwards
262	// compatibility with previous implementation.
263	txt, err := printer.PrintProtoToString(dsc)
264	if err != nil {
265		return "", err
266	}
267	// callers don't expect trailing newlines
268	if txt[len(txt)-1] == '\n' {
269		txt = txt[:len(txt)-1]
270	}
271	return txt, nil
272}
273
274// EnsureExtensions uses the given descriptor source to download extensions for
275// the given message. It returns a copy of the given message, but as a dynamic
276// message that knows about all extensions known to the given descriptor source.
277func EnsureExtensions(source DescriptorSource, msg proto.Message) proto.Message {
278	// load any server extensions so we can properly describe custom options
279	dsc, err := desc.LoadMessageDescriptorForMessage(msg)
280	if err != nil {
281		return msg
282	}
283
284	var ext dynamic.ExtensionRegistry
285	if err = fetchAllExtensions(source, &ext, dsc, map[string]bool{}); err != nil {
286		return msg
287	}
288
289	// convert message into dynamic message that knows about applicable extensions
290	// (that way we can show meaningful info for custom options instead of printing as unknown)
291	msgFactory := dynamic.NewMessageFactoryWithExtensionRegistry(&ext)
292	dm, err := fullyConvertToDynamic(msgFactory, msg)
293	if err != nil {
294		return msg
295	}
296	return dm
297}
298
299// fetchAllExtensions recursively fetches from the server extensions for the given message type as well as
300// for all message types of nested fields. The extensions are added to the given dynamic registry of extensions
301// so that all server-known extensions can be correctly parsed by grpcurl.
302func fetchAllExtensions(source DescriptorSource, ext *dynamic.ExtensionRegistry, md *desc.MessageDescriptor, alreadyFetched map[string]bool) error {
303	msgTypeName := md.GetFullyQualifiedName()
304	if alreadyFetched[msgTypeName] {
305		return nil
306	}
307	alreadyFetched[msgTypeName] = true
308	if len(md.GetExtensionRanges()) > 0 {
309		fds, err := source.AllExtensionsForType(msgTypeName)
310		if err != nil {
311			return fmt.Errorf("failed to query for extensions of type %s: %v", msgTypeName, err)
312		}
313		for _, fd := range fds {
314			if err := ext.AddExtension(fd); err != nil {
315				return fmt.Errorf("could not register extension %s of type %s: %v", fd.GetFullyQualifiedName(), msgTypeName, err)
316			}
317		}
318	}
319	// recursively fetch extensions for the types of any message fields
320	for _, fd := range md.GetFields() {
321		if fd.GetMessageType() != nil {
322			err := fetchAllExtensions(source, ext, fd.GetMessageType(), alreadyFetched)
323			if err != nil {
324				return err
325			}
326		}
327	}
328	return nil
329}
330
331// fullConvertToDynamic attempts to convert the given message to a dynamic message as well
332// as any nested messages it may contain as field values. If the given message factory has
333// extensions registered that were not known when the given message was parsed, this effectively
334// allows re-parsing to identify those extensions.
335func fullyConvertToDynamic(msgFact *dynamic.MessageFactory, msg proto.Message) (proto.Message, error) {
336	if _, ok := msg.(*dynamic.Message); ok {
337		return msg, nil // already a dynamic message
338	}
339	md, err := desc.LoadMessageDescriptorForMessage(msg)
340	if err != nil {
341		return nil, err
342	}
343	newMsg := msgFact.NewMessage(md)
344	dm, ok := newMsg.(*dynamic.Message)
345	if !ok {
346		// if message factory didn't produce a dynamic message, then we should leave msg as is
347		return msg, nil
348	}
349
350	if err := dm.ConvertFrom(msg); err != nil {
351		return nil, err
352	}
353
354	// recursively convert all field values, too
355	for _, fd := range md.GetFields() {
356		if fd.IsMap() {
357			if fd.GetMapValueType().GetMessageType() != nil {
358				m := dm.GetField(fd).(map[interface{}]interface{})
359				for k, v := range m {
360					// keys can't be nested messages; so we only need to recurse through map values, not keys
361					newVal, err := fullyConvertToDynamic(msgFact, v.(proto.Message))
362					if err != nil {
363						return nil, err
364					}
365					dm.PutMapField(fd, k, newVal)
366				}
367			}
368		} else if fd.IsRepeated() {
369			if fd.GetMessageType() != nil {
370				s := dm.GetField(fd).([]interface{})
371				for i, e := range s {
372					newVal, err := fullyConvertToDynamic(msgFact, e.(proto.Message))
373					if err != nil {
374						return nil, err
375					}
376					dm.SetRepeatedField(fd, i, newVal)
377				}
378			}
379		} else {
380			if fd.GetMessageType() != nil {
381				v := dm.GetField(fd)
382				newVal, err := fullyConvertToDynamic(msgFact, v.(proto.Message))
383				if err != nil {
384					return nil, err
385				}
386				dm.SetField(fd, newVal)
387			}
388		}
389	}
390	return dm, nil
391}
392
393// MakeTemplate returns a message instance for the given descriptor that is a
394// suitable template for creating an instance of that message in JSON. In
395// particular, it ensures that any repeated fields (which include map fields)
396// are not empty, so they will render with a single element (to show the types
397// and optionally nested fields). It also ensures that nested messages are not
398// nil by setting them to a message that is also fleshed out as a template
399// message.
400func MakeTemplate(md *desc.MessageDescriptor) proto.Message {
401	return makeTemplate(md, nil)
402}
403
404func makeTemplate(md *desc.MessageDescriptor, path []*desc.MessageDescriptor) proto.Message {
405	switch md.GetFullyQualifiedName() {
406	case "google.protobuf.Any":
407		// empty type URL is not allowed by JSON representation
408		// so we must give it a dummy type
409		var any anypb.Any
410		_ = anypb.MarshalFrom(&any, &emptypb.Empty{}, protov2.MarshalOptions{})
411		return &any
412	case "google.protobuf.Value":
413		// unset kind is not allowed by JSON representation
414		// so we must give it something
415		return &structpb.Value{
416			Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{
417				Fields: map[string]*structpb.Value{
418					"google.protobuf.Value": {Kind: &structpb.Value_StringValue{
419						StringValue: "supports arbitrary JSON",
420					}},
421				},
422			}},
423		}
424	case "google.protobuf.ListValue":
425		return &structpb.ListValue{
426			Values: []*structpb.Value{
427				{
428					Kind: &structpb.Value_StructValue{StructValue: &structpb.Struct{
429						Fields: map[string]*structpb.Value{
430							"google.protobuf.ListValue": {Kind: &structpb.Value_StringValue{
431								StringValue: "is an array of arbitrary JSON values",
432							}},
433						},
434					}},
435				},
436			},
437		}
438	case "google.protobuf.Struct":
439		return &structpb.Struct{
440			Fields: map[string]*structpb.Value{
441				"google.protobuf.Struct": {Kind: &structpb.Value_StringValue{
442					StringValue: "supports arbitrary JSON objects",
443				}},
444			},
445		}
446	}
447
448	dm := dynamic.NewMessage(md)
449
450	// if the message is a recursive structure, we don't want to blow the stack
451	for _, seen := range path {
452		if seen == md {
453			// already visited this type; avoid infinite recursion
454			return dm
455		}
456	}
457	path = append(path, dm.GetMessageDescriptor())
458
459	// for repeated fields, add a single element with default value
460	// and for message fields, add a message with all default fields
461	// that also has non-nil message and non-empty repeated fields
462
463	for _, fd := range dm.GetMessageDescriptor().GetFields() {
464		if fd.IsRepeated() {
465			switch fd.GetType() {
466			case descriptorpb.FieldDescriptorProto_TYPE_FIXED32,
467				descriptorpb.FieldDescriptorProto_TYPE_UINT32:
468				dm.AddRepeatedField(fd, uint32(0))
469
470			case descriptorpb.FieldDescriptorProto_TYPE_SFIXED32,
471				descriptorpb.FieldDescriptorProto_TYPE_SINT32,
472				descriptorpb.FieldDescriptorProto_TYPE_INT32,
473				descriptorpb.FieldDescriptorProto_TYPE_ENUM:
474				dm.AddRepeatedField(fd, int32(0))
475
476			case descriptorpb.FieldDescriptorProto_TYPE_FIXED64,
477				descriptorpb.FieldDescriptorProto_TYPE_UINT64:
478				dm.AddRepeatedField(fd, uint64(0))
479
480			case descriptorpb.FieldDescriptorProto_TYPE_SFIXED64,
481				descriptorpb.FieldDescriptorProto_TYPE_SINT64,
482				descriptorpb.FieldDescriptorProto_TYPE_INT64:
483				dm.AddRepeatedField(fd, int64(0))
484
485			case descriptorpb.FieldDescriptorProto_TYPE_STRING:
486				dm.AddRepeatedField(fd, "")
487
488			case descriptorpb.FieldDescriptorProto_TYPE_BYTES:
489				dm.AddRepeatedField(fd, []byte{})
490
491			case descriptorpb.FieldDescriptorProto_TYPE_BOOL:
492				dm.AddRepeatedField(fd, false)
493
494			case descriptorpb.FieldDescriptorProto_TYPE_FLOAT:
495				dm.AddRepeatedField(fd, float32(0))
496
497			case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE:
498				dm.AddRepeatedField(fd, float64(0))
499
500			case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE,
501				descriptorpb.FieldDescriptorProto_TYPE_GROUP:
502				dm.AddRepeatedField(fd, makeTemplate(fd.GetMessageType(), path))
503			}
504		} else if fd.GetMessageType() != nil {
505			dm.SetField(fd, makeTemplate(fd.GetMessageType(), path))
506		}
507	}
508	return dm
509}
510
511// ClientTransportCredentials is a helper function that constructs a TLS config with
512// the given properties (see ClientTLSConfig) and then constructs and returns gRPC
513// transport credentials using that config.
514//
515// Deprecated: Use grpcurl.ClientTLSConfig and credentials.NewTLS instead.
516func ClientTransportCredentials(insecureSkipVerify bool, cacertFile, clientCertFile, clientKeyFile string) (credentials.TransportCredentials, error) {
517	tlsConf, err := ClientTLSConfig(insecureSkipVerify, cacertFile, clientCertFile, clientKeyFile)
518	if err != nil {
519		return nil, err
520	}
521
522	return credentials.NewTLS(tlsConf), nil
523}
524
525// ClientTLSConfig builds transport-layer config for a gRPC client using the
526// given properties. If cacertFile is blank, only standard trusted certs are used to
527// verify the server certs. If clientCertFile is blank, the client will not use a client
528// certificate. If clientCertFile is not blank then clientKeyFile must not be blank.
529func ClientTLSConfig(insecureSkipVerify bool, cacertFile, clientCertFile, clientKeyFile string) (*tls.Config, error) {
530	var tlsConf tls.Config
531
532	if clientCertFile != "" {
533		// Load the client certificates from disk
534		certificate, err := tls.LoadX509KeyPair(clientCertFile, clientKeyFile)
535		if err != nil {
536			return nil, fmt.Errorf("could not load client key pair: %v", err)
537		}
538		tlsConf.Certificates = []tls.Certificate{certificate}
539	}
540
541	if insecureSkipVerify {
542		tlsConf.InsecureSkipVerify = true
543	} else if cacertFile != "" {
544		// Create a certificate pool from the certificate authority
545		certPool := x509.NewCertPool()
546		ca, err := ioutil.ReadFile(cacertFile)
547		if err != nil {
548			return nil, fmt.Errorf("could not read ca certificate: %v", err)
549		}
550
551		// Append the certificates from the CA
552		if ok := certPool.AppendCertsFromPEM(ca); !ok {
553			return nil, errors.New("failed to append ca certs")
554		}
555
556		tlsConf.RootCAs = certPool
557	}
558
559	return &tlsConf, nil
560}
561
562// ServerTransportCredentials builds transport credentials for a gRPC server using the
563// given properties. If cacertFile is blank, the server will not request client certs
564// unless requireClientCerts is true. When requireClientCerts is false and cacertFile is
565// not blank, the server will verify client certs when presented, but will not require
566// client certs. The serverCertFile and serverKeyFile must both not be blank.
567func ServerTransportCredentials(cacertFile, serverCertFile, serverKeyFile string, requireClientCerts bool) (credentials.TransportCredentials, error) {
568	var tlsConf tls.Config
569	// TODO(jh): Remove this line once https://github.com/golang/go/issues/28779 is fixed
570	// in Go tip. Until then, the recently merged TLS 1.3 support breaks the TLS tests.
571	tlsConf.MaxVersion = tls.VersionTLS12
572
573	// Load the server certificates from disk
574	certificate, err := tls.LoadX509KeyPair(serverCertFile, serverKeyFile)
575	if err != nil {
576		return nil, fmt.Errorf("could not load key pair: %v", err)
577	}
578	tlsConf.Certificates = []tls.Certificate{certificate}
579
580	if cacertFile != "" {
581		// Create a certificate pool from the certificate authority
582		certPool := x509.NewCertPool()
583		ca, err := ioutil.ReadFile(cacertFile)
584		if err != nil {
585			return nil, fmt.Errorf("could not read ca certificate: %v", err)
586		}
587
588		// Append the certificates from the CA
589		if ok := certPool.AppendCertsFromPEM(ca); !ok {
590			return nil, errors.New("failed to append ca certs")
591		}
592
593		tlsConf.ClientCAs = certPool
594	}
595
596	if requireClientCerts {
597		tlsConf.ClientAuth = tls.RequireAndVerifyClientCert
598	} else if cacertFile != "" {
599		tlsConf.ClientAuth = tls.VerifyClientCertIfGiven
600	} else {
601		tlsConf.ClientAuth = tls.NoClientCert
602	}
603
604	return credentials.NewTLS(&tlsConf), nil
605}
606
607// BlockingDial is a helper method to dial the given address, using optional TLS credentials,
608// and blocking until the returned connection is ready. If the given credentials are nil, the
609// connection will be insecure (plain-text).
610func BlockingDial(ctx context.Context, network, address string, creds credentials.TransportCredentials, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
611	// grpc.Dial doesn't provide any information on permanent connection errors (like
612	// TLS handshake failures). So in order to provide good error messages, we need a
613	// custom dialer that can provide that info. That means we manage the TLS handshake.
614	result := make(chan interface{}, 1)
615
616	writeResult := func(res interface{}) {
617		// non-blocking write: we only need the first result
618		select {
619		case result <- res:
620		default:
621		}
622	}
623
624	// custom credentials and dialer will notify on error via the
625	// writeResult function
626	if creds != nil {
627		creds = &errSignalingCreds{
628			TransportCredentials: creds,
629			writeResult:          writeResult,
630		}
631	}
632	dialer := func(ctx context.Context, address string) (net.Conn, error) {
633		// NB: We *could* handle the TLS handshake ourselves, in the custom
634		// dialer (instead of customizing both the dialer and the credentials).
635		// But that requires using WithInsecure dial option (so that the gRPC
636		// library doesn't *also* try to do a handshake). And that would mean
637		// that the library would send the wrong ":scheme" metaheader to
638		// servers: it would send "http" instead of "https" because it is
639		// unaware that TLS is actually in use.
640		conn, err := (&net.Dialer{}).DialContext(ctx, network, address)
641		if err != nil {
642			writeResult(err)
643		}
644		return conn, err
645	}
646
647	// Even with grpc.FailOnNonTempDialError, this call will usually timeout in
648	// the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to
649	// know when we're done. So we run it in a goroutine and then use result
650	// channel to either get the connection or fail-fast.
651	go func() {
652		// We put grpc.FailOnNonTempDialError *before* the explicitly provided
653		// options so that it could be overridden.
654		opts = append([]grpc.DialOption{grpc.FailOnNonTempDialError(true)}, opts...)
655		// But we don't want caller to be able to override these two, so we put
656		// them *after* the explicitly provided options.
657		opts = append(opts, grpc.WithBlock(), grpc.WithContextDialer(dialer))
658
659		if creds == nil {
660			opts = append(opts, grpc.WithInsecure())
661		} else {
662			opts = append(opts, grpc.WithTransportCredentials(creds))
663		}
664		conn, err := grpc.DialContext(ctx, address, opts...)
665		var res interface{}
666		if err != nil {
667			res = err
668		} else {
669			res = conn
670		}
671		writeResult(res)
672	}()
673
674	select {
675	case res := <-result:
676		if conn, ok := res.(*grpc.ClientConn); ok {
677			return conn, nil
678		}
679		return nil, res.(error)
680	case <-ctx.Done():
681		return nil, ctx.Err()
682	}
683}
684
685// errSignalingCreds is a wrapper around a TransportCredentials value, but
686// it will use the writeResult function to notify on error.
687type errSignalingCreds struct {
688	credentials.TransportCredentials
689	writeResult func(res interface{})
690}
691
692func (c *errSignalingCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
693	conn, auth, err := c.TransportCredentials.ClientHandshake(ctx, addr, rawConn)
694	if err != nil {
695		c.writeResult(err)
696	}
697	return conn, auth, err
698}
699