1/*
2 *
3 * Copyright 2018 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package test
20
21import (
22	"context"
23	"errors"
24	"fmt"
25	"net"
26	"reflect"
27	"testing"
28	"time"
29
30	"github.com/google/go-cmp/cmp"
31	"google.golang.org/grpc"
32	"google.golang.org/grpc/attributes"
33	"google.golang.org/grpc/balancer"
34	"google.golang.org/grpc/balancer/roundrobin"
35	"google.golang.org/grpc/codes"
36	"google.golang.org/grpc/connectivity"
37	"google.golang.org/grpc/credentials"
38	"google.golang.org/grpc/internal/balancer/stub"
39	"google.golang.org/grpc/internal/balancerload"
40	"google.golang.org/grpc/internal/grpcsync"
41	"google.golang.org/grpc/internal/grpcutil"
42	imetadata "google.golang.org/grpc/internal/metadata"
43	"google.golang.org/grpc/internal/stubserver"
44	"google.golang.org/grpc/internal/testutils"
45	"google.golang.org/grpc/metadata"
46	"google.golang.org/grpc/resolver"
47	"google.golang.org/grpc/resolver/manual"
48	"google.golang.org/grpc/status"
49	testpb "google.golang.org/grpc/test/grpc_testing"
50	"google.golang.org/grpc/testdata"
51)
52
53const testBalancerName = "testbalancer"
54
55// testBalancer creates one subconn with the first address from resolved
56// addresses.
57//
58// It's used to test whether options for NewSubConn are applied correctly.
59type testBalancer struct {
60	cc balancer.ClientConn
61	sc balancer.SubConn
62
63	newSubConnOptions balancer.NewSubConnOptions
64	pickInfos         []balancer.PickInfo
65	pickExtraMDs      []metadata.MD
66	doneInfo          []balancer.DoneInfo
67}
68
69func (b *testBalancer) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer {
70	b.cc = cc
71	return b
72}
73
74func (*testBalancer) Name() string {
75	return testBalancerName
76}
77
78func (*testBalancer) ResolverError(err error) {
79	panic("not implemented")
80}
81
82func (b *testBalancer) UpdateClientConnState(state balancer.ClientConnState) error {
83	// Only create a subconn at the first time.
84	if b.sc == nil {
85		var err error
86		b.sc, err = b.cc.NewSubConn(state.ResolverState.Addresses, b.newSubConnOptions)
87		if err != nil {
88			logger.Errorf("testBalancer: failed to NewSubConn: %v", err)
89			return nil
90		}
91		b.cc.UpdateState(balancer.State{ConnectivityState: connectivity.Connecting, Picker: &picker{sc: b.sc, bal: b}})
92		b.sc.Connect()
93	}
94	return nil
95}
96
97func (b *testBalancer) UpdateSubConnState(sc balancer.SubConn, s balancer.SubConnState) {
98	logger.Infof("testBalancer: UpdateSubConnState: %p, %v", sc, s)
99	if b.sc != sc {
100		logger.Infof("testBalancer: ignored state change because sc is not recognized")
101		return
102	}
103	if s.ConnectivityState == connectivity.Shutdown {
104		b.sc = nil
105		return
106	}
107
108	switch s.ConnectivityState {
109	case connectivity.Ready, connectivity.Idle:
110		b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{sc: sc, bal: b}})
111	case connectivity.Connecting:
112		b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{err: balancer.ErrNoSubConnAvailable, bal: b}})
113	case connectivity.TransientFailure:
114		b.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: &picker{err: balancer.ErrTransientFailure, bal: b}})
115	}
116}
117
118func (b *testBalancer) Close() {}
119
120type picker struct {
121	err error
122	sc  balancer.SubConn
123	bal *testBalancer
124}
125
126func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
127	if p.err != nil {
128		return balancer.PickResult{}, p.err
129	}
130	extraMD, _ := grpcutil.ExtraMetadata(info.Ctx)
131	info.Ctx = nil // Do not validate context.
132	p.bal.pickInfos = append(p.bal.pickInfos, info)
133	p.bal.pickExtraMDs = append(p.bal.pickExtraMDs, extraMD)
134	return balancer.PickResult{SubConn: p.sc, Done: func(d balancer.DoneInfo) { p.bal.doneInfo = append(p.bal.doneInfo, d) }}, nil
135}
136
137func (s) TestCredsBundleFromBalancer(t *testing.T) {
138	balancer.Register(&testBalancer{
139		newSubConnOptions: balancer.NewSubConnOptions{
140			CredsBundle: &testCredsBundle{},
141		},
142	})
143	te := newTest(t, env{name: "creds-bundle", network: "tcp", balancer: ""})
144	te.tapHandle = authHandle
145	te.customDialOptions = []grpc.DialOption{
146		grpc.WithBalancerName(testBalancerName),
147	}
148	creds, err := credentials.NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
149	if err != nil {
150		t.Fatalf("Failed to generate credentials %v", err)
151	}
152	te.customServerOptions = []grpc.ServerOption{
153		grpc.Creds(creds),
154	}
155	te.startServer(&testServer{})
156	defer te.tearDown()
157
158	cc := te.clientConn()
159	tc := testpb.NewTestServiceClient(cc)
160	if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
161		t.Fatalf("Test failed. Reason: %v", err)
162	}
163}
164
165func (s) TestPickExtraMetadata(t *testing.T) {
166	for _, e := range listTestEnv() {
167		testPickExtraMetadata(t, e)
168	}
169}
170
171func testPickExtraMetadata(t *testing.T, e env) {
172	te := newTest(t, e)
173	b := &testBalancer{}
174	balancer.Register(b)
175	const (
176		testUserAgent      = "test-user-agent"
177		testSubContentType = "proto"
178	)
179
180	te.customDialOptions = []grpc.DialOption{
181		grpc.WithBalancerName(testBalancerName),
182		grpc.WithUserAgent(testUserAgent),
183	}
184	te.startServer(&testServer{security: e.security})
185	defer te.tearDown()
186
187	// Set resolver to xds to trigger the extra metadata code path.
188	r := manual.NewBuilderWithScheme("xds")
189	resolver.Register(r)
190	defer func() {
191		resolver.UnregisterForTesting("xds")
192	}()
193	r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: te.srvAddr}}})
194	te.resolverScheme = "xds"
195	cc := te.clientConn()
196	tc := testpb.NewTestServiceClient(cc)
197
198	// The RPCs will fail, but we don't care. We just need the pick to happen.
199	ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second)
200	defer cancel1()
201	tc.EmptyCall(ctx1, &testpb.Empty{})
202
203	ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second)
204	defer cancel2()
205	tc.EmptyCall(ctx2, &testpb.Empty{}, grpc.CallContentSubtype(testSubContentType))
206
207	want := []metadata.MD{
208		// First RPC doesn't have sub-content-type.
209		{"content-type": []string{"application/grpc"}},
210		// Second RPC has sub-content-type "proto".
211		{"content-type": []string{"application/grpc+proto"}},
212	}
213
214	if !cmp.Equal(b.pickExtraMDs, want) {
215		t.Fatalf("%s", cmp.Diff(b.pickExtraMDs, want))
216	}
217}
218
219func (s) TestDoneInfo(t *testing.T) {
220	for _, e := range listTestEnv() {
221		testDoneInfo(t, e)
222	}
223}
224
225func testDoneInfo(t *testing.T, e env) {
226	te := newTest(t, e)
227	b := &testBalancer{}
228	balancer.Register(b)
229	te.customDialOptions = []grpc.DialOption{
230		grpc.WithBalancerName(testBalancerName),
231	}
232	te.userAgent = failAppUA
233	te.startServer(&testServer{security: e.security})
234	defer te.tearDown()
235
236	cc := te.clientConn()
237	tc := testpb.NewTestServiceClient(cc)
238
239	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
240	defer cancel()
241	wantErr := detailedError
242	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); !testutils.StatusErrEqual(err, wantErr) {
243		t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, wantErr)
244	}
245	if _, err := tc.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil {
246		t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
247	}
248
249	if len(b.doneInfo) < 1 || !testutils.StatusErrEqual(b.doneInfo[0].Err, wantErr) {
250		t.Fatalf("b.doneInfo = %v; want b.doneInfo[0].Err = %v", b.doneInfo, wantErr)
251	}
252	if len(b.doneInfo) < 2 || !reflect.DeepEqual(b.doneInfo[1].Trailer, testTrailerMetadata) {
253		t.Fatalf("b.doneInfo = %v; want b.doneInfo[1].Trailer = %v", b.doneInfo, testTrailerMetadata)
254	}
255	if len(b.pickInfos) != len(b.doneInfo) {
256		t.Fatalf("Got %d picks, but %d doneInfo, want equal amount", len(b.pickInfos), len(b.doneInfo))
257	}
258	// To test done() is always called, even if it's returned with a non-Ready
259	// SubConn.
260	//
261	// Stop server and at the same time send RPCs. There are chances that picker
262	// is not updated in time, causing a non-Ready SubConn to be returned.
263	finished := make(chan struct{})
264	go func() {
265		for i := 0; i < 20; i++ {
266			tc.UnaryCall(ctx, &testpb.SimpleRequest{})
267		}
268		close(finished)
269	}()
270	te.srv.Stop()
271	<-finished
272	if len(b.pickInfos) != len(b.doneInfo) {
273		t.Fatalf("Got %d picks, %d doneInfo, want equal amount", len(b.pickInfos), len(b.doneInfo))
274	}
275}
276
277const loadMDKey = "X-Endpoint-Load-Metrics-Bin"
278
279type testLoadParser struct{}
280
281func (*testLoadParser) Parse(md metadata.MD) interface{} {
282	vs := md.Get(loadMDKey)
283	if len(vs) == 0 {
284		return nil
285	}
286	return vs[0]
287}
288
289func init() {
290	balancerload.SetParser(&testLoadParser{})
291}
292
293func (s) TestDoneLoads(t *testing.T) {
294	for _, e := range listTestEnv() {
295		testDoneLoads(t, e)
296	}
297}
298
299func testDoneLoads(t *testing.T, e env) {
300	b := &testBalancer{}
301	balancer.Register(b)
302
303	const testLoad = "test-load-,-should-be-orca"
304
305	ss := &stubserver.StubServer{
306		EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
307			grpc.SetTrailer(ctx, metadata.Pairs(loadMDKey, testLoad))
308			return &testpb.Empty{}, nil
309		},
310	}
311	if err := ss.Start(nil, grpc.WithBalancerName(testBalancerName)); err != nil {
312		t.Fatalf("error starting testing server: %v", err)
313	}
314	defer ss.Stop()
315
316	tc := testpb.NewTestServiceClient(ss.CC)
317
318	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
319	defer cancel()
320	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
321		t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, nil)
322	}
323
324	piWant := []balancer.PickInfo{
325		{FullMethodName: "/grpc.testing.TestService/EmptyCall"},
326	}
327	if !reflect.DeepEqual(b.pickInfos, piWant) {
328		t.Fatalf("b.pickInfos = %v; want %v", b.pickInfos, piWant)
329	}
330
331	if len(b.doneInfo) < 1 {
332		t.Fatalf("b.doneInfo = %v, want length 1", b.doneInfo)
333	}
334	gotLoad, _ := b.doneInfo[0].ServerLoad.(string)
335	if gotLoad != testLoad {
336		t.Fatalf("b.doneInfo[0].ServerLoad = %v; want = %v", b.doneInfo[0].ServerLoad, testLoad)
337	}
338}
339
340const testBalancerKeepAddressesName = "testbalancer-keepingaddresses"
341
342// testBalancerKeepAddresses keeps the addresses in the builder instead of
343// creating SubConns.
344//
345// It's used to test the addresses balancer gets are correct.
346type testBalancerKeepAddresses struct {
347	addrsChan chan []resolver.Address
348}
349
350func newTestBalancerKeepAddresses() *testBalancerKeepAddresses {
351	return &testBalancerKeepAddresses{
352		addrsChan: make(chan []resolver.Address, 10),
353	}
354}
355
356func (testBalancerKeepAddresses) ResolverError(err error) {
357	panic("not implemented")
358}
359
360func (b *testBalancerKeepAddresses) Build(cc balancer.ClientConn, opt balancer.BuildOptions) balancer.Balancer {
361	return b
362}
363
364func (*testBalancerKeepAddresses) Name() string {
365	return testBalancerKeepAddressesName
366}
367
368func (b *testBalancerKeepAddresses) UpdateClientConnState(state balancer.ClientConnState) error {
369	b.addrsChan <- state.ResolverState.Addresses
370	return nil
371}
372
373func (testBalancerKeepAddresses) UpdateSubConnState(sc balancer.SubConn, s balancer.SubConnState) {
374	panic("not used")
375}
376
377func (testBalancerKeepAddresses) Close() {
378}
379
380// Make sure that non-grpclb balancers don't get grpclb addresses even if name
381// resolver sends them
382func (s) TestNonGRPCLBBalancerGetsNoGRPCLBAddress(t *testing.T) {
383	r := manual.NewBuilderWithScheme("whatever")
384
385	b := newTestBalancerKeepAddresses()
386	balancer.Register(b)
387
388	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithResolvers(r),
389		grpc.WithBalancerName(b.Name()))
390	if err != nil {
391		t.Fatalf("failed to dial: %v", err)
392	}
393	defer cc.Close()
394
395	grpclbAddresses := []resolver.Address{{
396		Addr:       "grpc.lb.com",
397		Type:       resolver.GRPCLB,
398		ServerName: "grpc.lb.com",
399	}}
400
401	nonGRPCLBAddresses := []resolver.Address{{
402		Addr: "localhost",
403		Type: resolver.Backend,
404	}}
405
406	r.UpdateState(resolver.State{
407		Addresses: nonGRPCLBAddresses,
408	})
409	if got := <-b.addrsChan; !reflect.DeepEqual(got, nonGRPCLBAddresses) {
410		t.Fatalf("With only backend addresses, balancer got addresses %v, want %v", got, nonGRPCLBAddresses)
411	}
412
413	r.UpdateState(resolver.State{
414		Addresses: grpclbAddresses,
415	})
416	if got := <-b.addrsChan; len(got) != 0 {
417		t.Fatalf("With only grpclb addresses, balancer got addresses %v, want empty", got)
418	}
419
420	r.UpdateState(resolver.State{
421		Addresses: append(grpclbAddresses, nonGRPCLBAddresses...),
422	})
423	if got := <-b.addrsChan; !reflect.DeepEqual(got, nonGRPCLBAddresses) {
424		t.Fatalf("With both backend and grpclb addresses, balancer got addresses %v, want %v", got, nonGRPCLBAddresses)
425	}
426}
427
428type aiPicker struct {
429	result balancer.PickResult
430	err    error
431}
432
433func (aip *aiPicker) Pick(_ balancer.PickInfo) (balancer.PickResult, error) {
434	return aip.result, aip.err
435}
436
437// attrTransportCreds is a transport credential implementation which stores
438// Attributes from the ClientHandshakeInfo struct passed in the context locally
439// for the test to inspect.
440type attrTransportCreds struct {
441	credentials.TransportCredentials
442	attr *attributes.Attributes
443}
444
445func (ac *attrTransportCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
446	ai := credentials.ClientHandshakeInfoFromContext(ctx)
447	ac.attr = ai.Attributes
448	return rawConn, nil, nil
449}
450func (ac *attrTransportCreds) Info() credentials.ProtocolInfo {
451	return credentials.ProtocolInfo{}
452}
453func (ac *attrTransportCreds) Clone() credentials.TransportCredentials {
454	return nil
455}
456
457// TestAddressAttributesInNewSubConn verifies that the Attributes passed from a
458// balancer in the resolver.Address that is passes to NewSubConn reaches all the
459// way to the ClientHandshake method of the credentials configured on the parent
460// channel.
461func (s) TestAddressAttributesInNewSubConn(t *testing.T) {
462	const (
463		testAttrKey      = "foo"
464		testAttrVal      = "bar"
465		attrBalancerName = "attribute-balancer"
466	)
467
468	// Register a stub balancer which adds attributes to the first address that
469	// it receives and then calls NewSubConn on it.
470	bf := stub.BalancerFuncs{
471		UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
472			addrs := ccs.ResolverState.Addresses
473			if len(addrs) == 0 {
474				return nil
475			}
476
477			// Only use the first address.
478			attr := attributes.New(testAttrKey, testAttrVal)
479			addrs[0].Attributes = attr
480			sc, err := bd.ClientConn.NewSubConn([]resolver.Address{addrs[0]}, balancer.NewSubConnOptions{})
481			if err != nil {
482				return err
483			}
484			sc.Connect()
485			return nil
486		},
487		UpdateSubConnState: func(bd *stub.BalancerData, sc balancer.SubConn, state balancer.SubConnState) {
488			bd.ClientConn.UpdateState(balancer.State{ConnectivityState: state.ConnectivityState, Picker: &aiPicker{result: balancer.PickResult{SubConn: sc}, err: state.ConnectionError}})
489		},
490	}
491	stub.Register(attrBalancerName, bf)
492	t.Logf("Registered balancer %s...", attrBalancerName)
493
494	r := manual.NewBuilderWithScheme("whatever")
495	t.Logf("Registered manual resolver with scheme %s...", r.Scheme())
496
497	lis, err := net.Listen("tcp", "localhost:0")
498	if err != nil {
499		t.Fatal(err)
500	}
501
502	s := grpc.NewServer()
503	testpb.RegisterTestServiceServer(s, &testServer{})
504	go s.Serve(lis)
505	defer s.Stop()
506	t.Logf("Started gRPC server at %s...", lis.Addr().String())
507
508	creds := &attrTransportCreds{}
509	dopts := []grpc.DialOption{
510		grpc.WithTransportCredentials(creds),
511		grpc.WithResolvers(r),
512		grpc.WithDefaultServiceConfig(fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, attrBalancerName)),
513	}
514	cc, err := grpc.Dial(r.Scheme()+":///test.server", dopts...)
515	if err != nil {
516		t.Fatal(err)
517	}
518	defer cc.Close()
519	tc := testpb.NewTestServiceClient(cc)
520	t.Log("Created a ClientConn...")
521
522	// The first RPC should fail because there's no address.
523	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
524	defer cancel()
525	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
526		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
527	}
528	t.Log("Made an RPC which was expected to fail...")
529
530	state := resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}}
531	r.UpdateState(state)
532	t.Logf("Pushing resolver state update: %v through the manual resolver", state)
533
534	// The second RPC should succeed.
535	ctx, cancel = context.WithTimeout(context.Background(), time.Second)
536	defer cancel()
537	if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
538		t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
539	}
540	t.Log("Made an RPC which succeeded...")
541
542	wantAttr := attributes.New(testAttrKey, testAttrVal)
543	if gotAttr := creds.attr; !cmp.Equal(gotAttr, wantAttr, cmp.AllowUnexported(attributes.Attributes{})) {
544		t.Fatalf("received attributes %v in creds, want %v", gotAttr, wantAttr)
545	}
546}
547
548// TestMetadataInAddressAttributes verifies that the metadata added to
549// address.Attributes will be sent with the RPCs.
550func (s) TestMetadataInAddressAttributes(t *testing.T) {
551	const (
552		testMDKey      = "test-md"
553		testMDValue    = "test-md-value"
554		mdBalancerName = "metadata-balancer"
555	)
556
557	// Register a stub balancer which adds metadata to the first address that it
558	// receives and then calls NewSubConn on it.
559	bf := stub.BalancerFuncs{
560		UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
561			addrs := ccs.ResolverState.Addresses
562			if len(addrs) == 0 {
563				return nil
564			}
565			// Only use the first address.
566			sc, err := bd.ClientConn.NewSubConn([]resolver.Address{
567				imetadata.Set(addrs[0], metadata.Pairs(testMDKey, testMDValue)),
568			}, balancer.NewSubConnOptions{})
569			if err != nil {
570				return err
571			}
572			sc.Connect()
573			return nil
574		},
575		UpdateSubConnState: func(bd *stub.BalancerData, sc balancer.SubConn, state balancer.SubConnState) {
576			bd.ClientConn.UpdateState(balancer.State{ConnectivityState: state.ConnectivityState, Picker: &aiPicker{result: balancer.PickResult{SubConn: sc}, err: state.ConnectionError}})
577		},
578	}
579	stub.Register(mdBalancerName, bf)
580	t.Logf("Registered balancer %s...", mdBalancerName)
581
582	testMDChan := make(chan []string, 1)
583	ss := &stubserver.StubServer{
584		EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
585			md, ok := metadata.FromIncomingContext(ctx)
586			if ok {
587				select {
588				case testMDChan <- md[testMDKey]:
589				case <-ctx.Done():
590					return nil, ctx.Err()
591				}
592			}
593			return &testpb.Empty{}, nil
594		},
595	}
596	if err := ss.Start(nil, grpc.WithDefaultServiceConfig(
597		fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, mdBalancerName),
598	)); err != nil {
599		t.Fatalf("Error starting endpoint server: %v", err)
600	}
601	defer ss.Stop()
602
603	// The RPC should succeed with the expected md.
604	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
605	defer cancel()
606	if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
607		t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
608	}
609	t.Log("Made an RPC which succeeded...")
610
611	// The server should receive the test metadata.
612	md1 := <-testMDChan
613	if len(md1) == 0 || md1[0] != testMDValue {
614		t.Fatalf("got md: %v, want %v", md1, []string{testMDValue})
615	}
616}
617
618// TestServersSwap creates two servers and verifies the client switches between
619// them when the name resolver reports the first and then the second.
620func (s) TestServersSwap(t *testing.T) {
621	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
622	defer cancel()
623
624	// Initialize servers
625	reg := func(username string) (addr string, cleanup func()) {
626		lis, err := net.Listen("tcp", "localhost:0")
627		if err != nil {
628			t.Fatalf("Error while listening. Err: %v", err)
629		}
630		s := grpc.NewServer()
631		ts := &funcServer{
632			unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
633				return &testpb.SimpleResponse{Username: username}, nil
634			},
635		}
636		testpb.RegisterTestServiceServer(s, ts)
637		go s.Serve(lis)
638		return lis.Addr().String(), s.Stop
639	}
640	const one = "1"
641	addr1, cleanup := reg(one)
642	defer cleanup()
643	const two = "2"
644	addr2, cleanup := reg(two)
645	defer cleanup()
646
647	// Initialize client
648	r := manual.NewBuilderWithScheme("whatever")
649	r.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: addr1}}})
650	cc, err := grpc.DialContext(ctx, r.Scheme()+":///", grpc.WithInsecure(), grpc.WithResolvers(r))
651	if err != nil {
652		t.Fatalf("Error creating client: %v", err)
653	}
654	defer cc.Close()
655	client := testpb.NewTestServiceClient(cc)
656
657	// Confirm we are connected to the first server
658	if res, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil || res.Username != one {
659		t.Fatalf("UnaryCall(_) = %v, %v; want {Username: %q}, nil", res, err, one)
660	}
661
662	// Update resolver to report only the second server
663	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: addr2}}})
664
665	// Loop until new RPCs talk to server two.
666	for i := 0; i < 2000; i++ {
667		if res, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil {
668			t.Fatalf("UnaryCall(_) = _, %v; want _, nil", err)
669		} else if res.Username == two {
670			break // pass
671		}
672		time.Sleep(5 * time.Millisecond)
673	}
674}
675
676// TestEmptyAddrs verifies client behavior when a working connection is
677// removed.  In pick first and round-robin, both will continue using the old
678// connections.
679func (s) TestEmptyAddrs(t *testing.T) {
680	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
681	defer cancel()
682
683	// Initialize server
684	lis, err := net.Listen("tcp", "localhost:0")
685	if err != nil {
686		t.Fatalf("Error while listening. Err: %v", err)
687	}
688	s := grpc.NewServer()
689	defer s.Stop()
690	const one = "1"
691	ts := &funcServer{
692		unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
693			return &testpb.SimpleResponse{Username: one}, nil
694		},
695	}
696	testpb.RegisterTestServiceServer(s, ts)
697	go s.Serve(lis)
698
699	// Initialize pickfirst client
700	pfr := manual.NewBuilderWithScheme("whatever")
701	pfrnCalled := grpcsync.NewEvent()
702	pfr.ResolveNowCallback = func(resolver.ResolveNowOptions) {
703		pfrnCalled.Fire()
704	}
705	pfr.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}})
706
707	pfcc, err := grpc.DialContext(ctx, pfr.Scheme()+":///", grpc.WithInsecure(), grpc.WithResolvers(pfr))
708	if err != nil {
709		t.Fatalf("Error creating client: %v", err)
710	}
711	defer pfcc.Close()
712	pfclient := testpb.NewTestServiceClient(pfcc)
713
714	// Confirm we are connected to the server
715	if res, err := pfclient.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil || res.Username != one {
716		t.Fatalf("UnaryCall(_) = %v, %v; want {Username: %q}, nil", res, err, one)
717	}
718
719	// Remove all addresses.
720	pfr.UpdateState(resolver.State{})
721	// Wait for a ResolveNow call on the pick first client's resolver.
722	<-pfrnCalled.Done()
723
724	// Initialize roundrobin client
725	rrr := manual.NewBuilderWithScheme("whatever")
726
727	rrrnCalled := grpcsync.NewEvent()
728	rrr.ResolveNowCallback = func(resolver.ResolveNowOptions) {
729		rrrnCalled.Fire()
730	}
731	rrr.InitialState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}})
732
733	rrcc, err := grpc.DialContext(ctx, rrr.Scheme()+":///", grpc.WithInsecure(), grpc.WithResolvers(rrr),
734		grpc.WithDefaultServiceConfig(fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, roundrobin.Name)))
735	if err != nil {
736		t.Fatalf("Error creating client: %v", err)
737	}
738	defer rrcc.Close()
739	rrclient := testpb.NewTestServiceClient(rrcc)
740
741	// Confirm we are connected to the server
742	if res, err := rrclient.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil || res.Username != one {
743		t.Fatalf("UnaryCall(_) = %v, %v; want {Username: %q}, nil", res, err, one)
744	}
745
746	// Remove all addresses.
747	rrr.UpdateState(resolver.State{})
748	// Wait for a ResolveNow call on the round robin client's resolver.
749	<-rrrnCalled.Done()
750
751	// Confirm several new RPCs succeed on pick first.
752	for i := 0; i < 10; i++ {
753		if _, err := pfclient.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil {
754			t.Fatalf("UnaryCall(_) = _, %v; want _, nil", err)
755		}
756		time.Sleep(5 * time.Millisecond)
757	}
758
759	// Confirm several new RPCs succeed on round robin.
760	for i := 0; i < 10; i++ {
761		if _, err := pfclient.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil {
762			t.Fatalf("UnaryCall(_) = _, %v; want _, nil", err)
763		}
764		time.Sleep(5 * time.Millisecond)
765	}
766}
767
768func (s) TestWaitForReady(t *testing.T) {
769	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
770	defer cancel()
771
772	// Initialize server
773	lis, err := net.Listen("tcp", "localhost:0")
774	if err != nil {
775		t.Fatalf("Error while listening. Err: %v", err)
776	}
777	s := grpc.NewServer()
778	defer s.Stop()
779	const one = "1"
780	ts := &funcServer{
781		unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
782			return &testpb.SimpleResponse{Username: one}, nil
783		},
784	}
785	testpb.RegisterTestServiceServer(s, ts)
786	go s.Serve(lis)
787
788	// Initialize client
789	r := manual.NewBuilderWithScheme("whatever")
790
791	cc, err := grpc.DialContext(ctx, r.Scheme()+":///", grpc.WithInsecure(), grpc.WithResolvers(r))
792	if err != nil {
793		t.Fatalf("Error creating client: %v", err)
794	}
795	defer cc.Close()
796	client := testpb.NewTestServiceClient(cc)
797
798	// Report an error so non-WFR RPCs will give up early.
799	r.CC.ReportError(errors.New("fake resolver error"))
800
801	// Ensure the client is not connected to anything and fails non-WFR RPCs.
802	if res, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}); status.Code(err) != codes.Unavailable {
803		t.Fatalf("UnaryCall(_) = %v, %v; want _, Code()=%v", res, err, codes.Unavailable)
804	}
805
806	errChan := make(chan error, 1)
807	go func() {
808		if res, err := client.UnaryCall(ctx, &testpb.SimpleRequest{}, grpc.WaitForReady(true)); err != nil || res.Username != one {
809			errChan <- fmt.Errorf("UnaryCall(_) = %v, %v; want {Username: %q}, nil", res, err, one)
810		}
811		close(errChan)
812	}()
813
814	select {
815	case err := <-errChan:
816		t.Errorf("unexpected receive from errChan before addresses provided")
817		t.Fatal(err.Error())
818	case <-time.After(5 * time.Millisecond):
819	}
820
821	// Resolve the server.  The WFR RPC should unblock and use it.
822	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String()}}})
823
824	if err := <-errChan; err != nil {
825		t.Fatal(err.Error())
826	}
827}
828