1/*
2 *
3 * Copyright 2018 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package test
20
21import (
22	"context"
23	"errors"
24	"fmt"
25	"net"
26	"sync"
27	"testing"
28	"time"
29
30	"google.golang.org/grpc"
31	"google.golang.org/grpc/codes"
32	"google.golang.org/grpc/connectivity"
33	_ "google.golang.org/grpc/health"
34	healthgrpc "google.golang.org/grpc/health/grpc_health_v1"
35	healthpb "google.golang.org/grpc/health/grpc_health_v1"
36	"google.golang.org/grpc/internal"
37	"google.golang.org/grpc/internal/channelz"
38	"google.golang.org/grpc/resolver"
39	"google.golang.org/grpc/resolver/manual"
40	"google.golang.org/grpc/status"
41	testpb "google.golang.org/grpc/test/grpc_testing"
42)
43
44var testHealthCheckFunc = internal.HealthCheckFunc
45
46func newTestHealthServer() *testHealthServer {
47	return newTestHealthServerWithWatchFunc(defaultWatchFunc)
48}
49
50func newTestHealthServerWithWatchFunc(f func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error) *testHealthServer {
51	return &testHealthServer{
52		watchFunc: f,
53		update:    make(chan struct{}, 1),
54		status:    make(map[string]healthpb.HealthCheckResponse_ServingStatus),
55	}
56}
57
58// defaultWatchFunc will send a HealthCheckResponse to the client whenever SetServingStatus is called.
59func defaultWatchFunc(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error {
60	if in.Service != "foo" {
61		return status.Error(codes.FailedPrecondition,
62			"the defaultWatchFunc only handles request with service name to be \"foo\"")
63	}
64	var done bool
65	for {
66		select {
67		case <-stream.Context().Done():
68			done = true
69		case <-s.update:
70		}
71		if done {
72			break
73		}
74		s.mu.Lock()
75		resp := &healthpb.HealthCheckResponse{
76			Status: s.status[in.Service],
77		}
78		s.mu.Unlock()
79		stream.SendMsg(resp)
80	}
81	return nil
82}
83
84type testHealthServer struct {
85	healthpb.UnimplementedHealthServer
86	watchFunc func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error
87	mu        sync.Mutex
88	status    map[string]healthpb.HealthCheckResponse_ServingStatus
89	update    chan struct{}
90}
91
92func (s *testHealthServer) Check(ctx context.Context, in *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) {
93	return &healthpb.HealthCheckResponse{
94		Status: healthpb.HealthCheckResponse_SERVING,
95	}, nil
96}
97
98func (s *testHealthServer) Watch(in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error {
99	return s.watchFunc(s, in, stream)
100}
101
102// SetServingStatus is called when need to reset the serving status of a service
103// or insert a new service entry into the statusMap.
104func (s *testHealthServer) SetServingStatus(service string, status healthpb.HealthCheckResponse_ServingStatus) {
105	s.mu.Lock()
106	s.status[service] = status
107	select {
108	case <-s.update:
109	default:
110	}
111	s.update <- struct{}{}
112	s.mu.Unlock()
113}
114
115func setupHealthCheckWrapper() (hcEnterChan chan struct{}, hcExitChan chan struct{}, wrapper internal.HealthChecker) {
116	hcEnterChan = make(chan struct{})
117	hcExitChan = make(chan struct{})
118	wrapper = func(ctx context.Context, newStream func(string) (interface{}, error), update func(state connectivity.State), service string) error {
119		close(hcEnterChan)
120		defer close(hcExitChan)
121		return testHealthCheckFunc(ctx, newStream, update, service)
122	}
123	return
124}
125
126type svrConfig struct {
127	specialWatchFunc func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error
128}
129
130func setupServer(sc *svrConfig) (s *grpc.Server, lis net.Listener, ts *testHealthServer, deferFunc func(), err error) {
131	s = grpc.NewServer()
132	lis, err = net.Listen("tcp", "localhost:0")
133	if err != nil {
134		return nil, nil, nil, func() {}, fmt.Errorf("failed to listen due to err %v", err)
135	}
136	if sc.specialWatchFunc != nil {
137		ts = newTestHealthServerWithWatchFunc(sc.specialWatchFunc)
138	} else {
139		ts = newTestHealthServer()
140	}
141	healthgrpc.RegisterHealthServer(s, ts)
142	testpb.RegisterTestServiceServer(s, &testServer{})
143	go s.Serve(lis)
144	return s, lis, ts, s.Stop, nil
145}
146
147type clientConfig struct {
148	balancerName               string
149	testHealthCheckFuncWrapper internal.HealthChecker
150	extraDialOption            []grpc.DialOption
151}
152
153func setupClient(c *clientConfig) (cc *grpc.ClientConn, r *manual.Resolver, deferFunc func(), err error) {
154	r, rcleanup := manual.GenerateAndRegisterManualResolver()
155	var opts []grpc.DialOption
156	opts = append(opts, grpc.WithInsecure(), grpc.WithBalancerName(c.balancerName))
157	if c.testHealthCheckFuncWrapper != nil {
158		opts = append(opts, internal.WithHealthCheckFunc.(func(internal.HealthChecker) grpc.DialOption)(c.testHealthCheckFuncWrapper))
159	}
160	opts = append(opts, c.extraDialOption...)
161	cc, err = grpc.Dial(r.Scheme()+":///test.server", opts...)
162	if err != nil {
163		rcleanup()
164		return nil, nil, nil, fmt.Errorf("dial failed due to err: %v", err)
165	}
166	return cc, r, func() { cc.Close(); rcleanup() }, nil
167}
168
169func (s) TestHealthCheckWatchStateChange(t *testing.T) {
170	_, lis, ts, deferFunc, err := setupServer(&svrConfig{})
171	defer deferFunc()
172	if err != nil {
173		t.Fatal(err)
174	}
175
176	// The table below shows the expected series of addrConn connectivity transitions when server
177	// updates its health status. As there's only one addrConn corresponds with the ClientConn in this
178	// test, we use ClientConn's connectivity state as the addrConn connectivity state.
179	//+------------------------------+-------------------------------------------+
180	//| Health Check Returned Status | Expected addrConn Connectivity Transition |
181	//+------------------------------+-------------------------------------------+
182	//| NOT_SERVING                  | ->TRANSIENT FAILURE                       |
183	//| SERVING                      | ->READY                                   |
184	//| SERVICE_UNKNOWN              | ->TRANSIENT FAILURE                       |
185	//| SERVING                      | ->READY                                   |
186	//| UNKNOWN                      | ->TRANSIENT FAILURE                       |
187	//+------------------------------+-------------------------------------------+
188	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_NOT_SERVING)
189
190	cc, r, deferFunc, err := setupClient(&clientConfig{balancerName: "round_robin"})
191	if err != nil {
192		t.Fatal(err)
193	}
194	defer deferFunc()
195
196	r.UpdateState(resolver.State{
197		Addresses: []resolver.Address{{Addr: lis.Addr().String()}},
198		ServiceConfig: parseCfg(r, `{
199	"healthCheckConfig": {
200		"serviceName": "foo"
201	}
202}`)})
203	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
204	defer cancel()
205	if ok := cc.WaitForStateChange(ctx, connectivity.Idle); !ok {
206		t.Fatal("ClientConn is still in IDLE state when the context times out.")
207	}
208	if ok := cc.WaitForStateChange(ctx, connectivity.Connecting); !ok {
209		t.Fatal("ClientConn is still in CONNECTING state when the context times out.")
210	}
211	if s := cc.GetState(); s != connectivity.TransientFailure {
212		t.Fatalf("ClientConn is in %v state, want TRANSIENT FAILURE", s)
213	}
214
215	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING)
216	if ok := cc.WaitForStateChange(ctx, connectivity.TransientFailure); !ok {
217		t.Fatal("ClientConn is still in TRANSIENT FAILURE state when the context times out.")
218	}
219	if s := cc.GetState(); s != connectivity.Ready {
220		t.Fatalf("ClientConn is in %v state, want READY", s)
221	}
222
223	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVICE_UNKNOWN)
224	if ok := cc.WaitForStateChange(ctx, connectivity.Ready); !ok {
225		t.Fatal("ClientConn is still in READY state when the context times out.")
226	}
227	if s := cc.GetState(); s != connectivity.TransientFailure {
228		t.Fatalf("ClientConn is in %v state, want TRANSIENT FAILURE", s)
229	}
230
231	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING)
232	if ok := cc.WaitForStateChange(ctx, connectivity.TransientFailure); !ok {
233		t.Fatal("ClientConn is still in TRANSIENT FAILURE state when the context times out.")
234	}
235	if s := cc.GetState(); s != connectivity.Ready {
236		t.Fatalf("ClientConn is in %v state, want READY", s)
237	}
238
239	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_UNKNOWN)
240	if ok := cc.WaitForStateChange(ctx, connectivity.Ready); !ok {
241		t.Fatal("ClientConn is still in READY state when the context times out.")
242	}
243	if s := cc.GetState(); s != connectivity.TransientFailure {
244		t.Fatalf("ClientConn is in %v state, want TRANSIENT FAILURE", s)
245	}
246}
247
248// If Watch returns Unimplemented, then the ClientConn should go into READY state.
249func (s) TestHealthCheckHealthServerNotRegistered(t *testing.T) {
250	s := grpc.NewServer()
251	lis, err := net.Listen("tcp", "localhost:0")
252	if err != nil {
253		t.Fatalf("failed to listen due to err: %v", err)
254	}
255	go s.Serve(lis)
256	defer s.Stop()
257
258	cc, r, deferFunc, err := setupClient(&clientConfig{balancerName: "round_robin"})
259	if err != nil {
260		t.Fatal(err)
261	}
262	defer deferFunc()
263
264	r.UpdateState(resolver.State{
265		Addresses: []resolver.Address{{Addr: lis.Addr().String()}},
266		ServiceConfig: parseCfg(r, `{
267	"healthCheckConfig": {
268		"serviceName": "foo"
269	}
270}`)})
271	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
272	defer cancel()
273
274	if ok := cc.WaitForStateChange(ctx, connectivity.Idle); !ok {
275		t.Fatal("ClientConn is still in IDLE state when the context times out.")
276	}
277	if ok := cc.WaitForStateChange(ctx, connectivity.Connecting); !ok {
278		t.Fatal("ClientConn is still in CONNECTING state when the context times out.")
279	}
280	if s := cc.GetState(); s != connectivity.Ready {
281		t.Fatalf("ClientConn is in %v state, want READY", s)
282	}
283}
284
285// In the case of a goaway received, the health check stream should be terminated and health check
286// function should exit.
287func (s) TestHealthCheckWithGoAway(t *testing.T) {
288	hcEnterChan, hcExitChan, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
289
290	s, lis, ts, deferFunc, err := setupServer(&svrConfig{})
291	defer deferFunc()
292	if err != nil {
293		t.Fatal(err)
294	}
295
296	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING)
297
298	cc, r, deferFunc, err := setupClient(&clientConfig{
299		balancerName:               "round_robin",
300		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
301	})
302	if err != nil {
303		t.Fatal(err)
304	}
305	defer deferFunc()
306
307	tc := testpb.NewTestServiceClient(cc)
308	r.UpdateState(resolver.State{
309		Addresses: []resolver.Address{{Addr: lis.Addr().String()}},
310		ServiceConfig: parseCfg(r, `{
311	"healthCheckConfig": {
312		"serviceName": "foo"
313	}
314}`)})
315
316	// make some rpcs to make sure connection is working.
317	if err := verifyResultWithDelay(func() (bool, error) {
318		if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
319			return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
320		}
321		return true, nil
322	}); err != nil {
323		t.Fatal(err)
324	}
325
326	// the stream rpc will persist through goaway event.
327	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
328	defer cancel()
329	stream, err := tc.FullDuplexCall(ctx, grpc.WaitForReady(true))
330	if err != nil {
331		t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
332	}
333	respParam := []*testpb.ResponseParameters{{Size: 1}}
334	payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(1))
335	if err != nil {
336		t.Fatal(err)
337	}
338	req := &testpb.StreamingOutputCallRequest{
339		ResponseParameters: respParam,
340		Payload:            payload,
341	}
342	if err := stream.Send(req); err != nil {
343		t.Fatalf("%v.Send(_) = %v, want <nil>", stream, err)
344	}
345	if _, err := stream.Recv(); err != nil {
346		t.Fatalf("%v.Recv() = _, %v, want _, <nil>", stream, err)
347	}
348
349	select {
350	case <-hcExitChan:
351		t.Fatal("Health check function has exited, which is not expected.")
352	default:
353	}
354
355	// server sends GoAway
356	go s.GracefulStop()
357
358	select {
359	case <-hcExitChan:
360	case <-time.After(5 * time.Second):
361		select {
362		case <-hcEnterChan:
363		default:
364			t.Fatal("Health check function has not entered after 5s.")
365		}
366		t.Fatal("Health check function has not exited after 5s.")
367	}
368
369	// The existing RPC should be still good to proceed.
370	if err := stream.Send(req); err != nil {
371		t.Fatalf("%v.Send(_) = %v, want <nil>", stream, err)
372	}
373	if _, err := stream.Recv(); err != nil {
374		t.Fatalf("%v.Recv() = _, %v, want _, <nil>", stream, err)
375	}
376}
377
378func (s) TestHealthCheckWithConnClose(t *testing.T) {
379	hcEnterChan, hcExitChan, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
380
381	s, lis, ts, deferFunc, err := setupServer(&svrConfig{})
382	defer deferFunc()
383	if err != nil {
384		t.Fatal(err)
385	}
386
387	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING)
388
389	cc, r, deferFunc, err := setupClient(&clientConfig{
390		balancerName:               "round_robin",
391		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
392	})
393	if err != nil {
394		t.Fatal(err)
395	}
396	defer deferFunc()
397
398	tc := testpb.NewTestServiceClient(cc)
399
400	r.UpdateState(resolver.State{
401		Addresses: []resolver.Address{{Addr: lis.Addr().String()}},
402		ServiceConfig: parseCfg(r, `{
403	"healthCheckConfig": {
404		"serviceName": "foo"
405	}
406}`)})
407
408	// make some rpcs to make sure connection is working.
409	if err := verifyResultWithDelay(func() (bool, error) {
410		if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
411			return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
412		}
413		return true, nil
414	}); err != nil {
415		t.Fatal(err)
416	}
417
418	select {
419	case <-hcExitChan:
420		t.Fatal("Health check function has exited, which is not expected.")
421	default:
422	}
423	// server closes the connection
424	s.Stop()
425
426	select {
427	case <-hcExitChan:
428	case <-time.After(5 * time.Second):
429		select {
430		case <-hcEnterChan:
431		default:
432			t.Fatal("Health check function has not entered after 5s.")
433		}
434		t.Fatal("Health check function has not exited after 5s.")
435	}
436}
437
438// addrConn drain happens when addrConn gets torn down due to its address being no longer in the
439// address list returned by the resolver.
440func (s) TestHealthCheckWithAddrConnDrain(t *testing.T) {
441	hcEnterChan, hcExitChan, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
442
443	_, lis, ts, deferFunc, err := setupServer(&svrConfig{})
444	defer deferFunc()
445	if err != nil {
446		t.Fatal(err)
447	}
448
449	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING)
450
451	cc, r, deferFunc, err := setupClient(&clientConfig{
452		balancerName:               "round_robin",
453		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
454	})
455	if err != nil {
456		t.Fatal(err)
457	}
458	defer deferFunc()
459
460	tc := testpb.NewTestServiceClient(cc)
461	sc := parseCfg(r, `{
462	"healthCheckConfig": {
463		"serviceName": "foo"
464	}
465}`)
466	r.UpdateState(resolver.State{
467		Addresses:     []resolver.Address{{Addr: lis.Addr().String()}},
468		ServiceConfig: sc,
469	})
470
471	// make some rpcs to make sure connection is working.
472	if err := verifyResultWithDelay(func() (bool, error) {
473		if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
474			return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
475		}
476		return true, nil
477	}); err != nil {
478		t.Fatal(err)
479	}
480
481	// the stream rpc will persist through goaway event.
482	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
483	defer cancel()
484	stream, err := tc.FullDuplexCall(ctx, grpc.WaitForReady(true))
485	if err != nil {
486		t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
487	}
488	respParam := []*testpb.ResponseParameters{{Size: 1}}
489	payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(1))
490	if err != nil {
491		t.Fatal(err)
492	}
493	req := &testpb.StreamingOutputCallRequest{
494		ResponseParameters: respParam,
495		Payload:            payload,
496	}
497	if err := stream.Send(req); err != nil {
498		t.Fatalf("%v.Send(_) = %v, want <nil>", stream, err)
499	}
500	if _, err := stream.Recv(); err != nil {
501		t.Fatalf("%v.Recv() = _, %v, want _, <nil>", stream, err)
502	}
503
504	select {
505	case <-hcExitChan:
506		t.Fatal("Health check function has exited, which is not expected.")
507	default:
508	}
509	// trigger teardown of the ac
510	r.UpdateState(resolver.State{Addresses: []resolver.Address{}, ServiceConfig: sc})
511
512	select {
513	case <-hcExitChan:
514	case <-time.After(5 * time.Second):
515		select {
516		case <-hcEnterChan:
517		default:
518			t.Fatal("Health check function has not entered after 5s.")
519		}
520		t.Fatal("Health check function has not exited after 5s.")
521	}
522
523	// The existing RPC should be still good to proceed.
524	if err := stream.Send(req); err != nil {
525		t.Fatalf("%v.Send(_) = %v, want <nil>", stream, err)
526	}
527	if _, err := stream.Recv(); err != nil {
528		t.Fatalf("%v.Recv() = _, %v, want _, <nil>", stream, err)
529	}
530}
531
532// ClientConn close will lead to its addrConns being torn down.
533func (s) TestHealthCheckWithClientConnClose(t *testing.T) {
534	hcEnterChan, hcExitChan, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
535
536	_, lis, ts, deferFunc, err := setupServer(&svrConfig{})
537	defer deferFunc()
538	if err != nil {
539		t.Fatal(err)
540	}
541
542	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING)
543
544	cc, r, deferFunc, err := setupClient(&clientConfig{
545		balancerName:               "round_robin",
546		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
547	})
548	if err != nil {
549		t.Fatal(err)
550	}
551	defer deferFunc()
552
553	tc := testpb.NewTestServiceClient(cc)
554	r.UpdateState(resolver.State{
555		Addresses: []resolver.Address{{Addr: lis.Addr().String()}},
556		ServiceConfig: parseCfg(r, `{
557	"healthCheckConfig": {
558		"serviceName": "foo"
559	}
560}`)})
561
562	// make some rpcs to make sure connection is working.
563	if err := verifyResultWithDelay(func() (bool, error) {
564		if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
565			return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
566		}
567		return true, nil
568	}); err != nil {
569		t.Fatal(err)
570	}
571
572	select {
573	case <-hcExitChan:
574		t.Fatal("Health check function has exited, which is not expected.")
575	default:
576	}
577
578	// trigger addrConn teardown
579	cc.Close()
580
581	select {
582	case <-hcExitChan:
583	case <-time.After(5 * time.Second):
584		select {
585		case <-hcEnterChan:
586		default:
587			t.Fatal("Health check function has not entered after 5s.")
588		}
589		t.Fatal("Health check function has not exited after 5s.")
590	}
591}
592
593// This test is to test the logic in the createTransport after the health check function returns which
594// closes the skipReset channel(since it has not been closed inside health check func) to unblock
595// onGoAway/onClose goroutine.
596func (s) TestHealthCheckWithoutSetConnectivityStateCalledAddrConnShutDown(t *testing.T) {
597	hcEnterChan, hcExitChan, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
598
599	_, lis, ts, deferFunc, err := setupServer(&svrConfig{
600		specialWatchFunc: func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error {
601			if in.Service != "delay" {
602				return status.Error(codes.FailedPrecondition,
603					"this special Watch function only handles request with service name to be \"delay\"")
604			}
605			// Do nothing to mock a delay of health check response from server side.
606			// This case is to help with the test that covers the condition that setConnectivityState is not
607			// called inside HealthCheckFunc before the func returns.
608			select {
609			case <-stream.Context().Done():
610			case <-time.After(5 * time.Second):
611			}
612			return nil
613		},
614	})
615	defer deferFunc()
616	if err != nil {
617		t.Fatal(err)
618	}
619
620	ts.SetServingStatus("delay", healthpb.HealthCheckResponse_SERVING)
621
622	_, r, deferFunc, err := setupClient(&clientConfig{
623		balancerName:               "round_robin",
624		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
625	})
626	if err != nil {
627		t.Fatal(err)
628	}
629	defer deferFunc()
630
631	// The serviceName "delay" is specially handled at server side, where response will not be sent
632	// back to client immediately upon receiving the request (client should receive no response until
633	// test ends).
634	sc := parseCfg(r, `{
635	"healthCheckConfig": {
636		"serviceName": "delay"
637	}
638}`)
639	r.UpdateState(resolver.State{
640		Addresses:     []resolver.Address{{Addr: lis.Addr().String()}},
641		ServiceConfig: sc,
642	})
643
644	select {
645	case <-hcExitChan:
646		t.Fatal("Health check function has exited, which is not expected.")
647	default:
648	}
649
650	select {
651	case <-hcEnterChan:
652	case <-time.After(5 * time.Second):
653		t.Fatal("Health check function has not been invoked after 5s.")
654	}
655	// trigger teardown of the ac, ac in SHUTDOWN state
656	r.UpdateState(resolver.State{Addresses: []resolver.Address{}, ServiceConfig: sc})
657
658	// The health check func should exit without calling the setConnectivityState func, as server hasn't sent
659	// any response.
660	select {
661	case <-hcExitChan:
662	case <-time.After(5 * time.Second):
663		t.Fatal("Health check function has not exited after 5s.")
664	}
665	// The deferred leakcheck will check whether there's leaked goroutine, which is an indication
666	// whether we closes the skipReset channel to unblock onGoAway/onClose goroutine.
667}
668
669// This test is to test the logic in the createTransport after the health check function returns which
670// closes the allowedToReset channel(since it has not been closed inside health check func) to unblock
671// onGoAway/onClose goroutine.
672func (s) TestHealthCheckWithoutSetConnectivityStateCalled(t *testing.T) {
673	hcEnterChan, hcExitChan, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
674
675	s, lis, ts, deferFunc, err := setupServer(&svrConfig{
676		specialWatchFunc: func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error {
677			if in.Service != "delay" {
678				return status.Error(codes.FailedPrecondition,
679					"this special Watch function only handles request with service name to be \"delay\"")
680			}
681			// Do nothing to mock a delay of health check response from server side.
682			// This case is to help with the test that covers the condition that setConnectivityState is not
683			// called inside HealthCheckFunc before the func returns.
684			select {
685			case <-stream.Context().Done():
686			case <-time.After(5 * time.Second):
687			}
688			return nil
689		},
690	})
691	defer deferFunc()
692	if err != nil {
693		t.Fatal(err)
694	}
695
696	ts.SetServingStatus("delay", healthpb.HealthCheckResponse_SERVING)
697
698	_, r, deferFunc, err := setupClient(&clientConfig{
699		balancerName:               "round_robin",
700		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
701	})
702	if err != nil {
703		t.Fatal(err)
704	}
705	defer deferFunc()
706
707	// The serviceName "delay" is specially handled at server side, where response will not be sent
708	// back to client immediately upon receiving the request (client should receive no response until
709	// test ends).
710	r.UpdateState(resolver.State{
711		Addresses: []resolver.Address{{Addr: lis.Addr().String()}},
712		ServiceConfig: parseCfg(r, `{
713	"healthCheckConfig": {
714		"serviceName": "delay"
715	}
716}`)})
717
718	select {
719	case <-hcExitChan:
720		t.Fatal("Health check function has exited, which is not expected.")
721	default:
722	}
723
724	select {
725	case <-hcEnterChan:
726	case <-time.After(5 * time.Second):
727		t.Fatal("Health check function has not been invoked after 5s.")
728	}
729	// trigger transport being closed
730	s.Stop()
731
732	// The health check func should exit without calling the setConnectivityState func, as server hasn't sent
733	// any response.
734	select {
735	case <-hcExitChan:
736	case <-time.After(5 * time.Second):
737		t.Fatal("Health check function has not exited after 5s.")
738	}
739	// The deferred leakcheck will check whether there's leaked goroutine, which is an indication
740	// whether we closes the allowedToReset channel to unblock onGoAway/onClose goroutine.
741}
742
743func testHealthCheckDisableWithDialOption(t *testing.T, addr string) {
744	hcEnterChan, _, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
745
746	cc, r, deferFunc, err := setupClient(&clientConfig{
747		balancerName:               "round_robin",
748		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
749		extraDialOption:            []grpc.DialOption{grpc.WithDisableHealthCheck()},
750	})
751	if err != nil {
752		t.Fatal(err)
753	}
754	defer deferFunc()
755
756	tc := testpb.NewTestServiceClient(cc)
757
758	r.UpdateState(resolver.State{
759		Addresses: []resolver.Address{{Addr: addr}},
760		ServiceConfig: parseCfg(r, `{
761	"healthCheckConfig": {
762		"serviceName": "foo"
763	}
764}`)})
765
766	// send some rpcs to make sure transport has been created and is ready for use.
767	if err := verifyResultWithDelay(func() (bool, error) {
768		if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
769			return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
770		}
771		return true, nil
772	}); err != nil {
773		t.Fatal(err)
774	}
775
776	select {
777	case <-hcEnterChan:
778		t.Fatal("Health check function has exited, which is not expected.")
779	default:
780	}
781}
782
783func testHealthCheckDisableWithBalancer(t *testing.T, addr string) {
784	hcEnterChan, _, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
785
786	cc, r, deferFunc, err := setupClient(&clientConfig{
787		balancerName:               "pick_first",
788		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
789	})
790	if err != nil {
791		t.Fatal(err)
792	}
793	defer deferFunc()
794
795	tc := testpb.NewTestServiceClient(cc)
796
797	r.UpdateState(resolver.State{
798		Addresses: []resolver.Address{{Addr: addr}},
799		ServiceConfig: parseCfg(r, `{
800	"healthCheckConfig": {
801		"serviceName": "foo"
802	}
803}`)})
804
805	// send some rpcs to make sure transport has been created and is ready for use.
806	if err := verifyResultWithDelay(func() (bool, error) {
807		if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
808			return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
809		}
810		return true, nil
811	}); err != nil {
812		t.Fatal(err)
813	}
814
815	select {
816	case <-hcEnterChan:
817		t.Fatal("Health check function has started, which is not expected.")
818	default:
819	}
820}
821
822func testHealthCheckDisableWithServiceConfig(t *testing.T, addr string) {
823	hcEnterChan, _, testHealthCheckFuncWrapper := setupHealthCheckWrapper()
824
825	cc, r, deferFunc, err := setupClient(&clientConfig{
826		balancerName:               "round_robin",
827		testHealthCheckFuncWrapper: testHealthCheckFuncWrapper,
828	})
829	if err != nil {
830		t.Fatal(err)
831	}
832	defer deferFunc()
833
834	tc := testpb.NewTestServiceClient(cc)
835
836	r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: addr}}})
837
838	// send some rpcs to make sure transport has been created and is ready for use.
839	if err := verifyResultWithDelay(func() (bool, error) {
840		if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
841			return false, fmt.Errorf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
842		}
843		return true, nil
844	}); err != nil {
845		t.Fatal(err)
846	}
847
848	select {
849	case <-hcEnterChan:
850		t.Fatal("Health check function has started, which is not expected.")
851	default:
852	}
853}
854
855func (s) TestHealthCheckDisable(t *testing.T) {
856	_, lis, ts, deferFunc, err := setupServer(&svrConfig{})
857	defer deferFunc()
858	if err != nil {
859		t.Fatal(err)
860	}
861	ts.SetServingStatus("foo", healthpb.HealthCheckResponse_SERVING)
862
863	// test client side disabling configuration.
864	testHealthCheckDisableWithDialOption(t, lis.Addr().String())
865	testHealthCheckDisableWithBalancer(t, lis.Addr().String())
866	testHealthCheckDisableWithServiceConfig(t, lis.Addr().String())
867}
868
869func (s) TestHealthCheckChannelzCountingCallSuccess(t *testing.T) {
870	_, lis, _, deferFunc, err := setupServer(&svrConfig{
871		specialWatchFunc: func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error {
872			if in.Service != "channelzSuccess" {
873				return status.Error(codes.FailedPrecondition,
874					"this special Watch function only handles request with service name to be \"channelzSuccess\"")
875			}
876			return status.Error(codes.OK, "fake success")
877		},
878	})
879	defer deferFunc()
880	if err != nil {
881		t.Fatal(err)
882	}
883
884	_, r, deferFunc, err := setupClient(&clientConfig{balancerName: "round_robin"})
885	if err != nil {
886		t.Fatal(err)
887	}
888	defer deferFunc()
889
890	r.UpdateState(resolver.State{
891		Addresses: []resolver.Address{{Addr: lis.Addr().String()}},
892		ServiceConfig: parseCfg(r, `{
893	"healthCheckConfig": {
894		"serviceName": "channelzSuccess"
895	}
896}`)})
897
898	if err := verifyResultWithDelay(func() (bool, error) {
899		cm, _ := channelz.GetTopChannels(0, 0)
900		if len(cm) == 0 {
901			return false, errors.New("channelz.GetTopChannels return 0 top channel")
902		}
903		if len(cm[0].SubChans) == 0 {
904			return false, errors.New("there is 0 subchannel")
905		}
906		var id int64
907		for k := range cm[0].SubChans {
908			id = k
909			break
910		}
911		scm := channelz.GetSubChannel(id)
912		if scm == nil || scm.ChannelData == nil {
913			return false, errors.New("nil subchannel metric or nil subchannel metric ChannelData returned")
914		}
915		// exponential backoff retry may result in more than one health check call.
916		if scm.ChannelData.CallsStarted > 0 && scm.ChannelData.CallsSucceeded > 0 && scm.ChannelData.CallsFailed == 0 {
917			return true, nil
918		}
919		return false, fmt.Errorf("got %d CallsStarted, %d CallsSucceeded, want >0 >0", scm.ChannelData.CallsStarted, scm.ChannelData.CallsSucceeded)
920	}); err != nil {
921		t.Fatal(err)
922	}
923}
924
925func (s) TestHealthCheckChannelzCountingCallFailure(t *testing.T) {
926	_, lis, _, deferFunc, err := setupServer(&svrConfig{
927		specialWatchFunc: func(s *testHealthServer, in *healthpb.HealthCheckRequest, stream healthgrpc.Health_WatchServer) error {
928			if in.Service != "channelzFailure" {
929				return status.Error(codes.FailedPrecondition,
930					"this special Watch function only handles request with service name to be \"channelzFailure\"")
931			}
932			return status.Error(codes.Internal, "fake failure")
933		},
934	})
935	if err != nil {
936		t.Fatal(err)
937	}
938	defer deferFunc()
939
940	_, r, deferFunc, err := setupClient(&clientConfig{balancerName: "round_robin"})
941	if err != nil {
942		t.Fatal(err)
943	}
944	defer deferFunc()
945
946	r.UpdateState(resolver.State{
947		Addresses: []resolver.Address{{Addr: lis.Addr().String()}},
948		ServiceConfig: parseCfg(r, `{
949	"healthCheckConfig": {
950		"serviceName": "channelzFailure"
951	}
952}`)})
953
954	if err := verifyResultWithDelay(func() (bool, error) {
955		cm, _ := channelz.GetTopChannels(0, 0)
956		if len(cm) == 0 {
957			return false, errors.New("channelz.GetTopChannels return 0 top channel")
958		}
959		if len(cm[0].SubChans) == 0 {
960			return false, errors.New("there is 0 subchannel")
961		}
962		var id int64
963		for k := range cm[0].SubChans {
964			id = k
965			break
966		}
967		scm := channelz.GetSubChannel(id)
968		if scm == nil || scm.ChannelData == nil {
969			return false, errors.New("nil subchannel metric or nil subchannel metric ChannelData returned")
970		}
971		// exponential backoff retry may result in more than one health check call.
972		if scm.ChannelData.CallsStarted > 0 && scm.ChannelData.CallsFailed > 0 && scm.ChannelData.CallsSucceeded == 0 {
973			return true, nil
974		}
975		return false, fmt.Errorf("got %d CallsStarted, %d CallsFailed, want >0, >0", scm.ChannelData.CallsStarted, scm.ChannelData.CallsFailed)
976	}); err != nil {
977		t.Fatal(err)
978	}
979}
980