1/*
2 *
3 * Copyright 2017 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package roundrobin_test
20
21import (
22	"context"
23	"fmt"
24	"net"
25	"strings"
26	"sync"
27	"testing"
28	"time"
29
30	"google.golang.org/grpc"
31	"google.golang.org/grpc/balancer/roundrobin"
32	"google.golang.org/grpc/codes"
33	"google.golang.org/grpc/connectivity"
34	"google.golang.org/grpc/internal/grpctest"
35	imetadata "google.golang.org/grpc/internal/metadata"
36	"google.golang.org/grpc/metadata"
37	"google.golang.org/grpc/peer"
38	"google.golang.org/grpc/resolver"
39	"google.golang.org/grpc/resolver/manual"
40	"google.golang.org/grpc/status"
41	testpb "google.golang.org/grpc/test/grpc_testing"
42)
43
44const (
45	testMDKey = "test-md"
46)
47
48type s struct {
49	grpctest.Tester
50}
51
52func Test(t *testing.T) {
53	grpctest.RunSubTests(t, s{})
54}
55
56type testServer struct {
57	testpb.UnimplementedTestServiceServer
58
59	testMDChan chan []string
60}
61
62func newTestServer() *testServer {
63	return &testServer{testMDChan: make(chan []string, 1)}
64}
65
66func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
67	md, ok := metadata.FromIncomingContext(ctx)
68	if ok && len(md[testMDKey]) != 0 {
69		select {
70		case s.testMDChan <- md[testMDKey]:
71		case <-ctx.Done():
72			return nil, ctx.Err()
73		}
74	}
75	return &testpb.Empty{}, nil
76}
77
78func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
79	return nil
80}
81
82type test struct {
83	servers     []*grpc.Server
84	serverImpls []*testServer
85	addresses   []string
86}
87
88func (t *test) cleanup() {
89	for _, s := range t.servers {
90		s.Stop()
91	}
92}
93
94func startTestServers(count int) (_ *test, err error) {
95	t := &test{}
96
97	defer func() {
98		if err != nil {
99			t.cleanup()
100		}
101	}()
102	for i := 0; i < count; i++ {
103		lis, err := net.Listen("tcp", "localhost:0")
104		if err != nil {
105			return nil, fmt.Errorf("failed to listen %v", err)
106		}
107
108		s := grpc.NewServer()
109		sImpl := newTestServer()
110		testpb.RegisterTestServiceServer(s, sImpl)
111		t.servers = append(t.servers, s)
112		t.serverImpls = append(t.serverImpls, sImpl)
113		t.addresses = append(t.addresses, lis.Addr().String())
114
115		go func(s *grpc.Server, l net.Listener) {
116			s.Serve(l)
117		}(s, lis)
118	}
119
120	return t, nil
121}
122
123func (s) TestOneBackend(t *testing.T) {
124	r := manual.NewBuilderWithScheme("whatever")
125
126	test, err := startTestServers(1)
127	if err != nil {
128		t.Fatalf("failed to start servers: %v", err)
129	}
130	defer test.cleanup()
131
132	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithResolvers(r), grpc.WithBalancerName(roundrobin.Name))
133	if err != nil {
134		t.Fatalf("failed to dial: %v", err)
135	}
136	defer cc.Close()
137	testc := testpb.NewTestServiceClient(cc)
138	// The first RPC should fail because there's no address.
139	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
140	defer cancel()
141	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
142		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
143	}
144
145	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: test.addresses[0]}}})
146	// The second RPC should succeed.
147	if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
148		t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
149	}
150}
151
152func (s) TestBackendsRoundRobin(t *testing.T) {
153	r := manual.NewBuilderWithScheme("whatever")
154
155	backendCount := 5
156	test, err := startTestServers(backendCount)
157	if err != nil {
158		t.Fatalf("failed to start servers: %v", err)
159	}
160	defer test.cleanup()
161
162	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithResolvers(r), grpc.WithBalancerName(roundrobin.Name))
163	if err != nil {
164		t.Fatalf("failed to dial: %v", err)
165	}
166	defer cc.Close()
167	testc := testpb.NewTestServiceClient(cc)
168	// The first RPC should fail because there's no address.
169	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
170	defer cancel()
171	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
172		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
173	}
174
175	var resolvedAddrs []resolver.Address
176	for i := 0; i < backendCount; i++ {
177		resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: test.addresses[i]})
178	}
179
180	r.UpdateState(resolver.State{Addresses: resolvedAddrs})
181	var p peer.Peer
182	// Make sure connections to all servers are up.
183	for si := 0; si < backendCount; si++ {
184		var connected bool
185		for i := 0; i < 1000; i++ {
186			if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
187				t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
188			}
189			if p.Addr.String() == test.addresses[si] {
190				connected = true
191				break
192			}
193			time.Sleep(time.Millisecond)
194		}
195		if !connected {
196			t.Fatalf("Connection to %v was not up after more than 1 second", test.addresses[si])
197		}
198	}
199
200	for i := 0; i < 3*backendCount; i++ {
201		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
202			t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
203		}
204		if p.Addr.String() != test.addresses[i%backendCount] {
205			t.Fatalf("Index %d: want peer %v, got peer %v", i, test.addresses[i%backendCount], p.Addr.String())
206		}
207	}
208}
209
210func (s) TestAddressesRemoved(t *testing.T) {
211	r := manual.NewBuilderWithScheme("whatever")
212
213	test, err := startTestServers(1)
214	if err != nil {
215		t.Fatalf("failed to start servers: %v", err)
216	}
217	defer test.cleanup()
218
219	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithResolvers(r), grpc.WithBalancerName(roundrobin.Name))
220	if err != nil {
221		t.Fatalf("failed to dial: %v", err)
222	}
223	defer cc.Close()
224	testc := testpb.NewTestServiceClient(cc)
225	// The first RPC should fail because there's no address.
226	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
227	defer cancel()
228	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
229		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
230	}
231
232	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: test.addresses[0]}}})
233	// The second RPC should succeed.
234	if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
235		t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
236	}
237
238	r.UpdateState(resolver.State{Addresses: []resolver.Address{}})
239
240	ctx2, cancel2 := context.WithTimeout(context.Background(), 500*time.Millisecond)
241	defer cancel2()
242	// Wait for state to change to transient failure.
243	for src := cc.GetState(); src != connectivity.TransientFailure; src = cc.GetState() {
244		if !cc.WaitForStateChange(ctx2, src) {
245			t.Fatalf("timed out waiting for state change.  got %v; want %v", src, connectivity.TransientFailure)
246		}
247	}
248
249	const msgWant = "produced zero addresses"
250	if _, err := testc.EmptyCall(ctx2, &testpb.Empty{}); err == nil || !strings.Contains(status.Convert(err).Message(), msgWant) {
251		t.Fatalf("EmptyCall() = _, %v, want _, Contains(Message(), %q)", err, msgWant)
252	}
253}
254
255func (s) TestCloseWithPendingRPC(t *testing.T) {
256	r := manual.NewBuilderWithScheme("whatever")
257
258	test, err := startTestServers(1)
259	if err != nil {
260		t.Fatalf("failed to start servers: %v", err)
261	}
262	defer test.cleanup()
263
264	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithResolvers(r), grpc.WithBalancerName(roundrobin.Name))
265	if err != nil {
266		t.Fatalf("failed to dial: %v", err)
267	}
268	testc := testpb.NewTestServiceClient(cc)
269
270	var wg sync.WaitGroup
271	for i := 0; i < 3; i++ {
272		wg.Add(1)
273		go func() {
274			defer wg.Done()
275			// This RPC blocks until cc is closed.
276			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
277			if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) == codes.DeadlineExceeded {
278				t.Errorf("RPC failed because of deadline after cc is closed; want error the client connection is closing")
279			}
280			cancel()
281		}()
282	}
283	cc.Close()
284	wg.Wait()
285}
286
287func (s) TestNewAddressWhileBlocking(t *testing.T) {
288	r := manual.NewBuilderWithScheme("whatever")
289
290	test, err := startTestServers(1)
291	if err != nil {
292		t.Fatalf("failed to start servers: %v", err)
293	}
294	defer test.cleanup()
295
296	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithResolvers(r), grpc.WithBalancerName(roundrobin.Name))
297	if err != nil {
298		t.Fatalf("failed to dial: %v", err)
299	}
300	defer cc.Close()
301	testc := testpb.NewTestServiceClient(cc)
302	// The first RPC should fail because there's no address.
303	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
304	defer cancel()
305	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
306		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
307	}
308
309	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: test.addresses[0]}}})
310	// The second RPC should succeed.
311	ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
312	defer cancel()
313	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
314		t.Fatalf("EmptyCall() = _, %v, want _, nil", err)
315	}
316
317	r.UpdateState(resolver.State{Addresses: []resolver.Address{}})
318
319	var wg sync.WaitGroup
320	for i := 0; i < 3; i++ {
321		wg.Add(1)
322		go func() {
323			defer wg.Done()
324			// This RPC blocks until NewAddress is called.
325			testc.EmptyCall(context.Background(), &testpb.Empty{})
326		}()
327	}
328	time.Sleep(50 * time.Millisecond)
329	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: test.addresses[0]}}})
330	wg.Wait()
331}
332
333func (s) TestOneServerDown(t *testing.T) {
334	r := manual.NewBuilderWithScheme("whatever")
335
336	backendCount := 3
337	test, err := startTestServers(backendCount)
338	if err != nil {
339		t.Fatalf("failed to start servers: %v", err)
340	}
341	defer test.cleanup()
342
343	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithResolvers(r), grpc.WithBalancerName(roundrobin.Name))
344	if err != nil {
345		t.Fatalf("failed to dial: %v", err)
346	}
347	defer cc.Close()
348	testc := testpb.NewTestServiceClient(cc)
349	// The first RPC should fail because there's no address.
350	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
351	defer cancel()
352	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
353		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
354	}
355
356	var resolvedAddrs []resolver.Address
357	for i := 0; i < backendCount; i++ {
358		resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: test.addresses[i]})
359	}
360
361	r.UpdateState(resolver.State{Addresses: resolvedAddrs})
362	var p peer.Peer
363	// Make sure connections to all servers are up.
364	for si := 0; si < backendCount; si++ {
365		var connected bool
366		for i := 0; i < 1000; i++ {
367			if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
368				t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
369			}
370			if p.Addr.String() == test.addresses[si] {
371				connected = true
372				break
373			}
374			time.Sleep(time.Millisecond)
375		}
376		if !connected {
377			t.Fatalf("Connection to %v was not up after more than 1 second", test.addresses[si])
378		}
379	}
380
381	for i := 0; i < 3*backendCount; i++ {
382		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
383			t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
384		}
385		if p.Addr.String() != test.addresses[i%backendCount] {
386			t.Fatalf("Index %d: want peer %v, got peer %v", i, test.addresses[i%backendCount], p.Addr.String())
387		}
388	}
389
390	// Stop one server, RPCs should roundrobin among the remaining servers.
391	backendCount--
392	test.servers[backendCount].Stop()
393	// Loop until see server[backendCount-1] twice without seeing server[backendCount].
394	var targetSeen int
395	for i := 0; i < 1000; i++ {
396		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
397			targetSeen = 0
398			t.Logf("EmptyCall() = _, %v, want _, <nil>", err)
399			// Due to a race, this RPC could possibly get the connection that
400			// was closing, and this RPC may fail. Keep trying when this
401			// happens.
402			continue
403		}
404		switch p.Addr.String() {
405		case test.addresses[backendCount-1]:
406			targetSeen++
407		case test.addresses[backendCount]:
408			// Reset targetSeen if peer is server[backendCount].
409			targetSeen = 0
410		}
411		// Break to make sure the last picked address is server[-1], so the following for loop won't be flaky.
412		if targetSeen >= 2 {
413			break
414		}
415	}
416	if targetSeen != 2 {
417		t.Fatal("Failed to see server[backendCount-1] twice without seeing server[backendCount]")
418	}
419	for i := 0; i < 3*backendCount; i++ {
420		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
421			t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
422		}
423		if p.Addr.String() != test.addresses[i%backendCount] {
424			t.Errorf("Index %d: want peer %v, got peer %v", i, test.addresses[i%backendCount], p.Addr.String())
425		}
426	}
427}
428
429func (s) TestAllServersDown(t *testing.T) {
430	r := manual.NewBuilderWithScheme("whatever")
431
432	backendCount := 3
433	test, err := startTestServers(backendCount)
434	if err != nil {
435		t.Fatalf("failed to start servers: %v", err)
436	}
437	defer test.cleanup()
438
439	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithResolvers(r), grpc.WithBalancerName(roundrobin.Name))
440	if err != nil {
441		t.Fatalf("failed to dial: %v", err)
442	}
443	defer cc.Close()
444	testc := testpb.NewTestServiceClient(cc)
445	// The first RPC should fail because there's no address.
446	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
447	defer cancel()
448	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
449		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
450	}
451
452	var resolvedAddrs []resolver.Address
453	for i := 0; i < backendCount; i++ {
454		resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: test.addresses[i]})
455	}
456
457	r.UpdateState(resolver.State{Addresses: resolvedAddrs})
458	var p peer.Peer
459	// Make sure connections to all servers are up.
460	for si := 0; si < backendCount; si++ {
461		var connected bool
462		for i := 0; i < 1000; i++ {
463			if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
464				t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
465			}
466			if p.Addr.String() == test.addresses[si] {
467				connected = true
468				break
469			}
470			time.Sleep(time.Millisecond)
471		}
472		if !connected {
473			t.Fatalf("Connection to %v was not up after more than 1 second", test.addresses[si])
474		}
475	}
476
477	for i := 0; i < 3*backendCount; i++ {
478		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
479			t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
480		}
481		if p.Addr.String() != test.addresses[i%backendCount] {
482			t.Fatalf("Index %d: want peer %v, got peer %v", i, test.addresses[i%backendCount], p.Addr.String())
483		}
484	}
485
486	// All servers are stopped, failfast RPC should fail with unavailable.
487	for i := 0; i < backendCount; i++ {
488		test.servers[i].Stop()
489	}
490	time.Sleep(100 * time.Millisecond)
491	for i := 0; i < 1000; i++ {
492		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}); status.Code(err) == codes.Unavailable {
493			return
494		}
495		time.Sleep(time.Millisecond)
496	}
497	t.Fatalf("Failfast RPCs didn't fail with Unavailable after all servers are stopped")
498}
499
500func (s) TestUpdateAddressAttributes(t *testing.T) {
501	r := manual.NewBuilderWithScheme("whatever")
502
503	test, err := startTestServers(1)
504	if err != nil {
505		t.Fatalf("failed to start servers: %v", err)
506	}
507	defer test.cleanup()
508
509	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithResolvers(r), grpc.WithBalancerName(roundrobin.Name))
510	if err != nil {
511		t.Fatalf("failed to dial: %v", err)
512	}
513	defer cc.Close()
514	testc := testpb.NewTestServiceClient(cc)
515	// The first RPC should fail because there's no address.
516	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
517	defer cancel()
518	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
519		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
520	}
521
522	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: test.addresses[0]}}})
523	// The second RPC should succeed.
524	if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
525		t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
526	}
527	// The second RPC should not set metadata, so there's no md in the channel.
528	select {
529	case md1 := <-test.serverImpls[0].testMDChan:
530		t.Fatalf("got md: %v, want empty metadata", md1)
531	case <-time.After(time.Microsecond * 100):
532	}
533
534	const testMDValue = "test-md-value"
535	// Update metadata in address.
536	r.UpdateState(resolver.State{Addresses: []resolver.Address{
537		imetadata.Set(resolver.Address{Addr: test.addresses[0]}, metadata.Pairs(testMDKey, testMDValue)),
538	}})
539	// The third RPC should succeed.
540	if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
541		t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
542	}
543
544	// The third RPC should send metadata with it.
545	md2 := <-test.serverImpls[0].testMDChan
546	if len(md2) == 0 || md2[0] != testMDValue {
547		t.Fatalf("got md: %v, want %v", md2, []string{testMDValue})
548	}
549}
550