1// Copyright 2015 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package v3rpc
16
17import (
18	"context"
19	"io"
20	"math/rand"
21	"sync"
22	"time"
23
24	pb "go.etcd.io/etcd/api/v3/etcdserverpb"
25	"go.etcd.io/etcd/api/v3/mvccpb"
26	"go.etcd.io/etcd/api/v3/v3rpc/rpctypes"
27	"go.etcd.io/etcd/server/v3/auth"
28	"go.etcd.io/etcd/server/v3/etcdserver"
29	"go.etcd.io/etcd/server/v3/mvcc"
30
31	"go.uber.org/zap"
32)
33
34const minWatchProgressInterval = 100 * time.Millisecond
35
36type watchServer struct {
37	lg *zap.Logger
38
39	clusterID int64
40	memberID  int64
41
42	maxRequestBytes int
43
44	sg        etcdserver.RaftStatusGetter
45	watchable mvcc.WatchableKV
46	ag        AuthGetter
47}
48
49// NewWatchServer returns a new watch server.
50func NewWatchServer(s *etcdserver.EtcdServer) pb.WatchServer {
51	srv := &watchServer{
52		lg: s.Cfg.Logger,
53
54		clusterID: int64(s.Cluster().ID()),
55		memberID:  int64(s.ID()),
56
57		maxRequestBytes: int(s.Cfg.MaxRequestBytes + grpcOverheadBytes),
58
59		sg:        s,
60		watchable: s.Watchable(),
61		ag:        s,
62	}
63	if srv.lg == nil {
64		srv.lg = zap.NewNop()
65	}
66	if s.Cfg.WatchProgressNotifyInterval > 0 {
67		if s.Cfg.WatchProgressNotifyInterval < minWatchProgressInterval {
68			srv.lg.Warn(
69				"adjusting watch progress notify interval to minimum period",
70				zap.Duration("min-watch-progress-notify-interval", minWatchProgressInterval),
71			)
72			s.Cfg.WatchProgressNotifyInterval = minWatchProgressInterval
73		}
74		SetProgressReportInterval(s.Cfg.WatchProgressNotifyInterval)
75	}
76	return srv
77}
78
79var (
80	// External test can read this with GetProgressReportInterval()
81	// and change this to a small value to finish fast with
82	// SetProgressReportInterval().
83	progressReportInterval   = 10 * time.Minute
84	progressReportIntervalMu sync.RWMutex
85)
86
87// GetProgressReportInterval returns the current progress report interval (for testing).
88func GetProgressReportInterval() time.Duration {
89	progressReportIntervalMu.RLock()
90	interval := progressReportInterval
91	progressReportIntervalMu.RUnlock()
92
93	// add rand(1/10*progressReportInterval) as jitter so that etcdserver will not
94	// send progress notifications to watchers around the same time even when watchers
95	// are created around the same time (which is common when a client restarts itself).
96	jitter := time.Duration(rand.Int63n(int64(interval) / 10))
97
98	return interval + jitter
99}
100
101// SetProgressReportInterval updates the current progress report interval (for testing).
102func SetProgressReportInterval(newTimeout time.Duration) {
103	progressReportIntervalMu.Lock()
104	progressReportInterval = newTimeout
105	progressReportIntervalMu.Unlock()
106}
107
108// We send ctrl response inside the read loop. We do not want
109// send to block read, but we still want ctrl response we sent to
110// be serialized. Thus we use a buffered chan to solve the problem.
111// A small buffer should be OK for most cases, since we expect the
112// ctrl requests are infrequent.
113const ctrlStreamBufLen = 16
114
115// serverWatchStream is an etcd server side stream. It receives requests
116// from client side gRPC stream. It receives watch events from mvcc.WatchStream,
117// and creates responses that forwarded to gRPC stream.
118// It also forwards control message like watch created and canceled.
119type serverWatchStream struct {
120	lg *zap.Logger
121
122	clusterID int64
123	memberID  int64
124
125	maxRequestBytes int
126
127	sg        etcdserver.RaftStatusGetter
128	watchable mvcc.WatchableKV
129	ag        AuthGetter
130
131	gRPCStream  pb.Watch_WatchServer
132	watchStream mvcc.WatchStream
133	ctrlStream  chan *pb.WatchResponse
134
135	// mu protects progress, prevKV, fragment
136	mu sync.RWMutex
137	// tracks the watchID that stream might need to send progress to
138	// TODO: combine progress and prevKV into a single struct?
139	progress map[mvcc.WatchID]bool
140	// record watch IDs that need return previous key-value pair
141	prevKV map[mvcc.WatchID]bool
142	// records fragmented watch IDs
143	fragment map[mvcc.WatchID]bool
144
145	// closec indicates the stream is closed.
146	closec chan struct{}
147
148	// wg waits for the send loop to complete
149	wg sync.WaitGroup
150}
151
152func (ws *watchServer) Watch(stream pb.Watch_WatchServer) (err error) {
153	sws := serverWatchStream{
154		lg: ws.lg,
155
156		clusterID: ws.clusterID,
157		memberID:  ws.memberID,
158
159		maxRequestBytes: ws.maxRequestBytes,
160
161		sg:        ws.sg,
162		watchable: ws.watchable,
163		ag:        ws.ag,
164
165		gRPCStream:  stream,
166		watchStream: ws.watchable.NewWatchStream(),
167		// chan for sending control response like watcher created and canceled.
168		ctrlStream: make(chan *pb.WatchResponse, ctrlStreamBufLen),
169
170		progress: make(map[mvcc.WatchID]bool),
171		prevKV:   make(map[mvcc.WatchID]bool),
172		fragment: make(map[mvcc.WatchID]bool),
173
174		closec: make(chan struct{}),
175	}
176
177	sws.wg.Add(1)
178	go func() {
179		sws.sendLoop()
180		sws.wg.Done()
181	}()
182
183	errc := make(chan error, 1)
184	// Ideally recvLoop would also use sws.wg to signal its completion
185	// but when stream.Context().Done() is closed, the stream's recv
186	// may continue to block since it uses a different context, leading to
187	// deadlock when calling sws.close().
188	go func() {
189		if rerr := sws.recvLoop(); rerr != nil {
190			if isClientCtxErr(stream.Context().Err(), rerr) {
191				sws.lg.Debug("failed to receive watch request from gRPC stream", zap.Error(rerr))
192			} else {
193				sws.lg.Warn("failed to receive watch request from gRPC stream", zap.Error(rerr))
194				streamFailures.WithLabelValues("receive", "watch").Inc()
195			}
196			errc <- rerr
197		}
198	}()
199
200	// TODO: There's a race here. When a stream  is closed (e.g. due to a cancellation),
201	// the underlying error (e.g. a gRPC stream error) may be returned and handled
202	// through errc if the recv goroutine finishes before the send goroutine.
203	// When the recv goroutine wins, the stream error is retained. When recv loses
204	// the race, the underlying error is lost (unless the root error is propagated
205	// through Context.Err() which is not always the case (as callers have to decide
206	// to implement a custom context to do so). The stdlib context package builtins
207	// may be insufficient to carry semantically useful errors around and should be
208	// revisited.
209	select {
210	case err = <-errc:
211		if err == context.Canceled {
212			err = rpctypes.ErrGRPCWatchCanceled
213		}
214		close(sws.ctrlStream)
215	case <-stream.Context().Done():
216		err = stream.Context().Err()
217		if err == context.Canceled {
218			err = rpctypes.ErrGRPCWatchCanceled
219		}
220	}
221
222	sws.close()
223	return err
224}
225
226func (sws *serverWatchStream) isWatchPermitted(wcr *pb.WatchCreateRequest) bool {
227	authInfo, err := sws.ag.AuthInfoFromCtx(sws.gRPCStream.Context())
228	if err != nil {
229		return false
230	}
231	if authInfo == nil {
232		// if auth is enabled, IsRangePermitted() can cause an error
233		authInfo = &auth.AuthInfo{}
234	}
235	return sws.ag.AuthStore().IsRangePermitted(authInfo, wcr.Key, wcr.RangeEnd) == nil
236}
237
238func (sws *serverWatchStream) recvLoop() error {
239	for {
240		req, err := sws.gRPCStream.Recv()
241		if err == io.EOF {
242			return nil
243		}
244		if err != nil {
245			return err
246		}
247
248		switch uv := req.RequestUnion.(type) {
249		case *pb.WatchRequest_CreateRequest:
250			if uv.CreateRequest == nil {
251				break
252			}
253
254			creq := uv.CreateRequest
255			if len(creq.Key) == 0 {
256				// \x00 is the smallest key
257				creq.Key = []byte{0}
258			}
259			if len(creq.RangeEnd) == 0 {
260				// force nil since watchstream.Watch distinguishes
261				// between nil and []byte{} for single key / >=
262				creq.RangeEnd = nil
263			}
264			if len(creq.RangeEnd) == 1 && creq.RangeEnd[0] == 0 {
265				// support  >= key queries
266				creq.RangeEnd = []byte{}
267			}
268
269			if !sws.isWatchPermitted(creq) {
270				wr := &pb.WatchResponse{
271					Header:       sws.newResponseHeader(sws.watchStream.Rev()),
272					WatchId:      creq.WatchId,
273					Canceled:     true,
274					Created:      true,
275					CancelReason: rpctypes.ErrGRPCPermissionDenied.Error(),
276				}
277
278				select {
279				case sws.ctrlStream <- wr:
280					continue
281				case <-sws.closec:
282					return nil
283				}
284			}
285
286			filters := FiltersFromRequest(creq)
287
288			wsrev := sws.watchStream.Rev()
289			rev := creq.StartRevision
290			if rev == 0 {
291				rev = wsrev + 1
292			}
293			id, err := sws.watchStream.Watch(mvcc.WatchID(creq.WatchId), creq.Key, creq.RangeEnd, rev, filters...)
294			if err == nil {
295				sws.mu.Lock()
296				if creq.ProgressNotify {
297					sws.progress[id] = true
298				}
299				if creq.PrevKv {
300					sws.prevKV[id] = true
301				}
302				if creq.Fragment {
303					sws.fragment[id] = true
304				}
305				sws.mu.Unlock()
306			}
307			wr := &pb.WatchResponse{
308				Header:   sws.newResponseHeader(wsrev),
309				WatchId:  int64(id),
310				Created:  true,
311				Canceled: err != nil,
312			}
313			if err != nil {
314				wr.CancelReason = err.Error()
315			}
316			select {
317			case sws.ctrlStream <- wr:
318			case <-sws.closec:
319				return nil
320			}
321
322		case *pb.WatchRequest_CancelRequest:
323			if uv.CancelRequest != nil {
324				id := uv.CancelRequest.WatchId
325				err := sws.watchStream.Cancel(mvcc.WatchID(id))
326				if err == nil {
327					sws.ctrlStream <- &pb.WatchResponse{
328						Header:   sws.newResponseHeader(sws.watchStream.Rev()),
329						WatchId:  id,
330						Canceled: true,
331					}
332					sws.mu.Lock()
333					delete(sws.progress, mvcc.WatchID(id))
334					delete(sws.prevKV, mvcc.WatchID(id))
335					delete(sws.fragment, mvcc.WatchID(id))
336					sws.mu.Unlock()
337				}
338			}
339		case *pb.WatchRequest_ProgressRequest:
340			if uv.ProgressRequest != nil {
341				sws.ctrlStream <- &pb.WatchResponse{
342					Header:  sws.newResponseHeader(sws.watchStream.Rev()),
343					WatchId: -1, // response is not associated with any WatchId and will be broadcast to all watch channels
344				}
345			}
346		default:
347			// we probably should not shutdown the entire stream when
348			// receive an valid command.
349			// so just do nothing instead.
350			continue
351		}
352	}
353}
354
355func (sws *serverWatchStream) sendLoop() {
356	// watch ids that are currently active
357	ids := make(map[mvcc.WatchID]struct{})
358	// watch responses pending on a watch id creation message
359	pending := make(map[mvcc.WatchID][]*pb.WatchResponse)
360
361	interval := GetProgressReportInterval()
362	progressTicker := time.NewTicker(interval)
363
364	defer func() {
365		progressTicker.Stop()
366		// drain the chan to clean up pending events
367		for ws := range sws.watchStream.Chan() {
368			mvcc.ReportEventReceived(len(ws.Events))
369		}
370		for _, wrs := range pending {
371			for _, ws := range wrs {
372				mvcc.ReportEventReceived(len(ws.Events))
373			}
374		}
375	}()
376
377	for {
378		select {
379		case wresp, ok := <-sws.watchStream.Chan():
380			if !ok {
381				return
382			}
383
384			// TODO: evs is []mvccpb.Event type
385			// either return []*mvccpb.Event from the mvcc package
386			// or define protocol buffer with []mvccpb.Event.
387			evs := wresp.Events
388			events := make([]*mvccpb.Event, len(evs))
389			sws.mu.RLock()
390			needPrevKV := sws.prevKV[wresp.WatchID]
391			sws.mu.RUnlock()
392			for i := range evs {
393				events[i] = &evs[i]
394				if needPrevKV && !IsCreateEvent(evs[i]) {
395					opt := mvcc.RangeOptions{Rev: evs[i].Kv.ModRevision - 1}
396					r, err := sws.watchable.Range(context.TODO(), evs[i].Kv.Key, nil, opt)
397					if err == nil && len(r.KVs) != 0 {
398						events[i].PrevKv = &(r.KVs[0])
399					}
400				}
401			}
402
403			canceled := wresp.CompactRevision != 0
404			wr := &pb.WatchResponse{
405				Header:          sws.newResponseHeader(wresp.Revision),
406				WatchId:         int64(wresp.WatchID),
407				Events:          events,
408				CompactRevision: wresp.CompactRevision,
409				Canceled:        canceled,
410			}
411
412			if _, okID := ids[wresp.WatchID]; !okID {
413				// buffer if id not yet announced
414				wrs := append(pending[wresp.WatchID], wr)
415				pending[wresp.WatchID] = wrs
416				continue
417			}
418
419			mvcc.ReportEventReceived(len(evs))
420
421			sws.mu.RLock()
422			fragmented, ok := sws.fragment[wresp.WatchID]
423			sws.mu.RUnlock()
424
425			var serr error
426			if !fragmented && !ok {
427				serr = sws.gRPCStream.Send(wr)
428			} else {
429				serr = sendFragments(wr, sws.maxRequestBytes, sws.gRPCStream.Send)
430			}
431
432			if serr != nil {
433				if isClientCtxErr(sws.gRPCStream.Context().Err(), serr) {
434					sws.lg.Debug("failed to send watch response to gRPC stream", zap.Error(serr))
435				} else {
436					sws.lg.Warn("failed to send watch response to gRPC stream", zap.Error(serr))
437					streamFailures.WithLabelValues("send", "watch").Inc()
438				}
439				return
440			}
441
442			sws.mu.Lock()
443			if len(evs) > 0 && sws.progress[wresp.WatchID] {
444				// elide next progress update if sent a key update
445				sws.progress[wresp.WatchID] = false
446			}
447			sws.mu.Unlock()
448
449		case c, ok := <-sws.ctrlStream:
450			if !ok {
451				return
452			}
453
454			if err := sws.gRPCStream.Send(c); err != nil {
455				if isClientCtxErr(sws.gRPCStream.Context().Err(), err) {
456					sws.lg.Debug("failed to send watch control response to gRPC stream", zap.Error(err))
457				} else {
458					sws.lg.Warn("failed to send watch control response to gRPC stream", zap.Error(err))
459					streamFailures.WithLabelValues("send", "watch").Inc()
460				}
461				return
462			}
463
464			// track id creation
465			wid := mvcc.WatchID(c.WatchId)
466			if c.Canceled {
467				delete(ids, wid)
468				continue
469			}
470			if c.Created {
471				// flush buffered events
472				ids[wid] = struct{}{}
473				for _, v := range pending[wid] {
474					mvcc.ReportEventReceived(len(v.Events))
475					if err := sws.gRPCStream.Send(v); err != nil {
476						if isClientCtxErr(sws.gRPCStream.Context().Err(), err) {
477							sws.lg.Debug("failed to send pending watch response to gRPC stream", zap.Error(err))
478						} else {
479							sws.lg.Warn("failed to send pending watch response to gRPC stream", zap.Error(err))
480							streamFailures.WithLabelValues("send", "watch").Inc()
481						}
482						return
483					}
484				}
485				delete(pending, wid)
486			}
487
488		case <-progressTicker.C:
489			sws.mu.Lock()
490			for id, ok := range sws.progress {
491				if ok {
492					sws.watchStream.RequestProgress(id)
493				}
494				sws.progress[id] = true
495			}
496			sws.mu.Unlock()
497
498		case <-sws.closec:
499			return
500		}
501	}
502}
503
504func IsCreateEvent(e mvccpb.Event) bool {
505	return e.Type == mvccpb.PUT && e.Kv.CreateRevision == e.Kv.ModRevision
506}
507
508func sendFragments(
509	wr *pb.WatchResponse,
510	maxRequestBytes int,
511	sendFunc func(*pb.WatchResponse) error) error {
512	// no need to fragment if total request size is smaller
513	// than max request limit or response contains only one event
514	if wr.Size() < maxRequestBytes || len(wr.Events) < 2 {
515		return sendFunc(wr)
516	}
517
518	ow := *wr
519	ow.Events = make([]*mvccpb.Event, 0)
520	ow.Fragment = true
521
522	var idx int
523	for {
524		cur := ow
525		for _, ev := range wr.Events[idx:] {
526			cur.Events = append(cur.Events, ev)
527			if len(cur.Events) > 1 && cur.Size() >= maxRequestBytes {
528				cur.Events = cur.Events[:len(cur.Events)-1]
529				break
530			}
531			idx++
532		}
533		if idx == len(wr.Events) {
534			// last response has no more fragment
535			cur.Fragment = false
536		}
537		if err := sendFunc(&cur); err != nil {
538			return err
539		}
540		if !cur.Fragment {
541			break
542		}
543	}
544	return nil
545}
546
547func (sws *serverWatchStream) close() {
548	sws.watchStream.Close()
549	close(sws.closec)
550	sws.wg.Wait()
551}
552
553func (sws *serverWatchStream) newResponseHeader(rev int64) *pb.ResponseHeader {
554	return &pb.ResponseHeader{
555		ClusterId: uint64(sws.clusterID),
556		MemberId:  uint64(sws.memberID),
557		Revision:  rev,
558		RaftTerm:  sws.sg.Term(),
559	}
560}
561
562func filterNoDelete(e mvccpb.Event) bool {
563	return e.Type == mvccpb.DELETE
564}
565
566func filterNoPut(e mvccpb.Event) bool {
567	return e.Type == mvccpb.PUT
568}
569
570// FiltersFromRequest returns "mvcc.FilterFunc" from a given watch create request.
571func FiltersFromRequest(creq *pb.WatchCreateRequest) []mvcc.FilterFunc {
572	filters := make([]mvcc.FilterFunc, 0, len(creq.Filters))
573	for _, ft := range creq.Filters {
574		switch ft {
575		case pb.WatchCreateRequest_NOPUT:
576			filters = append(filters, filterNoPut)
577		case pb.WatchCreateRequest_NODELETE:
578			filters = append(filters, filterNoDelete)
579		default:
580		}
581	}
582	return filters
583}
584