1// Copyright 2017 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15// Package pstest provides a fake Cloud PubSub service for testing. It implements a 16// simplified form of the service, suitable for unit tests. It may behave 17// differently from the actual service in ways in which the service is 18// non-deterministic or unspecified: timing, delivery order, etc. 19// 20// This package is EXPERIMENTAL and is subject to change without notice. 21// 22// See the example for usage. 23package pstest 24 25import ( 26 "context" 27 "fmt" 28 "io" 29 "path" 30 "sort" 31 "strings" 32 "sync" 33 "sync/atomic" 34 "time" 35 36 "cloud.google.com/go/internal/testutil" 37 "github.com/golang/protobuf/ptypes" 38 durpb "github.com/golang/protobuf/ptypes/duration" 39 emptypb "github.com/golang/protobuf/ptypes/empty" 40 pb "google.golang.org/genproto/googleapis/pubsub/v1" 41 "google.golang.org/grpc/codes" 42 "google.golang.org/grpc/status" 43) 44 45// For testing. Note that even though changes to the now variable are atomic, a call 46// to the stored function can race with a change to that function. This could be a 47// problem if tests are run in parallel, or even if concurrent parts of the same test 48// change the value of the variable. 49var now atomic.Value 50 51func init() { 52 now.Store(time.Now) 53 ResetMinAckDeadline() 54} 55 56func timeNow() time.Time { 57 return now.Load().(func() time.Time)() 58} 59 60// Server is a fake Pub/Sub server. 61type Server struct { 62 srv *testutil.Server 63 Addr string // The address that the server is listening on. 64 GServer GServer // Not intended to be used directly. 65} 66 67// GServer is the underlying service implementor. It is not intended to be used 68// directly. 69type GServer struct { 70 pb.PublisherServer 71 pb.SubscriberServer 72 73 mu sync.Mutex 74 topics map[string]*topic 75 subs map[string]*subscription 76 msgs []*Message // all messages ever published 77 msgsByID map[string]*Message 78 wg sync.WaitGroup 79 nextID int 80 streamTimeout time.Duration 81} 82 83// NewServer creates a new fake server running in the current process. 84func NewServer() *Server { 85 srv, err := testutil.NewServer() 86 if err != nil { 87 panic(fmt.Sprintf("pstest.NewServer: %v", err)) 88 } 89 s := &Server{ 90 srv: srv, 91 Addr: srv.Addr, 92 GServer: GServer{ 93 topics: map[string]*topic{}, 94 subs: map[string]*subscription{}, 95 msgsByID: map[string]*Message{}, 96 }, 97 } 98 pb.RegisterPublisherServer(srv.Gsrv, &s.GServer) 99 pb.RegisterSubscriberServer(srv.Gsrv, &s.GServer) 100 srv.Start() 101 return s 102} 103 104// Publish behaves as if the Publish RPC was called with a message with the given 105// data and attrs. It returns the ID of the message. 106// The topic will be created if it doesn't exist. 107// 108// Publish panics if there is an error, which is appropriate for testing. 109func (s *Server) Publish(topic string, data []byte, attrs map[string]string) string { 110 const topicPattern = "projects/*/topics/*" 111 ok, err := path.Match(topicPattern, topic) 112 if err != nil { 113 panic(err) 114 } 115 if !ok { 116 panic(fmt.Sprintf("topic name must be of the form %q", topicPattern)) 117 } 118 _, _ = s.GServer.CreateTopic(context.TODO(), &pb.Topic{Name: topic}) 119 req := &pb.PublishRequest{ 120 Topic: topic, 121 Messages: []*pb.PubsubMessage{{Data: data, Attributes: attrs}}, 122 } 123 res, err := s.GServer.Publish(context.TODO(), req) 124 if err != nil { 125 panic(fmt.Sprintf("pstest.Server.Publish: %v", err)) 126 } 127 return res.MessageIds[0] 128} 129 130// SetStreamTimeout sets the amount of time a stream will be active before it shuts 131// itself down. This mimics the real service's behavior of closing streams after 30 132// minutes. If SetStreamTimeout is never called or is passed zero, streams never shut 133// down. 134func (s *Server) SetStreamTimeout(d time.Duration) { 135 s.GServer.mu.Lock() 136 defer s.GServer.mu.Unlock() 137 s.GServer.streamTimeout = d 138} 139 140// A Message is a message that was published to the server. 141type Message struct { 142 ID string 143 Data []byte 144 Attributes map[string]string 145 PublishTime time.Time 146 Deliveries int // number of times delivery of the message was attempted 147 Acks int // number of acks received from clients 148 149 // protected by server mutex 150 deliveries int 151 acks int 152 Modacks []Modack // modacks received by server for this message 153 154} 155 156// Modack represents a modack sent to the server. 157type Modack struct { 158 AckID string 159 AckDeadline int32 160 ReceivedAt time.Time 161} 162 163// Messages returns information about all messages ever published. 164func (s *Server) Messages() []*Message { 165 s.GServer.mu.Lock() 166 defer s.GServer.mu.Unlock() 167 168 var msgs []*Message 169 for _, m := range s.GServer.msgs { 170 m.Deliveries = m.deliveries 171 m.Acks = m.acks 172 msgs = append(msgs, m) 173 } 174 return msgs 175} 176 177// Message returns the message with the given ID, or nil if no message 178// with that ID was published. 179func (s *Server) Message(id string) *Message { 180 s.GServer.mu.Lock() 181 defer s.GServer.mu.Unlock() 182 183 m := s.GServer.msgsByID[id] 184 if m != nil { 185 m.Deliveries = m.deliveries 186 m.Acks = m.acks 187 } 188 return m 189} 190 191// Wait blocks until all server activity has completed. 192func (s *Server) Wait() { 193 s.GServer.wg.Wait() 194} 195 196// ClearMessages removes all published messages 197// from internal containers. 198func (s *Server) ClearMessages() { 199 s.GServer.mu.Lock() 200 s.GServer.msgs = nil 201 s.GServer.msgsByID = make(map[string]*Message) 202 s.GServer.mu.Unlock() 203} 204 205// Close shuts down the server and releases all resources. 206func (s *Server) Close() error { 207 s.srv.Close() 208 s.GServer.mu.Lock() 209 defer s.GServer.mu.Unlock() 210 for _, sub := range s.GServer.subs { 211 sub.stop() 212 } 213 return nil 214} 215 216func (s *GServer) CreateTopic(_ context.Context, t *pb.Topic) (*pb.Topic, error) { 217 s.mu.Lock() 218 defer s.mu.Unlock() 219 220 if s.topics[t.Name] != nil { 221 return nil, status.Errorf(codes.AlreadyExists, "topic %q", t.Name) 222 } 223 top := newTopic(t) 224 s.topics[t.Name] = top 225 return top.proto, nil 226} 227 228func (s *GServer) GetTopic(_ context.Context, req *pb.GetTopicRequest) (*pb.Topic, error) { 229 s.mu.Lock() 230 defer s.mu.Unlock() 231 232 if t := s.topics[req.Topic]; t != nil { 233 return t.proto, nil 234 } 235 return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic) 236} 237 238func (s *GServer) UpdateTopic(_ context.Context, req *pb.UpdateTopicRequest) (*pb.Topic, error) { 239 s.mu.Lock() 240 defer s.mu.Unlock() 241 242 t := s.topics[req.Topic.Name] 243 if t == nil { 244 return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic.Name) 245 } 246 for _, path := range req.UpdateMask.Paths { 247 switch path { 248 case "labels": 249 t.proto.Labels = req.Topic.Labels 250 case "message_storage_policy": 251 t.proto.MessageStoragePolicy = req.Topic.MessageStoragePolicy 252 default: 253 return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path) 254 } 255 } 256 return t.proto, nil 257} 258 259func (s *GServer) ListTopics(_ context.Context, req *pb.ListTopicsRequest) (*pb.ListTopicsResponse, error) { 260 s.mu.Lock() 261 defer s.mu.Unlock() 262 263 var names []string 264 for n := range s.topics { 265 if strings.HasPrefix(n, req.Project) { 266 names = append(names, n) 267 } 268 } 269 sort.Strings(names) 270 from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names)) 271 if err != nil { 272 return nil, err 273 } 274 res := &pb.ListTopicsResponse{NextPageToken: nextToken} 275 for i := from; i < to; i++ { 276 res.Topics = append(res.Topics, s.topics[names[i]].proto) 277 } 278 return res, nil 279} 280 281func (s *GServer) ListTopicSubscriptions(_ context.Context, req *pb.ListTopicSubscriptionsRequest) (*pb.ListTopicSubscriptionsResponse, error) { 282 s.mu.Lock() 283 defer s.mu.Unlock() 284 285 var names []string 286 for name, sub := range s.subs { 287 if sub.topic.proto.Name == req.Topic { 288 names = append(names, name) 289 } 290 } 291 sort.Strings(names) 292 from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names)) 293 if err != nil { 294 return nil, err 295 } 296 return &pb.ListTopicSubscriptionsResponse{ 297 Subscriptions: names[from:to], 298 NextPageToken: nextToken, 299 }, nil 300} 301 302func (s *GServer) DeleteTopic(_ context.Context, req *pb.DeleteTopicRequest) (*emptypb.Empty, error) { 303 s.mu.Lock() 304 defer s.mu.Unlock() 305 306 t := s.topics[req.Topic] 307 if t == nil { 308 return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic) 309 } 310 t.stop() 311 delete(s.topics, req.Topic) 312 return &emptypb.Empty{}, nil 313} 314 315func (s *GServer) CreateSubscription(_ context.Context, ps *pb.Subscription) (*pb.Subscription, error) { 316 s.mu.Lock() 317 defer s.mu.Unlock() 318 319 if ps.Name == "" { 320 return nil, status.Errorf(codes.InvalidArgument, "missing name") 321 } 322 if s.subs[ps.Name] != nil { 323 return nil, status.Errorf(codes.AlreadyExists, "subscription %q", ps.Name) 324 } 325 if ps.Topic == "" { 326 return nil, status.Errorf(codes.InvalidArgument, "missing topic") 327 } 328 top := s.topics[ps.Topic] 329 if top == nil { 330 return nil, status.Errorf(codes.NotFound, "topic %q", ps.Topic) 331 } 332 if err := checkAckDeadline(ps.AckDeadlineSeconds); err != nil { 333 return nil, err 334 } 335 if ps.MessageRetentionDuration == nil { 336 ps.MessageRetentionDuration = defaultMessageRetentionDuration 337 } 338 if err := checkMRD(ps.MessageRetentionDuration); err != nil { 339 return nil, err 340 } 341 if ps.PushConfig == nil { 342 ps.PushConfig = &pb.PushConfig{} 343 } 344 345 sub := newSubscription(top, &s.mu, ps) 346 top.subs[ps.Name] = sub 347 s.subs[ps.Name] = sub 348 sub.start(&s.wg) 349 return ps, nil 350} 351 352// Can be set for testing. 353var minAckDeadlineSecs int32 354 355// SetMinAckDeadline changes the minack deadline to n. Must be 356// greater than or equal to 1 second. Remember to reset this value 357// to the default after your test changes it. Example usage: 358// pstest.SetMinAckDeadlineSecs(1) 359// defer pstest.ResetMinAckDeadlineSecs() 360func SetMinAckDeadline(n time.Duration) { 361 if n < time.Second { 362 panic("SetMinAckDeadline expects a value greater than 1 second") 363 } 364 365 minAckDeadlineSecs = int32(n / time.Second) 366} 367 368// ResetMinAckDeadline resets the minack deadline to the default. 369func ResetMinAckDeadline() { 370 minAckDeadlineSecs = 10 371} 372 373func checkAckDeadline(ads int32) error { 374 if ads < minAckDeadlineSecs || ads > 600 { 375 // PubSub service returns Unknown. 376 return status.Errorf(codes.Unknown, "bad ack_deadline_seconds: %d", ads) 377 } 378 return nil 379} 380 381const ( 382 minMessageRetentionDuration = 10 * time.Minute 383 maxMessageRetentionDuration = 168 * time.Hour 384) 385 386var defaultMessageRetentionDuration = ptypes.DurationProto(maxMessageRetentionDuration) 387 388func checkMRD(pmrd *durpb.Duration) error { 389 mrd, err := ptypes.Duration(pmrd) 390 if err != nil || mrd < minMessageRetentionDuration || mrd > maxMessageRetentionDuration { 391 return status.Errorf(codes.InvalidArgument, "bad message_retention_duration %+v", pmrd) 392 } 393 return nil 394} 395 396func (s *GServer) GetSubscription(_ context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) { 397 s.mu.Lock() 398 defer s.mu.Unlock() 399 sub, err := s.findSubscription(req.Subscription) 400 if err != nil { 401 return nil, err 402 } 403 return sub.proto, nil 404} 405 406func (s *GServer) UpdateSubscription(_ context.Context, req *pb.UpdateSubscriptionRequest) (*pb.Subscription, error) { 407 if req.Subscription == nil { 408 return nil, status.Errorf(codes.InvalidArgument, "missing subscription") 409 } 410 s.mu.Lock() 411 defer s.mu.Unlock() 412 sub, err := s.findSubscription(req.Subscription.Name) 413 if err != nil { 414 return nil, err 415 } 416 for _, path := range req.UpdateMask.Paths { 417 switch path { 418 case "push_config": 419 sub.proto.PushConfig = req.Subscription.PushConfig 420 421 case "ack_deadline_seconds": 422 a := req.Subscription.AckDeadlineSeconds 423 if err := checkAckDeadline(a); err != nil { 424 return nil, err 425 } 426 sub.proto.AckDeadlineSeconds = a 427 428 case "retain_acked_messages": 429 sub.proto.RetainAckedMessages = req.Subscription.RetainAckedMessages 430 431 case "message_retention_duration": 432 if err := checkMRD(req.Subscription.MessageRetentionDuration); err != nil { 433 return nil, err 434 } 435 sub.proto.MessageRetentionDuration = req.Subscription.MessageRetentionDuration 436 437 case "labels": 438 sub.proto.Labels = req.Subscription.Labels 439 440 case "expiration_policy": 441 sub.proto.ExpirationPolicy = req.Subscription.ExpirationPolicy 442 443 default: 444 return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path) 445 } 446 } 447 return sub.proto, nil 448} 449 450func (s *GServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptionsRequest) (*pb.ListSubscriptionsResponse, error) { 451 s.mu.Lock() 452 defer s.mu.Unlock() 453 454 var names []string 455 for name := range s.subs { 456 if strings.HasPrefix(name, req.Project) { 457 names = append(names, name) 458 } 459 } 460 sort.Strings(names) 461 from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names)) 462 if err != nil { 463 return nil, err 464 } 465 res := &pb.ListSubscriptionsResponse{NextPageToken: nextToken} 466 for i := from; i < to; i++ { 467 res.Subscriptions = append(res.Subscriptions, s.subs[names[i]].proto) 468 } 469 return res, nil 470} 471 472func (s *GServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) { 473 s.mu.Lock() 474 defer s.mu.Unlock() 475 sub, err := s.findSubscription(req.Subscription) 476 if err != nil { 477 return nil, err 478 } 479 sub.stop() 480 delete(s.subs, req.Subscription) 481 sub.topic.deleteSub(sub) 482 return &emptypb.Empty{}, nil 483} 484 485func (s *GServer) Publish(_ context.Context, req *pb.PublishRequest) (*pb.PublishResponse, error) { 486 s.mu.Lock() 487 defer s.mu.Unlock() 488 489 if req.Topic == "" { 490 return nil, status.Errorf(codes.InvalidArgument, "missing topic") 491 } 492 top := s.topics[req.Topic] 493 if top == nil { 494 return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic) 495 } 496 var ids []string 497 for _, pm := range req.Messages { 498 id := fmt.Sprintf("m%d", s.nextID) 499 s.nextID++ 500 pm.MessageId = id 501 pubTime := timeNow() 502 tsPubTime, err := ptypes.TimestampProto(pubTime) 503 if err != nil { 504 return nil, status.Errorf(codes.Internal, err.Error()) 505 } 506 pm.PublishTime = tsPubTime 507 m := &Message{ 508 ID: id, 509 Data: pm.Data, 510 Attributes: pm.Attributes, 511 PublishTime: pubTime, 512 } 513 top.publish(pm, m) 514 ids = append(ids, id) 515 s.msgs = append(s.msgs, m) 516 s.msgsByID[id] = m 517 } 518 return &pb.PublishResponse{MessageIds: ids}, nil 519} 520 521type topic struct { 522 proto *pb.Topic 523 subs map[string]*subscription 524} 525 526func newTopic(pt *pb.Topic) *topic { 527 return &topic{ 528 proto: pt, 529 subs: map[string]*subscription{}, 530 } 531} 532 533func (t *topic) stop() { 534 for _, sub := range t.subs { 535 sub.proto.Topic = "_deleted-topic_" 536 sub.stop() 537 } 538} 539 540func (t *topic) deleteSub(sub *subscription) { 541 delete(t.subs, sub.proto.Name) 542} 543 544func (t *topic) publish(pm *pb.PubsubMessage, m *Message) { 545 for _, s := range t.subs { 546 s.msgs[pm.MessageId] = &message{ 547 publishTime: m.PublishTime, 548 proto: &pb.ReceivedMessage{ 549 AckId: pm.MessageId, 550 Message: pm, 551 }, 552 deliveries: &m.deliveries, 553 acks: &m.acks, 554 streamIndex: -1, 555 } 556 } 557} 558 559type subscription struct { 560 topic *topic 561 mu *sync.Mutex // the server mutex, here for convenience 562 proto *pb.Subscription 563 ackTimeout time.Duration 564 msgs map[string]*message // unacked messages by message ID 565 streams []*stream 566 done chan struct{} 567} 568 569func newSubscription(t *topic, mu *sync.Mutex, ps *pb.Subscription) *subscription { 570 at := time.Duration(ps.AckDeadlineSeconds) * time.Second 571 if at == 0 { 572 at = 10 * time.Second 573 } 574 return &subscription{ 575 topic: t, 576 mu: mu, 577 proto: ps, 578 ackTimeout: at, 579 msgs: map[string]*message{}, 580 done: make(chan struct{}), 581 } 582} 583 584func (s *subscription) start(wg *sync.WaitGroup) { 585 wg.Add(1) 586 go func() { 587 defer wg.Done() 588 for { 589 select { 590 case <-s.done: 591 return 592 case <-time.After(10 * time.Millisecond): 593 s.deliver() 594 } 595 } 596 }() 597} 598 599func (s *subscription) stop() { 600 close(s.done) 601} 602 603func (s *GServer) Acknowledge(_ context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) { 604 s.mu.Lock() 605 defer s.mu.Unlock() 606 607 sub, err := s.findSubscription(req.Subscription) 608 if err != nil { 609 return nil, err 610 } 611 for _, id := range req.AckIds { 612 sub.ack(id) 613 } 614 return &emptypb.Empty{}, nil 615} 616 617func (s *GServer) ModifyAckDeadline(_ context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) { 618 s.mu.Lock() 619 defer s.mu.Unlock() 620 sub, err := s.findSubscription(req.Subscription) 621 if err != nil { 622 return nil, err 623 } 624 now := time.Now() 625 for _, id := range req.AckIds { 626 s.msgsByID[id].Modacks = append(s.msgsByID[id].Modacks, Modack{AckID: id, AckDeadline: req.AckDeadlineSeconds, ReceivedAt: now}) 627 } 628 dur := secsToDur(req.AckDeadlineSeconds) 629 for _, id := range req.AckIds { 630 sub.modifyAckDeadline(id, dur) 631 } 632 return &emptypb.Empty{}, nil 633} 634 635func (s *GServer) Pull(ctx context.Context, req *pb.PullRequest) (*pb.PullResponse, error) { 636 s.mu.Lock() 637 sub, err := s.findSubscription(req.Subscription) 638 if err != nil { 639 s.mu.Unlock() 640 return nil, err 641 } 642 max := int(req.MaxMessages) 643 if max < 0 { 644 s.mu.Unlock() 645 return nil, status.Error(codes.InvalidArgument, "MaxMessages cannot be negative") 646 } 647 if max == 0 { // MaxMessages not specified; use a default. 648 max = 1000 649 } 650 msgs := sub.pull(max) 651 s.mu.Unlock() 652 // Implement the spec from the pubsub proto: 653 // "If ReturnImmediately set to true, the system will respond immediately even if 654 // it there are no messages available to return in the `Pull` response. 655 // Otherwise, the system may wait (for a bounded amount of time) until at 656 // least one message is available, rather than returning no messages." 657 if len(msgs) == 0 && !req.ReturnImmediately { 658 // Wait for a short amount of time for a message. 659 // TODO: signal when a message arrives, so we don't wait the whole time. 660 select { 661 case <-ctx.Done(): 662 return nil, ctx.Err() 663 case <-time.After(500 * time.Millisecond): 664 s.mu.Lock() 665 msgs = sub.pull(max) 666 s.mu.Unlock() 667 } 668 } 669 return &pb.PullResponse{ReceivedMessages: msgs}, nil 670} 671 672func (s *GServer) StreamingPull(sps pb.Subscriber_StreamingPullServer) error { 673 // Receive initial message configuring the pull. 674 req, err := sps.Recv() 675 if err != nil { 676 return err 677 } 678 s.mu.Lock() 679 sub, err := s.findSubscription(req.Subscription) 680 s.mu.Unlock() 681 if err != nil { 682 return err 683 } 684 // Create a new stream to handle the pull. 685 st := sub.newStream(sps, s.streamTimeout) 686 err = st.pull(&s.wg) 687 sub.deleteStream(st) 688 return err 689} 690 691func (s *GServer) Seek(ctx context.Context, req *pb.SeekRequest) (*pb.SeekResponse, error) { 692 // Only handle time-based seeking for now. 693 // This fake doesn't deal with snapshots. 694 var target time.Time 695 switch v := req.Target.(type) { 696 case nil: 697 return nil, status.Errorf(codes.InvalidArgument, "missing Seek target type") 698 case *pb.SeekRequest_Time: 699 var err error 700 target, err = ptypes.Timestamp(v.Time) 701 if err != nil { 702 return nil, status.Errorf(codes.InvalidArgument, "bad Time target: %v", err) 703 } 704 default: 705 return nil, status.Errorf(codes.Unimplemented, "unhandled Seek target type %T", v) 706 } 707 708 // The entire server must be locked while doing the work below, 709 // because the messages don't have any other synchronization. 710 s.mu.Lock() 711 defer s.mu.Unlock() 712 sub, err := s.findSubscription(req.Subscription) 713 if err != nil { 714 return nil, err 715 } 716 // Drop all messages from sub that were published before the target time. 717 for id, m := range sub.msgs { 718 if m.publishTime.Before(target) { 719 delete(sub.msgs, id) 720 (*m.acks)++ 721 } 722 } 723 // Un-ack any already-acked messages after this time; 724 // redelivering them to the subscription is the closest analogue here. 725 for _, m := range s.msgs { 726 if m.PublishTime.Before(target) { 727 continue 728 } 729 sub.msgs[m.ID] = &message{ 730 publishTime: m.PublishTime, 731 proto: &pb.ReceivedMessage{ 732 AckId: m.ID, 733 // This was not preserved! 734 //Message: pm, 735 }, 736 deliveries: &m.deliveries, 737 acks: &m.acks, 738 streamIndex: -1, 739 } 740 } 741 return &pb.SeekResponse{}, nil 742} 743 744// Gets a subscription that must exist. 745// Must be called with the lock held. 746func (s *GServer) findSubscription(name string) (*subscription, error) { 747 if name == "" { 748 return nil, status.Errorf(codes.InvalidArgument, "missing subscription") 749 } 750 sub := s.subs[name] 751 if sub == nil { 752 return nil, status.Errorf(codes.NotFound, "subscription %s", name) 753 } 754 return sub, nil 755} 756 757// Must be called with the lock held. 758func (s *subscription) pull(max int) []*pb.ReceivedMessage { 759 now := timeNow() 760 s.maintainMessages(now) 761 var msgs []*pb.ReceivedMessage 762 for _, m := range s.msgs { 763 if m.outstanding() { 764 continue 765 } 766 (*m.deliveries)++ 767 m.ackDeadline = now.Add(s.ackTimeout) 768 msgs = append(msgs, m.proto) 769 if len(msgs) >= max { 770 break 771 } 772 } 773 return msgs 774} 775 776func (s *subscription) deliver() { 777 s.mu.Lock() 778 defer s.mu.Unlock() 779 780 now := timeNow() 781 s.maintainMessages(now) 782 // Try to deliver each remaining message. 783 curIndex := 0 784 for _, m := range s.msgs { 785 if m.outstanding() { 786 continue 787 } 788 // If the message was never delivered before, start with the stream at 789 // curIndex. If it was delivered before, start with the stream after the one 790 // that owned it. 791 if m.streamIndex < 0 { 792 delIndex, ok := s.tryDeliverMessage(m, curIndex, now) 793 if !ok { 794 break 795 } 796 curIndex = delIndex + 1 797 m.streamIndex = curIndex 798 } else { 799 delIndex, ok := s.tryDeliverMessage(m, m.streamIndex, now) 800 if !ok { 801 break 802 } 803 m.streamIndex = delIndex 804 } 805 } 806} 807 808// tryDeliverMessage attempts to deliver m to the stream at index i. If it can't, it 809// tries streams i+1, i+2, ..., wrapping around. Once it's tried all streams, it 810// exits. 811// 812// It returns the index of the stream it delivered the message to, or 0, false if 813// it didn't deliver the message. 814// 815// Must be called with the lock held. 816func (s *subscription) tryDeliverMessage(m *message, start int, now time.Time) (int, bool) { 817 for i := 0; i < len(s.streams); i++ { 818 idx := (i + start) % len(s.streams) 819 820 st := s.streams[idx] 821 select { 822 case <-st.done: 823 s.streams = deleteStreamAt(s.streams, idx) 824 i-- 825 826 case st.msgc <- m.proto: 827 (*m.deliveries)++ 828 m.ackDeadline = now.Add(st.ackTimeout) 829 return idx, true 830 831 default: 832 } 833 } 834 return 0, false 835} 836 837var retentionDuration = 10 * time.Minute 838 839// Must be called with the lock held. 840func (s *subscription) maintainMessages(now time.Time) { 841 for id, m := range s.msgs { 842 // Mark a message as re-deliverable if its ack deadline has expired. 843 if m.outstanding() && now.After(m.ackDeadline) { 844 m.makeAvailable() 845 } 846 pubTime, err := ptypes.Timestamp(m.proto.Message.PublishTime) 847 if err != nil { 848 panic(err) 849 } 850 // Remove messages that have been undelivered for a long time. 851 if !m.outstanding() && now.Sub(pubTime) > retentionDuration { 852 delete(s.msgs, id) 853 } 854 } 855} 856 857func (s *subscription) newStream(gs pb.Subscriber_StreamingPullServer, timeout time.Duration) *stream { 858 st := &stream{ 859 sub: s, 860 done: make(chan struct{}), 861 msgc: make(chan *pb.ReceivedMessage), 862 gstream: gs, 863 ackTimeout: s.ackTimeout, 864 timeout: timeout, 865 } 866 s.mu.Lock() 867 s.streams = append(s.streams, st) 868 s.mu.Unlock() 869 return st 870} 871 872func (s *subscription) deleteStream(st *stream) { 873 s.mu.Lock() 874 defer s.mu.Unlock() 875 var i int 876 for i = 0; i < len(s.streams); i++ { 877 if s.streams[i] == st { 878 break 879 } 880 } 881 if i < len(s.streams) { 882 s.streams = deleteStreamAt(s.streams, i) 883 } 884} 885func deleteStreamAt(s []*stream, i int) []*stream { 886 // Preserve order for round-robin delivery. 887 return append(s[:i], s[i+1:]...) 888} 889 890type message struct { 891 proto *pb.ReceivedMessage 892 publishTime time.Time 893 ackDeadline time.Time 894 deliveries *int 895 acks *int 896 streamIndex int // index of stream that currently owns msg, for round-robin delivery 897} 898 899// A message is outstanding if it is owned by some stream. 900func (m *message) outstanding() bool { 901 return !m.ackDeadline.IsZero() 902} 903 904func (m *message) makeAvailable() { 905 m.ackDeadline = time.Time{} 906} 907 908type stream struct { 909 sub *subscription 910 done chan struct{} // closed when the stream is finished 911 msgc chan *pb.ReceivedMessage 912 gstream pb.Subscriber_StreamingPullServer 913 ackTimeout time.Duration 914 timeout time.Duration 915} 916 917// pull manages the StreamingPull interaction for the life of the stream. 918func (st *stream) pull(wg *sync.WaitGroup) error { 919 errc := make(chan error, 2) 920 wg.Add(2) 921 go func() { 922 defer wg.Done() 923 errc <- st.sendLoop() 924 }() 925 go func() { 926 defer wg.Done() 927 errc <- st.recvLoop() 928 }() 929 var tchan <-chan time.Time 930 if st.timeout > 0 { 931 tchan = time.After(st.timeout) 932 } 933 // Wait until one of the goroutines returns an error, or we time out. 934 var err error 935 select { 936 case err = <-errc: 937 if err == io.EOF { 938 err = nil 939 } 940 case <-tchan: 941 } 942 close(st.done) // stop the other goroutine 943 return err 944} 945 946func (st *stream) sendLoop() error { 947 for { 948 select { 949 case <-st.done: 950 return nil 951 case rm := <-st.msgc: 952 res := &pb.StreamingPullResponse{ReceivedMessages: []*pb.ReceivedMessage{rm}} 953 if err := st.gstream.Send(res); err != nil { 954 return err 955 } 956 } 957 } 958} 959 960func (st *stream) recvLoop() error { 961 for { 962 req, err := st.gstream.Recv() 963 if err != nil { 964 return err 965 } 966 st.sub.handleStreamingPullRequest(st, req) 967 } 968} 969 970func (s *subscription) handleStreamingPullRequest(st *stream, req *pb.StreamingPullRequest) { 971 // Lock the entire server. 972 s.mu.Lock() 973 defer s.mu.Unlock() 974 975 for _, ackID := range req.AckIds { 976 s.ack(ackID) 977 } 978 for i, id := range req.ModifyDeadlineAckIds { 979 s.modifyAckDeadline(id, secsToDur(req.ModifyDeadlineSeconds[i])) 980 } 981 if req.StreamAckDeadlineSeconds > 0 { 982 st.ackTimeout = secsToDur(req.StreamAckDeadlineSeconds) 983 } 984} 985 986// Must be called with the lock held. 987func (s *subscription) ack(id string) { 988 m := s.msgs[id] 989 if m != nil { 990 (*m.acks)++ 991 delete(s.msgs, id) 992 } 993} 994 995// Must be called with the lock held. 996func (s *subscription) modifyAckDeadline(id string, d time.Duration) { 997 m := s.msgs[id] 998 if m == nil { // already acked: ignore. 999 return 1000 } 1001 if d == 0 { // nack 1002 m.makeAvailable() 1003 } else { // extend the deadline by d 1004 m.ackDeadline = timeNow().Add(d) 1005 } 1006} 1007 1008func secsToDur(secs int32) time.Duration { 1009 return time.Duration(secs) * time.Second 1010} 1011