1/*
2 *
3 * Copyright 2014 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package grpc
20
21import (
22	"context"
23	"errors"
24	"fmt"
25	"math"
26	"net"
27	"strings"
28	"sync/atomic"
29	"testing"
30	"time"
31
32	"golang.org/x/net/http2"
33	"google.golang.org/grpc/backoff"
34	"google.golang.org/grpc/connectivity"
35	"google.golang.org/grpc/credentials"
36	internalbackoff "google.golang.org/grpc/internal/backoff"
37	"google.golang.org/grpc/internal/transport"
38	"google.golang.org/grpc/keepalive"
39	"google.golang.org/grpc/resolver"
40	"google.golang.org/grpc/resolver/manual"
41	"google.golang.org/grpc/testdata"
42)
43
44func (s) TestDialWithTimeout(t *testing.T) {
45	lis, err := net.Listen("tcp", "localhost:0")
46	if err != nil {
47		t.Fatalf("Error while listening. Err: %v", err)
48	}
49	defer lis.Close()
50	lisAddr := resolver.Address{Addr: lis.Addr().String()}
51	lisDone := make(chan struct{})
52	dialDone := make(chan struct{})
53	// 1st listener accepts the connection and then does nothing
54	go func() {
55		defer close(lisDone)
56		conn, err := lis.Accept()
57		if err != nil {
58			t.Errorf("Error while accepting. Err: %v", err)
59			return
60		}
61		framer := http2.NewFramer(conn, conn)
62		if err := framer.WriteSettings(http2.Setting{}); err != nil {
63			t.Errorf("Error while writing settings. Err: %v", err)
64			return
65		}
66		<-dialDone // Close conn only after dial returns.
67	}()
68
69	r := manual.NewBuilderWithScheme("whatever")
70	r.InitialState(resolver.State{Addresses: []resolver.Address{lisAddr}})
71	client, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithResolvers(r), WithTimeout(5*time.Second))
72	close(dialDone)
73	if err != nil {
74		t.Fatalf("Dial failed. Err: %v", err)
75	}
76	defer client.Close()
77	timeout := time.After(1 * time.Second)
78	select {
79	case <-timeout:
80		t.Fatal("timed out waiting for server to finish")
81	case <-lisDone:
82	}
83}
84
85func (s) TestDialWithMultipleBackendsNotSendingServerPreface(t *testing.T) {
86	lis1, err := net.Listen("tcp", "localhost:0")
87	if err != nil {
88		t.Fatalf("Error while listening. Err: %v", err)
89	}
90	defer lis1.Close()
91	lis1Addr := resolver.Address{Addr: lis1.Addr().String()}
92	lis1Done := make(chan struct{})
93	// 1st listener accepts the connection and immediately closes it.
94	go func() {
95		defer close(lis1Done)
96		conn, err := lis1.Accept()
97		if err != nil {
98			t.Errorf("Error while accepting. Err: %v", err)
99			return
100		}
101		conn.Close()
102	}()
103
104	lis2, err := net.Listen("tcp", "localhost:0")
105	if err != nil {
106		t.Fatalf("Error while listening. Err: %v", err)
107	}
108	defer lis2.Close()
109	lis2Done := make(chan struct{})
110	lis2Addr := resolver.Address{Addr: lis2.Addr().String()}
111	// 2nd listener should get a connection attempt since the first one failed.
112	go func() {
113		defer close(lis2Done)
114		_, err := lis2.Accept() // Closing the client will clean up this conn.
115		if err != nil {
116			t.Errorf("Error while accepting. Err: %v", err)
117			return
118		}
119	}()
120
121	r := manual.NewBuilderWithScheme("whatever")
122	r.InitialState(resolver.State{Addresses: []resolver.Address{lis1Addr, lis2Addr}})
123	client, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithResolvers(r))
124	if err != nil {
125		t.Fatalf("Dial failed. Err: %v", err)
126	}
127	defer client.Close()
128	timeout := time.After(5 * time.Second)
129	select {
130	case <-timeout:
131		t.Fatal("timed out waiting for server 1 to finish")
132	case <-lis1Done:
133	}
134	select {
135	case <-timeout:
136		t.Fatal("timed out waiting for server 2 to finish")
137	case <-lis2Done:
138	}
139}
140
141func (s) TestDialWaitsForServerSettings(t *testing.T) {
142	lis, err := net.Listen("tcp", "localhost:0")
143	if err != nil {
144		t.Fatalf("Error while listening. Err: %v", err)
145	}
146	defer lis.Close()
147	done := make(chan struct{})
148	sent := make(chan struct{})
149	dialDone := make(chan struct{})
150	go func() { // Launch the server.
151		defer func() {
152			close(done)
153		}()
154		conn, err := lis.Accept()
155		if err != nil {
156			t.Errorf("Error while accepting. Err: %v", err)
157			return
158		}
159		defer conn.Close()
160		// Sleep for a little bit to make sure that Dial on client
161		// side blocks until settings are received.
162		time.Sleep(100 * time.Millisecond)
163		framer := http2.NewFramer(conn, conn)
164		close(sent)
165		if err := framer.WriteSettings(http2.Setting{}); err != nil {
166			t.Errorf("Error while writing settings. Err: %v", err)
167			return
168		}
169		<-dialDone // Close conn only after dial returns.
170	}()
171	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
172	defer cancel()
173	client, err := DialContext(ctx, lis.Addr().String(), WithInsecure(), WithBlock())
174	close(dialDone)
175	if err != nil {
176		t.Fatalf("Error while dialing. Err: %v", err)
177	}
178	defer client.Close()
179	select {
180	case <-sent:
181	default:
182		t.Fatalf("Dial returned before server settings were sent")
183	}
184	<-done
185}
186
187func (s) TestDialWaitsForServerSettingsAndFails(t *testing.T) {
188	lis, err := net.Listen("tcp", "localhost:0")
189	if err != nil {
190		t.Fatalf("Error while listening. Err: %v", err)
191	}
192	done := make(chan struct{})
193	numConns := 0
194	go func() { // Launch the server.
195		defer func() {
196			close(done)
197		}()
198		for {
199			conn, err := lis.Accept()
200			if err != nil {
201				break
202			}
203			numConns++
204			defer conn.Close()
205		}
206	}()
207	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
208	defer cancel()
209	client, err := DialContext(ctx,
210		lis.Addr().String(),
211		WithInsecure(),
212		WithReturnConnectionError(),
213		withBackoff(noBackoff{}),
214		withMinConnectDeadline(func() time.Duration { return time.Second / 4 }))
215	lis.Close()
216	if err == nil {
217		client.Close()
218		t.Fatalf("Unexpected success (err=nil) while dialing")
219	}
220	expectedMsg := "server handshake"
221	if !strings.Contains(err.Error(), context.DeadlineExceeded.Error()) || !strings.Contains(err.Error(), expectedMsg) {
222		t.Fatalf("DialContext(_) = %v; want a message that includes both %q and %q", err, context.DeadlineExceeded.Error(), expectedMsg)
223	}
224	<-done
225	if numConns < 2 {
226		t.Fatalf("dial attempts: %v; want > 1", numConns)
227	}
228}
229
230// 1. Client connects to a server that doesn't send preface.
231// 2. After minConnectTimeout(500 ms here), client disconnects and retries.
232// 3. The new server sends its preface.
233// 4. Client doesn't kill the connection this time.
234func (s) TestCloseConnectionWhenServerPrefaceNotReceived(t *testing.T) {
235	lis, err := net.Listen("tcp", "localhost:0")
236	if err != nil {
237		t.Fatalf("Error while listening. Err: %v", err)
238	}
239	var (
240		conn2 net.Conn
241		over  uint32
242	)
243	defer func() {
244		lis.Close()
245		// conn2 shouldn't be closed until the client has
246		// observed a successful test.
247		if conn2 != nil {
248			conn2.Close()
249		}
250	}()
251	done := make(chan struct{})
252	accepted := make(chan struct{})
253	go func() { // Launch the server.
254		defer close(done)
255		conn1, err := lis.Accept()
256		if err != nil {
257			t.Errorf("Error while accepting. Err: %v", err)
258			return
259		}
260		defer conn1.Close()
261		// Don't send server settings and the client should close the connection and try again.
262		conn2, err = lis.Accept() // Accept a reconnection request from client.
263		if err != nil {
264			t.Errorf("Error while accepting. Err: %v", err)
265			return
266		}
267		close(accepted)
268		framer := http2.NewFramer(conn2, conn2)
269		if err = framer.WriteSettings(http2.Setting{}); err != nil {
270			t.Errorf("Error while writing settings. Err: %v", err)
271			return
272		}
273		b := make([]byte, 8)
274		for {
275			_, err = conn2.Read(b)
276			if err == nil {
277				continue
278			}
279			if atomic.LoadUint32(&over) == 1 {
280				// The connection stayed alive for the timer.
281				// Success.
282				return
283			}
284			t.Errorf("Unexpected error while reading. Err: %v, want timeout error", err)
285			break
286		}
287	}()
288	client, err := Dial(lis.Addr().String(), WithInsecure(), withMinConnectDeadline(func() time.Duration { return time.Millisecond * 500 }))
289	if err != nil {
290		t.Fatalf("Error while dialing. Err: %v", err)
291	}
292	// wait for connection to be accepted on the server.
293	timer := time.NewTimer(time.Second * 10)
294	select {
295	case <-accepted:
296	case <-timer.C:
297		t.Fatalf("Client didn't make another connection request in time.")
298	}
299	// Make sure the connection stays alive for sometime.
300	time.Sleep(time.Second)
301	atomic.StoreUint32(&over, 1)
302	client.Close()
303	<-done
304}
305
306func (s) TestBackoffWhenNoServerPrefaceReceived(t *testing.T) {
307	lis, err := net.Listen("tcp", "localhost:0")
308	if err != nil {
309		t.Fatalf("Error while listening. Err: %v", err)
310	}
311	defer lis.Close()
312	done := make(chan struct{})
313	go func() { // Launch the server.
314		defer func() {
315			close(done)
316		}()
317		conn, err := lis.Accept() // Accept the connection only to close it immediately.
318		if err != nil {
319			t.Errorf("Error while accepting. Err: %v", err)
320			return
321		}
322		prevAt := time.Now()
323		conn.Close()
324		var prevDuration time.Duration
325		// Make sure the retry attempts are backed off properly.
326		for i := 0; i < 3; i++ {
327			conn, err := lis.Accept()
328			if err != nil {
329				t.Errorf("Error while accepting. Err: %v", err)
330				return
331			}
332			meow := time.Now()
333			conn.Close()
334			dr := meow.Sub(prevAt)
335			if dr <= prevDuration {
336				t.Errorf("Client backoff did not increase with retries. Previous duration: %v, current duration: %v", prevDuration, dr)
337				return
338			}
339			prevDuration = dr
340			prevAt = meow
341		}
342	}()
343	client, err := Dial(lis.Addr().String(), WithInsecure())
344	if err != nil {
345		t.Fatalf("Error while dialing. Err: %v", err)
346	}
347	defer client.Close()
348	<-done
349
350}
351
352func (s) TestWithTimeout(t *testing.T) {
353	conn, err := Dial("passthrough:///Non-Existent.Server:80", WithTimeout(time.Millisecond), WithBlock(), WithInsecure())
354	if err == nil {
355		conn.Close()
356	}
357	if err != context.DeadlineExceeded {
358		t.Fatalf("Dial(_, _) = %v, %v, want %v", conn, err, context.DeadlineExceeded)
359	}
360}
361
362func (s) TestWithTransportCredentialsTLS(t *testing.T) {
363	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
364	defer cancel()
365	creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com")
366	if err != nil {
367		t.Fatalf("Failed to create credentials %v", err)
368	}
369	conn, err := DialContext(ctx, "passthrough:///Non-Existent.Server:80", WithTransportCredentials(creds), WithBlock())
370	if err == nil {
371		conn.Close()
372	}
373	if err != context.DeadlineExceeded {
374		t.Fatalf("Dial(_, _) = %v, %v, want %v", conn, err, context.DeadlineExceeded)
375	}
376}
377
378func (s) TestDefaultAuthority(t *testing.T) {
379	target := "Non-Existent.Server:8080"
380	conn, err := Dial(target, WithInsecure())
381	if err != nil {
382		t.Fatalf("Dial(_, _) = _, %v, want _, <nil>", err)
383	}
384	defer conn.Close()
385	if conn.authority != target {
386		t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, target)
387	}
388}
389
390func (s) TestTLSServerNameOverwrite(t *testing.T) {
391	overwriteServerName := "over.write.server.name"
392	creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), overwriteServerName)
393	if err != nil {
394		t.Fatalf("Failed to create credentials %v", err)
395	}
396	conn, err := Dial("passthrough:///Non-Existent.Server:80", WithTransportCredentials(creds))
397	if err != nil {
398		t.Fatalf("Dial(_, _) = _, %v, want _, <nil>", err)
399	}
400	defer conn.Close()
401	if conn.authority != overwriteServerName {
402		t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, overwriteServerName)
403	}
404}
405
406func (s) TestWithAuthority(t *testing.T) {
407	overwriteServerName := "over.write.server.name"
408	conn, err := Dial("passthrough:///Non-Existent.Server:80", WithInsecure(), WithAuthority(overwriteServerName))
409	if err != nil {
410		t.Fatalf("Dial(_, _) = _, %v, want _, <nil>", err)
411	}
412	defer conn.Close()
413	if conn.authority != overwriteServerName {
414		t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, overwriteServerName)
415	}
416}
417
418func (s) TestWithAuthorityAndTLS(t *testing.T) {
419	overwriteServerName := "over.write.server.name"
420	creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), overwriteServerName)
421	if err != nil {
422		t.Fatalf("Failed to create credentials %v", err)
423	}
424	conn, err := Dial("passthrough:///Non-Existent.Server:80", WithTransportCredentials(creds), WithAuthority("no.effect.authority"))
425	if err != nil {
426		t.Fatalf("Dial(_, _) = _, %v, want _, <nil>", err)
427	}
428	defer conn.Close()
429	if conn.authority != overwriteServerName {
430		t.Fatalf("%v.authority = %v, want %v", conn, conn.authority, overwriteServerName)
431	}
432}
433
434// When creating a transport configured with n addresses, only calculate the
435// backoff once per "round" of attempts instead of once per address (n times
436// per "round" of attempts).
437func (s) TestDial_OneBackoffPerRetryGroup(t *testing.T) {
438	var attempts uint32
439	getMinConnectTimeout := func() time.Duration {
440		if atomic.AddUint32(&attempts, 1) == 1 {
441			// Once all addresses are exhausted, hang around and wait for the
442			// client.Close to happen rather than re-starting a new round of
443			// attempts.
444			return time.Hour
445		}
446		t.Error("only one attempt backoff calculation, but got more")
447		return 0
448	}
449
450	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
451	defer cancel()
452
453	lis1, err := net.Listen("tcp", "localhost:0")
454	if err != nil {
455		t.Fatalf("Error while listening. Err: %v", err)
456	}
457	defer lis1.Close()
458
459	lis2, err := net.Listen("tcp", "localhost:0")
460	if err != nil {
461		t.Fatalf("Error while listening. Err: %v", err)
462	}
463	defer lis2.Close()
464
465	server1Done := make(chan struct{})
466	server2Done := make(chan struct{})
467
468	// Launch server 1.
469	go func() {
470		conn, err := lis1.Accept()
471		if err != nil {
472			t.Error(err)
473			return
474		}
475
476		conn.Close()
477		close(server1Done)
478	}()
479	// Launch server 2.
480	go func() {
481		conn, err := lis2.Accept()
482		if err != nil {
483			t.Error(err)
484			return
485		}
486		conn.Close()
487		close(server2Done)
488	}()
489
490	rb := manual.NewBuilderWithScheme("whatever")
491	rb.InitialState(resolver.State{Addresses: []resolver.Address{
492		{Addr: lis1.Addr().String()},
493		{Addr: lis2.Addr().String()},
494	}})
495	client, err := DialContext(ctx, "whatever:///this-gets-overwritten",
496		WithInsecure(),
497		WithBalancerName(stateRecordingBalancerName),
498		WithResolvers(rb),
499		withMinConnectDeadline(getMinConnectTimeout))
500	if err != nil {
501		t.Fatal(err)
502	}
503	defer client.Close()
504
505	timeout := time.After(15 * time.Second)
506
507	select {
508	case <-timeout:
509		t.Fatal("timed out waiting for test to finish")
510	case <-server1Done:
511	}
512
513	select {
514	case <-timeout:
515		t.Fatal("timed out waiting for test to finish")
516	case <-server2Done:
517	}
518}
519
520func (s) TestDialContextCancel(t *testing.T) {
521	ctx, cancel := context.WithCancel(context.Background())
522	cancel()
523	if _, err := DialContext(ctx, "Non-Existent.Server:80", WithBlock(), WithInsecure()); err != context.Canceled {
524		t.Fatalf("DialContext(%v, _) = _, %v, want _, %v", ctx, err, context.Canceled)
525	}
526}
527
528type failFastError struct{}
529
530func (failFastError) Error() string   { return "failfast" }
531func (failFastError) Temporary() bool { return false }
532
533func (s) TestDialContextFailFast(t *testing.T) {
534	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
535	defer cancel()
536	failErr := failFastError{}
537	dialer := func(string, time.Duration) (net.Conn, error) {
538		return nil, failErr
539	}
540
541	_, err := DialContext(ctx, "Non-Existent.Server:80", WithBlock(), WithInsecure(), WithDialer(dialer), FailOnNonTempDialError(true))
542	if terr, ok := err.(transport.ConnectionError); !ok || terr.Origin() != failErr {
543		t.Fatalf("DialContext() = _, %v, want _, %v", err, failErr)
544	}
545}
546
547// securePerRPCCredentials always requires transport security.
548type securePerRPCCredentials struct{}
549
550func (c securePerRPCCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
551	return nil, nil
552}
553
554func (c securePerRPCCredentials) RequireTransportSecurity() bool {
555	return true
556}
557
558func (s) TestCredentialsMisuse(t *testing.T) {
559	tlsCreds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com")
560	if err != nil {
561		t.Fatalf("Failed to create authenticator %v", err)
562	}
563	// Two conflicting credential configurations
564	if _, err := Dial("passthrough:///Non-Existent.Server:80", WithTransportCredentials(tlsCreds), WithBlock(), WithInsecure()); err != errCredentialsConflict {
565		t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsConflict)
566	}
567	// security info on insecure connection
568	if _, err := Dial("passthrough:///Non-Existent.Server:80", WithPerRPCCredentials(securePerRPCCredentials{}), WithBlock(), WithInsecure()); err != errTransportCredentialsMissing {
569		t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errTransportCredentialsMissing)
570	}
571}
572
573func (s) TestWithBackoffConfigDefault(t *testing.T) {
574	testBackoffConfigSet(t, internalbackoff.DefaultExponential)
575}
576
577func (s) TestWithBackoffConfig(t *testing.T) {
578	b := BackoffConfig{MaxDelay: DefaultBackoffConfig.MaxDelay / 2}
579	bc := backoff.DefaultConfig
580	bc.MaxDelay = b.MaxDelay
581	wantBackoff := internalbackoff.Exponential{Config: bc}
582	testBackoffConfigSet(t, wantBackoff, WithBackoffConfig(b))
583}
584
585func (s) TestWithBackoffMaxDelay(t *testing.T) {
586	md := DefaultBackoffConfig.MaxDelay / 2
587	bc := backoff.DefaultConfig
588	bc.MaxDelay = md
589	wantBackoff := internalbackoff.Exponential{Config: bc}
590	testBackoffConfigSet(t, wantBackoff, WithBackoffMaxDelay(md))
591}
592
593func (s) TestWithConnectParams(t *testing.T) {
594	bd := 2 * time.Second
595	mltpr := 2.0
596	jitter := 0.0
597	bc := backoff.Config{BaseDelay: bd, Multiplier: mltpr, Jitter: jitter}
598
599	crt := ConnectParams{Backoff: bc}
600	// MaxDelay is not set in the ConnectParams. So it should not be set on
601	// internalbackoff.Exponential as well.
602	wantBackoff := internalbackoff.Exponential{Config: bc}
603	testBackoffConfigSet(t, wantBackoff, WithConnectParams(crt))
604}
605
606func testBackoffConfigSet(t *testing.T, wantBackoff internalbackoff.Exponential, opts ...DialOption) {
607	opts = append(opts, WithInsecure())
608	conn, err := Dial("passthrough:///foo:80", opts...)
609	if err != nil {
610		t.Fatalf("unexpected error dialing connection: %v", err)
611	}
612	defer conn.Close()
613
614	if conn.dopts.bs == nil {
615		t.Fatalf("backoff config not set")
616	}
617
618	gotBackoff, ok := conn.dopts.bs.(internalbackoff.Exponential)
619	if !ok {
620		t.Fatalf("unexpected type of backoff config: %#v", conn.dopts.bs)
621	}
622
623	if gotBackoff != wantBackoff {
624		t.Fatalf("unexpected backoff config on connection: %v, want %v", gotBackoff, wantBackoff)
625	}
626}
627
628func (s) TestConnectParamsWithMinConnectTimeout(t *testing.T) {
629	// Default value specified for minConnectTimeout in the spec is 20 seconds.
630	mct := 1 * time.Minute
631	conn, err := Dial("passthrough:///foo:80", WithInsecure(), WithConnectParams(ConnectParams{MinConnectTimeout: mct}))
632	if err != nil {
633		t.Fatalf("unexpected error dialing connection: %v", err)
634	}
635	defer conn.Close()
636
637	if got := conn.dopts.minConnectTimeout(); got != mct {
638		t.Errorf("unexpect minConnectTimeout on the connection: %v, want %v", got, mct)
639	}
640}
641
642func (s) TestResolverServiceConfigBeforeAddressNotPanic(t *testing.T) {
643	r := manual.NewBuilderWithScheme("whatever")
644
645	cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithResolvers(r))
646	if err != nil {
647		t.Fatalf("failed to dial: %v", err)
648	}
649	defer cc.Close()
650
651	// SwitchBalancer before NewAddress. There was no balancer created, this
652	// makes sure we don't call close on nil balancerWrapper.
653	r.UpdateState(resolver.State{ServiceConfig: parseCfg(r, `{"loadBalancingPolicy": "round_robin"}`)}) // This should not panic.
654
655	time.Sleep(time.Second) // Sleep to make sure the service config is handled by ClientConn.
656}
657
658func (s) TestResolverServiceConfigWhileClosingNotPanic(t *testing.T) {
659	for i := 0; i < 10; i++ { // Run this multiple times to make sure it doesn't panic.
660		r := manual.NewBuilderWithScheme(fmt.Sprintf("whatever-%d", i))
661
662		cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithResolvers(r))
663		if err != nil {
664			t.Fatalf("failed to dial: %v", err)
665		}
666		// Send a new service config while closing the ClientConn.
667		go cc.Close()
668		go r.UpdateState(resolver.State{ServiceConfig: parseCfg(r, `{"loadBalancingPolicy": "round_robin"}`)}) // This should not panic.
669	}
670}
671
672func (s) TestResolverEmptyUpdateNotPanic(t *testing.T) {
673	r := manual.NewBuilderWithScheme("whatever")
674
675	cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithResolvers(r))
676	if err != nil {
677		t.Fatalf("failed to dial: %v", err)
678	}
679	defer cc.Close()
680
681	// This make sure we don't create addrConn with empty address list.
682	r.UpdateState(resolver.State{}) // This should not panic.
683
684	time.Sleep(time.Second) // Sleep to make sure the service config is handled by ClientConn.
685}
686
687func (s) TestClientUpdatesParamsAfterGoAway(t *testing.T) {
688	lis, err := net.Listen("tcp", "localhost:0")
689	if err != nil {
690		t.Fatalf("Failed to listen. Err: %v", err)
691	}
692	defer lis.Close()
693	connected := make(chan struct{})
694	go func() {
695		conn, err := lis.Accept()
696		if err != nil {
697			t.Errorf("error accepting connection: %v", err)
698			return
699		}
700		defer conn.Close()
701		f := http2.NewFramer(conn, conn)
702		// Start a goroutine to read from the conn to prevent the client from
703		// blocking after it writes its preface.
704		go func() {
705			for {
706				if _, err := f.ReadFrame(); err != nil {
707					return
708				}
709			}
710		}()
711		if err := f.WriteSettings(http2.Setting{}); err != nil {
712			t.Errorf("error writing settings: %v", err)
713			return
714		}
715		<-connected
716		if err := f.WriteGoAway(0, http2.ErrCodeEnhanceYourCalm, []byte("too_many_pings")); err != nil {
717			t.Errorf("error writing GOAWAY: %v", err)
718			return
719		}
720	}()
721	addr := lis.Addr().String()
722	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
723	defer cancel()
724	cc, err := DialContext(ctx, addr, WithBlock(), WithInsecure(), WithKeepaliveParams(keepalive.ClientParameters{
725		Time:                10 * time.Second,
726		Timeout:             100 * time.Millisecond,
727		PermitWithoutStream: true,
728	}))
729	if err != nil {
730		t.Fatalf("Dial(%s, _) = _, %v, want _, <nil>", addr, err)
731	}
732	defer cc.Close()
733	close(connected)
734	for {
735		time.Sleep(10 * time.Millisecond)
736		cc.mu.RLock()
737		v := cc.mkp.Time
738		if v == 20*time.Second {
739			// Success
740			cc.mu.RUnlock()
741			return
742		}
743		if ctx.Err() != nil {
744			// Timeout
745			t.Fatalf("cc.dopts.copts.Keepalive.Time = %v , want 20s", v)
746		}
747		cc.mu.RUnlock()
748	}
749}
750
751func (s) TestDisableServiceConfigOption(t *testing.T) {
752	r := manual.NewBuilderWithScheme("whatever")
753	addr := r.Scheme() + ":///non.existent"
754	cc, err := Dial(addr, WithInsecure(), WithResolvers(r), WithDisableServiceConfig())
755	if err != nil {
756		t.Fatalf("Dial(%s, _) = _, %v, want _, <nil>", addr, err)
757	}
758	defer cc.Close()
759	r.UpdateState(resolver.State{ServiceConfig: parseCfg(r, `{
760    "methodConfig": [
761        {
762            "name": [
763                {
764                    "service": "foo",
765                    "method": "Bar"
766                }
767            ],
768            "waitForReady": true
769        }
770    ]
771}`)})
772	time.Sleep(1 * time.Second)
773	m := cc.GetMethodConfig("/foo/Bar")
774	if m.WaitForReady != nil {
775		t.Fatalf("want: method (\"/foo/bar/\") config to be empty, got: %+v", m)
776	}
777}
778
779func (s) TestMethodConfigDefaultService(t *testing.T) {
780	addr := "nonexist:///non.existent"
781	cc, err := Dial(addr, WithInsecure(), WithDefaultServiceConfig(`{
782  "methodConfig": [{
783    "name": [
784      {
785        "service": ""
786      }
787    ],
788    "waitForReady": true
789  }]
790}`))
791	if err != nil {
792		t.Fatalf("Dial(%s, _) = _, %v, want _, <nil>", addr, err)
793	}
794	defer cc.Close()
795
796	m := cc.GetMethodConfig("/foo/Bar")
797	if m.WaitForReady == nil {
798		t.Fatalf("want: method (%q) config to fallback to the default service", "/foo/Bar")
799	}
800}
801
802func (s) TestGetClientConnTarget(t *testing.T) {
803	addr := "nonexist:///non.existent"
804	cc, err := Dial(addr, WithInsecure())
805	if err != nil {
806		t.Fatalf("Dial(%s, _) = _, %v, want _, <nil>", addr, err)
807	}
808	defer cc.Close()
809	if cc.Target() != addr {
810		t.Fatalf("Target() = %s, want %s", cc.Target(), addr)
811	}
812}
813
814type backoffForever struct{}
815
816func (b backoffForever) Backoff(int) time.Duration { return time.Duration(math.MaxInt64) }
817
818func (s) TestResetConnectBackoff(t *testing.T) {
819	dials := make(chan struct{})
820	defer func() { // If we fail, let the http2client break out of dialing.
821		select {
822		case <-dials:
823		default:
824		}
825	}()
826	dialer := func(string, time.Duration) (net.Conn, error) {
827		dials <- struct{}{}
828		return nil, errors.New("failed to fake dial")
829	}
830	cc, err := Dial("any", WithInsecure(), WithDialer(dialer), withBackoff(backoffForever{}))
831	if err != nil {
832		t.Fatalf("Dial() = _, %v; want _, nil", err)
833	}
834	defer cc.Close()
835	select {
836	case <-dials:
837	case <-time.NewTimer(10 * time.Second).C:
838		t.Fatal("Failed to call dial within 10s")
839	}
840
841	select {
842	case <-dials:
843		t.Fatal("Dial called unexpectedly before resetting backoff")
844	case <-time.NewTimer(100 * time.Millisecond).C:
845	}
846
847	cc.ResetConnectBackoff()
848
849	select {
850	case <-dials:
851	case <-time.NewTimer(10 * time.Second).C:
852		t.Fatal("Failed to call dial within 10s after resetting backoff")
853	}
854}
855
856func (s) TestBackoffCancel(t *testing.T) {
857	dialStrCh := make(chan string)
858	cc, err := Dial("any", WithInsecure(), WithDialer(func(t string, _ time.Duration) (net.Conn, error) {
859		dialStrCh <- t
860		return nil, fmt.Errorf("test dialer, always error")
861	}))
862	if err != nil {
863		t.Fatalf("Failed to create ClientConn: %v", err)
864	}
865	<-dialStrCh
866	cc.Close()
867	// Should not leak. May need -count 5000 to exercise.
868}
869
870// UpdateAddresses should cause the next reconnect to begin from the top of the
871// list if the connection is not READY.
872func (s) TestUpdateAddresses_RetryFromFirstAddr(t *testing.T) {
873	lis1, err := net.Listen("tcp", "localhost:0")
874	if err != nil {
875		t.Fatalf("Error while listening. Err: %v", err)
876	}
877	defer lis1.Close()
878
879	lis2, err := net.Listen("tcp", "localhost:0")
880	if err != nil {
881		t.Fatalf("Error while listening. Err: %v", err)
882	}
883	defer lis2.Close()
884
885	lis3, err := net.Listen("tcp", "localhost:0")
886	if err != nil {
887		t.Fatalf("Error while listening. Err: %v", err)
888	}
889	defer lis3.Close()
890
891	closeServer2 := make(chan struct{})
892	server1ContactedFirstTime := make(chan struct{})
893	server1ContactedSecondTime := make(chan struct{})
894	server2ContactedFirstTime := make(chan struct{})
895	server2ContactedSecondTime := make(chan struct{})
896	server3Contacted := make(chan struct{})
897
898	// Launch server 1.
899	go func() {
900		// First, let's allow the initial connection to go READY. We need to do
901		// this because tryUpdateAddrs only works after there's some non-nil
902		// address on the ac, and curAddress is only set after READY.
903		conn1, err := lis1.Accept()
904		if err != nil {
905			t.Error(err)
906			return
907		}
908		go keepReading(conn1)
909
910		framer := http2.NewFramer(conn1, conn1)
911		if err := framer.WriteSettings(http2.Setting{}); err != nil {
912			t.Errorf("Error while writing settings frame. %v", err)
913			return
914		}
915
916		// nextStateNotifier() is updated after balancerBuilder.Build(), which is
917		// called by grpc.Dial. It's safe to do it here because lis1.Accept blocks
918		// until balancer is built to process the addresses.
919		stateNotifications := testBalancerBuilder.nextStateNotifier()
920		// Wait for the transport to become ready.
921		for s := range stateNotifications {
922			if s == connectivity.Ready {
923				break
924			}
925		}
926
927		// Once it's ready, curAddress has been set. So let's close this
928		// connection prompting the first reconnect cycle.
929		conn1.Close()
930
931		// Accept and immediately close, causing it to go to server2.
932		conn2, err := lis1.Accept()
933		if err != nil {
934			t.Error(err)
935			return
936		}
937		close(server1ContactedFirstTime)
938		conn2.Close()
939
940		// Hopefully it picks this server after tryUpdateAddrs.
941		lis1.Accept()
942		close(server1ContactedSecondTime)
943	}()
944	// Launch server 2.
945	go func() {
946		// Accept and then hang waiting for the test call tryUpdateAddrs and
947		// then signal to this server to close. After this server closes, it
948		// should start from the top instead of trying server2 or continuing
949		// to server3.
950		conn, err := lis2.Accept()
951		if err != nil {
952			t.Error(err)
953			return
954		}
955
956		close(server2ContactedFirstTime)
957		<-closeServer2
958		conn.Close()
959
960		// After tryUpdateAddrs, it should NOT try server2.
961		lis2.Accept()
962		close(server2ContactedSecondTime)
963	}()
964	// Launch server 3.
965	go func() {
966		// After tryUpdateAddrs, it should NOT try server3. (or any other time)
967		lis3.Accept()
968		close(server3Contacted)
969	}()
970
971	addrsList := []resolver.Address{
972		{Addr: lis1.Addr().String()},
973		{Addr: lis2.Addr().String()},
974		{Addr: lis3.Addr().String()},
975	}
976	rb := manual.NewBuilderWithScheme("whatever")
977	rb.InitialState(resolver.State{Addresses: addrsList})
978
979	client, err := Dial("whatever:///this-gets-overwritten",
980		WithInsecure(),
981		WithResolvers(rb),
982		withBackoff(noBackoff{}),
983		WithBalancerName(stateRecordingBalancerName),
984		withMinConnectDeadline(func() time.Duration { return time.Hour }))
985	if err != nil {
986		t.Fatal(err)
987	}
988	defer client.Close()
989
990	timeout := time.After(5 * time.Second)
991
992	// Wait for server1 to be contacted (which will immediately fail), then
993	// server2 (which will hang waiting for our signal).
994	select {
995	case <-server1ContactedFirstTime:
996	case <-timeout:
997		t.Fatal("timed out waiting for server1 to be contacted")
998	}
999	select {
1000	case <-server2ContactedFirstTime:
1001	case <-timeout:
1002		t.Fatal("timed out waiting for server2 to be contacted")
1003	}
1004
1005	// Grab the addrConn and call tryUpdateAddrs.
1006	var ac *addrConn
1007	client.mu.Lock()
1008	for clientAC := range client.conns {
1009		ac = clientAC
1010		break
1011	}
1012	client.mu.Unlock()
1013
1014	ac.acbw.UpdateAddresses(addrsList)
1015
1016	// We've called tryUpdateAddrs - now let's make server2 close the
1017	// connection and check that it goes back to server1 instead of continuing
1018	// to server3 or trying server2 again.
1019	close(closeServer2)
1020
1021	select {
1022	case <-server1ContactedSecondTime:
1023	case <-server2ContactedSecondTime:
1024		t.Fatal("server2 was contacted a second time, but it after tryUpdateAddrs it should have re-started the list and tried server1")
1025	case <-server3Contacted:
1026		t.Fatal("server3 was contacted, but after tryUpdateAddrs it should have re-started the list and tried server1")
1027	case <-timeout:
1028		t.Fatal("timed out waiting for any server to be contacted after tryUpdateAddrs")
1029	}
1030}
1031
1032func (s) TestDefaultServiceConfig(t *testing.T) {
1033	r := manual.NewBuilderWithScheme("whatever")
1034	addr := r.Scheme() + ":///non.existent"
1035	js := `{
1036    "methodConfig": [
1037        {
1038            "name": [
1039                {
1040                    "service": "foo",
1041                    "method": "bar"
1042                }
1043            ],
1044            "waitForReady": true
1045        }
1046    ]
1047}`
1048	testInvalidDefaultServiceConfig(t)
1049	testDefaultServiceConfigWhenResolverServiceConfigDisabled(t, r, addr, js)
1050	testDefaultServiceConfigWhenResolverDoesNotReturnServiceConfig(t, r, addr, js)
1051	testDefaultServiceConfigWhenResolverReturnInvalidServiceConfig(t, r, addr, js)
1052}
1053
1054func verifyWaitForReadyEqualsTrue(cc *ClientConn) bool {
1055	var i int
1056	for i = 0; i < 10; i++ {
1057		mc := cc.GetMethodConfig("/foo/bar")
1058		if mc.WaitForReady != nil && *mc.WaitForReady == true {
1059			break
1060		}
1061		time.Sleep(100 * time.Millisecond)
1062	}
1063	return i != 10
1064}
1065
1066func testInvalidDefaultServiceConfig(t *testing.T) {
1067	_, err := Dial("fake.com", WithInsecure(), WithDefaultServiceConfig(""))
1068	if !strings.Contains(err.Error(), invalidDefaultServiceConfigErrPrefix) {
1069		t.Fatalf("Dial got err: %v, want err contains: %v", err, invalidDefaultServiceConfigErrPrefix)
1070	}
1071}
1072
1073func testDefaultServiceConfigWhenResolverServiceConfigDisabled(t *testing.T, r *manual.Resolver, addr string, js string) {
1074	cc, err := Dial(addr, WithInsecure(), WithDisableServiceConfig(), WithResolvers(r), WithDefaultServiceConfig(js))
1075	if err != nil {
1076		t.Fatalf("Dial(%s, _) = _, %v, want _, <nil>", addr, err)
1077	}
1078	defer cc.Close()
1079	// Resolver service config gets ignored since resolver service config is disabled.
1080	r.UpdateState(resolver.State{
1081		Addresses:     []resolver.Address{{Addr: addr}},
1082		ServiceConfig: parseCfg(r, "{}"),
1083	})
1084	if !verifyWaitForReadyEqualsTrue(cc) {
1085		t.Fatal("default service config failed to be applied after 1s")
1086	}
1087}
1088
1089func testDefaultServiceConfigWhenResolverDoesNotReturnServiceConfig(t *testing.T, r *manual.Resolver, addr string, js string) {
1090	cc, err := Dial(addr, WithInsecure(), WithResolvers(r), WithDefaultServiceConfig(js))
1091	if err != nil {
1092		t.Fatalf("Dial(%s, _) = _, %v, want _, <nil>", addr, err)
1093	}
1094	defer cc.Close()
1095	r.UpdateState(resolver.State{
1096		Addresses: []resolver.Address{{Addr: addr}},
1097	})
1098	if !verifyWaitForReadyEqualsTrue(cc) {
1099		t.Fatal("default service config failed to be applied after 1s")
1100	}
1101}
1102
1103func testDefaultServiceConfigWhenResolverReturnInvalidServiceConfig(t *testing.T, r *manual.Resolver, addr string, js string) {
1104	cc, err := Dial(addr, WithInsecure(), WithResolvers(r), WithDefaultServiceConfig(js))
1105	if err != nil {
1106		t.Fatalf("Dial(%s, _) = _, %v, want _, <nil>", addr, err)
1107	}
1108	defer cc.Close()
1109	r.UpdateState(resolver.State{
1110		Addresses: []resolver.Address{{Addr: addr}},
1111	})
1112	if !verifyWaitForReadyEqualsTrue(cc) {
1113		t.Fatal("default service config failed to be applied after 1s")
1114	}
1115}
1116