1/*
2 *
3 * Copyright 2018 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package test
20
21import (
22	"context"
23	"errors"
24	"fmt"
25	"net"
26	"sync"
27	"testing"
28	"time"
29
30	"google.golang.org/grpc"
31	"google.golang.org/grpc/codes"
32	"google.golang.org/grpc/connectivity"
33	_ "google.golang.org/grpc/health"
34	healthgrpc "google.golang.org/grpc/health/grpc_health_v1"
35	healthpb "google.golang.org/grpc/health/grpc_health_v1"
36	"google.golang.org/grpc/internal"
37	"google.golang.org/grpc/internal/channelz"
38	"google.golang.org/grpc/internal/grpctest"
39	"google.golang.org/grpc/resolver"
40	"google.golang.org/grpc/resolver/manual"
41	"google.golang.org/grpc/status"
42	testpb "google.golang.org/grpc/test/grpc_testing"
43)
44
45var testHealthCheckFunc = internal.HealthCheckFunc
46
47func newTestHealthServer() *testHealthServer {
48	return newTestHealthServerWithWatchFunc(defaultWatchFunc)
49}
50
51func newTestHealthServerWithWatchFunc(f func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error) *testHealthServer {
52	return &testHealthServer{
53		watchFunc: f,
54		update:    make(chan struct{}, 1),
55		status:    make(map[string]healthpb.HealthCheckResponse_ServingStatus),
56	}
57}
58
59// defaultWatchFunc will send a HealthCheckResponse to the client whenever SetServingStatus is called.
60func defaultWatchFunc(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error {
61	if in.Service != "foo" {
62		return status.Error(codes.FailedPrecondition,
63			"the defaultWatchFunc only handles request with service name to be \"foo\"")
64	}
65	var done bool
66	for {
67		select {
68		case <-stream.Context().Done():
69			done = true
70		case <-s.update:
71		}
72		if done {
73			break
74		}
75		s.mu.Lock()
76		resp := &healthpb.HealthCheckResponse{
77			Status: s.status[in.Service],
78		}
79		s.mu.Unlock()
80		stream.SendMsg(resp)
81	}
82	return nil
83}
84
85type testHealthServer struct {
86	healthpb.UnimplementedHealthServer
87	watchFunc func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error
88	mu        sync.Mutex
89	status    map[string]healthpb.HealthCheckResponse_ServingStatus
90	update    chan struct{}
91}
92
93func (s *testHealthServer) Check(ctx context.Context, in *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) {
94	return &healthpb.HealthCheckResponse{
95		Status: healthpb.HealthCheckResponse_SERVING,
96	}, nil
97}
98
99func (s *testHealthServer) Watch(in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error {
100	return s.watchFunc(s, in, stream)
101}
102
103// SetServingStatus is called when need to reset the serving status of a service
104// or insert a new service entry into the statusMap.
105func (s *testHealthServer) SetServingStatus(service string, status healthpb.HealthCheckResponse_ServingStatus) {
106	s.mu.Lock()
107	s.status[service] = status
108	select {
109	case <-s.update:
110	default:
111	}
112	s.update <- struct{}{}
113	s.mu.Unlock()
114}
115
116func setupHealthCheckWrapper() (hcEnterChan chan struct{}, hcExitChan chan struct{}, wrapper internal.HealthChecker) {
117	hcEnterChan = make(chan struct{})
118	hcExitChan = make(chan struct{})
119	wrapper = func(ctx context.Context, newStream func(string) (interface{}, error), update func(connectivity.State, error), service string) error {
120		close(hcEnterChan)
121		defer close(hcExitChan)
122		return testHealthCheckFunc(ctx, newStream, update, service)
123	}
124	return
125}
126
127type svrConfig struct {
128	specialWatchFunc func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error
129}
130
131func setupServer(sc *svrConfig) (s *grpc.Server, lis net.Listener, ts *testHealthServer, deferFunc func(), err error) {
132	s = grpc.NewServer()
133	lis, err = net.Listen("tcp", "localhost:0")
134	if err != nil {
135		return nil, nil, nil, func() {}, fmt.Errorf("failed to listen due to err %v", err)
136	}
137	if sc.specialWatchFunc != nil {
138		ts = newTestHealthServerWithWatchFunc(sc.specialWatchFunc)
139	} else {
140		ts = newTestHealthServer()
141	}
142	healthgrpc.RegisterHealthServer(s, ts)
143	testpb.RegisterTestServiceServer(s, &testServer{})
144	go s.Serve(lis)
145	return s, lis, ts, s.Stop, nil
146}
147
148type clientConfig struct {
149	balancerName               string
150	testHealthCheckFuncWrapper internal.HealthChecker
151	extraDialOption            []grpc.DialOption
152}
153
154func setupClient(c *clientConfig) (cc *grpc.ClientConn, r *manual.Resolver, deferFunc func(), err error) {
155	r = manual.NewBuilderWithScheme("whatever")
156	var opts []grpc.DialOption
157	opts = append(opts, grpc.WithInsecure(), grpc.WithResolvers(r), grpc.WithBalancerName(c.balancerName))
158	if c.testHealthCheckFuncWrapper != nil {
159		opts = append(opts, internal.WithHealthCheckFunc.(func(internal.HealthChecker) grpc.DialOption)(c.testHealthCheckFuncWrapper))
160	}
161	opts = append(opts, c.extraDialOption...)
162	cc, err = grpc.Dial(r.Scheme()+":///test.server", opts...)
163	if err != nil {
164
165		return nil, nil, nil, fmt.Errorf("dial failed due to err: %v", err)
166	}
167	return cc, r, func() { cc.Close() }, nil
168}
169
170func (s) TestHealthCheckWatchStateChange(t *testing.T) {
171	_, lis, ts, deferFunc, err := setupServer(&svrConfig{})
172	defer deferFunc()
173	if err != nil {
174		t.Fatal(err)
175	}
176
177	// The table below shows the expected series of addrConn connectivity transitions when server
178	// updates its health status. As there's only one addrConn corresponds with the ClientConn in this
179	// test, we use ClientConn's connectivity state as the addrConn connectivity state.
180	//+------------------------------+-------------------------------------------+
181	//| Health Check Returned Status | Expected addrConn Connectivity Transition |
182	//+------------------------------+-------------------------------------------+
183	//| NOT_SERVING                  | ->TRANSIENT FAILURE                       |
184	//| SERVING                      | ->READY                                   |
185	//| SERVICE_UNKNOWN              | ->TRANSIENT FAILURE                       |
186	//| SERVING                      | ->READY                                   |
187	//| UNKNOWN                      | ->TRANSIENT FAILURE                       |
188	//+------------------------------+-------------------------------------------+
189	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_NOT_SERVING)
190
191	cc, r, deferFunc, err := setupClient(&clientConfig{balancerName: "round_robin"})
192	if err != nil {
193		t.Fatal(err)
194	}
195	defer deferFunc()
196
197	r.UpdateState(resolver.State{
198		Addresses: []resolver.Address{{Addr: lis.Addr().String()}},
199		ServiceConfig: parseCfg(r, `{
200	"healthCheckConfig": {
201		"serviceName": "foo"
202	}
203}`)})
204	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
205	defer cancel()
206	if ok := cc.WaitForStateChange(ctx, connectivity.Idle); !ok {
207		t.Fatal("ClientConn is still in IDLE state when the context times out.")
208	}
209	if ok := cc.WaitForStateChange(ctx, connectivity.Connecting); !ok {
210		t.Fatal("ClientConn is still in CONNECTING state when the context times out.")
211	}
212	if s := cc.GetState(); s != connectivity.TransientFailure {
213		t.Fatalf("ClientConn is in %v state, want TRANSIENT FAILURE", s)
214	}
215
216	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING)
217	if ok := cc.WaitForStateChange(ctx, connectivity.TransientFailure); !ok {
218		t.Fatal("ClientConn is still in TRANSIENT FAILURE state when the context times out.")
219	}
220	if s := cc.GetState(); s != connectivity.Ready {
221		t.Fatalf("ClientConn is in %v state, want READY", s)
222	}
223
224	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVICE_UNKNOWN)
225	if ok := cc.WaitForStateChange(ctx, connectivity.Ready); !ok {
226		t.Fatal("ClientConn is still in READY state when the context times out.")
227	}
228	if s := cc.GetState(); s != connectivity.TransientFailure {
229		t.Fatalf("ClientConn is in %v state, want TRANSIENT FAILURE", s)
230	}
231
232	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING)
233	if ok := cc.WaitForStateChange(ctx, connectivity.TransientFailure); !ok {
234		t.Fatal("ClientConn is still in TRANSIENT FAILURE state when the context times out.")
235	}
236	if s := cc.GetState(); s != connectivity.Ready {
237		t.Fatalf("ClientConn is in %v state, want READY", s)
238	}
239
240	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_UNKNOWN)
241	if ok := cc.WaitForStateChange(ctx, connectivity.Ready); !ok {
242		t.Fatal("ClientConn is still in READY state when the context times out.")
243	}
244	if s := cc.GetState(); s != connectivity.TransientFailure {
245		t.Fatalf("ClientConn is in %v state, want TRANSIENT FAILURE", s)
246	}
247}
248
249// If Watch returns Unimplemented, then the ClientConn should go into READY state.
250func (s) TestHealthCheckHealthServerNotRegistered(t *testing.T) {
251	grpctest.TLogger.ExpectError("Subchannel health check is unimplemented at server side, thus health check is disabled")
252	s := grpc.NewServer()
253	lis, err := net.Listen("tcp", "localhost:0")
254	if err != nil {
255		t.Fatalf("failed to listen due to err: %v", err)
256	}
257	go s.Serve(lis)
258	defer s.Stop()
259
260	cc, r, deferFunc, err := setupClient(&clientConfig{balancerName: "round_robin"})
261	if err != nil {
262		t.Fatal(err)
263	}
264	defer deferFunc()
265
266	r.UpdateState(resolver.State{
267		Addresses: []resolver.Address{{Addr: lis.Addr().String()}},
268		ServiceConfig: parseCfg(r, `{
269	"healthCheckConfig": {
270		"serviceName": "foo"
271	}
272}`)})
273	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
274	defer cancel()
275
276	if ok := cc.WaitForStateChange(ctx, connectivity.Idle); !ok {
277		t.Fatal("ClientConn is still in IDLE state when the context times out.")
278	}
279	if ok := cc.WaitForStateChange(ctx, connectivity.Connecting); !ok {
280		t.Fatal("ClientConn is still in CONNECTING state when the context times out.")
281	}
282	if s := cc.GetState(); s != connectivity.Ready {
283		t.Fatalf("ClientConn is in %v state, want READY", s)
284	}
285}
286
287// In the case of a goaway received, the health check stream should be terminated and health check
288// function should exit.
289func (s) TestHealthCheckWithGoAway(t *testing.T) {
290	hcEnterChan, hcExitChan, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
291
292	s, lis, ts, deferFunc, err := setupServer(&svrConfig{})
293	defer deferFunc()
294	if err != nil {
295		t.Fatal(err)
296	}
297
298	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING)
299
300	cc, r, deferFunc, err := setupClient(&clientConfig{
301		balancerName:               "round_robin",
302		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
303	})
304	if err != nil {
305		t.Fatal(err)
306	}
307	defer deferFunc()
308
309	tc := testpb.NewTestServiceClient(cc)
310	r.UpdateState(resolver.State{
311		Addresses: []resolver.Address{{Addr: lis.Addr().String()}},
312		ServiceConfig: parseCfg(r, `{
313	"healthCheckConfig": {
314		"serviceName": "foo"
315	}
316}`)})
317
318	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
319	defer cancel()
320
321	// make some rpcs to make sure connection is working.
322	if err := verifyResultWithDelay(func() (bool, error) {
323		if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
324			return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
325		}
326		return true, nil
327	}); err != nil {
328		t.Fatal(err)
329	}
330
331	// the stream rpc will persist through goaway event.
332	stream, err := tc.FullDuplexCall(ctx, grpc.WaitForReady(true))
333	if err != nil {
334		t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
335	}
336	respParam := []*testpb.ResponseParameters{{Size: 1}}
337	payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(1))
338	if err != nil {
339		t.Fatal(err)
340	}
341	req := &testpb.StreamingOutputCallRequest{
342		ResponseParameters: respParam,
343		Payload:            payload,
344	}
345	if err := stream.Send(req); err != nil {
346		t.Fatalf("%v.Send(_) = %v, want <nil>", stream, err)
347	}
348	if _, err := stream.Recv(); err != nil {
349		t.Fatalf("%v.Recv() = _, %v, want _, <nil>", stream, err)
350	}
351
352	select {
353	case <-hcExitChan:
354		t.Fatal("Health check function has exited, which is not expected.")
355	default:
356	}
357
358	// server sends GoAway
359	go s.GracefulStop()
360
361	select {
362	case <-hcExitChan:
363	case <-time.After(5 * time.Second):
364		select {
365		case <-hcEnterChan:
366		default:
367			t.Fatal("Health check function has not entered after 5s.")
368		}
369		t.Fatal("Health check function has not exited after 5s.")
370	}
371
372	// The existing RPC should be still good to proceed.
373	if err := stream.Send(req); err != nil {
374		t.Fatalf("%v.Send(_) = %v, want <nil>", stream, err)
375	}
376	if _, err := stream.Recv(); err != nil {
377		t.Fatalf("%v.Recv() = _, %v, want _, <nil>", stream, err)
378	}
379}
380
381func (s) TestHealthCheckWithConnClose(t *testing.T) {
382	hcEnterChan, hcExitChan, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
383
384	s, lis, ts, deferFunc, err := setupServer(&svrConfig{})
385	defer deferFunc()
386	if err != nil {
387		t.Fatal(err)
388	}
389
390	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING)
391
392	cc, r, deferFunc, err := setupClient(&clientConfig{
393		balancerName:               "round_robin",
394		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
395	})
396	if err != nil {
397		t.Fatal(err)
398	}
399	defer deferFunc()
400
401	tc := testpb.NewTestServiceClient(cc)
402
403	r.UpdateState(resolver.State{
404		Addresses: []resolver.Address{{Addr: lis.Addr().String()}},
405		ServiceConfig: parseCfg(r, `{
406	"healthCheckConfig": {
407		"serviceName": "foo"
408	}
409}`)})
410
411	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
412	defer cancel()
413	// make some rpcs to make sure connection is working.
414	if err := verifyResultWithDelay(func() (bool, error) {
415		if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
416			return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
417		}
418		return true, nil
419	}); err != nil {
420		t.Fatal(err)
421	}
422
423	select {
424	case <-hcExitChan:
425		t.Fatal("Health check function has exited, which is not expected.")
426	default:
427	}
428	// server closes the connection
429	s.Stop()
430
431	select {
432	case <-hcExitChan:
433	case <-time.After(5 * time.Second):
434		select {
435		case <-hcEnterChan:
436		default:
437			t.Fatal("Health check function has not entered after 5s.")
438		}
439		t.Fatal("Health check function has not exited after 5s.")
440	}
441}
442
443// addrConn drain happens when addrConn gets torn down due to its address being no longer in the
444// address list returned by the resolver.
445func (s) TestHealthCheckWithAddrConnDrain(t *testing.T) {
446	hcEnterChan, hcExitChan, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
447
448	_, lis, ts, deferFunc, err := setupServer(&svrConfig{})
449	defer deferFunc()
450	if err != nil {
451		t.Fatal(err)
452	}
453
454	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING)
455
456	cc, r, deferFunc, err := setupClient(&clientConfig{
457		balancerName:               "round_robin",
458		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
459	})
460	if err != nil {
461		t.Fatal(err)
462	}
463	defer deferFunc()
464
465	tc := testpb.NewTestServiceClient(cc)
466	sc := parseCfg(r, `{
467	"healthCheckConfig": {
468		"serviceName": "foo"
469	}
470}`)
471	r.UpdateState(resolver.State{
472		Addresses:     []resolver.Address{{Addr: lis.Addr().String()}},
473		ServiceConfig: sc,
474	})
475
476	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
477	defer cancel()
478	// make some rpcs to make sure connection is working.
479	if err := verifyResultWithDelay(func() (bool, error) {
480		if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
481			return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
482		}
483		return true, nil
484	}); err != nil {
485		t.Fatal(err)
486	}
487
488	// the stream rpc will persist through goaway event.
489	stream, err := tc.FullDuplexCall(ctx, grpc.WaitForReady(true))
490	if err != nil {
491		t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
492	}
493	respParam := []*testpb.ResponseParameters{{Size: 1}}
494	payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(1))
495	if err != nil {
496		t.Fatal(err)
497	}
498	req := &testpb.StreamingOutputCallRequest{
499		ResponseParameters: respParam,
500		Payload:            payload,
501	}
502	if err := stream.Send(req); err != nil {
503		t.Fatalf("%v.Send(_) = %v, want <nil>", stream, err)
504	}
505	if _, err := stream.Recv(); err != nil {
506		t.Fatalf("%v.Recv() = _, %v, want _, <nil>", stream, err)
507	}
508
509	select {
510	case <-hcExitChan:
511		t.Fatal("Health check function has exited, which is not expected.")
512	default:
513	}
514	// trigger teardown of the ac
515	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: "fake address"}}, ServiceConfig: sc})
516
517	select {
518	case <-hcExitChan:
519	case <-time.After(5 * time.Second):
520		select {
521		case <-hcEnterChan:
522		default:
523			t.Fatal("Health check function has not entered after 5s.")
524		}
525		t.Fatal("Health check function has not exited after 5s.")
526	}
527
528	// The existing RPC should be still good to proceed.
529	if err := stream.Send(req); err != nil {
530		t.Fatalf("%v.Send(_) = %v, want <nil>", stream, err)
531	}
532	if _, err := stream.Recv(); err != nil {
533		t.Fatalf("%v.Recv() = _, %v, want _, <nil>", stream, err)
534	}
535}
536
537// ClientConn close will lead to its addrConns being torn down.
538func (s) TestHealthCheckWithClientConnClose(t *testing.T) {
539	hcEnterChan, hcExitChan, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
540
541	_, lis, ts, deferFunc, err := setupServer(&svrConfig{})
542	defer deferFunc()
543	if err != nil {
544		t.Fatal(err)
545	}
546
547	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING)
548
549	cc, r, deferFunc, err := setupClient(&clientConfig{
550		balancerName:               "round_robin",
551		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
552	})
553	if err != nil {
554		t.Fatal(err)
555	}
556	defer deferFunc()
557
558	tc := testpb.NewTestServiceClient(cc)
559	r.UpdateState(resolver.State{
560		Addresses: []resolver.Address{{Addr: lis.Addr().String()}},
561		ServiceConfig: parseCfg(r, `{
562	"healthCheckConfig": {
563		"serviceName": "foo"
564	}
565}`)})
566
567	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
568	defer cancel()
569	// make some rpcs to make sure connection is working.
570	if err := verifyResultWithDelay(func() (bool, error) {
571		if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
572			return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
573		}
574		return true, nil
575	}); err != nil {
576		t.Fatal(err)
577	}
578
579	select {
580	case <-hcExitChan:
581		t.Fatal("Health check function has exited, which is not expected.")
582	default:
583	}
584
585	// trigger addrConn teardown
586	cc.Close()
587
588	select {
589	case <-hcExitChan:
590	case <-time.After(5 * time.Second):
591		select {
592		case <-hcEnterChan:
593		default:
594			t.Fatal("Health check function has not entered after 5s.")
595		}
596		t.Fatal("Health check function has not exited after 5s.")
597	}
598}
599
600// This test is to test the logic in the createTransport after the health check function returns which
601// closes the skipReset channel(since it has not been closed inside health check func) to unblock
602// onGoAway/onClose goroutine.
603func (s) TestHealthCheckWithoutSetConnectivityStateCalledAddrConnShutDown(t *testing.T) {
604	hcEnterChan, hcExitChan, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
605
606	_, lis, ts, deferFunc, err := setupServer(&svrConfig{
607		specialWatchFunc: func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error {
608			if in.Service != "delay" {
609				return status.Error(codes.FailedPrecondition,
610					"this special Watch function only handles request with service name to be \"delay\"")
611			}
612			// Do nothing to mock a delay of health check response from server side.
613			// This case is to help with the test that covers the condition that setConnectivityState is not
614			// called inside HealthCheckFunc before the func returns.
615			select {
616			case <-stream.Context().Done():
617			case <-time.After(5 * time.Second):
618			}
619			return nil
620		},
621	})
622	defer deferFunc()
623	if err != nil {
624		t.Fatal(err)
625	}
626
627	ts.SetServingStatus("delay", healthpb.HealthCheckResponse_SERVING)
628
629	_, r, deferFunc, err := setupClient(&clientConfig{
630		balancerName:               "round_robin",
631		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
632	})
633	if err != nil {
634		t.Fatal(err)
635	}
636	defer deferFunc()
637
638	// The serviceName "delay" is specially handled at server side, where response will not be sent
639	// back to client immediately upon receiving the request (client should receive no response until
640	// test ends).
641	sc := parseCfg(r, `{
642	"healthCheckConfig": {
643		"serviceName": "delay"
644	}
645}`)
646	r.UpdateState(resolver.State{
647		Addresses:     []resolver.Address{{Addr: lis.Addr().String()}},
648		ServiceConfig: sc,
649	})
650
651	select {
652	case <-hcExitChan:
653		t.Fatal("Health check function has exited, which is not expected.")
654	default:
655	}
656
657	select {
658	case <-hcEnterChan:
659	case <-time.After(5 * time.Second):
660		t.Fatal("Health check function has not been invoked after 5s.")
661	}
662	// trigger teardown of the ac, ac in SHUTDOWN state
663	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: "fake address"}}, ServiceConfig: sc})
664
665	// The health check func should exit without calling the setConnectivityState func, as server hasn't sent
666	// any response.
667	select {
668	case <-hcExitChan:
669	case <-time.After(5 * time.Second):
670		t.Fatal("Health check function has not exited after 5s.")
671	}
672	// The deferred leakcheck will check whether there's leaked goroutine, which is an indication
673	// whether we closes the skipReset channel to unblock onGoAway/onClose goroutine.
674}
675
676// This test is to test the logic in the createTransport after the health check function returns which
677// closes the allowedToReset channel(since it has not been closed inside health check func) to unblock
678// onGoAway/onClose goroutine.
679func (s) TestHealthCheckWithoutSetConnectivityStateCalled(t *testing.T) {
680	hcEnterChan, hcExitChan, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
681
682	s, lis, ts, deferFunc, err := setupServer(&svrConfig{
683		specialWatchFunc: func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error {
684			if in.Service != "delay" {
685				return status.Error(codes.FailedPrecondition,
686					"this special Watch function only handles request with service name to be \"delay\"")
687			}
688			// Do nothing to mock a delay of health check response from server side.
689			// This case is to help with the test that covers the condition that setConnectivityState is not
690			// called inside HealthCheckFunc before the func returns.
691			select {
692			case <-stream.Context().Done():
693			case <-time.After(5 * time.Second):
694			}
695			return nil
696		},
697	})
698	defer deferFunc()
699	if err != nil {
700		t.Fatal(err)
701	}
702
703	ts.SetServingStatus("delay", healthpb.HealthCheckResponse_SERVING)
704
705	_, r, deferFunc, err := setupClient(&clientConfig{
706		balancerName:               "round_robin",
707		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
708	})
709	if err != nil {
710		t.Fatal(err)
711	}
712	defer deferFunc()
713
714	// The serviceName "delay" is specially handled at server side, where response will not be sent
715	// back to client immediately upon receiving the request (client should receive no response until
716	// test ends).
717	r.UpdateState(resolver.State{
718		Addresses: []resolver.Address{{Addr: lis.Addr().String()}},
719		ServiceConfig: parseCfg(r, `{
720	"healthCheckConfig": {
721		"serviceName": "delay"
722	}
723}`)})
724
725	select {
726	case <-hcExitChan:
727		t.Fatal("Health check function has exited, which is not expected.")
728	default:
729	}
730
731	select {
732	case <-hcEnterChan:
733	case <-time.After(5 * time.Second):
734		t.Fatal("Health check function has not been invoked after 5s.")
735	}
736	// trigger transport being closed
737	s.Stop()
738
739	// The health check func should exit without calling the setConnectivityState func, as server hasn't sent
740	// any response.
741	select {
742	case <-hcExitChan:
743	case <-time.After(5 * time.Second):
744		t.Fatal("Health check function has not exited after 5s.")
745	}
746	// The deferred leakcheck will check whether there's leaked goroutine, which is an indication
747	// whether we closes the allowedToReset channel to unblock onGoAway/onClose goroutine.
748}
749
750func testHealthCheckDisableWithDialOption(t *testing.T, addr string) {
751	hcEnterChan, _, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
752
753	cc, r, deferFunc, err := setupClient(&clientConfig{
754		balancerName:               "round_robin",
755		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
756		extraDialOption:            []grpc.DialOption{grpc.WithDisableHealthCheck()},
757	})
758	if err != nil {
759		t.Fatal(err)
760	}
761	defer deferFunc()
762
763	tc := testpb.NewTestServiceClient(cc)
764
765	r.UpdateState(resolver.State{
766		Addresses: []resolver.Address{{Addr: addr}},
767		ServiceConfig: parseCfg(r, `{
768	"healthCheckConfig": {
769		"serviceName": "foo"
770	}
771}`)})
772
773	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
774	defer cancel()
775	// send some rpcs to make sure transport has been created and is ready for use.
776	if err := verifyResultWithDelay(func() (bool, error) {
777		if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
778			return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
779		}
780		return true, nil
781	}); err != nil {
782		t.Fatal(err)
783	}
784
785	select {
786	case <-hcEnterChan:
787		t.Fatal("Health check function has exited, which is not expected.")
788	default:
789	}
790}
791
792func testHealthCheckDisableWithBalancer(t *testing.T, addr string) {
793	hcEnterChan, _, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
794
795	cc, r, deferFunc, err := setupClient(&clientConfig{
796		balancerName:               "pick_first",
797		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
798	})
799	if err != nil {
800		t.Fatal(err)
801	}
802	defer deferFunc()
803
804	tc := testpb.NewTestServiceClient(cc)
805
806	r.UpdateState(resolver.State{
807		Addresses: []resolver.Address{{Addr: addr}},
808		ServiceConfig: parseCfg(r, `{
809	"healthCheckConfig": {
810		"serviceName": "foo"
811	}
812}`)})
813
814	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
815	defer cancel()
816	// send some rpcs to make sure transport has been created and is ready for use.
817	if err := verifyResultWithDelay(func() (bool, error) {
818		if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
819			return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
820		}
821		return true, nil
822	}); err != nil {
823		t.Fatal(err)
824	}
825
826	select {
827	case <-hcEnterChan:
828		t.Fatal("Health check function has started, which is not expected.")
829	default:
830	}
831}
832
833func testHealthCheckDisableWithServiceConfig(t *testing.T, addr string) {
834	hcEnterChan, _, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
835
836	cc, r, deferFunc, err := setupClient(&clientConfig{
837		balancerName:               "round_robin",
838		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
839	})
840	if err != nil {
841		t.Fatal(err)
842	}
843	defer deferFunc()
844
845	tc := testpb.NewTestServiceClient(cc)
846
847	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: addr}}})
848
849	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
850	defer cancel()
851	// send some rpcs to make sure transport has been created and is ready for use.
852	if err := verifyResultWithDelay(func() (bool, error) {
853		if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
854			return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
855		}
856		return true, nil
857	}); err != nil {
858		t.Fatal(err)
859	}
860
861	select {
862	case <-hcEnterChan:
863		t.Fatal("Health check function has started, which is not expected.")
864	default:
865	}
866}
867
868func (s) TestHealthCheckDisable(t *testing.T) {
869	_, lis, ts, deferFunc, err := setupServer(&svrConfig{})
870	defer deferFunc()
871	if err != nil {
872		t.Fatal(err)
873	}
874	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING)
875
876	// test client side disabling configuration.
877	testHealthCheckDisableWithDialOption(t, lis.Addr().String())
878	testHealthCheckDisableWithBalancer(t, lis.Addr().String())
879	testHealthCheckDisableWithServiceConfig(t, lis.Addr().String())
880}
881
882func (s) TestHealthCheckChannelzCountingCallSuccess(t *testing.T) {
883	_, lis, _, deferFunc, err := setupServer(&svrConfig{
884		specialWatchFunc: func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error {
885			if in.Service != "channelzSuccess" {
886				return status.Error(codes.FailedPrecondition,
887					"this special Watch function only handles request with service name to be \"channelzSuccess\"")
888			}
889			return status.Error(codes.OK, "fake success")
890		},
891	})
892	defer deferFunc()
893	if err != nil {
894		t.Fatal(err)
895	}
896
897	_, r, deferFunc, err := setupClient(&clientConfig{balancerName: "round_robin"})
898	if err != nil {
899		t.Fatal(err)
900	}
901	defer deferFunc()
902
903	r.UpdateState(resolver.State{
904		Addresses: []resolver.Address{{Addr: lis.Addr().String()}},
905		ServiceConfig: parseCfg(r, `{
906	"healthCheckConfig": {
907		"serviceName": "channelzSuccess"
908	}
909}`)})
910
911	if err := verifyResultWithDelay(func() (bool, error) {
912		cm, _ := channelz.GetTopChannels(0, 0)
913		if len(cm) == 0 {
914			return false, errors.New("channelz.GetTopChannels return 0 top channel")
915		}
916		if len(cm[0].SubChans) == 0 {
917			return false, errors.New("there is 0 subchannel")
918		}
919		var id int64
920		for k := range cm[0].SubChans {
921			id = k
922			break
923		}
924		scm := channelz.GetSubChannel(id)
925		if scm == nil || scm.ChannelData == nil {
926			return false, errors.New("nil subchannel metric or nil subchannel metric ChannelData returned")
927		}
928		// exponential backoff retry may result in more than one health check call.
929		if scm.ChannelData.CallsStarted > 0 && scm.ChannelData.CallsSucceeded > 0 && scm.ChannelData.CallsFailed == 0 {
930			return true, nil
931		}
932		return false, fmt.Errorf("got %d CallsStarted, %d CallsSucceeded, want >0 >0", scm.ChannelData.CallsStarted, scm.ChannelData.CallsSucceeded)
933	}); err != nil {
934		t.Fatal(err)
935	}
936}
937
938func (s) TestHealthCheckChannelzCountingCallFailure(t *testing.T) {
939	_, lis, _, deferFunc, err := setupServer(&svrConfig{
940		specialWatchFunc: func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error {
941			if in.Service != "channelzFailure" {
942				return status.Error(codes.FailedPrecondition,
943					"this special Watch function only handles request with service name to be \"channelzFailure\"")
944			}
945			return status.Error(codes.Internal, "fake failure")
946		},
947	})
948	if err != nil {
949		t.Fatal(err)
950	}
951	defer deferFunc()
952
953	_, r, deferFunc, err := setupClient(&clientConfig{balancerName: "round_robin"})
954	if err != nil {
955		t.Fatal(err)
956	}
957	defer deferFunc()
958
959	r.UpdateState(resolver.State{
960		Addresses: []resolver.Address{{Addr: lis.Addr().String()}},
961		ServiceConfig: parseCfg(r, `{
962	"healthCheckConfig": {
963		"serviceName": "channelzFailure"
964	}
965}`)})
966
967	if err := verifyResultWithDelay(func() (bool, error) {
968		cm, _ := channelz.GetTopChannels(0, 0)
969		if len(cm) == 0 {
970			return false, errors.New("channelz.GetTopChannels return 0 top channel")
971		}
972		if len(cm[0].SubChans) == 0 {
973			return false, errors.New("there is 0 subchannel")
974		}
975		var id int64
976		for k := range cm[0].SubChans {
977			id = k
978			break
979		}
980		scm := channelz.GetSubChannel(id)
981		if scm == nil || scm.ChannelData == nil {
982			return false, errors.New("nil subchannel metric or nil subchannel metric ChannelData returned")
983		}
984		// exponential backoff retry may result in more than one health check call.
985		if scm.ChannelData.CallsStarted > 0 && scm.ChannelData.CallsFailed > 0 && scm.ChannelData.CallsSucceeded == 0 {
986			return true, nil
987		}
988		return false, fmt.Errorf("got %d CallsStarted, %d CallsFailed, want >0, >0", scm.ChannelData.CallsStarted, scm.ChannelData.CallsFailed)
989	}); err != nil {
990		t.Fatal(err)
991	}
992}
993