1/*
2 *
3 * Copyright 2016 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package grpclb
20
21import (
22	"context"
23	"errors"
24	"fmt"
25	"io"
26	"net"
27	"strconv"
28	"strings"
29	"sync"
30	"sync/atomic"
31	"testing"
32	"time"
33
34	durationpb "github.com/golang/protobuf/ptypes/duration"
35	"google.golang.org/grpc"
36	"google.golang.org/grpc/balancer"
37	lbgrpc "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
38	lbpb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
39	"google.golang.org/grpc/codes"
40	"google.golang.org/grpc/credentials"
41	_ "google.golang.org/grpc/grpclog/glogger"
42	"google.golang.org/grpc/internal/leakcheck"
43	"google.golang.org/grpc/metadata"
44	"google.golang.org/grpc/peer"
45	"google.golang.org/grpc/resolver"
46	"google.golang.org/grpc/resolver/manual"
47	"google.golang.org/grpc/status"
48	testpb "google.golang.org/grpc/test/grpc_testing"
49)
50
51var (
52	lbServerName = "bar.com"
53	beServerName = "foo.com"
54	lbToken      = "iamatoken"
55
56	// Resolver replaces localhost with fakeName in Next().
57	// Dialer replaces fakeName with localhost when dialing.
58	// This will test that custom dialer is passed from Dial to grpclb.
59	fakeName = "fake.Name"
60)
61
62type serverNameCheckCreds struct {
63	mu       sync.Mutex
64	sn       string
65	expected string
66}
67
68func (c *serverNameCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
69	if _, err := io.WriteString(rawConn, c.sn); err != nil {
70		fmt.Printf("Failed to write the server name %s to the client %v", c.sn, err)
71		return nil, nil, err
72	}
73	return rawConn, nil, nil
74}
75func (c *serverNameCheckCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
76	c.mu.Lock()
77	defer c.mu.Unlock()
78	b := make([]byte, len(c.expected))
79	errCh := make(chan error, 1)
80	go func() {
81		_, err := rawConn.Read(b)
82		errCh <- err
83	}()
84	select {
85	case err := <-errCh:
86		if err != nil {
87			fmt.Printf("Failed to read the server name from the server %v", err)
88			return nil, nil, err
89		}
90	case <-ctx.Done():
91		return nil, nil, ctx.Err()
92	}
93	if c.expected != string(b) {
94		fmt.Printf("Read the server name %s want %s", string(b), c.expected)
95		return nil, nil, errors.New("received unexpected server name")
96	}
97	return rawConn, nil, nil
98}
99func (c *serverNameCheckCreds) Info() credentials.ProtocolInfo {
100	c.mu.Lock()
101	defer c.mu.Unlock()
102	return credentials.ProtocolInfo{}
103}
104func (c *serverNameCheckCreds) Clone() credentials.TransportCredentials {
105	c.mu.Lock()
106	defer c.mu.Unlock()
107	return &serverNameCheckCreds{
108		expected: c.expected,
109	}
110}
111func (c *serverNameCheckCreds) OverrideServerName(s string) error {
112	c.mu.Lock()
113	defer c.mu.Unlock()
114	c.expected = s
115	return nil
116}
117
118// fakeNameDialer replaces fakeName with localhost when dialing.
119// This will test that custom dialer is passed from Dial to grpclb.
120func fakeNameDialer(ctx context.Context, addr string) (net.Conn, error) {
121	addr = strings.Replace(addr, fakeName, "localhost", 1)
122	return (&net.Dialer{}).DialContext(ctx, "tcp", addr)
123}
124
125// merge merges the new client stats into current stats.
126//
127// It's a test-only method. rpcStats is defined in grpclb_picker.
128func (s *rpcStats) merge(cs *lbpb.ClientStats) {
129	atomic.AddInt64(&s.numCallsStarted, cs.NumCallsStarted)
130	atomic.AddInt64(&s.numCallsFinished, cs.NumCallsFinished)
131	atomic.AddInt64(&s.numCallsFinishedWithClientFailedToSend, cs.NumCallsFinishedWithClientFailedToSend)
132	atomic.AddInt64(&s.numCallsFinishedKnownReceived, cs.NumCallsFinishedKnownReceived)
133	s.mu.Lock()
134	for _, perToken := range cs.CallsFinishedWithDrop {
135		s.numCallsDropped[perToken.LoadBalanceToken] += perToken.NumCalls
136	}
137	s.mu.Unlock()
138}
139
140func mapsEqual(a, b map[string]int64) bool {
141	if len(a) != len(b) {
142		return false
143	}
144	for k, v1 := range a {
145		if v2, ok := b[k]; !ok || v1 != v2 {
146			return false
147		}
148	}
149	return true
150}
151
152func atomicEqual(a, b *int64) bool {
153	return atomic.LoadInt64(a) == atomic.LoadInt64(b)
154}
155
156// equal compares two rpcStats.
157//
158// It's a test-only method. rpcStats is defined in grpclb_picker.
159func (s *rpcStats) equal(o *rpcStats) bool {
160	if !atomicEqual(&s.numCallsStarted, &o.numCallsStarted) {
161		return false
162	}
163	if !atomicEqual(&s.numCallsFinished, &o.numCallsFinished) {
164		return false
165	}
166	if !atomicEqual(&s.numCallsFinishedWithClientFailedToSend, &o.numCallsFinishedWithClientFailedToSend) {
167		return false
168	}
169	if !atomicEqual(&s.numCallsFinishedKnownReceived, &o.numCallsFinishedKnownReceived) {
170		return false
171	}
172	s.mu.Lock()
173	defer s.mu.Unlock()
174	o.mu.Lock()
175	defer o.mu.Unlock()
176	return mapsEqual(s.numCallsDropped, o.numCallsDropped)
177}
178
179type remoteBalancer struct {
180	sls       chan *lbpb.ServerList
181	statsDura time.Duration
182	done      chan struct{}
183	stats     *rpcStats
184}
185
186func newRemoteBalancer(intervals []time.Duration) *remoteBalancer {
187	return &remoteBalancer{
188		sls:   make(chan *lbpb.ServerList, 1),
189		done:  make(chan struct{}),
190		stats: newRPCStats(),
191	}
192}
193
194func (b *remoteBalancer) stop() {
195	close(b.sls)
196	close(b.done)
197}
198
199func (b *remoteBalancer) BalanceLoad(stream lbgrpc.LoadBalancer_BalanceLoadServer) error {
200	req, err := stream.Recv()
201	if err != nil {
202		return err
203	}
204	initReq := req.GetInitialRequest()
205	if initReq.Name != beServerName {
206		return status.Errorf(codes.InvalidArgument, "invalid service name: %v", initReq.Name)
207	}
208	resp := &lbpb.LoadBalanceResponse{
209		LoadBalanceResponseType: &lbpb.LoadBalanceResponse_InitialResponse{
210			InitialResponse: &lbpb.InitialLoadBalanceResponse{
211				ClientStatsReportInterval: &durationpb.Duration{
212					Seconds: int64(b.statsDura.Seconds()),
213					Nanos:   int32(b.statsDura.Nanoseconds() - int64(b.statsDura.Seconds())*1e9),
214				},
215			},
216		},
217	}
218	if err := stream.Send(resp); err != nil {
219		return err
220	}
221	go func() {
222		for {
223			var (
224				req *lbpb.LoadBalanceRequest
225				err error
226			)
227			if req, err = stream.Recv(); err != nil {
228				return
229			}
230			b.stats.merge(req.GetClientStats())
231		}
232	}()
233	for {
234		select {
235		case v := <-b.sls:
236			resp = &lbpb.LoadBalanceResponse{
237				LoadBalanceResponseType: &lbpb.LoadBalanceResponse_ServerList{
238					ServerList: v,
239				},
240			}
241		case <-stream.Context().Done():
242			return stream.Context().Err()
243		}
244		if err := stream.Send(resp); err != nil {
245			return err
246		}
247	}
248}
249
250type testServer struct {
251	testpb.TestServiceServer
252
253	addr     string
254	fallback bool
255}
256
257const testmdkey = "testmd"
258
259func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
260	md, ok := metadata.FromIncomingContext(ctx)
261	if !ok {
262		return nil, status.Error(codes.Internal, "failed to receive metadata")
263	}
264	if !s.fallback && (md == nil || md["lb-token"][0] != lbToken) {
265		return nil, status.Errorf(codes.Internal, "received unexpected metadata: %v", md)
266	}
267	grpc.SetTrailer(ctx, metadata.Pairs(testmdkey, s.addr))
268	return &testpb.Empty{}, nil
269}
270
271func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
272	return nil
273}
274
275func startBackends(sn string, fallback bool, lis ...net.Listener) (servers []*grpc.Server) {
276	for _, l := range lis {
277		creds := &serverNameCheckCreds{
278			sn: sn,
279		}
280		s := grpc.NewServer(grpc.Creds(creds))
281		testpb.RegisterTestServiceServer(s, &testServer{addr: l.Addr().String(), fallback: fallback})
282		servers = append(servers, s)
283		go func(s *grpc.Server, l net.Listener) {
284			s.Serve(l)
285		}(s, l)
286	}
287	return
288}
289
290func stopBackends(servers []*grpc.Server) {
291	for _, s := range servers {
292		s.Stop()
293	}
294}
295
296type testServers struct {
297	lbAddr   string
298	ls       *remoteBalancer
299	lb       *grpc.Server
300	backends []*grpc.Server
301	beIPs    []net.IP
302	bePorts  []int
303
304	lbListener  net.Listener
305	beListeners []net.Listener
306}
307
308func newLoadBalancer(numberOfBackends int) (tss *testServers, cleanup func(), err error) {
309	var (
310		beListeners []net.Listener
311		ls          *remoteBalancer
312		lb          *grpc.Server
313		beIPs       []net.IP
314		bePorts     []int
315	)
316	for i := 0; i < numberOfBackends; i++ {
317		// Start a backend.
318		beLis, e := net.Listen("tcp", "localhost:0")
319		if e != nil {
320			err = fmt.Errorf("failed to listen %v", err)
321			return
322		}
323		beIPs = append(beIPs, beLis.Addr().(*net.TCPAddr).IP)
324		bePorts = append(bePorts, beLis.Addr().(*net.TCPAddr).Port)
325
326		beListeners = append(beListeners, newRestartableListener(beLis))
327	}
328	backends := startBackends(beServerName, false, beListeners...)
329
330	// Start a load balancer.
331	lbLis, err := net.Listen("tcp", "localhost:0")
332	if err != nil {
333		err = fmt.Errorf("failed to create the listener for the load balancer %v", err)
334		return
335	}
336	lbLis = newRestartableListener(lbLis)
337	lbCreds := &serverNameCheckCreds{
338		sn: lbServerName,
339	}
340	lb = grpc.NewServer(grpc.Creds(lbCreds))
341	ls = newRemoteBalancer(nil)
342	lbgrpc.RegisterLoadBalancerServer(lb, ls)
343	go func() {
344		lb.Serve(lbLis)
345	}()
346
347	tss = &testServers{
348		lbAddr:   net.JoinHostPort(fakeName, strconv.Itoa(lbLis.Addr().(*net.TCPAddr).Port)),
349		ls:       ls,
350		lb:       lb,
351		backends: backends,
352		beIPs:    beIPs,
353		bePorts:  bePorts,
354
355		lbListener:  lbLis,
356		beListeners: beListeners,
357	}
358	cleanup = func() {
359		defer stopBackends(backends)
360		defer func() {
361			ls.stop()
362			lb.Stop()
363		}()
364	}
365	return
366}
367
368func TestGRPCLB(t *testing.T) {
369	defer leakcheck.Check(t)
370
371	r, cleanup := manual.GenerateAndRegisterManualResolver()
372	defer cleanup()
373
374	tss, cleanup, err := newLoadBalancer(1)
375	if err != nil {
376		t.Fatalf("failed to create new load balancer: %v", err)
377	}
378	defer cleanup()
379
380	be := &lbpb.Server{
381		IpAddress:        tss.beIPs[0],
382		Port:             int32(tss.bePorts[0]),
383		LoadBalanceToken: lbToken,
384	}
385	var bes []*lbpb.Server
386	bes = append(bes, be)
387	sl := &lbpb.ServerList{
388		Servers: bes,
389	}
390	tss.ls.sls <- sl
391	creds := serverNameCheckCreds{
392		expected: beServerName,
393	}
394	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
395	defer cancel()
396	cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
397		grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer))
398	if err != nil {
399		t.Fatalf("Failed to dial to the backend %v", err)
400	}
401	defer cc.Close()
402	testC := testpb.NewTestServiceClient(cc)
403
404	r.UpdateState(resolver.State{Addresses: []resolver.Address{{
405		Addr:       tss.lbAddr,
406		Type:       resolver.GRPCLB,
407		ServerName: lbServerName,
408	}}})
409
410	if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true)); err != nil {
411		t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
412	}
413}
414
415// The remote balancer sends response with duplicates to grpclb client.
416func TestGRPCLBWeighted(t *testing.T) {
417	defer leakcheck.Check(t)
418
419	r, cleanup := manual.GenerateAndRegisterManualResolver()
420	defer cleanup()
421
422	tss, cleanup, err := newLoadBalancer(2)
423	if err != nil {
424		t.Fatalf("failed to create new load balancer: %v", err)
425	}
426	defer cleanup()
427
428	beServers := []*lbpb.Server{{
429		IpAddress:        tss.beIPs[0],
430		Port:             int32(tss.bePorts[0]),
431		LoadBalanceToken: lbToken,
432	}, {
433		IpAddress:        tss.beIPs[1],
434		Port:             int32(tss.bePorts[1]),
435		LoadBalanceToken: lbToken,
436	}}
437	portsToIndex := make(map[int]int)
438	for i := range beServers {
439		portsToIndex[tss.bePorts[i]] = i
440	}
441
442	creds := serverNameCheckCreds{
443		expected: beServerName,
444	}
445	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
446	defer cancel()
447	cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
448		grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer))
449	if err != nil {
450		t.Fatalf("Failed to dial to the backend %v", err)
451	}
452	defer cc.Close()
453	testC := testpb.NewTestServiceClient(cc)
454
455	r.UpdateState(resolver.State{Addresses: []resolver.Address{{
456		Addr:       tss.lbAddr,
457		Type:       resolver.GRPCLB,
458		ServerName: lbServerName,
459	}}})
460
461	sequences := []string{"00101", "00011"}
462	for _, seq := range sequences {
463		var (
464			bes    []*lbpb.Server
465			p      peer.Peer
466			result string
467		)
468		for _, s := range seq {
469			bes = append(bes, beServers[s-'0'])
470		}
471		tss.ls.sls <- &lbpb.ServerList{Servers: bes}
472
473		for i := 0; i < 1000; i++ {
474			if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
475				t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
476			}
477			result += strconv.Itoa(portsToIndex[p.Addr.(*net.TCPAddr).Port])
478		}
479		// The generated result will be in format of "0010100101".
480		if !strings.Contains(result, strings.Repeat(seq, 2)) {
481			t.Errorf("got result sequence %q, want patten %q", result, seq)
482		}
483	}
484}
485
486func TestDropRequest(t *testing.T) {
487	defer leakcheck.Check(t)
488
489	r, cleanup := manual.GenerateAndRegisterManualResolver()
490	defer cleanup()
491
492	tss, cleanup, err := newLoadBalancer(2)
493	if err != nil {
494		t.Fatalf("failed to create new load balancer: %v", err)
495	}
496	defer cleanup()
497	tss.ls.sls <- &lbpb.ServerList{
498		Servers: []*lbpb.Server{{
499			IpAddress:        tss.beIPs[0],
500			Port:             int32(tss.bePorts[0]),
501			LoadBalanceToken: lbToken,
502			Drop:             false,
503		}, {
504			IpAddress:        tss.beIPs[1],
505			Port:             int32(tss.bePorts[1]),
506			LoadBalanceToken: lbToken,
507			Drop:             false,
508		}, {
509			Drop: true,
510		}},
511	}
512	creds := serverNameCheckCreds{
513		expected: beServerName,
514	}
515	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
516	defer cancel()
517	cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
518		grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer))
519	if err != nil {
520		t.Fatalf("Failed to dial to the backend %v", err)
521	}
522	defer cc.Close()
523	testC := testpb.NewTestServiceClient(cc)
524
525	r.UpdateState(resolver.State{Addresses: []resolver.Address{{
526		Addr:       tss.lbAddr,
527		Type:       resolver.GRPCLB,
528		ServerName: lbServerName,
529	}}})
530
531	// Wait for the 1st, non-fail-fast RPC to succeed. This ensures both server
532	// connections are made, because the first one has Drop set to true.
533	var i int
534	for i = 0; i < 1000; i++ {
535		if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err == nil {
536			break
537		}
538		time.Sleep(time.Millisecond)
539	}
540	if i >= 1000 {
541		t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", testC, err)
542	}
543	select {
544	case <-ctx.Done():
545		t.Fatal("timed out", ctx.Err())
546	default:
547	}
548	for _, failfast := range []bool{true, false} {
549		for i := 0; i < 3; i++ {
550			// 1st RPCs pick the second item in server list. They should succeed
551			// since they choose the non-drop-request backend according to the
552			// round robin policy.
553			if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(!failfast)); err != nil {
554				t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
555			}
556			// 2st RPCs should fail, because they pick last item in server list,
557			// with Drop set to true.
558			if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(!failfast)); status.Code(err) != codes.Unavailable {
559				t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable)
560			}
561			// 3rd RPCs pick the first item in server list. They should succeed
562			// since they choose the non-drop-request backend according to the
563			// round robin policy.
564			if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(!failfast)); err != nil {
565				t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
566			}
567		}
568	}
569	tss.backends[0].Stop()
570	// This last pick was backend 0. Closing backend 0 doesn't reset drop index
571	// (for level 1 picking), so the following picks will be (backend1, drop,
572	// backend1), instead of (backend, backend, drop) if drop index was reset.
573	time.Sleep(time.Second)
574	for i := 0; i < 3; i++ {
575		var p peer.Peer
576		if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
577			t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
578		}
579		if want := tss.bePorts[1]; p.Addr.(*net.TCPAddr).Port != want {
580			t.Errorf("got peer: %v, want peer port: %v", p.Addr, want)
581		}
582
583		if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true)); status.Code(err) != codes.Unavailable {
584			t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable)
585		}
586
587		if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
588			t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
589		}
590		if want := tss.bePorts[1]; p.Addr.(*net.TCPAddr).Port != want {
591			t.Errorf("got peer: %v, want peer port: %v", p.Addr, want)
592		}
593	}
594}
595
596// When the balancer in use disconnects, grpclb should connect to the next address from resolved balancer address list.
597func TestBalancerDisconnects(t *testing.T) {
598	defer leakcheck.Check(t)
599
600	r, cleanup := manual.GenerateAndRegisterManualResolver()
601	defer cleanup()
602
603	var (
604		tests []*testServers
605		lbs   []*grpc.Server
606	)
607	for i := 0; i < 2; i++ {
608		tss, cleanup, err := newLoadBalancer(1)
609		if err != nil {
610			t.Fatalf("failed to create new load balancer: %v", err)
611		}
612		defer cleanup()
613
614		be := &lbpb.Server{
615			IpAddress:        tss.beIPs[0],
616			Port:             int32(tss.bePorts[0]),
617			LoadBalanceToken: lbToken,
618		}
619		var bes []*lbpb.Server
620		bes = append(bes, be)
621		sl := &lbpb.ServerList{
622			Servers: bes,
623		}
624		tss.ls.sls <- sl
625
626		tests = append(tests, tss)
627		lbs = append(lbs, tss.lb)
628	}
629
630	creds := serverNameCheckCreds{
631		expected: beServerName,
632	}
633	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
634	defer cancel()
635	cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
636		grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer))
637	if err != nil {
638		t.Fatalf("Failed to dial to the backend %v", err)
639	}
640	defer cc.Close()
641	testC := testpb.NewTestServiceClient(cc)
642
643	r.UpdateState(resolver.State{Addresses: []resolver.Address{{
644		Addr:       tests[0].lbAddr,
645		Type:       resolver.GRPCLB,
646		ServerName: lbServerName,
647	}, {
648		Addr:       tests[1].lbAddr,
649		Type:       resolver.GRPCLB,
650		ServerName: lbServerName,
651	}}})
652
653	var p peer.Peer
654	if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
655		t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
656	}
657	if p.Addr.(*net.TCPAddr).Port != tests[0].bePorts[0] {
658		t.Fatalf("got peer: %v, want peer port: %v", p.Addr, tests[0].bePorts[0])
659	}
660
661	lbs[0].Stop()
662	// Stop balancer[0], balancer[1] should be used by grpclb.
663	// Check peer address to see if that happened.
664	for i := 0; i < 1000; i++ {
665		if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
666			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
667		}
668		if p.Addr.(*net.TCPAddr).Port == tests[1].bePorts[0] {
669			return
670		}
671		time.Sleep(time.Millisecond)
672	}
673	t.Fatalf("No RPC sent to second backend after 1 second")
674}
675
676func TestFallback(t *testing.T) {
677	balancer.Register(newLBBuilderWithFallbackTimeout(100 * time.Millisecond))
678	defer balancer.Register(newLBBuilder())
679
680	defer leakcheck.Check(t)
681
682	r, cleanup := manual.GenerateAndRegisterManualResolver()
683	defer cleanup()
684
685	tss, cleanup, err := newLoadBalancer(1)
686	if err != nil {
687		t.Fatalf("failed to create new load balancer: %v", err)
688	}
689	defer cleanup()
690
691	// Start a standalone backend.
692	beLis, err := net.Listen("tcp", "localhost:0")
693	if err != nil {
694		t.Fatalf("Failed to listen %v", err)
695	}
696	defer beLis.Close()
697	standaloneBEs := startBackends(beServerName, true, beLis)
698	defer stopBackends(standaloneBEs)
699
700	be := &lbpb.Server{
701		IpAddress:        tss.beIPs[0],
702		Port:             int32(tss.bePorts[0]),
703		LoadBalanceToken: lbToken,
704	}
705	var bes []*lbpb.Server
706	bes = append(bes, be)
707	sl := &lbpb.ServerList{
708		Servers: bes,
709	}
710	tss.ls.sls <- sl
711	creds := serverNameCheckCreds{
712		expected: beServerName,
713	}
714	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
715	defer cancel()
716	cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
717		grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer))
718	if err != nil {
719		t.Fatalf("Failed to dial to the backend %v", err)
720	}
721	defer cc.Close()
722	testC := testpb.NewTestServiceClient(cc)
723
724	r.UpdateState(resolver.State{Addresses: []resolver.Address{{
725		Addr:       "invalid.address",
726		Type:       resolver.GRPCLB,
727		ServerName: lbServerName,
728	}, {
729		Addr:       beLis.Addr().String(),
730		Type:       resolver.Backend,
731		ServerName: beServerName,
732	}}})
733
734	var p peer.Peer
735	if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
736		t.Fatalf("_.EmptyCall(_, _) = _, %v, want _, <nil>", err)
737	}
738	if p.Addr.String() != beLis.Addr().String() {
739		t.Fatalf("got peer: %v, want peer: %v", p.Addr, beLis.Addr())
740	}
741
742	r.UpdateState(resolver.State{Addresses: []resolver.Address{{
743		Addr:       tss.lbAddr,
744		Type:       resolver.GRPCLB,
745		ServerName: lbServerName,
746	}, {
747		Addr:       beLis.Addr().String(),
748		Type:       resolver.Backend,
749		ServerName: beServerName,
750	}}})
751
752	var backendUsed bool
753	for i := 0; i < 1000; i++ {
754		if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
755			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
756		}
757		if p.Addr.(*net.TCPAddr).Port == tss.bePorts[0] {
758			backendUsed = true
759			break
760		}
761		time.Sleep(time.Millisecond)
762	}
763	if !backendUsed {
764		t.Fatalf("No RPC sent to backend behind remote balancer after 1 second")
765	}
766
767	// Close backend and remote balancer connections, should use fallback.
768	tss.beListeners[0].(*restartableListener).stopPreviousConns()
769	tss.lbListener.(*restartableListener).stopPreviousConns()
770	time.Sleep(time.Second)
771
772	var fallbackUsed bool
773	for i := 0; i < 1000; i++ {
774		if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
775			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
776		}
777		if p.Addr.String() == beLis.Addr().String() {
778			fallbackUsed = true
779			break
780		}
781		time.Sleep(time.Millisecond)
782	}
783	if !fallbackUsed {
784		t.Fatalf("No RPC sent to fallback after 1 second")
785	}
786
787	// Restart backend and remote balancer, should not use backends.
788	tss.beListeners[0].(*restartableListener).restart()
789	tss.lbListener.(*restartableListener).restart()
790	tss.ls.sls <- sl
791
792	time.Sleep(time.Second)
793
794	var backendUsed2 bool
795	for i := 0; i < 1000; i++ {
796		if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
797			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
798		}
799		if p.Addr.(*net.TCPAddr).Port == tss.bePorts[0] {
800			backendUsed2 = true
801			break
802		}
803		time.Sleep(time.Millisecond)
804	}
805	if !backendUsed2 {
806		t.Fatalf("No RPC sent to backend behind remote balancer after 1 second")
807	}
808}
809
810// The remote balancer sends response with duplicates to grpclb client.
811func TestGRPCLBPickFirst(t *testing.T) {
812	balancer.Register(newLBBuilderWithPickFirst())
813	defer balancer.Register(newLBBuilder())
814
815	defer leakcheck.Check(t)
816
817	r, cleanup := manual.GenerateAndRegisterManualResolver()
818	defer cleanup()
819
820	tss, cleanup, err := newLoadBalancer(3)
821	if err != nil {
822		t.Fatalf("failed to create new load balancer: %v", err)
823	}
824	defer cleanup()
825
826	beServers := []*lbpb.Server{{
827		IpAddress:        tss.beIPs[0],
828		Port:             int32(tss.bePorts[0]),
829		LoadBalanceToken: lbToken,
830	}, {
831		IpAddress:        tss.beIPs[1],
832		Port:             int32(tss.bePorts[1]),
833		LoadBalanceToken: lbToken,
834	}, {
835		IpAddress:        tss.beIPs[2],
836		Port:             int32(tss.bePorts[2]),
837		LoadBalanceToken: lbToken,
838	}}
839	portsToIndex := make(map[int]int)
840	for i := range beServers {
841		portsToIndex[tss.bePorts[i]] = i
842	}
843
844	creds := serverNameCheckCreds{
845		expected: beServerName,
846	}
847	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
848	defer cancel()
849	cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
850		grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer))
851	if err != nil {
852		t.Fatalf("Failed to dial to the backend %v", err)
853	}
854	defer cc.Close()
855	testC := testpb.NewTestServiceClient(cc)
856
857	r.UpdateState(resolver.State{Addresses: []resolver.Address{{
858		Addr:       tss.lbAddr,
859		Type:       resolver.GRPCLB,
860		ServerName: lbServerName,
861	}}})
862
863	var p peer.Peer
864
865	portPicked1 := 0
866	tss.ls.sls <- &lbpb.ServerList{Servers: beServers[1:2]}
867	for i := 0; i < 1000; i++ {
868		if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
869			t.Fatalf("_.EmptyCall(_, _) = _, %v, want _, <nil>", err)
870		}
871		if portPicked1 == 0 {
872			portPicked1 = p.Addr.(*net.TCPAddr).Port
873			continue
874		}
875		if portPicked1 != p.Addr.(*net.TCPAddr).Port {
876			t.Fatalf("Different backends are picked for RPCs: %v vs %v", portPicked1, p.Addr.(*net.TCPAddr).Port)
877		}
878	}
879
880	portPicked2 := portPicked1
881	tss.ls.sls <- &lbpb.ServerList{Servers: beServers[:1]}
882	for i := 0; i < 1000; i++ {
883		if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
884			t.Fatalf("_.EmptyCall(_, _) = _, %v, want _, <nil>", err)
885		}
886		if portPicked2 == portPicked1 {
887			portPicked2 = p.Addr.(*net.TCPAddr).Port
888			continue
889		}
890		if portPicked2 != p.Addr.(*net.TCPAddr).Port {
891			t.Fatalf("Different backends are picked for RPCs: %v vs %v", portPicked2, p.Addr.(*net.TCPAddr).Port)
892		}
893	}
894
895	portPicked := portPicked2
896	tss.ls.sls <- &lbpb.ServerList{Servers: beServers[1:]}
897	for i := 0; i < 1000; i++ {
898		if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil {
899			t.Fatalf("_.EmptyCall(_, _) = _, %v, want _, <nil>", err)
900		}
901		if portPicked == portPicked2 {
902			portPicked = p.Addr.(*net.TCPAddr).Port
903			continue
904		}
905		if portPicked != p.Addr.(*net.TCPAddr).Port {
906			t.Fatalf("Different backends are picked for RPCs: %v vs %v", portPicked, p.Addr.(*net.TCPAddr).Port)
907		}
908	}
909}
910
911type failPreRPCCred struct{}
912
913func (failPreRPCCred) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
914	if strings.Contains(uri[0], failtosendURI) {
915		return nil, fmt.Errorf("rpc should fail to send")
916	}
917	return nil, nil
918}
919
920func (failPreRPCCred) RequireTransportSecurity() bool {
921	return false
922}
923
924func checkStats(stats, expected *rpcStats) error {
925	if !stats.equal(expected) {
926		return fmt.Errorf("stats not equal: got %+v, want %+v", stats, expected)
927	}
928	return nil
929}
930
931func runAndGetStats(t *testing.T, drop bool, runRPCs func(*grpc.ClientConn)) *rpcStats {
932	defer leakcheck.Check(t)
933
934	r, cleanup := manual.GenerateAndRegisterManualResolver()
935	defer cleanup()
936
937	tss, cleanup, err := newLoadBalancer(1)
938	if err != nil {
939		t.Fatalf("failed to create new load balancer: %v", err)
940	}
941	defer cleanup()
942	servers := []*lbpb.Server{{
943		IpAddress:        tss.beIPs[0],
944		Port:             int32(tss.bePorts[0]),
945		LoadBalanceToken: lbToken,
946	}}
947	if drop {
948		servers = append(servers, &lbpb.Server{
949			LoadBalanceToken: lbToken,
950			Drop:             drop,
951		})
952	}
953	tss.ls.sls <- &lbpb.ServerList{Servers: servers}
954	tss.ls.statsDura = 100 * time.Millisecond
955	creds := serverNameCheckCreds{expected: beServerName}
956
957	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
958	defer cancel()
959	cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
960		grpc.WithTransportCredentials(&creds),
961		grpc.WithPerRPCCredentials(failPreRPCCred{}),
962		grpc.WithContextDialer(fakeNameDialer))
963	if err != nil {
964		t.Fatalf("Failed to dial to the backend %v", err)
965	}
966	defer cc.Close()
967
968	r.UpdateState(resolver.State{Addresses: []resolver.Address{{
969		Addr:       tss.lbAddr,
970		Type:       resolver.GRPCLB,
971		ServerName: lbServerName,
972	}}})
973
974	runRPCs(cc)
975	time.Sleep(1 * time.Second)
976	stats := tss.ls.stats
977	return stats
978}
979
980const (
981	countRPC      = 40
982	failtosendURI = "failtosend"
983)
984
985func TestGRPCLBStatsUnarySuccess(t *testing.T) {
986	defer leakcheck.Check(t)
987	stats := runAndGetStats(t, false, func(cc *grpc.ClientConn) {
988		testC := testpb.NewTestServiceClient(cc)
989		// The first non-failfast RPC succeeds, all connections are up.
990		if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true)); err != nil {
991			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
992		}
993		for i := 0; i < countRPC-1; i++ {
994			testC.EmptyCall(context.Background(), &testpb.Empty{})
995		}
996	})
997
998	if err := checkStats(stats, &rpcStats{
999		numCallsStarted:               int64(countRPC),
1000		numCallsFinished:              int64(countRPC),
1001		numCallsFinishedKnownReceived: int64(countRPC),
1002	}); err != nil {
1003		t.Fatal(err)
1004	}
1005}
1006
1007func TestGRPCLBStatsUnaryDrop(t *testing.T) {
1008	defer leakcheck.Check(t)
1009	stats := runAndGetStats(t, true, func(cc *grpc.ClientConn) {
1010		testC := testpb.NewTestServiceClient(cc)
1011		// The first non-failfast RPC succeeds, all connections are up.
1012		if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true)); err != nil {
1013			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
1014		}
1015		for i := 0; i < countRPC-1; i++ {
1016			testC.EmptyCall(context.Background(), &testpb.Empty{})
1017		}
1018	})
1019
1020	if err := checkStats(stats, &rpcStats{
1021		numCallsStarted:               int64(countRPC),
1022		numCallsFinished:              int64(countRPC),
1023		numCallsFinishedKnownReceived: int64(countRPC) / 2,
1024		numCallsDropped:               map[string]int64{lbToken: int64(countRPC) / 2},
1025	}); err != nil {
1026		t.Fatal(err)
1027	}
1028}
1029
1030func TestGRPCLBStatsUnaryFailedToSend(t *testing.T) {
1031	defer leakcheck.Check(t)
1032	stats := runAndGetStats(t, false, func(cc *grpc.ClientConn) {
1033		testC := testpb.NewTestServiceClient(cc)
1034		// The first non-failfast RPC succeeds, all connections are up.
1035		if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true)); err != nil {
1036			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
1037		}
1038		for i := 0; i < countRPC-1; i++ {
1039			cc.Invoke(context.Background(), failtosendURI, &testpb.Empty{}, nil)
1040		}
1041	})
1042
1043	if err := checkStats(stats, &rpcStats{
1044		numCallsStarted:                        int64(countRPC),
1045		numCallsFinished:                       int64(countRPC),
1046		numCallsFinishedWithClientFailedToSend: int64(countRPC - 1),
1047		numCallsFinishedKnownReceived:          1,
1048	}); err != nil {
1049		t.Fatal(err)
1050	}
1051}
1052
1053func TestGRPCLBStatsStreamingSuccess(t *testing.T) {
1054	defer leakcheck.Check(t)
1055	stats := runAndGetStats(t, false, func(cc *grpc.ClientConn) {
1056		testC := testpb.NewTestServiceClient(cc)
1057		// The first non-failfast RPC succeeds, all connections are up.
1058		stream, err := testC.FullDuplexCall(context.Background(), grpc.WaitForReady(true))
1059		if err != nil {
1060			t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err)
1061		}
1062		for {
1063			if _, err = stream.Recv(); err == io.EOF {
1064				break
1065			}
1066		}
1067		for i := 0; i < countRPC-1; i++ {
1068			stream, err = testC.FullDuplexCall(context.Background())
1069			if err == nil {
1070				// Wait for stream to end if err is nil.
1071				for {
1072					if _, err = stream.Recv(); err == io.EOF {
1073						break
1074					}
1075				}
1076			}
1077		}
1078	})
1079
1080	if err := checkStats(stats, &rpcStats{
1081		numCallsStarted:               int64(countRPC),
1082		numCallsFinished:              int64(countRPC),
1083		numCallsFinishedKnownReceived: int64(countRPC),
1084	}); err != nil {
1085		t.Fatal(err)
1086	}
1087}
1088
1089func TestGRPCLBStatsStreamingDrop(t *testing.T) {
1090	defer leakcheck.Check(t)
1091	stats := runAndGetStats(t, true, func(cc *grpc.ClientConn) {
1092		testC := testpb.NewTestServiceClient(cc)
1093		// The first non-failfast RPC succeeds, all connections are up.
1094		stream, err := testC.FullDuplexCall(context.Background(), grpc.WaitForReady(true))
1095		if err != nil {
1096			t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err)
1097		}
1098		for {
1099			if _, err = stream.Recv(); err == io.EOF {
1100				break
1101			}
1102		}
1103		for i := 0; i < countRPC-1; i++ {
1104			stream, err = testC.FullDuplexCall(context.Background())
1105			if err == nil {
1106				// Wait for stream to end if err is nil.
1107				for {
1108					if _, err = stream.Recv(); err == io.EOF {
1109						break
1110					}
1111				}
1112			}
1113		}
1114	})
1115
1116	if err := checkStats(stats, &rpcStats{
1117		numCallsStarted:               int64(countRPC),
1118		numCallsFinished:              int64(countRPC),
1119		numCallsFinishedKnownReceived: int64(countRPC) / 2,
1120		numCallsDropped:               map[string]int64{lbToken: int64(countRPC) / 2},
1121	}); err != nil {
1122		t.Fatal(err)
1123	}
1124}
1125
1126func TestGRPCLBStatsStreamingFailedToSend(t *testing.T) {
1127	defer leakcheck.Check(t)
1128	stats := runAndGetStats(t, false, func(cc *grpc.ClientConn) {
1129		testC := testpb.NewTestServiceClient(cc)
1130		// The first non-failfast RPC succeeds, all connections are up.
1131		stream, err := testC.FullDuplexCall(context.Background(), grpc.WaitForReady(true))
1132		if err != nil {
1133			t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err)
1134		}
1135		for {
1136			if _, err = stream.Recv(); err == io.EOF {
1137				break
1138			}
1139		}
1140		for i := 0; i < countRPC-1; i++ {
1141			cc.NewStream(context.Background(), &grpc.StreamDesc{}, failtosendURI)
1142		}
1143	})
1144
1145	if err := checkStats(stats, &rpcStats{
1146		numCallsStarted:                        int64(countRPC),
1147		numCallsFinished:                       int64(countRPC),
1148		numCallsFinishedWithClientFailedToSend: int64(countRPC - 1),
1149		numCallsFinishedKnownReceived:          1,
1150	}); err != nil {
1151		t.Fatal(err)
1152	}
1153}
1154