1// Copyright 2015 The etcd Authors 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package v3rpc 16 17import ( 18 "context" 19 "io" 20 "math/rand" 21 "sync" 22 "time" 23 24 "go.etcd.io/etcd/auth" 25 "go.etcd.io/etcd/etcdserver" 26 "go.etcd.io/etcd/etcdserver/api/v3rpc/rpctypes" 27 pb "go.etcd.io/etcd/etcdserver/etcdserverpb" 28 "go.etcd.io/etcd/mvcc" 29 "go.etcd.io/etcd/mvcc/mvccpb" 30 31 "go.uber.org/zap" 32) 33 34type watchServer struct { 35 lg *zap.Logger 36 37 clusterID int64 38 memberID int64 39 40 maxRequestBytes int 41 42 sg etcdserver.RaftStatusGetter 43 watchable mvcc.WatchableKV 44 ag AuthGetter 45} 46 47// NewWatchServer returns a new watch server. 48func NewWatchServer(s *etcdserver.EtcdServer) pb.WatchServer { 49 srv := &watchServer{ 50 lg: s.Cfg.Logger, 51 52 clusterID: int64(s.Cluster().ID()), 53 memberID: int64(s.ID()), 54 55 maxRequestBytes: int(s.Cfg.MaxRequestBytes + grpcOverheadBytes), 56 57 sg: s, 58 watchable: s.Watchable(), 59 ag: s, 60 } 61 if srv.lg == nil { 62 srv.lg = zap.NewNop() 63 } 64 return srv 65} 66 67var ( 68 // External test can read this with GetProgressReportInterval() 69 // and change this to a small value to finish fast with 70 // SetProgressReportInterval(). 71 progressReportInterval = 10 * time.Minute 72 progressReportIntervalMu sync.RWMutex 73) 74 75// GetProgressReportInterval returns the current progress report interval (for testing). 76func GetProgressReportInterval() time.Duration { 77 progressReportIntervalMu.RLock() 78 interval := progressReportInterval 79 progressReportIntervalMu.RUnlock() 80 81 // add rand(1/10*progressReportInterval) as jitter so that etcdserver will not 82 // send progress notifications to watchers around the same time even when watchers 83 // are created around the same time (which is common when a client restarts itself). 84 jitter := time.Duration(rand.Int63n(int64(interval) / 10)) 85 86 return interval + jitter 87} 88 89// SetProgressReportInterval updates the current progress report interval (for testing). 90func SetProgressReportInterval(newTimeout time.Duration) { 91 progressReportIntervalMu.Lock() 92 progressReportInterval = newTimeout 93 progressReportIntervalMu.Unlock() 94} 95 96// We send ctrl response inside the read loop. We do not want 97// send to block read, but we still want ctrl response we sent to 98// be serialized. Thus we use a buffered chan to solve the problem. 99// A small buffer should be OK for most cases, since we expect the 100// ctrl requests are infrequent. 101const ctrlStreamBufLen = 16 102 103// serverWatchStream is an etcd server side stream. It receives requests 104// from client side gRPC stream. It receives watch events from mvcc.WatchStream, 105// and creates responses that forwarded to gRPC stream. 106// It also forwards control message like watch created and canceled. 107type serverWatchStream struct { 108 lg *zap.Logger 109 110 clusterID int64 111 memberID int64 112 113 maxRequestBytes int 114 115 sg etcdserver.RaftStatusGetter 116 watchable mvcc.WatchableKV 117 ag AuthGetter 118 119 gRPCStream pb.Watch_WatchServer 120 watchStream mvcc.WatchStream 121 ctrlStream chan *pb.WatchResponse 122 123 // mu protects progress, prevKV, fragment 124 mu sync.RWMutex 125 // tracks the watchID that stream might need to send progress to 126 // TODO: combine progress and prevKV into a single struct? 127 progress map[mvcc.WatchID]bool 128 // record watch IDs that need return previous key-value pair 129 prevKV map[mvcc.WatchID]bool 130 // records fragmented watch IDs 131 fragment map[mvcc.WatchID]bool 132 133 // closec indicates the stream is closed. 134 closec chan struct{} 135 136 // wg waits for the send loop to complete 137 wg sync.WaitGroup 138} 139 140func (ws *watchServer) Watch(stream pb.Watch_WatchServer) (err error) { 141 sws := serverWatchStream{ 142 lg: ws.lg, 143 144 clusterID: ws.clusterID, 145 memberID: ws.memberID, 146 147 maxRequestBytes: ws.maxRequestBytes, 148 149 sg: ws.sg, 150 watchable: ws.watchable, 151 ag: ws.ag, 152 153 gRPCStream: stream, 154 watchStream: ws.watchable.NewWatchStream(), 155 // chan for sending control response like watcher created and canceled. 156 ctrlStream: make(chan *pb.WatchResponse, ctrlStreamBufLen), 157 158 progress: make(map[mvcc.WatchID]bool), 159 prevKV: make(map[mvcc.WatchID]bool), 160 fragment: make(map[mvcc.WatchID]bool), 161 162 closec: make(chan struct{}), 163 } 164 165 sws.wg.Add(1) 166 go func() { 167 sws.sendLoop() 168 sws.wg.Done() 169 }() 170 171 errc := make(chan error, 1) 172 // Ideally recvLoop would also use sws.wg to signal its completion 173 // but when stream.Context().Done() is closed, the stream's recv 174 // may continue to block since it uses a different context, leading to 175 // deadlock when calling sws.close(). 176 go func() { 177 if rerr := sws.recvLoop(); rerr != nil { 178 if isClientCtxErr(stream.Context().Err(), rerr) { 179 sws.lg.Debug("failed to receive watch request from gRPC stream", zap.Error(rerr)) 180 } else { 181 sws.lg.Warn("failed to receive watch request from gRPC stream", zap.Error(rerr)) 182 streamFailures.WithLabelValues("receive", "watch").Inc() 183 } 184 errc <- rerr 185 } 186 }() 187 188 select { 189 case err = <-errc: 190 close(sws.ctrlStream) 191 192 case <-stream.Context().Done(): 193 err = stream.Context().Err() 194 // the only server-side cancellation is noleader for now. 195 if err == context.Canceled { 196 err = rpctypes.ErrGRPCNoLeader 197 } 198 } 199 200 sws.close() 201 return err 202} 203 204func (sws *serverWatchStream) isWatchPermitted(wcr *pb.WatchCreateRequest) bool { 205 authInfo, err := sws.ag.AuthInfoFromCtx(sws.gRPCStream.Context()) 206 if err != nil { 207 return false 208 } 209 if authInfo == nil { 210 // if auth is enabled, IsRangePermitted() can cause an error 211 authInfo = &auth.AuthInfo{} 212 } 213 return sws.ag.AuthStore().IsRangePermitted(authInfo, wcr.Key, wcr.RangeEnd) == nil 214} 215 216func (sws *serverWatchStream) recvLoop() error { 217 for { 218 req, err := sws.gRPCStream.Recv() 219 if err == io.EOF { 220 return nil 221 } 222 if err != nil { 223 return err 224 } 225 226 switch uv := req.RequestUnion.(type) { 227 case *pb.WatchRequest_CreateRequest: 228 if uv.CreateRequest == nil { 229 break 230 } 231 232 creq := uv.CreateRequest 233 if len(creq.Key) == 0 { 234 // \x00 is the smallest key 235 creq.Key = []byte{0} 236 } 237 if len(creq.RangeEnd) == 0 { 238 // force nil since watchstream.Watch distinguishes 239 // between nil and []byte{} for single key / >= 240 creq.RangeEnd = nil 241 } 242 if len(creq.RangeEnd) == 1 && creq.RangeEnd[0] == 0 { 243 // support >= key queries 244 creq.RangeEnd = []byte{} 245 } 246 247 if !sws.isWatchPermitted(creq) { 248 wr := &pb.WatchResponse{ 249 Header: sws.newResponseHeader(sws.watchStream.Rev()), 250 WatchId: creq.WatchId, 251 Canceled: true, 252 Created: true, 253 CancelReason: rpctypes.ErrGRPCPermissionDenied.Error(), 254 } 255 256 select { 257 case sws.ctrlStream <- wr: 258 continue 259 case <-sws.closec: 260 return nil 261 } 262 } 263 264 filters := FiltersFromRequest(creq) 265 266 wsrev := sws.watchStream.Rev() 267 rev := creq.StartRevision 268 if rev == 0 { 269 rev = wsrev + 1 270 } 271 id, err := sws.watchStream.Watch(mvcc.WatchID(creq.WatchId), creq.Key, creq.RangeEnd, rev, filters...) 272 if err == nil { 273 sws.mu.Lock() 274 if creq.ProgressNotify { 275 sws.progress[id] = true 276 } 277 if creq.PrevKv { 278 sws.prevKV[id] = true 279 } 280 if creq.Fragment { 281 sws.fragment[id] = true 282 } 283 sws.mu.Unlock() 284 } 285 wr := &pb.WatchResponse{ 286 Header: sws.newResponseHeader(wsrev), 287 WatchId: int64(id), 288 Created: true, 289 Canceled: err != nil, 290 } 291 if err != nil { 292 wr.CancelReason = err.Error() 293 } 294 select { 295 case sws.ctrlStream <- wr: 296 case <-sws.closec: 297 return nil 298 } 299 300 case *pb.WatchRequest_CancelRequest: 301 if uv.CancelRequest != nil { 302 id := uv.CancelRequest.WatchId 303 err := sws.watchStream.Cancel(mvcc.WatchID(id)) 304 if err == nil { 305 sws.ctrlStream <- &pb.WatchResponse{ 306 Header: sws.newResponseHeader(sws.watchStream.Rev()), 307 WatchId: id, 308 Canceled: true, 309 } 310 sws.mu.Lock() 311 delete(sws.progress, mvcc.WatchID(id)) 312 delete(sws.prevKV, mvcc.WatchID(id)) 313 delete(sws.fragment, mvcc.WatchID(id)) 314 sws.mu.Unlock() 315 } 316 } 317 case *pb.WatchRequest_ProgressRequest: 318 if uv.ProgressRequest != nil { 319 sws.ctrlStream <- &pb.WatchResponse{ 320 Header: sws.newResponseHeader(sws.watchStream.Rev()), 321 WatchId: -1, // response is not associated with any WatchId and will be broadcast to all watch channels 322 } 323 } 324 default: 325 // we probably should not shutdown the entire stream when 326 // receive an valid command. 327 // so just do nothing instead. 328 continue 329 } 330 } 331} 332 333func (sws *serverWatchStream) sendLoop() { 334 // watch ids that are currently active 335 ids := make(map[mvcc.WatchID]struct{}) 336 // watch responses pending on a watch id creation message 337 pending := make(map[mvcc.WatchID][]*pb.WatchResponse) 338 339 interval := GetProgressReportInterval() 340 progressTicker := time.NewTicker(interval) 341 342 defer func() { 343 progressTicker.Stop() 344 // drain the chan to clean up pending events 345 for ws := range sws.watchStream.Chan() { 346 mvcc.ReportEventReceived(len(ws.Events)) 347 } 348 for _, wrs := range pending { 349 for _, ws := range wrs { 350 mvcc.ReportEventReceived(len(ws.Events)) 351 } 352 } 353 }() 354 355 for { 356 select { 357 case wresp, ok := <-sws.watchStream.Chan(): 358 if !ok { 359 return 360 } 361 362 // TODO: evs is []mvccpb.Event type 363 // either return []*mvccpb.Event from the mvcc package 364 // or define protocol buffer with []mvccpb.Event. 365 evs := wresp.Events 366 events := make([]*mvccpb.Event, len(evs)) 367 sws.mu.RLock() 368 needPrevKV := sws.prevKV[wresp.WatchID] 369 sws.mu.RUnlock() 370 for i := range evs { 371 events[i] = &evs[i] 372 if needPrevKV { 373 opt := mvcc.RangeOptions{Rev: evs[i].Kv.ModRevision - 1} 374 r, err := sws.watchable.Range(evs[i].Kv.Key, nil, opt) 375 if err == nil && len(r.KVs) != 0 { 376 events[i].PrevKv = &(r.KVs[0]) 377 } 378 } 379 } 380 381 canceled := wresp.CompactRevision != 0 382 wr := &pb.WatchResponse{ 383 Header: sws.newResponseHeader(wresp.Revision), 384 WatchId: int64(wresp.WatchID), 385 Events: events, 386 CompactRevision: wresp.CompactRevision, 387 Canceled: canceled, 388 } 389 390 if _, okID := ids[wresp.WatchID]; !okID { 391 // buffer if id not yet announced 392 wrs := append(pending[wresp.WatchID], wr) 393 pending[wresp.WatchID] = wrs 394 continue 395 } 396 397 mvcc.ReportEventReceived(len(evs)) 398 399 sws.mu.RLock() 400 fragmented, ok := sws.fragment[wresp.WatchID] 401 sws.mu.RUnlock() 402 403 var serr error 404 if !fragmented && !ok { 405 serr = sws.gRPCStream.Send(wr) 406 } else { 407 serr = sendFragments(wr, sws.maxRequestBytes, sws.gRPCStream.Send) 408 } 409 410 if serr != nil { 411 if isClientCtxErr(sws.gRPCStream.Context().Err(), serr) { 412 sws.lg.Debug("failed to send watch response to gRPC stream", zap.Error(serr)) 413 } else { 414 sws.lg.Warn("failed to send watch response to gRPC stream", zap.Error(serr)) 415 streamFailures.WithLabelValues("send", "watch").Inc() 416 } 417 return 418 } 419 420 sws.mu.Lock() 421 if len(evs) > 0 && sws.progress[wresp.WatchID] { 422 // elide next progress update if sent a key update 423 sws.progress[wresp.WatchID] = false 424 } 425 sws.mu.Unlock() 426 427 case c, ok := <-sws.ctrlStream: 428 if !ok { 429 return 430 } 431 432 if err := sws.gRPCStream.Send(c); err != nil { 433 if isClientCtxErr(sws.gRPCStream.Context().Err(), err) { 434 sws.lg.Debug("failed to send watch control response to gRPC stream", zap.Error(err)) 435 } else { 436 sws.lg.Warn("failed to send watch control response to gRPC stream", zap.Error(err)) 437 streamFailures.WithLabelValues("send", "watch").Inc() 438 } 439 return 440 } 441 442 // track id creation 443 wid := mvcc.WatchID(c.WatchId) 444 if c.Canceled { 445 delete(ids, wid) 446 continue 447 } 448 if c.Created { 449 // flush buffered events 450 ids[wid] = struct{}{} 451 for _, v := range pending[wid] { 452 mvcc.ReportEventReceived(len(v.Events)) 453 if err := sws.gRPCStream.Send(v); err != nil { 454 if isClientCtxErr(sws.gRPCStream.Context().Err(), err) { 455 sws.lg.Debug("failed to send pending watch response to gRPC stream", zap.Error(err)) 456 } else { 457 sws.lg.Warn("failed to send pending watch response to gRPC stream", zap.Error(err)) 458 streamFailures.WithLabelValues("send", "watch").Inc() 459 } 460 return 461 } 462 } 463 delete(pending, wid) 464 } 465 466 case <-progressTicker.C: 467 sws.mu.Lock() 468 for id, ok := range sws.progress { 469 if ok { 470 sws.watchStream.RequestProgress(id) 471 } 472 sws.progress[id] = true 473 } 474 sws.mu.Unlock() 475 476 case <-sws.closec: 477 return 478 } 479 } 480} 481 482func sendFragments( 483 wr *pb.WatchResponse, 484 maxRequestBytes int, 485 sendFunc func(*pb.WatchResponse) error) error { 486 // no need to fragment if total request size is smaller 487 // than max request limit or response contains only one event 488 if wr.Size() < maxRequestBytes || len(wr.Events) < 2 { 489 return sendFunc(wr) 490 } 491 492 ow := *wr 493 ow.Events = make([]*mvccpb.Event, 0) 494 ow.Fragment = true 495 496 var idx int 497 for { 498 cur := ow 499 for _, ev := range wr.Events[idx:] { 500 cur.Events = append(cur.Events, ev) 501 if len(cur.Events) > 1 && cur.Size() >= maxRequestBytes { 502 cur.Events = cur.Events[:len(cur.Events)-1] 503 break 504 } 505 idx++ 506 } 507 if idx == len(wr.Events) { 508 // last response has no more fragment 509 cur.Fragment = false 510 } 511 if err := sendFunc(&cur); err != nil { 512 return err 513 } 514 if !cur.Fragment { 515 break 516 } 517 } 518 return nil 519} 520 521func (sws *serverWatchStream) close() { 522 sws.watchStream.Close() 523 close(sws.closec) 524 sws.wg.Wait() 525} 526 527func (sws *serverWatchStream) newResponseHeader(rev int64) *pb.ResponseHeader { 528 return &pb.ResponseHeader{ 529 ClusterId: uint64(sws.clusterID), 530 MemberId: uint64(sws.memberID), 531 Revision: rev, 532 RaftTerm: sws.sg.Term(), 533 } 534} 535 536func filterNoDelete(e mvccpb.Event) bool { 537 return e.Type == mvccpb.DELETE 538} 539 540func filterNoPut(e mvccpb.Event) bool { 541 return e.Type == mvccpb.PUT 542} 543 544// FiltersFromRequest returns "mvcc.FilterFunc" from a given watch create request. 545func FiltersFromRequest(creq *pb.WatchCreateRequest) []mvcc.FilterFunc { 546 filters := make([]mvcc.FilterFunc, 0, len(creq.Filters)) 547 for _, ft := range creq.Filters { 548 switch ft { 549 case pb.WatchCreateRequest_NOPUT: 550 filters = append(filters, filterNoPut) 551 case pb.WatchCreateRequest_NODELETE: 552 filters = append(filters, filterNoDelete) 553 default: 554 } 555 } 556 return filters 557} 558