1// +build go1.12
2
3/*
4 *
5 * Copyright 2019 gRPC authors.
6 *
7 * Licensed under the Apache License, Version 2.0 (the "License");
8 * you may not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 *
11 *     http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
18 *
19 */
20
21package xds
22
23import (
24	"context"
25	"errors"
26	"io"
27	"net"
28	"testing"
29	"time"
30
31	"github.com/golang/protobuf/proto"
32	anypb "github.com/golang/protobuf/ptypes/any"
33	durationpb "github.com/golang/protobuf/ptypes/duration"
34	structpb "github.com/golang/protobuf/ptypes/struct"
35	wrpb "github.com/golang/protobuf/ptypes/wrappers"
36	"google.golang.org/grpc"
37	"google.golang.org/grpc/balancer"
38	"google.golang.org/grpc/balancer/xds/internal"
39	cdspb "google.golang.org/grpc/balancer/xds/internal/proto/envoy/api/v2/cds"
40	addresspb "google.golang.org/grpc/balancer/xds/internal/proto/envoy/api/v2/core/address"
41	basepb "google.golang.org/grpc/balancer/xds/internal/proto/envoy/api/v2/core/base"
42	discoverypb "google.golang.org/grpc/balancer/xds/internal/proto/envoy/api/v2/discovery"
43	edspb "google.golang.org/grpc/balancer/xds/internal/proto/envoy/api/v2/eds"
44	endpointpb "google.golang.org/grpc/balancer/xds/internal/proto/envoy/api/v2/endpoint/endpoint"
45	adsgrpc "google.golang.org/grpc/balancer/xds/internal/proto/envoy/service/discovery/v2/ads"
46	lrsgrpc "google.golang.org/grpc/balancer/xds/internal/proto/envoy/service/load_stats/v2/lrs"
47	"google.golang.org/grpc/codes"
48	"google.golang.org/grpc/resolver"
49	"google.golang.org/grpc/status"
50)
51
52var (
53	testServiceName = "test/foo"
54	testCDSReq      = &discoverypb.DiscoveryRequest{
55		Node: &basepb.Node{
56			Metadata: &structpb.Struct{
57				Fields: map[string]*structpb.Value{
58					internal.GrpcHostname: {
59						Kind: &structpb.Value_StringValue{StringValue: testServiceName},
60					},
61				},
62			},
63		},
64		TypeUrl: cdsType,
65	}
66	testEDSReq = &discoverypb.DiscoveryRequest{
67		Node: &basepb.Node{
68			Metadata: &structpb.Struct{
69				Fields: map[string]*structpb.Value{
70					endpointRequired: {
71						Kind: &structpb.Value_BoolValue{BoolValue: true},
72					},
73				},
74			},
75		},
76		ResourceNames: []string{testServiceName},
77		TypeUrl:       edsType,
78	}
79	testEDSReqWithoutEndpoints = &discoverypb.DiscoveryRequest{
80		Node: &basepb.Node{
81			Metadata: &structpb.Struct{
82				Fields: map[string]*structpb.Value{
83					endpointRequired: {
84						Kind: &structpb.Value_BoolValue{BoolValue: false},
85					},
86				},
87			},
88		},
89		ResourceNames: []string{testServiceName},
90		TypeUrl:       edsType,
91	}
92	testCluster = &cdspb.Cluster{
93		Name:                 testServiceName,
94		ClusterDiscoveryType: &cdspb.Cluster_Type{Type: cdspb.Cluster_EDS},
95		LbPolicy:             cdspb.Cluster_ROUND_ROBIN,
96	}
97	marshaledCluster, _ = proto.Marshal(testCluster)
98	testCDSResp         = &discoverypb.DiscoveryResponse{
99		Resources: []*anypb.Any{
100			{
101				TypeUrl: cdsType,
102				Value:   marshaledCluster,
103			},
104		},
105		TypeUrl: cdsType,
106	}
107	testClusterLoadAssignment = &edspb.ClusterLoadAssignment{
108		ClusterName: testServiceName,
109		Endpoints: []*endpointpb.LocalityLbEndpoints{
110			{
111				Locality: &basepb.Locality{
112					Region:  "asia-east1",
113					Zone:    "1",
114					SubZone: "sa",
115				},
116				LbEndpoints: []*endpointpb.LbEndpoint{
117					{
118						HostIdentifier: &endpointpb.LbEndpoint_Endpoint{
119							Endpoint: &endpointpb.Endpoint{
120								Address: &addresspb.Address{
121									Address: &addresspb.Address_SocketAddress{
122										SocketAddress: &addresspb.SocketAddress{
123											Address: "1.1.1.1",
124											PortSpecifier: &addresspb.SocketAddress_PortValue{
125												PortValue: 10001,
126											},
127											ResolverName: "dns",
128										},
129									},
130								},
131								HealthCheckConfig: nil,
132							},
133						},
134						Metadata: &basepb.Metadata{
135							FilterMetadata: map[string]*structpb.Struct{
136								"xx.lb": {
137									Fields: map[string]*structpb.Value{
138										"endpoint_name": {
139											Kind: &structpb.Value_StringValue{
140												StringValue: "some.endpoint.name",
141											},
142										},
143									},
144								},
145							},
146						},
147					},
148				},
149				LoadBalancingWeight: &wrpb.UInt32Value{
150					Value: 1,
151				},
152				Priority: 0,
153			},
154		},
155	}
156	marshaledClusterLoadAssignment, _ = proto.Marshal(testClusterLoadAssignment)
157	testEDSResp                       = &discoverypb.DiscoveryResponse{
158		Resources: []*anypb.Any{
159			{
160				TypeUrl: edsType,
161				Value:   marshaledClusterLoadAssignment,
162			},
163		},
164		TypeUrl: edsType,
165	}
166	testClusterLoadAssignmentWithoutEndpoints = &edspb.ClusterLoadAssignment{
167		ClusterName: testServiceName,
168		Endpoints: []*endpointpb.LocalityLbEndpoints{
169			{
170				Locality: &basepb.Locality{
171					SubZone: "sa",
172				},
173				LoadBalancingWeight: &wrpb.UInt32Value{
174					Value: 128,
175				},
176				Priority: 0,
177			},
178		},
179		Policy: nil,
180	}
181	marshaledClusterLoadAssignmentWithoutEndpoints, _ = proto.Marshal(testClusterLoadAssignmentWithoutEndpoints)
182	testEDSRespWithoutEndpoints                       = &discoverypb.DiscoveryResponse{
183		Resources: []*anypb.Any{
184			{
185				TypeUrl: edsType,
186				Value:   marshaledClusterLoadAssignmentWithoutEndpoints,
187			},
188		},
189		TypeUrl: edsType,
190	}
191)
192
193type testTrafficDirector struct {
194	reqChan  chan *request
195	respChan chan *response
196}
197
198type request struct {
199	req *discoverypb.DiscoveryRequest
200	err error
201}
202
203type response struct {
204	resp *discoverypb.DiscoveryResponse
205	err  error
206}
207
208func (ttd *testTrafficDirector) StreamAggregatedResources(s adsgrpc.AggregatedDiscoveryService_StreamAggregatedResourcesServer) error {
209	for {
210		req, err := s.Recv()
211		if err != nil {
212			ttd.reqChan <- &request{
213				req: nil,
214				err: err,
215			}
216			if err == io.EOF {
217				return nil
218			}
219			return err
220		}
221		ttd.reqChan <- &request{
222			req: req,
223			err: nil,
224		}
225		if req.TypeUrl == edsType {
226			break
227		}
228	}
229
230	for {
231		select {
232		case resp := <-ttd.respChan:
233			if resp.err != nil {
234				return resp.err
235			}
236			if err := s.Send(resp.resp); err != nil {
237				return err
238			}
239		case <-s.Context().Done():
240			return s.Context().Err()
241		}
242	}
243}
244
245func (ttd *testTrafficDirector) DeltaAggregatedResources(adsgrpc.AggregatedDiscoveryService_DeltaAggregatedResourcesServer) error {
246	return status.Error(codes.Unimplemented, "")
247}
248
249func (ttd *testTrafficDirector) sendResp(resp *response) {
250	ttd.respChan <- resp
251}
252
253func (ttd *testTrafficDirector) getReq() *request {
254	return <-ttd.reqChan
255}
256
257func newTestTrafficDirector() *testTrafficDirector {
258	return &testTrafficDirector{
259		reqChan:  make(chan *request, 10),
260		respChan: make(chan *response, 10),
261	}
262}
263
264type testConfig struct {
265	doCDS                bool
266	expectedRequests     []*discoverypb.DiscoveryRequest
267	responsesToSend      []*discoverypb.DiscoveryResponse
268	expectedADSResponses []proto.Message
269	adsErr               error
270	svrErr               error
271}
272
273func setupServer(t *testing.T) (addr string, td *testTrafficDirector, lrss *lrsServer, cleanup func()) {
274	lis, err := net.Listen("tcp", "localhost:0")
275	if err != nil {
276		t.Fatalf("listen failed due to: %v", err)
277	}
278	svr := grpc.NewServer()
279	td = newTestTrafficDirector()
280	lrss = &lrsServer{
281		drops: make(map[string]uint64),
282		reportingInterval: &durationpb.Duration{
283			Seconds: 60 * 60, // 1 hour, each test can override this to a shorter duration.
284			Nanos:   0,
285		},
286	}
287	adsgrpc.RegisterAggregatedDiscoveryServiceServer(svr, td)
288	lrsgrpc.RegisterLoadReportingServiceServer(svr, lrss)
289	go svr.Serve(lis)
290	return lis.Addr().String(), td, lrss, func() {
291		svr.Stop()
292		lis.Close()
293	}
294}
295
296func (s) TestXdsClientResponseHandling(t *testing.T) {
297	for _, test := range []*testConfig{
298		{
299			doCDS:                true,
300			expectedRequests:     []*discoverypb.DiscoveryRequest{testCDSReq, testEDSReq},
301			responsesToSend:      []*discoverypb.DiscoveryResponse{testCDSResp, testEDSResp},
302			expectedADSResponses: []proto.Message{testCluster, testClusterLoadAssignment},
303		},
304		{
305			doCDS:                false,
306			expectedRequests:     []*discoverypb.DiscoveryRequest{testEDSReqWithoutEndpoints},
307			responsesToSend:      []*discoverypb.DiscoveryResponse{testEDSRespWithoutEndpoints},
308			expectedADSResponses: []proto.Message{testClusterLoadAssignmentWithoutEndpoints},
309		},
310	} {
311		testXdsClientResponseHandling(t, test)
312	}
313}
314
315func testXdsClientResponseHandling(t *testing.T, test *testConfig) {
316	addr, td, _, cleanup := setupServer(t)
317	defer cleanup()
318	adsChan := make(chan proto.Message, 10)
319	newADS := func(ctx context.Context, i proto.Message) error {
320		adsChan <- i
321		return nil
322	}
323	client := newXDSClient(addr, test.doCDS, balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}, nil, newADS, func(context.Context) {}, func() {})
324	defer client.close()
325	go client.run()
326
327	for _, expectedReq := range test.expectedRequests {
328		req := td.getReq()
329		if req.err != nil {
330			t.Fatalf("ads RPC failed with err: %v", req.err)
331		}
332		if !proto.Equal(req.req, expectedReq) {
333			t.Fatalf("got ADS request %T %v, expected: %T %v", req.req, req.req, expectedReq, expectedReq)
334		}
335	}
336
337	for i, resp := range test.responsesToSend {
338		td.sendResp(&response{resp: resp})
339		ads := <-adsChan
340		if !proto.Equal(ads, test.expectedADSResponses[i]) {
341			t.Fatalf("received unexpected ads response, got %v, want %v", ads, test.expectedADSResponses[i])
342		}
343	}
344}
345
346func (s) TestXdsClientLoseContact(t *testing.T) {
347	for _, test := range []*testConfig{
348		{
349			doCDS:           true,
350			responsesToSend: []*discoverypb.DiscoveryResponse{},
351		},
352		{
353			doCDS:           false,
354			responsesToSend: []*discoverypb.DiscoveryResponse{testEDSRespWithoutEndpoints},
355		},
356	} {
357		testXdsClientLoseContactRemoteClose(t, test)
358	}
359
360	for _, test := range []*testConfig{
361		{
362			doCDS:           false,
363			responsesToSend: []*discoverypb.DiscoveryResponse{testCDSResp}, // CDS response when in custom mode.
364		},
365		{
366			doCDS:           true,
367			responsesToSend: []*discoverypb.DiscoveryResponse{{}}, // response with 0 resources is an error case.
368		},
369		{
370			doCDS:           true,
371			responsesToSend: []*discoverypb.DiscoveryResponse{testCDSResp},
372			adsErr:          errors.New("some ads parsing error from xdsBalancer"),
373		},
374	} {
375		testXdsClientLoseContactADSRelatedErrorOccur(t, test)
376	}
377}
378
379func testXdsClientLoseContactRemoteClose(t *testing.T, test *testConfig) {
380	addr, td, _, cleanup := setupServer(t)
381	defer cleanup()
382	adsChan := make(chan proto.Message, 10)
383	newADS := func(ctx context.Context, i proto.Message) error {
384		adsChan <- i
385		return nil
386	}
387	contactChan := make(chan *loseContact, 10)
388	loseContactFunc := func(context.Context) {
389		contactChan <- &loseContact{}
390	}
391	client := newXDSClient(addr, test.doCDS, balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}, nil, newADS, loseContactFunc, func() {})
392	defer client.close()
393	go client.run()
394
395	// make sure server side get the request (i.e stream created successfully on client side)
396	td.getReq()
397
398	for _, resp := range test.responsesToSend {
399		td.sendResp(&response{resp: resp})
400		// make sure client side receives it
401		<-adsChan
402	}
403	cleanup()
404
405	select {
406	case <-contactChan:
407	case <-time.After(2 * time.Second):
408		t.Fatal("time out when expecting lost contact signal")
409	}
410}
411
412func testXdsClientLoseContactADSRelatedErrorOccur(t *testing.T, test *testConfig) {
413	addr, td, _, cleanup := setupServer(t)
414	defer cleanup()
415
416	adsChan := make(chan proto.Message, 10)
417	newADS := func(ctx context.Context, i proto.Message) error {
418		adsChan <- i
419		return test.adsErr
420	}
421	contactChan := make(chan *loseContact, 10)
422	loseContactFunc := func(context.Context) {
423		contactChan <- &loseContact{}
424	}
425	client := newXDSClient(addr, test.doCDS, balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}, nil, newADS, loseContactFunc, func() {})
426	defer client.close()
427	go client.run()
428
429	// make sure server side get the request (i.e stream created successfully on client side)
430	td.getReq()
431
432	for _, resp := range test.responsesToSend {
433		td.sendResp(&response{resp: resp})
434	}
435
436	select {
437	case <-contactChan:
438	case <-time.After(2 * time.Second):
439		t.Fatal("time out when expecting lost contact signal")
440	}
441}
442
443func (s) TestXdsClientExponentialRetry(t *testing.T) {
444	cfg := &testConfig{
445		svrErr: status.Errorf(codes.Aborted, "abort the stream to trigger retry"),
446	}
447	addr, td, _, cleanup := setupServer(t)
448	defer cleanup()
449
450	adsChan := make(chan proto.Message, 10)
451	newADS := func(ctx context.Context, i proto.Message) error {
452		adsChan <- i
453		return nil
454	}
455	contactChan := make(chan *loseContact, 10)
456	loseContactFunc := func(context.Context) {
457		contactChan <- &loseContact{}
458	}
459	client := newXDSClient(addr, cfg.doCDS, balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}, nil, newADS, loseContactFunc, func() {})
460	defer client.close()
461	go client.run()
462
463	var secondRetry, thirdRetry time.Time
464	for i := 0; i < 3; i++ {
465		// make sure server side get the request (i.e stream created successfully on client side)
466		td.getReq()
467		td.sendResp(&response{err: cfg.svrErr})
468
469		select {
470		case <-contactChan:
471			if i == 1 {
472				secondRetry = time.Now()
473			}
474			if i == 2 {
475				thirdRetry = time.Now()
476			}
477		case <-time.After(2 * time.Second):
478			t.Fatal("time out when expecting lost contact signal")
479		}
480	}
481	if thirdRetry.Sub(secondRetry) < 1*time.Second {
482		t.Fatalf("interval between second and third retry is %v, expected > 1s", thirdRetry.Sub(secondRetry))
483	}
484}
485