1/*
2 *
3 * Copyright 2017 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package roundrobin_test
20
21import (
22	"context"
23	"fmt"
24	"net"
25	"strings"
26	"sync"
27	"testing"
28	"time"
29
30	"google.golang.org/grpc"
31	"google.golang.org/grpc/balancer/roundrobin"
32	"google.golang.org/grpc/codes"
33	"google.golang.org/grpc/connectivity"
34	"google.golang.org/grpc/internal/grpctest"
35	"google.golang.org/grpc/peer"
36	"google.golang.org/grpc/resolver"
37	"google.golang.org/grpc/resolver/manual"
38	"google.golang.org/grpc/status"
39	testpb "google.golang.org/grpc/test/grpc_testing"
40)
41
42type s struct {
43	grpctest.Tester
44}
45
46func Test(t *testing.T) {
47	grpctest.RunSubTests(t, s{})
48}
49
50type testServer struct {
51	testpb.UnimplementedTestServiceServer
52}
53
54func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
55	return &testpb.Empty{}, nil
56}
57
58func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
59	return nil
60}
61
62type test struct {
63	servers   []*grpc.Server
64	addresses []string
65}
66
67func (t *test) cleanup() {
68	for _, s := range t.servers {
69		s.Stop()
70	}
71}
72
73func startTestServers(count int) (_ *test, err error) {
74	t := &test{}
75
76	defer func() {
77		if err != nil {
78			t.cleanup()
79		}
80	}()
81	for i := 0; i < count; i++ {
82		lis, err := net.Listen("tcp", "localhost:0")
83		if err != nil {
84			return nil, fmt.Errorf("failed to listen %v", err)
85		}
86
87		s := grpc.NewServer()
88		testpb.RegisterTestServiceServer(s, &testServer{})
89		t.servers = append(t.servers, s)
90		t.addresses = append(t.addresses, lis.Addr().String())
91
92		go func(s *grpc.Server, l net.Listener) {
93			s.Serve(l)
94		}(s, lis)
95	}
96
97	return t, nil
98}
99
100func (s) TestOneBackend(t *testing.T) {
101	r, cleanup := manual.GenerateAndRegisterManualResolver()
102	defer cleanup()
103
104	test, err := startTestServers(1)
105	if err != nil {
106		t.Fatalf("failed to start servers: %v", err)
107	}
108	defer test.cleanup()
109
110	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name))
111	if err != nil {
112		t.Fatalf("failed to dial: %v", err)
113	}
114	defer cc.Close()
115	testc := testpb.NewTestServiceClient(cc)
116	// The first RPC should fail because there's no address.
117	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
118	defer cancel()
119	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
120		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
121	}
122
123	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: test.addresses[0]}}})
124	// The second RPC should succeed.
125	if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
126		t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
127	}
128}
129
130func (s) TestBackendsRoundRobin(t *testing.T) {
131	r, cleanup := manual.GenerateAndRegisterManualResolver()
132	defer cleanup()
133
134	backendCount := 5
135	test, err := startTestServers(backendCount)
136	if err != nil {
137		t.Fatalf("failed to start servers: %v", err)
138	}
139	defer test.cleanup()
140
141	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name))
142	if err != nil {
143		t.Fatalf("failed to dial: %v", err)
144	}
145	defer cc.Close()
146	testc := testpb.NewTestServiceClient(cc)
147	// The first RPC should fail because there's no address.
148	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
149	defer cancel()
150	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
151		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
152	}
153
154	var resolvedAddrs []resolver.Address
155	for i := 0; i < backendCount; i++ {
156		resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: test.addresses[i]})
157	}
158
159	r.UpdateState(resolver.State{Addresses: resolvedAddrs})
160	var p peer.Peer
161	// Make sure connections to all servers are up.
162	for si := 0; si < backendCount; si++ {
163		var connected bool
164		for i := 0; i < 1000; i++ {
165			if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
166				t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
167			}
168			if p.Addr.String() == test.addresses[si] {
169				connected = true
170				break
171			}
172			time.Sleep(time.Millisecond)
173		}
174		if !connected {
175			t.Fatalf("Connection to %v was not up after more than 1 second", test.addresses[si])
176		}
177	}
178
179	for i := 0; i < 3*backendCount; i++ {
180		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
181			t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
182		}
183		if p.Addr.String() != test.addresses[i%backendCount] {
184			t.Fatalf("Index %d: want peer %v, got peer %v", i, test.addresses[i%backendCount], p.Addr.String())
185		}
186	}
187}
188
189func (s) TestAddressesRemoved(t *testing.T) {
190	r, cleanup := manual.GenerateAndRegisterManualResolver()
191	defer cleanup()
192
193	test, err := startTestServers(1)
194	if err != nil {
195		t.Fatalf("failed to start servers: %v", err)
196	}
197	defer test.cleanup()
198
199	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name))
200	if err != nil {
201		t.Fatalf("failed to dial: %v", err)
202	}
203	defer cc.Close()
204	testc := testpb.NewTestServiceClient(cc)
205	// The first RPC should fail because there's no address.
206	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
207	defer cancel()
208	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
209		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
210	}
211
212	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: test.addresses[0]}}})
213	// The second RPC should succeed.
214	if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
215		t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
216	}
217
218	r.UpdateState(resolver.State{Addresses: []resolver.Address{}})
219
220	ctx2, cancel2 := context.WithTimeout(context.Background(), 500*time.Millisecond)
221	defer cancel2()
222	// Wait for state to change to transient failure.
223	for src := cc.GetState(); src != connectivity.TransientFailure; src = cc.GetState() {
224		if !cc.WaitForStateChange(ctx2, src) {
225			t.Fatalf("timed out waiting for state change.  got %v; want %v", src, connectivity.TransientFailure)
226		}
227	}
228
229	const msgWant = "produced zero addresses"
230	if _, err := testc.EmptyCall(ctx2, &testpb.Empty{}); err == nil || !strings.Contains(status.Convert(err).Message(), msgWant) {
231		t.Fatalf("EmptyCall() = _, %v, want _, Contains(Message(), %q)", err, msgWant)
232	}
233}
234
235func (s) TestCloseWithPendingRPC(t *testing.T) {
236	r, cleanup := manual.GenerateAndRegisterManualResolver()
237	defer cleanup()
238
239	test, err := startTestServers(1)
240	if err != nil {
241		t.Fatalf("failed to start servers: %v", err)
242	}
243	defer test.cleanup()
244
245	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name))
246	if err != nil {
247		t.Fatalf("failed to dial: %v", err)
248	}
249	testc := testpb.NewTestServiceClient(cc)
250
251	var wg sync.WaitGroup
252	for i := 0; i < 3; i++ {
253		wg.Add(1)
254		go func() {
255			defer wg.Done()
256			// This RPC blocks until cc is closed.
257			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
258			if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) == codes.DeadlineExceeded {
259				t.Errorf("RPC failed because of deadline after cc is closed; want error the client connection is closing")
260			}
261			cancel()
262		}()
263	}
264	cc.Close()
265	wg.Wait()
266}
267
268func (s) TestNewAddressWhileBlocking(t *testing.T) {
269	r, cleanup := manual.GenerateAndRegisterManualResolver()
270	defer cleanup()
271
272	test, err := startTestServers(1)
273	if err != nil {
274		t.Fatalf("failed to start servers: %v", err)
275	}
276	defer test.cleanup()
277
278	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name))
279	if err != nil {
280		t.Fatalf("failed to dial: %v", err)
281	}
282	defer cc.Close()
283	testc := testpb.NewTestServiceClient(cc)
284	// The first RPC should fail because there's no address.
285	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
286	defer cancel()
287	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
288		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
289	}
290
291	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: test.addresses[0]}}})
292	// The second RPC should succeed.
293	ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
294	defer cancel()
295	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
296		t.Fatalf("EmptyCall() = _, %v, want _, nil", err)
297	}
298
299	r.UpdateState(resolver.State{Addresses: []resolver.Address{}})
300
301	var wg sync.WaitGroup
302	for i := 0; i < 3; i++ {
303		wg.Add(1)
304		go func() {
305			defer wg.Done()
306			// This RPC blocks until NewAddress is called.
307			testc.EmptyCall(context.Background(), &testpb.Empty{})
308		}()
309	}
310	time.Sleep(50 * time.Millisecond)
311	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: test.addresses[0]}}})
312	wg.Wait()
313}
314
315func (s) TestOneServerDown(t *testing.T) {
316	r, cleanup := manual.GenerateAndRegisterManualResolver()
317	defer cleanup()
318
319	backendCount := 3
320	test, err := startTestServers(backendCount)
321	if err != nil {
322		t.Fatalf("failed to start servers: %v", err)
323	}
324	defer test.cleanup()
325
326	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name))
327	if err != nil {
328		t.Fatalf("failed to dial: %v", err)
329	}
330	defer cc.Close()
331	testc := testpb.NewTestServiceClient(cc)
332	// The first RPC should fail because there's no address.
333	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
334	defer cancel()
335	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
336		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
337	}
338
339	var resolvedAddrs []resolver.Address
340	for i := 0; i < backendCount; i++ {
341		resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: test.addresses[i]})
342	}
343
344	r.UpdateState(resolver.State{Addresses: resolvedAddrs})
345	var p peer.Peer
346	// Make sure connections to all servers are up.
347	for si := 0; si < backendCount; si++ {
348		var connected bool
349		for i := 0; i < 1000; i++ {
350			if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
351				t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
352			}
353			if p.Addr.String() == test.addresses[si] {
354				connected = true
355				break
356			}
357			time.Sleep(time.Millisecond)
358		}
359		if !connected {
360			t.Fatalf("Connection to %v was not up after more than 1 second", test.addresses[si])
361		}
362	}
363
364	for i := 0; i < 3*backendCount; i++ {
365		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
366			t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
367		}
368		if p.Addr.String() != test.addresses[i%backendCount] {
369			t.Fatalf("Index %d: want peer %v, got peer %v", i, test.addresses[i%backendCount], p.Addr.String())
370		}
371	}
372
373	// Stop one server, RPCs should roundrobin among the remaining servers.
374	backendCount--
375	test.servers[backendCount].Stop()
376	// Loop until see server[backendCount-1] twice without seeing server[backendCount].
377	var targetSeen int
378	for i := 0; i < 1000; i++ {
379		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
380			targetSeen = 0
381			t.Logf("EmptyCall() = _, %v, want _, <nil>", err)
382			// Due to a race, this RPC could possibly get the connection that
383			// was closing, and this RPC may fail. Keep trying when this
384			// happens.
385			continue
386		}
387		switch p.Addr.String() {
388		case test.addresses[backendCount-1]:
389			targetSeen++
390		case test.addresses[backendCount]:
391			// Reset targetSeen if peer is server[backendCount].
392			targetSeen = 0
393		}
394		// Break to make sure the last picked address is server[-1], so the following for loop won't be flaky.
395		if targetSeen >= 2 {
396			break
397		}
398	}
399	if targetSeen != 2 {
400		t.Fatal("Failed to see server[backendCount-1] twice without seeing server[backendCount]")
401	}
402	for i := 0; i < 3*backendCount; i++ {
403		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
404			t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
405		}
406		if p.Addr.String() != test.addresses[i%backendCount] {
407			t.Errorf("Index %d: want peer %v, got peer %v", i, test.addresses[i%backendCount], p.Addr.String())
408		}
409	}
410}
411
412func (s) TestAllServersDown(t *testing.T) {
413	r, cleanup := manual.GenerateAndRegisterManualResolver()
414	defer cleanup()
415
416	backendCount := 3
417	test, err := startTestServers(backendCount)
418	if err != nil {
419		t.Fatalf("failed to start servers: %v", err)
420	}
421	defer test.cleanup()
422
423	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithBalancerName(roundrobin.Name))
424	if err != nil {
425		t.Fatalf("failed to dial: %v", err)
426	}
427	defer cc.Close()
428	testc := testpb.NewTestServiceClient(cc)
429	// The first RPC should fail because there's no address.
430	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
431	defer cancel()
432	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
433		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
434	}
435
436	var resolvedAddrs []resolver.Address
437	for i := 0; i < backendCount; i++ {
438		resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: test.addresses[i]})
439	}
440
441	r.UpdateState(resolver.State{Addresses: resolvedAddrs})
442	var p peer.Peer
443	// Make sure connections to all servers are up.
444	for si := 0; si < backendCount; si++ {
445		var connected bool
446		for i := 0; i < 1000; i++ {
447			if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
448				t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
449			}
450			if p.Addr.String() == test.addresses[si] {
451				connected = true
452				break
453			}
454			time.Sleep(time.Millisecond)
455		}
456		if !connected {
457			t.Fatalf("Connection to %v was not up after more than 1 second", test.addresses[si])
458		}
459	}
460
461	for i := 0; i < 3*backendCount; i++ {
462		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
463			t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
464		}
465		if p.Addr.String() != test.addresses[i%backendCount] {
466			t.Fatalf("Index %d: want peer %v, got peer %v", i, test.addresses[i%backendCount], p.Addr.String())
467		}
468	}
469
470	// All servers are stopped, failfast RPC should fail with unavailable.
471	for i := 0; i < backendCount; i++ {
472		test.servers[i].Stop()
473	}
474	time.Sleep(100 * time.Millisecond)
475	for i := 0; i < 1000; i++ {
476		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}); status.Code(err) == codes.Unavailable {
477			return
478		}
479		time.Sleep(time.Millisecond)
480	}
481	t.Fatalf("Failfast RPCs didn't fail with Unavailable after all servers are stopped")
482}
483