1/*
2 *
3 * Copyright 2016 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package grpclb
20
21import (
22	"context"
23	"errors"
24	"fmt"
25	"io"
26	"net"
27	"strconv"
28	"strings"
29	"sync"
30	"sync/atomic"
31	"testing"
32	"time"
33
34	"google.golang.org/grpc"
35	"google.golang.org/grpc/balancer"
36	grpclbstate "google.golang.org/grpc/balancer/grpclb/state"
37	"google.golang.org/grpc/codes"
38	"google.golang.org/grpc/credentials"
39	"google.golang.org/grpc/internal/grpctest"
40	"google.golang.org/grpc/metadata"
41	"google.golang.org/grpc/peer"
42	"google.golang.org/grpc/resolver"
43	"google.golang.org/grpc/resolver/manual"
44	"google.golang.org/grpc/status"
45
46	durationpb "github.com/golang/protobuf/ptypes/duration"
47	lbgrpc "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
48	lbpb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
49	testpb "google.golang.org/grpc/test/grpc_testing"
50)
51
52var (
53	lbServerName = "lb.server.com"
54	beServerName = "backends.com"
55	lbToken      = "iamatoken"
56
57	// Resolver replaces localhost with fakeName in Next().
58	// Dialer replaces fakeName with localhost when dialing.
59	// This will test that custom dialer is passed from Dial to grpclb.
60	fakeName = "fake.Name"
61)
62
63type s struct {
64	grpctest.Tester
65}
66
67func Test(t *testing.T) {
68	grpctest.RunSubTests(t, s{})
69}
70
71type serverNameCheckCreds struct {
72	mu sync.Mutex
73	sn string
74}
75
76func (c *serverNameCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
77	if _, err := io.WriteString(rawConn, c.sn); err != nil {
78		fmt.Printf("Failed to write the server name %s to the client %v", c.sn, err)
79		return nil, nil, err
80	}
81	return rawConn, nil, nil
82}
83func (c *serverNameCheckCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
84	c.mu.Lock()
85	defer c.mu.Unlock()
86	b := make([]byte, len(authority))
87	errCh := make(chan error, 1)
88	go func() {
89		_, err := rawConn.Read(b)
90		errCh <- err
91	}()
92	select {
93	case err := <-errCh:
94		if err != nil {
95			fmt.Printf("test-creds: failed to read expected authority name from the server: %v\n", err)
96			return nil, nil, err
97		}
98	case <-ctx.Done():
99		return nil, nil, ctx.Err()
100	}
101	if authority != string(b) {
102		fmt.Printf("test-creds: got authority from ClientConn %q, expected by server %q\n", authority, string(b))
103		return nil, nil, errors.New("received unexpected server name")
104	}
105	return rawConn, nil, nil
106}
107func (c *serverNameCheckCreds) Info() credentials.ProtocolInfo {
108	return credentials.ProtocolInfo{}
109}
110func (c *serverNameCheckCreds) Clone() credentials.TransportCredentials {
111	return &serverNameCheckCreds{}
112}
113func (c *serverNameCheckCreds) OverrideServerName(s string) error {
114	return nil
115}
116
117// fakeNameDialer replaces fakeName with localhost when dialing.
118// This will test that custom dialer is passed from Dial to grpclb.
119func fakeNameDialer(ctx context.Context, addr string) (net.Conn, error) {
120	addr = strings.Replace(addr, fakeName, "localhost", 1)
121	return (&net.Dialer{}).DialContext(ctx, "tcp", addr)
122}
123
124// merge merges the new client stats into current stats.
125//
126// It's a test-only method. rpcStats is defined in grpclb_picker.
127func (s *rpcStats) merge(cs *lbpb.ClientStats) {
128	atomic.AddInt64(&s.numCallsStarted, cs.NumCallsStarted)
129	atomic.AddInt64(&s.numCallsFinished, cs.NumCallsFinished)
130	atomic.AddInt64(&s.numCallsFinishedWithClientFailedToSend, cs.NumCallsFinishedWithClientFailedToSend)
131	atomic.AddInt64(&s.numCallsFinishedKnownReceived, cs.NumCallsFinishedKnownReceived)
132	s.mu.Lock()
133	for _, perToken := range cs.CallsFinishedWithDrop {
134		s.numCallsDropped[perToken.LoadBalanceToken] += perToken.NumCalls
135	}
136	s.mu.Unlock()
137}
138
139func mapsEqual(a, b map[string]int64) bool {
140	if len(a) != len(b) {
141		return false
142	}
143	for k, v1 := range a {
144		if v2, ok := b[k]; !ok || v1 != v2 {
145			return false
146		}
147	}
148	return true
149}
150
151func atomicEqual(a, b *int64) bool {
152	return atomic.LoadInt64(a) == atomic.LoadInt64(b)
153}
154
155// equal compares two rpcStats.
156//
157// It's a test-only method. rpcStats is defined in grpclb_picker.
158func (s *rpcStats) equal(o *rpcStats) bool {
159	if !atomicEqual(&s.numCallsStarted, &o.numCallsStarted) {
160		return false
161	}
162	if !atomicEqual(&s.numCallsFinished, &o.numCallsFinished) {
163		return false
164	}
165	if !atomicEqual(&s.numCallsFinishedWithClientFailedToSend, &o.numCallsFinishedWithClientFailedToSend) {
166		return false
167	}
168	if !atomicEqual(&s.numCallsFinishedKnownReceived, &o.numCallsFinishedKnownReceived) {
169		return false
170	}
171	s.mu.Lock()
172	defer s.mu.Unlock()
173	o.mu.Lock()
174	defer o.mu.Unlock()
175	return mapsEqual(s.numCallsDropped, o.numCallsDropped)
176}
177
178func (s *rpcStats) String() string {
179	s.mu.Lock()
180	defer s.mu.Unlock()
181	return fmt.Sprintf("Started: %v, Finished: %v, FinishedWithClientFailedToSend: %v, FinishedKnownReceived: %v, Dropped: %v",
182		atomic.LoadInt64(&s.numCallsStarted),
183		atomic.LoadInt64(&s.numCallsFinished),
184		atomic.LoadInt64(&s.numCallsFinishedWithClientFailedToSend),
185		atomic.LoadInt64(&s.numCallsFinishedKnownReceived),
186		s.numCallsDropped)
187}
188
189type remoteBalancer struct {
190	lbgrpc.UnimplementedLoadBalancerServer
191	sls       chan *lbpb.ServerList
192	statsDura time.Duration
193	done      chan struct{}
194	stats     *rpcStats
195	statsChan chan *lbpb.ClientStats
196	fbChan    chan struct{}
197
198	customUserAgent string
199}
200
201func newRemoteBalancer(customUserAgent string, statsChan chan *lbpb.ClientStats) *remoteBalancer {
202	return &remoteBalancer{
203		sls:             make(chan *lbpb.ServerList, 1),
204		done:            make(chan struct{}),
205		stats:           newRPCStats(),
206		statsChan:       statsChan,
207		fbChan:          make(chan struct{}),
208		customUserAgent: customUserAgent,
209	}
210}
211
212func (b *remoteBalancer) stop() {
213	close(b.sls)
214	close(b.done)
215}
216
217func (b *remoteBalancer) fallbackNow() {
218	b.fbChan <- struct{}{}
219}
220
221func (b *remoteBalancer) BalanceLoad(stream lbgrpc.LoadBalancer_BalanceLoadServer) error {
222	md, ok := metadata.FromIncomingContext(stream.Context())
223	if !ok {
224		return status.Error(codes.Internal, "failed to receive metadata")
225	}
226	if b.customUserAgent != "" {
227		ua := md["user-agent"]
228		if len(ua) == 0 || !strings.HasPrefix(ua[0], b.customUserAgent) {
229			return status.Errorf(codes.InvalidArgument, "received unexpected user-agent: %v, want prefix %q", ua, b.customUserAgent)
230		}
231	}
232
233	req, err := stream.Recv()
234	if err != nil {
235		return err
236	}
237	initReq := req.GetInitialRequest()
238	if initReq.Name != beServerName {
239		return status.Errorf(codes.InvalidArgument, "invalid service name: %v", initReq.Name)
240	}
241	resp := &lbpb.LoadBalanceResponse{
242		LoadBalanceResponseType: &lbpb.LoadBalanceResponse_InitialResponse{
243			InitialResponse: &lbpb.InitialLoadBalanceResponse{
244				ClientStatsReportInterval: &durationpb.Duration{
245					Seconds: int64(b.statsDura.Seconds()),
246					Nanos:   int32(b.statsDura.Nanoseconds() - int64(b.statsDura.Seconds())*1e9),
247				},
248			},
249		},
250	}
251	if err := stream.Send(resp); err != nil {
252		return err
253	}
254	go func() {
255		for {
256			var (
257				req *lbpb.LoadBalanceRequest
258				err error
259			)
260			if req, err = stream.Recv(); err != nil {
261				return
262			}
263			b.stats.merge(req.GetClientStats())
264			if b.statsChan != nil && req.GetClientStats() != nil {
265				b.statsChan <- req.GetClientStats()
266			}
267		}
268	}()
269	for {
270		select {
271		case v := <-b.sls:
272			resp = &lbpb.LoadBalanceResponse{
273				LoadBalanceResponseType: &lbpb.LoadBalanceResponse_ServerList{
274					ServerList: v,
275				},
276			}
277		case <-b.fbChan:
278			resp = &lbpb.LoadBalanceResponse{
279				LoadBalanceResponseType: &lbpb.LoadBalanceResponse_FallbackResponse{
280					FallbackResponse: &lbpb.FallbackResponse{},
281				},
282			}
283		case <-stream.Context().Done():
284			return stream.Context().Err()
285		}
286		if err := stream.Send(resp); err != nil {
287			return err
288		}
289	}
290}
291
292type testServer struct {
293	testpb.UnimplementedTestServiceServer
294
295	addr     string
296	fallback bool
297}
298
299const testmdkey = "testmd"
300
301func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
302	md, ok := metadata.FromIncomingContext(ctx)
303	if !ok {
304		return nil, status.Error(codes.Internal, "failed to receive metadata")
305	}
306	if !s.fallback && (md == nil || len(md["lb-token"]) == 0 || md["lb-token"][0] != lbToken) {
307		return nil, status.Errorf(codes.Internal, "received unexpected metadata: %v", md)
308	}
309	grpc.SetTrailer(ctx, metadata.Pairs(testmdkey, s.addr))
310	return &testpb.Empty{}, nil
311}
312
313func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
314	return nil
315}
316
317func startBackends(sn string, fallback bool, lis ...net.Listener) (servers []*grpc.Server) {
318	for _, l := range lis {
319		creds := &serverNameCheckCreds{
320			sn: sn,
321		}
322		s := grpc.NewServer(grpc.Creds(creds))
323		testpb.RegisterTestServiceServer(s, &testServer{addr: l.Addr().String(), fallback: fallback})
324		servers = append(servers, s)
325		go func(s *grpc.Server, l net.Listener) {
326			s.Serve(l)
327		}(s, l)
328	}
329	return
330}
331
332func stopBackends(servers []*grpc.Server) {
333	for _, s := range servers {
334		s.Stop()
335	}
336}
337
338type testServers struct {
339	lbAddr   string
340	ls       *remoteBalancer
341	lb       *grpc.Server
342	backends []*grpc.Server
343	beIPs    []net.IP
344	bePorts  []int
345
346	lbListener  net.Listener
347	beListeners []net.Listener
348}
349
350func newLoadBalancer(numberOfBackends int, customUserAgent string, statsChan chan *lbpb.ClientStats) (tss *testServers, cleanup func(), err error) {
351	var (
352		beListeners []net.Listener
353		ls          *remoteBalancer
354		lb          *grpc.Server
355		beIPs       []net.IP
356		bePorts     []int
357	)
358	for i := 0; i < numberOfBackends; i++ {
359		// Start a backend.
360		beLis, e := net.Listen("tcp", "localhost:0")
361		if e != nil {
362			err = fmt.Errorf("failed to listen %v", err)
363			return
364		}
365		beIPs = append(beIPs, beLis.Addr().(*net.TCPAddr).IP)
366		bePorts = append(bePorts, beLis.Addr().(*net.TCPAddr).Port)
367
368		beListeners = append(beListeners, newRestartableListener(beLis))
369	}
370	backends := startBackends(beServerName, false, beListeners...)
371
372	// Start a load balancer.
373	lbLis, err := net.Listen("tcp", "localhost:0")
374	if err != nil {
375		err = fmt.Errorf("failed to create the listener for the load balancer %v", err)
376		return
377	}
378	lbLis = newRestartableListener(lbLis)
379	lbCreds := &serverNameCheckCreds{
380		sn: lbServerName,
381	}
382	lb = grpc.NewServer(grpc.Creds(lbCreds))
383	ls = newRemoteBalancer(customUserAgent, statsChan)
384	lbgrpc.RegisterLoadBalancerServer(lb, ls)
385	go func() {
386		lb.Serve(lbLis)
387	}()
388
389	tss = &testServers{
390		lbAddr:   net.JoinHostPort(fakeName, strconv.Itoa(lbLis.Addr().(*net.TCPAddr).Port)),
391		ls:       ls,
392		lb:       lb,
393		backends: backends,
394		beIPs:    beIPs,
395		bePorts:  bePorts,
396
397		lbListener:  lbLis,
398		beListeners: beListeners,
399	}
400	cleanup = func() {
401		defer stopBackends(backends)
402		defer func() {
403			ls.stop()
404			lb.Stop()
405		}()
406	}
407	return
408}
409
410var grpclbConfig = `{"loadBalancingConfig": [{"grpclb": {}}]}`
411
412func (s) TestGRPCLB(t *testing.T) {
413	r := manual.NewBuilderWithScheme("whatever")
414
415	const testUserAgent = "test-user-agent"
416	tss, cleanup, err := newLoadBalancer(1, testUserAgent, nil)
417	if err != nil {
418		t.Fatalf("failed to create new load balancer: %v", err)
419	}
420	defer cleanup()
421
422	be := &lbpb.Server{
423		IpAddress:        tss.beIPs[0],
424		Port:             int32(tss.bePorts[0]),
425		LoadBalanceToken: lbToken,
426	}
427	var bes []*lbpb.Server
428	bes = append(bes, be)
429	sl := &lbpb.ServerList{
430		Servers: bes,
431	}
432	tss.ls.sls <- sl
433	creds := serverNameCheckCreds{}
434	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
435	defer cancel()
436	cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r),
437		grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer),
438		grpc.WithUserAgent(testUserAgent))
439	if err != nil {
440		t.Fatalf("Failed to dial to the backend %v", err)
441	}
442	defer cc.Close()
443	testC := testpb.NewTestServiceClient(cc)
444
445	rs := grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig)},
446		&grpclbstate.State{BalancerAddresses: []resolver.Address{{
447			Addr:       tss.lbAddr,
448			Type:       resolver.Backend,
449			ServerName: lbServerName,
450		}}})
451	r.UpdateState(rs)
452
453	ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
454	defer cancel()
455	if _, err := testC.EmptyCall(ctx, &testpb.Empty{}); err != nil {
456		t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
457	}
458}
459
460// The remote balancer sends response with duplicates to grpclb client.
461func (s) TestGRPCLBWeighted(t *testing.T) {
462	r := manual.NewBuilderWithScheme("whatever")
463
464	tss, cleanup, err := newLoadBalancer(2, "", nil)
465	if err != nil {
466		t.Fatalf("failed to create new load balancer: %v", err)
467	}
468	defer cleanup()
469
470	beServers := []*lbpb.Server{{
471		IpAddress:        tss.beIPs[0],
472		Port:             int32(tss.bePorts[0]),
473		LoadBalanceToken: lbToken,
474	}, {
475		IpAddress:        tss.beIPs[1],
476		Port:             int32(tss.bePorts[1]),
477		LoadBalanceToken: lbToken,
478	}}
479	portsToIndex := make(map[int]int)
480	for i := range beServers {
481		portsToIndex[tss.bePorts[i]] = i
482	}
483
484	creds := serverNameCheckCreds{}
485	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
486	defer cancel()
487	cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r),
488		grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer))
489	if err != nil {
490		t.Fatalf("Failed to dial to the backend %v", err)
491	}
492	defer cc.Close()
493	testC := testpb.NewTestServiceClient(cc)
494
495	r.UpdateState(resolver.State{Addresses: []resolver.Address{{
496		Addr:       tss.lbAddr,
497		Type:       resolver.GRPCLB,
498		ServerName: lbServerName,
499	}}})
500
501	sequences := []string{"00101", "00011"}
502	for _, seq := range sequences {
503		var (
504			bes    []*lbpb.Server
505			p      peer.Peer
506			result string
507		)
508		for _, s := range seq {
509			bes = append(bes, beServers[s-'0'])
510		}
511		tss.ls.sls <- &lbpb.ServerList{Servers: bes}
512
513		for i := 0; i < 1000; i++ {
514			if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
515				t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
516			}
517			result += strconv.Itoa(portsToIndex[p.Addr.(*net.TCPAddr).Port])
518		}
519		// The generated result will be in format of "0010100101".
520		if !strings.Contains(result, strings.Repeat(seq, 2)) {
521			t.Errorf("got result sequence %q, want patten %q", result, seq)
522		}
523	}
524}
525
526func (s) TestDropRequest(t *testing.T) {
527	r := manual.NewBuilderWithScheme("whatever")
528
529	tss, cleanup, err := newLoadBalancer(2, "", nil)
530	if err != nil {
531		t.Fatalf("failed to create new load balancer: %v", err)
532	}
533	defer cleanup()
534	tss.ls.sls <- &lbpb.ServerList{
535		Servers: []*lbpb.Server{{
536			IpAddress:        tss.beIPs[0],
537			Port:             int32(tss.bePorts[0]),
538			LoadBalanceToken: lbToken,
539			Drop:             false,
540		}, {
541			IpAddress:        tss.beIPs[1],
542			Port:             int32(tss.bePorts[1]),
543			LoadBalanceToken: lbToken,
544			Drop:             false,
545		}, {
546			Drop: true,
547		}},
548	}
549	creds := serverNameCheckCreds{}
550	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
551	defer cancel()
552	cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r),
553		grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer))
554	if err != nil {
555		t.Fatalf("Failed to dial to the backend %v", err)
556	}
557	defer cc.Close()
558	testC := testpb.NewTestServiceClient(cc)
559
560	r.UpdateState(resolver.State{Addresses: []resolver.Address{{
561		Addr:       tss.lbAddr,
562		Type:       resolver.GRPCLB,
563		ServerName: lbServerName,
564	}}})
565
566	var (
567		i int
568		p peer.Peer
569	)
570	const (
571		// Poll to wait for something to happen. Total timeout 1 second. Sleep 1
572		// ms each loop, and do at most 1000 loops.
573		sleepEachLoop = time.Millisecond
574		loopCount     = int(time.Second / sleepEachLoop)
575	)
576	// Make a non-fail-fast RPC and wait for it to succeed.
577	for i = 0; i < loopCount; i++ {
578		if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err == nil {
579			break
580		}
581		time.Sleep(sleepEachLoop)
582	}
583	if i >= loopCount {
584		t.Fatalf("timeout waiting for the first connection to become ready. EmptyCall(_, _) = _, %v, want _, <nil>", err)
585	}
586
587	// Make RPCs until the peer is different. So we know both connections are
588	// READY.
589	for i = 0; i < loopCount; i++ {
590		var temp peer.Peer
591		if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&temp)); err == nil {
592			if temp.Addr.(*net.TCPAddr).Port != p.Addr.(*net.TCPAddr).Port {
593				break
594			}
595		}
596		time.Sleep(sleepEachLoop)
597	}
598	if i >= loopCount {
599		t.Fatalf("timeout waiting for the second connection to become ready")
600	}
601
602	// More RPCs until drop happens. So we know the picker index, and the
603	// expected behavior of following RPCs.
604	for i = 0; i < loopCount; i++ {
605		if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); status.Code(err) == codes.Unavailable {
606			break
607		}
608		time.Sleep(sleepEachLoop)
609	}
610	if i >= loopCount {
611		t.Fatalf("timeout waiting for drop. EmptyCall(_, _) = _, %v, want _, <Unavailable>", err)
612	}
613
614	select {
615	case <-ctx.Done():
616		t.Fatal("timed out", ctx.Err())
617	default:
618	}
619	for _, failfast := range []bool{true, false} {
620		for i := 0; i < 3; i++ {
621			// 1st RPCs pick the first item in server list. They should succeed
622			// since they choose the non-drop-request backend according to the
623			// round robin policy.
624			if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(!failfast)); err != nil {
625				t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
626			}
627			// 2nd RPCs pick the second item in server list. They should succeed
628			// since they choose the non-drop-request backend according to the
629			// round robin policy.
630			if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(!failfast)); err != nil {
631				t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
632			}
633			// 3rd RPCs should fail, because they pick last item in server list,
634			// with Drop set to true.
635			if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(!failfast)); status.Code(err) != codes.Unavailable {
636				t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable)
637			}
638		}
639	}
640
641	// Make one more RPC to move the picker index one step further, so it's not
642	// 0. The following RPCs will test that drop index is not reset. If picker
643	// index is at 0, we cannot tell whether it's reset or not.
644	if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil {
645		t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
646	}
647
648	tss.backends[0].Stop()
649	// This last pick was backend 0. Closing backend 0 doesn't reset drop index
650	// (for level 1 picking), so the following picks will be (backend1, drop,
651	// backend1), instead of (backend, backend, drop) if drop index was reset.
652	time.Sleep(time.Second)
653	for i := 0; i < 3; i++ {
654		var p peer.Peer
655		if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
656			t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
657		}
658		if want := tss.bePorts[1]; p.Addr.(*net.TCPAddr).Port != want {
659			t.Errorf("got peer: %v, want peer port: %v", p.Addr, want)
660		}
661
662		if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); status.Code(err) != codes.Unavailable {
663			t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable)
664		}
665
666		if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
667			t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
668		}
669		if want := tss.bePorts[1]; p.Addr.(*net.TCPAddr).Port != want {
670			t.Errorf("got peer: %v, want peer port: %v", p.Addr, want)
671		}
672	}
673}
674
675// When the balancer in use disconnects, grpclb should connect to the next address from resolved balancer address list.
676func (s) TestBalancerDisconnects(t *testing.T) {
677	r := manual.NewBuilderWithScheme("whatever")
678
679	var (
680		tests []*testServers
681		lbs   []*grpc.Server
682	)
683	for i := 0; i < 2; i++ {
684		tss, cleanup, err := newLoadBalancer(1, "", nil)
685		if err != nil {
686			t.Fatalf("failed to create new load balancer: %v", err)
687		}
688		defer cleanup()
689
690		be := &lbpb.Server{
691			IpAddress:        tss.beIPs[0],
692			Port:             int32(tss.bePorts[0]),
693			LoadBalanceToken: lbToken,
694		}
695		var bes []*lbpb.Server
696		bes = append(bes, be)
697		sl := &lbpb.ServerList{
698			Servers: bes,
699		}
700		tss.ls.sls <- sl
701
702		tests = append(tests, tss)
703		lbs = append(lbs, tss.lb)
704	}
705
706	creds := serverNameCheckCreds{}
707	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
708	defer cancel()
709	cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r),
710		grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer))
711	if err != nil {
712		t.Fatalf("Failed to dial to the backend %v", err)
713	}
714	defer cc.Close()
715	testC := testpb.NewTestServiceClient(cc)
716
717	r.UpdateState(resolver.State{Addresses: []resolver.Address{{
718		Addr:       tests[0].lbAddr,
719		Type:       resolver.GRPCLB,
720		ServerName: lbServerName,
721	}, {
722		Addr:       tests[1].lbAddr,
723		Type:       resolver.GRPCLB,
724		ServerName: lbServerName,
725	}}})
726
727	var p peer.Peer
728	if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
729		t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
730	}
731	if p.Addr.(*net.TCPAddr).Port != tests[0].bePorts[0] {
732		t.Fatalf("got peer: %v, want peer port: %v", p.Addr, tests[0].bePorts[0])
733	}
734
735	lbs[0].Stop()
736	// Stop balancer[0], balancer[1] should be used by grpclb.
737	// Check peer address to see if that happened.
738	for i := 0; i < 1000; i++ {
739		if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
740			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
741		}
742		if p.Addr.(*net.TCPAddr).Port == tests[1].bePorts[0] {
743			return
744		}
745		time.Sleep(time.Millisecond)
746	}
747	t.Fatalf("No RPC sent to second backend after 1 second")
748}
749
750func (s) TestFallback(t *testing.T) {
751	balancer.Register(newLBBuilderWithFallbackTimeout(100 * time.Millisecond))
752	defer balancer.Register(newLBBuilder())
753
754	r := manual.NewBuilderWithScheme("whatever")
755
756	tss, cleanup, err := newLoadBalancer(1, "", nil)
757	if err != nil {
758		t.Fatalf("failed to create new load balancer: %v", err)
759	}
760	defer cleanup()
761
762	// Start a standalone backend.
763	beLis, err := net.Listen("tcp", "localhost:0")
764	if err != nil {
765		t.Fatalf("Failed to listen %v", err)
766	}
767	defer beLis.Close()
768	standaloneBEs := startBackends(beServerName, true, beLis)
769	defer stopBackends(standaloneBEs)
770
771	be := &lbpb.Server{
772		IpAddress:        tss.beIPs[0],
773		Port:             int32(tss.bePorts[0]),
774		LoadBalanceToken: lbToken,
775	}
776	var bes []*lbpb.Server
777	bes = append(bes, be)
778	sl := &lbpb.ServerList{
779		Servers: bes,
780	}
781	tss.ls.sls <- sl
782	creds := serverNameCheckCreds{}
783	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
784	defer cancel()
785	cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r),
786		grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer))
787	if err != nil {
788		t.Fatalf("Failed to dial to the backend %v", err)
789	}
790	defer cc.Close()
791	testC := testpb.NewTestServiceClient(cc)
792
793	r.UpdateState(resolver.State{Addresses: []resolver.Address{{
794		Addr:       "invalid.address",
795		Type:       resolver.GRPCLB,
796		ServerName: lbServerName,
797	}, {
798		Addr: beLis.Addr().String(),
799		Type: resolver.Backend,
800	}}})
801
802	var p peer.Peer
803	if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
804		t.Fatalf("_.EmptyCall(_, _) = _, %v, want _, <nil>", err)
805	}
806	if p.Addr.String() != beLis.Addr().String() {
807		t.Fatalf("got peer: %v, want peer: %v", p.Addr, beLis.Addr())
808	}
809
810	r.UpdateState(resolver.State{Addresses: []resolver.Address{{
811		Addr:       tss.lbAddr,
812		Type:       resolver.GRPCLB,
813		ServerName: lbServerName,
814	}, {
815		Addr: beLis.Addr().String(),
816		Type: resolver.Backend,
817	}}})
818
819	var backendUsed bool
820	for i := 0; i < 1000; i++ {
821		if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
822			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
823		}
824		if p.Addr.(*net.TCPAddr).Port == tss.bePorts[0] {
825			backendUsed = true
826			break
827		}
828		time.Sleep(time.Millisecond)
829	}
830	if !backendUsed {
831		t.Fatalf("No RPC sent to backend behind remote balancer after 1 second")
832	}
833
834	// Close backend and remote balancer connections, should use fallback.
835	tss.beListeners[0].(*restartableListener).stopPreviousConns()
836	tss.lbListener.(*restartableListener).stopPreviousConns()
837
838	var fallbackUsed bool
839	for i := 0; i < 2000; i++ {
840		if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
841			// Because we are hard-closing the connection, above, it's possible
842			// for the first RPC attempt to be sent on the old connection,
843			// which will lead to an Unavailable error when it is closed.
844			// Ignore unavailable errors.
845			if status.Code(err) != codes.Unavailable {
846				t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
847			}
848		}
849		if p.Addr.String() == beLis.Addr().String() {
850			fallbackUsed = true
851			break
852		}
853		time.Sleep(time.Millisecond)
854	}
855	if !fallbackUsed {
856		t.Fatalf("No RPC sent to fallback after 2 seconds")
857	}
858
859	// Restart backend and remote balancer, should not use backends.
860	tss.beListeners[0].(*restartableListener).restart()
861	tss.lbListener.(*restartableListener).restart()
862	tss.ls.sls <- sl
863
864	var backendUsed2 bool
865	for i := 0; i < 2000; i++ {
866		if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
867			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
868		}
869		if p.Addr.(*net.TCPAddr).Port == tss.bePorts[0] {
870			backendUsed2 = true
871			break
872		}
873		time.Sleep(time.Millisecond)
874	}
875	if !backendUsed2 {
876		t.Fatalf("No RPC sent to backend behind remote balancer after 2 seconds")
877	}
878}
879
880func (s) TestExplicitFallback(t *testing.T) {
881	r := manual.NewBuilderWithScheme("whatever")
882
883	tss, cleanup, err := newLoadBalancer(1, "", nil)
884	if err != nil {
885		t.Fatalf("failed to create new load balancer: %v", err)
886	}
887	defer cleanup()
888
889	// Start a standalone backend.
890	beLis, err := net.Listen("tcp", "localhost:0")
891	if err != nil {
892		t.Fatalf("Failed to listen %v", err)
893	}
894	defer beLis.Close()
895	standaloneBEs := startBackends(beServerName, true, beLis)
896	defer stopBackends(standaloneBEs)
897
898	be := &lbpb.Server{
899		IpAddress:        tss.beIPs[0],
900		Port:             int32(tss.bePorts[0]),
901		LoadBalanceToken: lbToken,
902	}
903	var bes []*lbpb.Server
904	bes = append(bes, be)
905	sl := &lbpb.ServerList{
906		Servers: bes,
907	}
908	tss.ls.sls <- sl
909	creds := serverNameCheckCreds{}
910	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
911	defer cancel()
912	cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r),
913		grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer))
914	if err != nil {
915		t.Fatalf("Failed to dial to the backend %v", err)
916	}
917	defer cc.Close()
918	testC := testpb.NewTestServiceClient(cc)
919
920	r.UpdateState(resolver.State{Addresses: []resolver.Address{{
921		Addr:       tss.lbAddr,
922		Type:       resolver.GRPCLB,
923		ServerName: lbServerName,
924	}, {
925		Addr: beLis.Addr().String(),
926		Type: resolver.Backend,
927	}}})
928
929	var p peer.Peer
930	var backendUsed bool
931	for i := 0; i < 2000; i++ {
932		if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
933			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
934		}
935		if p.Addr.(*net.TCPAddr).Port == tss.bePorts[0] {
936			backendUsed = true
937			break
938		}
939		time.Sleep(time.Millisecond)
940	}
941	if !backendUsed {
942		t.Fatalf("No RPC sent to backend behind remote balancer after 2 seconds")
943	}
944
945	// Send fallback signal from remote balancer; should use fallback.
946	tss.ls.fallbackNow()
947
948	var fallbackUsed bool
949	for i := 0; i < 2000; i++ {
950		if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
951			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
952		}
953		if p.Addr.String() == beLis.Addr().String() {
954			fallbackUsed = true
955			break
956		}
957		time.Sleep(time.Millisecond)
958	}
959	if !fallbackUsed {
960		t.Fatalf("No RPC sent to fallback after 2 seconds")
961	}
962
963	// Send another server list; should use backends again.
964	tss.ls.sls <- sl
965
966	backendUsed = false
967	for i := 0; i < 2000; i++ {
968		if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
969			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
970		}
971		if p.Addr.(*net.TCPAddr).Port == tss.bePorts[0] {
972			backendUsed = true
973			break
974		}
975		time.Sleep(time.Millisecond)
976	}
977	if !backendUsed {
978		t.Fatalf("No RPC sent to backend behind remote balancer after 2 seconds")
979	}
980}
981
982func (s) TestFallBackWithNoServerAddress(t *testing.T) {
983	resolveNowCh := make(chan struct{}, 1)
984	r := manual.NewBuilderWithScheme("whatever")
985	r.ResolveNowCallback = func(resolver.ResolveNowOptions) {
986		select {
987		case <-resolveNowCh:
988		default:
989		}
990		resolveNowCh <- struct{}{}
991	}
992
993	tss, cleanup, err := newLoadBalancer(1, "", nil)
994	if err != nil {
995		t.Fatalf("failed to create new load balancer: %v", err)
996	}
997	defer cleanup()
998
999	// Start a standalone backend.
1000	beLis, err := net.Listen("tcp", "localhost:0")
1001	if err != nil {
1002		t.Fatalf("Failed to listen %v", err)
1003	}
1004	defer beLis.Close()
1005	standaloneBEs := startBackends(beServerName, true, beLis)
1006	defer stopBackends(standaloneBEs)
1007
1008	be := &lbpb.Server{
1009		IpAddress:        tss.beIPs[0],
1010		Port:             int32(tss.bePorts[0]),
1011		LoadBalanceToken: lbToken,
1012	}
1013	var bes []*lbpb.Server
1014	bes = append(bes, be)
1015	sl := &lbpb.ServerList{
1016		Servers: bes,
1017	}
1018	creds := serverNameCheckCreds{}
1019	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
1020	defer cancel()
1021	cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r),
1022		grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer))
1023	if err != nil {
1024		t.Fatalf("Failed to dial to the backend %v", err)
1025	}
1026	defer cc.Close()
1027	testC := testpb.NewTestServiceClient(cc)
1028
1029	// Select grpclb with service config.
1030	const pfc = `{"loadBalancingConfig":[{"grpclb":{"childPolicy":[{"round_robin":{}}]}}]}`
1031	scpr := r.CC.ParseServiceConfig(pfc)
1032	if scpr.Err != nil {
1033		t.Fatalf("Error parsing config %q: %v", pfc, scpr.Err)
1034	}
1035
1036	for i := 0; i < 2; i++ {
1037		// Send an update with only backend address. grpclb should enter fallback
1038		// and use the fallback backend.
1039		r.UpdateState(resolver.State{
1040			Addresses: []resolver.Address{{
1041				Addr: beLis.Addr().String(),
1042				Type: resolver.Backend,
1043			}},
1044			ServiceConfig: scpr,
1045		})
1046
1047		select {
1048		case <-resolveNowCh:
1049			t.Errorf("unexpected resolveNow when grpclb gets no balancer address 1111, %d", i)
1050		case <-time.After(time.Second):
1051		}
1052
1053		var p peer.Peer
1054		rpcCtx, rpcCancel := context.WithTimeout(context.Background(), time.Second)
1055		defer rpcCancel()
1056		if _, err := testC.EmptyCall(rpcCtx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
1057			t.Fatalf("_.EmptyCall(_, _) = _, %v, want _, <nil>", err)
1058		}
1059		if p.Addr.String() != beLis.Addr().String() {
1060			t.Fatalf("got peer: %v, want peer: %v", p.Addr, beLis.Addr())
1061		}
1062
1063		select {
1064		case <-resolveNowCh:
1065			t.Errorf("unexpected resolveNow when grpclb gets no balancer address 2222, %d", i)
1066		case <-time.After(time.Second):
1067		}
1068
1069		tss.ls.sls <- sl
1070		// Send an update with balancer address. The backends behind grpclb should
1071		// be used.
1072		r.UpdateState(resolver.State{
1073			Addresses: []resolver.Address{{
1074				Addr:       tss.lbAddr,
1075				Type:       resolver.GRPCLB,
1076				ServerName: lbServerName,
1077			}, {
1078				Addr: beLis.Addr().String(),
1079				Type: resolver.Backend,
1080			}},
1081			ServiceConfig: scpr,
1082		})
1083
1084		var backendUsed bool
1085		for i := 0; i < 1000; i++ {
1086			if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
1087				t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
1088			}
1089			if p.Addr.(*net.TCPAddr).Port == tss.bePorts[0] {
1090				backendUsed = true
1091				break
1092			}
1093			time.Sleep(time.Millisecond)
1094		}
1095		if !backendUsed {
1096			t.Fatalf("No RPC sent to backend behind remote balancer after 1 second")
1097		}
1098	}
1099}
1100
1101func (s) TestGRPCLBPickFirst(t *testing.T) {
1102	r := manual.NewBuilderWithScheme("whatever")
1103
1104	tss, cleanup, err := newLoadBalancer(3, "", nil)
1105	if err != nil {
1106		t.Fatalf("failed to create new load balancer: %v", err)
1107	}
1108	defer cleanup()
1109
1110	beServers := []*lbpb.Server{{
1111		IpAddress:        tss.beIPs[0],
1112		Port:             int32(tss.bePorts[0]),
1113		LoadBalanceToken: lbToken,
1114	}, {
1115		IpAddress:        tss.beIPs[1],
1116		Port:             int32(tss.bePorts[1]),
1117		LoadBalanceToken: lbToken,
1118	}, {
1119		IpAddress:        tss.beIPs[2],
1120		Port:             int32(tss.bePorts[2]),
1121		LoadBalanceToken: lbToken,
1122	}}
1123	portsToIndex := make(map[int]int)
1124	for i := range beServers {
1125		portsToIndex[tss.bePorts[i]] = i
1126	}
1127
1128	creds := serverNameCheckCreds{}
1129	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
1130	defer cancel()
1131	cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r),
1132		grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer))
1133	if err != nil {
1134		t.Fatalf("Failed to dial to the backend %v", err)
1135	}
1136	defer cc.Close()
1137	testC := testpb.NewTestServiceClient(cc)
1138
1139	var (
1140		p      peer.Peer
1141		result string
1142	)
1143	tss.ls.sls <- &lbpb.ServerList{Servers: beServers[0:3]}
1144
1145	// Start with sub policy pick_first.
1146	const pfc = `{"loadBalancingConfig":[{"grpclb":{"childPolicy":[{"pick_first":{}}]}}]}`
1147	scpr := r.CC.ParseServiceConfig(pfc)
1148	if scpr.Err != nil {
1149		t.Fatalf("Error parsing config %q: %v", pfc, scpr.Err)
1150	}
1151
1152	r.UpdateState(resolver.State{
1153		Addresses: []resolver.Address{{
1154			Addr:       tss.lbAddr,
1155			Type:       resolver.GRPCLB,
1156			ServerName: lbServerName,
1157		}},
1158		ServiceConfig: scpr,
1159	})
1160
1161	result = ""
1162	for i := 0; i < 1000; i++ {
1163		if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
1164			t.Fatalf("_.EmptyCall(_, _) = _, %v, want _, <nil>", err)
1165		}
1166		result += strconv.Itoa(portsToIndex[p.Addr.(*net.TCPAddr).Port])
1167	}
1168	if seq := "00000"; !strings.Contains(result, strings.Repeat(seq, 100)) {
1169		t.Errorf("got result sequence %q, want patten %q", result, seq)
1170	}
1171
1172	tss.ls.sls <- &lbpb.ServerList{Servers: beServers[2:]}
1173	result = ""
1174	for i := 0; i < 1000; i++ {
1175		if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
1176			t.Fatalf("_.EmptyCall(_, _) = _, %v, want _, <nil>", err)
1177		}
1178		result += strconv.Itoa(portsToIndex[p.Addr.(*net.TCPAddr).Port])
1179	}
1180	if seq := "22222"; !strings.Contains(result, strings.Repeat(seq, 100)) {
1181		t.Errorf("got result sequence %q, want patten %q", result, seq)
1182	}
1183
1184	tss.ls.sls <- &lbpb.ServerList{Servers: beServers[1:]}
1185	result = ""
1186	for i := 0; i < 1000; i++ {
1187		if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
1188			t.Fatalf("_.EmptyCall(_, _) = _, %v, want _, <nil>", err)
1189		}
1190		result += strconv.Itoa(portsToIndex[p.Addr.(*net.TCPAddr).Port])
1191	}
1192	if seq := "22222"; !strings.Contains(result, strings.Repeat(seq, 100)) {
1193		t.Errorf("got result sequence %q, want patten %q", result, seq)
1194	}
1195
1196	// Switch sub policy to roundrobin.
1197	grpclbServiceConfigEmpty := r.CC.ParseServiceConfig(`{}`)
1198	if grpclbServiceConfigEmpty.Err != nil {
1199		t.Fatalf("Error parsing config %q: %v", `{}`, grpclbServiceConfigEmpty.Err)
1200	}
1201
1202	r.UpdateState(resolver.State{
1203		Addresses: []resolver.Address{{
1204			Addr:       tss.lbAddr,
1205			Type:       resolver.GRPCLB,
1206			ServerName: lbServerName,
1207		}},
1208		ServiceConfig: grpclbServiceConfigEmpty,
1209	})
1210
1211	result = ""
1212	for i := 0; i < 1000; i++ {
1213		if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
1214			t.Fatalf("_.EmptyCall(_, _) = _, %v, want _, <nil>", err)
1215		}
1216		result += strconv.Itoa(portsToIndex[p.Addr.(*net.TCPAddr).Port])
1217	}
1218	if seq := "121212"; !strings.Contains(result, strings.Repeat(seq, 100)) {
1219		t.Errorf("got result sequence %q, want patten %q", result, seq)
1220	}
1221
1222	tss.ls.sls <- &lbpb.ServerList{Servers: beServers[0:3]}
1223	result = ""
1224	for i := 0; i < 1000; i++ {
1225		if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
1226			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
1227		}
1228		result += strconv.Itoa(portsToIndex[p.Addr.(*net.TCPAddr).Port])
1229	}
1230	if seq := "012012012"; !strings.Contains(result, strings.Repeat(seq, 2)) {
1231		t.Errorf("got result sequence %q, want patten %q", result, seq)
1232	}
1233}
1234
1235type failPreRPCCred struct{}
1236
1237func (failPreRPCCred) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
1238	if strings.Contains(uri[0], failtosendURI) {
1239		return nil, fmt.Errorf("rpc should fail to send")
1240	}
1241	return nil, nil
1242}
1243
1244func (failPreRPCCred) RequireTransportSecurity() bool {
1245	return false
1246}
1247
1248func checkStats(stats, expected *rpcStats) error {
1249	if !stats.equal(expected) {
1250		return fmt.Errorf("stats not equal: got %+v, want %+v", stats, expected)
1251	}
1252	return nil
1253}
1254
1255func runAndCheckStats(t *testing.T, drop bool, statsChan chan *lbpb.ClientStats, runRPCs func(*grpc.ClientConn), statsWant *rpcStats) error {
1256	r := manual.NewBuilderWithScheme("whatever")
1257
1258	tss, cleanup, err := newLoadBalancer(1, "", statsChan)
1259	if err != nil {
1260		t.Fatalf("failed to create new load balancer: %v", err)
1261	}
1262	defer cleanup()
1263	servers := []*lbpb.Server{{
1264		IpAddress:        tss.beIPs[0],
1265		Port:             int32(tss.bePorts[0]),
1266		LoadBalanceToken: lbToken,
1267	}}
1268	if drop {
1269		servers = append(servers, &lbpb.Server{
1270			LoadBalanceToken: lbToken,
1271			Drop:             drop,
1272		})
1273	}
1274	tss.ls.sls <- &lbpb.ServerList{Servers: servers}
1275	tss.ls.statsDura = 100 * time.Millisecond
1276	creds := serverNameCheckCreds{}
1277
1278	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
1279	defer cancel()
1280	cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r),
1281		grpc.WithTransportCredentials(&creds),
1282		grpc.WithPerRPCCredentials(failPreRPCCred{}),
1283		grpc.WithContextDialer(fakeNameDialer))
1284	if err != nil {
1285		t.Fatalf("Failed to dial to the backend %v", err)
1286	}
1287	defer cc.Close()
1288
1289	r.UpdateState(resolver.State{Addresses: []resolver.Address{{
1290		Addr:       tss.lbAddr,
1291		Type:       resolver.GRPCLB,
1292		ServerName: lbServerName,
1293	}}})
1294
1295	runRPCs(cc)
1296	end := time.Now().Add(time.Second)
1297	for time.Now().Before(end) {
1298		if err := checkStats(tss.ls.stats, statsWant); err == nil {
1299			time.Sleep(200 * time.Millisecond) // sleep for two intervals to make sure no new stats are reported.
1300			break
1301		}
1302	}
1303	return checkStats(tss.ls.stats, statsWant)
1304}
1305
1306const (
1307	countRPC      = 40
1308	failtosendURI = "failtosend"
1309)
1310
1311func (s) TestGRPCLBStatsUnarySuccess(t *testing.T) {
1312	if err := runAndCheckStats(t, false, nil, func(cc *grpc.ClientConn) {
1313		testC := testpb.NewTestServiceClient(cc)
1314		ctx, cancel := context.WithTimeout(context.Background(), defaultFallbackTimeout)
1315		defer cancel()
1316		// The first non-failfast RPC succeeds, all connections are up.
1317		if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil {
1318			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
1319		}
1320		for i := 0; i < countRPC-1; i++ {
1321			testC.EmptyCall(ctx, &testpb.Empty{})
1322		}
1323	}, &rpcStats{
1324		numCallsStarted:               int64(countRPC),
1325		numCallsFinished:              int64(countRPC),
1326		numCallsFinishedKnownReceived: int64(countRPC),
1327	}); err != nil {
1328		t.Fatal(err)
1329	}
1330}
1331
1332func (s) TestGRPCLBStatsUnaryDrop(t *testing.T) {
1333	if err := runAndCheckStats(t, true, nil, func(cc *grpc.ClientConn) {
1334		testC := testpb.NewTestServiceClient(cc)
1335		ctx, cancel := context.WithTimeout(context.Background(), defaultFallbackTimeout)
1336		defer cancel()
1337		// The first non-failfast RPC succeeds, all connections are up.
1338		if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil {
1339			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
1340		}
1341		for i := 0; i < countRPC-1; i++ {
1342			testC.EmptyCall(ctx, &testpb.Empty{})
1343		}
1344	}, &rpcStats{
1345		numCallsStarted:               int64(countRPC),
1346		numCallsFinished:              int64(countRPC),
1347		numCallsFinishedKnownReceived: int64(countRPC) / 2,
1348		numCallsDropped:               map[string]int64{lbToken: int64(countRPC) / 2},
1349	}); err != nil {
1350		t.Fatal(err)
1351	}
1352}
1353
1354func (s) TestGRPCLBStatsUnaryFailedToSend(t *testing.T) {
1355	if err := runAndCheckStats(t, false, nil, func(cc *grpc.ClientConn) {
1356		testC := testpb.NewTestServiceClient(cc)
1357		ctx, cancel := context.WithTimeout(context.Background(), defaultFallbackTimeout)
1358		defer cancel()
1359		// The first non-failfast RPC succeeds, all connections are up.
1360		if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil {
1361			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
1362		}
1363		for i := 0; i < countRPC-1; i++ {
1364			cc.Invoke(ctx, failtosendURI, &testpb.Empty{}, nil)
1365		}
1366	}, &rpcStats{
1367		numCallsStarted:                        int64(countRPC),
1368		numCallsFinished:                       int64(countRPC),
1369		numCallsFinishedWithClientFailedToSend: int64(countRPC) - 1,
1370		numCallsFinishedKnownReceived:          1,
1371	}); err != nil {
1372		t.Fatal(err)
1373	}
1374}
1375
1376func (s) TestGRPCLBStatsStreamingSuccess(t *testing.T) {
1377	if err := runAndCheckStats(t, false, nil, func(cc *grpc.ClientConn) {
1378		testC := testpb.NewTestServiceClient(cc)
1379		ctx, cancel := context.WithTimeout(context.Background(), defaultFallbackTimeout)
1380		defer cancel()
1381		// The first non-failfast RPC succeeds, all connections are up.
1382		stream, err := testC.FullDuplexCall(ctx, grpc.WaitForReady(true))
1383		if err != nil {
1384			t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err)
1385		}
1386		for {
1387			if _, err = stream.Recv(); err == io.EOF {
1388				break
1389			}
1390		}
1391		for i := 0; i < countRPC-1; i++ {
1392			stream, err = testC.FullDuplexCall(ctx)
1393			if err == nil {
1394				// Wait for stream to end if err is nil.
1395				for {
1396					if _, err = stream.Recv(); err == io.EOF {
1397						break
1398					}
1399				}
1400			}
1401		}
1402	}, &rpcStats{
1403		numCallsStarted:               int64(countRPC),
1404		numCallsFinished:              int64(countRPC),
1405		numCallsFinishedKnownReceived: int64(countRPC),
1406	}); err != nil {
1407		t.Fatal(err)
1408	}
1409}
1410
1411func (s) TestGRPCLBStatsStreamingDrop(t *testing.T) {
1412	if err := runAndCheckStats(t, true, nil, func(cc *grpc.ClientConn) {
1413		testC := testpb.NewTestServiceClient(cc)
1414		ctx, cancel := context.WithTimeout(context.Background(), defaultFallbackTimeout)
1415		defer cancel()
1416		// The first non-failfast RPC succeeds, all connections are up.
1417		stream, err := testC.FullDuplexCall(ctx, grpc.WaitForReady(true))
1418		if err != nil {
1419			t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err)
1420		}
1421		for {
1422			if _, err = stream.Recv(); err == io.EOF {
1423				break
1424			}
1425		}
1426		for i := 0; i < countRPC-1; i++ {
1427			stream, err = testC.FullDuplexCall(ctx)
1428			if err == nil {
1429				// Wait for stream to end if err is nil.
1430				for {
1431					if _, err = stream.Recv(); err == io.EOF {
1432						break
1433					}
1434				}
1435			}
1436		}
1437	}, &rpcStats{
1438		numCallsStarted:               int64(countRPC),
1439		numCallsFinished:              int64(countRPC),
1440		numCallsFinishedKnownReceived: int64(countRPC) / 2,
1441		numCallsDropped:               map[string]int64{lbToken: int64(countRPC) / 2},
1442	}); err != nil {
1443		t.Fatal(err)
1444	}
1445}
1446
1447func (s) TestGRPCLBStatsStreamingFailedToSend(t *testing.T) {
1448	if err := runAndCheckStats(t, false, nil, func(cc *grpc.ClientConn) {
1449		testC := testpb.NewTestServiceClient(cc)
1450		ctx, cancel := context.WithTimeout(context.Background(), defaultFallbackTimeout)
1451		defer cancel()
1452		// The first non-failfast RPC succeeds, all connections are up.
1453		stream, err := testC.FullDuplexCall(ctx, grpc.WaitForReady(true))
1454		if err != nil {
1455			t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err)
1456		}
1457		for {
1458			if _, err = stream.Recv(); err == io.EOF {
1459				break
1460			}
1461		}
1462		for i := 0; i < countRPC-1; i++ {
1463			cc.NewStream(ctx, &grpc.StreamDesc{}, failtosendURI)
1464		}
1465	}, &rpcStats{
1466		numCallsStarted:                        int64(countRPC),
1467		numCallsFinished:                       int64(countRPC),
1468		numCallsFinishedWithClientFailedToSend: int64(countRPC) - 1,
1469		numCallsFinishedKnownReceived:          1,
1470	}); err != nil {
1471		t.Fatal(err)
1472	}
1473}
1474
1475func (s) TestGRPCLBStatsQuashEmpty(t *testing.T) {
1476	ch := make(chan *lbpb.ClientStats)
1477	defer close(ch)
1478	if err := runAndCheckStats(t, false, ch, func(cc *grpc.ClientConn) {
1479		// Perform no RPCs; wait for load reports to start, which should be
1480		// zero, then expect no other load report within 5x the update
1481		// interval.
1482		select {
1483		case st := <-ch:
1484			if !isZeroStats(st) {
1485				t.Errorf("got stats %v; want all zero", st)
1486			}
1487		case <-time.After(5 * time.Second):
1488			t.Errorf("did not get initial stats report after 5 seconds")
1489			return
1490		}
1491
1492		select {
1493		case st := <-ch:
1494			t.Errorf("got unexpected stats report: %v", st)
1495		case <-time.After(500 * time.Millisecond):
1496			// Success.
1497		}
1498		go func() {
1499			for range ch { // Drain statsChan until it is closed.
1500			}
1501		}()
1502	}, &rpcStats{
1503		numCallsStarted:               0,
1504		numCallsFinished:              0,
1505		numCallsFinishedKnownReceived: 0,
1506	}); err != nil {
1507		t.Fatal(err)
1508	}
1509}
1510