1// Copyright 2015 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package v3rpc
16
17import (
18	"context"
19	"io"
20	"math/rand"
21	"sync"
22	"time"
23
24	"go.etcd.io/etcd/auth"
25	"go.etcd.io/etcd/etcdserver"
26	"go.etcd.io/etcd/etcdserver/api/v3rpc/rpctypes"
27	pb "go.etcd.io/etcd/etcdserver/etcdserverpb"
28	"go.etcd.io/etcd/mvcc"
29	"go.etcd.io/etcd/mvcc/mvccpb"
30
31	"go.uber.org/zap"
32)
33
34type watchServer struct {
35	lg *zap.Logger
36
37	clusterID int64
38	memberID  int64
39
40	maxRequestBytes int
41
42	sg        etcdserver.RaftStatusGetter
43	watchable mvcc.WatchableKV
44	ag        AuthGetter
45}
46
47// NewWatchServer returns a new watch server.
48func NewWatchServer(s *etcdserver.EtcdServer) pb.WatchServer {
49	srv := &watchServer{
50		lg: s.Cfg.Logger,
51
52		clusterID: int64(s.Cluster().ID()),
53		memberID:  int64(s.ID()),
54
55		maxRequestBytes: int(s.Cfg.MaxRequestBytes + grpcOverheadBytes),
56
57		sg:        s,
58		watchable: s.Watchable(),
59		ag:        s,
60	}
61	if srv.lg == nil {
62		srv.lg = zap.NewNop()
63	}
64	return srv
65}
66
67var (
68	// External test can read this with GetProgressReportInterval()
69	// and change this to a small value to finish fast with
70	// SetProgressReportInterval().
71	progressReportInterval   = 10 * time.Minute
72	progressReportIntervalMu sync.RWMutex
73)
74
75// GetProgressReportInterval returns the current progress report interval (for testing).
76func GetProgressReportInterval() time.Duration {
77	progressReportIntervalMu.RLock()
78	interval := progressReportInterval
79	progressReportIntervalMu.RUnlock()
80
81	// add rand(1/10*progressReportInterval) as jitter so that etcdserver will not
82	// send progress notifications to watchers around the same time even when watchers
83	// are created around the same time (which is common when a client restarts itself).
84	jitter := time.Duration(rand.Int63n(int64(interval) / 10))
85
86	return interval + jitter
87}
88
89// SetProgressReportInterval updates the current progress report interval (for testing).
90func SetProgressReportInterval(newTimeout time.Duration) {
91	progressReportIntervalMu.Lock()
92	progressReportInterval = newTimeout
93	progressReportIntervalMu.Unlock()
94}
95
96// We send ctrl response inside the read loop. We do not want
97// send to block read, but we still want ctrl response we sent to
98// be serialized. Thus we use a buffered chan to solve the problem.
99// A small buffer should be OK for most cases, since we expect the
100// ctrl requests are infrequent.
101const ctrlStreamBufLen = 16
102
103// serverWatchStream is an etcd server side stream. It receives requests
104// from client side gRPC stream. It receives watch events from mvcc.WatchStream,
105// and creates responses that forwarded to gRPC stream.
106// It also forwards control message like watch created and canceled.
107type serverWatchStream struct {
108	lg *zap.Logger
109
110	clusterID int64
111	memberID  int64
112
113	maxRequestBytes int
114
115	sg        etcdserver.RaftStatusGetter
116	watchable mvcc.WatchableKV
117	ag        AuthGetter
118
119	gRPCStream  pb.Watch_WatchServer
120	watchStream mvcc.WatchStream
121	ctrlStream  chan *pb.WatchResponse
122
123	// mu protects progress, prevKV, fragment
124	mu sync.RWMutex
125	// tracks the watchID that stream might need to send progress to
126	// TODO: combine progress and prevKV into a single struct?
127	progress map[mvcc.WatchID]bool
128	// record watch IDs that need return previous key-value pair
129	prevKV map[mvcc.WatchID]bool
130	// records fragmented watch IDs
131	fragment map[mvcc.WatchID]bool
132
133	// closec indicates the stream is closed.
134	closec chan struct{}
135
136	// wg waits for the send loop to complete
137	wg sync.WaitGroup
138}
139
140func (ws *watchServer) Watch(stream pb.Watch_WatchServer) (err error) {
141	sws := serverWatchStream{
142		lg: ws.lg,
143
144		clusterID: ws.clusterID,
145		memberID:  ws.memberID,
146
147		maxRequestBytes: ws.maxRequestBytes,
148
149		sg:        ws.sg,
150		watchable: ws.watchable,
151		ag:        ws.ag,
152
153		gRPCStream:  stream,
154		watchStream: ws.watchable.NewWatchStream(),
155		// chan for sending control response like watcher created and canceled.
156		ctrlStream: make(chan *pb.WatchResponse, ctrlStreamBufLen),
157
158		progress: make(map[mvcc.WatchID]bool),
159		prevKV:   make(map[mvcc.WatchID]bool),
160		fragment: make(map[mvcc.WatchID]bool),
161
162		closec: make(chan struct{}),
163	}
164
165	sws.wg.Add(1)
166	go func() {
167		sws.sendLoop()
168		sws.wg.Done()
169	}()
170
171	errc := make(chan error, 1)
172	// Ideally recvLoop would also use sws.wg to signal its completion
173	// but when stream.Context().Done() is closed, the stream's recv
174	// may continue to block since it uses a different context, leading to
175	// deadlock when calling sws.close().
176	go func() {
177		if rerr := sws.recvLoop(); rerr != nil {
178			if isClientCtxErr(stream.Context().Err(), rerr) {
179				sws.lg.Debug("failed to receive watch request from gRPC stream", zap.Error(rerr))
180			} else {
181				sws.lg.Warn("failed to receive watch request from gRPC stream", zap.Error(rerr))
182				streamFailures.WithLabelValues("receive", "watch").Inc()
183			}
184			errc <- rerr
185		}
186	}()
187
188	select {
189	case err = <-errc:
190		close(sws.ctrlStream)
191
192	case <-stream.Context().Done():
193		err = stream.Context().Err()
194		// the only server-side cancellation is noleader for now.
195		if err == context.Canceled {
196			err = rpctypes.ErrGRPCNoLeader
197		}
198	}
199
200	sws.close()
201	return err
202}
203
204func (sws *serverWatchStream) isWatchPermitted(wcr *pb.WatchCreateRequest) bool {
205	authInfo, err := sws.ag.AuthInfoFromCtx(sws.gRPCStream.Context())
206	if err != nil {
207		return false
208	}
209	if authInfo == nil {
210		// if auth is enabled, IsRangePermitted() can cause an error
211		authInfo = &auth.AuthInfo{}
212	}
213	return sws.ag.AuthStore().IsRangePermitted(authInfo, wcr.Key, wcr.RangeEnd) == nil
214}
215
216func (sws *serverWatchStream) recvLoop() error {
217	for {
218		req, err := sws.gRPCStream.Recv()
219		if err == io.EOF {
220			return nil
221		}
222		if err != nil {
223			return err
224		}
225
226		switch uv := req.RequestUnion.(type) {
227		case *pb.WatchRequest_CreateRequest:
228			if uv.CreateRequest == nil {
229				break
230			}
231
232			creq := uv.CreateRequest
233			if len(creq.Key) == 0 {
234				// \x00 is the smallest key
235				creq.Key = []byte{0}
236			}
237			if len(creq.RangeEnd) == 0 {
238				// force nil since watchstream.Watch distinguishes
239				// between nil and []byte{} for single key / >=
240				creq.RangeEnd = nil
241			}
242			if len(creq.RangeEnd) == 1 && creq.RangeEnd[0] == 0 {
243				// support  >= key queries
244				creq.RangeEnd = []byte{}
245			}
246
247			if !sws.isWatchPermitted(creq) {
248				wr := &pb.WatchResponse{
249					Header:       sws.newResponseHeader(sws.watchStream.Rev()),
250					WatchId:      creq.WatchId,
251					Canceled:     true,
252					Created:      true,
253					CancelReason: rpctypes.ErrGRPCPermissionDenied.Error(),
254				}
255
256				select {
257				case sws.ctrlStream <- wr:
258					continue
259				case <-sws.closec:
260					return nil
261				}
262			}
263
264			filters := FiltersFromRequest(creq)
265
266			wsrev := sws.watchStream.Rev()
267			rev := creq.StartRevision
268			if rev == 0 {
269				rev = wsrev + 1
270			}
271			id, err := sws.watchStream.Watch(mvcc.WatchID(creq.WatchId), creq.Key, creq.RangeEnd, rev, filters...)
272			if err == nil {
273				sws.mu.Lock()
274				if creq.ProgressNotify {
275					sws.progress[id] = true
276				}
277				if creq.PrevKv {
278					sws.prevKV[id] = true
279				}
280				if creq.Fragment {
281					sws.fragment[id] = true
282				}
283				sws.mu.Unlock()
284			}
285			wr := &pb.WatchResponse{
286				Header:   sws.newResponseHeader(wsrev),
287				WatchId:  int64(id),
288				Created:  true,
289				Canceled: err != nil,
290			}
291			if err != nil {
292				wr.CancelReason = err.Error()
293			}
294			select {
295			case sws.ctrlStream <- wr:
296			case <-sws.closec:
297				return nil
298			}
299
300		case *pb.WatchRequest_CancelRequest:
301			if uv.CancelRequest != nil {
302				id := uv.CancelRequest.WatchId
303				err := sws.watchStream.Cancel(mvcc.WatchID(id))
304				if err == nil {
305					sws.ctrlStream <- &pb.WatchResponse{
306						Header:   sws.newResponseHeader(sws.watchStream.Rev()),
307						WatchId:  id,
308						Canceled: true,
309					}
310					sws.mu.Lock()
311					delete(sws.progress, mvcc.WatchID(id))
312					delete(sws.prevKV, mvcc.WatchID(id))
313					delete(sws.fragment, mvcc.WatchID(id))
314					sws.mu.Unlock()
315				}
316			}
317		case *pb.WatchRequest_ProgressRequest:
318			if uv.ProgressRequest != nil {
319				sws.ctrlStream <- &pb.WatchResponse{
320					Header:  sws.newResponseHeader(sws.watchStream.Rev()),
321					WatchId: -1, // response is not associated with any WatchId and will be broadcast to all watch channels
322				}
323			}
324		default:
325			// we probably should not shutdown the entire stream when
326			// receive an valid command.
327			// so just do nothing instead.
328			continue
329		}
330	}
331}
332
333func (sws *serverWatchStream) sendLoop() {
334	// watch ids that are currently active
335	ids := make(map[mvcc.WatchID]struct{})
336	// watch responses pending on a watch id creation message
337	pending := make(map[mvcc.WatchID][]*pb.WatchResponse)
338
339	interval := GetProgressReportInterval()
340	progressTicker := time.NewTicker(interval)
341
342	defer func() {
343		progressTicker.Stop()
344		// drain the chan to clean up pending events
345		for ws := range sws.watchStream.Chan() {
346			mvcc.ReportEventReceived(len(ws.Events))
347		}
348		for _, wrs := range pending {
349			for _, ws := range wrs {
350				mvcc.ReportEventReceived(len(ws.Events))
351			}
352		}
353	}()
354
355	for {
356		select {
357		case wresp, ok := <-sws.watchStream.Chan():
358			if !ok {
359				return
360			}
361
362			// TODO: evs is []mvccpb.Event type
363			// either return []*mvccpb.Event from the mvcc package
364			// or define protocol buffer with []mvccpb.Event.
365			evs := wresp.Events
366			events := make([]*mvccpb.Event, len(evs))
367			sws.mu.RLock()
368			needPrevKV := sws.prevKV[wresp.WatchID]
369			sws.mu.RUnlock()
370			for i := range evs {
371				events[i] = &evs[i]
372				if needPrevKV {
373					opt := mvcc.RangeOptions{Rev: evs[i].Kv.ModRevision - 1}
374					r, err := sws.watchable.Range(evs[i].Kv.Key, nil, opt)
375					if err == nil && len(r.KVs) != 0 {
376						events[i].PrevKv = &(r.KVs[0])
377					}
378				}
379			}
380
381			canceled := wresp.CompactRevision != 0
382			wr := &pb.WatchResponse{
383				Header:          sws.newResponseHeader(wresp.Revision),
384				WatchId:         int64(wresp.WatchID),
385				Events:          events,
386				CompactRevision: wresp.CompactRevision,
387				Canceled:        canceled,
388			}
389
390			if _, okID := ids[wresp.WatchID]; !okID {
391				// buffer if id not yet announced
392				wrs := append(pending[wresp.WatchID], wr)
393				pending[wresp.WatchID] = wrs
394				continue
395			}
396
397			mvcc.ReportEventReceived(len(evs))
398
399			sws.mu.RLock()
400			fragmented, ok := sws.fragment[wresp.WatchID]
401			sws.mu.RUnlock()
402
403			var serr error
404			if !fragmented && !ok {
405				serr = sws.gRPCStream.Send(wr)
406			} else {
407				serr = sendFragments(wr, sws.maxRequestBytes, sws.gRPCStream.Send)
408			}
409
410			if serr != nil {
411				if isClientCtxErr(sws.gRPCStream.Context().Err(), serr) {
412					sws.lg.Debug("failed to send watch response to gRPC stream", zap.Error(serr))
413				} else {
414					sws.lg.Warn("failed to send watch response to gRPC stream", zap.Error(serr))
415					streamFailures.WithLabelValues("send", "watch").Inc()
416				}
417				return
418			}
419
420			sws.mu.Lock()
421			if len(evs) > 0 && sws.progress[wresp.WatchID] {
422				// elide next progress update if sent a key update
423				sws.progress[wresp.WatchID] = false
424			}
425			sws.mu.Unlock()
426
427		case c, ok := <-sws.ctrlStream:
428			if !ok {
429				return
430			}
431
432			if err := sws.gRPCStream.Send(c); err != nil {
433				if isClientCtxErr(sws.gRPCStream.Context().Err(), err) {
434					sws.lg.Debug("failed to send watch control response to gRPC stream", zap.Error(err))
435				} else {
436					sws.lg.Warn("failed to send watch control response to gRPC stream", zap.Error(err))
437					streamFailures.WithLabelValues("send", "watch").Inc()
438				}
439				return
440			}
441
442			// track id creation
443			wid := mvcc.WatchID(c.WatchId)
444			if c.Canceled {
445				delete(ids, wid)
446				continue
447			}
448			if c.Created {
449				// flush buffered events
450				ids[wid] = struct{}{}
451				for _, v := range pending[wid] {
452					mvcc.ReportEventReceived(len(v.Events))
453					if err := sws.gRPCStream.Send(v); err != nil {
454						if isClientCtxErr(sws.gRPCStream.Context().Err(), err) {
455							sws.lg.Debug("failed to send pending watch response to gRPC stream", zap.Error(err))
456						} else {
457							sws.lg.Warn("failed to send pending watch response to gRPC stream", zap.Error(err))
458							streamFailures.WithLabelValues("send", "watch").Inc()
459						}
460						return
461					}
462				}
463				delete(pending, wid)
464			}
465
466		case <-progressTicker.C:
467			sws.mu.Lock()
468			for id, ok := range sws.progress {
469				if ok {
470					sws.watchStream.RequestProgress(id)
471				}
472				sws.progress[id] = true
473			}
474			sws.mu.Unlock()
475
476		case <-sws.closec:
477			return
478		}
479	}
480}
481
482func sendFragments(
483	wr *pb.WatchResponse,
484	maxRequestBytes int,
485	sendFunc func(*pb.WatchResponse) error) error {
486	// no need to fragment if total request size is smaller
487	// than max request limit or response contains only one event
488	if wr.Size() < maxRequestBytes || len(wr.Events) < 2 {
489		return sendFunc(wr)
490	}
491
492	ow := *wr
493	ow.Events = make([]*mvccpb.Event, 0)
494	ow.Fragment = true
495
496	var idx int
497	for {
498		cur := ow
499		for _, ev := range wr.Events[idx:] {
500			cur.Events = append(cur.Events, ev)
501			if len(cur.Events) > 1 && cur.Size() >= maxRequestBytes {
502				cur.Events = cur.Events[:len(cur.Events)-1]
503				break
504			}
505			idx++
506		}
507		if idx == len(wr.Events) {
508			// last response has no more fragment
509			cur.Fragment = false
510		}
511		if err := sendFunc(&cur); err != nil {
512			return err
513		}
514		if !cur.Fragment {
515			break
516		}
517	}
518	return nil
519}
520
521func (sws *serverWatchStream) close() {
522	sws.watchStream.Close()
523	close(sws.closec)
524	sws.wg.Wait()
525}
526
527func (sws *serverWatchStream) newResponseHeader(rev int64) *pb.ResponseHeader {
528	return &pb.ResponseHeader{
529		ClusterId: uint64(sws.clusterID),
530		MemberId:  uint64(sws.memberID),
531		Revision:  rev,
532		RaftTerm:  sws.sg.Term(),
533	}
534}
535
536func filterNoDelete(e mvccpb.Event) bool {
537	return e.Type == mvccpb.DELETE
538}
539
540func filterNoPut(e mvccpb.Event) bool {
541	return e.Type == mvccpb.PUT
542}
543
544// FiltersFromRequest returns "mvcc.FilterFunc" from a given watch create request.
545func FiltersFromRequest(creq *pb.WatchCreateRequest) []mvcc.FilterFunc {
546	filters := make([]mvcc.FilterFunc, 0, len(creq.Filters))
547	for _, ft := range creq.Filters {
548		switch ft {
549		case pb.WatchCreateRequest_NOPUT:
550			filters = append(filters, filterNoPut)
551		case pb.WatchCreateRequest_NODELETE:
552			filters = append(filters, filterNoDelete)
553		default:
554		}
555	}
556	return filters
557}
558