1// Copyright 2018 Envoyproxy Authors
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15// Package server provides an implementation of a streaming xDS server.
16package server
17
18import (
19	"context"
20	"errors"
21	"strconv"
22	"sync/atomic"
23
24	"github.com/golang/protobuf/ptypes/any"
25	"google.golang.org/grpc"
26	"google.golang.org/grpc/codes"
27	"google.golang.org/grpc/status"
28
29	v2 "github.com/envoyproxy/go-control-plane/envoy/api/v2"
30	v2grpc "github.com/envoyproxy/go-control-plane/envoy/api/v2"
31	core "github.com/envoyproxy/go-control-plane/envoy/api/v2/core"
32	discoverygrpc "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v2"
33	"github.com/envoyproxy/go-control-plane/pkg/cache"
34)
35
36// Server is a collection of handlers for streaming discovery requests.
37type Server interface {
38	v2grpc.EndpointDiscoveryServiceServer
39	v2grpc.ClusterDiscoveryServiceServer
40	v2grpc.RouteDiscoveryServiceServer
41	v2grpc.ListenerDiscoveryServiceServer
42	discoverygrpc.AggregatedDiscoveryServiceServer
43	discoverygrpc.SecretDiscoveryServiceServer
44	discoverygrpc.RuntimeDiscoveryServiceServer
45
46	// Fetch is the universal fetch method.
47	Fetch(context.Context, *v2.DiscoveryRequest) (*v2.DiscoveryResponse, error)
48}
49
50// Callbacks is a collection of callbacks inserted into the server operation.
51// The callbacks are invoked synchronously.
52type Callbacks interface {
53	// OnStreamOpen is called once an xDS stream is open with a stream ID and the type URL (or "" for ADS).
54	// Returning an error will end processing and close the stream. OnStreamClosed will still be called.
55	OnStreamOpen(context.Context, int64, string) error
56	// OnStreamClosed is called immediately prior to closing an xDS stream with a stream ID.
57	OnStreamClosed(int64)
58	// OnStreamRequest is called once a request is received on a stream.
59	// Returning an error will end processing and close the stream. OnStreamClosed will still be called.
60	OnStreamRequest(int64, *v2.DiscoveryRequest) error
61	// OnStreamResponse is called immediately prior to sending a response on a stream.
62	OnStreamResponse(int64, *v2.DiscoveryRequest, *v2.DiscoveryResponse)
63	// OnFetchRequest is called for each Fetch request. Returning an error will end processing of the
64	// request and respond with an error.
65	OnFetchRequest(context.Context, *v2.DiscoveryRequest) error
66	// OnFetchResponse is called immediately prior to sending a response.
67	OnFetchResponse(*v2.DiscoveryRequest, *v2.DiscoveryResponse)
68}
69
70// NewServer creates handlers from a config watcher and callbacks.
71func NewServer(ctx context.Context, config cache.Cache, callbacks Callbacks) Server {
72	return &server{cache: config, callbacks: callbacks, ctx: ctx}
73}
74
75type server struct {
76	cache     cache.Cache
77	callbacks Callbacks
78
79	// streamCount for counting bi-di streams
80	streamCount int64
81	ctx         context.Context
82}
83
84type stream interface {
85	grpc.ServerStream
86
87	Send(*v2.DiscoveryResponse) error
88	Recv() (*v2.DiscoveryRequest, error)
89}
90
91// watches for all xDS resource types
92type watches struct {
93	endpoints chan cache.Response
94	clusters  chan cache.Response
95	routes    chan cache.Response
96	listeners chan cache.Response
97	secrets   chan cache.Response
98	runtimes  chan cache.Response
99
100	endpointCancel func()
101	clusterCancel  func()
102	routeCancel    func()
103	listenerCancel func()
104	secretCancel   func()
105	runtimeCancel  func()
106
107	endpointNonce string
108	clusterNonce  string
109	routeNonce    string
110	listenerNonce string
111	secretNonce   string
112	runtimeNonce  string
113}
114
115// Cancel all watches
116func (values watches) Cancel() {
117	if values.endpointCancel != nil {
118		values.endpointCancel()
119	}
120	if values.clusterCancel != nil {
121		values.clusterCancel()
122	}
123	if values.routeCancel != nil {
124		values.routeCancel()
125	}
126	if values.listenerCancel != nil {
127		values.listenerCancel()
128	}
129	if values.secretCancel != nil {
130		values.secretCancel()
131	}
132	if values.runtimeCancel != nil {
133		values.runtimeCancel()
134	}
135}
136
137func createResponse(resp *cache.Response, typeURL string) (*v2.DiscoveryResponse, error) {
138	if resp == nil {
139		return nil, errors.New("missing response")
140	}
141
142	var resources []*any.Any
143	if resp.ResourceMarshaled {
144		resources = make([]*any.Any, len(resp.MarshaledResources))
145	} else {
146		resources = make([]*any.Any, len(resp.Resources))
147	}
148
149	for i := 0; i < len(resources); i++ {
150		// Envoy relies on serialized protobuf bytes for detecting changes to the resources.
151		// This requires deterministic serialization.
152		if resp.ResourceMarshaled {
153			resources[i] = &any.Any{
154				TypeUrl: typeURL,
155				Value:   resp.MarshaledResources[i],
156			}
157		} else {
158			marshaledResource, err := cache.MarshalResource(resp.Resources[i])
159			if err != nil {
160				return nil, err
161			}
162
163			resources[i] = &any.Any{
164				TypeUrl: typeURL,
165				Value:   marshaledResource,
166			}
167		}
168	}
169	out := &v2.DiscoveryResponse{
170		VersionInfo: resp.Version,
171		Resources:   resources,
172		TypeUrl:     typeURL,
173	}
174	return out, nil
175}
176
177// process handles a bi-di stream request
178func (s *server) process(stream stream, reqCh <-chan *v2.DiscoveryRequest, defaultTypeURL string) error {
179	// increment stream count
180	streamID := atomic.AddInt64(&s.streamCount, 1)
181
182	// unique nonce generator for req-resp pairs per xDS stream; the server
183	// ignores stale nonces. nonce is only modified within send() function.
184	var streamNonce int64
185
186	// a collection of watches per request type
187	var values watches
188	defer func() {
189		values.Cancel()
190		if s.callbacks != nil {
191			s.callbacks.OnStreamClosed(streamID)
192		}
193	}()
194
195	// sends a response by serializing to protobuf Any
196	send := func(resp cache.Response, typeURL string) (string, error) {
197		out, err := createResponse(&resp, typeURL)
198		if err != nil {
199			return "", err
200		}
201
202		// increment nonce
203		streamNonce = streamNonce + 1
204		out.Nonce = strconv.FormatInt(streamNonce, 10)
205		if s.callbacks != nil {
206			s.callbacks.OnStreamResponse(streamID, &resp.Request, out)
207		}
208		return out.Nonce, stream.Send(out)
209	}
210
211	if s.callbacks != nil {
212		if err := s.callbacks.OnStreamOpen(stream.Context(), streamID, defaultTypeURL); err != nil {
213			return err
214		}
215	}
216
217	// node may only be set on the first discovery request
218	var node = &core.Node{}
219
220	for {
221		select {
222		case <-s.ctx.Done():
223			return nil
224		// config watcher can send the requested resources types in any order
225		case resp, more := <-values.endpoints:
226			if !more {
227				return status.Errorf(codes.Unavailable, "endpoints watch failed")
228			}
229			nonce, err := send(resp, cache.EndpointType)
230			if err != nil {
231				return err
232			}
233			values.endpointNonce = nonce
234
235		case resp, more := <-values.clusters:
236			if !more {
237				return status.Errorf(codes.Unavailable, "clusters watch failed")
238			}
239			nonce, err := send(resp, cache.ClusterType)
240			if err != nil {
241				return err
242			}
243			values.clusterNonce = nonce
244
245		case resp, more := <-values.routes:
246			if !more {
247				return status.Errorf(codes.Unavailable, "routes watch failed")
248			}
249			nonce, err := send(resp, cache.RouteType)
250			if err != nil {
251				return err
252			}
253			values.routeNonce = nonce
254
255		case resp, more := <-values.listeners:
256			if !more {
257				return status.Errorf(codes.Unavailable, "listeners watch failed")
258			}
259			nonce, err := send(resp, cache.ListenerType)
260			if err != nil {
261				return err
262			}
263			values.listenerNonce = nonce
264
265		case resp, more := <-values.secrets:
266			if !more {
267				return status.Errorf(codes.Unavailable, "secrets watch failed")
268			}
269			nonce, err := send(resp, cache.SecretType)
270			if err != nil {
271				return err
272			}
273			values.secretNonce = nonce
274
275		case resp, more := <-values.runtimes:
276			if !more {
277				return status.Errorf(codes.Unavailable, "runtimes watch failed")
278			}
279			nonce, err := send(resp, cache.RuntimeType)
280			if err != nil {
281				return err
282			}
283			values.runtimeNonce = nonce
284
285		case req, more := <-reqCh:
286			// input stream ended or errored out
287			if !more {
288				return nil
289			}
290			if req == nil {
291				return status.Errorf(codes.Unavailable, "empty request")
292			}
293
294			// node field in discovery request is delta-compressed
295			if req.Node != nil {
296				node = req.Node
297			} else {
298				req.Node = node
299			}
300
301			// nonces can be reused across streams; we verify nonce only if nonce is not initialized
302			nonce := req.GetResponseNonce()
303
304			// type URL is required for ADS but is implicit for xDS
305			if defaultTypeURL == cache.AnyType {
306				if req.TypeUrl == "" {
307					return status.Errorf(codes.InvalidArgument, "type URL is required for ADS")
308				}
309			} else if req.TypeUrl == "" {
310				req.TypeUrl = defaultTypeURL
311			}
312
313			if s.callbacks != nil {
314				if err := s.callbacks.OnStreamRequest(streamID, req); err != nil {
315					return err
316				}
317			}
318
319			// cancel existing watches to (re-)request a newer version
320			switch {
321			case req.TypeUrl == cache.EndpointType && (values.endpointNonce == "" || values.endpointNonce == nonce):
322				if values.endpointCancel != nil {
323					values.endpointCancel()
324				}
325				values.endpoints, values.endpointCancel = s.cache.CreateWatch(*req)
326			case req.TypeUrl == cache.ClusterType && (values.clusterNonce == "" || values.clusterNonce == nonce):
327				if values.clusterCancel != nil {
328					values.clusterCancel()
329				}
330				values.clusters, values.clusterCancel = s.cache.CreateWatch(*req)
331			case req.TypeUrl == cache.RouteType && (values.routeNonce == "" || values.routeNonce == nonce):
332				if values.routeCancel != nil {
333					values.routeCancel()
334				}
335				values.routes, values.routeCancel = s.cache.CreateWatch(*req)
336			case req.TypeUrl == cache.ListenerType && (values.listenerNonce == "" || values.listenerNonce == nonce):
337				if values.listenerCancel != nil {
338					values.listenerCancel()
339				}
340				values.listeners, values.listenerCancel = s.cache.CreateWatch(*req)
341			case req.TypeUrl == cache.SecretType && (values.secretNonce == "" || values.secretNonce == nonce):
342				if values.secretCancel != nil {
343					values.secretCancel()
344				}
345				values.secrets, values.secretCancel = s.cache.CreateWatch(*req)
346			case req.TypeUrl == cache.RuntimeType && (values.runtimeNonce == "" || values.runtimeNonce == nonce):
347				if values.runtimeCancel != nil {
348					values.runtimeCancel()
349				}
350				values.runtimes, values.runtimeCancel = s.cache.CreateWatch(*req)
351			}
352		}
353	}
354}
355
356// handler converts a blocking read call to channels and initiates stream processing
357func (s *server) handler(stream stream, typeURL string) error {
358	// a channel for receiving incoming requests
359	reqCh := make(chan *v2.DiscoveryRequest)
360	reqStop := int32(0)
361	go func() {
362		for {
363			req, err := stream.Recv()
364			if atomic.LoadInt32(&reqStop) != 0 {
365				return
366			}
367			if err != nil {
368				close(reqCh)
369				return
370			}
371			reqCh <- req
372		}
373	}()
374
375	err := s.process(stream, reqCh, typeURL)
376
377	// prevents writing to a closed channel if send failed on blocked recv
378	// TODO(kuat) figure out how to unblock recv through gRPC API
379	atomic.StoreInt32(&reqStop, 1)
380
381	return err
382}
383
384func (s *server) StreamAggregatedResources(stream discoverygrpc.AggregatedDiscoveryService_StreamAggregatedResourcesServer) error {
385	return s.handler(stream, cache.AnyType)
386}
387
388func (s *server) StreamEndpoints(stream v2grpc.EndpointDiscoveryService_StreamEndpointsServer) error {
389	return s.handler(stream, cache.EndpointType)
390}
391
392func (s *server) StreamClusters(stream v2grpc.ClusterDiscoveryService_StreamClustersServer) error {
393	return s.handler(stream, cache.ClusterType)
394}
395
396func (s *server) StreamRoutes(stream v2grpc.RouteDiscoveryService_StreamRoutesServer) error {
397	return s.handler(stream, cache.RouteType)
398}
399
400func (s *server) StreamListeners(stream v2grpc.ListenerDiscoveryService_StreamListenersServer) error {
401	return s.handler(stream, cache.ListenerType)
402}
403
404func (s *server) StreamSecrets(stream discoverygrpc.SecretDiscoveryService_StreamSecretsServer) error {
405	return s.handler(stream, cache.SecretType)
406}
407
408func (s *server) StreamRuntime(stream discoverygrpc.RuntimeDiscoveryService_StreamRuntimeServer) error {
409	return s.handler(stream, cache.RuntimeType)
410}
411
412// Fetch is the universal fetch method.
413func (s *server) Fetch(ctx context.Context, req *v2.DiscoveryRequest) (*v2.DiscoveryResponse, error) {
414	if s.callbacks != nil {
415		if err := s.callbacks.OnFetchRequest(ctx, req); err != nil {
416			return nil, err
417		}
418	}
419	resp, err := s.cache.Fetch(ctx, *req)
420	if err != nil {
421		return nil, err
422	}
423	out, err := createResponse(resp, req.TypeUrl)
424	if s.callbacks != nil {
425		s.callbacks.OnFetchResponse(req, out)
426	}
427	return out, err
428}
429
430func (s *server) FetchEndpoints(ctx context.Context, req *v2.DiscoveryRequest) (*v2.DiscoveryResponse, error) {
431	if req == nil {
432		return nil, status.Errorf(codes.Unavailable, "empty request")
433	}
434	req.TypeUrl = cache.EndpointType
435	return s.Fetch(ctx, req)
436}
437
438func (s *server) FetchClusters(ctx context.Context, req *v2.DiscoveryRequest) (*v2.DiscoveryResponse, error) {
439	if req == nil {
440		return nil, status.Errorf(codes.Unavailable, "empty request")
441	}
442	req.TypeUrl = cache.ClusterType
443	return s.Fetch(ctx, req)
444}
445
446func (s *server) FetchRoutes(ctx context.Context, req *v2.DiscoveryRequest) (*v2.DiscoveryResponse, error) {
447	if req == nil {
448		return nil, status.Errorf(codes.Unavailable, "empty request")
449	}
450	req.TypeUrl = cache.RouteType
451	return s.Fetch(ctx, req)
452}
453
454func (s *server) FetchListeners(ctx context.Context, req *v2.DiscoveryRequest) (*v2.DiscoveryResponse, error) {
455	if req == nil {
456		return nil, status.Errorf(codes.Unavailable, "empty request")
457	}
458	req.TypeUrl = cache.ListenerType
459	return s.Fetch(ctx, req)
460}
461
462func (s *server) FetchSecrets(ctx context.Context, req *v2.DiscoveryRequest) (*v2.DiscoveryResponse, error) {
463	if req == nil {
464		return nil, status.Errorf(codes.Unavailable, "empty request")
465	}
466	req.TypeUrl = cache.SecretType
467	return s.Fetch(ctx, req)
468}
469
470func (s *server) FetchRuntime(ctx context.Context, req *v2.DiscoveryRequest) (*v2.DiscoveryResponse, error) {
471	if req == nil {
472		return nil, status.Errorf(codes.Unavailable, "empty request")
473	}
474	req.TypeUrl = cache.RuntimeType
475	return s.Fetch(ctx, req)
476}
477
478func (s *server) DeltaAggregatedResources(_ discoverygrpc.AggregatedDiscoveryService_DeltaAggregatedResourcesServer) error {
479	return errors.New("not implemented")
480}
481
482func (s *server) DeltaEndpoints(_ v2grpc.EndpointDiscoveryService_DeltaEndpointsServer) error {
483	return errors.New("not implemented")
484}
485
486func (s *server) DeltaClusters(_ v2grpc.ClusterDiscoveryService_DeltaClustersServer) error {
487	return errors.New("not implemented")
488}
489
490func (s *server) DeltaRoutes(_ v2grpc.RouteDiscoveryService_DeltaRoutesServer) error {
491	return errors.New("not implemented")
492}
493
494func (s *server) DeltaListeners(_ v2grpc.ListenerDiscoveryService_DeltaListenersServer) error {
495	return errors.New("not implemented")
496}
497
498func (s *server) DeltaSecrets(_ discoverygrpc.SecretDiscoveryService_DeltaSecretsServer) error {
499	return errors.New("not implemented")
500}
501
502func (s *server) DeltaRuntime(_ discoverygrpc.RuntimeDiscoveryService_DeltaRuntimeServer) error {
503	return errors.New("not implemented")
504}
505