1/*
2 *
3 * Copyright 2017 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package roundrobin_test
20
21import (
22	"context"
23	"fmt"
24	"net"
25	"strings"
26	"sync"
27	"testing"
28	"time"
29
30	"google.golang.org/grpc"
31	"google.golang.org/grpc/balancer/roundrobin"
32	"google.golang.org/grpc/codes"
33	"google.golang.org/grpc/connectivity"
34	"google.golang.org/grpc/internal/grpctest"
35	"google.golang.org/grpc/peer"
36	"google.golang.org/grpc/resolver"
37	"google.golang.org/grpc/resolver/manual"
38	"google.golang.org/grpc/status"
39	testpb "google.golang.org/grpc/test/grpc_testing"
40)
41
42type s struct {
43	grpctest.Tester
44}
45
46func Test(t *testing.T) {
47	grpctest.RunSubTests(t, s{})
48}
49
50type testServer struct {
51	testpb.UnimplementedTestServiceServer
52}
53
54func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
55	return &testpb.Empty{}, nil
56}
57
58func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
59	return nil
60}
61
62type test struct {
63	servers   []*grpc.Server
64	addresses []string
65}
66
67func (t *test) cleanup() {
68	for _, s := range t.servers {
69		s.Stop()
70	}
71}
72
73func startTestServers(count int) (_ *test, err error) {
74	t := &test{}
75
76	defer func() {
77		if err != nil {
78			t.cleanup()
79		}
80	}()
81	for i := 0; i < count; i++ {
82		lis, err := net.Listen("tcp", "localhost:0")
83		if err != nil {
84			return nil, fmt.Errorf("failed to listen %v", err)
85		}
86
87		s := grpc.NewServer()
88		testpb.RegisterTestServiceServer(s, &testServer{})
89		t.servers = append(t.servers, s)
90		t.addresses = append(t.addresses, lis.Addr().String())
91
92		go func(s *grpc.Server, l net.Listener) {
93			s.Serve(l)
94		}(s, lis)
95	}
96
97	return t, nil
98}
99
100func (s) TestOneBackend(t *testing.T) {
101	r := manual.NewBuilderWithScheme("whatever")
102
103	test, err := startTestServers(1)
104	if err != nil {
105		t.Fatalf("failed to start servers: %v", err)
106	}
107	defer test.cleanup()
108
109	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithResolvers(r), grpc.WithBalancerName(roundrobin.Name))
110	if err != nil {
111		t.Fatalf("failed to dial: %v", err)
112	}
113	defer cc.Close()
114	testc := testpb.NewTestServiceClient(cc)
115	// The first RPC should fail because there's no address.
116	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
117	defer cancel()
118	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
119		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
120	}
121
122	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: test.addresses[0]}}})
123	// The second RPC should succeed.
124	if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
125		t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
126	}
127}
128
129func (s) TestBackendsRoundRobin(t *testing.T) {
130	r := manual.NewBuilderWithScheme("whatever")
131
132	backendCount := 5
133	test, err := startTestServers(backendCount)
134	if err != nil {
135		t.Fatalf("failed to start servers: %v", err)
136	}
137	defer test.cleanup()
138
139	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithResolvers(r), grpc.WithBalancerName(roundrobin.Name))
140	if err != nil {
141		t.Fatalf("failed to dial: %v", err)
142	}
143	defer cc.Close()
144	testc := testpb.NewTestServiceClient(cc)
145	// The first RPC should fail because there's no address.
146	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
147	defer cancel()
148	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
149		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
150	}
151
152	var resolvedAddrs []resolver.Address
153	for i := 0; i < backendCount; i++ {
154		resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: test.addresses[i]})
155	}
156
157	r.UpdateState(resolver.State{Addresses: resolvedAddrs})
158	var p peer.Peer
159	// Make sure connections to all servers are up.
160	for si := 0; si < backendCount; si++ {
161		var connected bool
162		for i := 0; i < 1000; i++ {
163			if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
164				t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
165			}
166			if p.Addr.String() == test.addresses[si] {
167				connected = true
168				break
169			}
170			time.Sleep(time.Millisecond)
171		}
172		if !connected {
173			t.Fatalf("Connection to %v was not up after more than 1 second", test.addresses[si])
174		}
175	}
176
177	for i := 0; i < 3*backendCount; i++ {
178		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
179			t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
180		}
181		if p.Addr.String() != test.addresses[i%backendCount] {
182			t.Fatalf("Index %d: want peer %v, got peer %v", i, test.addresses[i%backendCount], p.Addr.String())
183		}
184	}
185}
186
187func (s) TestAddressesRemoved(t *testing.T) {
188	r := manual.NewBuilderWithScheme("whatever")
189
190	test, err := startTestServers(1)
191	if err != nil {
192		t.Fatalf("failed to start servers: %v", err)
193	}
194	defer test.cleanup()
195
196	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithResolvers(r), grpc.WithBalancerName(roundrobin.Name))
197	if err != nil {
198		t.Fatalf("failed to dial: %v", err)
199	}
200	defer cc.Close()
201	testc := testpb.NewTestServiceClient(cc)
202	// The first RPC should fail because there's no address.
203	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
204	defer cancel()
205	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
206		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
207	}
208
209	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: test.addresses[0]}}})
210	// The second RPC should succeed.
211	if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
212		t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
213	}
214
215	r.UpdateState(resolver.State{Addresses: []resolver.Address{}})
216
217	ctx2, cancel2 := context.WithTimeout(context.Background(), 500*time.Millisecond)
218	defer cancel2()
219	// Wait for state to change to transient failure.
220	for src := cc.GetState(); src != connectivity.TransientFailure; src = cc.GetState() {
221		if !cc.WaitForStateChange(ctx2, src) {
222			t.Fatalf("timed out waiting for state change.  got %v; want %v", src, connectivity.TransientFailure)
223		}
224	}
225
226	const msgWant = "produced zero addresses"
227	if _, err := testc.EmptyCall(ctx2, &testpb.Empty{}); err == nil || !strings.Contains(status.Convert(err).Message(), msgWant) {
228		t.Fatalf("EmptyCall() = _, %v, want _, Contains(Message(), %q)", err, msgWant)
229	}
230}
231
232func (s) TestCloseWithPendingRPC(t *testing.T) {
233	r := manual.NewBuilderWithScheme("whatever")
234
235	test, err := startTestServers(1)
236	if err != nil {
237		t.Fatalf("failed to start servers: %v", err)
238	}
239	defer test.cleanup()
240
241	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithResolvers(r), grpc.WithBalancerName(roundrobin.Name))
242	if err != nil {
243		t.Fatalf("failed to dial: %v", err)
244	}
245	testc := testpb.NewTestServiceClient(cc)
246
247	var wg sync.WaitGroup
248	for i := 0; i < 3; i++ {
249		wg.Add(1)
250		go func() {
251			defer wg.Done()
252			// This RPC blocks until cc is closed.
253			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
254			if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) == codes.DeadlineExceeded {
255				t.Errorf("RPC failed because of deadline after cc is closed; want error the client connection is closing")
256			}
257			cancel()
258		}()
259	}
260	cc.Close()
261	wg.Wait()
262}
263
264func (s) TestNewAddressWhileBlocking(t *testing.T) {
265	r := manual.NewBuilderWithScheme("whatever")
266
267	test, err := startTestServers(1)
268	if err != nil {
269		t.Fatalf("failed to start servers: %v", err)
270	}
271	defer test.cleanup()
272
273	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithResolvers(r), grpc.WithBalancerName(roundrobin.Name))
274	if err != nil {
275		t.Fatalf("failed to dial: %v", err)
276	}
277	defer cc.Close()
278	testc := testpb.NewTestServiceClient(cc)
279	// The first RPC should fail because there's no address.
280	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
281	defer cancel()
282	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
283		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
284	}
285
286	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: test.addresses[0]}}})
287	// The second RPC should succeed.
288	ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
289	defer cancel()
290	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
291		t.Fatalf("EmptyCall() = _, %v, want _, nil", err)
292	}
293
294	r.UpdateState(resolver.State{Addresses: []resolver.Address{}})
295
296	var wg sync.WaitGroup
297	for i := 0; i < 3; i++ {
298		wg.Add(1)
299		go func() {
300			defer wg.Done()
301			// This RPC blocks until NewAddress is called.
302			testc.EmptyCall(context.Background(), &testpb.Empty{})
303		}()
304	}
305	time.Sleep(50 * time.Millisecond)
306	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: test.addresses[0]}}})
307	wg.Wait()
308}
309
310func (s) TestOneServerDown(t *testing.T) {
311	r := manual.NewBuilderWithScheme("whatever")
312
313	backendCount := 3
314	test, err := startTestServers(backendCount)
315	if err != nil {
316		t.Fatalf("failed to start servers: %v", err)
317	}
318	defer test.cleanup()
319
320	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithResolvers(r), grpc.WithBalancerName(roundrobin.Name))
321	if err != nil {
322		t.Fatalf("failed to dial: %v", err)
323	}
324	defer cc.Close()
325	testc := testpb.NewTestServiceClient(cc)
326	// The first RPC should fail because there's no address.
327	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
328	defer cancel()
329	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
330		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
331	}
332
333	var resolvedAddrs []resolver.Address
334	for i := 0; i < backendCount; i++ {
335		resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: test.addresses[i]})
336	}
337
338	r.UpdateState(resolver.State{Addresses: resolvedAddrs})
339	var p peer.Peer
340	// Make sure connections to all servers are up.
341	for si := 0; si < backendCount; si++ {
342		var connected bool
343		for i := 0; i < 1000; i++ {
344			if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
345				t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
346			}
347			if p.Addr.String() == test.addresses[si] {
348				connected = true
349				break
350			}
351			time.Sleep(time.Millisecond)
352		}
353		if !connected {
354			t.Fatalf("Connection to %v was not up after more than 1 second", test.addresses[si])
355		}
356	}
357
358	for i := 0; i < 3*backendCount; i++ {
359		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
360			t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
361		}
362		if p.Addr.String() != test.addresses[i%backendCount] {
363			t.Fatalf("Index %d: want peer %v, got peer %v", i, test.addresses[i%backendCount], p.Addr.String())
364		}
365	}
366
367	// Stop one server, RPCs should roundrobin among the remaining servers.
368	backendCount--
369	test.servers[backendCount].Stop()
370	// Loop until see server[backendCount-1] twice without seeing server[backendCount].
371	var targetSeen int
372	for i := 0; i < 1000; i++ {
373		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
374			targetSeen = 0
375			t.Logf("EmptyCall() = _, %v, want _, <nil>", err)
376			// Due to a race, this RPC could possibly get the connection that
377			// was closing, and this RPC may fail. Keep trying when this
378			// happens.
379			continue
380		}
381		switch p.Addr.String() {
382		case test.addresses[backendCount-1]:
383			targetSeen++
384		case test.addresses[backendCount]:
385			// Reset targetSeen if peer is server[backendCount].
386			targetSeen = 0
387		}
388		// Break to make sure the last picked address is server[-1], so the following for loop won't be flaky.
389		if targetSeen >= 2 {
390			break
391		}
392	}
393	if targetSeen != 2 {
394		t.Fatal("Failed to see server[backendCount-1] twice without seeing server[backendCount]")
395	}
396	for i := 0; i < 3*backendCount; i++ {
397		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
398			t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
399		}
400		if p.Addr.String() != test.addresses[i%backendCount] {
401			t.Errorf("Index %d: want peer %v, got peer %v", i, test.addresses[i%backendCount], p.Addr.String())
402		}
403	}
404}
405
406func (s) TestAllServersDown(t *testing.T) {
407	r := manual.NewBuilderWithScheme("whatever")
408
409	backendCount := 3
410	test, err := startTestServers(backendCount)
411	if err != nil {
412		t.Fatalf("failed to start servers: %v", err)
413	}
414	defer test.cleanup()
415
416	cc, err := grpc.Dial(r.Scheme()+":///test.server", grpc.WithInsecure(), grpc.WithResolvers(r), grpc.WithBalancerName(roundrobin.Name))
417	if err != nil {
418		t.Fatalf("failed to dial: %v", err)
419	}
420	defer cc.Close()
421	testc := testpb.NewTestServiceClient(cc)
422	// The first RPC should fail because there's no address.
423	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
424	defer cancel()
425	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || status.Code(err) != codes.DeadlineExceeded {
426		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
427	}
428
429	var resolvedAddrs []resolver.Address
430	for i := 0; i < backendCount; i++ {
431		resolvedAddrs = append(resolvedAddrs, resolver.Address{Addr: test.addresses[i]})
432	}
433
434	r.UpdateState(resolver.State{Addresses: resolvedAddrs})
435	var p peer.Peer
436	// Make sure connections to all servers are up.
437	for si := 0; si < backendCount; si++ {
438		var connected bool
439		for i := 0; i < 1000; i++ {
440			if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
441				t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
442			}
443			if p.Addr.String() == test.addresses[si] {
444				connected = true
445				break
446			}
447			time.Sleep(time.Millisecond)
448		}
449		if !connected {
450			t.Fatalf("Connection to %v was not up after more than 1 second", test.addresses[si])
451		}
452	}
453
454	for i := 0; i < 3*backendCount; i++ {
455		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
456			t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
457		}
458		if p.Addr.String() != test.addresses[i%backendCount] {
459			t.Fatalf("Index %d: want peer %v, got peer %v", i, test.addresses[i%backendCount], p.Addr.String())
460		}
461	}
462
463	// All servers are stopped, failfast RPC should fail with unavailable.
464	for i := 0; i < backendCount; i++ {
465		test.servers[i].Stop()
466	}
467	time.Sleep(100 * time.Millisecond)
468	for i := 0; i < 1000; i++ {
469		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}); status.Code(err) == codes.Unavailable {
470			return
471		}
472		time.Sleep(time.Millisecond)
473	}
474	t.Fatalf("Failfast RPCs didn't fail with Unavailable after all servers are stopped")
475}
476