1/* 2 * 3 * Copyright 2016 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19package grpclb 20 21import ( 22 "context" 23 "errors" 24 "fmt" 25 "io" 26 "net" 27 "strconv" 28 "strings" 29 "sync" 30 "sync/atomic" 31 "testing" 32 "time" 33 34 durationpb "github.com/golang/protobuf/ptypes/duration" 35 "google.golang.org/grpc" 36 "google.golang.org/grpc/balancer" 37 lbgrpc "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1" 38 lbpb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1" 39 "google.golang.org/grpc/codes" 40 "google.golang.org/grpc/credentials" 41 _ "google.golang.org/grpc/grpclog/glogger" 42 "google.golang.org/grpc/internal/leakcheck" 43 "google.golang.org/grpc/metadata" 44 "google.golang.org/grpc/peer" 45 "google.golang.org/grpc/resolver" 46 "google.golang.org/grpc/resolver/manual" 47 "google.golang.org/grpc/status" 48 testpb "google.golang.org/grpc/test/grpc_testing" 49) 50 51var ( 52 lbServerName = "bar.com" 53 beServerName = "foo.com" 54 lbToken = "iamatoken" 55 56 // Resolver replaces localhost with fakeName in Next(). 57 // Dialer replaces fakeName with localhost when dialing. 58 // This will test that custom dialer is passed from Dial to grpclb. 59 fakeName = "fake.Name" 60) 61 62type serverNameCheckCreds struct { 63 mu sync.Mutex 64 sn string 65 expected string 66} 67 68func (c *serverNameCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 69 if _, err := io.WriteString(rawConn, c.sn); err != nil { 70 fmt.Printf("Failed to write the server name %s to the client %v", c.sn, err) 71 return nil, nil, err 72 } 73 return rawConn, nil, nil 74} 75func (c *serverNameCheckCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { 76 c.mu.Lock() 77 defer c.mu.Unlock() 78 b := make([]byte, len(c.expected)) 79 errCh := make(chan error, 1) 80 go func() { 81 _, err := rawConn.Read(b) 82 errCh <- err 83 }() 84 select { 85 case err := <-errCh: 86 if err != nil { 87 fmt.Printf("Failed to read the server name from the server %v", err) 88 return nil, nil, err 89 } 90 case <-ctx.Done(): 91 return nil, nil, ctx.Err() 92 } 93 if c.expected != string(b) { 94 fmt.Printf("Read the server name %s want %s", string(b), c.expected) 95 return nil, nil, errors.New("received unexpected server name") 96 } 97 return rawConn, nil, nil 98} 99func (c *serverNameCheckCreds) Info() credentials.ProtocolInfo { 100 c.mu.Lock() 101 defer c.mu.Unlock() 102 return credentials.ProtocolInfo{} 103} 104func (c *serverNameCheckCreds) Clone() credentials.TransportCredentials { 105 c.mu.Lock() 106 defer c.mu.Unlock() 107 return &serverNameCheckCreds{ 108 expected: c.expected, 109 } 110} 111func (c *serverNameCheckCreds) OverrideServerName(s string) error { 112 c.mu.Lock() 113 defer c.mu.Unlock() 114 c.expected = s 115 return nil 116} 117 118// fakeNameDialer replaces fakeName with localhost when dialing. 119// This will test that custom dialer is passed from Dial to grpclb. 120func fakeNameDialer(ctx context.Context, addr string) (net.Conn, error) { 121 addr = strings.Replace(addr, fakeName, "localhost", 1) 122 return (&net.Dialer{}).DialContext(ctx, "tcp", addr) 123} 124 125// merge merges the new client stats into current stats. 126// 127// It's a test-only method. rpcStats is defined in grpclb_picker. 128func (s *rpcStats) merge(cs *lbpb.ClientStats) { 129 atomic.AddInt64(&s.numCallsStarted, cs.NumCallsStarted) 130 atomic.AddInt64(&s.numCallsFinished, cs.NumCallsFinished) 131 atomic.AddInt64(&s.numCallsFinishedWithClientFailedToSend, cs.NumCallsFinishedWithClientFailedToSend) 132 atomic.AddInt64(&s.numCallsFinishedKnownReceived, cs.NumCallsFinishedKnownReceived) 133 s.mu.Lock() 134 for _, perToken := range cs.CallsFinishedWithDrop { 135 s.numCallsDropped[perToken.LoadBalanceToken] += perToken.NumCalls 136 } 137 s.mu.Unlock() 138} 139 140func mapsEqual(a, b map[string]int64) bool { 141 if len(a) != len(b) { 142 return false 143 } 144 for k, v1 := range a { 145 if v2, ok := b[k]; !ok || v1 != v2 { 146 return false 147 } 148 } 149 return true 150} 151 152func atomicEqual(a, b *int64) bool { 153 return atomic.LoadInt64(a) == atomic.LoadInt64(b) 154} 155 156// equal compares two rpcStats. 157// 158// It's a test-only method. rpcStats is defined in grpclb_picker. 159func (s *rpcStats) equal(o *rpcStats) bool { 160 if !atomicEqual(&s.numCallsStarted, &o.numCallsStarted) { 161 return false 162 } 163 if !atomicEqual(&s.numCallsFinished, &o.numCallsFinished) { 164 return false 165 } 166 if !atomicEqual(&s.numCallsFinishedWithClientFailedToSend, &o.numCallsFinishedWithClientFailedToSend) { 167 return false 168 } 169 if !atomicEqual(&s.numCallsFinishedKnownReceived, &o.numCallsFinishedKnownReceived) { 170 return false 171 } 172 s.mu.Lock() 173 defer s.mu.Unlock() 174 o.mu.Lock() 175 defer o.mu.Unlock() 176 return mapsEqual(s.numCallsDropped, o.numCallsDropped) 177} 178 179type remoteBalancer struct { 180 sls chan *lbpb.ServerList 181 statsDura time.Duration 182 done chan struct{} 183 stats *rpcStats 184} 185 186func newRemoteBalancer(intervals []time.Duration) *remoteBalancer { 187 return &remoteBalancer{ 188 sls: make(chan *lbpb.ServerList, 1), 189 done: make(chan struct{}), 190 stats: newRPCStats(), 191 } 192} 193 194func (b *remoteBalancer) stop() { 195 close(b.sls) 196 close(b.done) 197} 198 199func (b *remoteBalancer) BalanceLoad(stream lbgrpc.LoadBalancer_BalanceLoadServer) error { 200 req, err := stream.Recv() 201 if err != nil { 202 return err 203 } 204 initReq := req.GetInitialRequest() 205 if initReq.Name != beServerName { 206 return status.Errorf(codes.InvalidArgument, "invalid service name: %v", initReq.Name) 207 } 208 resp := &lbpb.LoadBalanceResponse{ 209 LoadBalanceResponseType: &lbpb.LoadBalanceResponse_InitialResponse{ 210 InitialResponse: &lbpb.InitialLoadBalanceResponse{ 211 ClientStatsReportInterval: &durationpb.Duration{ 212 Seconds: int64(b.statsDura.Seconds()), 213 Nanos: int32(b.statsDura.Nanoseconds() - int64(b.statsDura.Seconds())*1e9), 214 }, 215 }, 216 }, 217 } 218 if err := stream.Send(resp); err != nil { 219 return err 220 } 221 go func() { 222 for { 223 var ( 224 req *lbpb.LoadBalanceRequest 225 err error 226 ) 227 if req, err = stream.Recv(); err != nil { 228 return 229 } 230 b.stats.merge(req.GetClientStats()) 231 } 232 }() 233 for { 234 select { 235 case v := <-b.sls: 236 resp = &lbpb.LoadBalanceResponse{ 237 LoadBalanceResponseType: &lbpb.LoadBalanceResponse_ServerList{ 238 ServerList: v, 239 }, 240 } 241 case <-stream.Context().Done(): 242 return stream.Context().Err() 243 } 244 if err := stream.Send(resp); err != nil { 245 return err 246 } 247 } 248} 249 250type testServer struct { 251 testpb.TestServiceServer 252 253 addr string 254 fallback bool 255} 256 257const testmdkey = "testmd" 258 259func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { 260 md, ok := metadata.FromIncomingContext(ctx) 261 if !ok { 262 return nil, status.Error(codes.Internal, "failed to receive metadata") 263 } 264 if !s.fallback && (md == nil || md["lb-token"][0] != lbToken) { 265 return nil, status.Errorf(codes.Internal, "received unexpected metadata: %v", md) 266 } 267 grpc.SetTrailer(ctx, metadata.Pairs(testmdkey, s.addr)) 268 return &testpb.Empty{}, nil 269} 270 271func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error { 272 return nil 273} 274 275func startBackends(sn string, fallback bool, lis ...net.Listener) (servers []*grpc.Server) { 276 for _, l := range lis { 277 creds := &serverNameCheckCreds{ 278 sn: sn, 279 } 280 s := grpc.NewServer(grpc.Creds(creds)) 281 testpb.RegisterTestServiceServer(s, &testServer{addr: l.Addr().String(), fallback: fallback}) 282 servers = append(servers, s) 283 go func(s *grpc.Server, l net.Listener) { 284 s.Serve(l) 285 }(s, l) 286 } 287 return 288} 289 290func stopBackends(servers []*grpc.Server) { 291 for _, s := range servers { 292 s.Stop() 293 } 294} 295 296type testServers struct { 297 lbAddr string 298 ls *remoteBalancer 299 lb *grpc.Server 300 backends []*grpc.Server 301 beIPs []net.IP 302 bePorts []int 303 304 lbListener net.Listener 305 beListeners []net.Listener 306} 307 308func newLoadBalancer(numberOfBackends int) (tss *testServers, cleanup func(), err error) { 309 var ( 310 beListeners []net.Listener 311 ls *remoteBalancer 312 lb *grpc.Server 313 beIPs []net.IP 314 bePorts []int 315 ) 316 for i := 0; i < numberOfBackends; i++ { 317 // Start a backend. 318 beLis, e := net.Listen("tcp", "localhost:0") 319 if e != nil { 320 err = fmt.Errorf("failed to listen %v", err) 321 return 322 } 323 beIPs = append(beIPs, beLis.Addr().(*net.TCPAddr).IP) 324 bePorts = append(bePorts, beLis.Addr().(*net.TCPAddr).Port) 325 326 beListeners = append(beListeners, newRestartableListener(beLis)) 327 } 328 backends := startBackends(beServerName, false, beListeners...) 329 330 // Start a load balancer. 331 lbLis, err := net.Listen("tcp", "localhost:0") 332 if err != nil { 333 err = fmt.Errorf("failed to create the listener for the load balancer %v", err) 334 return 335 } 336 lbLis = newRestartableListener(lbLis) 337 lbCreds := &serverNameCheckCreds{ 338 sn: lbServerName, 339 } 340 lb = grpc.NewServer(grpc.Creds(lbCreds)) 341 ls = newRemoteBalancer(nil) 342 lbgrpc.RegisterLoadBalancerServer(lb, ls) 343 go func() { 344 lb.Serve(lbLis) 345 }() 346 347 tss = &testServers{ 348 lbAddr: net.JoinHostPort(fakeName, strconv.Itoa(lbLis.Addr().(*net.TCPAddr).Port)), 349 ls: ls, 350 lb: lb, 351 backends: backends, 352 beIPs: beIPs, 353 bePorts: bePorts, 354 355 lbListener: lbLis, 356 beListeners: beListeners, 357 } 358 cleanup = func() { 359 defer stopBackends(backends) 360 defer func() { 361 ls.stop() 362 lb.Stop() 363 }() 364 } 365 return 366} 367 368func TestGRPCLB(t *testing.T) { 369 defer leakcheck.Check(t) 370 371 r, cleanup := manual.GenerateAndRegisterManualResolver() 372 defer cleanup() 373 374 tss, cleanup, err := newLoadBalancer(1) 375 if err != nil { 376 t.Fatalf("failed to create new load balancer: %v", err) 377 } 378 defer cleanup() 379 380 be := &lbpb.Server{ 381 IpAddress: tss.beIPs[0], 382 Port: int32(tss.bePorts[0]), 383 LoadBalanceToken: lbToken, 384 } 385 var bes []*lbpb.Server 386 bes = append(bes, be) 387 sl := &lbpb.ServerList{ 388 Servers: bes, 389 } 390 tss.ls.sls <- sl 391 creds := serverNameCheckCreds{ 392 expected: beServerName, 393 } 394 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 395 defer cancel() 396 cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, 397 grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) 398 if err != nil { 399 t.Fatalf("Failed to dial to the backend %v", err) 400 } 401 defer cc.Close() 402 testC := testpb.NewTestServiceClient(cc) 403 404 r.UpdateState(resolver.State{Addresses: []resolver.Address{{ 405 Addr: tss.lbAddr, 406 Type: resolver.GRPCLB, 407 ServerName: lbServerName, 408 }}}) 409 410 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { 411 t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) 412 } 413} 414 415// The remote balancer sends response with duplicates to grpclb client. 416func TestGRPCLBWeighted(t *testing.T) { 417 defer leakcheck.Check(t) 418 419 r, cleanup := manual.GenerateAndRegisterManualResolver() 420 defer cleanup() 421 422 tss, cleanup, err := newLoadBalancer(2) 423 if err != nil { 424 t.Fatalf("failed to create new load balancer: %v", err) 425 } 426 defer cleanup() 427 428 beServers := []*lbpb.Server{{ 429 IpAddress: tss.beIPs[0], 430 Port: int32(tss.bePorts[0]), 431 LoadBalanceToken: lbToken, 432 }, { 433 IpAddress: tss.beIPs[1], 434 Port: int32(tss.bePorts[1]), 435 LoadBalanceToken: lbToken, 436 }} 437 portsToIndex := make(map[int]int) 438 for i := range beServers { 439 portsToIndex[tss.bePorts[i]] = i 440 } 441 442 creds := serverNameCheckCreds{ 443 expected: beServerName, 444 } 445 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 446 defer cancel() 447 cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, 448 grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) 449 if err != nil { 450 t.Fatalf("Failed to dial to the backend %v", err) 451 } 452 defer cc.Close() 453 testC := testpb.NewTestServiceClient(cc) 454 455 r.UpdateState(resolver.State{Addresses: []resolver.Address{{ 456 Addr: tss.lbAddr, 457 Type: resolver.GRPCLB, 458 ServerName: lbServerName, 459 }}}) 460 461 sequences := []string{"00101", "00011"} 462 for _, seq := range sequences { 463 var ( 464 bes []*lbpb.Server 465 p peer.Peer 466 result string 467 ) 468 for _, s := range seq { 469 bes = append(bes, beServers[s-'0']) 470 } 471 tss.ls.sls <- &lbpb.ServerList{Servers: bes} 472 473 for i := 0; i < 1000; i++ { 474 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { 475 t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) 476 } 477 result += strconv.Itoa(portsToIndex[p.Addr.(*net.TCPAddr).Port]) 478 } 479 // The generated result will be in format of "0010100101". 480 if !strings.Contains(result, strings.Repeat(seq, 2)) { 481 t.Errorf("got result sequence %q, want patten %q", result, seq) 482 } 483 } 484} 485 486func TestDropRequest(t *testing.T) { 487 defer leakcheck.Check(t) 488 489 r, cleanup := manual.GenerateAndRegisterManualResolver() 490 defer cleanup() 491 492 tss, cleanup, err := newLoadBalancer(2) 493 if err != nil { 494 t.Fatalf("failed to create new load balancer: %v", err) 495 } 496 defer cleanup() 497 tss.ls.sls <- &lbpb.ServerList{ 498 Servers: []*lbpb.Server{{ 499 IpAddress: tss.beIPs[0], 500 Port: int32(tss.bePorts[0]), 501 LoadBalanceToken: lbToken, 502 Drop: false, 503 }, { 504 IpAddress: tss.beIPs[1], 505 Port: int32(tss.bePorts[1]), 506 LoadBalanceToken: lbToken, 507 Drop: false, 508 }, { 509 Drop: true, 510 }}, 511 } 512 creds := serverNameCheckCreds{ 513 expected: beServerName, 514 } 515 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 516 defer cancel() 517 cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, 518 grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) 519 if err != nil { 520 t.Fatalf("Failed to dial to the backend %v", err) 521 } 522 defer cc.Close() 523 testC := testpb.NewTestServiceClient(cc) 524 525 r.UpdateState(resolver.State{Addresses: []resolver.Address{{ 526 Addr: tss.lbAddr, 527 Type: resolver.GRPCLB, 528 ServerName: lbServerName, 529 }}}) 530 531 // Wait for the 1st, non-fail-fast RPC to succeed. This ensures both server 532 // connections are made, because the first one has Drop set to true. 533 var i int 534 for i = 0; i < 1000; i++ { 535 if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err == nil { 536 break 537 } 538 time.Sleep(time.Millisecond) 539 } 540 if i >= 1000 { 541 t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", testC, err) 542 } 543 select { 544 case <-ctx.Done(): 545 t.Fatal("timed out", ctx.Err()) 546 default: 547 } 548 for _, failfast := range []bool{true, false} { 549 for i := 0; i < 3; i++ { 550 // 1st RPCs pick the second item in server list. They should succeed 551 // since they choose the non-drop-request backend according to the 552 // round robin policy. 553 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(!failfast)); err != nil { 554 t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) 555 } 556 // 2st RPCs should fail, because they pick last item in server list, 557 // with Drop set to true. 558 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(!failfast)); status.Code(err) != codes.Unavailable { 559 t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable) 560 } 561 // 3rd RPCs pick the first item in server list. They should succeed 562 // since they choose the non-drop-request backend according to the 563 // round robin policy. 564 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(!failfast)); err != nil { 565 t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) 566 } 567 } 568 } 569 tss.backends[0].Stop() 570 // This last pick was backend 0. Closing backend 0 doesn't reset drop index 571 // (for level 1 picking), so the following picks will be (backend1, drop, 572 // backend1), instead of (backend, backend, drop) if drop index was reset. 573 time.Sleep(time.Second) 574 for i := 0; i < 3; i++ { 575 var p peer.Peer 576 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { 577 t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) 578 } 579 if want := tss.bePorts[1]; p.Addr.(*net.TCPAddr).Port != want { 580 t.Errorf("got peer: %v, want peer port: %v", p.Addr, want) 581 } 582 583 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true)); status.Code(err) != codes.Unavailable { 584 t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable) 585 } 586 587 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { 588 t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) 589 } 590 if want := tss.bePorts[1]; p.Addr.(*net.TCPAddr).Port != want { 591 t.Errorf("got peer: %v, want peer port: %v", p.Addr, want) 592 } 593 } 594} 595 596// When the balancer in use disconnects, grpclb should connect to the next address from resolved balancer address list. 597func TestBalancerDisconnects(t *testing.T) { 598 defer leakcheck.Check(t) 599 600 r, cleanup := manual.GenerateAndRegisterManualResolver() 601 defer cleanup() 602 603 var ( 604 tests []*testServers 605 lbs []*grpc.Server 606 ) 607 for i := 0; i < 2; i++ { 608 tss, cleanup, err := newLoadBalancer(1) 609 if err != nil { 610 t.Fatalf("failed to create new load balancer: %v", err) 611 } 612 defer cleanup() 613 614 be := &lbpb.Server{ 615 IpAddress: tss.beIPs[0], 616 Port: int32(tss.bePorts[0]), 617 LoadBalanceToken: lbToken, 618 } 619 var bes []*lbpb.Server 620 bes = append(bes, be) 621 sl := &lbpb.ServerList{ 622 Servers: bes, 623 } 624 tss.ls.sls <- sl 625 626 tests = append(tests, tss) 627 lbs = append(lbs, tss.lb) 628 } 629 630 creds := serverNameCheckCreds{ 631 expected: beServerName, 632 } 633 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 634 defer cancel() 635 cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, 636 grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) 637 if err != nil { 638 t.Fatalf("Failed to dial to the backend %v", err) 639 } 640 defer cc.Close() 641 testC := testpb.NewTestServiceClient(cc) 642 643 r.UpdateState(resolver.State{Addresses: []resolver.Address{{ 644 Addr: tests[0].lbAddr, 645 Type: resolver.GRPCLB, 646 ServerName: lbServerName, 647 }, { 648 Addr: tests[1].lbAddr, 649 Type: resolver.GRPCLB, 650 ServerName: lbServerName, 651 }}}) 652 653 var p peer.Peer 654 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { 655 t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) 656 } 657 if p.Addr.(*net.TCPAddr).Port != tests[0].bePorts[0] { 658 t.Fatalf("got peer: %v, want peer port: %v", p.Addr, tests[0].bePorts[0]) 659 } 660 661 lbs[0].Stop() 662 // Stop balancer[0], balancer[1] should be used by grpclb. 663 // Check peer address to see if that happened. 664 for i := 0; i < 1000; i++ { 665 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { 666 t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) 667 } 668 if p.Addr.(*net.TCPAddr).Port == tests[1].bePorts[0] { 669 return 670 } 671 time.Sleep(time.Millisecond) 672 } 673 t.Fatalf("No RPC sent to second backend after 1 second") 674} 675 676func TestFallback(t *testing.T) { 677 balancer.Register(newLBBuilderWithFallbackTimeout(100 * time.Millisecond)) 678 defer balancer.Register(newLBBuilder()) 679 680 defer leakcheck.Check(t) 681 682 r, cleanup := manual.GenerateAndRegisterManualResolver() 683 defer cleanup() 684 685 tss, cleanup, err := newLoadBalancer(1) 686 if err != nil { 687 t.Fatalf("failed to create new load balancer: %v", err) 688 } 689 defer cleanup() 690 691 // Start a standalone backend. 692 beLis, err := net.Listen("tcp", "localhost:0") 693 if err != nil { 694 t.Fatalf("Failed to listen %v", err) 695 } 696 defer beLis.Close() 697 standaloneBEs := startBackends(beServerName, true, beLis) 698 defer stopBackends(standaloneBEs) 699 700 be := &lbpb.Server{ 701 IpAddress: tss.beIPs[0], 702 Port: int32(tss.bePorts[0]), 703 LoadBalanceToken: lbToken, 704 } 705 var bes []*lbpb.Server 706 bes = append(bes, be) 707 sl := &lbpb.ServerList{ 708 Servers: bes, 709 } 710 tss.ls.sls <- sl 711 creds := serverNameCheckCreds{ 712 expected: beServerName, 713 } 714 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 715 defer cancel() 716 cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, 717 grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) 718 if err != nil { 719 t.Fatalf("Failed to dial to the backend %v", err) 720 } 721 defer cc.Close() 722 testC := testpb.NewTestServiceClient(cc) 723 724 r.UpdateState(resolver.State{Addresses: []resolver.Address{{ 725 Addr: "invalid.address", 726 Type: resolver.GRPCLB, 727 ServerName: lbServerName, 728 }, { 729 Addr: beLis.Addr().String(), 730 Type: resolver.Backend, 731 ServerName: beServerName, 732 }}}) 733 734 var p peer.Peer 735 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { 736 t.Fatalf("_.EmptyCall(_, _) = _, %v, want _, <nil>", err) 737 } 738 if p.Addr.String() != beLis.Addr().String() { 739 t.Fatalf("got peer: %v, want peer: %v", p.Addr, beLis.Addr()) 740 } 741 742 r.UpdateState(resolver.State{Addresses: []resolver.Address{{ 743 Addr: tss.lbAddr, 744 Type: resolver.GRPCLB, 745 ServerName: lbServerName, 746 }, { 747 Addr: beLis.Addr().String(), 748 Type: resolver.Backend, 749 ServerName: beServerName, 750 }}}) 751 752 var backendUsed bool 753 for i := 0; i < 1000; i++ { 754 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { 755 t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) 756 } 757 if p.Addr.(*net.TCPAddr).Port == tss.bePorts[0] { 758 backendUsed = true 759 break 760 } 761 time.Sleep(time.Millisecond) 762 } 763 if !backendUsed { 764 t.Fatalf("No RPC sent to backend behind remote balancer after 1 second") 765 } 766 767 // Close backend and remote balancer connections, should use fallback. 768 tss.beListeners[0].(*restartableListener).stopPreviousConns() 769 tss.lbListener.(*restartableListener).stopPreviousConns() 770 time.Sleep(time.Second) 771 772 var fallbackUsed bool 773 for i := 0; i < 1000; i++ { 774 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { 775 t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) 776 } 777 if p.Addr.String() == beLis.Addr().String() { 778 fallbackUsed = true 779 break 780 } 781 time.Sleep(time.Millisecond) 782 } 783 if !fallbackUsed { 784 t.Fatalf("No RPC sent to fallback after 1 second") 785 } 786 787 // Restart backend and remote balancer, should not use backends. 788 tss.beListeners[0].(*restartableListener).restart() 789 tss.lbListener.(*restartableListener).restart() 790 tss.ls.sls <- sl 791 792 time.Sleep(time.Second) 793 794 var backendUsed2 bool 795 for i := 0; i < 1000; i++ { 796 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { 797 t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) 798 } 799 if p.Addr.(*net.TCPAddr).Port == tss.bePorts[0] { 800 backendUsed2 = true 801 break 802 } 803 time.Sleep(time.Millisecond) 804 } 805 if !backendUsed2 { 806 t.Fatalf("No RPC sent to backend behind remote balancer after 1 second") 807 } 808} 809 810// The remote balancer sends response with duplicates to grpclb client. 811func TestGRPCLBPickFirst(t *testing.T) { 812 balancer.Register(newLBBuilderWithPickFirst()) 813 defer balancer.Register(newLBBuilder()) 814 815 defer leakcheck.Check(t) 816 817 r, cleanup := manual.GenerateAndRegisterManualResolver() 818 defer cleanup() 819 820 tss, cleanup, err := newLoadBalancer(3) 821 if err != nil { 822 t.Fatalf("failed to create new load balancer: %v", err) 823 } 824 defer cleanup() 825 826 beServers := []*lbpb.Server{{ 827 IpAddress: tss.beIPs[0], 828 Port: int32(tss.bePorts[0]), 829 LoadBalanceToken: lbToken, 830 }, { 831 IpAddress: tss.beIPs[1], 832 Port: int32(tss.bePorts[1]), 833 LoadBalanceToken: lbToken, 834 }, { 835 IpAddress: tss.beIPs[2], 836 Port: int32(tss.bePorts[2]), 837 LoadBalanceToken: lbToken, 838 }} 839 portsToIndex := make(map[int]int) 840 for i := range beServers { 841 portsToIndex[tss.bePorts[i]] = i 842 } 843 844 creds := serverNameCheckCreds{ 845 expected: beServerName, 846 } 847 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 848 defer cancel() 849 cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, 850 grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) 851 if err != nil { 852 t.Fatalf("Failed to dial to the backend %v", err) 853 } 854 defer cc.Close() 855 testC := testpb.NewTestServiceClient(cc) 856 857 r.UpdateState(resolver.State{Addresses: []resolver.Address{{ 858 Addr: tss.lbAddr, 859 Type: resolver.GRPCLB, 860 ServerName: lbServerName, 861 }}}) 862 863 var p peer.Peer 864 865 portPicked1 := 0 866 tss.ls.sls <- &lbpb.ServerList{Servers: beServers[1:2]} 867 for i := 0; i < 1000; i++ { 868 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { 869 t.Fatalf("_.EmptyCall(_, _) = _, %v, want _, <nil>", err) 870 } 871 if portPicked1 == 0 { 872 portPicked1 = p.Addr.(*net.TCPAddr).Port 873 continue 874 } 875 if portPicked1 != p.Addr.(*net.TCPAddr).Port { 876 t.Fatalf("Different backends are picked for RPCs: %v vs %v", portPicked1, p.Addr.(*net.TCPAddr).Port) 877 } 878 } 879 880 portPicked2 := portPicked1 881 tss.ls.sls <- &lbpb.ServerList{Servers: beServers[:1]} 882 for i := 0; i < 1000; i++ { 883 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { 884 t.Fatalf("_.EmptyCall(_, _) = _, %v, want _, <nil>", err) 885 } 886 if portPicked2 == portPicked1 { 887 portPicked2 = p.Addr.(*net.TCPAddr).Port 888 continue 889 } 890 if portPicked2 != p.Addr.(*net.TCPAddr).Port { 891 t.Fatalf("Different backends are picked for RPCs: %v vs %v", portPicked2, p.Addr.(*net.TCPAddr).Port) 892 } 893 } 894 895 portPicked := portPicked2 896 tss.ls.sls <- &lbpb.ServerList{Servers: beServers[1:]} 897 for i := 0; i < 1000; i++ { 898 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { 899 t.Fatalf("_.EmptyCall(_, _) = _, %v, want _, <nil>", err) 900 } 901 if portPicked == portPicked2 { 902 portPicked = p.Addr.(*net.TCPAddr).Port 903 continue 904 } 905 if portPicked != p.Addr.(*net.TCPAddr).Port { 906 t.Fatalf("Different backends are picked for RPCs: %v vs %v", portPicked, p.Addr.(*net.TCPAddr).Port) 907 } 908 } 909} 910 911type failPreRPCCred struct{} 912 913func (failPreRPCCred) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { 914 if strings.Contains(uri[0], failtosendURI) { 915 return nil, fmt.Errorf("rpc should fail to send") 916 } 917 return nil, nil 918} 919 920func (failPreRPCCred) RequireTransportSecurity() bool { 921 return false 922} 923 924func checkStats(stats, expected *rpcStats) error { 925 if !stats.equal(expected) { 926 return fmt.Errorf("stats not equal: got %+v, want %+v", stats, expected) 927 } 928 return nil 929} 930 931func runAndGetStats(t *testing.T, drop bool, runRPCs func(*grpc.ClientConn)) *rpcStats { 932 defer leakcheck.Check(t) 933 934 r, cleanup := manual.GenerateAndRegisterManualResolver() 935 defer cleanup() 936 937 tss, cleanup, err := newLoadBalancer(1) 938 if err != nil { 939 t.Fatalf("failed to create new load balancer: %v", err) 940 } 941 defer cleanup() 942 servers := []*lbpb.Server{{ 943 IpAddress: tss.beIPs[0], 944 Port: int32(tss.bePorts[0]), 945 LoadBalanceToken: lbToken, 946 }} 947 if drop { 948 servers = append(servers, &lbpb.Server{ 949 LoadBalanceToken: lbToken, 950 Drop: drop, 951 }) 952 } 953 tss.ls.sls <- &lbpb.ServerList{Servers: servers} 954 tss.ls.statsDura = 100 * time.Millisecond 955 creds := serverNameCheckCreds{expected: beServerName} 956 957 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 958 defer cancel() 959 cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, 960 grpc.WithTransportCredentials(&creds), 961 grpc.WithPerRPCCredentials(failPreRPCCred{}), 962 grpc.WithContextDialer(fakeNameDialer)) 963 if err != nil { 964 t.Fatalf("Failed to dial to the backend %v", err) 965 } 966 defer cc.Close() 967 968 r.UpdateState(resolver.State{Addresses: []resolver.Address{{ 969 Addr: tss.lbAddr, 970 Type: resolver.GRPCLB, 971 ServerName: lbServerName, 972 }}}) 973 974 runRPCs(cc) 975 time.Sleep(1 * time.Second) 976 stats := tss.ls.stats 977 return stats 978} 979 980const ( 981 countRPC = 40 982 failtosendURI = "failtosend" 983) 984 985func TestGRPCLBStatsUnarySuccess(t *testing.T) { 986 defer leakcheck.Check(t) 987 stats := runAndGetStats(t, false, func(cc *grpc.ClientConn) { 988 testC := testpb.NewTestServiceClient(cc) 989 // The first non-failfast RPC succeeds, all connections are up. 990 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { 991 t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) 992 } 993 for i := 0; i < countRPC-1; i++ { 994 testC.EmptyCall(context.Background(), &testpb.Empty{}) 995 } 996 }) 997 998 if err := checkStats(stats, &rpcStats{ 999 numCallsStarted: int64(countRPC), 1000 numCallsFinished: int64(countRPC), 1001 numCallsFinishedKnownReceived: int64(countRPC), 1002 }); err != nil { 1003 t.Fatal(err) 1004 } 1005} 1006 1007func TestGRPCLBStatsUnaryDrop(t *testing.T) { 1008 defer leakcheck.Check(t) 1009 stats := runAndGetStats(t, true, func(cc *grpc.ClientConn) { 1010 testC := testpb.NewTestServiceClient(cc) 1011 // The first non-failfast RPC succeeds, all connections are up. 1012 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { 1013 t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) 1014 } 1015 for i := 0; i < countRPC-1; i++ { 1016 testC.EmptyCall(context.Background(), &testpb.Empty{}) 1017 } 1018 }) 1019 1020 if err := checkStats(stats, &rpcStats{ 1021 numCallsStarted: int64(countRPC), 1022 numCallsFinished: int64(countRPC), 1023 numCallsFinishedKnownReceived: int64(countRPC) / 2, 1024 numCallsDropped: map[string]int64{lbToken: int64(countRPC) / 2}, 1025 }); err != nil { 1026 t.Fatal(err) 1027 } 1028} 1029 1030func TestGRPCLBStatsUnaryFailedToSend(t *testing.T) { 1031 defer leakcheck.Check(t) 1032 stats := runAndGetStats(t, false, func(cc *grpc.ClientConn) { 1033 testC := testpb.NewTestServiceClient(cc) 1034 // The first non-failfast RPC succeeds, all connections are up. 1035 if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { 1036 t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err) 1037 } 1038 for i := 0; i < countRPC-1; i++ { 1039 cc.Invoke(context.Background(), failtosendURI, &testpb.Empty{}, nil) 1040 } 1041 }) 1042 1043 if err := checkStats(stats, &rpcStats{ 1044 numCallsStarted: int64(countRPC), 1045 numCallsFinished: int64(countRPC), 1046 numCallsFinishedWithClientFailedToSend: int64(countRPC - 1), 1047 numCallsFinishedKnownReceived: 1, 1048 }); err != nil { 1049 t.Fatal(err) 1050 } 1051} 1052 1053func TestGRPCLBStatsStreamingSuccess(t *testing.T) { 1054 defer leakcheck.Check(t) 1055 stats := runAndGetStats(t, false, func(cc *grpc.ClientConn) { 1056 testC := testpb.NewTestServiceClient(cc) 1057 // The first non-failfast RPC succeeds, all connections are up. 1058 stream, err := testC.FullDuplexCall(context.Background(), grpc.WaitForReady(true)) 1059 if err != nil { 1060 t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err) 1061 } 1062 for { 1063 if _, err = stream.Recv(); err == io.EOF { 1064 break 1065 } 1066 } 1067 for i := 0; i < countRPC-1; i++ { 1068 stream, err = testC.FullDuplexCall(context.Background()) 1069 if err == nil { 1070 // Wait for stream to end if err is nil. 1071 for { 1072 if _, err = stream.Recv(); err == io.EOF { 1073 break 1074 } 1075 } 1076 } 1077 } 1078 }) 1079 1080 if err := checkStats(stats, &rpcStats{ 1081 numCallsStarted: int64(countRPC), 1082 numCallsFinished: int64(countRPC), 1083 numCallsFinishedKnownReceived: int64(countRPC), 1084 }); err != nil { 1085 t.Fatal(err) 1086 } 1087} 1088 1089func TestGRPCLBStatsStreamingDrop(t *testing.T) { 1090 defer leakcheck.Check(t) 1091 stats := runAndGetStats(t, true, func(cc *grpc.ClientConn) { 1092 testC := testpb.NewTestServiceClient(cc) 1093 // The first non-failfast RPC succeeds, all connections are up. 1094 stream, err := testC.FullDuplexCall(context.Background(), grpc.WaitForReady(true)) 1095 if err != nil { 1096 t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err) 1097 } 1098 for { 1099 if _, err = stream.Recv(); err == io.EOF { 1100 break 1101 } 1102 } 1103 for i := 0; i < countRPC-1; i++ { 1104 stream, err = testC.FullDuplexCall(context.Background()) 1105 if err == nil { 1106 // Wait for stream to end if err is nil. 1107 for { 1108 if _, err = stream.Recv(); err == io.EOF { 1109 break 1110 } 1111 } 1112 } 1113 } 1114 }) 1115 1116 if err := checkStats(stats, &rpcStats{ 1117 numCallsStarted: int64(countRPC), 1118 numCallsFinished: int64(countRPC), 1119 numCallsFinishedKnownReceived: int64(countRPC) / 2, 1120 numCallsDropped: map[string]int64{lbToken: int64(countRPC) / 2}, 1121 }); err != nil { 1122 t.Fatal(err) 1123 } 1124} 1125 1126func TestGRPCLBStatsStreamingFailedToSend(t *testing.T) { 1127 defer leakcheck.Check(t) 1128 stats := runAndGetStats(t, false, func(cc *grpc.ClientConn) { 1129 testC := testpb.NewTestServiceClient(cc) 1130 // The first non-failfast RPC succeeds, all connections are up. 1131 stream, err := testC.FullDuplexCall(context.Background(), grpc.WaitForReady(true)) 1132 if err != nil { 1133 t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err) 1134 } 1135 for { 1136 if _, err = stream.Recv(); err == io.EOF { 1137 break 1138 } 1139 } 1140 for i := 0; i < countRPC-1; i++ { 1141 cc.NewStream(context.Background(), &grpc.StreamDesc{}, failtosendURI) 1142 } 1143 }) 1144 1145 if err := checkStats(stats, &rpcStats{ 1146 numCallsStarted: int64(countRPC), 1147 numCallsFinished: int64(countRPC), 1148 numCallsFinishedWithClientFailedToSend: int64(countRPC - 1), 1149 numCallsFinishedKnownReceived: 1, 1150 }); err != nil { 1151 t.Fatal(err) 1152 } 1153} 1154