1// Copyright 2015 CoreOS, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package v3rpc
16
17import (
18	"io"
19	"time"
20
21	"github.com/coreos/etcd/etcdserver"
22	pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
23	"github.com/coreos/etcd/storage"
24	"github.com/coreos/etcd/storage/storagepb"
25)
26
27type watchServer struct {
28	clusterID int64
29	memberID  int64
30	raftTimer etcdserver.RaftTimer
31	watchable storage.Watchable
32}
33
34func NewWatchServer(s *etcdserver.EtcdServer) pb.WatchServer {
35	return &watchServer{
36		clusterID: int64(s.Cluster().ID()),
37		memberID:  int64(s.ID()),
38		raftTimer: s,
39		watchable: s.Watchable(),
40	}
41}
42
43var (
44	// expose for testing purpose. External test can change this to a
45	// small value to finish fast.
46	ProgressReportInterval = 10 * time.Minute
47)
48
49const (
50	// We send ctrl response inside the read loop. We do not want
51	// send to block read, but we still want ctrl response we sent to
52	// be serialized. Thus we use a buffered chan to solve the problem.
53	// A small buffer should be OK for most cases, since we expect the
54	// ctrl requests are infrequent.
55	ctrlStreamBufLen = 16
56)
57
58// serverWatchStream is an etcd server side stream. It receives requests
59// from client side gRPC stream. It receives watch events from storage.WatchStream,
60// and creates responses that forwarded to gRPC stream.
61// It also forwards control message like watch created and canceled.
62type serverWatchStream struct {
63	clusterID int64
64	memberID  int64
65	raftTimer etcdserver.RaftTimer
66
67	gRPCStream  pb.Watch_WatchServer
68	watchStream storage.WatchStream
69	ctrlStream  chan *pb.WatchResponse
70
71	// progress tracks the watchID that stream might need to send
72	// progress to.
73	progress map[storage.WatchID]bool
74
75	// closec indicates the stream is closed.
76	closec chan struct{}
77}
78
79func (ws *watchServer) Watch(stream pb.Watch_WatchServer) error {
80	sws := serverWatchStream{
81		clusterID:   ws.clusterID,
82		memberID:    ws.memberID,
83		raftTimer:   ws.raftTimer,
84		gRPCStream:  stream,
85		watchStream: ws.watchable.NewWatchStream(),
86		// chan for sending control response like watcher created and canceled.
87		ctrlStream: make(chan *pb.WatchResponse, ctrlStreamBufLen),
88		progress:   make(map[storage.WatchID]bool),
89		closec:     make(chan struct{}),
90	}
91	defer sws.close()
92
93	go sws.sendLoop()
94	return sws.recvLoop()
95}
96
97func (sws *serverWatchStream) recvLoop() error {
98	for {
99		req, err := sws.gRPCStream.Recv()
100		if err == io.EOF {
101			return nil
102		}
103		if err != nil {
104			return err
105		}
106
107		switch uv := req.RequestUnion.(type) {
108		case *pb.WatchRequest_CreateRequest:
109			if uv.CreateRequest == nil {
110				break
111			}
112
113			creq := uv.CreateRequest
114			if len(creq.Key) == 0 {
115				// \x00 is the smallest key
116				creq.Key = []byte{0}
117			}
118			if len(creq.RangeEnd) == 1 && creq.RangeEnd[0] == 0 {
119				// support  >= key queries
120				creq.RangeEnd = []byte{}
121			}
122			wsrev := sws.watchStream.Rev()
123			rev := creq.StartRevision
124			if rev == 0 {
125				rev = wsrev + 1
126			}
127			id := sws.watchStream.Watch(creq.Key, creq.RangeEnd, rev)
128			if id != -1 && creq.ProgressNotify {
129				sws.progress[id] = true
130			}
131			sws.ctrlStream <- &pb.WatchResponse{
132				Header:   sws.newResponseHeader(wsrev),
133				WatchId:  int64(id),
134				Created:  true,
135				Canceled: id == -1,
136			}
137		case *pb.WatchRequest_CancelRequest:
138			if uv.CancelRequest != nil {
139				id := uv.CancelRequest.WatchId
140				err := sws.watchStream.Cancel(storage.WatchID(id))
141				if err == nil {
142					sws.ctrlStream <- &pb.WatchResponse{
143						Header:   sws.newResponseHeader(sws.watchStream.Rev()),
144						WatchId:  id,
145						Canceled: true,
146					}
147					delete(sws.progress, storage.WatchID(id))
148				}
149			}
150			// TODO: do we need to return error back to client?
151		default:
152			panic("not implemented")
153		}
154	}
155}
156
157func (sws *serverWatchStream) sendLoop() {
158	// watch ids that are currently active
159	ids := make(map[storage.WatchID]struct{})
160	// watch responses pending on a watch id creation message
161	pending := make(map[storage.WatchID][]*pb.WatchResponse)
162
163	progressTicker := time.NewTicker(ProgressReportInterval)
164	defer progressTicker.Stop()
165
166	for {
167		select {
168		case wresp, ok := <-sws.watchStream.Chan():
169			if !ok {
170				return
171			}
172
173			// TODO: evs is []storagepb.Event type
174			// either return []*storagepb.Event from storage package
175			// or define protocol buffer with []storagepb.Event.
176			evs := wresp.Events
177			events := make([]*storagepb.Event, len(evs))
178			for i := range evs {
179				events[i] = &evs[i]
180			}
181
182			wr := &pb.WatchResponse{
183				Header:          sws.newResponseHeader(wresp.Revision),
184				WatchId:         int64(wresp.WatchID),
185				Events:          events,
186				CompactRevision: wresp.CompactRevision,
187			}
188
189			if _, hasId := ids[wresp.WatchID]; !hasId {
190				// buffer if id not yet announced
191				wrs := append(pending[wresp.WatchID], wr)
192				pending[wresp.WatchID] = wrs
193				continue
194			}
195
196			storage.ReportEventReceived()
197			if err := sws.gRPCStream.Send(wr); err != nil {
198				return
199			}
200
201			if _, ok := sws.progress[wresp.WatchID]; ok {
202				sws.progress[wresp.WatchID] = false
203			}
204
205		case c, ok := <-sws.ctrlStream:
206			if !ok {
207				return
208			}
209
210			if err := sws.gRPCStream.Send(c); err != nil {
211				return
212			}
213
214			// track id creation
215			wid := storage.WatchID(c.WatchId)
216			if c.Canceled {
217				delete(ids, wid)
218				continue
219			}
220			if c.Created {
221				// flush buffered events
222				ids[wid] = struct{}{}
223				for _, v := range pending[wid] {
224					storage.ReportEventReceived()
225					if err := sws.gRPCStream.Send(v); err != nil {
226						return
227					}
228				}
229				delete(pending, wid)
230			}
231		case <-progressTicker.C:
232			for id, ok := range sws.progress {
233				if ok {
234					sws.watchStream.RequestProgress(id)
235				}
236				sws.progress[id] = true
237			}
238		case <-sws.closec:
239			// drain the chan to clean up pending events
240			for range sws.watchStream.Chan() {
241				storage.ReportEventReceived()
242			}
243			for _, wrs := range pending {
244				for range wrs {
245					storage.ReportEventReceived()
246				}
247			}
248		}
249	}
250}
251
252func (sws *serverWatchStream) close() {
253	sws.watchStream.Close()
254	close(sws.closec)
255	close(sws.ctrlStream)
256}
257
258func (sws *serverWatchStream) newResponseHeader(rev int64) *pb.ResponseHeader {
259	return &pb.ResponseHeader{
260		ClusterId: uint64(sws.clusterID),
261		MemberId:  uint64(sws.memberID),
262		Revision:  rev,
263		RaftTerm:  sws.raftTimer.Term(),
264	}
265}
266