1// Copyright 2015 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package v3rpc
16
17import (
18	"context"
19	"io"
20	"sync"
21	"time"
22
23	"github.com/coreos/etcd/auth"
24	"github.com/coreos/etcd/etcdserver"
25	"github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
26	pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
27	"github.com/coreos/etcd/mvcc"
28	"github.com/coreos/etcd/mvcc/mvccpb"
29)
30
31type watchServer struct {
32	clusterID int64
33	memberID  int64
34	raftTimer etcdserver.RaftTimer
35	watchable mvcc.WatchableKV
36
37	ag AuthGetter
38}
39
40func NewWatchServer(s *etcdserver.EtcdServer) pb.WatchServer {
41	return &watchServer{
42		clusterID: int64(s.Cluster().ID()),
43		memberID:  int64(s.ID()),
44		raftTimer: s,
45		watchable: s.Watchable(),
46		ag:        s,
47	}
48}
49
50var (
51	// External test can read this with GetProgressReportInterval()
52	// and change this to a small value to finish fast with
53	// SetProgressReportInterval().
54	progressReportInterval   = 10 * time.Minute
55	progressReportIntervalMu sync.RWMutex
56)
57
58func GetProgressReportInterval() time.Duration {
59	progressReportIntervalMu.RLock()
60	defer progressReportIntervalMu.RUnlock()
61	return progressReportInterval
62}
63
64func SetProgressReportInterval(newTimeout time.Duration) {
65	progressReportIntervalMu.Lock()
66	defer progressReportIntervalMu.Unlock()
67	progressReportInterval = newTimeout
68}
69
70const (
71	// We send ctrl response inside the read loop. We do not want
72	// send to block read, but we still want ctrl response we sent to
73	// be serialized. Thus we use a buffered chan to solve the problem.
74	// A small buffer should be OK for most cases, since we expect the
75	// ctrl requests are infrequent.
76	ctrlStreamBufLen = 16
77)
78
79// serverWatchStream is an etcd server side stream. It receives requests
80// from client side gRPC stream. It receives watch events from mvcc.WatchStream,
81// and creates responses that forwarded to gRPC stream.
82// It also forwards control message like watch created and canceled.
83type serverWatchStream struct {
84	clusterID int64
85	memberID  int64
86	raftTimer etcdserver.RaftTimer
87
88	watchable mvcc.WatchableKV
89
90	gRPCStream  pb.Watch_WatchServer
91	watchStream mvcc.WatchStream
92	ctrlStream  chan *pb.WatchResponse
93
94	// mu protects progress, prevKV
95	mu sync.Mutex
96	// progress tracks the watchID that stream might need to send
97	// progress to.
98	// TODO: combine progress and prevKV into a single struct?
99	progress map[mvcc.WatchID]bool
100	prevKV   map[mvcc.WatchID]bool
101
102	// closec indicates the stream is closed.
103	closec chan struct{}
104
105	// wg waits for the send loop to complete
106	wg sync.WaitGroup
107
108	ag AuthGetter
109}
110
111func (ws *watchServer) Watch(stream pb.Watch_WatchServer) (err error) {
112	sws := serverWatchStream{
113		clusterID: ws.clusterID,
114		memberID:  ws.memberID,
115		raftTimer: ws.raftTimer,
116
117		watchable: ws.watchable,
118
119		gRPCStream:  stream,
120		watchStream: ws.watchable.NewWatchStream(),
121		// chan for sending control response like watcher created and canceled.
122		ctrlStream: make(chan *pb.WatchResponse, ctrlStreamBufLen),
123		progress:   make(map[mvcc.WatchID]bool),
124		prevKV:     make(map[mvcc.WatchID]bool),
125		closec:     make(chan struct{}),
126
127		ag: ws.ag,
128	}
129
130	sws.wg.Add(1)
131	go func() {
132		sws.sendLoop()
133		sws.wg.Done()
134	}()
135
136	errc := make(chan error, 1)
137	// Ideally recvLoop would also use sws.wg to signal its completion
138	// but when stream.Context().Done() is closed, the stream's recv
139	// may continue to block since it uses a different context, leading to
140	// deadlock when calling sws.close().
141	go func() {
142		if rerr := sws.recvLoop(); rerr != nil {
143			if isClientCtxErr(stream.Context().Err(), rerr) {
144				plog.Debugf("failed to receive watch request from gRPC stream (%q)", rerr.Error())
145			} else {
146				plog.Warningf("failed to receive watch request from gRPC stream (%q)", rerr.Error())
147			}
148			errc <- rerr
149		}
150	}()
151	select {
152	case err = <-errc:
153		close(sws.ctrlStream)
154	case <-stream.Context().Done():
155		err = stream.Context().Err()
156		// the only server-side cancellation is noleader for now.
157		if err == context.Canceled {
158			err = rpctypes.ErrGRPCNoLeader
159		}
160	}
161	sws.close()
162	return err
163}
164
165func (sws *serverWatchStream) isWatchPermitted(wcr *pb.WatchCreateRequest) bool {
166	authInfo, err := sws.ag.AuthInfoFromCtx(sws.gRPCStream.Context())
167	if err != nil {
168		return false
169	}
170	if authInfo == nil {
171		// if auth is enabled, IsRangePermitted() can cause an error
172		authInfo = &auth.AuthInfo{}
173	}
174
175	return sws.ag.AuthStore().IsRangePermitted(authInfo, wcr.Key, wcr.RangeEnd) == nil
176}
177
178func (sws *serverWatchStream) recvLoop() error {
179	for {
180		req, err := sws.gRPCStream.Recv()
181		if err == io.EOF {
182			return nil
183		}
184		if err != nil {
185			return err
186		}
187
188		switch uv := req.RequestUnion.(type) {
189		case *pb.WatchRequest_CreateRequest:
190			if uv.CreateRequest == nil {
191				break
192			}
193
194			creq := uv.CreateRequest
195			if len(creq.Key) == 0 {
196				// \x00 is the smallest key
197				creq.Key = []byte{0}
198			}
199			if len(creq.RangeEnd) == 0 {
200				// force nil since watchstream.Watch distinguishes
201				// between nil and []byte{} for single key / >=
202				creq.RangeEnd = nil
203			}
204			if len(creq.RangeEnd) == 1 && creq.RangeEnd[0] == 0 {
205				// support  >= key queries
206				creq.RangeEnd = []byte{}
207			}
208
209			if !sws.isWatchPermitted(creq) {
210				wr := &pb.WatchResponse{
211					Header:       sws.newResponseHeader(sws.watchStream.Rev()),
212					WatchId:      -1,
213					Canceled:     true,
214					Created:      true,
215					CancelReason: rpctypes.ErrGRPCPermissionDenied.Error(),
216				}
217
218				select {
219				case sws.ctrlStream <- wr:
220				case <-sws.closec:
221				}
222				return nil
223			}
224
225			filters := FiltersFromRequest(creq)
226
227			wsrev := sws.watchStream.Rev()
228			rev := creq.StartRevision
229			if rev == 0 {
230				rev = wsrev + 1
231			}
232			id := sws.watchStream.Watch(creq.Key, creq.RangeEnd, rev, filters...)
233			if id != -1 {
234				sws.mu.Lock()
235				if creq.ProgressNotify {
236					sws.progress[id] = true
237				}
238				if creq.PrevKv {
239					sws.prevKV[id] = true
240				}
241				sws.mu.Unlock()
242			}
243			wr := &pb.WatchResponse{
244				Header:   sws.newResponseHeader(wsrev),
245				WatchId:  int64(id),
246				Created:  true,
247				Canceled: id == -1,
248			}
249			select {
250			case sws.ctrlStream <- wr:
251			case <-sws.closec:
252				return nil
253			}
254		case *pb.WatchRequest_CancelRequest:
255			if uv.CancelRequest != nil {
256				id := uv.CancelRequest.WatchId
257				err := sws.watchStream.Cancel(mvcc.WatchID(id))
258				if err == nil {
259					sws.ctrlStream <- &pb.WatchResponse{
260						Header:   sws.newResponseHeader(sws.watchStream.Rev()),
261						WatchId:  id,
262						Canceled: true,
263					}
264					sws.mu.Lock()
265					delete(sws.progress, mvcc.WatchID(id))
266					delete(sws.prevKV, mvcc.WatchID(id))
267					sws.mu.Unlock()
268				}
269			}
270		default:
271			// we probably should not shutdown the entire stream when
272			// receive an valid command.
273			// so just do nothing instead.
274			continue
275		}
276	}
277}
278
279func (sws *serverWatchStream) sendLoop() {
280	// watch ids that are currently active
281	ids := make(map[mvcc.WatchID]struct{})
282	// watch responses pending on a watch id creation message
283	pending := make(map[mvcc.WatchID][]*pb.WatchResponse)
284
285	interval := GetProgressReportInterval()
286	progressTicker := time.NewTicker(interval)
287
288	defer func() {
289		progressTicker.Stop()
290		// drain the chan to clean up pending events
291		for ws := range sws.watchStream.Chan() {
292			mvcc.ReportEventReceived(len(ws.Events))
293		}
294		for _, wrs := range pending {
295			for _, ws := range wrs {
296				mvcc.ReportEventReceived(len(ws.Events))
297			}
298		}
299	}()
300
301	for {
302		select {
303		case wresp, ok := <-sws.watchStream.Chan():
304			if !ok {
305				return
306			}
307
308			// TODO: evs is []mvccpb.Event type
309			// either return []*mvccpb.Event from the mvcc package
310			// or define protocol buffer with []mvccpb.Event.
311			evs := wresp.Events
312			events := make([]*mvccpb.Event, len(evs))
313			sws.mu.Lock()
314			needPrevKV := sws.prevKV[wresp.WatchID]
315			sws.mu.Unlock()
316			for i := range evs {
317				events[i] = &evs[i]
318
319				if needPrevKV {
320					opt := mvcc.RangeOptions{Rev: evs[i].Kv.ModRevision - 1}
321					r, err := sws.watchable.Range(evs[i].Kv.Key, nil, opt)
322					if err == nil && len(r.KVs) != 0 {
323						events[i].PrevKv = &(r.KVs[0])
324					}
325				}
326			}
327
328			canceled := wresp.CompactRevision != 0
329			wr := &pb.WatchResponse{
330				Header:          sws.newResponseHeader(wresp.Revision),
331				WatchId:         int64(wresp.WatchID),
332				Events:          events,
333				CompactRevision: wresp.CompactRevision,
334				Canceled:        canceled,
335			}
336
337			if _, hasId := ids[wresp.WatchID]; !hasId {
338				// buffer if id not yet announced
339				wrs := append(pending[wresp.WatchID], wr)
340				pending[wresp.WatchID] = wrs
341				continue
342			}
343
344			mvcc.ReportEventReceived(len(evs))
345			if err := sws.gRPCStream.Send(wr); err != nil {
346				if isClientCtxErr(sws.gRPCStream.Context().Err(), err) {
347					plog.Debugf("failed to send watch response to gRPC stream (%q)", err.Error())
348				} else {
349					plog.Warningf("failed to send watch response to gRPC stream (%q)", err.Error())
350				}
351				return
352			}
353
354			sws.mu.Lock()
355			if len(evs) > 0 && sws.progress[wresp.WatchID] {
356				// elide next progress update if sent a key update
357				sws.progress[wresp.WatchID] = false
358			}
359			sws.mu.Unlock()
360
361		case c, ok := <-sws.ctrlStream:
362			if !ok {
363				return
364			}
365
366			if err := sws.gRPCStream.Send(c); err != nil {
367				if isClientCtxErr(sws.gRPCStream.Context().Err(), err) {
368					plog.Debugf("failed to send watch control response to gRPC stream (%q)", err.Error())
369				} else {
370					plog.Warningf("failed to send watch control response to gRPC stream (%q)", err.Error())
371				}
372				return
373			}
374
375			// track id creation
376			wid := mvcc.WatchID(c.WatchId)
377			if c.Canceled {
378				delete(ids, wid)
379				continue
380			}
381			if c.Created {
382				// flush buffered events
383				ids[wid] = struct{}{}
384				for _, v := range pending[wid] {
385					mvcc.ReportEventReceived(len(v.Events))
386					if err := sws.gRPCStream.Send(v); err != nil {
387						if isClientCtxErr(sws.gRPCStream.Context().Err(), err) {
388							plog.Debugf("failed to send pending watch response to gRPC stream (%q)", err.Error())
389						} else {
390							plog.Warningf("failed to send pending watch response to gRPC stream (%q)", err.Error())
391						}
392						return
393					}
394				}
395				delete(pending, wid)
396			}
397		case <-progressTicker.C:
398			sws.mu.Lock()
399			for id, ok := range sws.progress {
400				if ok {
401					sws.watchStream.RequestProgress(id)
402				}
403				sws.progress[id] = true
404			}
405			sws.mu.Unlock()
406		case <-sws.closec:
407			return
408		}
409	}
410}
411
412func (sws *serverWatchStream) close() {
413	sws.watchStream.Close()
414	close(sws.closec)
415	sws.wg.Wait()
416}
417
418func (sws *serverWatchStream) newResponseHeader(rev int64) *pb.ResponseHeader {
419	return &pb.ResponseHeader{
420		ClusterId: uint64(sws.clusterID),
421		MemberId:  uint64(sws.memberID),
422		Revision:  rev,
423		RaftTerm:  sws.raftTimer.Term(),
424	}
425}
426
427func filterNoDelete(e mvccpb.Event) bool {
428	return e.Type == mvccpb.DELETE
429}
430
431func filterNoPut(e mvccpb.Event) bool {
432	return e.Type == mvccpb.PUT
433}
434
435func FiltersFromRequest(creq *pb.WatchCreateRequest) []mvcc.FilterFunc {
436	filters := make([]mvcc.FilterFunc, 0, len(creq.Filters))
437	for _, ft := range creq.Filters {
438		switch ft {
439		case pb.WatchCreateRequest_NOPUT:
440			filters = append(filters, filterNoPut)
441		case pb.WatchCreateRequest_NODELETE:
442			filters = append(filters, filterNoDelete)
443		default:
444		}
445	}
446	return filters
447}
448